From eb788cd007119472b99787c8c3af6d6204e8b834 Mon Sep 17 00:00:00 2001 From: Kristoffer Dalby Date: Tue, 2 Dec 2025 12:01:25 +0100 Subject: [PATCH] make tags first class node owner (#2885) This PR changes tags to be something that exists on nodes in addition to users, to being its own thing. It is part of moving our tags support towards the correct tailscale compatible implementation. There are probably rough edges in this PR, but the intention is to get it in, and then start fixing bugs from 0.28.0 milestone (long standing tags issue) to discover what works and what doesnt. Updates #2417 Closes #2619 --- .github/workflows/test-integration.yaml | 3 +- AGENTS.md | 336 +++++++++++ CHANGELOG.md | 6 + hscontrol/auth.go | 18 +- hscontrol/auth_tags_test.go | 535 ++++++++++++++++++ hscontrol/auth_test.go | 141 +++-- hscontrol/db/db.go | 15 + hscontrol/db/db_test.go | 3 +- hscontrol/db/ip_test.go | 20 +- hscontrol/db/node.go | 28 +- hscontrol/db/node_test.go | 51 +- hscontrol/db/preauth_keys.go | 60 +- hscontrol/db/preauth_keys_test.go | 26 +- hscontrol/db/schema.sql | 2 +- hscontrol/db/users.go | 6 +- hscontrol/db/users_test.go | 6 +- hscontrol/grpcv1.go | 33 +- hscontrol/grpcv1_test.go | 222 +++++++- hscontrol/mapper/mapper.go | 12 +- hscontrol/mapper/mapper_test.go | 5 +- hscontrol/mapper/tail.go | 5 +- hscontrol/mapper/tail_test.go | 13 +- hscontrol/oidc_template_test.go | 12 - hscontrol/policy/policy_autoapprove_test.go | 14 +- .../policy/policy_route_approval_test.go | 14 +- hscontrol/policy/policy_test.go | 209 ++++--- hscontrol/policy/policyutil/reduce_test.go | 49 +- hscontrol/policy/route_approval_test.go | 33 +- hscontrol/policy/v2/filter.go | 6 +- hscontrol/policy/v2/filter_test.go | 359 ++++++++++-- hscontrol/policy/v2/policy.go | 20 +- hscontrol/policy/v2/policy_test.go | 15 +- hscontrol/policy/v2/types.go | 11 +- hscontrol/policy/v2/types_test.go | 179 +++--- hscontrol/poll.go | 74 +-- hscontrol/state/debug.go | 4 +- hscontrol/state/maprequest_test.go | 5 +- hscontrol/state/node_store.go | 4 +- hscontrol/state/node_store_test.go | 37 +- hscontrol/state/state.go | 91 ++- hscontrol/state/tags.go | 107 ++++ hscontrol/types/node.go | 139 +++-- hscontrol/types/node_tags_test.go | 295 ++++++++++ hscontrol/types/node_test.go | 8 +- hscontrol/types/preauth_key.go | 35 +- hscontrol/types/types_clone.go | 24 +- hscontrol/types/types_view.go | 42 +- hscontrol/types/users.go | 29 + integration/cli_test.go | 498 +++++++++++++--- 49 files changed, 3102 insertions(+), 757 deletions(-) create mode 100644 hscontrol/auth_tags_test.go create mode 100644 hscontrol/state/tags.go create mode 100644 hscontrol/types/node_tags_test.go diff --git a/.github/workflows/test-integration.yaml b/.github/workflows/test-integration.yaml index 62ea3a97..eebadb74 100644 --- a/.github/workflows/test-integration.yaml +++ b/.github/workflows/test-integration.yaml @@ -55,11 +55,12 @@ jobs: - TestPreAuthKeyCorrectUserLoggedInCommand - TestApiKeyCommand - TestNodeTagCommand + - TestTaggedNodeRegistration + - TestTagPersistenceAcrossRestart - TestNodeAdvertiseTagCommand - TestNodeCommand - TestNodeExpireCommand - TestNodeRenameCommand - - TestNodeMoveCommand - TestPolicyCommand - TestPolicyBrokenConfigCommand - TestDERPVerifyEndpoint diff --git a/AGENTS.md b/AGENTS.md index e5dd1b01..42da654c 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -237,6 +237,21 @@ headscale/ - `policy.go`: Policy storage and retrieval - Schema migrations in `schema.sql` with extensive test data coverage +**CRITICAL DATABASE MIGRATION RULES**: + +1. **NEVER reorder existing migrations** - Migration order is immutable once committed +2. **ONLY add new migrations to the END** of the migrations array +3. **NEVER disable foreign keys** in new migrations - no new migrations should be added to `migrationsRequiringFKDisabled` +4. **Migration ID format**: `YYYYMMDDHHSS-short-description` (timestamp + descriptive suffix) + - Example: `202511131500-add-user-roles` + - The timestamp must be chronologically ordered +5. **New migrations go after the comment** "As of 2025-07-02, no new IDs should be added here" +6. If you need to rename a column that other migrations depend on: + - Accept that the old column name will exist in intermediate migration states + - Update code to work with the new column name + - Let AutoMigrate create the new column if needed + - Do NOT try to rename columns that later migrations reference + **Policy Engine (`hscontrol/policy/`)** - `policy.go`: Core ACL evaluation logic, HuJSON parsing @@ -687,6 +702,326 @@ assert.EventuallyWithT(t, func(c *assert.CollectT) { }, 10*time.Second, 500*time.Millisecond, "mixed operations") ``` +## Tags-as-Identity Architecture + +### Overview + +Headscale implements a **tags-as-identity** model where tags and user ownership are mutually exclusive ways to identify nodes. This is a fundamental architectural principle that affects node registration, ownership, ACL evaluation, and API behavior. + +### Core Principle: Tags XOR User Ownership + +Every node in Headscale is **either** tagged **or** user-owned, never both: + +- **Tagged Nodes**: Ownership is defined by tags (e.g., `tag:server`, `tag:database`) + - Tags are set during registration via tagged PreAuthKey + - Tags are immutable after registration (cannot be changed via API) + - May have `UserID` set for "created by" tracking, but ownership is via tags + - Identified by: `node.IsTagged()` returns `true` + +- **User-Owned Nodes**: Ownership is defined by user assignment + - Registered via OIDC, web auth, or untagged PreAuthKey + - Node belongs to a specific user's namespace + - No tags (empty tags array) + - Identified by: `node.UserID().Valid() && !node.IsTagged()` + +### Critical Implementation Details + +#### Node Identification Methods + +```go +// Primary methods for determining node ownership +node.IsTagged() // Returns true if node has tags OR AuthKey.Tags +node.HasTag(tag) // Returns true if node has specific tag +node.IsUserOwned() // Returns true if UserID set AND not tagged + +// IMPORTANT: UserID can be set on tagged nodes for tracking! +// Always use IsTagged() to determine actual ownership, not just UserID.Valid() +``` + +#### UserID Field Semantics + +**Critical distinction**: `UserID` has different meanings depending on node type: + +- **Tagged nodes**: `UserID` is optional "created by" tracking + - Indicates which user created the tagged PreAuthKey + - Does NOT define ownership (tags define ownership) + - Example: User "alice" creates tagged PreAuthKey with `tag:server`, node gets `UserID=alice.ID` + `Tags=["tag:server"]` + +- **User-owned nodes**: `UserID` defines ownership + - Required field for non-tagged nodes + - Defines which user namespace the node belongs to + - Example: User "bob" registers via OIDC, node gets `UserID=bob.ID` + `Tags=[]` + +#### Mapper Behavior (mapper/tail.go) + +The mapper converts internal nodes to Tailscale protocol format, handling the TaggedDevices special user: + +```go +// From mapper/tail.go:102-116 +User: func() tailcfg.UserID { + // IMPORTANT: Tags-as-identity model + // Tagged nodes ALWAYS use TaggedDevices user, even if UserID is set + if node.IsTagged() { + return tailcfg.UserID(int64(types.TaggedDevices.ID)) + } + // User-owned nodes: use the actual user ID + return tailcfg.UserID(int64(node.UserID().Get())) +}() +``` + +**TaggedDevices constant** (`types.TaggedDevices.ID = 2147455555`): Special user ID for all tagged nodes in MapResponse protocol. + +#### Registration Flow + +**Tagged Node Registration** (via tagged PreAuthKey): + +1. User creates PreAuthKey with tags: `pak.Tags = ["tag:server"]` +2. Node registers with PreAuthKey +3. Node gets: `Tags = ["tag:server"]`, `UserID = user.ID` (optional tracking), `AuthKeyID = pak.ID` +4. `IsTagged()` returns `true` (ownership via tags) +5. MapResponse sends `User = TaggedDevices.ID` + +**User-Owned Node Registration** (via OIDC/web/untagged PreAuthKey): + +1. User authenticates or uses untagged PreAuthKey +2. Node registers +3. Node gets: `Tags = []`, `UserID = user.ID` (required) +4. `IsTagged()` returns `false` (ownership via user) +5. MapResponse sends `User = user.ID` + +#### API Validation (SetTags) + +The SetTags gRPC API enforces tags-as-identity rules: + +```go +// From grpcv1.go:340-347 +// User-owned nodes are nodes with UserID that are NOT tagged +isUserOwned := nodeView.UserID().Valid() && !nodeView.IsTagged() +if isUserOwned && len(request.GetTags()) > 0 { + return error("cannot set tags on user-owned nodes") +} +``` + +**Key validation rules**: + +- ✅ Can call SetTags on tagged nodes (tags already define ownership) +- ❌ Cannot set tags on user-owned nodes (would violate XOR rule) +- ❌ Cannot remove all tags from tagged nodes (would orphan the node) + +#### Database Layer (db/node.go) + +**Tag storage**: Tags are stored in PostgreSQL ARRAY column and SQLite JSON column: + +```sql +-- From schema.sql +tags TEXT[] DEFAULT '{}' NOT NULL, -- PostgreSQL +tags TEXT DEFAULT '[]' NOT NULL, -- SQLite (JSON array) +``` + +**Validation** (`state/tags.go`): + +- `validateNodeOwnership()`: Enforces tags XOR user rule +- `validateAndNormalizeTags()`: Validates tag format (`tag:name`) and uniqueness + +#### Policy Layer + +**Tag Ownership** (policy/v2/policy.go): + +```go +func NodeCanHaveTag(node types.NodeView, tag string) bool { + // Checks if node's IP is in the tagOwnerMap IP set + // This is IP-based authorization, not UserID-based + if ips, ok := pm.tagOwnerMap[Tag(tag)]; ok { + if slices.ContainsFunc(node.IPs(), ips.Contains) { + return true + } + } + return false +} +``` + +**Important**: Tag authorization is based on IP ranges in ACL, not UserID. Tags define identity, ACL authorizes that identity. + +### Testing Tags-as-Identity + +**Unit Tests** (`hscontrol/types/node_tags_test.go`): + +- `TestNodeIsTagged`: Validates IsTagged() for various scenarios +- `TestNodeOwnershipModel`: Tests tags XOR user ownership +- `TestUserTypedID`: Helper method validation + +**API Tests** (`hscontrol/grpcv1_test.go`): + +- `TestSetTags_UserXORTags`: Validates rejection of setting tags on user-owned nodes +- `TestSetTags_TaggedNode`: Validates that tagged nodes (even with UserID) are not rejected + +**Auth Tests** (`hscontrol/auth_test.go:890-928`): + +- Tests node registration with tagged PreAuthKey +- Validates tags are applied during registration + +### Common Pitfalls + +1. **Don't check only `UserID.Valid()` to determine user ownership** + - ❌ Wrong: `if node.UserID().Valid() { /* user-owned */ }` + - ✅ Correct: `if node.UserID().Valid() && !node.IsTagged() { /* user-owned */ }` + +2. **Don't assume tagged nodes never have UserID set** + - Tagged nodes MAY have UserID for "created by" tracking + - Always use `IsTagged()` to determine ownership type + +3. **Don't allow setting tags on user-owned nodes** + - This violates the tags XOR user principle + - Use API validation to prevent this + +4. **Don't forget TaggedDevices in mapper** + - All tagged nodes MUST use `TaggedDevices.ID` in MapResponse + - User ID is only for actual user-owned nodes + +### Migration Considerations + +When nodes transition between ownership models: + +- **No automatic migration**: Tags-as-identity is set at registration and immutable +- **Re-registration required**: To change from user-owned to tagged (or vice versa), node must be deleted and re-registered +- **UserID persistence**: UserID on tagged nodes is informational and not cleared + +### Architecture Benefits + +The tags-as-identity model provides: + +1. **Clear ownership semantics**: No ambiguity about who/what owns a node +2. **ACL simplicity**: Tag-based access control without user conflicts +3. **API safety**: Validation prevents invalid ownership states +4. **Protocol compatibility**: TaggedDevices special user aligns with Tailscale's model + +## Logging Patterns + +### Incremental Log Event Building + +When building log statements with multiple fields, especially with conditional fields, use the **incremental log event pattern** instead of long single-line chains. This improves readability and allows conditional field addition. + +**Pattern:** + +```go +// GOOD: Incremental building with conditional fields +logEvent := log.Debug(). + Str("node", node.Hostname). + Str("machine_key", node.MachineKey.ShortString()). + Str("node_key", node.NodeKey.ShortString()) + +if node.User != nil { + logEvent = logEvent.Str("user", node.User.Username()) +} else if node.UserID != nil { + logEvent = logEvent.Uint("user_id", *node.UserID) +} else { + logEvent = logEvent.Str("user", "none") +} + +logEvent.Msg("Registering node") +``` + +**Key rules:** + +1. **Assign chained calls back to the variable**: `logEvent = logEvent.Str(...)` - zerolog methods return a new event, so you must capture the return value +2. **Use for conditional fields**: When fields depend on runtime conditions, build incrementally +3. **Use for long log lines**: When a log line exceeds ~100 characters, split it for readability +4. **Call `.Msg()` at the end**: The final `.Msg()` or `.Msgf()` sends the log event + +**Anti-pattern to avoid:** + +```go +// BAD: Long single-line chains are hard to read and can't have conditional fields +log.Debug().Caller().Str("node", node.Hostname).Str("machine_key", node.MachineKey.ShortString()).Str("node_key", node.NodeKey.ShortString()).Str("user", node.User.Username()).Msg("Registering node") + +// BAD: Forgetting to assign the return value (field is lost!) +logEvent := log.Debug().Str("node", node.Hostname) +logEvent.Str("user", username) // This field is LOST - not assigned back +logEvent.Msg("message") // Only has "node" field +``` + +**When to use this pattern:** + +- Log statements with 4+ fields +- Any log with conditional fields +- Complex logging in loops or error handling +- When you need to add context incrementally + +**Example from codebase** (`hscontrol/db/node.go`): + +```go +logEvent := log.Debug(). + Str("node", node.Hostname). + Str("machine_key", node.MachineKey.ShortString()). + Str("node_key", node.NodeKey.ShortString()) + +if node.User != nil { + logEvent = logEvent.Str("user", node.User.Username()) +} else if node.UserID != nil { + logEvent = logEvent.Uint("user_id", *node.UserID) +} else { + logEvent = logEvent.Str("user", "none") +} + +logEvent.Msg("Registering test node") +``` + +### Avoiding Log Helper Functions + +Prefer the incremental log event pattern over creating helper functions that return multiple logging closures. Helper functions like `logPollFunc` create unnecessary indirection and allocate closures. + +**Instead of:** + +```go +// AVOID: Helper function returning closures +func logPollFunc(req tailcfg.MapRequest, node *types.Node) ( + func(string, ...any), // warnf + func(string, ...any), // infof + func(string, ...any), // tracef + func(error, string, ...any), // errf +) { + return func(msg string, a ...any) { + log.Warn(). + Caller(). + Bool("omitPeers", req.OmitPeers). + Bool("stream", req.Stream). + Uint64("node.id", node.ID.Uint64()). + Str("node.name", node.Hostname). + Msgf(msg, a...) + }, + // ... more closures +} +``` + +**Prefer:** + +```go +// BETTER: Build log events inline with shared context +func (m *mapSession) logTrace(msg string) { + log.Trace(). + Caller(). + Bool("omitPeers", m.req.OmitPeers). + Bool("stream", m.req.Stream). + Uint64("node.id", m.node.ID.Uint64()). + Str("node.name", m.node.Hostname). + Msg(msg) +} + +// Or use incremental building for complex cases +logEvent := log.Trace(). + Caller(). + Bool("omitPeers", m.req.OmitPeers). + Bool("stream", m.req.Stream). + Uint64("node.id", m.node.ID.Uint64()). + Str("node.name", m.node.Hostname) + +if additionalContext { + logEvent = logEvent.Str("extra", value) +} + +logEvent.Msg("Operation completed") +``` + ## Important Notes - **Dependencies**: Use `nix develop` for consistent toolchain (Go, buf, protobuf tools, linting) @@ -697,3 +1032,4 @@ assert.EventuallyWithT(t, func(c *assert.CollectT) { - **Integration Tests**: Require Docker and can consume significant disk space - use headscale-integration-tester agent - **Performance**: NodeStore optimizations are critical for scale - be careful with changes to state management - **Quality Assurance**: Always use appropriate specialized agents for testing and validation tasks +- **Tags-as-Identity**: Tags and user ownership are mutually exclusive - always use `IsTagged()` to determine ownership diff --git a/CHANGELOG.md b/CHANGELOG.md index d812bcef..e5a4163b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -21,6 +21,10 @@ at creation time. When listing keys, only the prefix is shown (e.g., `hskey-auth-{prefix}-{secret}`. Legacy plaintext keys continue to work for backwards compatibility. +### Tags + +Tags are now implemented following the Tailscale model where tags and user ownership are mutually exclusive. Devices can be either user-owned (authenticated via web/OIDC) or tagged (authenticated via tagged PreAuthKeys). Tagged devices receive their identity from tags rather than users, making them suitable for servers and infrastructure. Applying a tag to a device removes user-based authentication. See the [Tailscale tags documentation](https://tailscale.com/kb/1068/tags) for details on how tags work. + ### Database migration support removed for pre-0.25.0 databases Headscale no longer supports direct upgrades from databases created before @@ -30,6 +34,8 @@ release. ### BREAKING +- **Tags**: The gRPC `SetTags` endpoint now allows converting user-owned nodes to tagged nodes by setting tags. Once a node is tagged, it cannot be converted back to a user-owned node. + - Database migration support removed for pre-0.25.0 databases [#2883](https://github.com/juanfont/headscale/pull/2883) - If you are running a version older than 0.25.0, you must upgrade to 0.25.1 first, then upgrade to this release - See the [upgrade path documentation](https://headscale.net/stable/about/faq/#what-is-the-recommended-update-path-can-i-skip-multiple-versions-while-updating) for detailed guidance diff --git a/hscontrol/auth.go b/hscontrol/auth.go index 1ab22709..5c72ea81 100644 --- a/hscontrol/auth.go +++ b/hscontrol/auth.go @@ -233,11 +233,7 @@ func isAuthKey(req tailcfg.RegisterRequest) bool { } func nodeToRegisterResponse(node types.NodeView) *tailcfg.RegisterResponse { - return &tailcfg.RegisterResponse{ - // TODO(kradalby): Only send for user-owned nodes - // and not tagged nodes when tags is working. - User: node.UserView().TailscaleUser(), - Login: node.UserView().TailscaleLogin(), + resp := &tailcfg.RegisterResponse{ NodeKeyExpired: node.IsExpired(), // Headscale does not implement the concept of machine authorization @@ -245,6 +241,18 @@ func nodeToRegisterResponse(node types.NodeView) *tailcfg.RegisterResponse { // Revisit this if #2176 gets implemented. MachineAuthorized: true, } + + // For tagged nodes, use the TaggedDevices special user + // For user-owned nodes, include User and Login information from the actual user + if node.IsTagged() { + resp.User = types.TaggedDevices.View().TailscaleUser() + resp.Login = types.TaggedDevices.View().TailscaleLogin() + } else if node.UserView().Valid() { + resp.User = node.UserView().TailscaleUser() + resp.Login = node.UserView().TailscaleLogin() + } + + return resp } func (h *Headscale) waitForFollowup( diff --git a/hscontrol/auth_tags_test.go b/hscontrol/auth_tags_test.go new file mode 100644 index 00000000..4bc32a6a --- /dev/null +++ b/hscontrol/auth_tags_test.go @@ -0,0 +1,535 @@ +package hscontrol + +import ( + "testing" + "time" + + "github.com/juanfont/headscale/hscontrol/types" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "tailscale.com/tailcfg" + "tailscale.com/types/key" +) + +// TestTaggedPreAuthKeyCreatesTaggedNode tests that a PreAuthKey with tags creates +// a tagged node with: +// - Tags from the PreAuthKey +// - UserID tracking who created the key (informational "created by") +// - IsTagged() returns true. +func TestTaggedPreAuthKeyCreatesTaggedNode(t *testing.T) { + app := createTestApp(t) + + user := app.state.CreateUserForTest("tag-creator") + tags := []string{"tag:server", "tag:prod"} + + // Create a tagged PreAuthKey + pak, err := app.state.CreatePreAuthKey(user.TypedID(), true, false, nil, tags) + require.NoError(t, err) + require.NotEmpty(t, pak.Tags, "PreAuthKey should have tags") + require.ElementsMatch(t, tags, pak.Tags, "PreAuthKey should have specified tags") + + // Register a node using the tagged key + machineKey := key.NewMachine() + nodeKey := key.NewNode() + + regReq := tailcfg.RegisterRequest{ + Auth: &tailcfg.RegisterResponseAuth{ + AuthKey: pak.Key, + }, + NodeKey: nodeKey.Public(), + Hostinfo: &tailcfg.Hostinfo{ + Hostname: "tagged-node", + }, + Expiry: time.Now().Add(24 * time.Hour), + } + + resp, err := app.handleRegisterWithAuthKey(regReq, machineKey.Public()) + require.NoError(t, err) + require.True(t, resp.MachineAuthorized) + + // Verify the node was created with tags + node, found := app.state.GetNodeByNodeKey(nodeKey.Public()) + require.True(t, found) + + // Critical assertions for tags-as-identity model + assert.True(t, node.IsTagged(), "Node should be tagged") + assert.ElementsMatch(t, tags, node.Tags().AsSlice(), "Node should have tags from PreAuthKey") + assert.True(t, node.UserID().Valid(), "Node should have UserID tracking creator") + assert.Equal(t, user.ID, node.UserID().Get(), "UserID should track PreAuthKey creator") + + // Verify node is identified correctly + assert.True(t, node.IsTagged(), "Tagged node is not user-owned") + assert.True(t, node.HasTag("tag:server"), "Node should have tag:server") + assert.True(t, node.HasTag("tag:prod"), "Node should have tag:prod") + assert.False(t, node.HasTag("tag:other"), "Node should not have tag:other") +} + +// TestReAuthDoesNotReapplyTags tests that when a node re-authenticates using the +// same PreAuthKey, the tags are NOT re-applied. Tags are only set during initial +// authentication. This is critical for the container restart scenario (#2830). +// +// NOTE: This test verifies that re-authentication preserves the node's current tags +// without testing tag modification via SetNodeTags (which requires ACL policy setup). +func TestReAuthDoesNotReapplyTags(t *testing.T) { + app := createTestApp(t) + + user := app.state.CreateUserForTest("tag-creator") + initialTags := []string{"tag:server", "tag:dev"} + + // Create a tagged PreAuthKey with reusable=true for re-auth + pak, err := app.state.CreatePreAuthKey(user.TypedID(), true, false, nil, initialTags) + require.NoError(t, err) + + // Initial registration + machineKey := key.NewMachine() + nodeKey := key.NewNode() + + regReq := tailcfg.RegisterRequest{ + Auth: &tailcfg.RegisterResponseAuth{ + AuthKey: pak.Key, + }, + NodeKey: nodeKey.Public(), + Hostinfo: &tailcfg.Hostinfo{ + Hostname: "reauth-test-node", + }, + Expiry: time.Now().Add(24 * time.Hour), + } + + resp, err := app.handleRegisterWithAuthKey(regReq, machineKey.Public()) + require.NoError(t, err) + require.True(t, resp.MachineAuthorized) + + // Verify initial tags + node, found := app.state.GetNodeByNodeKey(nodeKey.Public()) + require.True(t, found) + require.True(t, node.IsTagged()) + require.ElementsMatch(t, initialTags, node.Tags().AsSlice()) + + // Re-authenticate with the SAME PreAuthKey (container restart scenario) + // Key behavior: Tags should NOT be re-applied during re-auth + reAuthReq := tailcfg.RegisterRequest{ + Auth: &tailcfg.RegisterResponseAuth{ + AuthKey: pak.Key, // Same key + }, + NodeKey: nodeKey.Public(), // Same node key + Hostinfo: &tailcfg.Hostinfo{ + Hostname: "reauth-test-node", + }, + Expiry: time.Now().Add(24 * time.Hour), + } + + reAuthResp, err := app.handleRegisterWithAuthKey(reAuthReq, machineKey.Public()) + require.NoError(t, err) + require.True(t, reAuthResp.MachineAuthorized) + + // CRITICAL: Tags should remain unchanged after re-auth + // They should match the original tags, proving they weren't re-applied + nodeAfterReauth, found := app.state.GetNodeByNodeKey(nodeKey.Public()) + require.True(t, found) + assert.True(t, nodeAfterReauth.IsTagged(), "Node should still be tagged") + assert.ElementsMatch(t, initialTags, nodeAfterReauth.Tags().AsSlice(), "Tags should remain unchanged on re-auth") + + // Verify only one node was created (no duplicates) + nodes := app.state.ListNodesByUser(types.UserID(user.ID)) + assert.Equal(t, 1, nodes.Len(), "Should have exactly one node") +} + +// NOTE: TestSetTagsOnUserOwnedNode functionality is covered by gRPC tests in grpcv1_test.go +// which properly handle ACL policy setup. The test verifies that SetTags can convert +// user-owned nodes to tagged nodes while preserving UserID. + +// TestCannotRemoveAllTags tests that attempting to remove all tags from a +// tagged node fails with ErrCannotRemoveAllTags. Once a node is tagged, +// it must always have at least one tag (Tailscale requirement). +func TestCannotRemoveAllTags(t *testing.T) { + app := createTestApp(t) + + user := app.state.CreateUserForTest("tag-creator") + tags := []string{"tag:server"} + + // Create a tagged node + pak, err := app.state.CreatePreAuthKey(user.TypedID(), true, false, nil, tags) + require.NoError(t, err) + + machineKey := key.NewMachine() + nodeKey := key.NewNode() + + regReq := tailcfg.RegisterRequest{ + Auth: &tailcfg.RegisterResponseAuth{ + AuthKey: pak.Key, + }, + NodeKey: nodeKey.Public(), + Hostinfo: &tailcfg.Hostinfo{ + Hostname: "tagged-node", + }, + Expiry: time.Now().Add(24 * time.Hour), + } + + resp, err := app.handleRegisterWithAuthKey(regReq, machineKey.Public()) + require.NoError(t, err) + require.True(t, resp.MachineAuthorized) + + // Verify node is tagged + node, found := app.state.GetNodeByNodeKey(nodeKey.Public()) + require.True(t, found) + require.True(t, node.IsTagged()) + + // Attempt to remove all tags by setting empty array + _, _, err = app.state.SetNodeTags(node.ID(), []string{}) + require.Error(t, err, "Should not be able to remove all tags") + require.ErrorIs(t, err, types.ErrCannotRemoveAllTags, "Error should be ErrCannotRemoveAllTags") + + // Verify node still has original tags + nodeAfter, found := app.state.GetNodeByNodeKey(nodeKey.Public()) + require.True(t, found) + assert.True(t, nodeAfter.IsTagged(), "Node should still be tagged") + assert.ElementsMatch(t, tags, nodeAfter.Tags().AsSlice(), "Tags should be unchanged") +} + +// TestUserOwnedNodeCreatedWithUntaggedPreAuthKey tests that using a PreAuthKey +// without tags creates a user-owned node (no tags, UserID is the owner). +func TestUserOwnedNodeCreatedWithUntaggedPreAuthKey(t *testing.T) { + app := createTestApp(t) + + user := app.state.CreateUserForTest("node-owner") + + // Create an untagged PreAuthKey + pak, err := app.state.CreatePreAuthKey(user.TypedID(), true, false, nil, nil) + require.NoError(t, err) + require.Empty(t, pak.Tags, "PreAuthKey should not be tagged") + require.Empty(t, pak.Tags, "PreAuthKey should have no tags") + + // Register a node + machineKey := key.NewMachine() + nodeKey := key.NewNode() + + regReq := tailcfg.RegisterRequest{ + Auth: &tailcfg.RegisterResponseAuth{ + AuthKey: pak.Key, + }, + NodeKey: nodeKey.Public(), + Hostinfo: &tailcfg.Hostinfo{ + Hostname: "user-owned-node", + }, + Expiry: time.Now().Add(24 * time.Hour), + } + + resp, err := app.handleRegisterWithAuthKey(regReq, machineKey.Public()) + require.NoError(t, err) + require.True(t, resp.MachineAuthorized) + + // Verify node is user-owned + node, found := app.state.GetNodeByNodeKey(nodeKey.Public()) + require.True(t, found) + + // Critical assertions for user-owned node + assert.False(t, node.IsTagged(), "Node should not be tagged") + assert.False(t, node.IsTagged(), "Node should be user-owned (not tagged)") + assert.Empty(t, node.Tags().AsSlice(), "Node should have no tags") + assert.True(t, node.UserID().Valid(), "Node should have UserID") + assert.Equal(t, user.ID, node.UserID().Get(), "UserID should be the PreAuthKey owner") +} + +// TestMultipleNodesWithSameReusableTaggedPreAuthKey tests that a reusable +// PreAuthKey with tags can be used to register multiple nodes, and all nodes +// receive the same tags from the key. +func TestMultipleNodesWithSameReusableTaggedPreAuthKey(t *testing.T) { + app := createTestApp(t) + + user := app.state.CreateUserForTest("tag-creator") + tags := []string{"tag:server", "tag:prod"} + + // Create a REUSABLE tagged PreAuthKey + pak, err := app.state.CreatePreAuthKey(user.TypedID(), true, false, nil, tags) + require.NoError(t, err) + require.ElementsMatch(t, tags, pak.Tags) + + // Register first node + machineKey1 := key.NewMachine() + nodeKey1 := key.NewNode() + + regReq1 := tailcfg.RegisterRequest{ + Auth: &tailcfg.RegisterResponseAuth{ + AuthKey: pak.Key, + }, + NodeKey: nodeKey1.Public(), + Hostinfo: &tailcfg.Hostinfo{ + Hostname: "tagged-node-1", + }, + Expiry: time.Now().Add(24 * time.Hour), + } + + resp1, err := app.handleRegisterWithAuthKey(regReq1, machineKey1.Public()) + require.NoError(t, err) + require.True(t, resp1.MachineAuthorized) + + // Register second node with SAME PreAuthKey + machineKey2 := key.NewMachine() + nodeKey2 := key.NewNode() + + regReq2 := tailcfg.RegisterRequest{ + Auth: &tailcfg.RegisterResponseAuth{ + AuthKey: pak.Key, // Same key + }, + NodeKey: nodeKey2.Public(), + Hostinfo: &tailcfg.Hostinfo{ + Hostname: "tagged-node-2", + }, + Expiry: time.Now().Add(24 * time.Hour), + } + + resp2, err := app.handleRegisterWithAuthKey(regReq2, machineKey2.Public()) + require.NoError(t, err) + require.True(t, resp2.MachineAuthorized) + + // Verify both nodes exist and have the same tags + node1, found := app.state.GetNodeByNodeKey(nodeKey1.Public()) + require.True(t, found) + node2, found := app.state.GetNodeByNodeKey(nodeKey2.Public()) + require.True(t, found) + + // Both nodes should be tagged with the same tags + assert.True(t, node1.IsTagged(), "First node should be tagged") + assert.True(t, node2.IsTagged(), "Second node should be tagged") + assert.ElementsMatch(t, tags, node1.Tags().AsSlice(), "First node should have PreAuthKey tags") + assert.ElementsMatch(t, tags, node2.Tags().AsSlice(), "Second node should have PreAuthKey tags") + + // Both nodes should track the same creator + assert.Equal(t, user.ID, node1.UserID().Get(), "First node should track creator") + assert.Equal(t, user.ID, node2.UserID().Get(), "Second node should track creator") + + // Verify we have exactly 2 nodes + nodes := app.state.ListNodesByUser(types.UserID(user.ID)) + assert.Equal(t, 2, nodes.Len(), "Should have exactly two nodes") +} + +// TestNonReusableTaggedPreAuthKey tests that a non-reusable PreAuthKey with tags +// can only be used once. The second attempt should fail. +func TestNonReusableTaggedPreAuthKey(t *testing.T) { + app := createTestApp(t) + + user := app.state.CreateUserForTest("tag-creator") + tags := []string{"tag:server"} + + // Create a NON-REUSABLE tagged PreAuthKey + pak, err := app.state.CreatePreAuthKey(user.TypedID(), false, false, nil, tags) + require.NoError(t, err) + require.ElementsMatch(t, tags, pak.Tags) + + // Register first node - should succeed + machineKey1 := key.NewMachine() + nodeKey1 := key.NewNode() + + regReq1 := tailcfg.RegisterRequest{ + Auth: &tailcfg.RegisterResponseAuth{ + AuthKey: pak.Key, + }, + NodeKey: nodeKey1.Public(), + Hostinfo: &tailcfg.Hostinfo{ + Hostname: "tagged-node-1", + }, + Expiry: time.Now().Add(24 * time.Hour), + } + + resp1, err := app.handleRegisterWithAuthKey(regReq1, machineKey1.Public()) + require.NoError(t, err) + require.True(t, resp1.MachineAuthorized) + + // Verify first node was created with tags + node1, found := app.state.GetNodeByNodeKey(nodeKey1.Public()) + require.True(t, found) + assert.True(t, node1.IsTagged()) + assert.ElementsMatch(t, tags, node1.Tags().AsSlice()) + + // Attempt to register second node with SAME non-reusable key - should fail + machineKey2 := key.NewMachine() + nodeKey2 := key.NewNode() + + regReq2 := tailcfg.RegisterRequest{ + Auth: &tailcfg.RegisterResponseAuth{ + AuthKey: pak.Key, // Same non-reusable key + }, + NodeKey: nodeKey2.Public(), + Hostinfo: &tailcfg.Hostinfo{ + Hostname: "tagged-node-2", + }, + Expiry: time.Now().Add(24 * time.Hour), + } + + _, err = app.handleRegisterWithAuthKey(regReq2, machineKey2.Public()) + require.Error(t, err, "Should not be able to reuse non-reusable PreAuthKey") + + // Verify only one node was created + nodes := app.state.ListNodesByUser(types.UserID(user.ID)) + assert.Equal(t, 1, nodes.Len(), "Should have exactly one node") +} + +// TestExpiredTaggedPreAuthKey tests that an expired PreAuthKey with tags +// cannot be used to register a node. +func TestExpiredTaggedPreAuthKey(t *testing.T) { + app := createTestApp(t) + + user := app.state.CreateUserForTest("tag-creator") + tags := []string{"tag:server"} + + // Create a PreAuthKey that expires immediately + expiration := time.Now().Add(-1 * time.Hour) // Already expired + pak, err := app.state.CreatePreAuthKey(user.TypedID(), false, false, &expiration, tags) + require.NoError(t, err) + require.ElementsMatch(t, tags, pak.Tags) + + // Attempt to register with expired key + machineKey := key.NewMachine() + nodeKey := key.NewNode() + + regReq := tailcfg.RegisterRequest{ + Auth: &tailcfg.RegisterResponseAuth{ + AuthKey: pak.Key, + }, + NodeKey: nodeKey.Public(), + Hostinfo: &tailcfg.Hostinfo{ + Hostname: "tagged-node", + }, + Expiry: time.Now().Add(24 * time.Hour), + } + + _, err = app.handleRegisterWithAuthKey(regReq, machineKey.Public()) + require.Error(t, err, "Should not be able to use expired PreAuthKey") + + // Verify no node was created + _, found := app.state.GetNodeByNodeKey(nodeKey.Public()) + assert.False(t, found, "No node should be created with expired key") +} + +// TestSingleVsMultipleTags tests that PreAuthKeys work correctly with both +// a single tag and multiple tags. +func TestSingleVsMultipleTags(t *testing.T) { + app := createTestApp(t) + + user := app.state.CreateUserForTest("tag-creator") + + // Test with single tag + singleTag := []string{"tag:server"} + pak1, err := app.state.CreatePreAuthKey(user.TypedID(), true, false, nil, singleTag) + require.NoError(t, err) + + machineKey1 := key.NewMachine() + nodeKey1 := key.NewNode() + + regReq1 := tailcfg.RegisterRequest{ + Auth: &tailcfg.RegisterResponseAuth{ + AuthKey: pak1.Key, + }, + NodeKey: nodeKey1.Public(), + Hostinfo: &tailcfg.Hostinfo{ + Hostname: "single-tag-node", + }, + Expiry: time.Now().Add(24 * time.Hour), + } + + resp1, err := app.handleRegisterWithAuthKey(regReq1, machineKey1.Public()) + require.NoError(t, err) + require.True(t, resp1.MachineAuthorized) + + node1, found := app.state.GetNodeByNodeKey(nodeKey1.Public()) + require.True(t, found) + assert.True(t, node1.IsTagged()) + assert.ElementsMatch(t, singleTag, node1.Tags().AsSlice()) + + // Test with multiple tags + multipleTags := []string{"tag:server", "tag:prod", "tag:database"} + pak2, err := app.state.CreatePreAuthKey(user.TypedID(), true, false, nil, multipleTags) + require.NoError(t, err) + + machineKey2 := key.NewMachine() + nodeKey2 := key.NewNode() + + regReq2 := tailcfg.RegisterRequest{ + Auth: &tailcfg.RegisterResponseAuth{ + AuthKey: pak2.Key, + }, + NodeKey: nodeKey2.Public(), + Hostinfo: &tailcfg.Hostinfo{ + Hostname: "multi-tag-node", + }, + Expiry: time.Now().Add(24 * time.Hour), + } + + resp2, err := app.handleRegisterWithAuthKey(regReq2, machineKey2.Public()) + require.NoError(t, err) + require.True(t, resp2.MachineAuthorized) + + node2, found := app.state.GetNodeByNodeKey(nodeKey2.Public()) + require.True(t, found) + assert.True(t, node2.IsTagged()) + assert.ElementsMatch(t, multipleTags, node2.Tags().AsSlice()) + + // Verify HasTag works for all tags + assert.True(t, node2.HasTag("tag:server")) + assert.True(t, node2.HasTag("tag:prod")) + assert.True(t, node2.HasTag("tag:database")) + assert.False(t, node2.HasTag("tag:other")) +} + +// TestReAuthWithDifferentMachineKey tests the edge case where a node attempts +// to re-authenticate with the same NodeKey but a DIFFERENT MachineKey. +// This scenario should be handled gracefully (currently creates a new node). +func TestReAuthWithDifferentMachineKey(t *testing.T) { + app := createTestApp(t) + + user := app.state.CreateUserForTest("tag-creator") + tags := []string{"tag:server"} + + // Create a reusable tagged PreAuthKey + pak, err := app.state.CreatePreAuthKey(user.TypedID(), true, false, nil, tags) + require.NoError(t, err) + + // Initial registration + machineKey1 := key.NewMachine() + nodeKey := key.NewNode() // Same NodeKey for both attempts + + regReq1 := tailcfg.RegisterRequest{ + Auth: &tailcfg.RegisterResponseAuth{ + AuthKey: pak.Key, + }, + NodeKey: nodeKey.Public(), + Hostinfo: &tailcfg.Hostinfo{ + Hostname: "test-node", + }, + Expiry: time.Now().Add(24 * time.Hour), + } + + resp1, err := app.handleRegisterWithAuthKey(regReq1, machineKey1.Public()) + require.NoError(t, err) + require.True(t, resp1.MachineAuthorized) + + // Verify initial node + node1, found := app.state.GetNodeByNodeKey(nodeKey.Public()) + require.True(t, found) + assert.True(t, node1.IsTagged()) + + // Re-authenticate with DIFFERENT MachineKey but SAME NodeKey + machineKey2 := key.NewMachine() // Different machine key + + regReq2 := tailcfg.RegisterRequest{ + Auth: &tailcfg.RegisterResponseAuth{ + AuthKey: pak.Key, + }, + NodeKey: nodeKey.Public(), // Same NodeKey + Hostinfo: &tailcfg.Hostinfo{ + Hostname: "test-node", + }, + Expiry: time.Now().Add(24 * time.Hour), + } + + resp2, err := app.handleRegisterWithAuthKey(regReq2, machineKey2.Public()) + require.NoError(t, err) + require.True(t, resp2.MachineAuthorized) + + // Verify the node still exists and has tags + // Note: Depending on implementation, this might be the same node or a new node + node2, found := app.state.GetNodeByNodeKey(nodeKey.Public()) + require.True(t, found) + assert.True(t, node2.IsTagged()) + assert.ElementsMatch(t, tags, node2.Tags().AsSlice()) +} diff --git a/hscontrol/auth_test.go b/hscontrol/auth_test.go index b4ab8f16..23e1f226 100644 --- a/hscontrol/auth_test.go +++ b/hscontrol/auth_test.go @@ -70,7 +70,8 @@ func TestAuthenticationFlows(t *testing.T) { name: "preauth_key_valid_new_node", setupFunc: func(t *testing.T, app *Headscale) (string, error) { user := app.state.CreateUserForTest("preauth-user") - pak, err := app.state.CreatePreAuthKey(types.UserID(user.ID), true, false, nil, nil) + + pak, err := app.state.CreatePreAuthKey(user.TypedID(), true, false, nil, nil) if err != nil { return "", err } @@ -111,7 +112,8 @@ func TestAuthenticationFlows(t *testing.T) { name: "preauth_key_reusable_multiple_nodes", setupFunc: func(t *testing.T, app *Headscale) (string, error) { user := app.state.CreateUserForTest("reusable-user") - pak, err := app.state.CreatePreAuthKey(types.UserID(user.ID), true, false, nil, nil) + + pak, err := app.state.CreatePreAuthKey(user.TypedID(), true, false, nil, nil) if err != nil { return "", err } @@ -177,7 +179,8 @@ func TestAuthenticationFlows(t *testing.T) { name: "preauth_key_single_use_exhausted", setupFunc: func(t *testing.T, app *Headscale) (string, error) { user := app.state.CreateUserForTest("single-use-user") - pak, err := app.state.CreatePreAuthKey(types.UserID(user.ID), false, false, nil, nil) + + pak, err := app.state.CreatePreAuthKey(user.TypedID(), false, false, nil, nil) if err != nil { return "", err } @@ -264,7 +267,8 @@ func TestAuthenticationFlows(t *testing.T) { name: "preauth_key_ephemeral_node", setupFunc: func(t *testing.T, app *Headscale) (string, error) { user := app.state.CreateUserForTest("ephemeral-user") - pak, err := app.state.CreatePreAuthKey(types.UserID(user.ID), false, true, nil, nil) + + pak, err := app.state.CreatePreAuthKey(user.TypedID(), false, true, nil, nil) if err != nil { return "", err } @@ -370,7 +374,8 @@ func TestAuthenticationFlows(t *testing.T) { name: "existing_node_logout", setupFunc: func(t *testing.T, app *Headscale) (string, error) { user := app.state.CreateUserForTest("logout-user") - pak, err := app.state.CreatePreAuthKey(types.UserID(user.ID), true, false, nil, nil) + + pak, err := app.state.CreatePreAuthKey(user.TypedID(), true, false, nil, nil) if err != nil { return "", err } @@ -429,7 +434,8 @@ func TestAuthenticationFlows(t *testing.T) { name: "existing_node_machine_key_mismatch", setupFunc: func(t *testing.T, app *Headscale) (string, error) { user := app.state.CreateUserForTest("mismatch-user") - pak, err := app.state.CreatePreAuthKey(types.UserID(user.ID), true, false, nil, nil) + + pak, err := app.state.CreatePreAuthKey(user.TypedID(), true, false, nil, nil) if err != nil { return "", err } @@ -477,7 +483,8 @@ func TestAuthenticationFlows(t *testing.T) { name: "existing_node_key_extension_not_allowed", setupFunc: func(t *testing.T, app *Headscale) (string, error) { user := app.state.CreateUserForTest("extend-user") - pak, err := app.state.CreatePreAuthKey(types.UserID(user.ID), true, false, nil, nil) + + pak, err := app.state.CreatePreAuthKey(user.TypedID(), true, false, nil, nil) if err != nil { return "", err } @@ -525,7 +532,8 @@ func TestAuthenticationFlows(t *testing.T) { name: "existing_node_expired_forces_reauth", setupFunc: func(t *testing.T, app *Headscale) (string, error) { user := app.state.CreateUserForTest("reauth-user") - pak, err := app.state.CreatePreAuthKey(types.UserID(user.ID), true, false, nil, nil) + + pak, err := app.state.CreatePreAuthKey(user.TypedID(), true, false, nil, nil) if err != nil { return "", err } @@ -585,7 +593,8 @@ func TestAuthenticationFlows(t *testing.T) { name: "ephemeral_node_logout_deletion", setupFunc: func(t *testing.T, app *Headscale) (string, error) { user := app.state.CreateUserForTest("ephemeral-logout-user") - pak, err := app.state.CreatePreAuthKey(types.UserID(user.ID), false, true, nil, nil) + + pak, err := app.state.CreatePreAuthKey(user.TypedID(), false, true, nil, nil) if err != nil { return "", err } @@ -767,7 +776,8 @@ func TestAuthenticationFlows(t *testing.T) { name: "empty_hostname", setupFunc: func(t *testing.T, app *Headscale) (string, error) { user := app.state.CreateUserForTest("empty-hostname-user") - pak, err := app.state.CreatePreAuthKey(types.UserID(user.ID), true, false, nil, nil) + + pak, err := app.state.CreatePreAuthKey(user.TypedID(), true, false, nil, nil) if err != nil { return "", err } @@ -805,7 +815,8 @@ func TestAuthenticationFlows(t *testing.T) { name: "nil_hostinfo", setupFunc: func(t *testing.T, app *Headscale) (string, error) { user := app.state.CreateUserForTest("nil-hostinfo-user") - pak, err := app.state.CreatePreAuthKey(types.UserID(user.ID), true, false, nil, nil) + + pak, err := app.state.CreatePreAuthKey(user.TypedID(), true, false, nil, nil) if err != nil { return "", err } @@ -848,7 +859,8 @@ func TestAuthenticationFlows(t *testing.T) { setupFunc: func(t *testing.T, app *Headscale) (string, error) { user := app.state.CreateUserForTest("expired-pak-user") expiry := time.Now().Add(-1 * time.Hour) // Expired 1 hour ago - pak, err := app.state.CreatePreAuthKey(types.UserID(user.ID), true, false, &expiry, nil) + + pak, err := app.state.CreatePreAuthKey(user.TypedID(), true, false, &expiry, nil) if err != nil { return "", err } @@ -880,7 +892,8 @@ func TestAuthenticationFlows(t *testing.T) { setupFunc: func(t *testing.T, app *Headscale) (string, error) { user := app.state.CreateUserForTest("tagged-pak-user") tags := []string{"tag:server", "tag:database"} - pak, err := app.state.CreatePreAuthKey(types.UserID(user.ID), true, false, nil, tags) + + pak, err := app.state.CreatePreAuthKey(user.TypedID(), true, false, nil, tags) if err != nil { return "", err } @@ -926,7 +939,7 @@ func TestAuthenticationFlows(t *testing.T) { user := app.state.CreateUserForTest("reauth-user") // First, register with initial auth key - pak1, err := app.state.CreatePreAuthKey(types.UserID(user.ID), true, false, nil, nil) + pak1, err := app.state.CreatePreAuthKey(user.TypedID(), true, false, nil, nil) if err != nil { return "", err } @@ -953,7 +966,7 @@ func TestAuthenticationFlows(t *testing.T) { }, 1*time.Second, 50*time.Millisecond, "waiting for node to be available in NodeStore") // Create new auth key for re-authentication - pak2, err := app.state.CreatePreAuthKey(types.UserID(user.ID), true, false, nil, nil) + pak2, err := app.state.CreatePreAuthKey(user.TypedID(), true, false, nil, nil) if err != nil { return "", err } @@ -992,7 +1005,8 @@ func TestAuthenticationFlows(t *testing.T) { name: "existing_node_reauth_interactive_flow", setupFunc: func(t *testing.T, app *Headscale) (string, error) { user := app.state.CreateUserForTest("interactive-reauth-user") - pak, err := app.state.CreatePreAuthKey(types.UserID(user.ID), true, false, nil, nil) + + pak, err := app.state.CreatePreAuthKey(user.TypedID(), true, false, nil, nil) if err != nil { return "", err } @@ -1053,7 +1067,8 @@ func TestAuthenticationFlows(t *testing.T) { name: "node_key_rotation_same_machine", setupFunc: func(t *testing.T, app *Headscale) (string, error) { user := app.state.CreateUserForTest("rotation-user") - pak, err := app.state.CreatePreAuthKey(types.UserID(user.ID), true, false, nil, nil) + + pak, err := app.state.CreatePreAuthKey(user.TypedID(), true, false, nil, nil) if err != nil { return "", err } @@ -1081,7 +1096,7 @@ func TestAuthenticationFlows(t *testing.T) { }, 1*time.Second, 50*time.Millisecond, "waiting for node to be available in NodeStore") // Create new auth key for rotation - pakRotation, err := app.state.CreatePreAuthKey(types.UserID(user.ID), true, false, nil, nil) + pakRotation, err := app.state.CreatePreAuthKey(user.TypedID(), true, false, nil, nil) if err != nil { return "", err } @@ -1129,7 +1144,8 @@ func TestAuthenticationFlows(t *testing.T) { name: "malformed_expiry_zero_time", setupFunc: func(t *testing.T, app *Headscale) (string, error) { user := app.state.CreateUserForTest("zero-expiry-user") - pak, err := app.state.CreatePreAuthKey(types.UserID(user.ID), true, false, nil, nil) + + pak, err := app.state.CreatePreAuthKey(user.TypedID(), true, false, nil, nil) if err != nil { return "", err } @@ -1167,7 +1183,8 @@ func TestAuthenticationFlows(t *testing.T) { name: "malformed_hostinfo_invalid_data", setupFunc: func(t *testing.T, app *Headscale) (string, error) { user := app.state.CreateUserForTest("invalid-hostinfo-user") - pak, err := app.state.CreatePreAuthKey(types.UserID(user.ID), true, false, nil, nil) + + pak, err := app.state.CreatePreAuthKey(user.TypedID(), true, false, nil, nil) if err != nil { return "", err } @@ -1353,7 +1370,8 @@ func TestAuthenticationFlows(t *testing.T) { name: "preauth_key_usage_count_tracking", setupFunc: func(t *testing.T, app *Headscale) (string, error) { user := app.state.CreateUserForTest("usage-count-user") - pak, err := app.state.CreatePreAuthKey(types.UserID(user.ID), false, false, nil, nil) // Single use + + pak, err := app.state.CreatePreAuthKey(user.TypedID(), false, false, nil, nil) // Single use if err != nil { return "", err } @@ -1432,7 +1450,8 @@ func TestAuthenticationFlows(t *testing.T) { name: "concurrent_registration_same_node_key", setupFunc: func(t *testing.T, app *Headscale) (string, error) { user := app.state.CreateUserForTest("concurrent-user") - pak, err := app.state.CreatePreAuthKey(types.UserID(user.ID), true, false, nil, nil) + + pak, err := app.state.CreatePreAuthKey(user.TypedID(), true, false, nil, nil) if err != nil { return "", err } @@ -1473,7 +1492,8 @@ func TestAuthenticationFlows(t *testing.T) { user := app.state.CreateUserForTest("future-expiry-user") // Auth key expires in the future expiry := time.Now().Add(48 * time.Hour) - pak, err := app.state.CreatePreAuthKey(types.UserID(user.ID), true, false, &expiry, nil) + + pak, err := app.state.CreatePreAuthKey(user.TypedID(), true, false, &expiry, nil) if err != nil { return "", err } @@ -1517,7 +1537,7 @@ func TestAuthenticationFlows(t *testing.T) { user2 := app.state.CreateUserForTest("user2-context") // Register node with user1's auth key - pak1, err := app.state.CreatePreAuthKey(types.UserID(user1.ID), true, false, nil, nil) + pak1, err := app.state.CreatePreAuthKey(user1.TypedID(), true, false, nil, nil) if err != nil { return "", err } @@ -1544,7 +1564,7 @@ func TestAuthenticationFlows(t *testing.T) { }, 1*time.Second, 50*time.Millisecond, "waiting for node to be available in NodeStore") // Return user2's auth key for re-authentication - pak2, err := app.state.CreatePreAuthKey(types.UserID(user2.ID), true, false, nil, nil) + pak2, err := app.state.CreatePreAuthKey(user2.TypedID(), true, false, nil, nil) if err != nil { return "", err } @@ -1571,15 +1591,15 @@ func TestAuthenticationFlows(t *testing.T) { // Verify NEW node was created for user2 node2, found := app.state.GetNodeByMachineKey(machineKey1.Public(), types.UserID(2)) require.True(t, found, "new node should exist for user2") - assert.Equal(t, uint(2), node2.UserID(), "new node should belong to user2") + assert.Equal(t, uint(2), node2.UserID().Get(), "new node should belong to user2") user := node2.User() - assert.Equal(t, "user2-context", user.Username(), "new node should show user2 username") + assert.Equal(t, "user2-context", user.Name(), "new node should show user2 username") // Verify original node still exists for user1 node1, found := app.state.GetNodeByMachineKey(machineKey1.Public(), types.UserID(1)) require.True(t, found, "original node should still exist for user1") - assert.Equal(t, uint(1), node1.UserID(), "original node should still belong to user1") + assert.Equal(t, uint(1), node1.UserID().Get(), "original node should still belong to user1") // Verify they are different nodes (different IDs) assert.NotEqual(t, node1.ID(), node2.ID(), "should be different node IDs") @@ -1595,7 +1615,8 @@ func TestAuthenticationFlows(t *testing.T) { setupFunc: func(t *testing.T, app *Headscale) (string, error) { // Create user1 and register a node with auth key user1 := app.state.CreateUserForTest("interactive-user-1") - pak1, err := app.state.CreatePreAuthKey(types.UserID(user1.ID), true, false, nil, nil) + + pak1, err := app.state.CreatePreAuthKey(user1.TypedID(), true, false, nil, nil) if err != nil { return "", err } @@ -1645,16 +1666,16 @@ func TestAuthenticationFlows(t *testing.T) { // User1's original node should STILL exist (not transferred) node1, found1 := app.state.GetNodeByMachineKey(machineKey1.Public(), types.UserID(1)) require.True(t, found1, "user1's original node should still exist") - assert.Equal(t, uint(1), node1.UserID(), "user1's node should still belong to user1") + assert.Equal(t, uint(1), node1.UserID().Get(), "user1's node should still belong to user1") assert.Equal(t, nodeKey1.Public(), node1.NodeKey(), "user1's node should have original node key") // User2 should have a NEW node created node2, found2 := app.state.GetNodeByMachineKey(machineKey1.Public(), types.UserID(2)) require.True(t, found2, "user2 should have new node created") - assert.Equal(t, uint(2), node2.UserID(), "user2's node should belong to user2") + assert.Equal(t, uint(2), node2.UserID().Get(), "user2's node should belong to user2") user := node2.User() - assert.Equal(t, "interactive-test-user", user.Username(), "user2's node should show correct username") + assert.Equal(t, "interactive-test-user", user.Name(), "user2's node should show correct username") // Both nodes should have the same machine key but different IDs assert.NotEqual(t, node1.ID(), node2.ID(), "should be different nodes (different IDs)") @@ -1720,7 +1741,8 @@ func TestAuthenticationFlows(t *testing.T) { name: "logout_with_exactly_now_expiry", setupFunc: func(t *testing.T, app *Headscale) (string, error) { user := app.state.CreateUserForTest("exact-now-user") - pak, err := app.state.CreatePreAuthKey(types.UserID(user.ID), true, false, nil, nil) + + pak, err := app.state.CreatePreAuthKey(user.TypedID(), true, false, nil, nil) if err != nil { return "", err } @@ -1813,7 +1835,8 @@ func TestAuthenticationFlows(t *testing.T) { setupFunc: func(t *testing.T, app *Headscale) (string, error) { // First create a node under user1 user1 := app.state.CreateUserForTest("existing-user-1") - pak1, err := app.state.CreatePreAuthKey(types.UserID(user1.ID), true, false, nil, nil) + + pak1, err := app.state.CreatePreAuthKey(user1.TypedID(), true, false, nil, nil) if err != nil { return "", err } @@ -1863,7 +1886,7 @@ func TestAuthenticationFlows(t *testing.T) { // User1's original node with nodeKey1 should STILL exist node1, found1 := app.state.GetNodeByNodeKey(nodeKey1.Public()) require.True(t, found1, "user1's original node with nodeKey1 should still exist") - assert.Equal(t, uint(1), node1.UserID(), "user1's node should still belong to user1") + assert.Equal(t, uint(1), node1.UserID().Get(), "user1's node should still belong to user1") assert.Equal(t, uint64(1), node1.ID().Uint64(), "user1's node should be ID=1") // User2 should have a NEW node with nodeKey2 @@ -1872,7 +1895,7 @@ func TestAuthenticationFlows(t *testing.T) { assert.Equal(t, "existing-node-user2", node2.Hostname(), "hostname should be from new registration") user := node2.User() - assert.Equal(t, "interactive-test-user", user.Username(), "user2's node should belong to user2") + assert.Equal(t, "interactive-test-user", user.Name(), "user2's node should belong to user2") assert.Equal(t, machineKey1.Public(), node2.MachineKey(), "machine key should be the same") // Verify it's a NEW node, not transferred @@ -2022,7 +2045,8 @@ func TestAuthenticationFlows(t *testing.T) { setupFunc: func(t *testing.T, app *Headscale) (string, error) { // Register initial node user := app.state.CreateUserForTest("rotation-user") - pak, err := app.state.CreatePreAuthKey(types.UserID(user.ID), true, false, nil, nil) + + pak, err := app.state.CreatePreAuthKey(user.TypedID(), true, false, nil, nil) if err != nil { return "", err } @@ -2072,7 +2096,7 @@ func TestAuthenticationFlows(t *testing.T) { // User1's original node with nodeKey1 should STILL exist oldNode, foundOld := app.state.GetNodeByNodeKey(nodeKey1.Public()) require.True(t, foundOld, "user1's original node with nodeKey1 should still exist") - assert.Equal(t, uint(1), oldNode.UserID(), "user1's node should still belong to user1") + assert.Equal(t, uint(1), oldNode.UserID().Get(), "user1's node should still belong to user1") assert.Equal(t, uint64(1), oldNode.ID().Uint64(), "user1's node should be ID=1") // User2 should have a NEW node with nodeKey2 @@ -2082,7 +2106,7 @@ func TestAuthenticationFlows(t *testing.T) { assert.Equal(t, machineKey1.Public(), newNode.MachineKey()) user := newNode.User() - assert.Equal(t, "interactive-test-user", user.Username(), "user2's node should belong to user2") + assert.Equal(t, "interactive-test-user", user.Name(), "user2's node should belong to user2") // Verify it's a NEW node, not transferred assert.NotEqual(t, uint64(1), newNode.ID().Uint64(), "should be a NEW node (different ID)") @@ -2333,7 +2357,7 @@ func TestAuthenticationFlows(t *testing.T) { assert.True(t, found, "node should be registered") if found { assert.Equal(t, "pending-node-2", node.Hostname()) - assert.Equal(t, "second-registration-user", node.User().Name) + assert.Equal(t, "second-registration-user", node.User().Name()) } // First registration should still be in cache (not completed) @@ -2593,7 +2617,7 @@ func TestNodeStoreLookup(t *testing.T) { nodeKey := key.NewNode() user := app.state.CreateUserForTest("test-user") - pak, err := app.state.CreatePreAuthKey(types.UserID(user.ID), true, false, nil, nil) + pak, err := app.state.CreatePreAuthKey(user.TypedID(), true, false, nil, nil) require.NoError(t, err) // Register a node @@ -2642,9 +2666,9 @@ func TestPreAuthKeyLogoutAndReloginDifferentUser(t *testing.T) { user2 := app.state.CreateUserForTest("user2") // Create pre-auth keys for both users - pak1, err := app.state.CreatePreAuthKey(types.UserID(user1.ID), true, false, nil, nil) + pak1, err := app.state.CreatePreAuthKey(user1.TypedID(), true, false, nil, nil) require.NoError(t, err) - pak2, err := app.state.CreatePreAuthKey(types.UserID(user2.ID), true, false, nil, nil) + pak2, err := app.state.CreatePreAuthKey(user2.TypedID(), true, false, nil, nil) require.NoError(t, err) // Create machine and node keys for 4 nodes (2 per user) @@ -2720,7 +2744,7 @@ func TestPreAuthKeyLogoutAndReloginDifferentUser(t *testing.T) { t.Logf("All nodes logged out") // Create a new pre-auth key for user1 (reusable for all nodes) - newPak1, err := app.state.CreatePreAuthKey(types.UserID(user1.ID), true, false, nil, nil) + newPak1, err := app.state.CreatePreAuthKey(user1.TypedID(), true, false, nil, nil) require.NoError(t, err) // Re-login all nodes using user1's new pre-auth key @@ -2765,7 +2789,7 @@ func TestPreAuthKeyLogoutAndReloginDifferentUser(t *testing.T) { // User1's original nodes should still be owned by user1 registeredNode, found := app.state.GetNodeByMachineKey(node.machineKey.Public(), types.UserID(user1.ID)) require.True(t, found, "User1's original node %s should still exist", node.hostname) - require.Equal(t, user1.ID, registeredNode.UserID(), "Node %s should still belong to user1", node.hostname) + require.Equal(t, user1.ID, registeredNode.UserID().Get(), "Node %s should still belong to user1", node.hostname) t.Logf("✓ User1's original node %s (ID=%d) still owned by user1", node.hostname, registeredNode.ID().Uint64()) } @@ -2774,7 +2798,7 @@ func TestPreAuthKeyLogoutAndReloginDifferentUser(t *testing.T) { // User2's original nodes should still be owned by user2 registeredNode, found := app.state.GetNodeByMachineKey(node.machineKey.Public(), types.UserID(user2.ID)) require.True(t, found, "User2's original node %s should still exist", node.hostname) - require.Equal(t, user2.ID, registeredNode.UserID(), "Node %s should still belong to user2", node.hostname) + require.Equal(t, user2.ID, registeredNode.UserID().Get(), "Node %s should still belong to user2", node.hostname) t.Logf("✓ User2's original node %s (ID=%d) still owned by user2", node.hostname, registeredNode.ID().Uint64()) } @@ -2785,7 +2809,7 @@ func TestPreAuthKeyLogoutAndReloginDifferentUser(t *testing.T) { // Should be able to find a node with user1 and this machine key (the new one) newNode, found := app.state.GetNodeByMachineKey(node.machineKey.Public(), types.UserID(user1.ID)) require.True(t, found, "Should have created new node for user1 with machine key from %s", node.hostname) - require.Equal(t, user1.ID, newNode.UserID(), "New node should belong to user1") + require.Equal(t, user1.ID, newNode.UserID().Get(), "New node should belong to user1") t.Logf("✓ New node created for user1 with machine key from %s (ID=%d)", node.hostname, newNode.ID().Uint64()) } } @@ -2813,7 +2837,7 @@ func TestWebFlowReauthDifferentUser(t *testing.T) { // Step 1: Register node for user1 via pre-auth key (simulating initial web flow registration) user1 := app.state.CreateUserForTest("user1") - pak1, err := app.state.CreatePreAuthKey(types.UserID(user1.ID), true, false, nil, nil) + pak1, err := app.state.CreatePreAuthKey(user1.TypedID(), true, false, nil, nil) require.NoError(t, err) regReq1 := tailcfg.RegisterRequest{ @@ -2834,7 +2858,7 @@ func TestWebFlowReauthDifferentUser(t *testing.T) { // Verify node exists for user1 user1Node, found := app.state.GetNodeByMachineKey(machineKey.Public(), types.UserID(user1.ID)) require.True(t, found, "Node should exist for user1") - require.Equal(t, user1.ID, user1Node.UserID(), "Node should belong to user1") + require.Equal(t, user1.ID, user1Node.UserID().Get(), "Node should belong to user1") user1NodeID := user1Node.ID() t.Logf("✓ User1 node created with ID: %d", user1NodeID) @@ -2896,7 +2920,7 @@ func TestWebFlowReauthDifferentUser(t *testing.T) { t.Fatal("User1's node was transferred or deleted - this breaks the integration test!") } - assert.Equal(t, user1.ID, user1NodeAfter.UserID(), "User1's node should still belong to user1") + assert.Equal(t, user1.ID, user1NodeAfter.UserID().Get(), "User1's node should still belong to user1") assert.Equal(t, user1NodeID, user1NodeAfter.ID(), "Should be the same node (same ID)") assert.True(t, user1NodeAfter.IsExpired(), "User1's node should still be expired") t.Logf("✓ User1's original node still exists (ID: %d, expired: %v)", user1NodeAfter.ID(), user1NodeAfter.IsExpired()) @@ -2911,7 +2935,7 @@ func TestWebFlowReauthDifferentUser(t *testing.T) { t.Fatal("User2 doesn't have a node - registration failed!") } - assert.Equal(t, user2.ID, user2Node.UserID(), "User2's node should belong to user2") + assert.Equal(t, user2.ID, user2Node.UserID().Get(), "User2's node should belong to user2") assert.NotEqual(t, user1NodeID, user2Node.ID(), "Should be a NEW node (different ID), not transfer!") assert.Equal(t, machineKey.Public(), user2Node.MachineKey(), "Should have same machine key") assert.Equal(t, nodeKey2.Public(), user2Node.NodeKey(), "Should have new node key") @@ -2921,7 +2945,7 @@ func TestWebFlowReauthDifferentUser(t *testing.T) { t.Run("returned_node_is_user2_new_node", func(t *testing.T) { // The node returned from HandleNodeFromAuthPath should be user2's NEW node - assert.Equal(t, user2.ID, node.UserID(), "Returned node should belong to user2") + assert.Equal(t, user2.ID, node.UserID().Get(), "Returned node should belong to user2") assert.NotEqual(t, user1NodeID, node.ID(), "Returned node should be NEW, not transferred from user1") t.Logf("✓ HandleNodeFromAuthPath returned user2's new node (ID: %d)", node.ID()) }) @@ -2949,10 +2973,11 @@ func TestWebFlowReauthDifferentUser(t *testing.T) { user2Nodes := 0 for i := 0; i < allNodesSlice.Len(); i++ { n := allNodesSlice.At(i) - if n.UserID() == user1.ID { + if n.UserID().Get() == user1.ID { user1Nodes++ } - if n.UserID() == user2.ID { + + if n.UserID().Get() == user2.ID { user2Nodes++ } } @@ -3026,7 +3051,7 @@ func TestGitHubIssue2830_NodeRestartWithUsedPreAuthKey(t *testing.T) { // Create user and single-use pre-auth key user := app.state.CreateUserForTest("test-user") - pakNew, err := app.state.CreatePreAuthKey(types.UserID(user.ID), false, false, nil, nil) // reusable=false + pakNew, err := app.state.CreatePreAuthKey(user.TypedID(), false, false, nil, nil) // reusable=false require.NoError(t, err) // Fetch the full pre-auth key to check Reusable field @@ -3117,7 +3142,7 @@ func TestNodeReregistrationWithReusablePreAuthKey(t *testing.T) { app := createTestApp(t) user := app.state.CreateUserForTest("test-user") - pakNew, err := app.state.CreatePreAuthKey(types.UserID(user.ID), true, false, nil, nil) // reusable=true + pakNew, err := app.state.CreatePreAuthKey(user.TypedID(), true, false, nil, nil) // reusable=true require.NoError(t, err) // Fetch the full pre-auth key to check Reusable field @@ -3173,7 +3198,7 @@ func TestNodeReregistrationWithExpiredPreAuthKey(t *testing.T) { user := app.state.CreateUserForTest("test-user") expiry := time.Now().Add(-1 * time.Hour) // Already expired - pak, err := app.state.CreatePreAuthKey(types.UserID(user.ID), true, false, &expiry, nil) + pak, err := app.state.CreatePreAuthKey(user.TypedID(), true, false, &expiry, nil) require.NoError(t, err) machineKey := key.NewMachine() @@ -3306,7 +3331,7 @@ func TestGitHubIssue2830_ExistingNodeCanReregisterWithUsedPreAuthKey(t *testing. // Create a SINGLE-USE pre-auth key (reusable=false) // This is the type of key that triggers the bug in issue #2830 - preAuthKeyNew, err := app.state.CreatePreAuthKey(types.UserID(user.ID), false, false, nil, nil) + preAuthKeyNew, err := app.state.CreatePreAuthKey(user.TypedID(), false, false, nil, nil) require.NoError(t, err) // Fetch the full pre-auth key to check Reusable and Used fields diff --git a/hscontrol/db/db.go b/hscontrol/db/db.go index b3e6b704..ad1a8a25 100644 --- a/hscontrol/db/db.go +++ b/hscontrol/db/db.go @@ -577,6 +577,21 @@ AND auth_key_id NOT IN ( }, Rollback: func(db *gorm.DB) error { return nil }, }, + { + // Rename forced_tags column to tags in nodes table. + // This must run after migration 202505141324 which creates tables with forced_tags. + ID: "202511131445-node-forced-tags-to-tags", + Migrate: func(tx *gorm.DB) error { + // Rename the column from forced_tags to tags + err := tx.Migrator().RenameColumn(&types.Node{}, "forced_tags", "tags") + if err != nil { + return fmt.Errorf("renaming forced_tags to tags: %w", err) + } + + return nil + }, + Rollback: func(db *gorm.DB) error { return nil }, + }, }, ) diff --git a/hscontrol/db/db_test.go b/hscontrol/db/db_test.go index 2a30027e..c50ad37c 100644 --- a/hscontrol/db/db_test.go +++ b/hscontrol/db/db_test.go @@ -231,8 +231,7 @@ func TestPostgresMigrationAndDataValidation(t *testing.T) { name string dbPath string wantFunc func(*testing.T, *HSDatabase) - }{ - } + }{} for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { diff --git a/hscontrol/db/ip_test.go b/hscontrol/db/ip_test.go index 3ec81c9f..7ba335e8 100644 --- a/hscontrol/db/ip_test.go +++ b/hscontrol/db/ip_test.go @@ -95,7 +95,7 @@ func TestIPAllocatorSequential(t *testing.T) { db.DB.Save(&user) db.DB.Save(&types.Node{ - User: user, + User: &user, IPv4: nap("100.64.0.1"), IPv6: nap("fd7a:115c:a1e0::1"), }) @@ -123,7 +123,7 @@ func TestIPAllocatorSequential(t *testing.T) { db.DB.Save(&user) db.DB.Save(&types.Node{ - User: user, + User: &user, IPv4: nap("100.64.0.2"), IPv6: nap("fd7a:115c:a1e0::2"), }) @@ -309,7 +309,7 @@ func TestBackfillIPAddresses(t *testing.T) { db.DB.Save(&user) db.DB.Save(&types.Node{ - User: user, + User: &user, IPv4: nap("100.64.0.1"), }) @@ -334,7 +334,7 @@ func TestBackfillIPAddresses(t *testing.T) { db.DB.Save(&user) db.DB.Save(&types.Node{ - User: user, + User: &user, IPv6: nap("fd7a:115c:a1e0::1"), }) @@ -359,7 +359,7 @@ func TestBackfillIPAddresses(t *testing.T) { db.DB.Save(&user) db.DB.Save(&types.Node{ - User: user, + User: &user, IPv4: nap("100.64.0.1"), IPv6: nap("fd7a:115c:a1e0::1"), }) @@ -383,7 +383,7 @@ func TestBackfillIPAddresses(t *testing.T) { db.DB.Save(&user) db.DB.Save(&types.Node{ - User: user, + User: &user, IPv4: nap("100.64.0.1"), IPv6: nap("fd7a:115c:a1e0::1"), }) @@ -407,19 +407,19 @@ func TestBackfillIPAddresses(t *testing.T) { db.DB.Save(&user) db.DB.Save(&types.Node{ - User: user, + User: &user, IPv4: nap("100.64.0.1"), }) db.DB.Save(&types.Node{ - User: user, + User: &user, IPv4: nap("100.64.0.2"), }) db.DB.Save(&types.Node{ - User: user, + User: &user, IPv4: nap("100.64.0.3"), }) db.DB.Save(&types.Node{ - User: user, + User: &user, IPv4: nap("100.64.0.4"), }) diff --git a/hscontrol/db/node.go b/hscontrol/db/node.go index cfeefb82..bf407bb4 100644 --- a/hscontrol/db/node.go +++ b/hscontrol/db/node.go @@ -196,8 +196,9 @@ func SetTags( tags []string, ) error { if len(tags) == 0 { - // if no tags are provided, we remove all forced tags - if err := tx.Model(&types.Node{}).Where("id = ?", nodeID).Update("forced_tags", "[]").Error; err != nil { + // if no tags are provided, we remove all tags + err := tx.Model(&types.Node{}).Where("id = ?", nodeID).Update("tags", "[]").Error + if err != nil { return fmt.Errorf("removing tags: %w", err) } @@ -211,7 +212,8 @@ func SetTags( return err } - if err := tx.Model(&types.Node{}).Where("id = ?", nodeID).Update("forced_tags", string(b)).Error; err != nil { + err = tx.Model(&types.Node{}).Where("id = ?", nodeID).Update("tags", string(b)).Error + if err != nil { return fmt.Errorf("updating tags: %w", err) } @@ -349,12 +351,20 @@ func RegisterNodeForTest(tx *gorm.DB, node types.Node, ipv4 *netip.Addr, ipv6 *n panic("RegisterNodeForTest can only be called during tests") } - log.Debug(). + logEvent := log.Debug(). Str("node", node.Hostname). Str("machine_key", node.MachineKey.ShortString()). - Str("node_key", node.NodeKey.ShortString()). - Str("user", node.User.Username()). - Msg("Registering test node") + Str("node_key", node.NodeKey.ShortString()) + + if node.User != nil { + logEvent = logEvent.Str("user", node.User.Username()) + } else if node.UserID != nil { + logEvent = logEvent.Uint("user_id", *node.UserID) + } else { + logEvent = logEvent.Str("user", "none") + } + + logEvent.Msg("Registering test node") // If the a new node is registered with the same machine key, to the same user, // update the existing node. @@ -642,7 +652,7 @@ func (hsdb *HSDatabase) CreateNodeForTest(user *types.User, hostname ...string) } // Create a preauth key for the node - pak, err := hsdb.CreatePreAuthKey(types.UserID(user.ID), false, false, nil, nil) + pak, err := hsdb.CreatePreAuthKey(user.TypedID(), false, false, nil, nil) if err != nil { panic(fmt.Sprintf("failed to create preauth key for test node: %v", err)) } @@ -656,7 +666,7 @@ func (hsdb *HSDatabase) CreateNodeForTest(user *types.User, hostname ...string) NodeKey: nodeKey.Public(), DiscoKey: discoKey.Public(), Hostname: nodeName, - UserID: user.ID, + UserID: &user.ID, RegisterMethod: util.RegisterMethodAuthKey, AuthKeyID: ptr.To(pak.ID), } diff --git a/hscontrol/db/node_test.go b/hscontrol/db/node_test.go index 0efd0e8b..f0fac74c 100644 --- a/hscontrol/db/node_test.go +++ b/hscontrol/db/node_test.go @@ -83,7 +83,7 @@ func (s *Suite) TestExpireNode(c *check.C) { user, err := db.CreateUser(types.User{Name: "test"}) c.Assert(err, check.IsNil) - pak, err := db.CreatePreAuthKey(types.UserID(user.ID), false, false, nil, nil) + pak, err := db.CreatePreAuthKey(user.TypedID(), false, false, nil, nil) c.Assert(err, check.IsNil) _, err = db.getNode(types.UserID(user.ID), "testnode") @@ -97,7 +97,7 @@ func (s *Suite) TestExpireNode(c *check.C) { MachineKey: machineKey.Public(), NodeKey: nodeKey.Public(), Hostname: "testnode", - UserID: user.ID, + UserID: &user.ID, RegisterMethod: util.RegisterMethodAuthKey, AuthKeyID: ptr.To(pak.ID), Expiry: &time.Time{}, @@ -124,7 +124,7 @@ func (s *Suite) TestSetTags(c *check.C) { user, err := db.CreateUser(types.User{Name: "test"}) c.Assert(err, check.IsNil) - pak, err := db.CreatePreAuthKey(types.UserID(user.ID), false, false, nil, nil) + pak, err := db.CreatePreAuthKey(user.TypedID(), false, false, nil, nil) c.Assert(err, check.IsNil) _, err = db.getNode(types.UserID(user.ID), "testnode") @@ -138,7 +138,7 @@ func (s *Suite) TestSetTags(c *check.C) { MachineKey: machineKey.Public(), NodeKey: nodeKey.Public(), Hostname: "testnode", - UserID: user.ID, + UserID: &user.ID, RegisterMethod: util.RegisterMethodAuthKey, AuthKeyID: ptr.To(pak.ID), } @@ -152,7 +152,7 @@ func (s *Suite) TestSetTags(c *check.C) { c.Assert(err, check.IsNil) node, err = db.getNode(types.UserID(user.ID), "testnode") c.Assert(err, check.IsNil) - c.Assert(node.ForcedTags, check.DeepEquals, sTags) + c.Assert(node.Tags, check.DeepEquals, sTags) // assign duplicate tags, expect no errors but no doubles in DB eTags := []string{"tag:bar", "tag:test", "tag:unknown", "tag:test"} @@ -161,17 +161,10 @@ func (s *Suite) TestSetTags(c *check.C) { node, err = db.getNode(types.UserID(user.ID), "testnode") c.Assert(err, check.IsNil) c.Assert( - node.ForcedTags, + node.Tags, check.DeepEquals, []string{"tag:bar", "tag:test", "tag:unknown"}, ) - - // test removing tags - err = db.SetTags(node.ID, []string{}) - c.Assert(err, check.IsNil) - node, err = db.getNode(types.UserID(user.ID), "testnode") - c.Assert(err, check.IsNil) - c.Assert(node.ForcedTags, check.DeepEquals, []string{}) } func TestHeadscale_generateGivenName(t *testing.T) { @@ -430,7 +423,7 @@ func TestAutoApproveRoutes(t *testing.T) { MachineKey: key.NewMachine().Public(), NodeKey: key.NewNode().Public(), Hostname: "testnode", - UserID: user.ID, + UserID: &user.ID, RegisterMethod: util.RegisterMethodAuthKey, Hostinfo: &tailcfg.Hostinfo{ RoutableIPs: tt.routes, @@ -446,12 +439,12 @@ func TestAutoApproveRoutes(t *testing.T) { MachineKey: key.NewMachine().Public(), NodeKey: key.NewNode().Public(), Hostname: "taggednode", - UserID: taggedUser.ID, + UserID: &taggedUser.ID, RegisterMethod: util.RegisterMethodAuthKey, Hostinfo: &tailcfg.Hostinfo{ RoutableIPs: tt.routes, }, - ForcedTags: []string{"tag:exit"}, + Tags: []string{"tag:exit"}, IPv4: ptr.To(netip.MustParseAddr("100.64.0.2")), } @@ -593,10 +586,10 @@ func TestListEphemeralNodes(t *testing.T) { user, err := db.CreateUser(types.User{Name: "test"}) require.NoError(t, err) - pak, err := db.CreatePreAuthKey(types.UserID(user.ID), false, false, nil, nil) + pak, err := db.CreatePreAuthKey(user.TypedID(), false, false, nil, nil) require.NoError(t, err) - pakEph, err := db.CreatePreAuthKey(types.UserID(user.ID), false, true, nil, nil) + pakEph, err := db.CreatePreAuthKey(user.TypedID(), false, true, nil, nil) require.NoError(t, err) node := types.Node{ @@ -604,7 +597,7 @@ func TestListEphemeralNodes(t *testing.T) { MachineKey: key.NewMachine().Public(), NodeKey: key.NewNode().Public(), Hostname: "test", - UserID: user.ID, + UserID: &user.ID, RegisterMethod: util.RegisterMethodAuthKey, AuthKeyID: ptr.To(pak.ID), } @@ -614,7 +607,7 @@ func TestListEphemeralNodes(t *testing.T) { MachineKey: key.NewMachine().Public(), NodeKey: key.NewNode().Public(), Hostname: "ephemeral", - UserID: user.ID, + UserID: &user.ID, RegisterMethod: util.RegisterMethodAuthKey, AuthKeyID: ptr.To(pakEph.ID), } @@ -657,7 +650,7 @@ func TestNodeNaming(t *testing.T) { MachineKey: key.NewMachine().Public(), NodeKey: key.NewNode().Public(), Hostname: "test", - UserID: user.ID, + UserID: &user.ID, RegisterMethod: util.RegisterMethodAuthKey, Hostinfo: &tailcfg.Hostinfo{}, } @@ -667,7 +660,7 @@ func TestNodeNaming(t *testing.T) { MachineKey: key.NewMachine().Public(), NodeKey: key.NewNode().Public(), Hostname: "test", - UserID: user2.ID, + UserID: &user2.ID, RegisterMethod: util.RegisterMethodAuthKey, Hostinfo: &tailcfg.Hostinfo{}, } @@ -680,7 +673,7 @@ func TestNodeNaming(t *testing.T) { MachineKey: key.NewMachine().Public(), NodeKey: key.NewNode().Public(), Hostname: "我的电脑", - UserID: user2.ID, + UserID: &user2.ID, RegisterMethod: util.RegisterMethodAuthKey, } @@ -688,7 +681,7 @@ func TestNodeNaming(t *testing.T) { MachineKey: key.NewMachine().Public(), NodeKey: key.NewNode().Public(), Hostname: "a", - UserID: user2.ID, + UserID: &user2.ID, RegisterMethod: util.RegisterMethodAuthKey, } @@ -808,7 +801,7 @@ func TestRenameNodeComprehensive(t *testing.T) { MachineKey: key.NewMachine().Public(), NodeKey: key.NewNode().Public(), Hostname: "testnode", - UserID: user.ID, + UserID: &user.ID, RegisterMethod: util.RegisterMethodAuthKey, Hostinfo: &tailcfg.Hostinfo{}, } @@ -931,7 +924,7 @@ func TestListPeers(t *testing.T) { MachineKey: key.NewMachine().Public(), NodeKey: key.NewNode().Public(), Hostname: "test1", - UserID: user.ID, + UserID: &user.ID, RegisterMethod: util.RegisterMethodAuthKey, Hostinfo: &tailcfg.Hostinfo{}, } @@ -941,7 +934,7 @@ func TestListPeers(t *testing.T) { MachineKey: key.NewMachine().Public(), NodeKey: key.NewNode().Public(), Hostname: "test2", - UserID: user2.ID, + UserID: &user2.ID, RegisterMethod: util.RegisterMethodAuthKey, Hostinfo: &tailcfg.Hostinfo{}, } @@ -1016,7 +1009,7 @@ func TestListNodes(t *testing.T) { MachineKey: key.NewMachine().Public(), NodeKey: key.NewNode().Public(), Hostname: "test1", - UserID: user.ID, + UserID: &user.ID, RegisterMethod: util.RegisterMethodAuthKey, Hostinfo: &tailcfg.Hostinfo{}, } @@ -1026,7 +1019,7 @@ func TestListNodes(t *testing.T) { MachineKey: key.NewMachine().Public(), NodeKey: key.NewNode().Public(), Hostname: "test2", - UserID: user2.ID, + UserID: &user2.ID, RegisterMethod: util.RegisterMethodAuthKey, Hostinfo: &tailcfg.Hostinfo{}, } diff --git a/hscontrol/db/preauth_keys.go b/hscontrol/db/preauth_keys.go index 8580c5b4..105495f1 100644 --- a/hscontrol/db/preauth_keys.go +++ b/hscontrol/db/preauth_keys.go @@ -15,15 +15,15 @@ import ( ) var ( - ErrPreAuthKeyNotFound = errors.New("AuthKey not found") - ErrPreAuthKeyExpired = errors.New("AuthKey expired") - ErrSingleUseAuthKeyHasBeenUsed = errors.New("AuthKey has already been used") + ErrPreAuthKeyNotFound = errors.New("auth-key not found") + ErrPreAuthKeyExpired = errors.New("auth-key expired") + ErrSingleUseAuthKeyHasBeenUsed = errors.New("auth-key has already been used") ErrUserMismatch = errors.New("user mismatch") - ErrPreAuthKeyACLTagInvalid = errors.New("AuthKey tag is invalid") + ErrPreAuthKeyACLTagInvalid = errors.New("auth-key tag is invalid") ) func (hsdb *HSDatabase) CreatePreAuthKey( - uid types.UserID, + uid *types.UserID, reusable bool, ephemeral bool, expiration *time.Time, @@ -41,17 +41,40 @@ const ( ) // CreatePreAuthKey creates a new PreAuthKey in a user, and returns it. +// The uid parameter can be nil for system-created tagged keys. +// For tagged keys, uid tracks "created by" (who created the key). +// For user-owned keys, uid tracks the node owner. func CreatePreAuthKey( tx *gorm.DB, - uid types.UserID, + uid *types.UserID, reusable bool, ephemeral bool, expiration *time.Time, aclTags []string, ) (*types.PreAuthKeyNew, error) { - user, err := GetUserByID(tx, uid) - if err != nil { - return nil, err + // Validate: must be tagged OR user-owned, not neither + if uid == nil && len(aclTags) == 0 { + return nil, ErrPreAuthKeyNotTaggedOrOwned + } + + // If uid != nil && len(aclTags) > 0: + // Both are allowed: UserID tracks "created by", tags define node ownership + // This is valid per the new model + + var ( + user *types.User + userID *uint + ) + + if uid != nil { + var err error + + user, err = GetUserByID(tx, *uid) + if err != nil { + return nil, err + } + + userID = &user.ID } // Remove duplicates and sort for consistency @@ -108,15 +131,15 @@ func CreatePreAuthKey( } key := types.PreAuthKey{ - UserID: user.ID, - User: *user, + UserID: userID, // nil for system-created keys, or "created by" for tagged keys + User: user, // nil for system-created keys Reusable: reusable, Ephemeral: ephemeral, CreatedAt: &now, Expiration: expiration, - Tags: aclTags, - Prefix: prefix, // Store prefix - Hash: hash, // Store hash + Tags: aclTags, // empty for user-owned keys + Prefix: prefix, // Store prefix + Hash: hash, // Store hash } if err := tx.Save(&key).Error; err != nil { @@ -149,14 +172,19 @@ func ListPreAuthKeysByUser(tx *gorm.DB, uid types.UserID) ([]types.PreAuthKey, e } keys := []types.PreAuthKey{} - if err := tx.Preload("User").Where(&types.PreAuthKey{UserID: user.ID}).Find(&keys).Error; err != nil { + + err = tx.Preload("User").Where(&types.PreAuthKey{UserID: &user.ID}).Find(&keys).Error + if err != nil { return nil, err } return keys, nil } -var ErrPreAuthKeyFailedToParse = errors.New("failed to parse AuthKey") +var ( + ErrPreAuthKeyFailedToParse = errors.New("failed to parse auth-key") + ErrPreAuthKeyNotTaggedOrOwned = errors.New("auth-key must be either tagged or owned by user") +) func findAuthKey(tx *gorm.DB, keyStr string) (*types.PreAuthKey, error) { var pak types.PreAuthKey diff --git a/hscontrol/db/preauth_keys_test.go b/hscontrol/db/preauth_keys_test.go index 9ad8ae42..643b579c 100644 --- a/hscontrol/db/preauth_keys_test.go +++ b/hscontrol/db/preauth_keys_test.go @@ -24,7 +24,7 @@ func TestCreatePreAuthKey(t *testing.T) { test: func(t *testing.T, db *HSDatabase) { t.Helper() - _, err := db.CreatePreAuthKey(12345, true, false, nil, nil) + _, err := db.CreatePreAuthKey(ptr.To(types.UserID(12345)), true, false, nil, nil) assert.Error(t, err) }, }, @@ -36,7 +36,7 @@ func TestCreatePreAuthKey(t *testing.T) { user, err := db.CreateUser(types.User{Name: "test"}) require.NoError(t, err) - key, err := db.CreatePreAuthKey(types.UserID(user.ID), true, false, nil, nil) + key, err := db.CreatePreAuthKey(user.TypedID(), true, false, nil, nil) require.NoError(t, err) assert.NotEmpty(t, key.Key) @@ -83,7 +83,7 @@ func TestPreAuthKeyACLTags(t *testing.T) { user, err := db.CreateUser(types.User{Name: "test-tags-1"}) require.NoError(t, err) - _, err = db.CreatePreAuthKey(types.UserID(user.ID), false, false, nil, []string{"badtag"}) + _, err = db.CreatePreAuthKey(user.TypedID(), false, false, nil, []string{"badtag"}) assert.Error(t, err) }, }, @@ -98,7 +98,7 @@ func TestPreAuthKeyACLTags(t *testing.T) { expectedTags := []string{"tag:test1", "tag:test2"} tagsWithDuplicate := []string{"tag:test1", "tag:test2", "tag:test2"} - _, err = db.CreatePreAuthKey(types.UserID(user.ID), false, false, nil, tagsWithDuplicate) + _, err = db.CreatePreAuthKey(user.TypedID(), false, false, nil, tagsWithDuplicate) require.NoError(t, err) listedPaks, err := db.ListPreAuthKeys(types.UserID(user.ID)) @@ -128,13 +128,13 @@ func TestCannotDeleteAssignedPreAuthKey(t *testing.T) { user, err := db.CreateUser(types.User{Name: "test8"}) require.NoError(t, err) - key, err := db.CreatePreAuthKey(types.UserID(user.ID), false, false, nil, []string{"tag:good"}) + key, err := db.CreatePreAuthKey(user.TypedID(), false, false, nil, []string{"tag:good"}) require.NoError(t, err) node := types.Node{ ID: 0, Hostname: "testest", - UserID: user.ID, + UserID: &user.ID, RegisterMethod: util.RegisterMethodAuthKey, AuthKeyID: ptr.To(key.ID), } @@ -180,7 +180,7 @@ func TestPreAuthKeyAuthentication(t *testing.T) { validateResult: func(t *testing.T, pak *types.PreAuthKey) { t.Helper() - assert.Equal(t, user.ID, pak.UserID) + assert.Equal(t, user.ID, *pak.UserID) assert.NotEmpty(t, pak.Key) // Legacy keys have Key populated assert.Empty(t, pak.Prefix) // Legacy keys have empty Prefix assert.Nil(t, pak.Hash) // Legacy keys have nil Hash @@ -191,7 +191,7 @@ func TestPreAuthKeyAuthentication(t *testing.T) { setupKey: func() string { // Create new key via API keyStr, err := db.CreatePreAuthKey( - types.UserID(user.ID), + user.TypedID(), true, false, nil, []string{"tag:test"}, ) require.NoError(t, err) @@ -203,7 +203,7 @@ func TestPreAuthKeyAuthentication(t *testing.T) { validateResult: func(t *testing.T, pak *types.PreAuthKey) { t.Helper() - assert.Equal(t, user.ID, pak.UserID) + assert.Equal(t, user.ID, *pak.UserID) assert.Empty(t, pak.Key) // New keys have empty Key assert.NotEmpty(t, pak.Prefix) // New keys have Prefix assert.NotNil(t, pak.Hash) // New keys have Hash @@ -214,7 +214,7 @@ func TestPreAuthKeyAuthentication(t *testing.T) { name: "new_key_format_validation", setupKey: func() string { keyStr, err := db.CreatePreAuthKey( - types.UserID(user.ID), + user.TypedID(), true, false, nil, nil, ) require.NoError(t, err) @@ -244,7 +244,7 @@ func TestPreAuthKeyAuthentication(t *testing.T) { setupKey: func() string { // Create valid key key, err := db.CreatePreAuthKey( - types.UserID(user.ID), + user.TypedID(), true, false, nil, nil, ) require.NoError(t, err) @@ -415,11 +415,11 @@ func TestMultipleLegacyKeysAllowed(t *testing.T) { assert.Len(t, legacyKeys, 5, "should have created 5 legacy keys") // Now create new bcrypt-based keys - these should have unique prefixes - key1, err := db.CreatePreAuthKey(types.UserID(user.ID), true, false, nil, nil) + key1, err := db.CreatePreAuthKey(user.TypedID(), true, false, nil, nil) require.NoError(t, err) assert.NotEmpty(t, key1.Key) - key2, err := db.CreatePreAuthKey(types.UserID(user.ID), true, false, nil, nil) + key2, err := db.CreatePreAuthKey(user.TypedID(), true, false, nil, nil) require.NoError(t, err) assert.NotEmpty(t, key2.Key) diff --git a/hscontrol/db/schema.sql b/hscontrol/db/schema.sql index 1fb9528e..ef0a2a0e 100644 --- a/hscontrol/db/schema.sql +++ b/hscontrol/db/schema.sql @@ -81,7 +81,7 @@ CREATE TABLE nodes( given_name varchar(63), user_id integer, register_method text, - forced_tags text, + tags text, auth_key_id integer, last_seen datetime, expiry datetime, diff --git a/hscontrol/db/users.go b/hscontrol/db/users.go index dff235e4..92c3292d 100644 --- a/hscontrol/db/users.go +++ b/hscontrol/db/users.go @@ -189,7 +189,11 @@ func (hsdb *HSDatabase) GetUserByName(name string) (*types.User, error) { // ListNodesByUser gets all the nodes in a given user. func ListNodesByUser(tx *gorm.DB, uid types.UserID) (types.Nodes, error) { nodes := types.Nodes{} - if err := tx.Preload("AuthKey").Preload("AuthKey.User").Preload("User").Where(&types.Node{UserID: uint(uid)}).Find(&nodes).Error; err != nil { + + uidPtr := uint(uid) + + err := tx.Preload("AuthKey").Preload("AuthKey.User").Preload("User").Where(&types.Node{UserID: &uidPtr}).Find(&nodes).Error + if err != nil { return nil, err } diff --git a/hscontrol/db/users_test.go b/hscontrol/db/users_test.go index 1ea0772c..a3fd49b3 100644 --- a/hscontrol/db/users_test.go +++ b/hscontrol/db/users_test.go @@ -50,7 +50,7 @@ func TestDestroyUserErrors(t *testing.T) { user := db.CreateUserForTest("test") - pak, err := db.CreatePreAuthKey(types.UserID(user.ID), false, false, nil, nil) + pak, err := db.CreatePreAuthKey(user.TypedID(), false, false, nil, nil) require.NoError(t, err) err = db.DestroyUser(types.UserID(user.ID)) @@ -71,13 +71,13 @@ func TestDestroyUserErrors(t *testing.T) { user, err := db.CreateUser(types.User{Name: "test"}) require.NoError(t, err) - pak, err := db.CreatePreAuthKey(types.UserID(user.ID), false, false, nil, nil) + pak, err := db.CreatePreAuthKey(user.TypedID(), false, false, nil, nil) require.NoError(t, err) node := types.Node{ ID: 0, Hostname: "testnode", - UserID: user.ID, + UserID: &user.ID, RegisterMethod: util.RegisterMethodAuthKey, AuthKeyID: ptr.To(pak.ID), } diff --git a/hscontrol/grpcv1.go b/hscontrol/grpcv1.go index c2e9cee7..a0409e4e 100644 --- a/hscontrol/grpcv1.go +++ b/hscontrol/grpcv1.go @@ -172,7 +172,7 @@ func (api headscaleV1APIServer) CreatePreAuthKey( } preAuthKey, err := api.h.state.CreatePreAuthKey( - types.UserID(user.ID), + user.TypedID(), request.GetReusable(), request.GetEphemeral(), &expiration, @@ -341,6 +341,17 @@ func (api headscaleV1APIServer) SetTags( ctx context.Context, request *v1.SetTagsRequest, ) (*v1.SetTagsResponse, error) { + // Validate tags not empty - tagged nodes must have at least one tag + if len(request.GetTags()) == 0 { + return &v1.SetTagsResponse{ + Node: nil, + }, status.Error( + codes.InvalidArgument, + "cannot remove all tags from a node - tagged nodes must have at least one tag", + ) + } + + // Validate tag format for _, tag := range request.GetTags() { err := validateTag(tag) if err != nil { @@ -348,6 +359,16 @@ func (api headscaleV1APIServer) SetTags( } } + // User XOR Tags: nodes are either tagged or user-owned, never both. + // Setting tags on a user-owned node converts it to a tagged node. + // Once tagged, a node cannot be converted back to user-owned. + _, found := api.h.state.GetNodeByID(types.NodeID(request.GetNodeId())) + if !found { + return &v1.SetTagsResponse{ + Node: nil, + }, status.Error(codes.NotFound, "node not found") + } + node, nodeChange, err := api.h.state.SetNodeTags(types.NodeID(request.GetNodeId()), request.GetTags()) if err != nil { return &v1.SetTagsResponse{ @@ -529,13 +550,19 @@ func nodesToProto(state *state.State, nodes views.Slice[types.NodeView]) []*v1.N for index, node := range nodes.All() { resp := node.Proto() + // Tags-as-identity: tagged nodes show as TaggedDevices user in API responses + // (UserID may be set internally for "created by" tracking) + if node.IsTagged() { + resp.User = types.TaggedDevices.Proto() + } + var tags []string for _, tag := range node.RequestTags() { if state.NodeCanHaveTag(node, tag) { tags = append(tags, tag) } } - resp.ValidTags = lo.Uniq(append(tags, node.ForcedTags().AsSlice()...)) + resp.ValidTags = lo.Uniq(append(tags, node.Tags().AsSlice()...)) resp.SubnetRoutes = util.PrefixesToString(append(state.GetNodePrimaryRoutes(node.ID()), node.ExitRoutes()...)) response[index] = resp @@ -780,7 +807,7 @@ func (api headscaleV1APIServer) DebugCreateNode( NodeKey: key.NewNode().Public(), MachineKey: key.NewMachine().Public(), Hostname: request.GetName(), - User: *user, + User: user, Expiry: &time.Time{}, LastSeen: &time.Time{}, diff --git a/hscontrol/grpcv1_test.go b/hscontrol/grpcv1_test.go index 1d87bfe0..8a50dc59 100644 --- a/hscontrol/grpcv1_test.go +++ b/hscontrol/grpcv1_test.go @@ -1,6 +1,17 @@ package hscontrol -import "testing" +import ( + "context" + "testing" + + v1 "github.com/juanfont/headscale/gen/go/headscale/v1" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" + "tailscale.com/tailcfg" + "tailscale.com/types/key" +) func Test_validateTag(t *testing.T) { type args struct { @@ -40,3 +51,212 @@ func Test_validateTag(t *testing.T) { }) } } + +// TestSetTags_Conversion tests the conversion of user-owned nodes to tagged nodes. +// The tags-as-identity model allows one-way conversion from user-owned to tagged. +// Tag authorization is checked via the policy manager - unauthorized tags are rejected. +func TestSetTags_Conversion(t *testing.T) { + t.Parallel() + + app := createTestApp(t) + + // Create test user and nodes + user := app.state.CreateUserForTest("test-user") + + // Create a pre-auth key WITHOUT tags for user-owned node + pak, err := app.state.CreatePreAuthKey(user.TypedID(), false, false, nil, nil) + require.NoError(t, err) + + machineKey1 := key.NewMachine() + nodeKey1 := key.NewNode() + + // Register a user-owned node (via untagged PreAuthKey) + userOwnedReq := tailcfg.RegisterRequest{ + Auth: &tailcfg.RegisterResponseAuth{ + AuthKey: pak.Key, + }, + NodeKey: nodeKey1.Public(), + Hostinfo: &tailcfg.Hostinfo{ + Hostname: "user-owned-node", + }, + } + _, err = app.handleRegisterWithAuthKey(userOwnedReq, machineKey1.Public()) + require.NoError(t, err) + + // Get the created node + userOwnedNode, found := app.state.GetNodeByNodeKey(nodeKey1.Public()) + require.True(t, found) + + // Create API server instance + apiServer := newHeadscaleV1APIServer(app) + + tests := []struct { + name string + nodeID uint64 + tags []string + wantErr bool + wantCode codes.Code + wantErrMessage string + }{ + { + // Conversion is allowed, but tag authorization fails without tagOwners + name: "reject unauthorized tags on user-owned node", + nodeID: uint64(userOwnedNode.ID()), + tags: []string{"tag:server"}, + wantErr: true, + wantCode: codes.InvalidArgument, + wantErrMessage: "invalid or unauthorized tags", + }, + { + // Conversion is allowed, but tag authorization fails without tagOwners + name: "reject multiple unauthorized tags", + nodeID: uint64(userOwnedNode.ID()), + tags: []string{"tag:server", "tag:database"}, + wantErr: true, + wantCode: codes.InvalidArgument, + wantErrMessage: "invalid or unauthorized tags", + }, + { + name: "reject non-existent node", + nodeID: 99999, + tags: []string{"tag:server"}, + wantErr: true, + wantCode: codes.NotFound, + wantErrMessage: "node not found", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + resp, err := apiServer.SetTags(context.Background(), &v1.SetTagsRequest{ + NodeId: tt.nodeID, + Tags: tt.tags, + }) + + if tt.wantErr { + require.Error(t, err) + st, ok := status.FromError(err) + require.True(t, ok, "error should be a gRPC status error") + assert.Equal(t, tt.wantCode, st.Code()) + assert.Contains(t, st.Message(), tt.wantErrMessage) + assert.Nil(t, resp.GetNode()) + } else { + require.NoError(t, err) + assert.NotNil(t, resp) + assert.NotNil(t, resp.GetNode()) + } + }) + } +} + +// TestSetTags_TaggedNode tests that SetTags correctly identifies tagged nodes +// and doesn't reject them with the "user-owned nodes" error. +// Note: This test doesn't validate ACL tag authorization - that's tested elsewhere. +func TestSetTags_TaggedNode(t *testing.T) { + t.Parallel() + + app := createTestApp(t) + + // Create test user and tagged pre-auth key + user := app.state.CreateUserForTest("test-user") + pak, err := app.state.CreatePreAuthKey(user.TypedID(), false, false, nil, []string{"tag:initial"}) + require.NoError(t, err) + + machineKey := key.NewMachine() + nodeKey := key.NewNode() + + // Register a tagged node (via tagged PreAuthKey) + taggedReq := tailcfg.RegisterRequest{ + Auth: &tailcfg.RegisterResponseAuth{ + AuthKey: pak.Key, + }, + NodeKey: nodeKey.Public(), + Hostinfo: &tailcfg.Hostinfo{ + Hostname: "tagged-node", + }, + } + _, err = app.handleRegisterWithAuthKey(taggedReq, machineKey.Public()) + require.NoError(t, err) + + // Get the created node + taggedNode, found := app.state.GetNodeByNodeKey(nodeKey.Public()) + require.True(t, found) + assert.True(t, taggedNode.IsTagged(), "Node should be tagged") + assert.True(t, taggedNode.UserID().Valid(), "Tagged node should have UserID for tracking") + + // Create API server instance + apiServer := newHeadscaleV1APIServer(app) + + // Test: SetTags should NOT reject tagged nodes with "user-owned" error + // (Even though they have UserID set, IsTagged() identifies them correctly) + resp, err := apiServer.SetTags(context.Background(), &v1.SetTagsRequest{ + NodeId: uint64(taggedNode.ID()), + Tags: []string{"tag:initial"}, // Keep existing tag to avoid ACL validation issues + }) + + // The call should NOT fail with "cannot set tags on user-owned nodes" + if err != nil { + st, ok := status.FromError(err) + require.True(t, ok) + // If error is about unauthorized tags, that's fine - ACL validation is working + // If error is about user-owned nodes, that's the bug we're testing for + assert.NotContains(t, st.Message(), "user-owned nodes", "Should not reject tagged nodes as user-owned") + } else { + // Success is also fine + assert.NotNil(t, resp) + } +} + +// TestSetTags_CannotRemoveAllTags tests that SetTags rejects attempts to remove +// all tags from a tagged node, enforcing Tailscale's requirement that tagged +// nodes must have at least one tag. +func TestSetTags_CannotRemoveAllTags(t *testing.T) { + t.Parallel() + + app := createTestApp(t) + + // Create test user and tagged pre-auth key + user := app.state.CreateUserForTest("test-user") + pak, err := app.state.CreatePreAuthKey(user.TypedID(), false, false, nil, []string{"tag:server"}) + require.NoError(t, err) + + machineKey := key.NewMachine() + nodeKey := key.NewNode() + + // Register a tagged node + taggedReq := tailcfg.RegisterRequest{ + Auth: &tailcfg.RegisterResponseAuth{ + AuthKey: pak.Key, + }, + NodeKey: nodeKey.Public(), + Hostinfo: &tailcfg.Hostinfo{ + Hostname: "tagged-node", + }, + } + _, err = app.handleRegisterWithAuthKey(taggedReq, machineKey.Public()) + require.NoError(t, err) + + // Get the created node + taggedNode, found := app.state.GetNodeByNodeKey(nodeKey.Public()) + require.True(t, found) + assert.True(t, taggedNode.IsTagged()) + + // Create API server instance + apiServer := newHeadscaleV1APIServer(app) + + // Attempt to remove all tags (empty array) + resp, err := apiServer.SetTags(context.Background(), &v1.SetTagsRequest{ + NodeId: uint64(taggedNode.ID()), + Tags: []string{}, // Empty - attempting to remove all tags + }) + + // Should fail with InvalidArgument error + require.Error(t, err) + st, ok := status.FromError(err) + require.True(t, ok, "error should be a gRPC status error") + assert.Equal(t, codes.InvalidArgument, st.Code()) + assert.Contains(t, st.Message(), "cannot remove all tags") + assert.Nil(t, resp.GetNode()) +} diff --git a/hscontrol/mapper/mapper.go b/hscontrol/mapper/mapper.go index 372bb557..37778dc0 100644 --- a/hscontrol/mapper/mapper.go +++ b/hscontrol/mapper/mapper.go @@ -73,15 +73,17 @@ func generateUserProfiles( node types.NodeView, peers views.Slice[types.NodeView], ) []tailcfg.UserProfile { - userMap := make(map[uint]*types.User) + userMap := make(map[uint]*types.UserView) ids := make([]uint, 0, len(userMap)) user := node.User() - userMap[user.ID] = &user - ids = append(ids, user.ID) + userID := user.Model().ID + userMap[userID] = &user + ids = append(ids, userID) for _, peer := range peers.All() { peerUser := peer.User() - userMap[peerUser.ID] = &peerUser - ids = append(ids, peerUser.ID) + peerUserID := peerUser.Model().ID + userMap[peerUserID] = &peerUser + ids = append(ids, peerUserID) } slices.Sort(ids) diff --git a/hscontrol/mapper/mapper_test.go b/hscontrol/mapper/mapper_test.go index b801f7dd..1bafd135 100644 --- a/hscontrol/mapper/mapper_test.go +++ b/hscontrol/mapper/mapper_test.go @@ -14,6 +14,7 @@ import ( "github.com/juanfont/headscale/hscontrol/types" "tailscale.com/tailcfg" "tailscale.com/types/dnstype" + "tailscale.com/types/ptr" ) var iap = func(ipStr string) *netip.Addr { @@ -50,8 +51,8 @@ func TestDNSConfigMapResponse(t *testing.T) { mach := func(hostname, username string, userid uint) *types.Node { return &types.Node{ Hostname: hostname, - UserID: userid, - User: types.User{ + UserID: ptr.To(userid), + User: &types.User{ Name: username, }, } diff --git a/hscontrol/mapper/tail.go b/hscontrol/mapper/tail.go index 3a518d94..3153f62b 100644 --- a/hscontrol/mapper/tail.go +++ b/hscontrol/mapper/tail.go @@ -83,7 +83,8 @@ func tailNode( tags = append(tags, tag) } } - for _, tag := range node.ForcedTags().All() { + + for _, tag := range node.Tags().All() { tags = append(tags, tag) } tags = lo.Uniq(tags) @@ -99,7 +100,7 @@ func tailNode( Name: hostname, Cap: capVer, - User: tailcfg.UserID(node.UserID()), + User: node.TailscaleUserID(), Key: node.NodeKey(), KeyExpiry: keyExpiry.UTC(), diff --git a/hscontrol/mapper/tail_test.go b/hscontrol/mapper/tail_test.go index 3a3b39d1..9b0765ba 100644 --- a/hscontrol/mapper/tail_test.go +++ b/hscontrol/mapper/tail_test.go @@ -15,6 +15,7 @@ import ( "tailscale.com/net/tsaddr" "tailscale.com/tailcfg" "tailscale.com/types/key" + "tailscale.com/types/ptr" ) func TestTailNode(t *testing.T) { @@ -97,14 +98,14 @@ func TestTailNode(t *testing.T) { IPv4: iap("100.64.0.1"), Hostname: "mini", GivenName: "mini", - UserID: 0, - User: types.User{ + UserID: ptr.To(uint(0)), + User: &types.User{ Name: "mini", }, - ForcedTags: []string{}, - AuthKey: &types.PreAuthKey{}, - LastSeen: &lastSeen, - Expiry: &expire, + Tags: []string{}, + AuthKey: &types.PreAuthKey{}, + LastSeen: &lastSeen, + Expiry: &expire, Hostinfo: &tailcfg.Hostinfo{ RoutableIPs: []netip.Prefix{ tsaddr.AllIPv4(), diff --git a/hscontrol/oidc_template_test.go b/hscontrol/oidc_template_test.go index 4c5ecaa7..367451b1 100644 --- a/hscontrol/oidc_template_test.go +++ b/hscontrol/oidc_template_test.go @@ -1,13 +1,10 @@ package hscontrol import ( - "os" - "path/filepath" "testing" "github.com/juanfont/headscale/hscontrol/templates" "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" ) func TestOIDCCallbackTemplate(t *testing.T) { @@ -49,15 +46,6 @@ func TestOIDCCallbackTemplate(t *testing.T) { assert.Contains(t, html, "= len(newIPs) || oldIP != newIPs[i] { - affectedUsers[newNode.User().ID] = struct{}{} + affectedUsers[newNode.User().ID()] = struct{}{} break } } @@ -750,7 +750,7 @@ func (pm *PolicyManager) invalidateAutogroupSelfCache(oldNodes, newNodes views.S // Check in new nodes first for _, node := range newNodes.All() { if node.ID() == nodeID { - nodeUserID = node.User().ID + nodeUserID = node.User().ID() found = true break } @@ -760,7 +760,7 @@ func (pm *PolicyManager) invalidateAutogroupSelfCache(oldNodes, newNodes views.S if !found { for _, node := range oldNodes.All() { if node.ID() == nodeID { - nodeUserID = node.User().ID + nodeUserID = node.User().ID() found = true break } diff --git a/hscontrol/policy/v2/policy_test.go b/hscontrol/policy/v2/policy_test.go index 5b36b79e..94e631e7 100644 --- a/hscontrol/policy/v2/policy_test.go +++ b/hscontrol/policy/v2/policy_test.go @@ -11,6 +11,7 @@ import ( "github.com/stretchr/testify/require" "gorm.io/gorm" "tailscale.com/tailcfg" + "tailscale.com/types/ptr" ) func node(name, ipv4, ipv6 string, user types.User, hostinfo *tailcfg.Hostinfo) *types.Node { @@ -19,8 +20,8 @@ func node(name, ipv4, ipv6 string, user types.User, hostinfo *tailcfg.Hostinfo) Hostname: name, IPv4: ap(ipv4), IPv6: ap(ipv6), - User: user, - UserID: user.ID, + User: ptr.To(user), + UserID: ptr.To(user.ID), Hostinfo: hostinfo, } } @@ -456,8 +457,8 @@ func TestAutogroupSelfWithOtherRules(t *testing.T) { Hostname: "test-1-device", IPv4: ap("100.64.0.1"), IPv6: ap("fd7a:115c:a1e0::1"), - User: users[0], - UserID: users[0].ID, + User: ptr.To(users[0]), + UserID: ptr.To(users[0].ID), Hostinfo: &tailcfg.Hostinfo{}, } @@ -467,9 +468,9 @@ func TestAutogroupSelfWithOtherRules(t *testing.T) { Hostname: "test-2-router", IPv4: ap("100.64.0.2"), IPv6: ap("fd7a:115c:a1e0::2"), - User: users[1], - UserID: users[1].ID, - ForcedTags: []string{"tag:node-router"}, + User: ptr.To(users[1]), + UserID: ptr.To(users[1].ID), + Tags: []string{"tag:node-router"}, Hostinfo: &tailcfg.Hostinfo{}, } diff --git a/hscontrol/policy/v2/types.go b/hscontrol/policy/v2/types.go index 7b4b2b28..0635a557 100644 --- a/hscontrol/policy/v2/types.go +++ b/hscontrol/policy/v2/types.go @@ -206,7 +206,12 @@ func (u Username) Resolve(_ *Policy, users types.Users, nodes views.Slice[types. continue } - if node.User().ID == user.ID { + // Skip nodes without a user (defensive check for tests) + if !node.User().Valid() { + continue + } + + if node.User().ID() == user.ID { node.AppendToIPSet(&ips) } } @@ -311,8 +316,8 @@ func (t Tag) Resolve(p *Policy, users types.Users, nodes views.Slice[types.NodeV } for _, node := range nodes.All() { - // Check if node has this tag in all tags (ForcedTags + AuthKey.Tags) - if slices.Contains(node.Tags(), string(t)) { + // Check if node has this tag + if node.HasTag(string(t)) { node.AppendToIPSet(&ips) } diff --git a/hscontrol/policy/v2/types_test.go b/hscontrol/policy/v2/types_test.go index d5a8730a..2d379b4d 100644 --- a/hscontrol/policy/v2/types_test.go +++ b/hscontrol/policy/v2/types_test.go @@ -1549,7 +1549,17 @@ func TestResolvePolicy(t *testing.T) { "groupuser1": {Model: gorm.Model{ID: 3}, Name: "groupuser1"}, "groupuser2": {Model: gorm.Model{ID: 4}, Name: "groupuser2"}, "notme": {Model: gorm.Model{ID: 5}, Name: "notme"}, + "testuser2": {Model: gorm.Model{ID: 6}, Name: "testuser2"}, } + + // Extract users to variables so we can take their addresses + testuser := users["testuser"] + groupuser := users["groupuser"] + groupuser1 := users["groupuser1"] + groupuser2 := users["groupuser2"] + notme := users["notme"] + testuser2 := users["testuser2"] + tests := []struct { name string nodes types.Nodes @@ -1579,29 +1589,27 @@ func TestResolvePolicy(t *testing.T) { nodes: types.Nodes{ // Not matching other user { - User: users["notme"], + User: ptr.To(notme), IPv4: ap("100.100.101.1"), }, // Not matching forced tags { - User: users["testuser"], - ForcedTags: []string{"tag:anything"}, + User: ptr.To(testuser), + Tags: []string{"tag:anything"}, IPv4: ap("100.100.101.2"), }, - // not matching pak tag + // not matching because it's tagged (tags copied from AuthKey) { - User: users["testuser"], - AuthKey: &types.PreAuthKey{ - Tags: []string{"alsotagged"}, - }, + User: ptr.To(testuser), + Tags: []string{"alsotagged"}, IPv4: ap("100.100.101.3"), }, { - User: users["testuser"], + User: ptr.To(testuser), IPv4: ap("100.100.101.103"), }, { - User: users["testuser"], + User: ptr.To(testuser), IPv4: ap("100.100.101.104"), }, }, @@ -1613,29 +1621,27 @@ func TestResolvePolicy(t *testing.T) { nodes: types.Nodes{ // Not matching other user { - User: users["notme"], + User: ptr.To(notme), IPv4: ap("100.100.101.4"), }, // Not matching forced tags { - User: users["groupuser"], - ForcedTags: []string{"tag:anything"}, + User: ptr.To(groupuser), + Tags: []string{"tag:anything"}, IPv4: ap("100.100.101.5"), }, - // not matching pak tag + // not matching because it's tagged (tags copied from AuthKey) { - User: users["groupuser"], - AuthKey: &types.PreAuthKey{ - Tags: []string{"tag:alsotagged"}, - }, + User: ptr.To(groupuser), + Tags: []string{"tag:alsotagged"}, IPv4: ap("100.100.101.6"), }, { - User: users["groupuser"], + User: ptr.To(groupuser), IPv4: ap("100.100.101.203"), }, { - User: users["groupuser"], + User: ptr.To(groupuser), IPv4: ap("100.100.101.204"), }, }, @@ -1653,12 +1659,12 @@ func TestResolvePolicy(t *testing.T) { nodes: types.Nodes{ // Not matching other user { - User: users["notme"], + User: ptr.To(notme), IPv4: ap("100.100.101.9"), }, // Not matching forced tags { - ForcedTags: []string{"tag:anything"}, + Tags: []string{"tag:anything"}, IPv4: ap("100.100.101.10"), }, // not matching pak tag @@ -1670,14 +1676,12 @@ func TestResolvePolicy(t *testing.T) { }, // Not matching forced tags { - ForcedTags: []string{"tag:test"}, + Tags: []string{"tag:test"}, IPv4: ap("100.100.101.234"), }, - // not matching pak tag + // matching tag (tags copied from AuthKey during registration) { - AuthKey: &types.PreAuthKey{ - Tags: []string{"tag:test"}, - }, + Tags: []string{"tag:test"}, IPv4: ap("100.100.101.239"), }, }, @@ -1706,11 +1710,11 @@ func TestResolvePolicy(t *testing.T) { toResolve: ptr.To(Group("group:testgroup")), nodes: types.Nodes{ { - User: users["groupuser1"], + User: ptr.To(groupuser1), IPv4: ap("100.100.101.203"), }, { - User: users["groupuser2"], + User: ptr.To(groupuser2), IPv4: ap("100.100.101.204"), }, }, @@ -1731,7 +1735,7 @@ func TestResolvePolicy(t *testing.T) { toResolve: ptr.To(Username("invaliduser@")), nodes: types.Nodes{ { - User: users["testuser"], + User: ptr.To(testuser), IPv4: ap("100.100.101.103"), }, }, @@ -1742,7 +1746,7 @@ func TestResolvePolicy(t *testing.T) { toResolve: tp("tag:invalid"), nodes: types.Nodes{ { - ForcedTags: []string{"tag:test"}, + Tags: []string{"tag:test"}, IPv4: ap("100.100.101.234"), }, }, @@ -1763,18 +1767,18 @@ func TestResolvePolicy(t *testing.T) { nodes: types.Nodes{ // Node with no tags (should be included) { - User: users["testuser"], + User: ptr.To(testuser), IPv4: ap("100.100.101.1"), }, // Node with forced tags (should be excluded) { - User: users["testuser"], - ForcedTags: []string{"tag:test"}, + User: ptr.To(testuser), + Tags: []string{"tag:test"}, IPv4: ap("100.100.101.2"), }, // Node with allowed requested tag (should be excluded) { - User: users["testuser"], + User: ptr.To(testuser), Hostinfo: &tailcfg.Hostinfo{ RequestTags: []string{"tag:test"}, }, @@ -1782,7 +1786,7 @@ func TestResolvePolicy(t *testing.T) { }, // Node with non-allowed requested tag (should be included) { - User: users["testuser"], + User: ptr.To(testuser), Hostinfo: &tailcfg.Hostinfo{ RequestTags: []string{"tag:notallowed"}, }, @@ -1790,7 +1794,7 @@ func TestResolvePolicy(t *testing.T) { }, // Node with multiple requested tags, one allowed (should be excluded) { - User: users["testuser"], + User: ptr.To(testuser), Hostinfo: &tailcfg.Hostinfo{ RequestTags: []string{"tag:test", "tag:notallowed"}, }, @@ -1798,7 +1802,7 @@ func TestResolvePolicy(t *testing.T) { }, // Node with multiple requested tags, none allowed (should be included) { - User: users["testuser"], + User: ptr.To(testuser), Hostinfo: &tailcfg.Hostinfo{ RequestTags: []string{"tag:notallowed1", "tag:notallowed2"}, }, @@ -1822,18 +1826,18 @@ func TestResolvePolicy(t *testing.T) { nodes: types.Nodes{ // Node with no tags (should be excluded) { - User: users["testuser"], + User: ptr.To(testuser), IPv4: ap("100.100.101.1"), }, // Node with forced tag (should be included) { - User: users["testuser"], - ForcedTags: []string{"tag:test"}, + User: ptr.To(testuser), + Tags: []string{"tag:test"}, IPv4: ap("100.100.101.2"), }, // Node with allowed requested tag (should be included) { - User: users["testuser"], + User: ptr.To(testuser), Hostinfo: &tailcfg.Hostinfo{ RequestTags: []string{"tag:test"}, }, @@ -1841,7 +1845,7 @@ func TestResolvePolicy(t *testing.T) { }, // Node with non-allowed requested tag (should be excluded) { - User: users["testuser"], + User: ptr.To(testuser), Hostinfo: &tailcfg.Hostinfo{ RequestTags: []string{"tag:notallowed"}, }, @@ -1849,7 +1853,7 @@ func TestResolvePolicy(t *testing.T) { }, // Node with multiple requested tags, one allowed (should be included) { - User: users["testuser"], + User: ptr.To(testuser), Hostinfo: &tailcfg.Hostinfo{ RequestTags: []string{"tag:test", "tag:notallowed"}, }, @@ -1857,7 +1861,7 @@ func TestResolvePolicy(t *testing.T) { }, // Node with multiple requested tags, none allowed (should be excluded) { - User: users["testuser"], + User: ptr.To(testuser), Hostinfo: &tailcfg.Hostinfo{ RequestTags: []string{"tag:notallowed1", "tag:notallowed2"}, }, @@ -1865,8 +1869,8 @@ func TestResolvePolicy(t *testing.T) { }, // Node with multiple forced tags (should be included) { - User: users["testuser"], - ForcedTags: []string{"tag:test", "tag:other"}, + User: ptr.To(testuser), + Tags: []string{"tag:test", "tag:other"}, IPv4: ap("100.100.101.7"), }, }, @@ -1886,20 +1890,20 @@ func TestResolvePolicy(t *testing.T) { toResolve: ptr.To(AutoGroupSelf), nodes: types.Nodes{ { - User: users["testuser"], + User: ptr.To(testuser), IPv4: ap("100.100.101.1"), }, { - User: users["testuser2"], + User: ptr.To(testuser2), IPv4: ap("100.100.101.2"), }, { - User: users["testuser"], - ForcedTags: []string{"tag:test"}, + User: ptr.To(testuser), + Tags: []string{"tag:test"}, IPv4: ap("100.100.101.3"), }, { - User: users["testuser2"], + User: ptr.To(testuser2), Hostinfo: &tailcfg.Hostinfo{ RequestTags: []string{"tag:test"}, }, @@ -1961,23 +1965,23 @@ func TestResolveAutoApprovers(t *testing.T) { nodes := types.Nodes{ { IPv4: ap("100.64.0.1"), - User: users[0], + User: &users[0], }, { IPv4: ap("100.64.0.2"), - User: users[1], + User: &users[1], }, { IPv4: ap("100.64.0.3"), - User: users[2], + User: &users[2], }, { IPv4: ap("100.64.0.4"), - ForcedTags: []string{"tag:testtag"}, + Tags: []string{"tag:testtag"}, }, { IPv4: ap("100.64.0.5"), - ForcedTags: []string{"tag:exittest"}, + Tags: []string{"tag:exittest"}, }, } @@ -2280,15 +2284,15 @@ func TestNodeCanApproveRoute(t *testing.T) { nodes := types.Nodes{ { IPv4: ap("100.64.0.1"), - User: users[0], + User: &users[0], }, { IPv4: ap("100.64.0.2"), - User: users[1], + User: &users[1], }, { IPv4: ap("100.64.0.3"), - User: users[2], + User: &users[2], }, } @@ -2413,15 +2417,15 @@ func TestResolveTagOwners(t *testing.T) { nodes := types.Nodes{ { IPv4: ap("100.64.0.1"), - User: users[0], + User: &users[0], }, { IPv4: ap("100.64.0.2"), - User: users[1], + User: &users[1], }, { IPv4: ap("100.64.0.3"), - User: users[2], + User: &users[2], }, } @@ -2498,15 +2502,15 @@ func TestNodeCanHaveTag(t *testing.T) { nodes := types.Nodes{ { IPv4: ap("100.64.0.1"), - User: users[0], + User: &users[0], }, { IPv4: ap("100.64.0.2"), - User: users[1], + User: &users[1], }, { IPv4: ap("100.64.0.3"), - User: users[2], + User: &users[2], }, } @@ -2580,6 +2584,49 @@ func TestNodeCanHaveTag(t *testing.T) { tag: "tag:test", want: false, }, + { + name: "node-with-unauthorized-tag-different-user", + policy: &Policy{ + TagOwners: TagOwners{ + Tag("tag:prod"): Owners{ptr.To(Username("user1@"))}, + }, + }, + node: nodes[2], // user3's node + tag: "tag:prod", + want: false, + }, + { + name: "node-with-multiple-tags-one-unauthorized", + policy: &Policy{ + TagOwners: TagOwners{ + Tag("tag:web"): Owners{ptr.To(Username("user1@"))}, + Tag("tag:database"): Owners{ptr.To(Username("user2@"))}, + }, + }, + node: nodes[0], // user1's node + tag: "tag:database", + want: false, // user1 cannot have tag:database (owned by user2) + }, + { + name: "empty-tagowners-map", + policy: &Policy{ + TagOwners: TagOwners{}, + }, + node: nodes[0], + tag: "tag:test", + want: false, // No one can have tags if tagOwners is empty + }, + { + name: "tag-not-in-tagowners", + policy: &Policy{ + TagOwners: TagOwners{ + Tag("tag:prod"): Owners{ptr.To(Username("user1@"))}, + }, + }, + node: nodes[0], + tag: "tag:dev", // This tag is not defined in tagOwners + want: false, + }, } for _, tt := range tests { diff --git a/hscontrol/poll.go b/hscontrol/poll.go index 4324ffba..b0453051 100644 --- a/hscontrol/poll.go +++ b/hscontrol/poll.go @@ -11,6 +11,7 @@ import ( "github.com/juanfont/headscale/hscontrol/types" "github.com/juanfont/headscale/hscontrol/util" + "github.com/rs/zerolog" "github.com/rs/zerolog/log" "github.com/sasha-s/go-deadlock" "tailscale.com/tailcfg" @@ -42,11 +43,6 @@ type mapSession struct { node *types.Node w http.ResponseWriter - - warnf func(string, ...any) - infof func(string, ...any) - tracef func(string, ...any) - errf func(error, string, ...any) } func (h *Headscale) newMapSession( @@ -55,8 +51,6 @@ func (h *Headscale) newMapSession( w http.ResponseWriter, node *types.Node, ) *mapSession { - warnf, infof, tracef, errf := logPollFunc(req, node) - ka := keepAliveInterval + (time.Duration(rand.IntN(9000)) * time.Millisecond) return &mapSession{ @@ -73,12 +67,6 @@ func (h *Headscale) newMapSession( keepAlive: ka, keepAliveTicker: nil, - - // Loggers - warnf: warnf, - infof: infof, - tracef: tracef, - errf: errf, } } @@ -295,6 +283,7 @@ func (m *mapSession) writeMap(msg *tailcfg.MapResponse) error { } data := make([]byte, reservedResponseHeaderSize) + //nolint:gosec // G115: JSON response size will not exceed uint32 max binary.LittleEndian.PutUint32(data, uint32(len(jsonBody))) data = append(data, jsonBody...) @@ -365,45 +354,22 @@ func logTracePeerChange(hostname string, hostinfoChange bool, peerChange *tailcf trace.Time("last_seen", *peerChange.LastSeen).Msg("PeerChange received") } -func logPollFunc( - mapRequest tailcfg.MapRequest, - node *types.Node, -) (func(string, ...any), func(string, ...any), func(string, ...any), func(error, string, ...any)) { - return func(msg string, a ...any) { - log.Warn(). - Caller(). - Bool("omitPeers", mapRequest.OmitPeers). - Bool("stream", mapRequest.Stream). - Uint64("node.id", node.ID.Uint64()). - Str("node.name", node.Hostname). - Msgf(msg, a...) - }, - func(msg string, a ...any) { - log.Info(). - Caller(). - Bool("omitPeers", mapRequest.OmitPeers). - Bool("stream", mapRequest.Stream). - Uint64("node.id", node.ID.Uint64()). - Str("node.name", node.Hostname). - Msgf(msg, a...) - }, - func(msg string, a ...any) { - log.Trace(). - Caller(). - Bool("omitPeers", mapRequest.OmitPeers). - Bool("stream", mapRequest.Stream). - Uint64("node.id", node.ID.Uint64()). - Str("node.name", node.Hostname). - Msgf(msg, a...) - }, - func(err error, msg string, a ...any) { - log.Error(). - Caller(). - Bool("omitPeers", mapRequest.OmitPeers). - Bool("stream", mapRequest.Stream). - Uint64("node.id", node.ID.Uint64()). - Str("node.name", node.Hostname). - Err(err). - Msgf(msg, a...) - } +// logf adds common mapSession context to a zerolog event. +func (m *mapSession) logf(event *zerolog.Event) *zerolog.Event { + return event. + Bool("omitPeers", m.req.OmitPeers). + Bool("stream", m.req.Stream). + Uint64("node.id", m.node.ID.Uint64()). + Str("node.name", m.node.Hostname) +} + +//nolint:zerologlint // logf returns *zerolog.Event which is properly terminated with Msgf +func (m *mapSession) infof(msg string, a ...any) { m.logf(log.Info().Caller()).Msgf(msg, a...) } + +//nolint:zerologlint // logf returns *zerolog.Event which is properly terminated with Msgf +func (m *mapSession) tracef(msg string, a ...any) { m.logf(log.Trace().Caller()).Msgf(msg, a...) } + +//nolint:zerologlint // logf returns *zerolog.Event which is properly terminated with Msgf +func (m *mapSession) errf(err error, msg string, a ...any) { + m.logf(log.Error().Caller()).Err(err).Msgf(msg, a...) } diff --git a/hscontrol/state/debug.go b/hscontrol/state/debug.go index 757591ad..ef6ef50c 100644 --- a/hscontrol/state/debug.go +++ b/hscontrol/state/debug.go @@ -78,7 +78,7 @@ func (s *State) DebugOverview() string { now := time.Now() for _, node := range allNodes.All() { if node.Valid() { - userName := node.User().Name + userName := node.User().Name() userNodeCounts[userName]++ if node.IsOnline().Valid() && node.IsOnline().Get() { @@ -281,7 +281,7 @@ func (s *State) DebugOverviewJSON() DebugOverviewInfo { for _, node := range allNodes.All() { if node.Valid() { - userName := node.User().Name + userName := node.User().Name() info.Users[userName]++ if node.IsOnline().Valid() && node.IsOnline().Get() { diff --git a/hscontrol/state/maprequest_test.go b/hscontrol/state/maprequest_test.go index 865d3eb4..99f781d4 100644 --- a/hscontrol/state/maprequest_test.go +++ b/hscontrol/state/maprequest_test.go @@ -9,6 +9,7 @@ import ( "github.com/stretchr/testify/require" "tailscale.com/tailcfg" "tailscale.com/types/key" + "tailscale.com/types/ptr" ) func TestNetInfoFromMapRequest(t *testing.T) { @@ -148,8 +149,8 @@ func createTestNodeSimple(id types.NodeID) *types.Node { node := &types.Node{ ID: id, Hostname: "test-node", - UserID: uint(id), - User: user, + UserID: ptr.To(uint(id)), + User: &user, MachineKey: machineKey.Public(), NodeKey: nodeKey.Public(), IPv4: &netip.Addr{}, diff --git a/hscontrol/state/node_store.go b/hscontrol/state/node_store.go index adcc410a..241d2f46 100644 --- a/hscontrol/state/node_store.go +++ b/hscontrol/state/node_store.go @@ -408,7 +408,7 @@ func snapshotFromNodes(nodes map[types.NodeID]types.Node, peersFunc PeersFunc) S // Build nodesByUser, nodesByNodeKey, and nodesByMachineKey maps for _, n := range nodes { nodeView := n.View() - userID := types.UserID(n.UserID) + userID := n.TypedUserID() newSnap.nodesByUser[userID] = append(newSnap.nodesByUser[userID], nodeView) newSnap.nodesByNodeKey[n.NodeKey] = nodeView @@ -515,7 +515,7 @@ func (s *NodeStore) DebugString() string { if len(nodes) > 0 { userName := "unknown" if len(nodes) > 0 && nodes[0].Valid() { - userName = nodes[0].User().Name + userName = nodes[0].User().Name() } sb.WriteString(fmt.Sprintf(" - User %d (%s): %d nodes\n", userID, userName, len(nodes))) } diff --git a/hscontrol/state/node_store_test.go b/hscontrol/state/node_store_test.go index 788721b9..82f1a255 100644 --- a/hscontrol/state/node_store_test.go +++ b/hscontrol/state/node_store_test.go @@ -13,6 +13,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "tailscale.com/types/key" + "tailscale.com/types/ptr" ) func TestSnapshotFromNodes(t *testing.T) { @@ -173,8 +174,8 @@ func createTestNode(nodeID types.NodeID, userID uint, username, hostname string) DiscoKey: discoKey.Public(), Hostname: hostname, GivenName: hostname, - UserID: userID, - User: types.User{ + UserID: ptr.To(userID), + User: &types.User{ Name: username, DisplayName: username, }, @@ -627,7 +628,7 @@ func TestNodeStoreOperations(t *testing.T) { go func() { resultNode3, ok3 = store.UpdateNode(1, func(n *types.Node) { - n.ForcedTags = []string{"tag1", "tag2"} + n.Tags = []string{"tag1", "tag2"} }) close(done3) }() @@ -648,24 +649,24 @@ func TestNodeStoreOperations(t *testing.T) { // resultNode1 (from hostname update) should also have the givenname and tags changes assert.Equal(t, "multi-update-hostname", resultNode1.Hostname()) assert.Equal(t, "multi-update-givenname", resultNode1.GivenName()) - assert.Equal(t, []string{"tag1", "tag2"}, resultNode1.ForcedTags().AsSlice()) + assert.Equal(t, []string{"tag1", "tag2"}, resultNode1.Tags().AsSlice()) // resultNode2 (from givenname update) should also have the hostname and tags changes assert.Equal(t, "multi-update-hostname", resultNode2.Hostname()) assert.Equal(t, "multi-update-givenname", resultNode2.GivenName()) - assert.Equal(t, []string{"tag1", "tag2"}, resultNode2.ForcedTags().AsSlice()) + assert.Equal(t, []string{"tag1", "tag2"}, resultNode2.Tags().AsSlice()) // resultNode3 (from tags update) should also have the hostname and givenname changes assert.Equal(t, "multi-update-hostname", resultNode3.Hostname()) assert.Equal(t, "multi-update-givenname", resultNode3.GivenName()) - assert.Equal(t, []string{"tag1", "tag2"}, resultNode3.ForcedTags().AsSlice()) + assert.Equal(t, []string{"tag1", "tag2"}, resultNode3.Tags().AsSlice()) // Verify the snapshot also has all changes snapshot := store.data.Load() finalNode := snapshot.nodesByID[1] assert.Equal(t, "multi-update-hostname", finalNode.Hostname) assert.Equal(t, "multi-update-givenname", finalNode.GivenName) - assert.Equal(t, []string{"tag1", "tag2"}, finalNode.ForcedTags) + assert.Equal(t, []string{"tag1", "tag2"}, finalNode.Tags) }, }, }, @@ -687,7 +688,7 @@ func TestNodeStoreOperations(t *testing.T) { resultNode, ok := store.UpdateNode(1, func(n *types.Node) { n.Hostname = "db-save-hostname" n.GivenName = "db-save-given" - n.ForcedTags = []string{"db-tag1", "db-tag2"} + n.Tags = []string{"db-tag1", "db-tag2"} }) assert.True(t, ok, "UpdateNode should succeed") @@ -696,21 +697,21 @@ func TestNodeStoreOperations(t *testing.T) { // Verify the returned node has all expected values assert.Equal(t, "db-save-hostname", resultNode.Hostname()) assert.Equal(t, "db-save-given", resultNode.GivenName()) - assert.Equal(t, []string{"db-tag1", "db-tag2"}, resultNode.ForcedTags().AsSlice()) + assert.Equal(t, []string{"db-tag1", "db-tag2"}, resultNode.Tags().AsSlice()) // Convert to struct as would be done for database save nodePtr := resultNode.AsStruct() assert.NotNil(t, nodePtr) assert.Equal(t, "db-save-hostname", nodePtr.Hostname) assert.Equal(t, "db-save-given", nodePtr.GivenName) - assert.Equal(t, []string{"db-tag1", "db-tag2"}, nodePtr.ForcedTags) + assert.Equal(t, []string{"db-tag1", "db-tag2"}, nodePtr.Tags) // Verify the snapshot also reflects the same state snapshot := store.data.Load() storedNode := snapshot.nodesByID[1] assert.Equal(t, "db-save-hostname", storedNode.Hostname) assert.Equal(t, "db-save-given", storedNode.GivenName) - assert.Equal(t, []string{"db-tag1", "db-tag2"}, storedNode.ForcedTags) + assert.Equal(t, []string{"db-tag1", "db-tag2"}, storedNode.Tags) }, }, { @@ -742,7 +743,7 @@ func TestNodeStoreOperations(t *testing.T) { go func() { result3, ok3 = store.UpdateNode(1, func(n *types.Node) { - n.ForcedTags = []string{"concurrent-tag"} + n.Tags = []string{"concurrent-tag"} }) close(done3) }() @@ -767,22 +768,22 @@ func TestNodeStoreOperations(t *testing.T) { // All should have the complete final state assert.Equal(t, "concurrent-db-hostname", nodePtr1.Hostname) assert.Equal(t, "concurrent-db-given", nodePtr1.GivenName) - assert.Equal(t, []string{"concurrent-tag"}, nodePtr1.ForcedTags) + assert.Equal(t, []string{"concurrent-tag"}, nodePtr1.Tags) assert.Equal(t, "concurrent-db-hostname", nodePtr2.Hostname) assert.Equal(t, "concurrent-db-given", nodePtr2.GivenName) - assert.Equal(t, []string{"concurrent-tag"}, nodePtr2.ForcedTags) + assert.Equal(t, []string{"concurrent-tag"}, nodePtr2.Tags) assert.Equal(t, "concurrent-db-hostname", nodePtr3.Hostname) assert.Equal(t, "concurrent-db-given", nodePtr3.GivenName) - assert.Equal(t, []string{"concurrent-tag"}, nodePtr3.ForcedTags) + assert.Equal(t, []string{"concurrent-tag"}, nodePtr3.Tags) // Verify consistency with stored state snapshot := store.data.Load() storedNode := snapshot.nodesByID[1] assert.Equal(t, nodePtr1.Hostname, storedNode.Hostname) assert.Equal(t, nodePtr1.GivenName, storedNode.GivenName) - assert.Equal(t, nodePtr1.ForcedTags, storedNode.ForcedTags) + assert.Equal(t, nodePtr1.Tags, storedNode.Tags) }, }, { @@ -855,8 +856,8 @@ func createConcurrentTestNode(id types.NodeID, hostname string) types.Node { Hostname: hostname, MachineKey: machineKey.Public(), NodeKey: nodeKey.Public(), - UserID: 1, - User: types.User{ + UserID: ptr.To(uint(1)), + User: &types.User{ Name: "concurrent-test-user", }, } diff --git a/hscontrol/state/state.go b/hscontrol/state/state.go index 68aefbc4..149bae4d 100644 --- a/hscontrol/state/state.go +++ b/hscontrol/state/state.go @@ -53,6 +53,9 @@ const ( // ErrUnsupportedPolicyMode is returned for invalid policy modes. Valid modes are "file" and "db". var ErrUnsupportedPolicyMode = errors.New("unsupported policy mode") +// ErrNodeNotFound is returned when a node cannot be found by its ID. +var ErrNodeNotFound = errors.New("node not found") + // State manages Headscale's core state, coordinating between database, policy management, // IP allocation, and DERP routing. All methods are thread-safe. type State struct { @@ -651,13 +654,36 @@ func (s *State) SetNodeExpiry(nodeID types.NodeID, expiry time.Time) (types.Node return s.persistNodeToDB(n) } -// SetNodeTags assigns tags to a node for use in access control policies. +// SetNodeTags assigns tags to a node, making it a "tagged node". +// Once a node is tagged, it cannot be un-tagged (only tags can be changed). +// The UserID is preserved as "created by" information. func (s *State) SetNodeTags(nodeID types.NodeID, tags []string) (types.NodeView, change.ChangeSet, error) { + // CANNOT REMOVE ALL TAGS + if len(tags) == 0 { + return types.NodeView{}, change.EmptySet, types.ErrCannotRemoveAllTags + } + + // Get node for validation + existingNode, exists := s.nodeStore.GetNode(nodeID) + if !exists { + return types.NodeView{}, change.EmptySet, fmt.Errorf("%w: %d", ErrNodeNotFound, nodeID) + } + + // Validate tags against policy + validatedTags, err := s.validateAndNormalizeTags(existingNode.AsStruct(), tags) + if err != nil { + return types.NodeView{}, change.EmptySet, err + } + + // Log the operation + logTagOperation(existingNode, validatedTags) + // Update NodeStore before database to ensure consistency. The NodeStore update is // blocking and will be the source of truth for the batcher. The database update must // make the exact same change. n, ok := s.nodeStore.UpdateNode(nodeID, func(node *types.Node) { - node.ForcedTags = tags + node.Tags = validatedTags + // UserID is preserved as "created by" - do NOT set to nil }) if !ok { @@ -927,7 +953,8 @@ func (s *State) DestroyAPIKey(key types.APIKey) error { } // CreatePreAuthKey generates a new pre-authentication key for a user. -func (s *State) CreatePreAuthKey(userID types.UserID, reusable bool, ephemeral bool, expiration *time.Time, aclTags []string) (*types.PreAuthKeyNew, error) { +// The userID parameter is now optional (can be nil) for system-created tagged keys. +func (s *State) CreatePreAuthKey(userID *types.UserID, reusable bool, ephemeral bool, expiration *time.Time, aclTags []string) (*types.PreAuthKeyNew, error) { return s.db.CreatePreAuthKey(userID, reusable, ephemeral, expiration, aclTags) } @@ -1063,8 +1090,6 @@ func (s *State) createAndSaveNewNode(params newNodeParams) (types.NodeView, erro // Prepare the node for registration nodeToRegister := types.Node{ Hostname: params.Hostname, - UserID: params.User.ID, - User: params.User, MachineKey: params.MachineKey, NodeKey: params.NodeKey, DiscoKey: params.DiscoKey, @@ -1075,11 +1100,38 @@ func (s *State) createAndSaveNewNode(params newNodeParams) (types.NodeView, erro Expiry: params.Expiry, } - // Pre-auth key specific fields + // Assign ownership based on PreAuthKey if params.PreAuthKey != nil { - nodeToRegister.ForcedTags = params.PreAuthKey.Proto().GetAclTags() + if params.PreAuthKey.IsTagged() { + // TAGGED NODE + // Tags from PreAuthKey are assigned ONLY during initial authentication + nodeToRegister.Tags = params.PreAuthKey.Proto().GetAclTags() + + // Set UserID to track "created by" (who created the PreAuthKey) + if params.PreAuthKey.UserID != nil { + nodeToRegister.UserID = params.PreAuthKey.UserID + nodeToRegister.User = params.PreAuthKey.User + } + // If PreAuthKey.UserID is nil, the node is "orphaned" (system-created) + } else { + // USER-OWNED NODE + nodeToRegister.UserID = ¶ms.PreAuthKey.User.ID + nodeToRegister.User = params.PreAuthKey.User + nodeToRegister.Tags = nil + } nodeToRegister.AuthKey = params.PreAuthKey nodeToRegister.AuthKeyID = ¶ms.PreAuthKey.ID + } else { + // Non-PreAuthKey registration (OIDC, CLI) - always user-owned + nodeToRegister.UserID = ¶ms.User.ID + nodeToRegister.User = ¶ms.User + nodeToRegister.Tags = nil + } + + // Validate before saving + err := validateNodeOwnership(&nodeToRegister) + if err != nil { + return types.NodeView{}, err } // Allocate new IPs @@ -1156,7 +1208,7 @@ func (s *State) HandleNodeFromAuthPath( logHostinfoValidation( regEntry.Node.MachineKey.ShortString(), regEntry.Node.NodeKey.String(), - user.Username(), + user.Name, hostname, regEntry.Node.Hostinfo, ) @@ -1171,7 +1223,7 @@ func (s *State) HandleNodeFromAuthPath( log.Debug(). Caller(). Str("registration_id", registrationID.String()). - Str("user.name", user.Username()). + Str("user.name", user.Name). Str("registrationMethod", registrationMethod). Str("node.name", existingNodeSameUser.Hostname()). Uint64("node.id", existingNodeSameUser.ID().Uint64()). @@ -1233,7 +1285,7 @@ func (s *State) HandleNodeFromAuthPath( // Check if node exists with this machine key for a different user (for netinfo preservation) existingNodeAnyUser, existsAnyUser := s.nodeStore.GetNodeByMachineKeyAnyUser(regEntry.Node.MachineKey) - if existsAnyUser && existingNodeAnyUser.Valid() && existingNodeAnyUser.UserID() != user.ID { + if existsAnyUser && existingNodeAnyUser.Valid() && existingNodeAnyUser.UserID().Get() != user.ID { // Node exists but belongs to a different user // Create a NEW node for the new user (do not transfer) // This allows the same machine to have separate node identities per user @@ -1243,8 +1295,8 @@ func (s *State) HandleNodeFromAuthPath( Str("existing.node.name", existingNodeAnyUser.Hostname()). Uint64("existing.node.id", existingNodeAnyUser.ID().Uint64()). Str("machine.key", regEntry.Node.MachineKey.ShortString()). - Str("old.user", oldUser.Username()). - Str("new.user", user.Username()). + Str("old.user", oldUser.Name()). + Str("new.user", user.Name). Str("method", registrationMethod). Msg("Creating new node for different user (same machine key exists for another user)") } @@ -1253,7 +1305,7 @@ func (s *State) HandleNodeFromAuthPath( log.Debug(). Caller(). Str("registration_id", registrationID.String()). - Str("user.name", user.Username()). + Str("user.name", user.Name). Str("registrationMethod", registrationMethod). Str("expiresAt", fmt.Sprintf("%v", expiry)). Msg("Registering new node from auth callback") @@ -1416,8 +1468,11 @@ func (s *State) HandleNodeFromPreAuthKey( node.RegisterMethod = util.RegisterMethodAuthKey - // TODO(kradalby): This might need a rework as part of #2417 - node.ForcedTags = pak.Proto().GetAclTags() + // CRITICAL: Tags from PreAuthKey are ONLY applied during initial authentication + // On re-registration, we MUST NOT change tags or node ownership + // The node keeps whatever tags/user ownership it already has + // + // Only update AuthKey reference node.AuthKey = pak node.AuthKeyID = &pak.ID node.IsOnline = ptr.To(false) @@ -1467,7 +1522,7 @@ func (s *State) HandleNodeFromPreAuthKey( // Check if node exists with this machine key for a different user existingNodeAnyUser, existsAnyUser := s.nodeStore.GetNodeByMachineKeyAnyUser(machineKey) - if existsAnyUser && existingNodeAnyUser.Valid() && existingNodeAnyUser.UserID() != pak.User.ID { + if existsAnyUser && existingNodeAnyUser.Valid() && existingNodeAnyUser.UserID().Get() != pak.User.ID { // Node exists but belongs to a different user // Create a NEW node for the new user (do not transfer) // This allows the same machine to have separate node identities per user @@ -1477,7 +1532,7 @@ func (s *State) HandleNodeFromPreAuthKey( Str("existing.node.name", existingNodeAnyUser.Hostname()). Uint64("existing.node.id", existingNodeAnyUser.ID().Uint64()). Str("machine.key", machineKey.ShortString()). - Str("old.user", oldUser.Username()). + Str("old.user", oldUser.Name()). Str("new.user", pak.User.Username()). Msg("Creating new node for different user (same machine key exists for another user)") } @@ -1488,7 +1543,7 @@ func (s *State) HandleNodeFromPreAuthKey( // Create and save new node var err error finalNode, err = s.createAndSaveNewNode(newNodeParams{ - User: pak.User, + User: *pak.User, MachineKey: machineKey, NodeKey: regReq.NodeKey, DiscoKey: key.DiscoPublic{}, // DiscoKey not available in RegisterRequest diff --git a/hscontrol/state/tags.go b/hscontrol/state/tags.go new file mode 100644 index 00000000..c1dd3127 --- /dev/null +++ b/hscontrol/state/tags.go @@ -0,0 +1,107 @@ +package state + +import ( + "errors" + "fmt" + "slices" + "strings" + + "github.com/juanfont/headscale/hscontrol/types" + "github.com/rs/zerolog/log" +) + +var ( + // ErrNodeMarkedTaggedButHasNoTags is returned when a node is marked as tagged but has no tags. + ErrNodeMarkedTaggedButHasNoTags = errors.New("node marked as tagged but has no tags") + + // ErrNodeHasNeitherUserNorTags is returned when a node has neither a user nor tags. + ErrNodeHasNeitherUserNorTags = errors.New("node has neither user nor tags - must be owned by user or tagged") + + // ErrInvalidOrUnauthorizedTags is returned when tags are invalid or unauthorized. + ErrInvalidOrUnauthorizedTags = errors.New("invalid or unauthorized tags") +) + +// validateNodeOwnership ensures proper node ownership model. +// A node must be EITHER user-owned OR tagged (mutually exclusive by behavior). +// Tagged nodes CAN have a UserID for "created by" tracking, but the tag is the owner. +func validateNodeOwnership(node *types.Node) error { + isTagged := node.IsTagged() + + // Tagged nodes: Must have tags, UserID is optional (just "created by") + if isTagged { + if len(node.Tags) == 0 { + return fmt.Errorf("%w: %q", ErrNodeMarkedTaggedButHasNoTags, node.Hostname) + } + // UserID can be set (created by) or nil (orphaned), both valid for tagged nodes + return nil + } + + // User-owned nodes: Must have UserID, must NOT have tags + if node.UserID == nil { + return fmt.Errorf("%w: %q", ErrNodeHasNeitherUserNorTags, node.Hostname) + } + + return nil +} + +// validateAndNormalizeTags validates tags against policy and normalizes them. +// Returns validated and normalized tags, or an error if validation fails. +func (s *State) validateAndNormalizeTags(node *types.Node, requestedTags []string) ([]string, error) { + if len(requestedTags) == 0 { + return nil, nil + } + + var ( + validTags []string + invalidTags []string + ) + + for _, tag := range requestedTags { + // Validate format + if !strings.HasPrefix(tag, "tag:") { + invalidTags = append(invalidTags, tag) + continue + } + + // Validate against policy + nodeView := node.View() + if s.polMan.NodeCanHaveTag(nodeView, tag) { + validTags = append(validTags, tag) + } else { + invalidTags = append(invalidTags, tag) + } + } + + if len(invalidTags) > 0 { + return nil, fmt.Errorf("%w: %v", ErrInvalidOrUnauthorizedTags, invalidTags) + } + + // Normalize: sort and deduplicate + slices.Sort(validTags) + + return slices.Compact(validTags), nil +} + +// logTagOperation logs tag assignment operations for audit purposes. +func logTagOperation(existingNode types.NodeView, newTags []string) { + if existingNode.IsTagged() { + log.Info(). + Uint64("node.id", existingNode.ID().Uint64()). + Str("node.name", existingNode.Hostname()). + Strs("old.tags", existingNode.Tags().AsSlice()). + Strs("new.tags", newTags). + Msg("Updating tags on already-tagged node") + } else { + var userID uint + if existingNode.UserID().Valid() { + userID = existingNode.UserID().Get() + } + + log.Info(). + Uint64("node.id", existingNode.ID().Uint64()). + Str("node.name", existingNode.Hostname()). + Uint("created.by.user", userID). + Strs("new.tags", newTags). + Msg("Converting user-owned node to tagged node (irreversible)") + } +} diff --git a/hscontrol/types/node.go b/hscontrol/types/node.go index 05eb8a35..f8264392 100644 --- a/hscontrol/types/node.go +++ b/hscontrol/types/node.go @@ -6,7 +6,6 @@ import ( "net/netip" "regexp" "slices" - "sort" "strconv" "strings" "time" @@ -28,6 +27,7 @@ var ( ErrHostnameTooLong = errors.New("hostname too long, cannot except 255 ASCII chars") ErrNodeHasNoGivenName = errors.New("node has no given name") ErrNodeUserHasNoName = errors.New("node user has no name") + ErrCannotRemoveAllTags = errors.New("cannot remove all tags from node") invalidDNSRegex = regexp.MustCompile("[^a-z0-9-.]+") ) @@ -97,16 +97,21 @@ type Node struct { // GivenName is the name used in all DNS related // parts of headscale. GivenName string `gorm:"type:varchar(63);unique_index"` - UserID uint - User User `gorm:"constraint:OnDelete:CASCADE;"` + + // UserID is set for ALL nodes (tagged and user-owned) to track "created by". + // For tagged nodes, this is informational only - the tag is the owner. + // For user-owned nodes, this identifies the owner. + // Only nil for orphaned nodes (should not happen in normal operation). + UserID *uint + User *User `gorm:"constraint:OnDelete:CASCADE;"` RegisterMethod string - // ForcedTags are tags set by CLI/API. It is not considered - // the source of truth, but is one of the sources from - // which a tag might originate. - // ForcedTags are _always_ applied to the node. - ForcedTags []string `gorm:"column:forced_tags;serializer:json"` + // Tags is the definitive owner for tagged nodes. + // When non-empty, the node is "tagged" and tags define its identity. + // Empty for user-owned nodes. + // Tags cannot be removed once set (one-way transition). + Tags []string `gorm:"column:tags;serializer:json"` // When a node has been created with a PreAuthKey, we need to // prevent the preauthkey from being deleted before the node. @@ -196,55 +201,32 @@ func (node *Node) HasIP(i netip.Addr) bool { return false } -// IsTagged reports if a device is tagged -// and therefore should not be treated as a -// user owned device. -// Currently, this function only handles tags set -// via CLI ("forced tags" and preauthkeys). +// IsTagged reports if a device is tagged and therefore should not be treated +// as a user-owned device. +// When a node has tags, the tags define its identity (not the user). func (node *Node) IsTagged() bool { - if len(node.ForcedTags) > 0 { - return true - } + return len(node.Tags) > 0 +} - if node.AuthKey != nil && len(node.AuthKey.Tags) > 0 { - return true - } - - if node.Hostinfo == nil { - return false - } - - // TODO(kradalby): Figure out how tagging should work - // and hostinfo.requestedtags. - // Do this in other work. - - return false +// IsUserOwned returns true if node is owned by a user (not tagged). +// Tagged nodes may have a UserID for "created by" tracking, but the tag is the owner. +func (node *Node) IsUserOwned() bool { + return !node.IsTagged() } // HasTag reports if a node has a given tag. -// Currently, this function only handles tags set -// via CLI ("forced tags" and preauthkeys). func (node *Node) HasTag(tag string) bool { - return slices.Contains(node.Tags(), tag) + return slices.Contains(node.Tags, tag) } -func (node *Node) Tags() []string { - var tags []string - - if node.AuthKey != nil { - tags = append(tags, node.AuthKey.Tags...) +// TypedUserID returns the UserID as a typed UserID type. +// Returns 0 if UserID is nil. +func (node *Node) TypedUserID() UserID { + if node.UserID == nil { + return 0 } - // TODO(kradalby): Figure out how tagging should work - // and hostinfo.requestedtags. - // Do this in other work. - // #2417 - - tags = append(tags, node.ForcedTags...) - sort.Strings(tags) - tags = slices.Compact(tags) - - return tags + return UserID(*node.UserID) } func (node *Node) RequestTags() []string { @@ -389,8 +371,8 @@ func (node *Node) Proto() *v1.Node { IpAddresses: node.IPsAsString(), Name: node.Hostname, GivenName: node.GivenName, - User: node.User.Proto(), - ForcedTags: node.ForcedTags, + User: nil, // Will be set below based on node type + ForcedTags: node.Tags, Online: node.IsOnline != nil && *node.IsOnline, // Only ApprovedRoutes and AvailableRoutes is set here. SubnetRoutes has @@ -404,6 +386,13 @@ func (node *Node) Proto() *v1.Node { CreatedAt: timestamppb.New(node.CreatedAt), } + // Set User field based on node ownership + // Note: User will be set to TaggedDevices in the gRPC layer (grpcv1.go) + // for proper MapResponse formatting + if node.User != nil { + nodeProto.User = node.User.Proto() + } + if node.AuthKey != nil { nodeProto.PreAuthKey = node.AuthKey.Proto() } @@ -701,8 +690,20 @@ func (nodes Nodes) DebugString() string { func (node Node) DebugString() string { var sb strings.Builder fmt.Fprintf(&sb, "%s(%s):\n", node.Hostname, node.ID) - fmt.Fprintf(&sb, "\tUser: %s (%d, %q)\n", node.User.Display(), node.User.ID, node.User.Username()) - fmt.Fprintf(&sb, "\tTags: %v\n", node.Tags()) + + // Show ownership status + if node.IsTagged() { + fmt.Fprintf(&sb, "\tTagged: %v\n", node.Tags) + + if node.User != nil { + fmt.Fprintf(&sb, "\tCreated by: %s (%d, %q)\n", node.User.Display(), node.User.ID, node.User.Username()) + } + } else if node.User != nil { + fmt.Fprintf(&sb, "\tUser-owned: %s (%d, %q)\n", node.User.Display(), node.User.ID, node.User.Username()) + } else { + fmt.Fprintf(&sb, "\tOrphaned: no user or tags\n") + } + fmt.Fprintf(&sb, "\tIPs: %v\n", node.IPs()) fmt.Fprintf(&sb, "\tApprovedRoutes: %v\n", node.ApprovedRoutes) fmt.Fprintf(&sb, "\tAnnouncedRoutes: %v\n", node.AnnouncedRoutes()) @@ -714,8 +715,7 @@ func (node Node) DebugString() string { } func (v NodeView) UserView() UserView { - u := v.User() - return u.View() + return v.User() } func (v NodeView) IPs() []netip.Addr { @@ -790,13 +790,6 @@ func (v NodeView) RequestTagsSlice() views.Slice[string] { return v.Hostinfo().RequestTags() } -func (v NodeView) Tags() []string { - if !v.Valid() { - return nil - } - return v.ж.Tags() -} - // IsTagged reports if a device is tagged // and therefore should not be treated as a // user owned device. @@ -893,6 +886,32 @@ func (v NodeView) HasTag(tag string) bool { return v.ж.HasTag(tag) } +// TypedUserID returns the UserID as a typed UserID type. +// Returns 0 if UserID is nil or node is invalid. +func (v NodeView) TypedUserID() UserID { + if !v.Valid() { + return 0 + } + + return v.ж.TypedUserID() +} + +// TailscaleUserID returns the user ID to use in Tailscale protocol. +// Tagged nodes always return TaggedDevices.ID, user-owned nodes return their actual UserID. +func (v NodeView) TailscaleUserID() tailcfg.UserID { + if !v.Valid() { + return 0 + } + + if v.IsTagged() { + //nolint:gosec // G115: TaggedDevices.ID is a constant that fits in int64 + return tailcfg.UserID(int64(TaggedDevices.ID)) + } + + //nolint:gosec // G115: UserID values are within int64 range + return tailcfg.UserID(int64(v.UserID().Get())) +} + // Prefixes returns the node IPs as netip.Prefix. func (v NodeView) Prefixes() []netip.Prefix { if !v.Valid() { diff --git a/hscontrol/types/node_tags_test.go b/hscontrol/types/node_tags_test.go new file mode 100644 index 00000000..72598b3c --- /dev/null +++ b/hscontrol/types/node_tags_test.go @@ -0,0 +1,295 @@ +package types + +import ( + "testing" + + "github.com/juanfont/headscale/hscontrol/util" + "github.com/stretchr/testify/assert" + "gorm.io/gorm" + "tailscale.com/types/ptr" +) + +// TestNodeIsTagged tests the IsTagged() method for determining if a node is tagged. +func TestNodeIsTagged(t *testing.T) { + tests := []struct { + name string + node Node + want bool + }{ + { + name: "node with tags - is tagged", + node: Node{ + Tags: []string{"tag:server", "tag:prod"}, + }, + want: true, + }, + { + name: "node with single tag - is tagged", + node: Node{ + Tags: []string{"tag:web"}, + }, + want: true, + }, + { + name: "node with no tags - not tagged", + node: Node{ + Tags: []string{}, + }, + want: false, + }, + { + name: "node with nil tags - not tagged", + node: Node{ + Tags: nil, + }, + want: false, + }, + { + // Tags should be copied from AuthKey during registration, so a node + // with only AuthKey.Tags and no Tags would be invalid in practice. + // IsTagged() only checks node.Tags, not AuthKey.Tags. + name: "node registered with tagged authkey only - not tagged (tags should be copied)", + node: Node{ + AuthKey: &PreAuthKey{ + Tags: []string{"tag:database"}, + }, + }, + want: false, + }, + { + name: "node with both tags and authkey tags - is tagged", + node: Node{ + Tags: []string{"tag:server"}, + AuthKey: &PreAuthKey{ + Tags: []string{"tag:database"}, + }, + }, + want: true, + }, + { + name: "node with user and no tags - not tagged", + node: Node{ + UserID: ptr.To(uint(42)), + Tags: []string{}, + }, + want: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := tt.node.IsTagged() + assert.Equal(t, tt.want, got, "IsTagged() returned unexpected value") + }) + } +} + +// TestNodeViewIsTagged tests the IsTagged() method on NodeView. +func TestNodeViewIsTagged(t *testing.T) { + tests := []struct { + name string + node Node + want bool + }{ + { + name: "tagged node via Tags field", + node: Node{ + Tags: []string{"tag:server"}, + }, + want: true, + }, + { + // Tags should be copied from AuthKey during registration, so a node + // with only AuthKey.Tags and no Tags would be invalid in practice. + name: "node with only AuthKey tags - not tagged (tags should be copied)", + node: Node{ + AuthKey: &PreAuthKey{ + Tags: []string{"tag:web"}, + }, + }, + want: false, // IsTagged() only checks node.Tags + }, + { + name: "user-owned node", + node: Node{ + UserID: ptr.To(uint(1)), + }, + want: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + view := tt.node.View() + got := view.IsTagged() + assert.Equal(t, tt.want, got, "NodeView.IsTagged() returned unexpected value") + }) + } +} + +// TestNodeHasTag tests the HasTag() method for checking specific tag membership. +func TestNodeHasTag(t *testing.T) { + tests := []struct { + name string + node Node + tag string + want bool + }{ + { + name: "node has the tag", + node: Node{ + Tags: []string{"tag:server", "tag:prod"}, + }, + tag: "tag:server", + want: true, + }, + { + name: "node does not have the tag", + node: Node{ + Tags: []string{"tag:server", "tag:prod"}, + }, + tag: "tag:web", + want: false, + }, + { + // Tags should be copied from AuthKey during registration + // HasTag() only checks node.Tags, not AuthKey.Tags + name: "node has tag only in authkey - returns false", + node: Node{ + AuthKey: &PreAuthKey{ + Tags: []string{"tag:database"}, + }, + }, + tag: "tag:database", + want: false, + }, + { + // node.Tags is what matters, not AuthKey.Tags + name: "node has tag in Tags but not in AuthKey", + node: Node{ + Tags: []string{"tag:server"}, + AuthKey: &PreAuthKey{ + Tags: []string{"tag:database"}, + }, + }, + tag: "tag:server", + want: true, + }, + { + name: "invalid tag format still returns false", + node: Node{ + Tags: []string{"tag:server"}, + }, + tag: "invalid-tag", + want: false, + }, + { + name: "empty tag returns false", + node: Node{ + Tags: []string{"tag:server"}, + }, + tag: "", + want: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := tt.node.HasTag(tt.tag) + assert.Equal(t, tt.want, got, "HasTag() returned unexpected value") + }) + } +} + +// TestNodeTagsImmutableAfterRegistration tests that tags can only be set during registration. +func TestNodeTagsImmutableAfterRegistration(t *testing.T) { + // Test that a node registered with tags keeps them + taggedNode := Node{ + ID: 1, + Tags: []string{"tag:server"}, + AuthKey: &PreAuthKey{ + Tags: []string{"tag:server"}, + }, + RegisterMethod: util.RegisterMethodAuthKey, + } + + // Node should be tagged + assert.True(t, taggedNode.IsTagged(), "Node registered with tags should be tagged") + + // Node should have the tag + has := taggedNode.HasTag("tag:server") + assert.True(t, has, "Node should have the tag it was registered with") + + // Test that a user-owned node is not tagged + userNode := Node{ + ID: 2, + UserID: ptr.To(uint(42)), + Tags: []string{}, + RegisterMethod: util.RegisterMethodOIDC, + } + + assert.False(t, userNode.IsTagged(), "User-owned node should not be tagged") +} + +// TestNodeOwnershipModel tests the tags-as-identity model. +func TestNodeOwnershipModel(t *testing.T) { + tests := []struct { + name string + node Node + wantIsTagged bool + description string + }{ + { + name: "tagged node has tags, UserID is informational", + node: Node{ + ID: 1, + UserID: ptr.To(uint(5)), // "created by" user 5 + Tags: []string{"tag:server"}, + }, + wantIsTagged: true, + description: "Tagged nodes may have UserID set for tracking, but ownership is defined by tags", + }, + { + name: "user-owned node has no tags", + node: Node{ + ID: 2, + UserID: ptr.To(uint(5)), + Tags: []string{}, + }, + wantIsTagged: false, + description: "User-owned nodes are owned by the user, not by tags", + }, + { + // Tags should be copied from AuthKey to Node during registration + // IsTagged() only checks node.Tags, not AuthKey.Tags + name: "node with only authkey tags - not tagged (tags should be copied)", + node: Node{ + ID: 3, + UserID: ptr.To(uint(5)), // "created by" user 5 + AuthKey: &PreAuthKey{ + Tags: []string{"tag:database"}, + }, + }, + wantIsTagged: false, + description: "IsTagged() only checks node.Tags; AuthKey.Tags should be copied during registration", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := tt.node.IsTagged() + assert.Equal(t, tt.wantIsTagged, got, tt.description) + }) + } +} + +// TestUserTypedID tests the TypedID() helper method. +func TestUserTypedID(t *testing.T) { + user := User{ + Model: gorm.Model{ID: 42}, + } + + typedID := user.TypedID() + assert.NotNil(t, typedID, "TypedID() should return non-nil pointer") + assert.Equal(t, UserID(42), *typedID, "TypedID() should return correct UserID value") +} diff --git a/hscontrol/types/node_test.go b/hscontrol/types/node_test.go index c992219e..9518833f 100644 --- a/hscontrol/types/node_test.go +++ b/hscontrol/types/node_test.go @@ -139,7 +139,7 @@ func TestNodeFQDN(t *testing.T) { name: "no-dnsconfig-with-username", node: Node{ GivenName: "test", - User: User{ + User: &User{ Name: "user", }, }, @@ -150,7 +150,7 @@ func TestNodeFQDN(t *testing.T) { name: "all-set", node: Node{ GivenName: "test", - User: User{ + User: &User{ Name: "user", }, }, @@ -160,7 +160,7 @@ func TestNodeFQDN(t *testing.T) { { name: "no-given-name", node: Node{ - User: User{ + User: &User{ Name: "user", }, }, @@ -179,7 +179,7 @@ func TestNodeFQDN(t *testing.T) { name: "no-dnsconfig", node: Node{ GivenName: "test", - User: User{ + User: &User{ Name: "user", }, }, diff --git a/hscontrol/types/preauth_key.go b/hscontrol/types/preauth_key.go index 1081f451..2ce02f02 100644 --- a/hscontrol/types/preauth_key.go +++ b/hscontrol/types/preauth_key.go @@ -23,16 +23,19 @@ type PreAuthKey struct { Prefix string Hash []byte // bcrypt - UserID uint - User User `gorm:"constraint:OnDelete:SET NULL;"` + // For tagged keys: UserID tracks who created the key (informational) + // For user-owned keys: UserID tracks the node owner + // Can be nil for system-created tagged keys + UserID *uint + User *User `gorm:"constraint:OnDelete:SET NULL;"` + Reusable bool Ephemeral bool `gorm:"default:false"` Used bool `gorm:"default:false"` - // Tags are always applied to the node and is one of - // the sources of tags a node might have. They are copied - // from the PreAuthKey when the node logs in the first time, - // and ignored after. + // Tags to assign to nodes registered with this key. + // Tags are copied to the node during registration. + // If non-empty, this creates tagged nodes (not user-owned). Tags []string `gorm:"serializer:json"` CreatedAt *time.Time @@ -48,19 +51,23 @@ type PreAuthKeyNew struct { Tags []string Expiration *time.Time CreatedAt *time.Time - User User + User *User // Can be nil for system-created tagged keys } func (key *PreAuthKeyNew) Proto() *v1.PreAuthKey { protoKey := v1.PreAuthKey{ Id: key.ID, Key: key.Key, - User: key.User.Proto(), + User: nil, // Will be set below if not nil Reusable: key.Reusable, Ephemeral: key.Ephemeral, AclTags: key.Tags, } + if key.User != nil { + protoKey.User = key.User.Proto() + } + if key.Expiration != nil { protoKey.Expiration = timestamppb.New(*key.Expiration) } @@ -74,7 +81,7 @@ func (key *PreAuthKeyNew) Proto() *v1.PreAuthKey { func (key *PreAuthKey) Proto() *v1.PreAuthKey { protoKey := v1.PreAuthKey{ - User: key.User.Proto(), + User: nil, // Will be set below if not nil Id: key.ID, Ephemeral: key.Ephemeral, Reusable: key.Reusable, @@ -82,6 +89,10 @@ func (key *PreAuthKey) Proto() *v1.PreAuthKey { AclTags: key.Tags, } + if key.User != nil { + protoKey.User = key.User.Proto() + } + // For new keys (with prefix/hash), show the prefix so users can identify the key // For legacy keys (with plaintext key), show the full key for backwards compatibility if key.Prefix != "" { @@ -139,3 +150,9 @@ func (pak *PreAuthKey) Validate() error { return nil } + +// IsTagged returns true if this PreAuthKey creates tagged nodes. +// When a PreAuthKey has tags, nodes registered with it will be tagged nodes. +func (pak *PreAuthKey) IsTagged() bool { + return len(pak.Tags) > 0 +} diff --git a/hscontrol/types/types_clone.go b/hscontrol/types/types_clone.go index 7699fb8f..4dfeedc2 100644 --- a/hscontrol/types/types_clone.go +++ b/hscontrol/types/types_clone.go @@ -54,7 +54,13 @@ func (src *Node) Clone() *Node { if dst.IPv6 != nil { dst.IPv6 = ptr.To(*src.IPv6) } - dst.ForcedTags = append(src.ForcedTags[:0:0], src.ForcedTags...) + if dst.UserID != nil { + dst.UserID = ptr.To(*src.UserID) + } + if dst.User != nil { + dst.User = ptr.To(*src.User) + } + dst.Tags = append(src.Tags[:0:0], src.Tags...) if dst.AuthKeyID != nil { dst.AuthKeyID = ptr.To(*src.AuthKeyID) } @@ -87,10 +93,10 @@ var _NodeCloneNeedsRegeneration = Node(struct { IPv6 *netip.Addr Hostname string GivenName string - UserID uint - User User + UserID *uint + User *User RegisterMethod string - ForcedTags []string + Tags []string AuthKeyID *uint64 AuthKey *PreAuthKey Expiry *time.Time @@ -111,6 +117,12 @@ func (src *PreAuthKey) Clone() *PreAuthKey { dst := new(PreAuthKey) *dst = *src dst.Hash = append(src.Hash[:0:0], src.Hash...) + if dst.UserID != nil { + dst.UserID = ptr.To(*src.UserID) + } + if dst.User != nil { + dst.User = ptr.To(*src.User) + } dst.Tags = append(src.Tags[:0:0], src.Tags...) if dst.CreatedAt != nil { dst.CreatedAt = ptr.To(*src.CreatedAt) @@ -127,8 +139,8 @@ var _PreAuthKeyCloneNeedsRegeneration = PreAuthKey(struct { Key string Prefix string Hash []byte - UserID uint - User User + UserID *uint + User *User Reusable bool Ephemeral bool Used bool diff --git a/hscontrol/types/types_view.go b/hscontrol/types/types_view.go index 076f5dbb..753e86d3 100644 --- a/hscontrol/types/types_view.go +++ b/hscontrol/types/types_view.go @@ -139,12 +139,13 @@ func (v NodeView) IPv4() views.ValuePointer[netip.Addr] { return views.ValuePo func (v NodeView) IPv6() views.ValuePointer[netip.Addr] { return views.ValuePointerOf(v.ж.IPv6) } -func (v NodeView) Hostname() string { return v.ж.Hostname } -func (v NodeView) GivenName() string { return v.ж.GivenName } -func (v NodeView) UserID() uint { return v.ж.UserID } -func (v NodeView) User() User { return v.ж.User } +func (v NodeView) Hostname() string { return v.ж.Hostname } +func (v NodeView) GivenName() string { return v.ж.GivenName } +func (v NodeView) UserID() views.ValuePointer[uint] { return views.ValuePointerOf(v.ж.UserID) } + +func (v NodeView) User() UserView { return v.ж.User.View() } func (v NodeView) RegisterMethod() string { return v.ж.RegisterMethod } -func (v NodeView) ForcedTags() views.Slice[string] { return views.SliceOf(v.ж.ForcedTags) } +func (v NodeView) Tags() views.Slice[string] { return views.SliceOf(v.ж.Tags) } func (v NodeView) AuthKeyID() views.ValuePointer[uint64] { return views.ValuePointerOf(v.ж.AuthKeyID) } func (v NodeView) AuthKey() PreAuthKeyView { return v.ж.AuthKey.View() } @@ -179,10 +180,10 @@ var _NodeViewNeedsRegeneration = Node(struct { IPv6 *netip.Addr Hostname string GivenName string - UserID uint - User User + UserID *uint + User *User RegisterMethod string - ForcedTags []string + Tags []string AuthKeyID *uint64 AuthKey *PreAuthKey Expiry *time.Time @@ -239,16 +240,17 @@ func (v *PreAuthKeyView) UnmarshalJSON(b []byte) error { return nil } -func (v PreAuthKeyView) ID() uint64 { return v.ж.ID } -func (v PreAuthKeyView) Key() string { return v.ж.Key } -func (v PreAuthKeyView) Prefix() string { return v.ж.Prefix } -func (v PreAuthKeyView) Hash() views.ByteSlice[[]byte] { return views.ByteSliceOf(v.ж.Hash) } -func (v PreAuthKeyView) UserID() uint { return v.ж.UserID } -func (v PreAuthKeyView) User() User { return v.ж.User } -func (v PreAuthKeyView) Reusable() bool { return v.ж.Reusable } -func (v PreAuthKeyView) Ephemeral() bool { return v.ж.Ephemeral } -func (v PreAuthKeyView) Used() bool { return v.ж.Used } -func (v PreAuthKeyView) Tags() views.Slice[string] { return views.SliceOf(v.ж.Tags) } +func (v PreAuthKeyView) ID() uint64 { return v.ж.ID } +func (v PreAuthKeyView) Key() string { return v.ж.Key } +func (v PreAuthKeyView) Prefix() string { return v.ж.Prefix } +func (v PreAuthKeyView) Hash() views.ByteSlice[[]byte] { return views.ByteSliceOf(v.ж.Hash) } +func (v PreAuthKeyView) UserID() views.ValuePointer[uint] { return views.ValuePointerOf(v.ж.UserID) } + +func (v PreAuthKeyView) User() UserView { return v.ж.User.View() } +func (v PreAuthKeyView) Reusable() bool { return v.ж.Reusable } +func (v PreAuthKeyView) Ephemeral() bool { return v.ж.Ephemeral } +func (v PreAuthKeyView) Used() bool { return v.ж.Used } +func (v PreAuthKeyView) Tags() views.Slice[string] { return views.SliceOf(v.ж.Tags) } func (v PreAuthKeyView) CreatedAt() views.ValuePointer[time.Time] { return views.ValuePointerOf(v.ж.CreatedAt) } @@ -263,8 +265,8 @@ var _PreAuthKeyViewNeedsRegeneration = PreAuthKey(struct { Key string Prefix string Hash []byte - UserID uint - User User + UserID *uint + User *User Reusable bool Ephemeral bool Used bool diff --git a/hscontrol/types/users.go b/hscontrol/types/users.go index b7cb1038..2e78386c 100644 --- a/hscontrol/types/users.go +++ b/hscontrol/types/users.go @@ -22,6 +22,21 @@ type UserID uint64 type Users []User +const ( + // TaggedDevicesUserID is the special user ID for tagged devices. + // This ID is used when rendering tagged nodes in the Tailscale protocol. + TaggedDevicesUserID = 2147455555 +) + +// TaggedDevices is a special user used in MapResponse for tagged nodes. +// Tagged nodes don't belong to a real user - the tag is their identity. +// This special user ID is used when rendering tagged nodes in the Tailscale protocol. +var TaggedDevices = User{ + Model: gorm.Model{ID: TaggedDevicesUserID}, + Name: "tagged-devices", + DisplayName: "Tagged Devices", +} + func (u Users) String() string { var sb strings.Builder sb.WriteString("[ ") @@ -77,6 +92,13 @@ func (u *User) StringID() string { return strconv.FormatUint(uint64(u.ID), 10) } +// TypedID returns a pointer to the user's ID as a UserID type. +// This is a convenience method to avoid ugly casting like ptr.To(types.UserID(user.ID)). +func (u *User) TypedID() *UserID { + uid := UserID(u.ID) + return &uid +} + // Username is the main way to get the username of a user, // it will return the email if it exists, the name if it exists, // the OIDCIdentifier if it exists, and the ID if nothing else exists. @@ -117,6 +139,13 @@ func (u UserView) TailscaleUser() tailcfg.User { return u.ж.TailscaleUser() } +// ID returns the user's ID. +// This is a custom accessor because gorm.Model.ID is embedded +// and the viewer generator doesn't always produce it. +func (u UserView) ID() uint { + return u.ж.ID +} + func (u *User) TailscaleLogin() tailcfg.Login { return tailcfg.Login{ ID: tailcfg.LoginID(u.ID), diff --git a/integration/cli_test.go b/integration/cli_test.go index fd2321b4..cf1badb2 100644 --- a/integration/cli_test.go +++ b/integration/cli_test.go @@ -4,6 +4,7 @@ import ( "cmp" "encoding/json" "fmt" + "slices" "strconv" "strings" "testing" @@ -18,7 +19,6 @@ import ( "github.com/juanfont/headscale/integration/tsic" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "golang.org/x/exp/slices" "tailscale.com/tailcfg" ) @@ -643,7 +643,9 @@ func TestPreAuthKeyCorrectUserLoggedInCommand(t *testing.T) { status, err := client.Status() assert.NoError(ct, err) assert.Equal(ct, "Running", status.BackendState, "Expected node to be logged in, backend state: %s", status.BackendState) - assert.Equal(ct, "userid:2", status.Self.UserID.String(), "Expected node to be logged in as userid:2") + // With tags-as-identity model, tagged nodes show as TaggedDevices user (2147455555) + // The PreAuthKey was created with tags, so the node is tagged + assert.Equal(ct, "userid:2147455555", status.Self.UserID.String(), "Expected node to be logged in as tagged-devices user") }, 30*time.Second, 2*time.Second) assert.EventuallyWithT(t, func(ct *assert.CollectT) { @@ -652,7 +654,8 @@ func TestPreAuthKeyCorrectUserLoggedInCommand(t *testing.T) { assert.NoError(ct, err) assert.Len(ct, listNodes, 2, "Should have 2 nodes after re-login") assert.Equal(ct, user1, listNodes[0].GetUser().GetName(), "First node should belong to user1") - assert.Equal(ct, user2, listNodes[1].GetUser().GetName(), "Second node should belong to user2") + // Second node is tagged (created with tagged PreAuthKey), so it shows as "tagged-devices" + assert.Equal(ct, "tagged-devices", listNodes[1].GetUser().GetName(), "Second node should be tagged-devices") }, 20*time.Second, 1*time.Second) } @@ -847,118 +850,455 @@ func TestNodeTagCommand(t *testing.T) { headscale, err := scenario.Headscale() require.NoError(t, err) - regIDs := []string{ - types.MustRegistrationID().String(), - types.MustRegistrationID().String(), - } - nodes := make([]*v1.Node, len(regIDs)) + // Test 1: Verify that tags require authorization via ACL policy + // The tags-as-identity model allows conversion from user-owned to tagged, but only + // if the tag is authorized via tagOwners in the ACL policy. + regID := types.MustRegistrationID().String() + + _, err = headscale.Execute( + []string{ + "headscale", + "debug", + "create-node", + "--name", + "user-owned-node", + "--user", + "user1", + "--key", + regID, + "--output", + "json", + }, + ) assert.NoError(t, err) - for index, regID := range regIDs { - _, err := headscale.Execute( + var userOwnedNode v1.Node + + assert.EventuallyWithT(t, func(c *assert.CollectT) { + err = executeAndUnmarshal( + headscale, []string{ "headscale", - "debug", - "create-node", - "--name", - fmt.Sprintf("node-%d", index+1), + "nodes", "--user", "user1", + "register", "--key", regID, "--output", "json", }, - ) - assert.NoError(t, err) - - var node v1.Node - assert.EventuallyWithT(t, func(c *assert.CollectT) { - err = executeAndUnmarshal( - headscale, - []string{ - "headscale", - "nodes", - "--user", - "user1", - "register", - "--key", - regID, - "--output", - "json", - }, - &node, - ) - assert.NoError(c, err) - }, 10*time.Second, 200*time.Millisecond, "Waiting for node registration") - - nodes[index] = &node - } - assert.EventuallyWithT(t, func(ct *assert.CollectT) { - assert.Len(ct, nodes, len(regIDs), "Should have correct number of nodes after CLI operations") - }, 15*time.Second, 1*time.Second) - - var node v1.Node - assert.EventuallyWithT(t, func(c *assert.CollectT) { - err = executeAndUnmarshal( - headscale, - []string{ - "headscale", - "nodes", - "tag", - "-i", "1", - "-t", "tag:test", - "--output", "json", - }, - &node, + &userOwnedNode, ) assert.NoError(c, err) - }, 10*time.Second, 200*time.Millisecond, "Waiting for node tag command") + }, 10*time.Second, 200*time.Millisecond, "Waiting for user-owned node registration") - assert.Equal(t, []string{"tag:test"}, node.GetForcedTags()) + // Verify node is user-owned (no tags) + assert.Empty(t, userOwnedNode.GetValidTags(), "User-owned node should not have tags") + assert.Empty(t, userOwnedNode.GetForcedTags(), "User-owned node should not have forced tags") + // Attempt to set tags on user-owned node should FAIL because there's no ACL policy + // authorizing the tag. The tags-as-identity model allows conversion from user-owned + // to tagged, but only if the tag is authorized via tagOwners in the ACL policy. _, err = headscale.Execute( []string{ "headscale", "nodes", "tag", - "-i", "2", - "-t", "wrong-tag", + "-i", strconv.FormatUint(userOwnedNode.GetId(), 10), + "-t", "tag:test", "--output", "json", }, ) - assert.ErrorContains(t, err, "tag must start with the string 'tag:'") + require.ErrorContains(t, err, "invalid or unauthorized tags", "Setting unauthorized tags should fail") + + // Test 2: Verify tag format validation + // Create a PreAuthKey with tags to create a tagged node + // Get the user ID from the node + userID := userOwnedNode.GetUser().GetId() + + var preAuthKey v1.PreAuthKey - // Test list all nodes after added seconds - resultMachines := make([]*v1.Node, len(regIDs)) assert.EventuallyWithT(t, func(c *assert.CollectT) { err = executeAndUnmarshal( headscale, []string{ "headscale", - "nodes", - "list", + "preauthkeys", + "--user", strconv.FormatUint(userID, 10), + "create", + "--reusable", + "--tags", "tag:integration-test", "--output", "json", }, - &resultMachines, + &preAuthKey, ) assert.NoError(c, err) - }, 10*time.Second, 200*time.Millisecond, "Waiting for nodes list after tagging") - found := false - for _, node := range resultMachines { - if node.GetForcedTags() != nil { - for _, tag := range node.GetForcedTags() { - if tag == "tag:test" { - found = true - } - } + }, 10*time.Second, 200*time.Millisecond, "Creating PreAuthKey with tags") + + // Verify PreAuthKey has tags + assert.Contains(t, preAuthKey.GetAclTags(), "tag:integration-test", "PreAuthKey should have tags") + + // Test 3: Verify invalid tag format is rejected + _, err = headscale.Execute( + []string{ + "headscale", + "preauthkeys", + "--user", strconv.FormatUint(userID, 10), + "create", + "--tags", "wrong-tag", // Missing "tag:" prefix + "--output", "json", + }, + ) + assert.ErrorContains(t, err, "tag must start with the string 'tag:'", "Invalid tag format should be rejected") +} + +func TestTaggedNodeRegistration(t *testing.T) { + IntegrationSkip(t) + + // ACL policy that authorizes the tags used in tagged PreAuthKeys + // user1 and user2 can assign these tags when creating PreAuthKeys + policy := &policyv2.Policy{ + TagOwners: policyv2.TagOwners{ + "tag:server": policyv2.Owners{usernameOwner("user1@"), usernameOwner("user2@")}, + "tag:prod": policyv2.Owners{usernameOwner("user1@"), usernameOwner("user2@")}, + "tag:forbidden": policyv2.Owners{usernameOwner("user1@"), usernameOwner("user2@")}, + }, + ACLs: []policyv2.ACL{ + { + Action: "accept", + Sources: []policyv2.Alias{policyv2.Wildcard}, + Destinations: []policyv2.AliasWithPorts{{Alias: policyv2.Wildcard, Ports: []tailcfg.PortRange{tailcfg.PortRangeAny}}}, + }, + }, + } + + spec := ScenarioSpec{ + Users: []string{"user1", "user2"}, + } + + scenario, err := NewScenario(spec) + + require.NoError(t, err) + defer scenario.ShutdownAssertNoPanics(t) + + err = scenario.CreateHeadscaleEnv( + []tsic.Option{}, + hsic.WithACLPolicy(policy), + hsic.WithTestName("tagged-reg"), + ) + require.NoError(t, err) + + headscale, err := scenario.Headscale() + require.NoError(t, err) + + // Get users (they were already created by ScenarioSpec) + users, err := headscale.ListUsers() + require.NoError(t, err) + require.Len(t, users, 2, "Should have 2 users") + + var user1, user2 *v1.User + + for _, u := range users { + if u.GetName() == "user1" { + user1 = u + } else if u.GetName() == "user2" { + user2 = u } } - assert.True( - t, - found, - "should find a node with the tag 'tag:test' in the list of nodes", + + require.NotNil(t, user1, "Should find user1") + require.NotNil(t, user2, "Should find user2") + + // Test 1: Create a PreAuthKey with tags + var taggedKey v1.PreAuthKey + assert.EventuallyWithT(t, func(c *assert.CollectT) { + err = executeAndUnmarshal( + headscale, + []string{ + "headscale", + "preauthkeys", + "--user", strconv.FormatUint(user1.GetId(), 10), + "create", + "--reusable", + "--tags", "tag:server,tag:prod", + "--output", "json", + }, + &taggedKey, + ) + assert.NoError(c, err) + }, 10*time.Second, 200*time.Millisecond, "Creating tagged PreAuthKey") + + // Verify PreAuthKey has both tags + assert.Contains(t, taggedKey.GetAclTags(), "tag:server", "PreAuthKey should have tag:server") + assert.Contains(t, taggedKey.GetAclTags(), "tag:prod", "PreAuthKey should have tag:prod") + assert.Len(t, taggedKey.GetAclTags(), 2, "PreAuthKey should have exactly 2 tags") + + // Test 2: Register a node using the tagged PreAuthKey + err = scenario.CreateTailscaleNodesInUser("user1", "unstable", 1, tsic.WithNetwork(scenario.Networks()[0])) + require.NoError(t, err) + + err = scenario.RunTailscaleUp("user1", headscale.GetEndpoint(), taggedKey.GetKey()) + require.NoError(t, err) + + // Wait for the node to be registered + var registeredNode *v1.Node + + assert.EventuallyWithT(t, func(c *assert.CollectT) { + nodes, err := headscale.ListNodes() + assert.NoError(c, err) + assert.GreaterOrEqual(c, len(nodes), 1, "Should have at least 1 node") + + // Find the tagged node - it will have user "tagged-devices" per tags-as-identity model + for _, node := range nodes { + if node.GetUser().GetName() == "tagged-devices" && len(node.GetValidTags()) > 0 { + registeredNode = node + break + } + } + + assert.NotNil(c, registeredNode, "Should find a tagged node") + }, 30*time.Second, 500*time.Millisecond, "Waiting for tagged node registration") + + // Test 3: Verify the registered node has the tags from the PreAuthKey + assert.Contains(t, registeredNode.GetValidTags(), "tag:server", "Node should have tag:server") + assert.Contains(t, registeredNode.GetValidTags(), "tag:prod", "Node should have tag:prod") + assert.Len(t, registeredNode.GetValidTags(), 2, "Node should have exactly 2 tags") + + // Test 4: Verify the node shows as TaggedDevices user (tags-as-identity model) + // Tagged nodes always show as "tagged-devices" in API responses, even though + // internally UserID may be set for "created by" tracking + assert.Equal(t, "tagged-devices", registeredNode.GetUser().GetName(), "Tagged node should show as tagged-devices user") + + // Test 5: Verify the node is identified as tagged + assert.NotEmpty(t, registeredNode.GetValidTags(), "Tagged node should have tags") + + // Test 6: Verify tag modification on tagged nodes + // NOTE: Changing tags requires complex ACL authorization where the node's IP + // must be authorized for the new tags via tagOwners. For simplicity, we skip + // this test and instead verify that tags cannot be arbitrarily changed without + // proper ACL authorization. + // + // This is expected behavior - tag changes must be authorized by ACL policy. + _, err = headscale.Execute( + []string{ + "headscale", + "nodes", + "tag", + "-i", strconv.FormatUint(registeredNode.GetId(), 10), + "-t", "tag:unauthorized", + "--output", "json", + }, ) + // This SHOULD fail because tag:unauthorized is not in our ACL policy + require.ErrorContains(t, err, "invalid or unauthorized tags", "Unauthorized tag should be rejected") + + // Test 7: Create a user-owned node for comparison + var userOwnedKey v1.PreAuthKey + assert.EventuallyWithT(t, func(c *assert.CollectT) { + err = executeAndUnmarshal( + headscale, + []string{ + "headscale", + "preauthkeys", + "--user", strconv.FormatUint(user2.GetId(), 10), + "create", + "--reusable", + "--output", "json", + }, + &userOwnedKey, + ) + assert.NoError(c, err) + }, 10*time.Second, 200*time.Millisecond, "Creating user-owned PreAuthKey") + + // Verify this PreAuthKey has NO tags + assert.Empty(t, userOwnedKey.GetAclTags(), "User-owned PreAuthKey should have no tags") + + err = scenario.CreateTailscaleNodesInUser("user2", "unstable", 1, tsic.WithNetwork(scenario.Networks()[0])) + require.NoError(t, err) + + err = scenario.RunTailscaleUp("user2", headscale.GetEndpoint(), userOwnedKey.GetKey()) + require.NoError(t, err) + + // Wait for the user-owned node to be registered + var userOwnedNode *v1.Node + + assert.EventuallyWithT(t, func(c *assert.CollectT) { + nodes, err := headscale.ListNodes() + assert.NoError(c, err) + assert.GreaterOrEqual(c, len(nodes), 2, "Should have at least 2 nodes") + + // Find the node registered with user2 + for _, node := range nodes { + if node.GetUser().GetName() == "user2" { + userOwnedNode = node + break + } + } + + assert.NotNil(c, userOwnedNode, "Should find a node for user2") + }, 30*time.Second, 500*time.Millisecond, "Waiting for user-owned node registration") + + // Test 8: Verify user-owned node has NO tags + assert.Empty(t, userOwnedNode.GetValidTags(), "User-owned node should have no tags") + assert.NotZero(t, userOwnedNode.GetUser().GetId(), "User-owned node should have UserID") + + // Test 9: Verify attempting to set UNAUTHORIZED tags on user-owned node fails + // Note: Under tags-as-identity model, user-owned nodes CAN be converted to tagged nodes + // if the tags are authorized. We use an unauthorized tag to test rejection. + _, err = headscale.Execute( + []string{ + "headscale", + "nodes", + "tag", + "-i", strconv.FormatUint(userOwnedNode.GetId(), 10), + "-t", "tag:not-in-policy", + "--output", "json", + }, + ) + require.ErrorContains(t, err, "invalid or unauthorized tags", "Setting unauthorized tags should fail") + + // Test 10: Verify basic connectivity - wait for sync + err = scenario.WaitForTailscaleSync() + require.NoError(t, err, "Clients should be able to sync") +} + +// TestTagPersistenceAcrossRestart validates that tags persist across container +// restarts and that re-authentication doesn't re-apply tags from PreAuthKey. +// This is a regression test for issue #2830. +func TestTagPersistenceAcrossRestart(t *testing.T) { + IntegrationSkip(t) + + spec := ScenarioSpec{ + Users: []string{"user1"}, + } + + scenario, err := NewScenario(spec) + + require.NoError(t, err) + defer scenario.ShutdownAssertNoPanics(t) + + err = scenario.CreateHeadscaleEnv([]tsic.Option{}, hsic.WithTestName("tag-persist")) + require.NoError(t, err) + + headscale, err := scenario.Headscale() + require.NoError(t, err) + + // Get user + users, err := headscale.ListUsers() + require.NoError(t, err) + require.Len(t, users, 1) + user1 := users[0] + + // Create a reusable PreAuthKey with tags + var taggedKey v1.PreAuthKey + + assert.EventuallyWithT(t, func(c *assert.CollectT) { + err = executeAndUnmarshal( + headscale, + []string{ + "headscale", + "preauthkeys", + "--user", strconv.FormatUint(user1.GetId(), 10), + "create", + "--reusable", // Critical: key must be reusable for container restart + "--tags", "tag:server,tag:prod", + "--output", "json", + }, + &taggedKey, + ) + assert.NoError(c, err) + }, 10*time.Second, 200*time.Millisecond, "Creating reusable tagged PreAuthKey") + + require.True(t, taggedKey.GetReusable(), "PreAuthKey must be reusable for restart scenario") + require.Contains(t, taggedKey.GetAclTags(), "tag:server") + require.Contains(t, taggedKey.GetAclTags(), "tag:prod") + + // Register initial node with tagged PreAuthKey + err = scenario.CreateTailscaleNodesInUser("user1", "unstable", 1, tsic.WithNetwork(scenario.Networks()[0])) + require.NoError(t, err) + + err = scenario.RunTailscaleUp("user1", headscale.GetEndpoint(), taggedKey.GetKey()) + require.NoError(t, err) + + // Wait for node registration and get initial node state + var initialNode *v1.Node + + assert.EventuallyWithT(t, func(c *assert.CollectT) { + nodes, err := headscale.ListNodes() + assert.NoError(c, err) + assert.GreaterOrEqual(c, len(nodes), 1, "Should have at least 1 node") + + for _, node := range nodes { + if node.GetUser().GetId() == user1.GetId() || node.GetUser().GetName() == "tagged-devices" { + initialNode = node + break + } + } + + assert.NotNil(c, initialNode, "Should find the registered node") + }, 30*time.Second, 500*time.Millisecond, "Waiting for initial node registration") + + // Verify initial tags + require.Contains(t, initialNode.GetValidTags(), "tag:server", "Initial node should have tag:server") + require.Contains(t, initialNode.GetValidTags(), "tag:prod", "Initial node should have tag:prod") + require.Len(t, initialNode.GetValidTags(), 2, "Initial node should have exactly 2 tags") + + initialNodeID := initialNode.GetId() + t.Logf("Initial node registered with ID %d and tags %v", initialNodeID, initialNode.GetValidTags()) + + // Simulate container restart by shutting down and restarting Tailscale client + allClients, err := scenario.ListTailscaleClients() + require.NoError(t, err) + require.Len(t, allClients, 1, "Should have exactly 1 client") + + client := allClients[0] + + // Stop the client (simulates container stop) + err = client.Down() + require.NoError(t, err) + + // Wait a bit to ensure the client is fully stopped + time.Sleep(2 * time.Second) + + // Restart the client with the SAME PreAuthKey (container restart scenario) + // This simulates what happens when a Docker container restarts with a reusable PreAuthKey + err = scenario.RunTailscaleUp("user1", headscale.GetEndpoint(), taggedKey.GetKey()) + require.NoError(t, err) + + // Wait for re-authentication + var nodeAfterRestart *v1.Node + + assert.EventuallyWithT(t, func(c *assert.CollectT) { + nodes, err := headscale.ListNodes() + assert.NoError(c, err) + + for _, node := range nodes { + if node.GetId() == initialNodeID { + nodeAfterRestart = node + break + } + } + + assert.NotNil(c, nodeAfterRestart, "Should find the same node after restart") + }, 30*time.Second, 500*time.Millisecond, "Waiting for node re-authentication") + + // CRITICAL ASSERTION: Tags should NOT be re-applied from PreAuthKey + // Tags are only applied during INITIAL authentication, not re-authentication + // The node should keep its existing tags (which happen to be the same in this case) + assert.Contains(t, nodeAfterRestart.GetValidTags(), "tag:server", "Node should still have tag:server after restart") + assert.Contains(t, nodeAfterRestart.GetValidTags(), "tag:prod", "Node should still have tag:prod after restart") + assert.Len(t, nodeAfterRestart.GetValidTags(), 2, "Node should still have exactly 2 tags after restart") + + // Verify it's the SAME node (same ID), not a new registration + assert.Equal(t, initialNodeID, nodeAfterRestart.GetId(), "Should be the same node, not a new registration") + + // Verify node count hasn't increased (no duplicate nodes) + finalNodes, err := headscale.ListNodes() + require.NoError(t, err) + assert.Len(t, finalNodes, 1, "Should still have exactly 1 node (no duplicates from restart)") + + t.Logf("Container restart validation complete - node %d maintained tags across restart", initialNodeID) } func TestNodeAdvertiseTagCommand(t *testing.T) {