diff --git a/.github/workflows/test-integration.yaml b/.github/workflows/test-integration.yaml index 62ea3a97..eebadb74 100644 --- a/.github/workflows/test-integration.yaml +++ b/.github/workflows/test-integration.yaml @@ -55,11 +55,12 @@ jobs: - TestPreAuthKeyCorrectUserLoggedInCommand - TestApiKeyCommand - TestNodeTagCommand + - TestTaggedNodeRegistration + - TestTagPersistenceAcrossRestart - TestNodeAdvertiseTagCommand - TestNodeCommand - TestNodeExpireCommand - TestNodeRenameCommand - - TestNodeMoveCommand - TestPolicyCommand - TestPolicyBrokenConfigCommand - TestDERPVerifyEndpoint diff --git a/AGENTS.md b/AGENTS.md index e5dd1b01..42da654c 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -237,6 +237,21 @@ headscale/ - `policy.go`: Policy storage and retrieval - Schema migrations in `schema.sql` with extensive test data coverage +**CRITICAL DATABASE MIGRATION RULES**: + +1. **NEVER reorder existing migrations** - Migration order is immutable once committed +2. **ONLY add new migrations to the END** of the migrations array +3. **NEVER disable foreign keys** in new migrations - no new migrations should be added to `migrationsRequiringFKDisabled` +4. **Migration ID format**: `YYYYMMDDHHSS-short-description` (timestamp + descriptive suffix) + - Example: `202511131500-add-user-roles` + - The timestamp must be chronologically ordered +5. **New migrations go after the comment** "As of 2025-07-02, no new IDs should be added here" +6. If you need to rename a column that other migrations depend on: + - Accept that the old column name will exist in intermediate migration states + - Update code to work with the new column name + - Let AutoMigrate create the new column if needed + - Do NOT try to rename columns that later migrations reference + **Policy Engine (`hscontrol/policy/`)** - `policy.go`: Core ACL evaluation logic, HuJSON parsing @@ -687,6 +702,326 @@ assert.EventuallyWithT(t, func(c *assert.CollectT) { }, 10*time.Second, 500*time.Millisecond, "mixed operations") ``` +## Tags-as-Identity Architecture + +### Overview + +Headscale implements a **tags-as-identity** model where tags and user ownership are mutually exclusive ways to identify nodes. This is a fundamental architectural principle that affects node registration, ownership, ACL evaluation, and API behavior. + +### Core Principle: Tags XOR User Ownership + +Every node in Headscale is **either** tagged **or** user-owned, never both: + +- **Tagged Nodes**: Ownership is defined by tags (e.g., `tag:server`, `tag:database`) + - Tags are set during registration via tagged PreAuthKey + - Tags are immutable after registration (cannot be changed via API) + - May have `UserID` set for "created by" tracking, but ownership is via tags + - Identified by: `node.IsTagged()` returns `true` + +- **User-Owned Nodes**: Ownership is defined by user assignment + - Registered via OIDC, web auth, or untagged PreAuthKey + - Node belongs to a specific user's namespace + - No tags (empty tags array) + - Identified by: `node.UserID().Valid() && !node.IsTagged()` + +### Critical Implementation Details + +#### Node Identification Methods + +```go +// Primary methods for determining node ownership +node.IsTagged() // Returns true if node has tags OR AuthKey.Tags +node.HasTag(tag) // Returns true if node has specific tag +node.IsUserOwned() // Returns true if UserID set AND not tagged + +// IMPORTANT: UserID can be set on tagged nodes for tracking! +// Always use IsTagged() to determine actual ownership, not just UserID.Valid() +``` + +#### UserID Field Semantics + +**Critical distinction**: `UserID` has different meanings depending on node type: + +- **Tagged nodes**: `UserID` is optional "created by" tracking + - Indicates which user created the tagged PreAuthKey + - Does NOT define ownership (tags define ownership) + - Example: User "alice" creates tagged PreAuthKey with `tag:server`, node gets `UserID=alice.ID` + `Tags=["tag:server"]` + +- **User-owned nodes**: `UserID` defines ownership + - Required field for non-tagged nodes + - Defines which user namespace the node belongs to + - Example: User "bob" registers via OIDC, node gets `UserID=bob.ID` + `Tags=[]` + +#### Mapper Behavior (mapper/tail.go) + +The mapper converts internal nodes to Tailscale protocol format, handling the TaggedDevices special user: + +```go +// From mapper/tail.go:102-116 +User: func() tailcfg.UserID { + // IMPORTANT: Tags-as-identity model + // Tagged nodes ALWAYS use TaggedDevices user, even if UserID is set + if node.IsTagged() { + return tailcfg.UserID(int64(types.TaggedDevices.ID)) + } + // User-owned nodes: use the actual user ID + return tailcfg.UserID(int64(node.UserID().Get())) +}() +``` + +**TaggedDevices constant** (`types.TaggedDevices.ID = 2147455555`): Special user ID for all tagged nodes in MapResponse protocol. + +#### Registration Flow + +**Tagged Node Registration** (via tagged PreAuthKey): + +1. User creates PreAuthKey with tags: `pak.Tags = ["tag:server"]` +2. Node registers with PreAuthKey +3. Node gets: `Tags = ["tag:server"]`, `UserID = user.ID` (optional tracking), `AuthKeyID = pak.ID` +4. `IsTagged()` returns `true` (ownership via tags) +5. MapResponse sends `User = TaggedDevices.ID` + +**User-Owned Node Registration** (via OIDC/web/untagged PreAuthKey): + +1. User authenticates or uses untagged PreAuthKey +2. Node registers +3. Node gets: `Tags = []`, `UserID = user.ID` (required) +4. `IsTagged()` returns `false` (ownership via user) +5. MapResponse sends `User = user.ID` + +#### API Validation (SetTags) + +The SetTags gRPC API enforces tags-as-identity rules: + +```go +// From grpcv1.go:340-347 +// User-owned nodes are nodes with UserID that are NOT tagged +isUserOwned := nodeView.UserID().Valid() && !nodeView.IsTagged() +if isUserOwned && len(request.GetTags()) > 0 { + return error("cannot set tags on user-owned nodes") +} +``` + +**Key validation rules**: + +- ✅ Can call SetTags on tagged nodes (tags already define ownership) +- ❌ Cannot set tags on user-owned nodes (would violate XOR rule) +- ❌ Cannot remove all tags from tagged nodes (would orphan the node) + +#### Database Layer (db/node.go) + +**Tag storage**: Tags are stored in PostgreSQL ARRAY column and SQLite JSON column: + +```sql +-- From schema.sql +tags TEXT[] DEFAULT '{}' NOT NULL, -- PostgreSQL +tags TEXT DEFAULT '[]' NOT NULL, -- SQLite (JSON array) +``` + +**Validation** (`state/tags.go`): + +- `validateNodeOwnership()`: Enforces tags XOR user rule +- `validateAndNormalizeTags()`: Validates tag format (`tag:name`) and uniqueness + +#### Policy Layer + +**Tag Ownership** (policy/v2/policy.go): + +```go +func NodeCanHaveTag(node types.NodeView, tag string) bool { + // Checks if node's IP is in the tagOwnerMap IP set + // This is IP-based authorization, not UserID-based + if ips, ok := pm.tagOwnerMap[Tag(tag)]; ok { + if slices.ContainsFunc(node.IPs(), ips.Contains) { + return true + } + } + return false +} +``` + +**Important**: Tag authorization is based on IP ranges in ACL, not UserID. Tags define identity, ACL authorizes that identity. + +### Testing Tags-as-Identity + +**Unit Tests** (`hscontrol/types/node_tags_test.go`): + +- `TestNodeIsTagged`: Validates IsTagged() for various scenarios +- `TestNodeOwnershipModel`: Tests tags XOR user ownership +- `TestUserTypedID`: Helper method validation + +**API Tests** (`hscontrol/grpcv1_test.go`): + +- `TestSetTags_UserXORTags`: Validates rejection of setting tags on user-owned nodes +- `TestSetTags_TaggedNode`: Validates that tagged nodes (even with UserID) are not rejected + +**Auth Tests** (`hscontrol/auth_test.go:890-928`): + +- Tests node registration with tagged PreAuthKey +- Validates tags are applied during registration + +### Common Pitfalls + +1. **Don't check only `UserID.Valid()` to determine user ownership** + - ❌ Wrong: `if node.UserID().Valid() { /* user-owned */ }` + - ✅ Correct: `if node.UserID().Valid() && !node.IsTagged() { /* user-owned */ }` + +2. **Don't assume tagged nodes never have UserID set** + - Tagged nodes MAY have UserID for "created by" tracking + - Always use `IsTagged()` to determine ownership type + +3. **Don't allow setting tags on user-owned nodes** + - This violates the tags XOR user principle + - Use API validation to prevent this + +4. **Don't forget TaggedDevices in mapper** + - All tagged nodes MUST use `TaggedDevices.ID` in MapResponse + - User ID is only for actual user-owned nodes + +### Migration Considerations + +When nodes transition between ownership models: + +- **No automatic migration**: Tags-as-identity is set at registration and immutable +- **Re-registration required**: To change from user-owned to tagged (or vice versa), node must be deleted and re-registered +- **UserID persistence**: UserID on tagged nodes is informational and not cleared + +### Architecture Benefits + +The tags-as-identity model provides: + +1. **Clear ownership semantics**: No ambiguity about who/what owns a node +2. **ACL simplicity**: Tag-based access control without user conflicts +3. **API safety**: Validation prevents invalid ownership states +4. **Protocol compatibility**: TaggedDevices special user aligns with Tailscale's model + +## Logging Patterns + +### Incremental Log Event Building + +When building log statements with multiple fields, especially with conditional fields, use the **incremental log event pattern** instead of long single-line chains. This improves readability and allows conditional field addition. + +**Pattern:** + +```go +// GOOD: Incremental building with conditional fields +logEvent := log.Debug(). + Str("node", node.Hostname). + Str("machine_key", node.MachineKey.ShortString()). + Str("node_key", node.NodeKey.ShortString()) + +if node.User != nil { + logEvent = logEvent.Str("user", node.User.Username()) +} else if node.UserID != nil { + logEvent = logEvent.Uint("user_id", *node.UserID) +} else { + logEvent = logEvent.Str("user", "none") +} + +logEvent.Msg("Registering node") +``` + +**Key rules:** + +1. **Assign chained calls back to the variable**: `logEvent = logEvent.Str(...)` - zerolog methods return a new event, so you must capture the return value +2. **Use for conditional fields**: When fields depend on runtime conditions, build incrementally +3. **Use for long log lines**: When a log line exceeds ~100 characters, split it for readability +4. **Call `.Msg()` at the end**: The final `.Msg()` or `.Msgf()` sends the log event + +**Anti-pattern to avoid:** + +```go +// BAD: Long single-line chains are hard to read and can't have conditional fields +log.Debug().Caller().Str("node", node.Hostname).Str("machine_key", node.MachineKey.ShortString()).Str("node_key", node.NodeKey.ShortString()).Str("user", node.User.Username()).Msg("Registering node") + +// BAD: Forgetting to assign the return value (field is lost!) +logEvent := log.Debug().Str("node", node.Hostname) +logEvent.Str("user", username) // This field is LOST - not assigned back +logEvent.Msg("message") // Only has "node" field +``` + +**When to use this pattern:** + +- Log statements with 4+ fields +- Any log with conditional fields +- Complex logging in loops or error handling +- When you need to add context incrementally + +**Example from codebase** (`hscontrol/db/node.go`): + +```go +logEvent := log.Debug(). + Str("node", node.Hostname). + Str("machine_key", node.MachineKey.ShortString()). + Str("node_key", node.NodeKey.ShortString()) + +if node.User != nil { + logEvent = logEvent.Str("user", node.User.Username()) +} else if node.UserID != nil { + logEvent = logEvent.Uint("user_id", *node.UserID) +} else { + logEvent = logEvent.Str("user", "none") +} + +logEvent.Msg("Registering test node") +``` + +### Avoiding Log Helper Functions + +Prefer the incremental log event pattern over creating helper functions that return multiple logging closures. Helper functions like `logPollFunc` create unnecessary indirection and allocate closures. + +**Instead of:** + +```go +// AVOID: Helper function returning closures +func logPollFunc(req tailcfg.MapRequest, node *types.Node) ( + func(string, ...any), // warnf + func(string, ...any), // infof + func(string, ...any), // tracef + func(error, string, ...any), // errf +) { + return func(msg string, a ...any) { + log.Warn(). + Caller(). + Bool("omitPeers", req.OmitPeers). + Bool("stream", req.Stream). + Uint64("node.id", node.ID.Uint64()). + Str("node.name", node.Hostname). + Msgf(msg, a...) + }, + // ... more closures +} +``` + +**Prefer:** + +```go +// BETTER: Build log events inline with shared context +func (m *mapSession) logTrace(msg string) { + log.Trace(). + Caller(). + Bool("omitPeers", m.req.OmitPeers). + Bool("stream", m.req.Stream). + Uint64("node.id", m.node.ID.Uint64()). + Str("node.name", m.node.Hostname). + Msg(msg) +} + +// Or use incremental building for complex cases +logEvent := log.Trace(). + Caller(). + Bool("omitPeers", m.req.OmitPeers). + Bool("stream", m.req.Stream). + Uint64("node.id", m.node.ID.Uint64()). + Str("node.name", m.node.Hostname) + +if additionalContext { + logEvent = logEvent.Str("extra", value) +} + +logEvent.Msg("Operation completed") +``` + ## Important Notes - **Dependencies**: Use `nix develop` for consistent toolchain (Go, buf, protobuf tools, linting) @@ -697,3 +1032,4 @@ assert.EventuallyWithT(t, func(c *assert.CollectT) { - **Integration Tests**: Require Docker and can consume significant disk space - use headscale-integration-tester agent - **Performance**: NodeStore optimizations are critical for scale - be careful with changes to state management - **Quality Assurance**: Always use appropriate specialized agents for testing and validation tasks +- **Tags-as-Identity**: Tags and user ownership are mutually exclusive - always use `IsTagged()` to determine ownership diff --git a/CHANGELOG.md b/CHANGELOG.md index d812bcef..e5a4163b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -21,6 +21,10 @@ at creation time. When listing keys, only the prefix is shown (e.g., `hskey-auth-{prefix}-{secret}`. Legacy plaintext keys continue to work for backwards compatibility. +### Tags + +Tags are now implemented following the Tailscale model where tags and user ownership are mutually exclusive. Devices can be either user-owned (authenticated via web/OIDC) or tagged (authenticated via tagged PreAuthKeys). Tagged devices receive their identity from tags rather than users, making them suitable for servers and infrastructure. Applying a tag to a device removes user-based authentication. See the [Tailscale tags documentation](https://tailscale.com/kb/1068/tags) for details on how tags work. + ### Database migration support removed for pre-0.25.0 databases Headscale no longer supports direct upgrades from databases created before @@ -30,6 +34,8 @@ release. ### BREAKING +- **Tags**: The gRPC `SetTags` endpoint now allows converting user-owned nodes to tagged nodes by setting tags. Once a node is tagged, it cannot be converted back to a user-owned node. + - Database migration support removed for pre-0.25.0 databases [#2883](https://github.com/juanfont/headscale/pull/2883) - If you are running a version older than 0.25.0, you must upgrade to 0.25.1 first, then upgrade to this release - See the [upgrade path documentation](https://headscale.net/stable/about/faq/#what-is-the-recommended-update-path-can-i-skip-multiple-versions-while-updating) for detailed guidance diff --git a/hscontrol/auth.go b/hscontrol/auth.go index 1ab22709..5c72ea81 100644 --- a/hscontrol/auth.go +++ b/hscontrol/auth.go @@ -233,11 +233,7 @@ func isAuthKey(req tailcfg.RegisterRequest) bool { } func nodeToRegisterResponse(node types.NodeView) *tailcfg.RegisterResponse { - return &tailcfg.RegisterResponse{ - // TODO(kradalby): Only send for user-owned nodes - // and not tagged nodes when tags is working. - User: node.UserView().TailscaleUser(), - Login: node.UserView().TailscaleLogin(), + resp := &tailcfg.RegisterResponse{ NodeKeyExpired: node.IsExpired(), // Headscale does not implement the concept of machine authorization @@ -245,6 +241,18 @@ func nodeToRegisterResponse(node types.NodeView) *tailcfg.RegisterResponse { // Revisit this if #2176 gets implemented. MachineAuthorized: true, } + + // For tagged nodes, use the TaggedDevices special user + // For user-owned nodes, include User and Login information from the actual user + if node.IsTagged() { + resp.User = types.TaggedDevices.View().TailscaleUser() + resp.Login = types.TaggedDevices.View().TailscaleLogin() + } else if node.UserView().Valid() { + resp.User = node.UserView().TailscaleUser() + resp.Login = node.UserView().TailscaleLogin() + } + + return resp } func (h *Headscale) waitForFollowup( diff --git a/hscontrol/auth_tags_test.go b/hscontrol/auth_tags_test.go new file mode 100644 index 00000000..4bc32a6a --- /dev/null +++ b/hscontrol/auth_tags_test.go @@ -0,0 +1,535 @@ +package hscontrol + +import ( + "testing" + "time" + + "github.com/juanfont/headscale/hscontrol/types" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "tailscale.com/tailcfg" + "tailscale.com/types/key" +) + +// TestTaggedPreAuthKeyCreatesTaggedNode tests that a PreAuthKey with tags creates +// a tagged node with: +// - Tags from the PreAuthKey +// - UserID tracking who created the key (informational "created by") +// - IsTagged() returns true. +func TestTaggedPreAuthKeyCreatesTaggedNode(t *testing.T) { + app := createTestApp(t) + + user := app.state.CreateUserForTest("tag-creator") + tags := []string{"tag:server", "tag:prod"} + + // Create a tagged PreAuthKey + pak, err := app.state.CreatePreAuthKey(user.TypedID(), true, false, nil, tags) + require.NoError(t, err) + require.NotEmpty(t, pak.Tags, "PreAuthKey should have tags") + require.ElementsMatch(t, tags, pak.Tags, "PreAuthKey should have specified tags") + + // Register a node using the tagged key + machineKey := key.NewMachine() + nodeKey := key.NewNode() + + regReq := tailcfg.RegisterRequest{ + Auth: &tailcfg.RegisterResponseAuth{ + AuthKey: pak.Key, + }, + NodeKey: nodeKey.Public(), + Hostinfo: &tailcfg.Hostinfo{ + Hostname: "tagged-node", + }, + Expiry: time.Now().Add(24 * time.Hour), + } + + resp, err := app.handleRegisterWithAuthKey(regReq, machineKey.Public()) + require.NoError(t, err) + require.True(t, resp.MachineAuthorized) + + // Verify the node was created with tags + node, found := app.state.GetNodeByNodeKey(nodeKey.Public()) + require.True(t, found) + + // Critical assertions for tags-as-identity model + assert.True(t, node.IsTagged(), "Node should be tagged") + assert.ElementsMatch(t, tags, node.Tags().AsSlice(), "Node should have tags from PreAuthKey") + assert.True(t, node.UserID().Valid(), "Node should have UserID tracking creator") + assert.Equal(t, user.ID, node.UserID().Get(), "UserID should track PreAuthKey creator") + + // Verify node is identified correctly + assert.True(t, node.IsTagged(), "Tagged node is not user-owned") + assert.True(t, node.HasTag("tag:server"), "Node should have tag:server") + assert.True(t, node.HasTag("tag:prod"), "Node should have tag:prod") + assert.False(t, node.HasTag("tag:other"), "Node should not have tag:other") +} + +// TestReAuthDoesNotReapplyTags tests that when a node re-authenticates using the +// same PreAuthKey, the tags are NOT re-applied. Tags are only set during initial +// authentication. This is critical for the container restart scenario (#2830). +// +// NOTE: This test verifies that re-authentication preserves the node's current tags +// without testing tag modification via SetNodeTags (which requires ACL policy setup). +func TestReAuthDoesNotReapplyTags(t *testing.T) { + app := createTestApp(t) + + user := app.state.CreateUserForTest("tag-creator") + initialTags := []string{"tag:server", "tag:dev"} + + // Create a tagged PreAuthKey with reusable=true for re-auth + pak, err := app.state.CreatePreAuthKey(user.TypedID(), true, false, nil, initialTags) + require.NoError(t, err) + + // Initial registration + machineKey := key.NewMachine() + nodeKey := key.NewNode() + + regReq := tailcfg.RegisterRequest{ + Auth: &tailcfg.RegisterResponseAuth{ + AuthKey: pak.Key, + }, + NodeKey: nodeKey.Public(), + Hostinfo: &tailcfg.Hostinfo{ + Hostname: "reauth-test-node", + }, + Expiry: time.Now().Add(24 * time.Hour), + } + + resp, err := app.handleRegisterWithAuthKey(regReq, machineKey.Public()) + require.NoError(t, err) + require.True(t, resp.MachineAuthorized) + + // Verify initial tags + node, found := app.state.GetNodeByNodeKey(nodeKey.Public()) + require.True(t, found) + require.True(t, node.IsTagged()) + require.ElementsMatch(t, initialTags, node.Tags().AsSlice()) + + // Re-authenticate with the SAME PreAuthKey (container restart scenario) + // Key behavior: Tags should NOT be re-applied during re-auth + reAuthReq := tailcfg.RegisterRequest{ + Auth: &tailcfg.RegisterResponseAuth{ + AuthKey: pak.Key, // Same key + }, + NodeKey: nodeKey.Public(), // Same node key + Hostinfo: &tailcfg.Hostinfo{ + Hostname: "reauth-test-node", + }, + Expiry: time.Now().Add(24 * time.Hour), + } + + reAuthResp, err := app.handleRegisterWithAuthKey(reAuthReq, machineKey.Public()) + require.NoError(t, err) + require.True(t, reAuthResp.MachineAuthorized) + + // CRITICAL: Tags should remain unchanged after re-auth + // They should match the original tags, proving they weren't re-applied + nodeAfterReauth, found := app.state.GetNodeByNodeKey(nodeKey.Public()) + require.True(t, found) + assert.True(t, nodeAfterReauth.IsTagged(), "Node should still be tagged") + assert.ElementsMatch(t, initialTags, nodeAfterReauth.Tags().AsSlice(), "Tags should remain unchanged on re-auth") + + // Verify only one node was created (no duplicates) + nodes := app.state.ListNodesByUser(types.UserID(user.ID)) + assert.Equal(t, 1, nodes.Len(), "Should have exactly one node") +} + +// NOTE: TestSetTagsOnUserOwnedNode functionality is covered by gRPC tests in grpcv1_test.go +// which properly handle ACL policy setup. The test verifies that SetTags can convert +// user-owned nodes to tagged nodes while preserving UserID. + +// TestCannotRemoveAllTags tests that attempting to remove all tags from a +// tagged node fails with ErrCannotRemoveAllTags. Once a node is tagged, +// it must always have at least one tag (Tailscale requirement). +func TestCannotRemoveAllTags(t *testing.T) { + app := createTestApp(t) + + user := app.state.CreateUserForTest("tag-creator") + tags := []string{"tag:server"} + + // Create a tagged node + pak, err := app.state.CreatePreAuthKey(user.TypedID(), true, false, nil, tags) + require.NoError(t, err) + + machineKey := key.NewMachine() + nodeKey := key.NewNode() + + regReq := tailcfg.RegisterRequest{ + Auth: &tailcfg.RegisterResponseAuth{ + AuthKey: pak.Key, + }, + NodeKey: nodeKey.Public(), + Hostinfo: &tailcfg.Hostinfo{ + Hostname: "tagged-node", + }, + Expiry: time.Now().Add(24 * time.Hour), + } + + resp, err := app.handleRegisterWithAuthKey(regReq, machineKey.Public()) + require.NoError(t, err) + require.True(t, resp.MachineAuthorized) + + // Verify node is tagged + node, found := app.state.GetNodeByNodeKey(nodeKey.Public()) + require.True(t, found) + require.True(t, node.IsTagged()) + + // Attempt to remove all tags by setting empty array + _, _, err = app.state.SetNodeTags(node.ID(), []string{}) + require.Error(t, err, "Should not be able to remove all tags") + require.ErrorIs(t, err, types.ErrCannotRemoveAllTags, "Error should be ErrCannotRemoveAllTags") + + // Verify node still has original tags + nodeAfter, found := app.state.GetNodeByNodeKey(nodeKey.Public()) + require.True(t, found) + assert.True(t, nodeAfter.IsTagged(), "Node should still be tagged") + assert.ElementsMatch(t, tags, nodeAfter.Tags().AsSlice(), "Tags should be unchanged") +} + +// TestUserOwnedNodeCreatedWithUntaggedPreAuthKey tests that using a PreAuthKey +// without tags creates a user-owned node (no tags, UserID is the owner). +func TestUserOwnedNodeCreatedWithUntaggedPreAuthKey(t *testing.T) { + app := createTestApp(t) + + user := app.state.CreateUserForTest("node-owner") + + // Create an untagged PreAuthKey + pak, err := app.state.CreatePreAuthKey(user.TypedID(), true, false, nil, nil) + require.NoError(t, err) + require.Empty(t, pak.Tags, "PreAuthKey should not be tagged") + require.Empty(t, pak.Tags, "PreAuthKey should have no tags") + + // Register a node + machineKey := key.NewMachine() + nodeKey := key.NewNode() + + regReq := tailcfg.RegisterRequest{ + Auth: &tailcfg.RegisterResponseAuth{ + AuthKey: pak.Key, + }, + NodeKey: nodeKey.Public(), + Hostinfo: &tailcfg.Hostinfo{ + Hostname: "user-owned-node", + }, + Expiry: time.Now().Add(24 * time.Hour), + } + + resp, err := app.handleRegisterWithAuthKey(regReq, machineKey.Public()) + require.NoError(t, err) + require.True(t, resp.MachineAuthorized) + + // Verify node is user-owned + node, found := app.state.GetNodeByNodeKey(nodeKey.Public()) + require.True(t, found) + + // Critical assertions for user-owned node + assert.False(t, node.IsTagged(), "Node should not be tagged") + assert.False(t, node.IsTagged(), "Node should be user-owned (not tagged)") + assert.Empty(t, node.Tags().AsSlice(), "Node should have no tags") + assert.True(t, node.UserID().Valid(), "Node should have UserID") + assert.Equal(t, user.ID, node.UserID().Get(), "UserID should be the PreAuthKey owner") +} + +// TestMultipleNodesWithSameReusableTaggedPreAuthKey tests that a reusable +// PreAuthKey with tags can be used to register multiple nodes, and all nodes +// receive the same tags from the key. +func TestMultipleNodesWithSameReusableTaggedPreAuthKey(t *testing.T) { + app := createTestApp(t) + + user := app.state.CreateUserForTest("tag-creator") + tags := []string{"tag:server", "tag:prod"} + + // Create a REUSABLE tagged PreAuthKey + pak, err := app.state.CreatePreAuthKey(user.TypedID(), true, false, nil, tags) + require.NoError(t, err) + require.ElementsMatch(t, tags, pak.Tags) + + // Register first node + machineKey1 := key.NewMachine() + nodeKey1 := key.NewNode() + + regReq1 := tailcfg.RegisterRequest{ + Auth: &tailcfg.RegisterResponseAuth{ + AuthKey: pak.Key, + }, + NodeKey: nodeKey1.Public(), + Hostinfo: &tailcfg.Hostinfo{ + Hostname: "tagged-node-1", + }, + Expiry: time.Now().Add(24 * time.Hour), + } + + resp1, err := app.handleRegisterWithAuthKey(regReq1, machineKey1.Public()) + require.NoError(t, err) + require.True(t, resp1.MachineAuthorized) + + // Register second node with SAME PreAuthKey + machineKey2 := key.NewMachine() + nodeKey2 := key.NewNode() + + regReq2 := tailcfg.RegisterRequest{ + Auth: &tailcfg.RegisterResponseAuth{ + AuthKey: pak.Key, // Same key + }, + NodeKey: nodeKey2.Public(), + Hostinfo: &tailcfg.Hostinfo{ + Hostname: "tagged-node-2", + }, + Expiry: time.Now().Add(24 * time.Hour), + } + + resp2, err := app.handleRegisterWithAuthKey(regReq2, machineKey2.Public()) + require.NoError(t, err) + require.True(t, resp2.MachineAuthorized) + + // Verify both nodes exist and have the same tags + node1, found := app.state.GetNodeByNodeKey(nodeKey1.Public()) + require.True(t, found) + node2, found := app.state.GetNodeByNodeKey(nodeKey2.Public()) + require.True(t, found) + + // Both nodes should be tagged with the same tags + assert.True(t, node1.IsTagged(), "First node should be tagged") + assert.True(t, node2.IsTagged(), "Second node should be tagged") + assert.ElementsMatch(t, tags, node1.Tags().AsSlice(), "First node should have PreAuthKey tags") + assert.ElementsMatch(t, tags, node2.Tags().AsSlice(), "Second node should have PreAuthKey tags") + + // Both nodes should track the same creator + assert.Equal(t, user.ID, node1.UserID().Get(), "First node should track creator") + assert.Equal(t, user.ID, node2.UserID().Get(), "Second node should track creator") + + // Verify we have exactly 2 nodes + nodes := app.state.ListNodesByUser(types.UserID(user.ID)) + assert.Equal(t, 2, nodes.Len(), "Should have exactly two nodes") +} + +// TestNonReusableTaggedPreAuthKey tests that a non-reusable PreAuthKey with tags +// can only be used once. The second attempt should fail. +func TestNonReusableTaggedPreAuthKey(t *testing.T) { + app := createTestApp(t) + + user := app.state.CreateUserForTest("tag-creator") + tags := []string{"tag:server"} + + // Create a NON-REUSABLE tagged PreAuthKey + pak, err := app.state.CreatePreAuthKey(user.TypedID(), false, false, nil, tags) + require.NoError(t, err) + require.ElementsMatch(t, tags, pak.Tags) + + // Register first node - should succeed + machineKey1 := key.NewMachine() + nodeKey1 := key.NewNode() + + regReq1 := tailcfg.RegisterRequest{ + Auth: &tailcfg.RegisterResponseAuth{ + AuthKey: pak.Key, + }, + NodeKey: nodeKey1.Public(), + Hostinfo: &tailcfg.Hostinfo{ + Hostname: "tagged-node-1", + }, + Expiry: time.Now().Add(24 * time.Hour), + } + + resp1, err := app.handleRegisterWithAuthKey(regReq1, machineKey1.Public()) + require.NoError(t, err) + require.True(t, resp1.MachineAuthorized) + + // Verify first node was created with tags + node1, found := app.state.GetNodeByNodeKey(nodeKey1.Public()) + require.True(t, found) + assert.True(t, node1.IsTagged()) + assert.ElementsMatch(t, tags, node1.Tags().AsSlice()) + + // Attempt to register second node with SAME non-reusable key - should fail + machineKey2 := key.NewMachine() + nodeKey2 := key.NewNode() + + regReq2 := tailcfg.RegisterRequest{ + Auth: &tailcfg.RegisterResponseAuth{ + AuthKey: pak.Key, // Same non-reusable key + }, + NodeKey: nodeKey2.Public(), + Hostinfo: &tailcfg.Hostinfo{ + Hostname: "tagged-node-2", + }, + Expiry: time.Now().Add(24 * time.Hour), + } + + _, err = app.handleRegisterWithAuthKey(regReq2, machineKey2.Public()) + require.Error(t, err, "Should not be able to reuse non-reusable PreAuthKey") + + // Verify only one node was created + nodes := app.state.ListNodesByUser(types.UserID(user.ID)) + assert.Equal(t, 1, nodes.Len(), "Should have exactly one node") +} + +// TestExpiredTaggedPreAuthKey tests that an expired PreAuthKey with tags +// cannot be used to register a node. +func TestExpiredTaggedPreAuthKey(t *testing.T) { + app := createTestApp(t) + + user := app.state.CreateUserForTest("tag-creator") + tags := []string{"tag:server"} + + // Create a PreAuthKey that expires immediately + expiration := time.Now().Add(-1 * time.Hour) // Already expired + pak, err := app.state.CreatePreAuthKey(user.TypedID(), false, false, &expiration, tags) + require.NoError(t, err) + require.ElementsMatch(t, tags, pak.Tags) + + // Attempt to register with expired key + machineKey := key.NewMachine() + nodeKey := key.NewNode() + + regReq := tailcfg.RegisterRequest{ + Auth: &tailcfg.RegisterResponseAuth{ + AuthKey: pak.Key, + }, + NodeKey: nodeKey.Public(), + Hostinfo: &tailcfg.Hostinfo{ + Hostname: "tagged-node", + }, + Expiry: time.Now().Add(24 * time.Hour), + } + + _, err = app.handleRegisterWithAuthKey(regReq, machineKey.Public()) + require.Error(t, err, "Should not be able to use expired PreAuthKey") + + // Verify no node was created + _, found := app.state.GetNodeByNodeKey(nodeKey.Public()) + assert.False(t, found, "No node should be created with expired key") +} + +// TestSingleVsMultipleTags tests that PreAuthKeys work correctly with both +// a single tag and multiple tags. +func TestSingleVsMultipleTags(t *testing.T) { + app := createTestApp(t) + + user := app.state.CreateUserForTest("tag-creator") + + // Test with single tag + singleTag := []string{"tag:server"} + pak1, err := app.state.CreatePreAuthKey(user.TypedID(), true, false, nil, singleTag) + require.NoError(t, err) + + machineKey1 := key.NewMachine() + nodeKey1 := key.NewNode() + + regReq1 := tailcfg.RegisterRequest{ + Auth: &tailcfg.RegisterResponseAuth{ + AuthKey: pak1.Key, + }, + NodeKey: nodeKey1.Public(), + Hostinfo: &tailcfg.Hostinfo{ + Hostname: "single-tag-node", + }, + Expiry: time.Now().Add(24 * time.Hour), + } + + resp1, err := app.handleRegisterWithAuthKey(regReq1, machineKey1.Public()) + require.NoError(t, err) + require.True(t, resp1.MachineAuthorized) + + node1, found := app.state.GetNodeByNodeKey(nodeKey1.Public()) + require.True(t, found) + assert.True(t, node1.IsTagged()) + assert.ElementsMatch(t, singleTag, node1.Tags().AsSlice()) + + // Test with multiple tags + multipleTags := []string{"tag:server", "tag:prod", "tag:database"} + pak2, err := app.state.CreatePreAuthKey(user.TypedID(), true, false, nil, multipleTags) + require.NoError(t, err) + + machineKey2 := key.NewMachine() + nodeKey2 := key.NewNode() + + regReq2 := tailcfg.RegisterRequest{ + Auth: &tailcfg.RegisterResponseAuth{ + AuthKey: pak2.Key, + }, + NodeKey: nodeKey2.Public(), + Hostinfo: &tailcfg.Hostinfo{ + Hostname: "multi-tag-node", + }, + Expiry: time.Now().Add(24 * time.Hour), + } + + resp2, err := app.handleRegisterWithAuthKey(regReq2, machineKey2.Public()) + require.NoError(t, err) + require.True(t, resp2.MachineAuthorized) + + node2, found := app.state.GetNodeByNodeKey(nodeKey2.Public()) + require.True(t, found) + assert.True(t, node2.IsTagged()) + assert.ElementsMatch(t, multipleTags, node2.Tags().AsSlice()) + + // Verify HasTag works for all tags + assert.True(t, node2.HasTag("tag:server")) + assert.True(t, node2.HasTag("tag:prod")) + assert.True(t, node2.HasTag("tag:database")) + assert.False(t, node2.HasTag("tag:other")) +} + +// TestReAuthWithDifferentMachineKey tests the edge case where a node attempts +// to re-authenticate with the same NodeKey but a DIFFERENT MachineKey. +// This scenario should be handled gracefully (currently creates a new node). +func TestReAuthWithDifferentMachineKey(t *testing.T) { + app := createTestApp(t) + + user := app.state.CreateUserForTest("tag-creator") + tags := []string{"tag:server"} + + // Create a reusable tagged PreAuthKey + pak, err := app.state.CreatePreAuthKey(user.TypedID(), true, false, nil, tags) + require.NoError(t, err) + + // Initial registration + machineKey1 := key.NewMachine() + nodeKey := key.NewNode() // Same NodeKey for both attempts + + regReq1 := tailcfg.RegisterRequest{ + Auth: &tailcfg.RegisterResponseAuth{ + AuthKey: pak.Key, + }, + NodeKey: nodeKey.Public(), + Hostinfo: &tailcfg.Hostinfo{ + Hostname: "test-node", + }, + Expiry: time.Now().Add(24 * time.Hour), + } + + resp1, err := app.handleRegisterWithAuthKey(regReq1, machineKey1.Public()) + require.NoError(t, err) + require.True(t, resp1.MachineAuthorized) + + // Verify initial node + node1, found := app.state.GetNodeByNodeKey(nodeKey.Public()) + require.True(t, found) + assert.True(t, node1.IsTagged()) + + // Re-authenticate with DIFFERENT MachineKey but SAME NodeKey + machineKey2 := key.NewMachine() // Different machine key + + regReq2 := tailcfg.RegisterRequest{ + Auth: &tailcfg.RegisterResponseAuth{ + AuthKey: pak.Key, + }, + NodeKey: nodeKey.Public(), // Same NodeKey + Hostinfo: &tailcfg.Hostinfo{ + Hostname: "test-node", + }, + Expiry: time.Now().Add(24 * time.Hour), + } + + resp2, err := app.handleRegisterWithAuthKey(regReq2, machineKey2.Public()) + require.NoError(t, err) + require.True(t, resp2.MachineAuthorized) + + // Verify the node still exists and has tags + // Note: Depending on implementation, this might be the same node or a new node + node2, found := app.state.GetNodeByNodeKey(nodeKey.Public()) + require.True(t, found) + assert.True(t, node2.IsTagged()) + assert.ElementsMatch(t, tags, node2.Tags().AsSlice()) +} diff --git a/hscontrol/auth_test.go b/hscontrol/auth_test.go index b4ab8f16..23e1f226 100644 --- a/hscontrol/auth_test.go +++ b/hscontrol/auth_test.go @@ -70,7 +70,8 @@ func TestAuthenticationFlows(t *testing.T) { name: "preauth_key_valid_new_node", setupFunc: func(t *testing.T, app *Headscale) (string, error) { user := app.state.CreateUserForTest("preauth-user") - pak, err := app.state.CreatePreAuthKey(types.UserID(user.ID), true, false, nil, nil) + + pak, err := app.state.CreatePreAuthKey(user.TypedID(), true, false, nil, nil) if err != nil { return "", err } @@ -111,7 +112,8 @@ func TestAuthenticationFlows(t *testing.T) { name: "preauth_key_reusable_multiple_nodes", setupFunc: func(t *testing.T, app *Headscale) (string, error) { user := app.state.CreateUserForTest("reusable-user") - pak, err := app.state.CreatePreAuthKey(types.UserID(user.ID), true, false, nil, nil) + + pak, err := app.state.CreatePreAuthKey(user.TypedID(), true, false, nil, nil) if err != nil { return "", err } @@ -177,7 +179,8 @@ func TestAuthenticationFlows(t *testing.T) { name: "preauth_key_single_use_exhausted", setupFunc: func(t *testing.T, app *Headscale) (string, error) { user := app.state.CreateUserForTest("single-use-user") - pak, err := app.state.CreatePreAuthKey(types.UserID(user.ID), false, false, nil, nil) + + pak, err := app.state.CreatePreAuthKey(user.TypedID(), false, false, nil, nil) if err != nil { return "", err } @@ -264,7 +267,8 @@ func TestAuthenticationFlows(t *testing.T) { name: "preauth_key_ephemeral_node", setupFunc: func(t *testing.T, app *Headscale) (string, error) { user := app.state.CreateUserForTest("ephemeral-user") - pak, err := app.state.CreatePreAuthKey(types.UserID(user.ID), false, true, nil, nil) + + pak, err := app.state.CreatePreAuthKey(user.TypedID(), false, true, nil, nil) if err != nil { return "", err } @@ -370,7 +374,8 @@ func TestAuthenticationFlows(t *testing.T) { name: "existing_node_logout", setupFunc: func(t *testing.T, app *Headscale) (string, error) { user := app.state.CreateUserForTest("logout-user") - pak, err := app.state.CreatePreAuthKey(types.UserID(user.ID), true, false, nil, nil) + + pak, err := app.state.CreatePreAuthKey(user.TypedID(), true, false, nil, nil) if err != nil { return "", err } @@ -429,7 +434,8 @@ func TestAuthenticationFlows(t *testing.T) { name: "existing_node_machine_key_mismatch", setupFunc: func(t *testing.T, app *Headscale) (string, error) { user := app.state.CreateUserForTest("mismatch-user") - pak, err := app.state.CreatePreAuthKey(types.UserID(user.ID), true, false, nil, nil) + + pak, err := app.state.CreatePreAuthKey(user.TypedID(), true, false, nil, nil) if err != nil { return "", err } @@ -477,7 +483,8 @@ func TestAuthenticationFlows(t *testing.T) { name: "existing_node_key_extension_not_allowed", setupFunc: func(t *testing.T, app *Headscale) (string, error) { user := app.state.CreateUserForTest("extend-user") - pak, err := app.state.CreatePreAuthKey(types.UserID(user.ID), true, false, nil, nil) + + pak, err := app.state.CreatePreAuthKey(user.TypedID(), true, false, nil, nil) if err != nil { return "", err } @@ -525,7 +532,8 @@ func TestAuthenticationFlows(t *testing.T) { name: "existing_node_expired_forces_reauth", setupFunc: func(t *testing.T, app *Headscale) (string, error) { user := app.state.CreateUserForTest("reauth-user") - pak, err := app.state.CreatePreAuthKey(types.UserID(user.ID), true, false, nil, nil) + + pak, err := app.state.CreatePreAuthKey(user.TypedID(), true, false, nil, nil) if err != nil { return "", err } @@ -585,7 +593,8 @@ func TestAuthenticationFlows(t *testing.T) { name: "ephemeral_node_logout_deletion", setupFunc: func(t *testing.T, app *Headscale) (string, error) { user := app.state.CreateUserForTest("ephemeral-logout-user") - pak, err := app.state.CreatePreAuthKey(types.UserID(user.ID), false, true, nil, nil) + + pak, err := app.state.CreatePreAuthKey(user.TypedID(), false, true, nil, nil) if err != nil { return "", err } @@ -767,7 +776,8 @@ func TestAuthenticationFlows(t *testing.T) { name: "empty_hostname", setupFunc: func(t *testing.T, app *Headscale) (string, error) { user := app.state.CreateUserForTest("empty-hostname-user") - pak, err := app.state.CreatePreAuthKey(types.UserID(user.ID), true, false, nil, nil) + + pak, err := app.state.CreatePreAuthKey(user.TypedID(), true, false, nil, nil) if err != nil { return "", err } @@ -805,7 +815,8 @@ func TestAuthenticationFlows(t *testing.T) { name: "nil_hostinfo", setupFunc: func(t *testing.T, app *Headscale) (string, error) { user := app.state.CreateUserForTest("nil-hostinfo-user") - pak, err := app.state.CreatePreAuthKey(types.UserID(user.ID), true, false, nil, nil) + + pak, err := app.state.CreatePreAuthKey(user.TypedID(), true, false, nil, nil) if err != nil { return "", err } @@ -848,7 +859,8 @@ func TestAuthenticationFlows(t *testing.T) { setupFunc: func(t *testing.T, app *Headscale) (string, error) { user := app.state.CreateUserForTest("expired-pak-user") expiry := time.Now().Add(-1 * time.Hour) // Expired 1 hour ago - pak, err := app.state.CreatePreAuthKey(types.UserID(user.ID), true, false, &expiry, nil) + + pak, err := app.state.CreatePreAuthKey(user.TypedID(), true, false, &expiry, nil) if err != nil { return "", err } @@ -880,7 +892,8 @@ func TestAuthenticationFlows(t *testing.T) { setupFunc: func(t *testing.T, app *Headscale) (string, error) { user := app.state.CreateUserForTest("tagged-pak-user") tags := []string{"tag:server", "tag:database"} - pak, err := app.state.CreatePreAuthKey(types.UserID(user.ID), true, false, nil, tags) + + pak, err := app.state.CreatePreAuthKey(user.TypedID(), true, false, nil, tags) if err != nil { return "", err } @@ -926,7 +939,7 @@ func TestAuthenticationFlows(t *testing.T) { user := app.state.CreateUserForTest("reauth-user") // First, register with initial auth key - pak1, err := app.state.CreatePreAuthKey(types.UserID(user.ID), true, false, nil, nil) + pak1, err := app.state.CreatePreAuthKey(user.TypedID(), true, false, nil, nil) if err != nil { return "", err } @@ -953,7 +966,7 @@ func TestAuthenticationFlows(t *testing.T) { }, 1*time.Second, 50*time.Millisecond, "waiting for node to be available in NodeStore") // Create new auth key for re-authentication - pak2, err := app.state.CreatePreAuthKey(types.UserID(user.ID), true, false, nil, nil) + pak2, err := app.state.CreatePreAuthKey(user.TypedID(), true, false, nil, nil) if err != nil { return "", err } @@ -992,7 +1005,8 @@ func TestAuthenticationFlows(t *testing.T) { name: "existing_node_reauth_interactive_flow", setupFunc: func(t *testing.T, app *Headscale) (string, error) { user := app.state.CreateUserForTest("interactive-reauth-user") - pak, err := app.state.CreatePreAuthKey(types.UserID(user.ID), true, false, nil, nil) + + pak, err := app.state.CreatePreAuthKey(user.TypedID(), true, false, nil, nil) if err != nil { return "", err } @@ -1053,7 +1067,8 @@ func TestAuthenticationFlows(t *testing.T) { name: "node_key_rotation_same_machine", setupFunc: func(t *testing.T, app *Headscale) (string, error) { user := app.state.CreateUserForTest("rotation-user") - pak, err := app.state.CreatePreAuthKey(types.UserID(user.ID), true, false, nil, nil) + + pak, err := app.state.CreatePreAuthKey(user.TypedID(), true, false, nil, nil) if err != nil { return "", err } @@ -1081,7 +1096,7 @@ func TestAuthenticationFlows(t *testing.T) { }, 1*time.Second, 50*time.Millisecond, "waiting for node to be available in NodeStore") // Create new auth key for rotation - pakRotation, err := app.state.CreatePreAuthKey(types.UserID(user.ID), true, false, nil, nil) + pakRotation, err := app.state.CreatePreAuthKey(user.TypedID(), true, false, nil, nil) if err != nil { return "", err } @@ -1129,7 +1144,8 @@ func TestAuthenticationFlows(t *testing.T) { name: "malformed_expiry_zero_time", setupFunc: func(t *testing.T, app *Headscale) (string, error) { user := app.state.CreateUserForTest("zero-expiry-user") - pak, err := app.state.CreatePreAuthKey(types.UserID(user.ID), true, false, nil, nil) + + pak, err := app.state.CreatePreAuthKey(user.TypedID(), true, false, nil, nil) if err != nil { return "", err } @@ -1167,7 +1183,8 @@ func TestAuthenticationFlows(t *testing.T) { name: "malformed_hostinfo_invalid_data", setupFunc: func(t *testing.T, app *Headscale) (string, error) { user := app.state.CreateUserForTest("invalid-hostinfo-user") - pak, err := app.state.CreatePreAuthKey(types.UserID(user.ID), true, false, nil, nil) + + pak, err := app.state.CreatePreAuthKey(user.TypedID(), true, false, nil, nil) if err != nil { return "", err } @@ -1353,7 +1370,8 @@ func TestAuthenticationFlows(t *testing.T) { name: "preauth_key_usage_count_tracking", setupFunc: func(t *testing.T, app *Headscale) (string, error) { user := app.state.CreateUserForTest("usage-count-user") - pak, err := app.state.CreatePreAuthKey(types.UserID(user.ID), false, false, nil, nil) // Single use + + pak, err := app.state.CreatePreAuthKey(user.TypedID(), false, false, nil, nil) // Single use if err != nil { return "", err } @@ -1432,7 +1450,8 @@ func TestAuthenticationFlows(t *testing.T) { name: "concurrent_registration_same_node_key", setupFunc: func(t *testing.T, app *Headscale) (string, error) { user := app.state.CreateUserForTest("concurrent-user") - pak, err := app.state.CreatePreAuthKey(types.UserID(user.ID), true, false, nil, nil) + + pak, err := app.state.CreatePreAuthKey(user.TypedID(), true, false, nil, nil) if err != nil { return "", err } @@ -1473,7 +1492,8 @@ func TestAuthenticationFlows(t *testing.T) { user := app.state.CreateUserForTest("future-expiry-user") // Auth key expires in the future expiry := time.Now().Add(48 * time.Hour) - pak, err := app.state.CreatePreAuthKey(types.UserID(user.ID), true, false, &expiry, nil) + + pak, err := app.state.CreatePreAuthKey(user.TypedID(), true, false, &expiry, nil) if err != nil { return "", err } @@ -1517,7 +1537,7 @@ func TestAuthenticationFlows(t *testing.T) { user2 := app.state.CreateUserForTest("user2-context") // Register node with user1's auth key - pak1, err := app.state.CreatePreAuthKey(types.UserID(user1.ID), true, false, nil, nil) + pak1, err := app.state.CreatePreAuthKey(user1.TypedID(), true, false, nil, nil) if err != nil { return "", err } @@ -1544,7 +1564,7 @@ func TestAuthenticationFlows(t *testing.T) { }, 1*time.Second, 50*time.Millisecond, "waiting for node to be available in NodeStore") // Return user2's auth key for re-authentication - pak2, err := app.state.CreatePreAuthKey(types.UserID(user2.ID), true, false, nil, nil) + pak2, err := app.state.CreatePreAuthKey(user2.TypedID(), true, false, nil, nil) if err != nil { return "", err } @@ -1571,15 +1591,15 @@ func TestAuthenticationFlows(t *testing.T) { // Verify NEW node was created for user2 node2, found := app.state.GetNodeByMachineKey(machineKey1.Public(), types.UserID(2)) require.True(t, found, "new node should exist for user2") - assert.Equal(t, uint(2), node2.UserID(), "new node should belong to user2") + assert.Equal(t, uint(2), node2.UserID().Get(), "new node should belong to user2") user := node2.User() - assert.Equal(t, "user2-context", user.Username(), "new node should show user2 username") + assert.Equal(t, "user2-context", user.Name(), "new node should show user2 username") // Verify original node still exists for user1 node1, found := app.state.GetNodeByMachineKey(machineKey1.Public(), types.UserID(1)) require.True(t, found, "original node should still exist for user1") - assert.Equal(t, uint(1), node1.UserID(), "original node should still belong to user1") + assert.Equal(t, uint(1), node1.UserID().Get(), "original node should still belong to user1") // Verify they are different nodes (different IDs) assert.NotEqual(t, node1.ID(), node2.ID(), "should be different node IDs") @@ -1595,7 +1615,8 @@ func TestAuthenticationFlows(t *testing.T) { setupFunc: func(t *testing.T, app *Headscale) (string, error) { // Create user1 and register a node with auth key user1 := app.state.CreateUserForTest("interactive-user-1") - pak1, err := app.state.CreatePreAuthKey(types.UserID(user1.ID), true, false, nil, nil) + + pak1, err := app.state.CreatePreAuthKey(user1.TypedID(), true, false, nil, nil) if err != nil { return "", err } @@ -1645,16 +1666,16 @@ func TestAuthenticationFlows(t *testing.T) { // User1's original node should STILL exist (not transferred) node1, found1 := app.state.GetNodeByMachineKey(machineKey1.Public(), types.UserID(1)) require.True(t, found1, "user1's original node should still exist") - assert.Equal(t, uint(1), node1.UserID(), "user1's node should still belong to user1") + assert.Equal(t, uint(1), node1.UserID().Get(), "user1's node should still belong to user1") assert.Equal(t, nodeKey1.Public(), node1.NodeKey(), "user1's node should have original node key") // User2 should have a NEW node created node2, found2 := app.state.GetNodeByMachineKey(machineKey1.Public(), types.UserID(2)) require.True(t, found2, "user2 should have new node created") - assert.Equal(t, uint(2), node2.UserID(), "user2's node should belong to user2") + assert.Equal(t, uint(2), node2.UserID().Get(), "user2's node should belong to user2") user := node2.User() - assert.Equal(t, "interactive-test-user", user.Username(), "user2's node should show correct username") + assert.Equal(t, "interactive-test-user", user.Name(), "user2's node should show correct username") // Both nodes should have the same machine key but different IDs assert.NotEqual(t, node1.ID(), node2.ID(), "should be different nodes (different IDs)") @@ -1720,7 +1741,8 @@ func TestAuthenticationFlows(t *testing.T) { name: "logout_with_exactly_now_expiry", setupFunc: func(t *testing.T, app *Headscale) (string, error) { user := app.state.CreateUserForTest("exact-now-user") - pak, err := app.state.CreatePreAuthKey(types.UserID(user.ID), true, false, nil, nil) + + pak, err := app.state.CreatePreAuthKey(user.TypedID(), true, false, nil, nil) if err != nil { return "", err } @@ -1813,7 +1835,8 @@ func TestAuthenticationFlows(t *testing.T) { setupFunc: func(t *testing.T, app *Headscale) (string, error) { // First create a node under user1 user1 := app.state.CreateUserForTest("existing-user-1") - pak1, err := app.state.CreatePreAuthKey(types.UserID(user1.ID), true, false, nil, nil) + + pak1, err := app.state.CreatePreAuthKey(user1.TypedID(), true, false, nil, nil) if err != nil { return "", err } @@ -1863,7 +1886,7 @@ func TestAuthenticationFlows(t *testing.T) { // User1's original node with nodeKey1 should STILL exist node1, found1 := app.state.GetNodeByNodeKey(nodeKey1.Public()) require.True(t, found1, "user1's original node with nodeKey1 should still exist") - assert.Equal(t, uint(1), node1.UserID(), "user1's node should still belong to user1") + assert.Equal(t, uint(1), node1.UserID().Get(), "user1's node should still belong to user1") assert.Equal(t, uint64(1), node1.ID().Uint64(), "user1's node should be ID=1") // User2 should have a NEW node with nodeKey2 @@ -1872,7 +1895,7 @@ func TestAuthenticationFlows(t *testing.T) { assert.Equal(t, "existing-node-user2", node2.Hostname(), "hostname should be from new registration") user := node2.User() - assert.Equal(t, "interactive-test-user", user.Username(), "user2's node should belong to user2") + assert.Equal(t, "interactive-test-user", user.Name(), "user2's node should belong to user2") assert.Equal(t, machineKey1.Public(), node2.MachineKey(), "machine key should be the same") // Verify it's a NEW node, not transferred @@ -2022,7 +2045,8 @@ func TestAuthenticationFlows(t *testing.T) { setupFunc: func(t *testing.T, app *Headscale) (string, error) { // Register initial node user := app.state.CreateUserForTest("rotation-user") - pak, err := app.state.CreatePreAuthKey(types.UserID(user.ID), true, false, nil, nil) + + pak, err := app.state.CreatePreAuthKey(user.TypedID(), true, false, nil, nil) if err != nil { return "", err } @@ -2072,7 +2096,7 @@ func TestAuthenticationFlows(t *testing.T) { // User1's original node with nodeKey1 should STILL exist oldNode, foundOld := app.state.GetNodeByNodeKey(nodeKey1.Public()) require.True(t, foundOld, "user1's original node with nodeKey1 should still exist") - assert.Equal(t, uint(1), oldNode.UserID(), "user1's node should still belong to user1") + assert.Equal(t, uint(1), oldNode.UserID().Get(), "user1's node should still belong to user1") assert.Equal(t, uint64(1), oldNode.ID().Uint64(), "user1's node should be ID=1") // User2 should have a NEW node with nodeKey2 @@ -2082,7 +2106,7 @@ func TestAuthenticationFlows(t *testing.T) { assert.Equal(t, machineKey1.Public(), newNode.MachineKey()) user := newNode.User() - assert.Equal(t, "interactive-test-user", user.Username(), "user2's node should belong to user2") + assert.Equal(t, "interactive-test-user", user.Name(), "user2's node should belong to user2") // Verify it's a NEW node, not transferred assert.NotEqual(t, uint64(1), newNode.ID().Uint64(), "should be a NEW node (different ID)") @@ -2333,7 +2357,7 @@ func TestAuthenticationFlows(t *testing.T) { assert.True(t, found, "node should be registered") if found { assert.Equal(t, "pending-node-2", node.Hostname()) - assert.Equal(t, "second-registration-user", node.User().Name) + assert.Equal(t, "second-registration-user", node.User().Name()) } // First registration should still be in cache (not completed) @@ -2593,7 +2617,7 @@ func TestNodeStoreLookup(t *testing.T) { nodeKey := key.NewNode() user := app.state.CreateUserForTest("test-user") - pak, err := app.state.CreatePreAuthKey(types.UserID(user.ID), true, false, nil, nil) + pak, err := app.state.CreatePreAuthKey(user.TypedID(), true, false, nil, nil) require.NoError(t, err) // Register a node @@ -2642,9 +2666,9 @@ func TestPreAuthKeyLogoutAndReloginDifferentUser(t *testing.T) { user2 := app.state.CreateUserForTest("user2") // Create pre-auth keys for both users - pak1, err := app.state.CreatePreAuthKey(types.UserID(user1.ID), true, false, nil, nil) + pak1, err := app.state.CreatePreAuthKey(user1.TypedID(), true, false, nil, nil) require.NoError(t, err) - pak2, err := app.state.CreatePreAuthKey(types.UserID(user2.ID), true, false, nil, nil) + pak2, err := app.state.CreatePreAuthKey(user2.TypedID(), true, false, nil, nil) require.NoError(t, err) // Create machine and node keys for 4 nodes (2 per user) @@ -2720,7 +2744,7 @@ func TestPreAuthKeyLogoutAndReloginDifferentUser(t *testing.T) { t.Logf("All nodes logged out") // Create a new pre-auth key for user1 (reusable for all nodes) - newPak1, err := app.state.CreatePreAuthKey(types.UserID(user1.ID), true, false, nil, nil) + newPak1, err := app.state.CreatePreAuthKey(user1.TypedID(), true, false, nil, nil) require.NoError(t, err) // Re-login all nodes using user1's new pre-auth key @@ -2765,7 +2789,7 @@ func TestPreAuthKeyLogoutAndReloginDifferentUser(t *testing.T) { // User1's original nodes should still be owned by user1 registeredNode, found := app.state.GetNodeByMachineKey(node.machineKey.Public(), types.UserID(user1.ID)) require.True(t, found, "User1's original node %s should still exist", node.hostname) - require.Equal(t, user1.ID, registeredNode.UserID(), "Node %s should still belong to user1", node.hostname) + require.Equal(t, user1.ID, registeredNode.UserID().Get(), "Node %s should still belong to user1", node.hostname) t.Logf("✓ User1's original node %s (ID=%d) still owned by user1", node.hostname, registeredNode.ID().Uint64()) } @@ -2774,7 +2798,7 @@ func TestPreAuthKeyLogoutAndReloginDifferentUser(t *testing.T) { // User2's original nodes should still be owned by user2 registeredNode, found := app.state.GetNodeByMachineKey(node.machineKey.Public(), types.UserID(user2.ID)) require.True(t, found, "User2's original node %s should still exist", node.hostname) - require.Equal(t, user2.ID, registeredNode.UserID(), "Node %s should still belong to user2", node.hostname) + require.Equal(t, user2.ID, registeredNode.UserID().Get(), "Node %s should still belong to user2", node.hostname) t.Logf("✓ User2's original node %s (ID=%d) still owned by user2", node.hostname, registeredNode.ID().Uint64()) } @@ -2785,7 +2809,7 @@ func TestPreAuthKeyLogoutAndReloginDifferentUser(t *testing.T) { // Should be able to find a node with user1 and this machine key (the new one) newNode, found := app.state.GetNodeByMachineKey(node.machineKey.Public(), types.UserID(user1.ID)) require.True(t, found, "Should have created new node for user1 with machine key from %s", node.hostname) - require.Equal(t, user1.ID, newNode.UserID(), "New node should belong to user1") + require.Equal(t, user1.ID, newNode.UserID().Get(), "New node should belong to user1") t.Logf("✓ New node created for user1 with machine key from %s (ID=%d)", node.hostname, newNode.ID().Uint64()) } } @@ -2813,7 +2837,7 @@ func TestWebFlowReauthDifferentUser(t *testing.T) { // Step 1: Register node for user1 via pre-auth key (simulating initial web flow registration) user1 := app.state.CreateUserForTest("user1") - pak1, err := app.state.CreatePreAuthKey(types.UserID(user1.ID), true, false, nil, nil) + pak1, err := app.state.CreatePreAuthKey(user1.TypedID(), true, false, nil, nil) require.NoError(t, err) regReq1 := tailcfg.RegisterRequest{ @@ -2834,7 +2858,7 @@ func TestWebFlowReauthDifferentUser(t *testing.T) { // Verify node exists for user1 user1Node, found := app.state.GetNodeByMachineKey(machineKey.Public(), types.UserID(user1.ID)) require.True(t, found, "Node should exist for user1") - require.Equal(t, user1.ID, user1Node.UserID(), "Node should belong to user1") + require.Equal(t, user1.ID, user1Node.UserID().Get(), "Node should belong to user1") user1NodeID := user1Node.ID() t.Logf("✓ User1 node created with ID: %d", user1NodeID) @@ -2896,7 +2920,7 @@ func TestWebFlowReauthDifferentUser(t *testing.T) { t.Fatal("User1's node was transferred or deleted - this breaks the integration test!") } - assert.Equal(t, user1.ID, user1NodeAfter.UserID(), "User1's node should still belong to user1") + assert.Equal(t, user1.ID, user1NodeAfter.UserID().Get(), "User1's node should still belong to user1") assert.Equal(t, user1NodeID, user1NodeAfter.ID(), "Should be the same node (same ID)") assert.True(t, user1NodeAfter.IsExpired(), "User1's node should still be expired") t.Logf("✓ User1's original node still exists (ID: %d, expired: %v)", user1NodeAfter.ID(), user1NodeAfter.IsExpired()) @@ -2911,7 +2935,7 @@ func TestWebFlowReauthDifferentUser(t *testing.T) { t.Fatal("User2 doesn't have a node - registration failed!") } - assert.Equal(t, user2.ID, user2Node.UserID(), "User2's node should belong to user2") + assert.Equal(t, user2.ID, user2Node.UserID().Get(), "User2's node should belong to user2") assert.NotEqual(t, user1NodeID, user2Node.ID(), "Should be a NEW node (different ID), not transfer!") assert.Equal(t, machineKey.Public(), user2Node.MachineKey(), "Should have same machine key") assert.Equal(t, nodeKey2.Public(), user2Node.NodeKey(), "Should have new node key") @@ -2921,7 +2945,7 @@ func TestWebFlowReauthDifferentUser(t *testing.T) { t.Run("returned_node_is_user2_new_node", func(t *testing.T) { // The node returned from HandleNodeFromAuthPath should be user2's NEW node - assert.Equal(t, user2.ID, node.UserID(), "Returned node should belong to user2") + assert.Equal(t, user2.ID, node.UserID().Get(), "Returned node should belong to user2") assert.NotEqual(t, user1NodeID, node.ID(), "Returned node should be NEW, not transferred from user1") t.Logf("✓ HandleNodeFromAuthPath returned user2's new node (ID: %d)", node.ID()) }) @@ -2949,10 +2973,11 @@ func TestWebFlowReauthDifferentUser(t *testing.T) { user2Nodes := 0 for i := 0; i < allNodesSlice.Len(); i++ { n := allNodesSlice.At(i) - if n.UserID() == user1.ID { + if n.UserID().Get() == user1.ID { user1Nodes++ } - if n.UserID() == user2.ID { + + if n.UserID().Get() == user2.ID { user2Nodes++ } } @@ -3026,7 +3051,7 @@ func TestGitHubIssue2830_NodeRestartWithUsedPreAuthKey(t *testing.T) { // Create user and single-use pre-auth key user := app.state.CreateUserForTest("test-user") - pakNew, err := app.state.CreatePreAuthKey(types.UserID(user.ID), false, false, nil, nil) // reusable=false + pakNew, err := app.state.CreatePreAuthKey(user.TypedID(), false, false, nil, nil) // reusable=false require.NoError(t, err) // Fetch the full pre-auth key to check Reusable field @@ -3117,7 +3142,7 @@ func TestNodeReregistrationWithReusablePreAuthKey(t *testing.T) { app := createTestApp(t) user := app.state.CreateUserForTest("test-user") - pakNew, err := app.state.CreatePreAuthKey(types.UserID(user.ID), true, false, nil, nil) // reusable=true + pakNew, err := app.state.CreatePreAuthKey(user.TypedID(), true, false, nil, nil) // reusable=true require.NoError(t, err) // Fetch the full pre-auth key to check Reusable field @@ -3173,7 +3198,7 @@ func TestNodeReregistrationWithExpiredPreAuthKey(t *testing.T) { user := app.state.CreateUserForTest("test-user") expiry := time.Now().Add(-1 * time.Hour) // Already expired - pak, err := app.state.CreatePreAuthKey(types.UserID(user.ID), true, false, &expiry, nil) + pak, err := app.state.CreatePreAuthKey(user.TypedID(), true, false, &expiry, nil) require.NoError(t, err) machineKey := key.NewMachine() @@ -3306,7 +3331,7 @@ func TestGitHubIssue2830_ExistingNodeCanReregisterWithUsedPreAuthKey(t *testing. // Create a SINGLE-USE pre-auth key (reusable=false) // This is the type of key that triggers the bug in issue #2830 - preAuthKeyNew, err := app.state.CreatePreAuthKey(types.UserID(user.ID), false, false, nil, nil) + preAuthKeyNew, err := app.state.CreatePreAuthKey(user.TypedID(), false, false, nil, nil) require.NoError(t, err) // Fetch the full pre-auth key to check Reusable and Used fields diff --git a/hscontrol/db/db.go b/hscontrol/db/db.go index b3e6b704..ad1a8a25 100644 --- a/hscontrol/db/db.go +++ b/hscontrol/db/db.go @@ -577,6 +577,21 @@ AND auth_key_id NOT IN ( }, Rollback: func(db *gorm.DB) error { return nil }, }, + { + // Rename forced_tags column to tags in nodes table. + // This must run after migration 202505141324 which creates tables with forced_tags. + ID: "202511131445-node-forced-tags-to-tags", + Migrate: func(tx *gorm.DB) error { + // Rename the column from forced_tags to tags + err := tx.Migrator().RenameColumn(&types.Node{}, "forced_tags", "tags") + if err != nil { + return fmt.Errorf("renaming forced_tags to tags: %w", err) + } + + return nil + }, + Rollback: func(db *gorm.DB) error { return nil }, + }, }, ) diff --git a/hscontrol/db/db_test.go b/hscontrol/db/db_test.go index 2a30027e..c50ad37c 100644 --- a/hscontrol/db/db_test.go +++ b/hscontrol/db/db_test.go @@ -231,8 +231,7 @@ func TestPostgresMigrationAndDataValidation(t *testing.T) { name string dbPath string wantFunc func(*testing.T, *HSDatabase) - }{ - } + }{} for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { diff --git a/hscontrol/db/ip_test.go b/hscontrol/db/ip_test.go index 3ec81c9f..7ba335e8 100644 --- a/hscontrol/db/ip_test.go +++ b/hscontrol/db/ip_test.go @@ -95,7 +95,7 @@ func TestIPAllocatorSequential(t *testing.T) { db.DB.Save(&user) db.DB.Save(&types.Node{ - User: user, + User: &user, IPv4: nap("100.64.0.1"), IPv6: nap("fd7a:115c:a1e0::1"), }) @@ -123,7 +123,7 @@ func TestIPAllocatorSequential(t *testing.T) { db.DB.Save(&user) db.DB.Save(&types.Node{ - User: user, + User: &user, IPv4: nap("100.64.0.2"), IPv6: nap("fd7a:115c:a1e0::2"), }) @@ -309,7 +309,7 @@ func TestBackfillIPAddresses(t *testing.T) { db.DB.Save(&user) db.DB.Save(&types.Node{ - User: user, + User: &user, IPv4: nap("100.64.0.1"), }) @@ -334,7 +334,7 @@ func TestBackfillIPAddresses(t *testing.T) { db.DB.Save(&user) db.DB.Save(&types.Node{ - User: user, + User: &user, IPv6: nap("fd7a:115c:a1e0::1"), }) @@ -359,7 +359,7 @@ func TestBackfillIPAddresses(t *testing.T) { db.DB.Save(&user) db.DB.Save(&types.Node{ - User: user, + User: &user, IPv4: nap("100.64.0.1"), IPv6: nap("fd7a:115c:a1e0::1"), }) @@ -383,7 +383,7 @@ func TestBackfillIPAddresses(t *testing.T) { db.DB.Save(&user) db.DB.Save(&types.Node{ - User: user, + User: &user, IPv4: nap("100.64.0.1"), IPv6: nap("fd7a:115c:a1e0::1"), }) @@ -407,19 +407,19 @@ func TestBackfillIPAddresses(t *testing.T) { db.DB.Save(&user) db.DB.Save(&types.Node{ - User: user, + User: &user, IPv4: nap("100.64.0.1"), }) db.DB.Save(&types.Node{ - User: user, + User: &user, IPv4: nap("100.64.0.2"), }) db.DB.Save(&types.Node{ - User: user, + User: &user, IPv4: nap("100.64.0.3"), }) db.DB.Save(&types.Node{ - User: user, + User: &user, IPv4: nap("100.64.0.4"), }) diff --git a/hscontrol/db/node.go b/hscontrol/db/node.go index cfeefb82..bf407bb4 100644 --- a/hscontrol/db/node.go +++ b/hscontrol/db/node.go @@ -196,8 +196,9 @@ func SetTags( tags []string, ) error { if len(tags) == 0 { - // if no tags are provided, we remove all forced tags - if err := tx.Model(&types.Node{}).Where("id = ?", nodeID).Update("forced_tags", "[]").Error; err != nil { + // if no tags are provided, we remove all tags + err := tx.Model(&types.Node{}).Where("id = ?", nodeID).Update("tags", "[]").Error + if err != nil { return fmt.Errorf("removing tags: %w", err) } @@ -211,7 +212,8 @@ func SetTags( return err } - if err := tx.Model(&types.Node{}).Where("id = ?", nodeID).Update("forced_tags", string(b)).Error; err != nil { + err = tx.Model(&types.Node{}).Where("id = ?", nodeID).Update("tags", string(b)).Error + if err != nil { return fmt.Errorf("updating tags: %w", err) } @@ -349,12 +351,20 @@ func RegisterNodeForTest(tx *gorm.DB, node types.Node, ipv4 *netip.Addr, ipv6 *n panic("RegisterNodeForTest can only be called during tests") } - log.Debug(). + logEvent := log.Debug(). Str("node", node.Hostname). Str("machine_key", node.MachineKey.ShortString()). - Str("node_key", node.NodeKey.ShortString()). - Str("user", node.User.Username()). - Msg("Registering test node") + Str("node_key", node.NodeKey.ShortString()) + + if node.User != nil { + logEvent = logEvent.Str("user", node.User.Username()) + } else if node.UserID != nil { + logEvent = logEvent.Uint("user_id", *node.UserID) + } else { + logEvent = logEvent.Str("user", "none") + } + + logEvent.Msg("Registering test node") // If the a new node is registered with the same machine key, to the same user, // update the existing node. @@ -642,7 +652,7 @@ func (hsdb *HSDatabase) CreateNodeForTest(user *types.User, hostname ...string) } // Create a preauth key for the node - pak, err := hsdb.CreatePreAuthKey(types.UserID(user.ID), false, false, nil, nil) + pak, err := hsdb.CreatePreAuthKey(user.TypedID(), false, false, nil, nil) if err != nil { panic(fmt.Sprintf("failed to create preauth key for test node: %v", err)) } @@ -656,7 +666,7 @@ func (hsdb *HSDatabase) CreateNodeForTest(user *types.User, hostname ...string) NodeKey: nodeKey.Public(), DiscoKey: discoKey.Public(), Hostname: nodeName, - UserID: user.ID, + UserID: &user.ID, RegisterMethod: util.RegisterMethodAuthKey, AuthKeyID: ptr.To(pak.ID), } diff --git a/hscontrol/db/node_test.go b/hscontrol/db/node_test.go index 0efd0e8b..f0fac74c 100644 --- a/hscontrol/db/node_test.go +++ b/hscontrol/db/node_test.go @@ -83,7 +83,7 @@ func (s *Suite) TestExpireNode(c *check.C) { user, err := db.CreateUser(types.User{Name: "test"}) c.Assert(err, check.IsNil) - pak, err := db.CreatePreAuthKey(types.UserID(user.ID), false, false, nil, nil) + pak, err := db.CreatePreAuthKey(user.TypedID(), false, false, nil, nil) c.Assert(err, check.IsNil) _, err = db.getNode(types.UserID(user.ID), "testnode") @@ -97,7 +97,7 @@ func (s *Suite) TestExpireNode(c *check.C) { MachineKey: machineKey.Public(), NodeKey: nodeKey.Public(), Hostname: "testnode", - UserID: user.ID, + UserID: &user.ID, RegisterMethod: util.RegisterMethodAuthKey, AuthKeyID: ptr.To(pak.ID), Expiry: &time.Time{}, @@ -124,7 +124,7 @@ func (s *Suite) TestSetTags(c *check.C) { user, err := db.CreateUser(types.User{Name: "test"}) c.Assert(err, check.IsNil) - pak, err := db.CreatePreAuthKey(types.UserID(user.ID), false, false, nil, nil) + pak, err := db.CreatePreAuthKey(user.TypedID(), false, false, nil, nil) c.Assert(err, check.IsNil) _, err = db.getNode(types.UserID(user.ID), "testnode") @@ -138,7 +138,7 @@ func (s *Suite) TestSetTags(c *check.C) { MachineKey: machineKey.Public(), NodeKey: nodeKey.Public(), Hostname: "testnode", - UserID: user.ID, + UserID: &user.ID, RegisterMethod: util.RegisterMethodAuthKey, AuthKeyID: ptr.To(pak.ID), } @@ -152,7 +152,7 @@ func (s *Suite) TestSetTags(c *check.C) { c.Assert(err, check.IsNil) node, err = db.getNode(types.UserID(user.ID), "testnode") c.Assert(err, check.IsNil) - c.Assert(node.ForcedTags, check.DeepEquals, sTags) + c.Assert(node.Tags, check.DeepEquals, sTags) // assign duplicate tags, expect no errors but no doubles in DB eTags := []string{"tag:bar", "tag:test", "tag:unknown", "tag:test"} @@ -161,17 +161,10 @@ func (s *Suite) TestSetTags(c *check.C) { node, err = db.getNode(types.UserID(user.ID), "testnode") c.Assert(err, check.IsNil) c.Assert( - node.ForcedTags, + node.Tags, check.DeepEquals, []string{"tag:bar", "tag:test", "tag:unknown"}, ) - - // test removing tags - err = db.SetTags(node.ID, []string{}) - c.Assert(err, check.IsNil) - node, err = db.getNode(types.UserID(user.ID), "testnode") - c.Assert(err, check.IsNil) - c.Assert(node.ForcedTags, check.DeepEquals, []string{}) } func TestHeadscale_generateGivenName(t *testing.T) { @@ -430,7 +423,7 @@ func TestAutoApproveRoutes(t *testing.T) { MachineKey: key.NewMachine().Public(), NodeKey: key.NewNode().Public(), Hostname: "testnode", - UserID: user.ID, + UserID: &user.ID, RegisterMethod: util.RegisterMethodAuthKey, Hostinfo: &tailcfg.Hostinfo{ RoutableIPs: tt.routes, @@ -446,12 +439,12 @@ func TestAutoApproveRoutes(t *testing.T) { MachineKey: key.NewMachine().Public(), NodeKey: key.NewNode().Public(), Hostname: "taggednode", - UserID: taggedUser.ID, + UserID: &taggedUser.ID, RegisterMethod: util.RegisterMethodAuthKey, Hostinfo: &tailcfg.Hostinfo{ RoutableIPs: tt.routes, }, - ForcedTags: []string{"tag:exit"}, + Tags: []string{"tag:exit"}, IPv4: ptr.To(netip.MustParseAddr("100.64.0.2")), } @@ -593,10 +586,10 @@ func TestListEphemeralNodes(t *testing.T) { user, err := db.CreateUser(types.User{Name: "test"}) require.NoError(t, err) - pak, err := db.CreatePreAuthKey(types.UserID(user.ID), false, false, nil, nil) + pak, err := db.CreatePreAuthKey(user.TypedID(), false, false, nil, nil) require.NoError(t, err) - pakEph, err := db.CreatePreAuthKey(types.UserID(user.ID), false, true, nil, nil) + pakEph, err := db.CreatePreAuthKey(user.TypedID(), false, true, nil, nil) require.NoError(t, err) node := types.Node{ @@ -604,7 +597,7 @@ func TestListEphemeralNodes(t *testing.T) { MachineKey: key.NewMachine().Public(), NodeKey: key.NewNode().Public(), Hostname: "test", - UserID: user.ID, + UserID: &user.ID, RegisterMethod: util.RegisterMethodAuthKey, AuthKeyID: ptr.To(pak.ID), } @@ -614,7 +607,7 @@ func TestListEphemeralNodes(t *testing.T) { MachineKey: key.NewMachine().Public(), NodeKey: key.NewNode().Public(), Hostname: "ephemeral", - UserID: user.ID, + UserID: &user.ID, RegisterMethod: util.RegisterMethodAuthKey, AuthKeyID: ptr.To(pakEph.ID), } @@ -657,7 +650,7 @@ func TestNodeNaming(t *testing.T) { MachineKey: key.NewMachine().Public(), NodeKey: key.NewNode().Public(), Hostname: "test", - UserID: user.ID, + UserID: &user.ID, RegisterMethod: util.RegisterMethodAuthKey, Hostinfo: &tailcfg.Hostinfo{}, } @@ -667,7 +660,7 @@ func TestNodeNaming(t *testing.T) { MachineKey: key.NewMachine().Public(), NodeKey: key.NewNode().Public(), Hostname: "test", - UserID: user2.ID, + UserID: &user2.ID, RegisterMethod: util.RegisterMethodAuthKey, Hostinfo: &tailcfg.Hostinfo{}, } @@ -680,7 +673,7 @@ func TestNodeNaming(t *testing.T) { MachineKey: key.NewMachine().Public(), NodeKey: key.NewNode().Public(), Hostname: "我的电脑", - UserID: user2.ID, + UserID: &user2.ID, RegisterMethod: util.RegisterMethodAuthKey, } @@ -688,7 +681,7 @@ func TestNodeNaming(t *testing.T) { MachineKey: key.NewMachine().Public(), NodeKey: key.NewNode().Public(), Hostname: "a", - UserID: user2.ID, + UserID: &user2.ID, RegisterMethod: util.RegisterMethodAuthKey, } @@ -808,7 +801,7 @@ func TestRenameNodeComprehensive(t *testing.T) { MachineKey: key.NewMachine().Public(), NodeKey: key.NewNode().Public(), Hostname: "testnode", - UserID: user.ID, + UserID: &user.ID, RegisterMethod: util.RegisterMethodAuthKey, Hostinfo: &tailcfg.Hostinfo{}, } @@ -931,7 +924,7 @@ func TestListPeers(t *testing.T) { MachineKey: key.NewMachine().Public(), NodeKey: key.NewNode().Public(), Hostname: "test1", - UserID: user.ID, + UserID: &user.ID, RegisterMethod: util.RegisterMethodAuthKey, Hostinfo: &tailcfg.Hostinfo{}, } @@ -941,7 +934,7 @@ func TestListPeers(t *testing.T) { MachineKey: key.NewMachine().Public(), NodeKey: key.NewNode().Public(), Hostname: "test2", - UserID: user2.ID, + UserID: &user2.ID, RegisterMethod: util.RegisterMethodAuthKey, Hostinfo: &tailcfg.Hostinfo{}, } @@ -1016,7 +1009,7 @@ func TestListNodes(t *testing.T) { MachineKey: key.NewMachine().Public(), NodeKey: key.NewNode().Public(), Hostname: "test1", - UserID: user.ID, + UserID: &user.ID, RegisterMethod: util.RegisterMethodAuthKey, Hostinfo: &tailcfg.Hostinfo{}, } @@ -1026,7 +1019,7 @@ func TestListNodes(t *testing.T) { MachineKey: key.NewMachine().Public(), NodeKey: key.NewNode().Public(), Hostname: "test2", - UserID: user2.ID, + UserID: &user2.ID, RegisterMethod: util.RegisterMethodAuthKey, Hostinfo: &tailcfg.Hostinfo{}, } diff --git a/hscontrol/db/preauth_keys.go b/hscontrol/db/preauth_keys.go index 8580c5b4..105495f1 100644 --- a/hscontrol/db/preauth_keys.go +++ b/hscontrol/db/preauth_keys.go @@ -15,15 +15,15 @@ import ( ) var ( - ErrPreAuthKeyNotFound = errors.New("AuthKey not found") - ErrPreAuthKeyExpired = errors.New("AuthKey expired") - ErrSingleUseAuthKeyHasBeenUsed = errors.New("AuthKey has already been used") + ErrPreAuthKeyNotFound = errors.New("auth-key not found") + ErrPreAuthKeyExpired = errors.New("auth-key expired") + ErrSingleUseAuthKeyHasBeenUsed = errors.New("auth-key has already been used") ErrUserMismatch = errors.New("user mismatch") - ErrPreAuthKeyACLTagInvalid = errors.New("AuthKey tag is invalid") + ErrPreAuthKeyACLTagInvalid = errors.New("auth-key tag is invalid") ) func (hsdb *HSDatabase) CreatePreAuthKey( - uid types.UserID, + uid *types.UserID, reusable bool, ephemeral bool, expiration *time.Time, @@ -41,17 +41,40 @@ const ( ) // CreatePreAuthKey creates a new PreAuthKey in a user, and returns it. +// The uid parameter can be nil for system-created tagged keys. +// For tagged keys, uid tracks "created by" (who created the key). +// For user-owned keys, uid tracks the node owner. func CreatePreAuthKey( tx *gorm.DB, - uid types.UserID, + uid *types.UserID, reusable bool, ephemeral bool, expiration *time.Time, aclTags []string, ) (*types.PreAuthKeyNew, error) { - user, err := GetUserByID(tx, uid) - if err != nil { - return nil, err + // Validate: must be tagged OR user-owned, not neither + if uid == nil && len(aclTags) == 0 { + return nil, ErrPreAuthKeyNotTaggedOrOwned + } + + // If uid != nil && len(aclTags) > 0: + // Both are allowed: UserID tracks "created by", tags define node ownership + // This is valid per the new model + + var ( + user *types.User + userID *uint + ) + + if uid != nil { + var err error + + user, err = GetUserByID(tx, *uid) + if err != nil { + return nil, err + } + + userID = &user.ID } // Remove duplicates and sort for consistency @@ -108,15 +131,15 @@ func CreatePreAuthKey( } key := types.PreAuthKey{ - UserID: user.ID, - User: *user, + UserID: userID, // nil for system-created keys, or "created by" for tagged keys + User: user, // nil for system-created keys Reusable: reusable, Ephemeral: ephemeral, CreatedAt: &now, Expiration: expiration, - Tags: aclTags, - Prefix: prefix, // Store prefix - Hash: hash, // Store hash + Tags: aclTags, // empty for user-owned keys + Prefix: prefix, // Store prefix + Hash: hash, // Store hash } if err := tx.Save(&key).Error; err != nil { @@ -149,14 +172,19 @@ func ListPreAuthKeysByUser(tx *gorm.DB, uid types.UserID) ([]types.PreAuthKey, e } keys := []types.PreAuthKey{} - if err := tx.Preload("User").Where(&types.PreAuthKey{UserID: user.ID}).Find(&keys).Error; err != nil { + + err = tx.Preload("User").Where(&types.PreAuthKey{UserID: &user.ID}).Find(&keys).Error + if err != nil { return nil, err } return keys, nil } -var ErrPreAuthKeyFailedToParse = errors.New("failed to parse AuthKey") +var ( + ErrPreAuthKeyFailedToParse = errors.New("failed to parse auth-key") + ErrPreAuthKeyNotTaggedOrOwned = errors.New("auth-key must be either tagged or owned by user") +) func findAuthKey(tx *gorm.DB, keyStr string) (*types.PreAuthKey, error) { var pak types.PreAuthKey diff --git a/hscontrol/db/preauth_keys_test.go b/hscontrol/db/preauth_keys_test.go index 9ad8ae42..643b579c 100644 --- a/hscontrol/db/preauth_keys_test.go +++ b/hscontrol/db/preauth_keys_test.go @@ -24,7 +24,7 @@ func TestCreatePreAuthKey(t *testing.T) { test: func(t *testing.T, db *HSDatabase) { t.Helper() - _, err := db.CreatePreAuthKey(12345, true, false, nil, nil) + _, err := db.CreatePreAuthKey(ptr.To(types.UserID(12345)), true, false, nil, nil) assert.Error(t, err) }, }, @@ -36,7 +36,7 @@ func TestCreatePreAuthKey(t *testing.T) { user, err := db.CreateUser(types.User{Name: "test"}) require.NoError(t, err) - key, err := db.CreatePreAuthKey(types.UserID(user.ID), true, false, nil, nil) + key, err := db.CreatePreAuthKey(user.TypedID(), true, false, nil, nil) require.NoError(t, err) assert.NotEmpty(t, key.Key) @@ -83,7 +83,7 @@ func TestPreAuthKeyACLTags(t *testing.T) { user, err := db.CreateUser(types.User{Name: "test-tags-1"}) require.NoError(t, err) - _, err = db.CreatePreAuthKey(types.UserID(user.ID), false, false, nil, []string{"badtag"}) + _, err = db.CreatePreAuthKey(user.TypedID(), false, false, nil, []string{"badtag"}) assert.Error(t, err) }, }, @@ -98,7 +98,7 @@ func TestPreAuthKeyACLTags(t *testing.T) { expectedTags := []string{"tag:test1", "tag:test2"} tagsWithDuplicate := []string{"tag:test1", "tag:test2", "tag:test2"} - _, err = db.CreatePreAuthKey(types.UserID(user.ID), false, false, nil, tagsWithDuplicate) + _, err = db.CreatePreAuthKey(user.TypedID(), false, false, nil, tagsWithDuplicate) require.NoError(t, err) listedPaks, err := db.ListPreAuthKeys(types.UserID(user.ID)) @@ -128,13 +128,13 @@ func TestCannotDeleteAssignedPreAuthKey(t *testing.T) { user, err := db.CreateUser(types.User{Name: "test8"}) require.NoError(t, err) - key, err := db.CreatePreAuthKey(types.UserID(user.ID), false, false, nil, []string{"tag:good"}) + key, err := db.CreatePreAuthKey(user.TypedID(), false, false, nil, []string{"tag:good"}) require.NoError(t, err) node := types.Node{ ID: 0, Hostname: "testest", - UserID: user.ID, + UserID: &user.ID, RegisterMethod: util.RegisterMethodAuthKey, AuthKeyID: ptr.To(key.ID), } @@ -180,7 +180,7 @@ func TestPreAuthKeyAuthentication(t *testing.T) { validateResult: func(t *testing.T, pak *types.PreAuthKey) { t.Helper() - assert.Equal(t, user.ID, pak.UserID) + assert.Equal(t, user.ID, *pak.UserID) assert.NotEmpty(t, pak.Key) // Legacy keys have Key populated assert.Empty(t, pak.Prefix) // Legacy keys have empty Prefix assert.Nil(t, pak.Hash) // Legacy keys have nil Hash @@ -191,7 +191,7 @@ func TestPreAuthKeyAuthentication(t *testing.T) { setupKey: func() string { // Create new key via API keyStr, err := db.CreatePreAuthKey( - types.UserID(user.ID), + user.TypedID(), true, false, nil, []string{"tag:test"}, ) require.NoError(t, err) @@ -203,7 +203,7 @@ func TestPreAuthKeyAuthentication(t *testing.T) { validateResult: func(t *testing.T, pak *types.PreAuthKey) { t.Helper() - assert.Equal(t, user.ID, pak.UserID) + assert.Equal(t, user.ID, *pak.UserID) assert.Empty(t, pak.Key) // New keys have empty Key assert.NotEmpty(t, pak.Prefix) // New keys have Prefix assert.NotNil(t, pak.Hash) // New keys have Hash @@ -214,7 +214,7 @@ func TestPreAuthKeyAuthentication(t *testing.T) { name: "new_key_format_validation", setupKey: func() string { keyStr, err := db.CreatePreAuthKey( - types.UserID(user.ID), + user.TypedID(), true, false, nil, nil, ) require.NoError(t, err) @@ -244,7 +244,7 @@ func TestPreAuthKeyAuthentication(t *testing.T) { setupKey: func() string { // Create valid key key, err := db.CreatePreAuthKey( - types.UserID(user.ID), + user.TypedID(), true, false, nil, nil, ) require.NoError(t, err) @@ -415,11 +415,11 @@ func TestMultipleLegacyKeysAllowed(t *testing.T) { assert.Len(t, legacyKeys, 5, "should have created 5 legacy keys") // Now create new bcrypt-based keys - these should have unique prefixes - key1, err := db.CreatePreAuthKey(types.UserID(user.ID), true, false, nil, nil) + key1, err := db.CreatePreAuthKey(user.TypedID(), true, false, nil, nil) require.NoError(t, err) assert.NotEmpty(t, key1.Key) - key2, err := db.CreatePreAuthKey(types.UserID(user.ID), true, false, nil, nil) + key2, err := db.CreatePreAuthKey(user.TypedID(), true, false, nil, nil) require.NoError(t, err) assert.NotEmpty(t, key2.Key) diff --git a/hscontrol/db/schema.sql b/hscontrol/db/schema.sql index 1fb9528e..ef0a2a0e 100644 --- a/hscontrol/db/schema.sql +++ b/hscontrol/db/schema.sql @@ -81,7 +81,7 @@ CREATE TABLE nodes( given_name varchar(63), user_id integer, register_method text, - forced_tags text, + tags text, auth_key_id integer, last_seen datetime, expiry datetime, diff --git a/hscontrol/db/users.go b/hscontrol/db/users.go index dff235e4..92c3292d 100644 --- a/hscontrol/db/users.go +++ b/hscontrol/db/users.go @@ -189,7 +189,11 @@ func (hsdb *HSDatabase) GetUserByName(name string) (*types.User, error) { // ListNodesByUser gets all the nodes in a given user. func ListNodesByUser(tx *gorm.DB, uid types.UserID) (types.Nodes, error) { nodes := types.Nodes{} - if err := tx.Preload("AuthKey").Preload("AuthKey.User").Preload("User").Where(&types.Node{UserID: uint(uid)}).Find(&nodes).Error; err != nil { + + uidPtr := uint(uid) + + err := tx.Preload("AuthKey").Preload("AuthKey.User").Preload("User").Where(&types.Node{UserID: &uidPtr}).Find(&nodes).Error + if err != nil { return nil, err } diff --git a/hscontrol/db/users_test.go b/hscontrol/db/users_test.go index 1ea0772c..a3fd49b3 100644 --- a/hscontrol/db/users_test.go +++ b/hscontrol/db/users_test.go @@ -50,7 +50,7 @@ func TestDestroyUserErrors(t *testing.T) { user := db.CreateUserForTest("test") - pak, err := db.CreatePreAuthKey(types.UserID(user.ID), false, false, nil, nil) + pak, err := db.CreatePreAuthKey(user.TypedID(), false, false, nil, nil) require.NoError(t, err) err = db.DestroyUser(types.UserID(user.ID)) @@ -71,13 +71,13 @@ func TestDestroyUserErrors(t *testing.T) { user, err := db.CreateUser(types.User{Name: "test"}) require.NoError(t, err) - pak, err := db.CreatePreAuthKey(types.UserID(user.ID), false, false, nil, nil) + pak, err := db.CreatePreAuthKey(user.TypedID(), false, false, nil, nil) require.NoError(t, err) node := types.Node{ ID: 0, Hostname: "testnode", - UserID: user.ID, + UserID: &user.ID, RegisterMethod: util.RegisterMethodAuthKey, AuthKeyID: ptr.To(pak.ID), } diff --git a/hscontrol/grpcv1.go b/hscontrol/grpcv1.go index c2e9cee7..a0409e4e 100644 --- a/hscontrol/grpcv1.go +++ b/hscontrol/grpcv1.go @@ -172,7 +172,7 @@ func (api headscaleV1APIServer) CreatePreAuthKey( } preAuthKey, err := api.h.state.CreatePreAuthKey( - types.UserID(user.ID), + user.TypedID(), request.GetReusable(), request.GetEphemeral(), &expiration, @@ -341,6 +341,17 @@ func (api headscaleV1APIServer) SetTags( ctx context.Context, request *v1.SetTagsRequest, ) (*v1.SetTagsResponse, error) { + // Validate tags not empty - tagged nodes must have at least one tag + if len(request.GetTags()) == 0 { + return &v1.SetTagsResponse{ + Node: nil, + }, status.Error( + codes.InvalidArgument, + "cannot remove all tags from a node - tagged nodes must have at least one tag", + ) + } + + // Validate tag format for _, tag := range request.GetTags() { err := validateTag(tag) if err != nil { @@ -348,6 +359,16 @@ func (api headscaleV1APIServer) SetTags( } } + // User XOR Tags: nodes are either tagged or user-owned, never both. + // Setting tags on a user-owned node converts it to a tagged node. + // Once tagged, a node cannot be converted back to user-owned. + _, found := api.h.state.GetNodeByID(types.NodeID(request.GetNodeId())) + if !found { + return &v1.SetTagsResponse{ + Node: nil, + }, status.Error(codes.NotFound, "node not found") + } + node, nodeChange, err := api.h.state.SetNodeTags(types.NodeID(request.GetNodeId()), request.GetTags()) if err != nil { return &v1.SetTagsResponse{ @@ -529,13 +550,19 @@ func nodesToProto(state *state.State, nodes views.Slice[types.NodeView]) []*v1.N for index, node := range nodes.All() { resp := node.Proto() + // Tags-as-identity: tagged nodes show as TaggedDevices user in API responses + // (UserID may be set internally for "created by" tracking) + if node.IsTagged() { + resp.User = types.TaggedDevices.Proto() + } + var tags []string for _, tag := range node.RequestTags() { if state.NodeCanHaveTag(node, tag) { tags = append(tags, tag) } } - resp.ValidTags = lo.Uniq(append(tags, node.ForcedTags().AsSlice()...)) + resp.ValidTags = lo.Uniq(append(tags, node.Tags().AsSlice()...)) resp.SubnetRoutes = util.PrefixesToString(append(state.GetNodePrimaryRoutes(node.ID()), node.ExitRoutes()...)) response[index] = resp @@ -780,7 +807,7 @@ func (api headscaleV1APIServer) DebugCreateNode( NodeKey: key.NewNode().Public(), MachineKey: key.NewMachine().Public(), Hostname: request.GetName(), - User: *user, + User: user, Expiry: &time.Time{}, LastSeen: &time.Time{}, diff --git a/hscontrol/grpcv1_test.go b/hscontrol/grpcv1_test.go index 1d87bfe0..8a50dc59 100644 --- a/hscontrol/grpcv1_test.go +++ b/hscontrol/grpcv1_test.go @@ -1,6 +1,17 @@ package hscontrol -import "testing" +import ( + "context" + "testing" + + v1 "github.com/juanfont/headscale/gen/go/headscale/v1" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" + "tailscale.com/tailcfg" + "tailscale.com/types/key" +) func Test_validateTag(t *testing.T) { type args struct { @@ -40,3 +51,212 @@ func Test_validateTag(t *testing.T) { }) } } + +// TestSetTags_Conversion tests the conversion of user-owned nodes to tagged nodes. +// The tags-as-identity model allows one-way conversion from user-owned to tagged. +// Tag authorization is checked via the policy manager - unauthorized tags are rejected. +func TestSetTags_Conversion(t *testing.T) { + t.Parallel() + + app := createTestApp(t) + + // Create test user and nodes + user := app.state.CreateUserForTest("test-user") + + // Create a pre-auth key WITHOUT tags for user-owned node + pak, err := app.state.CreatePreAuthKey(user.TypedID(), false, false, nil, nil) + require.NoError(t, err) + + machineKey1 := key.NewMachine() + nodeKey1 := key.NewNode() + + // Register a user-owned node (via untagged PreAuthKey) + userOwnedReq := tailcfg.RegisterRequest{ + Auth: &tailcfg.RegisterResponseAuth{ + AuthKey: pak.Key, + }, + NodeKey: nodeKey1.Public(), + Hostinfo: &tailcfg.Hostinfo{ + Hostname: "user-owned-node", + }, + } + _, err = app.handleRegisterWithAuthKey(userOwnedReq, machineKey1.Public()) + require.NoError(t, err) + + // Get the created node + userOwnedNode, found := app.state.GetNodeByNodeKey(nodeKey1.Public()) + require.True(t, found) + + // Create API server instance + apiServer := newHeadscaleV1APIServer(app) + + tests := []struct { + name string + nodeID uint64 + tags []string + wantErr bool + wantCode codes.Code + wantErrMessage string + }{ + { + // Conversion is allowed, but tag authorization fails without tagOwners + name: "reject unauthorized tags on user-owned node", + nodeID: uint64(userOwnedNode.ID()), + tags: []string{"tag:server"}, + wantErr: true, + wantCode: codes.InvalidArgument, + wantErrMessage: "invalid or unauthorized tags", + }, + { + // Conversion is allowed, but tag authorization fails without tagOwners + name: "reject multiple unauthorized tags", + nodeID: uint64(userOwnedNode.ID()), + tags: []string{"tag:server", "tag:database"}, + wantErr: true, + wantCode: codes.InvalidArgument, + wantErrMessage: "invalid or unauthorized tags", + }, + { + name: "reject non-existent node", + nodeID: 99999, + tags: []string{"tag:server"}, + wantErr: true, + wantCode: codes.NotFound, + wantErrMessage: "node not found", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + resp, err := apiServer.SetTags(context.Background(), &v1.SetTagsRequest{ + NodeId: tt.nodeID, + Tags: tt.tags, + }) + + if tt.wantErr { + require.Error(t, err) + st, ok := status.FromError(err) + require.True(t, ok, "error should be a gRPC status error") + assert.Equal(t, tt.wantCode, st.Code()) + assert.Contains(t, st.Message(), tt.wantErrMessage) + assert.Nil(t, resp.GetNode()) + } else { + require.NoError(t, err) + assert.NotNil(t, resp) + assert.NotNil(t, resp.GetNode()) + } + }) + } +} + +// TestSetTags_TaggedNode tests that SetTags correctly identifies tagged nodes +// and doesn't reject them with the "user-owned nodes" error. +// Note: This test doesn't validate ACL tag authorization - that's tested elsewhere. +func TestSetTags_TaggedNode(t *testing.T) { + t.Parallel() + + app := createTestApp(t) + + // Create test user and tagged pre-auth key + user := app.state.CreateUserForTest("test-user") + pak, err := app.state.CreatePreAuthKey(user.TypedID(), false, false, nil, []string{"tag:initial"}) + require.NoError(t, err) + + machineKey := key.NewMachine() + nodeKey := key.NewNode() + + // Register a tagged node (via tagged PreAuthKey) + taggedReq := tailcfg.RegisterRequest{ + Auth: &tailcfg.RegisterResponseAuth{ + AuthKey: pak.Key, + }, + NodeKey: nodeKey.Public(), + Hostinfo: &tailcfg.Hostinfo{ + Hostname: "tagged-node", + }, + } + _, err = app.handleRegisterWithAuthKey(taggedReq, machineKey.Public()) + require.NoError(t, err) + + // Get the created node + taggedNode, found := app.state.GetNodeByNodeKey(nodeKey.Public()) + require.True(t, found) + assert.True(t, taggedNode.IsTagged(), "Node should be tagged") + assert.True(t, taggedNode.UserID().Valid(), "Tagged node should have UserID for tracking") + + // Create API server instance + apiServer := newHeadscaleV1APIServer(app) + + // Test: SetTags should NOT reject tagged nodes with "user-owned" error + // (Even though they have UserID set, IsTagged() identifies them correctly) + resp, err := apiServer.SetTags(context.Background(), &v1.SetTagsRequest{ + NodeId: uint64(taggedNode.ID()), + Tags: []string{"tag:initial"}, // Keep existing tag to avoid ACL validation issues + }) + + // The call should NOT fail with "cannot set tags on user-owned nodes" + if err != nil { + st, ok := status.FromError(err) + require.True(t, ok) + // If error is about unauthorized tags, that's fine - ACL validation is working + // If error is about user-owned nodes, that's the bug we're testing for + assert.NotContains(t, st.Message(), "user-owned nodes", "Should not reject tagged nodes as user-owned") + } else { + // Success is also fine + assert.NotNil(t, resp) + } +} + +// TestSetTags_CannotRemoveAllTags tests that SetTags rejects attempts to remove +// all tags from a tagged node, enforcing Tailscale's requirement that tagged +// nodes must have at least one tag. +func TestSetTags_CannotRemoveAllTags(t *testing.T) { + t.Parallel() + + app := createTestApp(t) + + // Create test user and tagged pre-auth key + user := app.state.CreateUserForTest("test-user") + pak, err := app.state.CreatePreAuthKey(user.TypedID(), false, false, nil, []string{"tag:server"}) + require.NoError(t, err) + + machineKey := key.NewMachine() + nodeKey := key.NewNode() + + // Register a tagged node + taggedReq := tailcfg.RegisterRequest{ + Auth: &tailcfg.RegisterResponseAuth{ + AuthKey: pak.Key, + }, + NodeKey: nodeKey.Public(), + Hostinfo: &tailcfg.Hostinfo{ + Hostname: "tagged-node", + }, + } + _, err = app.handleRegisterWithAuthKey(taggedReq, machineKey.Public()) + require.NoError(t, err) + + // Get the created node + taggedNode, found := app.state.GetNodeByNodeKey(nodeKey.Public()) + require.True(t, found) + assert.True(t, taggedNode.IsTagged()) + + // Create API server instance + apiServer := newHeadscaleV1APIServer(app) + + // Attempt to remove all tags (empty array) + resp, err := apiServer.SetTags(context.Background(), &v1.SetTagsRequest{ + NodeId: uint64(taggedNode.ID()), + Tags: []string{}, // Empty - attempting to remove all tags + }) + + // Should fail with InvalidArgument error + require.Error(t, err) + st, ok := status.FromError(err) + require.True(t, ok, "error should be a gRPC status error") + assert.Equal(t, codes.InvalidArgument, st.Code()) + assert.Contains(t, st.Message(), "cannot remove all tags") + assert.Nil(t, resp.GetNode()) +} diff --git a/hscontrol/mapper/mapper.go b/hscontrol/mapper/mapper.go index 372bb557..37778dc0 100644 --- a/hscontrol/mapper/mapper.go +++ b/hscontrol/mapper/mapper.go @@ -73,15 +73,17 @@ func generateUserProfiles( node types.NodeView, peers views.Slice[types.NodeView], ) []tailcfg.UserProfile { - userMap := make(map[uint]*types.User) + userMap := make(map[uint]*types.UserView) ids := make([]uint, 0, len(userMap)) user := node.User() - userMap[user.ID] = &user - ids = append(ids, user.ID) + userID := user.Model().ID + userMap[userID] = &user + ids = append(ids, userID) for _, peer := range peers.All() { peerUser := peer.User() - userMap[peerUser.ID] = &peerUser - ids = append(ids, peerUser.ID) + peerUserID := peerUser.Model().ID + userMap[peerUserID] = &peerUser + ids = append(ids, peerUserID) } slices.Sort(ids) diff --git a/hscontrol/mapper/mapper_test.go b/hscontrol/mapper/mapper_test.go index b801f7dd..1bafd135 100644 --- a/hscontrol/mapper/mapper_test.go +++ b/hscontrol/mapper/mapper_test.go @@ -14,6 +14,7 @@ import ( "github.com/juanfont/headscale/hscontrol/types" "tailscale.com/tailcfg" "tailscale.com/types/dnstype" + "tailscale.com/types/ptr" ) var iap = func(ipStr string) *netip.Addr { @@ -50,8 +51,8 @@ func TestDNSConfigMapResponse(t *testing.T) { mach := func(hostname, username string, userid uint) *types.Node { return &types.Node{ Hostname: hostname, - UserID: userid, - User: types.User{ + UserID: ptr.To(userid), + User: &types.User{ Name: username, }, } diff --git a/hscontrol/mapper/tail.go b/hscontrol/mapper/tail.go index 3a518d94..3153f62b 100644 --- a/hscontrol/mapper/tail.go +++ b/hscontrol/mapper/tail.go @@ -83,7 +83,8 @@ func tailNode( tags = append(tags, tag) } } - for _, tag := range node.ForcedTags().All() { + + for _, tag := range node.Tags().All() { tags = append(tags, tag) } tags = lo.Uniq(tags) @@ -99,7 +100,7 @@ func tailNode( Name: hostname, Cap: capVer, - User: tailcfg.UserID(node.UserID()), + User: node.TailscaleUserID(), Key: node.NodeKey(), KeyExpiry: keyExpiry.UTC(), diff --git a/hscontrol/mapper/tail_test.go b/hscontrol/mapper/tail_test.go index 3a3b39d1..9b0765ba 100644 --- a/hscontrol/mapper/tail_test.go +++ b/hscontrol/mapper/tail_test.go @@ -15,6 +15,7 @@ import ( "tailscale.com/net/tsaddr" "tailscale.com/tailcfg" "tailscale.com/types/key" + "tailscale.com/types/ptr" ) func TestTailNode(t *testing.T) { @@ -97,14 +98,14 @@ func TestTailNode(t *testing.T) { IPv4: iap("100.64.0.1"), Hostname: "mini", GivenName: "mini", - UserID: 0, - User: types.User{ + UserID: ptr.To(uint(0)), + User: &types.User{ Name: "mini", }, - ForcedTags: []string{}, - AuthKey: &types.PreAuthKey{}, - LastSeen: &lastSeen, - Expiry: &expire, + Tags: []string{}, + AuthKey: &types.PreAuthKey{}, + LastSeen: &lastSeen, + Expiry: &expire, Hostinfo: &tailcfg.Hostinfo{ RoutableIPs: []netip.Prefix{ tsaddr.AllIPv4(), diff --git a/hscontrol/oidc_template_test.go b/hscontrol/oidc_template_test.go index 4c5ecaa7..367451b1 100644 --- a/hscontrol/oidc_template_test.go +++ b/hscontrol/oidc_template_test.go @@ -1,13 +1,10 @@ package hscontrol import ( - "os" - "path/filepath" "testing" "github.com/juanfont/headscale/hscontrol/templates" "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" ) func TestOIDCCallbackTemplate(t *testing.T) { @@ -49,15 +46,6 @@ func TestOIDCCallbackTemplate(t *testing.T) { assert.Contains(t, html, "= len(newIPs) || oldIP != newIPs[i] { - affectedUsers[newNode.User().ID] = struct{}{} + affectedUsers[newNode.User().ID()] = struct{}{} break } } @@ -750,7 +750,7 @@ func (pm *PolicyManager) invalidateAutogroupSelfCache(oldNodes, newNodes views.S // Check in new nodes first for _, node := range newNodes.All() { if node.ID() == nodeID { - nodeUserID = node.User().ID + nodeUserID = node.User().ID() found = true break } @@ -760,7 +760,7 @@ func (pm *PolicyManager) invalidateAutogroupSelfCache(oldNodes, newNodes views.S if !found { for _, node := range oldNodes.All() { if node.ID() == nodeID { - nodeUserID = node.User().ID + nodeUserID = node.User().ID() found = true break } diff --git a/hscontrol/policy/v2/policy_test.go b/hscontrol/policy/v2/policy_test.go index 5b36b79e..94e631e7 100644 --- a/hscontrol/policy/v2/policy_test.go +++ b/hscontrol/policy/v2/policy_test.go @@ -11,6 +11,7 @@ import ( "github.com/stretchr/testify/require" "gorm.io/gorm" "tailscale.com/tailcfg" + "tailscale.com/types/ptr" ) func node(name, ipv4, ipv6 string, user types.User, hostinfo *tailcfg.Hostinfo) *types.Node { @@ -19,8 +20,8 @@ func node(name, ipv4, ipv6 string, user types.User, hostinfo *tailcfg.Hostinfo) Hostname: name, IPv4: ap(ipv4), IPv6: ap(ipv6), - User: user, - UserID: user.ID, + User: ptr.To(user), + UserID: ptr.To(user.ID), Hostinfo: hostinfo, } } @@ -456,8 +457,8 @@ func TestAutogroupSelfWithOtherRules(t *testing.T) { Hostname: "test-1-device", IPv4: ap("100.64.0.1"), IPv6: ap("fd7a:115c:a1e0::1"), - User: users[0], - UserID: users[0].ID, + User: ptr.To(users[0]), + UserID: ptr.To(users[0].ID), Hostinfo: &tailcfg.Hostinfo{}, } @@ -467,9 +468,9 @@ func TestAutogroupSelfWithOtherRules(t *testing.T) { Hostname: "test-2-router", IPv4: ap("100.64.0.2"), IPv6: ap("fd7a:115c:a1e0::2"), - User: users[1], - UserID: users[1].ID, - ForcedTags: []string{"tag:node-router"}, + User: ptr.To(users[1]), + UserID: ptr.To(users[1].ID), + Tags: []string{"tag:node-router"}, Hostinfo: &tailcfg.Hostinfo{}, } diff --git a/hscontrol/policy/v2/types.go b/hscontrol/policy/v2/types.go index 7b4b2b28..0635a557 100644 --- a/hscontrol/policy/v2/types.go +++ b/hscontrol/policy/v2/types.go @@ -206,7 +206,12 @@ func (u Username) Resolve(_ *Policy, users types.Users, nodes views.Slice[types. continue } - if node.User().ID == user.ID { + // Skip nodes without a user (defensive check for tests) + if !node.User().Valid() { + continue + } + + if node.User().ID() == user.ID { node.AppendToIPSet(&ips) } } @@ -311,8 +316,8 @@ func (t Tag) Resolve(p *Policy, users types.Users, nodes views.Slice[types.NodeV } for _, node := range nodes.All() { - // Check if node has this tag in all tags (ForcedTags + AuthKey.Tags) - if slices.Contains(node.Tags(), string(t)) { + // Check if node has this tag + if node.HasTag(string(t)) { node.AppendToIPSet(&ips) } diff --git a/hscontrol/policy/v2/types_test.go b/hscontrol/policy/v2/types_test.go index d5a8730a..2d379b4d 100644 --- a/hscontrol/policy/v2/types_test.go +++ b/hscontrol/policy/v2/types_test.go @@ -1549,7 +1549,17 @@ func TestResolvePolicy(t *testing.T) { "groupuser1": {Model: gorm.Model{ID: 3}, Name: "groupuser1"}, "groupuser2": {Model: gorm.Model{ID: 4}, Name: "groupuser2"}, "notme": {Model: gorm.Model{ID: 5}, Name: "notme"}, + "testuser2": {Model: gorm.Model{ID: 6}, Name: "testuser2"}, } + + // Extract users to variables so we can take their addresses + testuser := users["testuser"] + groupuser := users["groupuser"] + groupuser1 := users["groupuser1"] + groupuser2 := users["groupuser2"] + notme := users["notme"] + testuser2 := users["testuser2"] + tests := []struct { name string nodes types.Nodes @@ -1579,29 +1589,27 @@ func TestResolvePolicy(t *testing.T) { nodes: types.Nodes{ // Not matching other user { - User: users["notme"], + User: ptr.To(notme), IPv4: ap("100.100.101.1"), }, // Not matching forced tags { - User: users["testuser"], - ForcedTags: []string{"tag:anything"}, + User: ptr.To(testuser), + Tags: []string{"tag:anything"}, IPv4: ap("100.100.101.2"), }, - // not matching pak tag + // not matching because it's tagged (tags copied from AuthKey) { - User: users["testuser"], - AuthKey: &types.PreAuthKey{ - Tags: []string{"alsotagged"}, - }, + User: ptr.To(testuser), + Tags: []string{"alsotagged"}, IPv4: ap("100.100.101.3"), }, { - User: users["testuser"], + User: ptr.To(testuser), IPv4: ap("100.100.101.103"), }, { - User: users["testuser"], + User: ptr.To(testuser), IPv4: ap("100.100.101.104"), }, }, @@ -1613,29 +1621,27 @@ func TestResolvePolicy(t *testing.T) { nodes: types.Nodes{ // Not matching other user { - User: users["notme"], + User: ptr.To(notme), IPv4: ap("100.100.101.4"), }, // Not matching forced tags { - User: users["groupuser"], - ForcedTags: []string{"tag:anything"}, + User: ptr.To(groupuser), + Tags: []string{"tag:anything"}, IPv4: ap("100.100.101.5"), }, - // not matching pak tag + // not matching because it's tagged (tags copied from AuthKey) { - User: users["groupuser"], - AuthKey: &types.PreAuthKey{ - Tags: []string{"tag:alsotagged"}, - }, + User: ptr.To(groupuser), + Tags: []string{"tag:alsotagged"}, IPv4: ap("100.100.101.6"), }, { - User: users["groupuser"], + User: ptr.To(groupuser), IPv4: ap("100.100.101.203"), }, { - User: users["groupuser"], + User: ptr.To(groupuser), IPv4: ap("100.100.101.204"), }, }, @@ -1653,12 +1659,12 @@ func TestResolvePolicy(t *testing.T) { nodes: types.Nodes{ // Not matching other user { - User: users["notme"], + User: ptr.To(notme), IPv4: ap("100.100.101.9"), }, // Not matching forced tags { - ForcedTags: []string{"tag:anything"}, + Tags: []string{"tag:anything"}, IPv4: ap("100.100.101.10"), }, // not matching pak tag @@ -1670,14 +1676,12 @@ func TestResolvePolicy(t *testing.T) { }, // Not matching forced tags { - ForcedTags: []string{"tag:test"}, + Tags: []string{"tag:test"}, IPv4: ap("100.100.101.234"), }, - // not matching pak tag + // matching tag (tags copied from AuthKey during registration) { - AuthKey: &types.PreAuthKey{ - Tags: []string{"tag:test"}, - }, + Tags: []string{"tag:test"}, IPv4: ap("100.100.101.239"), }, }, @@ -1706,11 +1710,11 @@ func TestResolvePolicy(t *testing.T) { toResolve: ptr.To(Group("group:testgroup")), nodes: types.Nodes{ { - User: users["groupuser1"], + User: ptr.To(groupuser1), IPv4: ap("100.100.101.203"), }, { - User: users["groupuser2"], + User: ptr.To(groupuser2), IPv4: ap("100.100.101.204"), }, }, @@ -1731,7 +1735,7 @@ func TestResolvePolicy(t *testing.T) { toResolve: ptr.To(Username("invaliduser@")), nodes: types.Nodes{ { - User: users["testuser"], + User: ptr.To(testuser), IPv4: ap("100.100.101.103"), }, }, @@ -1742,7 +1746,7 @@ func TestResolvePolicy(t *testing.T) { toResolve: tp("tag:invalid"), nodes: types.Nodes{ { - ForcedTags: []string{"tag:test"}, + Tags: []string{"tag:test"}, IPv4: ap("100.100.101.234"), }, }, @@ -1763,18 +1767,18 @@ func TestResolvePolicy(t *testing.T) { nodes: types.Nodes{ // Node with no tags (should be included) { - User: users["testuser"], + User: ptr.To(testuser), IPv4: ap("100.100.101.1"), }, // Node with forced tags (should be excluded) { - User: users["testuser"], - ForcedTags: []string{"tag:test"}, + User: ptr.To(testuser), + Tags: []string{"tag:test"}, IPv4: ap("100.100.101.2"), }, // Node with allowed requested tag (should be excluded) { - User: users["testuser"], + User: ptr.To(testuser), Hostinfo: &tailcfg.Hostinfo{ RequestTags: []string{"tag:test"}, }, @@ -1782,7 +1786,7 @@ func TestResolvePolicy(t *testing.T) { }, // Node with non-allowed requested tag (should be included) { - User: users["testuser"], + User: ptr.To(testuser), Hostinfo: &tailcfg.Hostinfo{ RequestTags: []string{"tag:notallowed"}, }, @@ -1790,7 +1794,7 @@ func TestResolvePolicy(t *testing.T) { }, // Node with multiple requested tags, one allowed (should be excluded) { - User: users["testuser"], + User: ptr.To(testuser), Hostinfo: &tailcfg.Hostinfo{ RequestTags: []string{"tag:test", "tag:notallowed"}, }, @@ -1798,7 +1802,7 @@ func TestResolvePolicy(t *testing.T) { }, // Node with multiple requested tags, none allowed (should be included) { - User: users["testuser"], + User: ptr.To(testuser), Hostinfo: &tailcfg.Hostinfo{ RequestTags: []string{"tag:notallowed1", "tag:notallowed2"}, }, @@ -1822,18 +1826,18 @@ func TestResolvePolicy(t *testing.T) { nodes: types.Nodes{ // Node with no tags (should be excluded) { - User: users["testuser"], + User: ptr.To(testuser), IPv4: ap("100.100.101.1"), }, // Node with forced tag (should be included) { - User: users["testuser"], - ForcedTags: []string{"tag:test"}, + User: ptr.To(testuser), + Tags: []string{"tag:test"}, IPv4: ap("100.100.101.2"), }, // Node with allowed requested tag (should be included) { - User: users["testuser"], + User: ptr.To(testuser), Hostinfo: &tailcfg.Hostinfo{ RequestTags: []string{"tag:test"}, }, @@ -1841,7 +1845,7 @@ func TestResolvePolicy(t *testing.T) { }, // Node with non-allowed requested tag (should be excluded) { - User: users["testuser"], + User: ptr.To(testuser), Hostinfo: &tailcfg.Hostinfo{ RequestTags: []string{"tag:notallowed"}, }, @@ -1849,7 +1853,7 @@ func TestResolvePolicy(t *testing.T) { }, // Node with multiple requested tags, one allowed (should be included) { - User: users["testuser"], + User: ptr.To(testuser), Hostinfo: &tailcfg.Hostinfo{ RequestTags: []string{"tag:test", "tag:notallowed"}, }, @@ -1857,7 +1861,7 @@ func TestResolvePolicy(t *testing.T) { }, // Node with multiple requested tags, none allowed (should be excluded) { - User: users["testuser"], + User: ptr.To(testuser), Hostinfo: &tailcfg.Hostinfo{ RequestTags: []string{"tag:notallowed1", "tag:notallowed2"}, }, @@ -1865,8 +1869,8 @@ func TestResolvePolicy(t *testing.T) { }, // Node with multiple forced tags (should be included) { - User: users["testuser"], - ForcedTags: []string{"tag:test", "tag:other"}, + User: ptr.To(testuser), + Tags: []string{"tag:test", "tag:other"}, IPv4: ap("100.100.101.7"), }, }, @@ -1886,20 +1890,20 @@ func TestResolvePolicy(t *testing.T) { toResolve: ptr.To(AutoGroupSelf), nodes: types.Nodes{ { - User: users["testuser"], + User: ptr.To(testuser), IPv4: ap("100.100.101.1"), }, { - User: users["testuser2"], + User: ptr.To(testuser2), IPv4: ap("100.100.101.2"), }, { - User: users["testuser"], - ForcedTags: []string{"tag:test"}, + User: ptr.To(testuser), + Tags: []string{"tag:test"}, IPv4: ap("100.100.101.3"), }, { - User: users["testuser2"], + User: ptr.To(testuser2), Hostinfo: &tailcfg.Hostinfo{ RequestTags: []string{"tag:test"}, }, @@ -1961,23 +1965,23 @@ func TestResolveAutoApprovers(t *testing.T) { nodes := types.Nodes{ { IPv4: ap("100.64.0.1"), - User: users[0], + User: &users[0], }, { IPv4: ap("100.64.0.2"), - User: users[1], + User: &users[1], }, { IPv4: ap("100.64.0.3"), - User: users[2], + User: &users[2], }, { IPv4: ap("100.64.0.4"), - ForcedTags: []string{"tag:testtag"}, + Tags: []string{"tag:testtag"}, }, { IPv4: ap("100.64.0.5"), - ForcedTags: []string{"tag:exittest"}, + Tags: []string{"tag:exittest"}, }, } @@ -2280,15 +2284,15 @@ func TestNodeCanApproveRoute(t *testing.T) { nodes := types.Nodes{ { IPv4: ap("100.64.0.1"), - User: users[0], + User: &users[0], }, { IPv4: ap("100.64.0.2"), - User: users[1], + User: &users[1], }, { IPv4: ap("100.64.0.3"), - User: users[2], + User: &users[2], }, } @@ -2413,15 +2417,15 @@ func TestResolveTagOwners(t *testing.T) { nodes := types.Nodes{ { IPv4: ap("100.64.0.1"), - User: users[0], + User: &users[0], }, { IPv4: ap("100.64.0.2"), - User: users[1], + User: &users[1], }, { IPv4: ap("100.64.0.3"), - User: users[2], + User: &users[2], }, } @@ -2498,15 +2502,15 @@ func TestNodeCanHaveTag(t *testing.T) { nodes := types.Nodes{ { IPv4: ap("100.64.0.1"), - User: users[0], + User: &users[0], }, { IPv4: ap("100.64.0.2"), - User: users[1], + User: &users[1], }, { IPv4: ap("100.64.0.3"), - User: users[2], + User: &users[2], }, } @@ -2580,6 +2584,49 @@ func TestNodeCanHaveTag(t *testing.T) { tag: "tag:test", want: false, }, + { + name: "node-with-unauthorized-tag-different-user", + policy: &Policy{ + TagOwners: TagOwners{ + Tag("tag:prod"): Owners{ptr.To(Username("user1@"))}, + }, + }, + node: nodes[2], // user3's node + tag: "tag:prod", + want: false, + }, + { + name: "node-with-multiple-tags-one-unauthorized", + policy: &Policy{ + TagOwners: TagOwners{ + Tag("tag:web"): Owners{ptr.To(Username("user1@"))}, + Tag("tag:database"): Owners{ptr.To(Username("user2@"))}, + }, + }, + node: nodes[0], // user1's node + tag: "tag:database", + want: false, // user1 cannot have tag:database (owned by user2) + }, + { + name: "empty-tagowners-map", + policy: &Policy{ + TagOwners: TagOwners{}, + }, + node: nodes[0], + tag: "tag:test", + want: false, // No one can have tags if tagOwners is empty + }, + { + name: "tag-not-in-tagowners", + policy: &Policy{ + TagOwners: TagOwners{ + Tag("tag:prod"): Owners{ptr.To(Username("user1@"))}, + }, + }, + node: nodes[0], + tag: "tag:dev", // This tag is not defined in tagOwners + want: false, + }, } for _, tt := range tests { diff --git a/hscontrol/poll.go b/hscontrol/poll.go index 4324ffba..b0453051 100644 --- a/hscontrol/poll.go +++ b/hscontrol/poll.go @@ -11,6 +11,7 @@ import ( "github.com/juanfont/headscale/hscontrol/types" "github.com/juanfont/headscale/hscontrol/util" + "github.com/rs/zerolog" "github.com/rs/zerolog/log" "github.com/sasha-s/go-deadlock" "tailscale.com/tailcfg" @@ -42,11 +43,6 @@ type mapSession struct { node *types.Node w http.ResponseWriter - - warnf func(string, ...any) - infof func(string, ...any) - tracef func(string, ...any) - errf func(error, string, ...any) } func (h *Headscale) newMapSession( @@ -55,8 +51,6 @@ func (h *Headscale) newMapSession( w http.ResponseWriter, node *types.Node, ) *mapSession { - warnf, infof, tracef, errf := logPollFunc(req, node) - ka := keepAliveInterval + (time.Duration(rand.IntN(9000)) * time.Millisecond) return &mapSession{ @@ -73,12 +67,6 @@ func (h *Headscale) newMapSession( keepAlive: ka, keepAliveTicker: nil, - - // Loggers - warnf: warnf, - infof: infof, - tracef: tracef, - errf: errf, } } @@ -295,6 +283,7 @@ func (m *mapSession) writeMap(msg *tailcfg.MapResponse) error { } data := make([]byte, reservedResponseHeaderSize) + //nolint:gosec // G115: JSON response size will not exceed uint32 max binary.LittleEndian.PutUint32(data, uint32(len(jsonBody))) data = append(data, jsonBody...) @@ -365,45 +354,22 @@ func logTracePeerChange(hostname string, hostinfoChange bool, peerChange *tailcf trace.Time("last_seen", *peerChange.LastSeen).Msg("PeerChange received") } -func logPollFunc( - mapRequest tailcfg.MapRequest, - node *types.Node, -) (func(string, ...any), func(string, ...any), func(string, ...any), func(error, string, ...any)) { - return func(msg string, a ...any) { - log.Warn(). - Caller(). - Bool("omitPeers", mapRequest.OmitPeers). - Bool("stream", mapRequest.Stream). - Uint64("node.id", node.ID.Uint64()). - Str("node.name", node.Hostname). - Msgf(msg, a...) - }, - func(msg string, a ...any) { - log.Info(). - Caller(). - Bool("omitPeers", mapRequest.OmitPeers). - Bool("stream", mapRequest.Stream). - Uint64("node.id", node.ID.Uint64()). - Str("node.name", node.Hostname). - Msgf(msg, a...) - }, - func(msg string, a ...any) { - log.Trace(). - Caller(). - Bool("omitPeers", mapRequest.OmitPeers). - Bool("stream", mapRequest.Stream). - Uint64("node.id", node.ID.Uint64()). - Str("node.name", node.Hostname). - Msgf(msg, a...) - }, - func(err error, msg string, a ...any) { - log.Error(). - Caller(). - Bool("omitPeers", mapRequest.OmitPeers). - Bool("stream", mapRequest.Stream). - Uint64("node.id", node.ID.Uint64()). - Str("node.name", node.Hostname). - Err(err). - Msgf(msg, a...) - } +// logf adds common mapSession context to a zerolog event. +func (m *mapSession) logf(event *zerolog.Event) *zerolog.Event { + return event. + Bool("omitPeers", m.req.OmitPeers). + Bool("stream", m.req.Stream). + Uint64("node.id", m.node.ID.Uint64()). + Str("node.name", m.node.Hostname) +} + +//nolint:zerologlint // logf returns *zerolog.Event which is properly terminated with Msgf +func (m *mapSession) infof(msg string, a ...any) { m.logf(log.Info().Caller()).Msgf(msg, a...) } + +//nolint:zerologlint // logf returns *zerolog.Event which is properly terminated with Msgf +func (m *mapSession) tracef(msg string, a ...any) { m.logf(log.Trace().Caller()).Msgf(msg, a...) } + +//nolint:zerologlint // logf returns *zerolog.Event which is properly terminated with Msgf +func (m *mapSession) errf(err error, msg string, a ...any) { + m.logf(log.Error().Caller()).Err(err).Msgf(msg, a...) } diff --git a/hscontrol/state/debug.go b/hscontrol/state/debug.go index 757591ad..ef6ef50c 100644 --- a/hscontrol/state/debug.go +++ b/hscontrol/state/debug.go @@ -78,7 +78,7 @@ func (s *State) DebugOverview() string { now := time.Now() for _, node := range allNodes.All() { if node.Valid() { - userName := node.User().Name + userName := node.User().Name() userNodeCounts[userName]++ if node.IsOnline().Valid() && node.IsOnline().Get() { @@ -281,7 +281,7 @@ func (s *State) DebugOverviewJSON() DebugOverviewInfo { for _, node := range allNodes.All() { if node.Valid() { - userName := node.User().Name + userName := node.User().Name() info.Users[userName]++ if node.IsOnline().Valid() && node.IsOnline().Get() { diff --git a/hscontrol/state/maprequest_test.go b/hscontrol/state/maprequest_test.go index 865d3eb4..99f781d4 100644 --- a/hscontrol/state/maprequest_test.go +++ b/hscontrol/state/maprequest_test.go @@ -9,6 +9,7 @@ import ( "github.com/stretchr/testify/require" "tailscale.com/tailcfg" "tailscale.com/types/key" + "tailscale.com/types/ptr" ) func TestNetInfoFromMapRequest(t *testing.T) { @@ -148,8 +149,8 @@ func createTestNodeSimple(id types.NodeID) *types.Node { node := &types.Node{ ID: id, Hostname: "test-node", - UserID: uint(id), - User: user, + UserID: ptr.To(uint(id)), + User: &user, MachineKey: machineKey.Public(), NodeKey: nodeKey.Public(), IPv4: &netip.Addr{}, diff --git a/hscontrol/state/node_store.go b/hscontrol/state/node_store.go index adcc410a..241d2f46 100644 --- a/hscontrol/state/node_store.go +++ b/hscontrol/state/node_store.go @@ -408,7 +408,7 @@ func snapshotFromNodes(nodes map[types.NodeID]types.Node, peersFunc PeersFunc) S // Build nodesByUser, nodesByNodeKey, and nodesByMachineKey maps for _, n := range nodes { nodeView := n.View() - userID := types.UserID(n.UserID) + userID := n.TypedUserID() newSnap.nodesByUser[userID] = append(newSnap.nodesByUser[userID], nodeView) newSnap.nodesByNodeKey[n.NodeKey] = nodeView @@ -515,7 +515,7 @@ func (s *NodeStore) DebugString() string { if len(nodes) > 0 { userName := "unknown" if len(nodes) > 0 && nodes[0].Valid() { - userName = nodes[0].User().Name + userName = nodes[0].User().Name() } sb.WriteString(fmt.Sprintf(" - User %d (%s): %d nodes\n", userID, userName, len(nodes))) } diff --git a/hscontrol/state/node_store_test.go b/hscontrol/state/node_store_test.go index 788721b9..82f1a255 100644 --- a/hscontrol/state/node_store_test.go +++ b/hscontrol/state/node_store_test.go @@ -13,6 +13,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "tailscale.com/types/key" + "tailscale.com/types/ptr" ) func TestSnapshotFromNodes(t *testing.T) { @@ -173,8 +174,8 @@ func createTestNode(nodeID types.NodeID, userID uint, username, hostname string) DiscoKey: discoKey.Public(), Hostname: hostname, GivenName: hostname, - UserID: userID, - User: types.User{ + UserID: ptr.To(userID), + User: &types.User{ Name: username, DisplayName: username, }, @@ -627,7 +628,7 @@ func TestNodeStoreOperations(t *testing.T) { go func() { resultNode3, ok3 = store.UpdateNode(1, func(n *types.Node) { - n.ForcedTags = []string{"tag1", "tag2"} + n.Tags = []string{"tag1", "tag2"} }) close(done3) }() @@ -648,24 +649,24 @@ func TestNodeStoreOperations(t *testing.T) { // resultNode1 (from hostname update) should also have the givenname and tags changes assert.Equal(t, "multi-update-hostname", resultNode1.Hostname()) assert.Equal(t, "multi-update-givenname", resultNode1.GivenName()) - assert.Equal(t, []string{"tag1", "tag2"}, resultNode1.ForcedTags().AsSlice()) + assert.Equal(t, []string{"tag1", "tag2"}, resultNode1.Tags().AsSlice()) // resultNode2 (from givenname update) should also have the hostname and tags changes assert.Equal(t, "multi-update-hostname", resultNode2.Hostname()) assert.Equal(t, "multi-update-givenname", resultNode2.GivenName()) - assert.Equal(t, []string{"tag1", "tag2"}, resultNode2.ForcedTags().AsSlice()) + assert.Equal(t, []string{"tag1", "tag2"}, resultNode2.Tags().AsSlice()) // resultNode3 (from tags update) should also have the hostname and givenname changes assert.Equal(t, "multi-update-hostname", resultNode3.Hostname()) assert.Equal(t, "multi-update-givenname", resultNode3.GivenName()) - assert.Equal(t, []string{"tag1", "tag2"}, resultNode3.ForcedTags().AsSlice()) + assert.Equal(t, []string{"tag1", "tag2"}, resultNode3.Tags().AsSlice()) // Verify the snapshot also has all changes snapshot := store.data.Load() finalNode := snapshot.nodesByID[1] assert.Equal(t, "multi-update-hostname", finalNode.Hostname) assert.Equal(t, "multi-update-givenname", finalNode.GivenName) - assert.Equal(t, []string{"tag1", "tag2"}, finalNode.ForcedTags) + assert.Equal(t, []string{"tag1", "tag2"}, finalNode.Tags) }, }, }, @@ -687,7 +688,7 @@ func TestNodeStoreOperations(t *testing.T) { resultNode, ok := store.UpdateNode(1, func(n *types.Node) { n.Hostname = "db-save-hostname" n.GivenName = "db-save-given" - n.ForcedTags = []string{"db-tag1", "db-tag2"} + n.Tags = []string{"db-tag1", "db-tag2"} }) assert.True(t, ok, "UpdateNode should succeed") @@ -696,21 +697,21 @@ func TestNodeStoreOperations(t *testing.T) { // Verify the returned node has all expected values assert.Equal(t, "db-save-hostname", resultNode.Hostname()) assert.Equal(t, "db-save-given", resultNode.GivenName()) - assert.Equal(t, []string{"db-tag1", "db-tag2"}, resultNode.ForcedTags().AsSlice()) + assert.Equal(t, []string{"db-tag1", "db-tag2"}, resultNode.Tags().AsSlice()) // Convert to struct as would be done for database save nodePtr := resultNode.AsStruct() assert.NotNil(t, nodePtr) assert.Equal(t, "db-save-hostname", nodePtr.Hostname) assert.Equal(t, "db-save-given", nodePtr.GivenName) - assert.Equal(t, []string{"db-tag1", "db-tag2"}, nodePtr.ForcedTags) + assert.Equal(t, []string{"db-tag1", "db-tag2"}, nodePtr.Tags) // Verify the snapshot also reflects the same state snapshot := store.data.Load() storedNode := snapshot.nodesByID[1] assert.Equal(t, "db-save-hostname", storedNode.Hostname) assert.Equal(t, "db-save-given", storedNode.GivenName) - assert.Equal(t, []string{"db-tag1", "db-tag2"}, storedNode.ForcedTags) + assert.Equal(t, []string{"db-tag1", "db-tag2"}, storedNode.Tags) }, }, { @@ -742,7 +743,7 @@ func TestNodeStoreOperations(t *testing.T) { go func() { result3, ok3 = store.UpdateNode(1, func(n *types.Node) { - n.ForcedTags = []string{"concurrent-tag"} + n.Tags = []string{"concurrent-tag"} }) close(done3) }() @@ -767,22 +768,22 @@ func TestNodeStoreOperations(t *testing.T) { // All should have the complete final state assert.Equal(t, "concurrent-db-hostname", nodePtr1.Hostname) assert.Equal(t, "concurrent-db-given", nodePtr1.GivenName) - assert.Equal(t, []string{"concurrent-tag"}, nodePtr1.ForcedTags) + assert.Equal(t, []string{"concurrent-tag"}, nodePtr1.Tags) assert.Equal(t, "concurrent-db-hostname", nodePtr2.Hostname) assert.Equal(t, "concurrent-db-given", nodePtr2.GivenName) - assert.Equal(t, []string{"concurrent-tag"}, nodePtr2.ForcedTags) + assert.Equal(t, []string{"concurrent-tag"}, nodePtr2.Tags) assert.Equal(t, "concurrent-db-hostname", nodePtr3.Hostname) assert.Equal(t, "concurrent-db-given", nodePtr3.GivenName) - assert.Equal(t, []string{"concurrent-tag"}, nodePtr3.ForcedTags) + assert.Equal(t, []string{"concurrent-tag"}, nodePtr3.Tags) // Verify consistency with stored state snapshot := store.data.Load() storedNode := snapshot.nodesByID[1] assert.Equal(t, nodePtr1.Hostname, storedNode.Hostname) assert.Equal(t, nodePtr1.GivenName, storedNode.GivenName) - assert.Equal(t, nodePtr1.ForcedTags, storedNode.ForcedTags) + assert.Equal(t, nodePtr1.Tags, storedNode.Tags) }, }, { @@ -855,8 +856,8 @@ func createConcurrentTestNode(id types.NodeID, hostname string) types.Node { Hostname: hostname, MachineKey: machineKey.Public(), NodeKey: nodeKey.Public(), - UserID: 1, - User: types.User{ + UserID: ptr.To(uint(1)), + User: &types.User{ Name: "concurrent-test-user", }, } diff --git a/hscontrol/state/state.go b/hscontrol/state/state.go index 68aefbc4..149bae4d 100644 --- a/hscontrol/state/state.go +++ b/hscontrol/state/state.go @@ -53,6 +53,9 @@ const ( // ErrUnsupportedPolicyMode is returned for invalid policy modes. Valid modes are "file" and "db". var ErrUnsupportedPolicyMode = errors.New("unsupported policy mode") +// ErrNodeNotFound is returned when a node cannot be found by its ID. +var ErrNodeNotFound = errors.New("node not found") + // State manages Headscale's core state, coordinating between database, policy management, // IP allocation, and DERP routing. All methods are thread-safe. type State struct { @@ -651,13 +654,36 @@ func (s *State) SetNodeExpiry(nodeID types.NodeID, expiry time.Time) (types.Node return s.persistNodeToDB(n) } -// SetNodeTags assigns tags to a node for use in access control policies. +// SetNodeTags assigns tags to a node, making it a "tagged node". +// Once a node is tagged, it cannot be un-tagged (only tags can be changed). +// The UserID is preserved as "created by" information. func (s *State) SetNodeTags(nodeID types.NodeID, tags []string) (types.NodeView, change.ChangeSet, error) { + // CANNOT REMOVE ALL TAGS + if len(tags) == 0 { + return types.NodeView{}, change.EmptySet, types.ErrCannotRemoveAllTags + } + + // Get node for validation + existingNode, exists := s.nodeStore.GetNode(nodeID) + if !exists { + return types.NodeView{}, change.EmptySet, fmt.Errorf("%w: %d", ErrNodeNotFound, nodeID) + } + + // Validate tags against policy + validatedTags, err := s.validateAndNormalizeTags(existingNode.AsStruct(), tags) + if err != nil { + return types.NodeView{}, change.EmptySet, err + } + + // Log the operation + logTagOperation(existingNode, validatedTags) + // Update NodeStore before database to ensure consistency. The NodeStore update is // blocking and will be the source of truth for the batcher. The database update must // make the exact same change. n, ok := s.nodeStore.UpdateNode(nodeID, func(node *types.Node) { - node.ForcedTags = tags + node.Tags = validatedTags + // UserID is preserved as "created by" - do NOT set to nil }) if !ok { @@ -927,7 +953,8 @@ func (s *State) DestroyAPIKey(key types.APIKey) error { } // CreatePreAuthKey generates a new pre-authentication key for a user. -func (s *State) CreatePreAuthKey(userID types.UserID, reusable bool, ephemeral bool, expiration *time.Time, aclTags []string) (*types.PreAuthKeyNew, error) { +// The userID parameter is now optional (can be nil) for system-created tagged keys. +func (s *State) CreatePreAuthKey(userID *types.UserID, reusable bool, ephemeral bool, expiration *time.Time, aclTags []string) (*types.PreAuthKeyNew, error) { return s.db.CreatePreAuthKey(userID, reusable, ephemeral, expiration, aclTags) } @@ -1063,8 +1090,6 @@ func (s *State) createAndSaveNewNode(params newNodeParams) (types.NodeView, erro // Prepare the node for registration nodeToRegister := types.Node{ Hostname: params.Hostname, - UserID: params.User.ID, - User: params.User, MachineKey: params.MachineKey, NodeKey: params.NodeKey, DiscoKey: params.DiscoKey, @@ -1075,11 +1100,38 @@ func (s *State) createAndSaveNewNode(params newNodeParams) (types.NodeView, erro Expiry: params.Expiry, } - // Pre-auth key specific fields + // Assign ownership based on PreAuthKey if params.PreAuthKey != nil { - nodeToRegister.ForcedTags = params.PreAuthKey.Proto().GetAclTags() + if params.PreAuthKey.IsTagged() { + // TAGGED NODE + // Tags from PreAuthKey are assigned ONLY during initial authentication + nodeToRegister.Tags = params.PreAuthKey.Proto().GetAclTags() + + // Set UserID to track "created by" (who created the PreAuthKey) + if params.PreAuthKey.UserID != nil { + nodeToRegister.UserID = params.PreAuthKey.UserID + nodeToRegister.User = params.PreAuthKey.User + } + // If PreAuthKey.UserID is nil, the node is "orphaned" (system-created) + } else { + // USER-OWNED NODE + nodeToRegister.UserID = ¶ms.PreAuthKey.User.ID + nodeToRegister.User = params.PreAuthKey.User + nodeToRegister.Tags = nil + } nodeToRegister.AuthKey = params.PreAuthKey nodeToRegister.AuthKeyID = ¶ms.PreAuthKey.ID + } else { + // Non-PreAuthKey registration (OIDC, CLI) - always user-owned + nodeToRegister.UserID = ¶ms.User.ID + nodeToRegister.User = ¶ms.User + nodeToRegister.Tags = nil + } + + // Validate before saving + err := validateNodeOwnership(&nodeToRegister) + if err != nil { + return types.NodeView{}, err } // Allocate new IPs @@ -1156,7 +1208,7 @@ func (s *State) HandleNodeFromAuthPath( logHostinfoValidation( regEntry.Node.MachineKey.ShortString(), regEntry.Node.NodeKey.String(), - user.Username(), + user.Name, hostname, regEntry.Node.Hostinfo, ) @@ -1171,7 +1223,7 @@ func (s *State) HandleNodeFromAuthPath( log.Debug(). Caller(). Str("registration_id", registrationID.String()). - Str("user.name", user.Username()). + Str("user.name", user.Name). Str("registrationMethod", registrationMethod). Str("node.name", existingNodeSameUser.Hostname()). Uint64("node.id", existingNodeSameUser.ID().Uint64()). @@ -1233,7 +1285,7 @@ func (s *State) HandleNodeFromAuthPath( // Check if node exists with this machine key for a different user (for netinfo preservation) existingNodeAnyUser, existsAnyUser := s.nodeStore.GetNodeByMachineKeyAnyUser(regEntry.Node.MachineKey) - if existsAnyUser && existingNodeAnyUser.Valid() && existingNodeAnyUser.UserID() != user.ID { + if existsAnyUser && existingNodeAnyUser.Valid() && existingNodeAnyUser.UserID().Get() != user.ID { // Node exists but belongs to a different user // Create a NEW node for the new user (do not transfer) // This allows the same machine to have separate node identities per user @@ -1243,8 +1295,8 @@ func (s *State) HandleNodeFromAuthPath( Str("existing.node.name", existingNodeAnyUser.Hostname()). Uint64("existing.node.id", existingNodeAnyUser.ID().Uint64()). Str("machine.key", regEntry.Node.MachineKey.ShortString()). - Str("old.user", oldUser.Username()). - Str("new.user", user.Username()). + Str("old.user", oldUser.Name()). + Str("new.user", user.Name). Str("method", registrationMethod). Msg("Creating new node for different user (same machine key exists for another user)") } @@ -1253,7 +1305,7 @@ func (s *State) HandleNodeFromAuthPath( log.Debug(). Caller(). Str("registration_id", registrationID.String()). - Str("user.name", user.Username()). + Str("user.name", user.Name). Str("registrationMethod", registrationMethod). Str("expiresAt", fmt.Sprintf("%v", expiry)). Msg("Registering new node from auth callback") @@ -1416,8 +1468,11 @@ func (s *State) HandleNodeFromPreAuthKey( node.RegisterMethod = util.RegisterMethodAuthKey - // TODO(kradalby): This might need a rework as part of #2417 - node.ForcedTags = pak.Proto().GetAclTags() + // CRITICAL: Tags from PreAuthKey are ONLY applied during initial authentication + // On re-registration, we MUST NOT change tags or node ownership + // The node keeps whatever tags/user ownership it already has + // + // Only update AuthKey reference node.AuthKey = pak node.AuthKeyID = &pak.ID node.IsOnline = ptr.To(false) @@ -1467,7 +1522,7 @@ func (s *State) HandleNodeFromPreAuthKey( // Check if node exists with this machine key for a different user existingNodeAnyUser, existsAnyUser := s.nodeStore.GetNodeByMachineKeyAnyUser(machineKey) - if existsAnyUser && existingNodeAnyUser.Valid() && existingNodeAnyUser.UserID() != pak.User.ID { + if existsAnyUser && existingNodeAnyUser.Valid() && existingNodeAnyUser.UserID().Get() != pak.User.ID { // Node exists but belongs to a different user // Create a NEW node for the new user (do not transfer) // This allows the same machine to have separate node identities per user @@ -1477,7 +1532,7 @@ func (s *State) HandleNodeFromPreAuthKey( Str("existing.node.name", existingNodeAnyUser.Hostname()). Uint64("existing.node.id", existingNodeAnyUser.ID().Uint64()). Str("machine.key", machineKey.ShortString()). - Str("old.user", oldUser.Username()). + Str("old.user", oldUser.Name()). Str("new.user", pak.User.Username()). Msg("Creating new node for different user (same machine key exists for another user)") } @@ -1488,7 +1543,7 @@ func (s *State) HandleNodeFromPreAuthKey( // Create and save new node var err error finalNode, err = s.createAndSaveNewNode(newNodeParams{ - User: pak.User, + User: *pak.User, MachineKey: machineKey, NodeKey: regReq.NodeKey, DiscoKey: key.DiscoPublic{}, // DiscoKey not available in RegisterRequest diff --git a/hscontrol/state/tags.go b/hscontrol/state/tags.go new file mode 100644 index 00000000..c1dd3127 --- /dev/null +++ b/hscontrol/state/tags.go @@ -0,0 +1,107 @@ +package state + +import ( + "errors" + "fmt" + "slices" + "strings" + + "github.com/juanfont/headscale/hscontrol/types" + "github.com/rs/zerolog/log" +) + +var ( + // ErrNodeMarkedTaggedButHasNoTags is returned when a node is marked as tagged but has no tags. + ErrNodeMarkedTaggedButHasNoTags = errors.New("node marked as tagged but has no tags") + + // ErrNodeHasNeitherUserNorTags is returned when a node has neither a user nor tags. + ErrNodeHasNeitherUserNorTags = errors.New("node has neither user nor tags - must be owned by user or tagged") + + // ErrInvalidOrUnauthorizedTags is returned when tags are invalid or unauthorized. + ErrInvalidOrUnauthorizedTags = errors.New("invalid or unauthorized tags") +) + +// validateNodeOwnership ensures proper node ownership model. +// A node must be EITHER user-owned OR tagged (mutually exclusive by behavior). +// Tagged nodes CAN have a UserID for "created by" tracking, but the tag is the owner. +func validateNodeOwnership(node *types.Node) error { + isTagged := node.IsTagged() + + // Tagged nodes: Must have tags, UserID is optional (just "created by") + if isTagged { + if len(node.Tags) == 0 { + return fmt.Errorf("%w: %q", ErrNodeMarkedTaggedButHasNoTags, node.Hostname) + } + // UserID can be set (created by) or nil (orphaned), both valid for tagged nodes + return nil + } + + // User-owned nodes: Must have UserID, must NOT have tags + if node.UserID == nil { + return fmt.Errorf("%w: %q", ErrNodeHasNeitherUserNorTags, node.Hostname) + } + + return nil +} + +// validateAndNormalizeTags validates tags against policy and normalizes them. +// Returns validated and normalized tags, or an error if validation fails. +func (s *State) validateAndNormalizeTags(node *types.Node, requestedTags []string) ([]string, error) { + if len(requestedTags) == 0 { + return nil, nil + } + + var ( + validTags []string + invalidTags []string + ) + + for _, tag := range requestedTags { + // Validate format + if !strings.HasPrefix(tag, "tag:") { + invalidTags = append(invalidTags, tag) + continue + } + + // Validate against policy + nodeView := node.View() + if s.polMan.NodeCanHaveTag(nodeView, tag) { + validTags = append(validTags, tag) + } else { + invalidTags = append(invalidTags, tag) + } + } + + if len(invalidTags) > 0 { + return nil, fmt.Errorf("%w: %v", ErrInvalidOrUnauthorizedTags, invalidTags) + } + + // Normalize: sort and deduplicate + slices.Sort(validTags) + + return slices.Compact(validTags), nil +} + +// logTagOperation logs tag assignment operations for audit purposes. +func logTagOperation(existingNode types.NodeView, newTags []string) { + if existingNode.IsTagged() { + log.Info(). + Uint64("node.id", existingNode.ID().Uint64()). + Str("node.name", existingNode.Hostname()). + Strs("old.tags", existingNode.Tags().AsSlice()). + Strs("new.tags", newTags). + Msg("Updating tags on already-tagged node") + } else { + var userID uint + if existingNode.UserID().Valid() { + userID = existingNode.UserID().Get() + } + + log.Info(). + Uint64("node.id", existingNode.ID().Uint64()). + Str("node.name", existingNode.Hostname()). + Uint("created.by.user", userID). + Strs("new.tags", newTags). + Msg("Converting user-owned node to tagged node (irreversible)") + } +} diff --git a/hscontrol/types/node.go b/hscontrol/types/node.go index 05eb8a35..f8264392 100644 --- a/hscontrol/types/node.go +++ b/hscontrol/types/node.go @@ -6,7 +6,6 @@ import ( "net/netip" "regexp" "slices" - "sort" "strconv" "strings" "time" @@ -28,6 +27,7 @@ var ( ErrHostnameTooLong = errors.New("hostname too long, cannot except 255 ASCII chars") ErrNodeHasNoGivenName = errors.New("node has no given name") ErrNodeUserHasNoName = errors.New("node user has no name") + ErrCannotRemoveAllTags = errors.New("cannot remove all tags from node") invalidDNSRegex = regexp.MustCompile("[^a-z0-9-.]+") ) @@ -97,16 +97,21 @@ type Node struct { // GivenName is the name used in all DNS related // parts of headscale. GivenName string `gorm:"type:varchar(63);unique_index"` - UserID uint - User User `gorm:"constraint:OnDelete:CASCADE;"` + + // UserID is set for ALL nodes (tagged and user-owned) to track "created by". + // For tagged nodes, this is informational only - the tag is the owner. + // For user-owned nodes, this identifies the owner. + // Only nil for orphaned nodes (should not happen in normal operation). + UserID *uint + User *User `gorm:"constraint:OnDelete:CASCADE;"` RegisterMethod string - // ForcedTags are tags set by CLI/API. It is not considered - // the source of truth, but is one of the sources from - // which a tag might originate. - // ForcedTags are _always_ applied to the node. - ForcedTags []string `gorm:"column:forced_tags;serializer:json"` + // Tags is the definitive owner for tagged nodes. + // When non-empty, the node is "tagged" and tags define its identity. + // Empty for user-owned nodes. + // Tags cannot be removed once set (one-way transition). + Tags []string `gorm:"column:tags;serializer:json"` // When a node has been created with a PreAuthKey, we need to // prevent the preauthkey from being deleted before the node. @@ -196,55 +201,32 @@ func (node *Node) HasIP(i netip.Addr) bool { return false } -// IsTagged reports if a device is tagged -// and therefore should not be treated as a -// user owned device. -// Currently, this function only handles tags set -// via CLI ("forced tags" and preauthkeys). +// IsTagged reports if a device is tagged and therefore should not be treated +// as a user-owned device. +// When a node has tags, the tags define its identity (not the user). func (node *Node) IsTagged() bool { - if len(node.ForcedTags) > 0 { - return true - } + return len(node.Tags) > 0 +} - if node.AuthKey != nil && len(node.AuthKey.Tags) > 0 { - return true - } - - if node.Hostinfo == nil { - return false - } - - // TODO(kradalby): Figure out how tagging should work - // and hostinfo.requestedtags. - // Do this in other work. - - return false +// IsUserOwned returns true if node is owned by a user (not tagged). +// Tagged nodes may have a UserID for "created by" tracking, but the tag is the owner. +func (node *Node) IsUserOwned() bool { + return !node.IsTagged() } // HasTag reports if a node has a given tag. -// Currently, this function only handles tags set -// via CLI ("forced tags" and preauthkeys). func (node *Node) HasTag(tag string) bool { - return slices.Contains(node.Tags(), tag) + return slices.Contains(node.Tags, tag) } -func (node *Node) Tags() []string { - var tags []string - - if node.AuthKey != nil { - tags = append(tags, node.AuthKey.Tags...) +// TypedUserID returns the UserID as a typed UserID type. +// Returns 0 if UserID is nil. +func (node *Node) TypedUserID() UserID { + if node.UserID == nil { + return 0 } - // TODO(kradalby): Figure out how tagging should work - // and hostinfo.requestedtags. - // Do this in other work. - // #2417 - - tags = append(tags, node.ForcedTags...) - sort.Strings(tags) - tags = slices.Compact(tags) - - return tags + return UserID(*node.UserID) } func (node *Node) RequestTags() []string { @@ -389,8 +371,8 @@ func (node *Node) Proto() *v1.Node { IpAddresses: node.IPsAsString(), Name: node.Hostname, GivenName: node.GivenName, - User: node.User.Proto(), - ForcedTags: node.ForcedTags, + User: nil, // Will be set below based on node type + ForcedTags: node.Tags, Online: node.IsOnline != nil && *node.IsOnline, // Only ApprovedRoutes and AvailableRoutes is set here. SubnetRoutes has @@ -404,6 +386,13 @@ func (node *Node) Proto() *v1.Node { CreatedAt: timestamppb.New(node.CreatedAt), } + // Set User field based on node ownership + // Note: User will be set to TaggedDevices in the gRPC layer (grpcv1.go) + // for proper MapResponse formatting + if node.User != nil { + nodeProto.User = node.User.Proto() + } + if node.AuthKey != nil { nodeProto.PreAuthKey = node.AuthKey.Proto() } @@ -701,8 +690,20 @@ func (nodes Nodes) DebugString() string { func (node Node) DebugString() string { var sb strings.Builder fmt.Fprintf(&sb, "%s(%s):\n", node.Hostname, node.ID) - fmt.Fprintf(&sb, "\tUser: %s (%d, %q)\n", node.User.Display(), node.User.ID, node.User.Username()) - fmt.Fprintf(&sb, "\tTags: %v\n", node.Tags()) + + // Show ownership status + if node.IsTagged() { + fmt.Fprintf(&sb, "\tTagged: %v\n", node.Tags) + + if node.User != nil { + fmt.Fprintf(&sb, "\tCreated by: %s (%d, %q)\n", node.User.Display(), node.User.ID, node.User.Username()) + } + } else if node.User != nil { + fmt.Fprintf(&sb, "\tUser-owned: %s (%d, %q)\n", node.User.Display(), node.User.ID, node.User.Username()) + } else { + fmt.Fprintf(&sb, "\tOrphaned: no user or tags\n") + } + fmt.Fprintf(&sb, "\tIPs: %v\n", node.IPs()) fmt.Fprintf(&sb, "\tApprovedRoutes: %v\n", node.ApprovedRoutes) fmt.Fprintf(&sb, "\tAnnouncedRoutes: %v\n", node.AnnouncedRoutes()) @@ -714,8 +715,7 @@ func (node Node) DebugString() string { } func (v NodeView) UserView() UserView { - u := v.User() - return u.View() + return v.User() } func (v NodeView) IPs() []netip.Addr { @@ -790,13 +790,6 @@ func (v NodeView) RequestTagsSlice() views.Slice[string] { return v.Hostinfo().RequestTags() } -func (v NodeView) Tags() []string { - if !v.Valid() { - return nil - } - return v.ж.Tags() -} - // IsTagged reports if a device is tagged // and therefore should not be treated as a // user owned device. @@ -893,6 +886,32 @@ func (v NodeView) HasTag(tag string) bool { return v.ж.HasTag(tag) } +// TypedUserID returns the UserID as a typed UserID type. +// Returns 0 if UserID is nil or node is invalid. +func (v NodeView) TypedUserID() UserID { + if !v.Valid() { + return 0 + } + + return v.ж.TypedUserID() +} + +// TailscaleUserID returns the user ID to use in Tailscale protocol. +// Tagged nodes always return TaggedDevices.ID, user-owned nodes return their actual UserID. +func (v NodeView) TailscaleUserID() tailcfg.UserID { + if !v.Valid() { + return 0 + } + + if v.IsTagged() { + //nolint:gosec // G115: TaggedDevices.ID is a constant that fits in int64 + return tailcfg.UserID(int64(TaggedDevices.ID)) + } + + //nolint:gosec // G115: UserID values are within int64 range + return tailcfg.UserID(int64(v.UserID().Get())) +} + // Prefixes returns the node IPs as netip.Prefix. func (v NodeView) Prefixes() []netip.Prefix { if !v.Valid() { diff --git a/hscontrol/types/node_tags_test.go b/hscontrol/types/node_tags_test.go new file mode 100644 index 00000000..72598b3c --- /dev/null +++ b/hscontrol/types/node_tags_test.go @@ -0,0 +1,295 @@ +package types + +import ( + "testing" + + "github.com/juanfont/headscale/hscontrol/util" + "github.com/stretchr/testify/assert" + "gorm.io/gorm" + "tailscale.com/types/ptr" +) + +// TestNodeIsTagged tests the IsTagged() method for determining if a node is tagged. +func TestNodeIsTagged(t *testing.T) { + tests := []struct { + name string + node Node + want bool + }{ + { + name: "node with tags - is tagged", + node: Node{ + Tags: []string{"tag:server", "tag:prod"}, + }, + want: true, + }, + { + name: "node with single tag - is tagged", + node: Node{ + Tags: []string{"tag:web"}, + }, + want: true, + }, + { + name: "node with no tags - not tagged", + node: Node{ + Tags: []string{}, + }, + want: false, + }, + { + name: "node with nil tags - not tagged", + node: Node{ + Tags: nil, + }, + want: false, + }, + { + // Tags should be copied from AuthKey during registration, so a node + // with only AuthKey.Tags and no Tags would be invalid in practice. + // IsTagged() only checks node.Tags, not AuthKey.Tags. + name: "node registered with tagged authkey only - not tagged (tags should be copied)", + node: Node{ + AuthKey: &PreAuthKey{ + Tags: []string{"tag:database"}, + }, + }, + want: false, + }, + { + name: "node with both tags and authkey tags - is tagged", + node: Node{ + Tags: []string{"tag:server"}, + AuthKey: &PreAuthKey{ + Tags: []string{"tag:database"}, + }, + }, + want: true, + }, + { + name: "node with user and no tags - not tagged", + node: Node{ + UserID: ptr.To(uint(42)), + Tags: []string{}, + }, + want: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := tt.node.IsTagged() + assert.Equal(t, tt.want, got, "IsTagged() returned unexpected value") + }) + } +} + +// TestNodeViewIsTagged tests the IsTagged() method on NodeView. +func TestNodeViewIsTagged(t *testing.T) { + tests := []struct { + name string + node Node + want bool + }{ + { + name: "tagged node via Tags field", + node: Node{ + Tags: []string{"tag:server"}, + }, + want: true, + }, + { + // Tags should be copied from AuthKey during registration, so a node + // with only AuthKey.Tags and no Tags would be invalid in practice. + name: "node with only AuthKey tags - not tagged (tags should be copied)", + node: Node{ + AuthKey: &PreAuthKey{ + Tags: []string{"tag:web"}, + }, + }, + want: false, // IsTagged() only checks node.Tags + }, + { + name: "user-owned node", + node: Node{ + UserID: ptr.To(uint(1)), + }, + want: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + view := tt.node.View() + got := view.IsTagged() + assert.Equal(t, tt.want, got, "NodeView.IsTagged() returned unexpected value") + }) + } +} + +// TestNodeHasTag tests the HasTag() method for checking specific tag membership. +func TestNodeHasTag(t *testing.T) { + tests := []struct { + name string + node Node + tag string + want bool + }{ + { + name: "node has the tag", + node: Node{ + Tags: []string{"tag:server", "tag:prod"}, + }, + tag: "tag:server", + want: true, + }, + { + name: "node does not have the tag", + node: Node{ + Tags: []string{"tag:server", "tag:prod"}, + }, + tag: "tag:web", + want: false, + }, + { + // Tags should be copied from AuthKey during registration + // HasTag() only checks node.Tags, not AuthKey.Tags + name: "node has tag only in authkey - returns false", + node: Node{ + AuthKey: &PreAuthKey{ + Tags: []string{"tag:database"}, + }, + }, + tag: "tag:database", + want: false, + }, + { + // node.Tags is what matters, not AuthKey.Tags + name: "node has tag in Tags but not in AuthKey", + node: Node{ + Tags: []string{"tag:server"}, + AuthKey: &PreAuthKey{ + Tags: []string{"tag:database"}, + }, + }, + tag: "tag:server", + want: true, + }, + { + name: "invalid tag format still returns false", + node: Node{ + Tags: []string{"tag:server"}, + }, + tag: "invalid-tag", + want: false, + }, + { + name: "empty tag returns false", + node: Node{ + Tags: []string{"tag:server"}, + }, + tag: "", + want: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := tt.node.HasTag(tt.tag) + assert.Equal(t, tt.want, got, "HasTag() returned unexpected value") + }) + } +} + +// TestNodeTagsImmutableAfterRegistration tests that tags can only be set during registration. +func TestNodeTagsImmutableAfterRegistration(t *testing.T) { + // Test that a node registered with tags keeps them + taggedNode := Node{ + ID: 1, + Tags: []string{"tag:server"}, + AuthKey: &PreAuthKey{ + Tags: []string{"tag:server"}, + }, + RegisterMethod: util.RegisterMethodAuthKey, + } + + // Node should be tagged + assert.True(t, taggedNode.IsTagged(), "Node registered with tags should be tagged") + + // Node should have the tag + has := taggedNode.HasTag("tag:server") + assert.True(t, has, "Node should have the tag it was registered with") + + // Test that a user-owned node is not tagged + userNode := Node{ + ID: 2, + UserID: ptr.To(uint(42)), + Tags: []string{}, + RegisterMethod: util.RegisterMethodOIDC, + } + + assert.False(t, userNode.IsTagged(), "User-owned node should not be tagged") +} + +// TestNodeOwnershipModel tests the tags-as-identity model. +func TestNodeOwnershipModel(t *testing.T) { + tests := []struct { + name string + node Node + wantIsTagged bool + description string + }{ + { + name: "tagged node has tags, UserID is informational", + node: Node{ + ID: 1, + UserID: ptr.To(uint(5)), // "created by" user 5 + Tags: []string{"tag:server"}, + }, + wantIsTagged: true, + description: "Tagged nodes may have UserID set for tracking, but ownership is defined by tags", + }, + { + name: "user-owned node has no tags", + node: Node{ + ID: 2, + UserID: ptr.To(uint(5)), + Tags: []string{}, + }, + wantIsTagged: false, + description: "User-owned nodes are owned by the user, not by tags", + }, + { + // Tags should be copied from AuthKey to Node during registration + // IsTagged() only checks node.Tags, not AuthKey.Tags + name: "node with only authkey tags - not tagged (tags should be copied)", + node: Node{ + ID: 3, + UserID: ptr.To(uint(5)), // "created by" user 5 + AuthKey: &PreAuthKey{ + Tags: []string{"tag:database"}, + }, + }, + wantIsTagged: false, + description: "IsTagged() only checks node.Tags; AuthKey.Tags should be copied during registration", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := tt.node.IsTagged() + assert.Equal(t, tt.wantIsTagged, got, tt.description) + }) + } +} + +// TestUserTypedID tests the TypedID() helper method. +func TestUserTypedID(t *testing.T) { + user := User{ + Model: gorm.Model{ID: 42}, + } + + typedID := user.TypedID() + assert.NotNil(t, typedID, "TypedID() should return non-nil pointer") + assert.Equal(t, UserID(42), *typedID, "TypedID() should return correct UserID value") +} diff --git a/hscontrol/types/node_test.go b/hscontrol/types/node_test.go index c992219e..9518833f 100644 --- a/hscontrol/types/node_test.go +++ b/hscontrol/types/node_test.go @@ -139,7 +139,7 @@ func TestNodeFQDN(t *testing.T) { name: "no-dnsconfig-with-username", node: Node{ GivenName: "test", - User: User{ + User: &User{ Name: "user", }, }, @@ -150,7 +150,7 @@ func TestNodeFQDN(t *testing.T) { name: "all-set", node: Node{ GivenName: "test", - User: User{ + User: &User{ Name: "user", }, }, @@ -160,7 +160,7 @@ func TestNodeFQDN(t *testing.T) { { name: "no-given-name", node: Node{ - User: User{ + User: &User{ Name: "user", }, }, @@ -179,7 +179,7 @@ func TestNodeFQDN(t *testing.T) { name: "no-dnsconfig", node: Node{ GivenName: "test", - User: User{ + User: &User{ Name: "user", }, }, diff --git a/hscontrol/types/preauth_key.go b/hscontrol/types/preauth_key.go index 1081f451..2ce02f02 100644 --- a/hscontrol/types/preauth_key.go +++ b/hscontrol/types/preauth_key.go @@ -23,16 +23,19 @@ type PreAuthKey struct { Prefix string Hash []byte // bcrypt - UserID uint - User User `gorm:"constraint:OnDelete:SET NULL;"` + // For tagged keys: UserID tracks who created the key (informational) + // For user-owned keys: UserID tracks the node owner + // Can be nil for system-created tagged keys + UserID *uint + User *User `gorm:"constraint:OnDelete:SET NULL;"` + Reusable bool Ephemeral bool `gorm:"default:false"` Used bool `gorm:"default:false"` - // Tags are always applied to the node and is one of - // the sources of tags a node might have. They are copied - // from the PreAuthKey when the node logs in the first time, - // and ignored after. + // Tags to assign to nodes registered with this key. + // Tags are copied to the node during registration. + // If non-empty, this creates tagged nodes (not user-owned). Tags []string `gorm:"serializer:json"` CreatedAt *time.Time @@ -48,19 +51,23 @@ type PreAuthKeyNew struct { Tags []string Expiration *time.Time CreatedAt *time.Time - User User + User *User // Can be nil for system-created tagged keys } func (key *PreAuthKeyNew) Proto() *v1.PreAuthKey { protoKey := v1.PreAuthKey{ Id: key.ID, Key: key.Key, - User: key.User.Proto(), + User: nil, // Will be set below if not nil Reusable: key.Reusable, Ephemeral: key.Ephemeral, AclTags: key.Tags, } + if key.User != nil { + protoKey.User = key.User.Proto() + } + if key.Expiration != nil { protoKey.Expiration = timestamppb.New(*key.Expiration) } @@ -74,7 +81,7 @@ func (key *PreAuthKeyNew) Proto() *v1.PreAuthKey { func (key *PreAuthKey) Proto() *v1.PreAuthKey { protoKey := v1.PreAuthKey{ - User: key.User.Proto(), + User: nil, // Will be set below if not nil Id: key.ID, Ephemeral: key.Ephemeral, Reusable: key.Reusable, @@ -82,6 +89,10 @@ func (key *PreAuthKey) Proto() *v1.PreAuthKey { AclTags: key.Tags, } + if key.User != nil { + protoKey.User = key.User.Proto() + } + // For new keys (with prefix/hash), show the prefix so users can identify the key // For legacy keys (with plaintext key), show the full key for backwards compatibility if key.Prefix != "" { @@ -139,3 +150,9 @@ func (pak *PreAuthKey) Validate() error { return nil } + +// IsTagged returns true if this PreAuthKey creates tagged nodes. +// When a PreAuthKey has tags, nodes registered with it will be tagged nodes. +func (pak *PreAuthKey) IsTagged() bool { + return len(pak.Tags) > 0 +} diff --git a/hscontrol/types/types_clone.go b/hscontrol/types/types_clone.go index 7699fb8f..4dfeedc2 100644 --- a/hscontrol/types/types_clone.go +++ b/hscontrol/types/types_clone.go @@ -54,7 +54,13 @@ func (src *Node) Clone() *Node { if dst.IPv6 != nil { dst.IPv6 = ptr.To(*src.IPv6) } - dst.ForcedTags = append(src.ForcedTags[:0:0], src.ForcedTags...) + if dst.UserID != nil { + dst.UserID = ptr.To(*src.UserID) + } + if dst.User != nil { + dst.User = ptr.To(*src.User) + } + dst.Tags = append(src.Tags[:0:0], src.Tags...) if dst.AuthKeyID != nil { dst.AuthKeyID = ptr.To(*src.AuthKeyID) } @@ -87,10 +93,10 @@ var _NodeCloneNeedsRegeneration = Node(struct { IPv6 *netip.Addr Hostname string GivenName string - UserID uint - User User + UserID *uint + User *User RegisterMethod string - ForcedTags []string + Tags []string AuthKeyID *uint64 AuthKey *PreAuthKey Expiry *time.Time @@ -111,6 +117,12 @@ func (src *PreAuthKey) Clone() *PreAuthKey { dst := new(PreAuthKey) *dst = *src dst.Hash = append(src.Hash[:0:0], src.Hash...) + if dst.UserID != nil { + dst.UserID = ptr.To(*src.UserID) + } + if dst.User != nil { + dst.User = ptr.To(*src.User) + } dst.Tags = append(src.Tags[:0:0], src.Tags...) if dst.CreatedAt != nil { dst.CreatedAt = ptr.To(*src.CreatedAt) @@ -127,8 +139,8 @@ var _PreAuthKeyCloneNeedsRegeneration = PreAuthKey(struct { Key string Prefix string Hash []byte - UserID uint - User User + UserID *uint + User *User Reusable bool Ephemeral bool Used bool diff --git a/hscontrol/types/types_view.go b/hscontrol/types/types_view.go index 076f5dbb..753e86d3 100644 --- a/hscontrol/types/types_view.go +++ b/hscontrol/types/types_view.go @@ -139,12 +139,13 @@ func (v NodeView) IPv4() views.ValuePointer[netip.Addr] { return views.ValuePo func (v NodeView) IPv6() views.ValuePointer[netip.Addr] { return views.ValuePointerOf(v.ж.IPv6) } -func (v NodeView) Hostname() string { return v.ж.Hostname } -func (v NodeView) GivenName() string { return v.ж.GivenName } -func (v NodeView) UserID() uint { return v.ж.UserID } -func (v NodeView) User() User { return v.ж.User } +func (v NodeView) Hostname() string { return v.ж.Hostname } +func (v NodeView) GivenName() string { return v.ж.GivenName } +func (v NodeView) UserID() views.ValuePointer[uint] { return views.ValuePointerOf(v.ж.UserID) } + +func (v NodeView) User() UserView { return v.ж.User.View() } func (v NodeView) RegisterMethod() string { return v.ж.RegisterMethod } -func (v NodeView) ForcedTags() views.Slice[string] { return views.SliceOf(v.ж.ForcedTags) } +func (v NodeView) Tags() views.Slice[string] { return views.SliceOf(v.ж.Tags) } func (v NodeView) AuthKeyID() views.ValuePointer[uint64] { return views.ValuePointerOf(v.ж.AuthKeyID) } func (v NodeView) AuthKey() PreAuthKeyView { return v.ж.AuthKey.View() } @@ -179,10 +180,10 @@ var _NodeViewNeedsRegeneration = Node(struct { IPv6 *netip.Addr Hostname string GivenName string - UserID uint - User User + UserID *uint + User *User RegisterMethod string - ForcedTags []string + Tags []string AuthKeyID *uint64 AuthKey *PreAuthKey Expiry *time.Time @@ -239,16 +240,17 @@ func (v *PreAuthKeyView) UnmarshalJSON(b []byte) error { return nil } -func (v PreAuthKeyView) ID() uint64 { return v.ж.ID } -func (v PreAuthKeyView) Key() string { return v.ж.Key } -func (v PreAuthKeyView) Prefix() string { return v.ж.Prefix } -func (v PreAuthKeyView) Hash() views.ByteSlice[[]byte] { return views.ByteSliceOf(v.ж.Hash) } -func (v PreAuthKeyView) UserID() uint { return v.ж.UserID } -func (v PreAuthKeyView) User() User { return v.ж.User } -func (v PreAuthKeyView) Reusable() bool { return v.ж.Reusable } -func (v PreAuthKeyView) Ephemeral() bool { return v.ж.Ephemeral } -func (v PreAuthKeyView) Used() bool { return v.ж.Used } -func (v PreAuthKeyView) Tags() views.Slice[string] { return views.SliceOf(v.ж.Tags) } +func (v PreAuthKeyView) ID() uint64 { return v.ж.ID } +func (v PreAuthKeyView) Key() string { return v.ж.Key } +func (v PreAuthKeyView) Prefix() string { return v.ж.Prefix } +func (v PreAuthKeyView) Hash() views.ByteSlice[[]byte] { return views.ByteSliceOf(v.ж.Hash) } +func (v PreAuthKeyView) UserID() views.ValuePointer[uint] { return views.ValuePointerOf(v.ж.UserID) } + +func (v PreAuthKeyView) User() UserView { return v.ж.User.View() } +func (v PreAuthKeyView) Reusable() bool { return v.ж.Reusable } +func (v PreAuthKeyView) Ephemeral() bool { return v.ж.Ephemeral } +func (v PreAuthKeyView) Used() bool { return v.ж.Used } +func (v PreAuthKeyView) Tags() views.Slice[string] { return views.SliceOf(v.ж.Tags) } func (v PreAuthKeyView) CreatedAt() views.ValuePointer[time.Time] { return views.ValuePointerOf(v.ж.CreatedAt) } @@ -263,8 +265,8 @@ var _PreAuthKeyViewNeedsRegeneration = PreAuthKey(struct { Key string Prefix string Hash []byte - UserID uint - User User + UserID *uint + User *User Reusable bool Ephemeral bool Used bool diff --git a/hscontrol/types/users.go b/hscontrol/types/users.go index b7cb1038..2e78386c 100644 --- a/hscontrol/types/users.go +++ b/hscontrol/types/users.go @@ -22,6 +22,21 @@ type UserID uint64 type Users []User +const ( + // TaggedDevicesUserID is the special user ID for tagged devices. + // This ID is used when rendering tagged nodes in the Tailscale protocol. + TaggedDevicesUserID = 2147455555 +) + +// TaggedDevices is a special user used in MapResponse for tagged nodes. +// Tagged nodes don't belong to a real user - the tag is their identity. +// This special user ID is used when rendering tagged nodes in the Tailscale protocol. +var TaggedDevices = User{ + Model: gorm.Model{ID: TaggedDevicesUserID}, + Name: "tagged-devices", + DisplayName: "Tagged Devices", +} + func (u Users) String() string { var sb strings.Builder sb.WriteString("[ ") @@ -77,6 +92,13 @@ func (u *User) StringID() string { return strconv.FormatUint(uint64(u.ID), 10) } +// TypedID returns a pointer to the user's ID as a UserID type. +// This is a convenience method to avoid ugly casting like ptr.To(types.UserID(user.ID)). +func (u *User) TypedID() *UserID { + uid := UserID(u.ID) + return &uid +} + // Username is the main way to get the username of a user, // it will return the email if it exists, the name if it exists, // the OIDCIdentifier if it exists, and the ID if nothing else exists. @@ -117,6 +139,13 @@ func (u UserView) TailscaleUser() tailcfg.User { return u.ж.TailscaleUser() } +// ID returns the user's ID. +// This is a custom accessor because gorm.Model.ID is embedded +// and the viewer generator doesn't always produce it. +func (u UserView) ID() uint { + return u.ж.ID +} + func (u *User) TailscaleLogin() tailcfg.Login { return tailcfg.Login{ ID: tailcfg.LoginID(u.ID), diff --git a/integration/cli_test.go b/integration/cli_test.go index fd2321b4..cf1badb2 100644 --- a/integration/cli_test.go +++ b/integration/cli_test.go @@ -4,6 +4,7 @@ import ( "cmp" "encoding/json" "fmt" + "slices" "strconv" "strings" "testing" @@ -18,7 +19,6 @@ import ( "github.com/juanfont/headscale/integration/tsic" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "golang.org/x/exp/slices" "tailscale.com/tailcfg" ) @@ -643,7 +643,9 @@ func TestPreAuthKeyCorrectUserLoggedInCommand(t *testing.T) { status, err := client.Status() assert.NoError(ct, err) assert.Equal(ct, "Running", status.BackendState, "Expected node to be logged in, backend state: %s", status.BackendState) - assert.Equal(ct, "userid:2", status.Self.UserID.String(), "Expected node to be logged in as userid:2") + // With tags-as-identity model, tagged nodes show as TaggedDevices user (2147455555) + // The PreAuthKey was created with tags, so the node is tagged + assert.Equal(ct, "userid:2147455555", status.Self.UserID.String(), "Expected node to be logged in as tagged-devices user") }, 30*time.Second, 2*time.Second) assert.EventuallyWithT(t, func(ct *assert.CollectT) { @@ -652,7 +654,8 @@ func TestPreAuthKeyCorrectUserLoggedInCommand(t *testing.T) { assert.NoError(ct, err) assert.Len(ct, listNodes, 2, "Should have 2 nodes after re-login") assert.Equal(ct, user1, listNodes[0].GetUser().GetName(), "First node should belong to user1") - assert.Equal(ct, user2, listNodes[1].GetUser().GetName(), "Second node should belong to user2") + // Second node is tagged (created with tagged PreAuthKey), so it shows as "tagged-devices" + assert.Equal(ct, "tagged-devices", listNodes[1].GetUser().GetName(), "Second node should be tagged-devices") }, 20*time.Second, 1*time.Second) } @@ -847,118 +850,455 @@ func TestNodeTagCommand(t *testing.T) { headscale, err := scenario.Headscale() require.NoError(t, err) - regIDs := []string{ - types.MustRegistrationID().String(), - types.MustRegistrationID().String(), - } - nodes := make([]*v1.Node, len(regIDs)) + // Test 1: Verify that tags require authorization via ACL policy + // The tags-as-identity model allows conversion from user-owned to tagged, but only + // if the tag is authorized via tagOwners in the ACL policy. + regID := types.MustRegistrationID().String() + + _, err = headscale.Execute( + []string{ + "headscale", + "debug", + "create-node", + "--name", + "user-owned-node", + "--user", + "user1", + "--key", + regID, + "--output", + "json", + }, + ) assert.NoError(t, err) - for index, regID := range regIDs { - _, err := headscale.Execute( + var userOwnedNode v1.Node + + assert.EventuallyWithT(t, func(c *assert.CollectT) { + err = executeAndUnmarshal( + headscale, []string{ "headscale", - "debug", - "create-node", - "--name", - fmt.Sprintf("node-%d", index+1), + "nodes", "--user", "user1", + "register", "--key", regID, "--output", "json", }, - ) - assert.NoError(t, err) - - var node v1.Node - assert.EventuallyWithT(t, func(c *assert.CollectT) { - err = executeAndUnmarshal( - headscale, - []string{ - "headscale", - "nodes", - "--user", - "user1", - "register", - "--key", - regID, - "--output", - "json", - }, - &node, - ) - assert.NoError(c, err) - }, 10*time.Second, 200*time.Millisecond, "Waiting for node registration") - - nodes[index] = &node - } - assert.EventuallyWithT(t, func(ct *assert.CollectT) { - assert.Len(ct, nodes, len(regIDs), "Should have correct number of nodes after CLI operations") - }, 15*time.Second, 1*time.Second) - - var node v1.Node - assert.EventuallyWithT(t, func(c *assert.CollectT) { - err = executeAndUnmarshal( - headscale, - []string{ - "headscale", - "nodes", - "tag", - "-i", "1", - "-t", "tag:test", - "--output", "json", - }, - &node, + &userOwnedNode, ) assert.NoError(c, err) - }, 10*time.Second, 200*time.Millisecond, "Waiting for node tag command") + }, 10*time.Second, 200*time.Millisecond, "Waiting for user-owned node registration") - assert.Equal(t, []string{"tag:test"}, node.GetForcedTags()) + // Verify node is user-owned (no tags) + assert.Empty(t, userOwnedNode.GetValidTags(), "User-owned node should not have tags") + assert.Empty(t, userOwnedNode.GetForcedTags(), "User-owned node should not have forced tags") + // Attempt to set tags on user-owned node should FAIL because there's no ACL policy + // authorizing the tag. The tags-as-identity model allows conversion from user-owned + // to tagged, but only if the tag is authorized via tagOwners in the ACL policy. _, err = headscale.Execute( []string{ "headscale", "nodes", "tag", - "-i", "2", - "-t", "wrong-tag", + "-i", strconv.FormatUint(userOwnedNode.GetId(), 10), + "-t", "tag:test", "--output", "json", }, ) - assert.ErrorContains(t, err, "tag must start with the string 'tag:'") + require.ErrorContains(t, err, "invalid or unauthorized tags", "Setting unauthorized tags should fail") + + // Test 2: Verify tag format validation + // Create a PreAuthKey with tags to create a tagged node + // Get the user ID from the node + userID := userOwnedNode.GetUser().GetId() + + var preAuthKey v1.PreAuthKey - // Test list all nodes after added seconds - resultMachines := make([]*v1.Node, len(regIDs)) assert.EventuallyWithT(t, func(c *assert.CollectT) { err = executeAndUnmarshal( headscale, []string{ "headscale", - "nodes", - "list", + "preauthkeys", + "--user", strconv.FormatUint(userID, 10), + "create", + "--reusable", + "--tags", "tag:integration-test", "--output", "json", }, - &resultMachines, + &preAuthKey, ) assert.NoError(c, err) - }, 10*time.Second, 200*time.Millisecond, "Waiting for nodes list after tagging") - found := false - for _, node := range resultMachines { - if node.GetForcedTags() != nil { - for _, tag := range node.GetForcedTags() { - if tag == "tag:test" { - found = true - } - } + }, 10*time.Second, 200*time.Millisecond, "Creating PreAuthKey with tags") + + // Verify PreAuthKey has tags + assert.Contains(t, preAuthKey.GetAclTags(), "tag:integration-test", "PreAuthKey should have tags") + + // Test 3: Verify invalid tag format is rejected + _, err = headscale.Execute( + []string{ + "headscale", + "preauthkeys", + "--user", strconv.FormatUint(userID, 10), + "create", + "--tags", "wrong-tag", // Missing "tag:" prefix + "--output", "json", + }, + ) + assert.ErrorContains(t, err, "tag must start with the string 'tag:'", "Invalid tag format should be rejected") +} + +func TestTaggedNodeRegistration(t *testing.T) { + IntegrationSkip(t) + + // ACL policy that authorizes the tags used in tagged PreAuthKeys + // user1 and user2 can assign these tags when creating PreAuthKeys + policy := &policyv2.Policy{ + TagOwners: policyv2.TagOwners{ + "tag:server": policyv2.Owners{usernameOwner("user1@"), usernameOwner("user2@")}, + "tag:prod": policyv2.Owners{usernameOwner("user1@"), usernameOwner("user2@")}, + "tag:forbidden": policyv2.Owners{usernameOwner("user1@"), usernameOwner("user2@")}, + }, + ACLs: []policyv2.ACL{ + { + Action: "accept", + Sources: []policyv2.Alias{policyv2.Wildcard}, + Destinations: []policyv2.AliasWithPorts{{Alias: policyv2.Wildcard, Ports: []tailcfg.PortRange{tailcfg.PortRangeAny}}}, + }, + }, + } + + spec := ScenarioSpec{ + Users: []string{"user1", "user2"}, + } + + scenario, err := NewScenario(spec) + + require.NoError(t, err) + defer scenario.ShutdownAssertNoPanics(t) + + err = scenario.CreateHeadscaleEnv( + []tsic.Option{}, + hsic.WithACLPolicy(policy), + hsic.WithTestName("tagged-reg"), + ) + require.NoError(t, err) + + headscale, err := scenario.Headscale() + require.NoError(t, err) + + // Get users (they were already created by ScenarioSpec) + users, err := headscale.ListUsers() + require.NoError(t, err) + require.Len(t, users, 2, "Should have 2 users") + + var user1, user2 *v1.User + + for _, u := range users { + if u.GetName() == "user1" { + user1 = u + } else if u.GetName() == "user2" { + user2 = u } } - assert.True( - t, - found, - "should find a node with the tag 'tag:test' in the list of nodes", + + require.NotNil(t, user1, "Should find user1") + require.NotNil(t, user2, "Should find user2") + + // Test 1: Create a PreAuthKey with tags + var taggedKey v1.PreAuthKey + assert.EventuallyWithT(t, func(c *assert.CollectT) { + err = executeAndUnmarshal( + headscale, + []string{ + "headscale", + "preauthkeys", + "--user", strconv.FormatUint(user1.GetId(), 10), + "create", + "--reusable", + "--tags", "tag:server,tag:prod", + "--output", "json", + }, + &taggedKey, + ) + assert.NoError(c, err) + }, 10*time.Second, 200*time.Millisecond, "Creating tagged PreAuthKey") + + // Verify PreAuthKey has both tags + assert.Contains(t, taggedKey.GetAclTags(), "tag:server", "PreAuthKey should have tag:server") + assert.Contains(t, taggedKey.GetAclTags(), "tag:prod", "PreAuthKey should have tag:prod") + assert.Len(t, taggedKey.GetAclTags(), 2, "PreAuthKey should have exactly 2 tags") + + // Test 2: Register a node using the tagged PreAuthKey + err = scenario.CreateTailscaleNodesInUser("user1", "unstable", 1, tsic.WithNetwork(scenario.Networks()[0])) + require.NoError(t, err) + + err = scenario.RunTailscaleUp("user1", headscale.GetEndpoint(), taggedKey.GetKey()) + require.NoError(t, err) + + // Wait for the node to be registered + var registeredNode *v1.Node + + assert.EventuallyWithT(t, func(c *assert.CollectT) { + nodes, err := headscale.ListNodes() + assert.NoError(c, err) + assert.GreaterOrEqual(c, len(nodes), 1, "Should have at least 1 node") + + // Find the tagged node - it will have user "tagged-devices" per tags-as-identity model + for _, node := range nodes { + if node.GetUser().GetName() == "tagged-devices" && len(node.GetValidTags()) > 0 { + registeredNode = node + break + } + } + + assert.NotNil(c, registeredNode, "Should find a tagged node") + }, 30*time.Second, 500*time.Millisecond, "Waiting for tagged node registration") + + // Test 3: Verify the registered node has the tags from the PreAuthKey + assert.Contains(t, registeredNode.GetValidTags(), "tag:server", "Node should have tag:server") + assert.Contains(t, registeredNode.GetValidTags(), "tag:prod", "Node should have tag:prod") + assert.Len(t, registeredNode.GetValidTags(), 2, "Node should have exactly 2 tags") + + // Test 4: Verify the node shows as TaggedDevices user (tags-as-identity model) + // Tagged nodes always show as "tagged-devices" in API responses, even though + // internally UserID may be set for "created by" tracking + assert.Equal(t, "tagged-devices", registeredNode.GetUser().GetName(), "Tagged node should show as tagged-devices user") + + // Test 5: Verify the node is identified as tagged + assert.NotEmpty(t, registeredNode.GetValidTags(), "Tagged node should have tags") + + // Test 6: Verify tag modification on tagged nodes + // NOTE: Changing tags requires complex ACL authorization where the node's IP + // must be authorized for the new tags via tagOwners. For simplicity, we skip + // this test and instead verify that tags cannot be arbitrarily changed without + // proper ACL authorization. + // + // This is expected behavior - tag changes must be authorized by ACL policy. + _, err = headscale.Execute( + []string{ + "headscale", + "nodes", + "tag", + "-i", strconv.FormatUint(registeredNode.GetId(), 10), + "-t", "tag:unauthorized", + "--output", "json", + }, ) + // This SHOULD fail because tag:unauthorized is not in our ACL policy + require.ErrorContains(t, err, "invalid or unauthorized tags", "Unauthorized tag should be rejected") + + // Test 7: Create a user-owned node for comparison + var userOwnedKey v1.PreAuthKey + assert.EventuallyWithT(t, func(c *assert.CollectT) { + err = executeAndUnmarshal( + headscale, + []string{ + "headscale", + "preauthkeys", + "--user", strconv.FormatUint(user2.GetId(), 10), + "create", + "--reusable", + "--output", "json", + }, + &userOwnedKey, + ) + assert.NoError(c, err) + }, 10*time.Second, 200*time.Millisecond, "Creating user-owned PreAuthKey") + + // Verify this PreAuthKey has NO tags + assert.Empty(t, userOwnedKey.GetAclTags(), "User-owned PreAuthKey should have no tags") + + err = scenario.CreateTailscaleNodesInUser("user2", "unstable", 1, tsic.WithNetwork(scenario.Networks()[0])) + require.NoError(t, err) + + err = scenario.RunTailscaleUp("user2", headscale.GetEndpoint(), userOwnedKey.GetKey()) + require.NoError(t, err) + + // Wait for the user-owned node to be registered + var userOwnedNode *v1.Node + + assert.EventuallyWithT(t, func(c *assert.CollectT) { + nodes, err := headscale.ListNodes() + assert.NoError(c, err) + assert.GreaterOrEqual(c, len(nodes), 2, "Should have at least 2 nodes") + + // Find the node registered with user2 + for _, node := range nodes { + if node.GetUser().GetName() == "user2" { + userOwnedNode = node + break + } + } + + assert.NotNil(c, userOwnedNode, "Should find a node for user2") + }, 30*time.Second, 500*time.Millisecond, "Waiting for user-owned node registration") + + // Test 8: Verify user-owned node has NO tags + assert.Empty(t, userOwnedNode.GetValidTags(), "User-owned node should have no tags") + assert.NotZero(t, userOwnedNode.GetUser().GetId(), "User-owned node should have UserID") + + // Test 9: Verify attempting to set UNAUTHORIZED tags on user-owned node fails + // Note: Under tags-as-identity model, user-owned nodes CAN be converted to tagged nodes + // if the tags are authorized. We use an unauthorized tag to test rejection. + _, err = headscale.Execute( + []string{ + "headscale", + "nodes", + "tag", + "-i", strconv.FormatUint(userOwnedNode.GetId(), 10), + "-t", "tag:not-in-policy", + "--output", "json", + }, + ) + require.ErrorContains(t, err, "invalid or unauthorized tags", "Setting unauthorized tags should fail") + + // Test 10: Verify basic connectivity - wait for sync + err = scenario.WaitForTailscaleSync() + require.NoError(t, err, "Clients should be able to sync") +} + +// TestTagPersistenceAcrossRestart validates that tags persist across container +// restarts and that re-authentication doesn't re-apply tags from PreAuthKey. +// This is a regression test for issue #2830. +func TestTagPersistenceAcrossRestart(t *testing.T) { + IntegrationSkip(t) + + spec := ScenarioSpec{ + Users: []string{"user1"}, + } + + scenario, err := NewScenario(spec) + + require.NoError(t, err) + defer scenario.ShutdownAssertNoPanics(t) + + err = scenario.CreateHeadscaleEnv([]tsic.Option{}, hsic.WithTestName("tag-persist")) + require.NoError(t, err) + + headscale, err := scenario.Headscale() + require.NoError(t, err) + + // Get user + users, err := headscale.ListUsers() + require.NoError(t, err) + require.Len(t, users, 1) + user1 := users[0] + + // Create a reusable PreAuthKey with tags + var taggedKey v1.PreAuthKey + + assert.EventuallyWithT(t, func(c *assert.CollectT) { + err = executeAndUnmarshal( + headscale, + []string{ + "headscale", + "preauthkeys", + "--user", strconv.FormatUint(user1.GetId(), 10), + "create", + "--reusable", // Critical: key must be reusable for container restart + "--tags", "tag:server,tag:prod", + "--output", "json", + }, + &taggedKey, + ) + assert.NoError(c, err) + }, 10*time.Second, 200*time.Millisecond, "Creating reusable tagged PreAuthKey") + + require.True(t, taggedKey.GetReusable(), "PreAuthKey must be reusable for restart scenario") + require.Contains(t, taggedKey.GetAclTags(), "tag:server") + require.Contains(t, taggedKey.GetAclTags(), "tag:prod") + + // Register initial node with tagged PreAuthKey + err = scenario.CreateTailscaleNodesInUser("user1", "unstable", 1, tsic.WithNetwork(scenario.Networks()[0])) + require.NoError(t, err) + + err = scenario.RunTailscaleUp("user1", headscale.GetEndpoint(), taggedKey.GetKey()) + require.NoError(t, err) + + // Wait for node registration and get initial node state + var initialNode *v1.Node + + assert.EventuallyWithT(t, func(c *assert.CollectT) { + nodes, err := headscale.ListNodes() + assert.NoError(c, err) + assert.GreaterOrEqual(c, len(nodes), 1, "Should have at least 1 node") + + for _, node := range nodes { + if node.GetUser().GetId() == user1.GetId() || node.GetUser().GetName() == "tagged-devices" { + initialNode = node + break + } + } + + assert.NotNil(c, initialNode, "Should find the registered node") + }, 30*time.Second, 500*time.Millisecond, "Waiting for initial node registration") + + // Verify initial tags + require.Contains(t, initialNode.GetValidTags(), "tag:server", "Initial node should have tag:server") + require.Contains(t, initialNode.GetValidTags(), "tag:prod", "Initial node should have tag:prod") + require.Len(t, initialNode.GetValidTags(), 2, "Initial node should have exactly 2 tags") + + initialNodeID := initialNode.GetId() + t.Logf("Initial node registered with ID %d and tags %v", initialNodeID, initialNode.GetValidTags()) + + // Simulate container restart by shutting down and restarting Tailscale client + allClients, err := scenario.ListTailscaleClients() + require.NoError(t, err) + require.Len(t, allClients, 1, "Should have exactly 1 client") + + client := allClients[0] + + // Stop the client (simulates container stop) + err = client.Down() + require.NoError(t, err) + + // Wait a bit to ensure the client is fully stopped + time.Sleep(2 * time.Second) + + // Restart the client with the SAME PreAuthKey (container restart scenario) + // This simulates what happens when a Docker container restarts with a reusable PreAuthKey + err = scenario.RunTailscaleUp("user1", headscale.GetEndpoint(), taggedKey.GetKey()) + require.NoError(t, err) + + // Wait for re-authentication + var nodeAfterRestart *v1.Node + + assert.EventuallyWithT(t, func(c *assert.CollectT) { + nodes, err := headscale.ListNodes() + assert.NoError(c, err) + + for _, node := range nodes { + if node.GetId() == initialNodeID { + nodeAfterRestart = node + break + } + } + + assert.NotNil(c, nodeAfterRestart, "Should find the same node after restart") + }, 30*time.Second, 500*time.Millisecond, "Waiting for node re-authentication") + + // CRITICAL ASSERTION: Tags should NOT be re-applied from PreAuthKey + // Tags are only applied during INITIAL authentication, not re-authentication + // The node should keep its existing tags (which happen to be the same in this case) + assert.Contains(t, nodeAfterRestart.GetValidTags(), "tag:server", "Node should still have tag:server after restart") + assert.Contains(t, nodeAfterRestart.GetValidTags(), "tag:prod", "Node should still have tag:prod after restart") + assert.Len(t, nodeAfterRestart.GetValidTags(), 2, "Node should still have exactly 2 tags after restart") + + // Verify it's the SAME node (same ID), not a new registration + assert.Equal(t, initialNodeID, nodeAfterRestart.GetId(), "Should be the same node, not a new registration") + + // Verify node count hasn't increased (no duplicate nodes) + finalNodes, err := headscale.ListNodes() + require.NoError(t, err) + assert.Len(t, finalNodes, 1, "Should still have exactly 1 node (no duplicates from restart)") + + t.Logf("Container restart validation complete - node %d maintained tags across restart", initialNodeID) } func TestNodeAdvertiseTagCommand(t *testing.T) {