diff --git a/hscontrol/app.go b/hscontrol/app.go index 74b5f3a4..e82af703 100644 --- a/hscontrol/app.go +++ b/hscontrol/app.go @@ -271,7 +271,7 @@ func (h *Headscale) scheduledTasks(ctx context.Context) { return case <-expireTicker.C: - var expiredNodeChanges []change.ChangeSet + var expiredNodeChanges []change.Change var changed bool lastExpiryCheck, expiredNodeChanges, changed = h.state.ExpireExpiredNodes(lastExpiryCheck) @@ -305,7 +305,7 @@ func (h *Headscale) scheduledTasks(ctx context.Context) { } h.state.SetDERPMap(derpMap) - h.Change(change.DERPSet) + h.Change(change.DERPMap()) case records, ok := <-extraRecordsUpdate: if !ok { @@ -313,7 +313,7 @@ func (h *Headscale) scheduledTasks(ctx context.Context) { } h.cfg.TailcfgDNSConfig.ExtraRecords = records - h.Change(change.ExtraRecordsSet) + h.Change(change.ExtraRecords()) } } } @@ -988,7 +988,7 @@ func readOrCreatePrivateKey(path string) (*key.MachinePrivate, error) { // Change is used to send changes to nodes. // All change should be enqueued here and empty will be automatically // ignored. -func (h *Headscale) Change(cs ...change.ChangeSet) { +func (h *Headscale) Change(cs ...change.Change) { h.mapBatcher.AddWork(cs...) } diff --git a/hscontrol/grpcv1.go b/hscontrol/grpcv1.go index 0cccf5ca..9573d1ea 100644 --- a/hscontrol/grpcv1.go +++ b/hscontrol/grpcv1.go @@ -58,14 +58,9 @@ func (api headscaleV1APIServer) CreateUser( return nil, status.Errorf(codes.Internal, "failed to create user: %s", err) } - c := change.UserAdded(types.UserID(user.ID)) - - // TODO(kradalby): Both of these might be policy changes, find a better way to merge. - if !policyChanged.Empty() { - c.Change = change.Policy - } - - api.h.Change(c) + // CreateUser returns a policy change response if the user creation affected policy. + // This triggers a full policy re-evaluation for all connected nodes. + api.h.Change(policyChanged) return &v1.CreateUserResponse{User: user.Proto()}, nil } @@ -109,7 +104,8 @@ func (api headscaleV1APIServer) DeleteUser( return nil, err } - api.h.Change(change.UserRemoved(types.UserID(user.ID))) + // User deletion may affect policy, trigger a full policy re-evaluation. + api.h.Change(change.UserRemoved()) return &v1.DeleteUserResponse{}, nil } diff --git a/hscontrol/mapper/batcher.go b/hscontrol/mapper/batcher.go index d0ce2e1f..0e0a9b25 100644 --- a/hscontrol/mapper/batcher.go +++ b/hscontrol/mapper/batcher.go @@ -13,18 +13,13 @@ import ( "github.com/puzpuzpuz/xsync/v4" "github.com/rs/zerolog/log" "tailscale.com/tailcfg" - "tailscale.com/types/ptr" ) -var ( - mapResponseGenerated = promauto.NewCounterVec(prometheus.CounterOpts{ - Namespace: "headscale", - Name: "mapresponse_generated_total", - Help: "total count of mapresponses generated by response type and change type", - }, []string{"response_type", "change_type"}) - - errNodeNotFoundInNodeStore = errors.New("node not found in NodeStore") -) +var mapResponseGenerated = promauto.NewCounterVec(prometheus.CounterOpts{ + Namespace: "headscale", + Name: "mapresponse_generated_total", + Help: "total count of mapresponses generated by response type", +}, []string{"response_type"}) type batcherFunc func(cfg *types.Config, state *state.State) Batcher @@ -36,8 +31,8 @@ type Batcher interface { RemoveNode(id types.NodeID, c chan<- *tailcfg.MapResponse) bool IsConnected(id types.NodeID) bool ConnectedMap() *xsync.Map[types.NodeID, bool] - AddWork(c ...change.ChangeSet) - MapResponseFromChange(id types.NodeID, c change.ChangeSet) (*tailcfg.MapResponse, error) + AddWork(r ...change.Change) + MapResponseFromChange(id types.NodeID, r change.Change) (*tailcfg.MapResponse, error) DebugMapResponses() (map[types.NodeID][]tailcfg.MapResponse, error) } @@ -51,7 +46,7 @@ func NewBatcher(batchTime time.Duration, workers int, mapper *mapper) *LockFreeB workCh: make(chan work, workers*200), nodes: xsync.NewMap[types.NodeID, *multiChannelNodeConn](), connected: xsync.NewMap[types.NodeID, *time.Time](), - pendingChanges: xsync.NewMap[types.NodeID, []change.ChangeSet](), + pendingChanges: xsync.NewMap[types.NodeID, []change.Change](), } } @@ -69,15 +64,21 @@ type nodeConnection interface { nodeID() types.NodeID version() tailcfg.CapabilityVersion send(data *tailcfg.MapResponse) error + // computePeerDiff returns peers that were previously sent but are no longer in the current list. + computePeerDiff(currentPeers []tailcfg.NodeID) (removed []tailcfg.NodeID) + // updateSentPeers updates the tracking of which peers have been sent to this node. + updateSentPeers(resp *tailcfg.MapResponse) } -// generateMapResponse generates a [tailcfg.MapResponse] for the given NodeID that is based on the provided [change.ChangeSet]. -func generateMapResponse(nodeID types.NodeID, version tailcfg.CapabilityVersion, mapper *mapper, c change.ChangeSet) (*tailcfg.MapResponse, error) { - if c.Empty() { - return nil, nil +// generateMapResponse generates a [tailcfg.MapResponse] for the given NodeID based on the provided [change.Change]. +func generateMapResponse(nc nodeConnection, mapper *mapper, r change.Change) (*tailcfg.MapResponse, error) { + nodeID := nc.nodeID() + version := nc.version() + + if r.IsEmpty() { + return nil, nil //nolint:nilnil // Empty response means nothing to send } - // Validate inputs before processing if nodeID == 0 { return nil, fmt.Errorf("invalid nodeID: %d", nodeID) } @@ -86,141 +87,58 @@ func generateMapResponse(nodeID types.NodeID, version tailcfg.CapabilityVersion, return nil, fmt.Errorf("mapper is nil for nodeID %d", nodeID) } + // Handle self-only responses + if r.IsSelfOnly() && r.TargetNode != nodeID { + return nil, nil //nolint:nilnil // No response needed for other nodes when self-only + } + var ( - mapResp *tailcfg.MapResponse - err error - responseType string + mapResp *tailcfg.MapResponse + err error ) - // Record metric when function exits - defer func() { - if err == nil && mapResp != nil && responseType != "" { - mapResponseGenerated.WithLabelValues(responseType, c.Change.String()).Inc() - } - }() + // Track metric using categorized type, not free-form reason + mapResponseGenerated.WithLabelValues(r.Type()).Inc() - switch c.Change { - case change.DERP: - responseType = "derp" - mapResp, err = mapper.derpMapResponse(nodeID) + // Check if this requires runtime peer visibility computation (e.g., policy changes) + if r.RequiresRuntimePeerComputation { + currentPeers := mapper.state.ListPeers(nodeID) - case change.NodeCameOnline, change.NodeWentOffline: - if c.IsSubnetRouter { - // TODO(kradalby): This can potentially be a peer update of the old and new subnet router. - responseType = "full" - mapResp, err = mapper.fullMapResponse(nodeID, version) - } else { - // Trust the change type for online/offline status to avoid race conditions - // between NodeStore updates and change processing - responseType = string(patchResponseDebug) - onlineStatus := c.Change == change.NodeCameOnline - - mapResp, err = mapper.peerChangedPatchResponse(nodeID, []*tailcfg.PeerChange{ - { - NodeID: c.NodeID.NodeID(), - Online: ptr.To(onlineStatus), - }, - }) + currentPeerIDs := make([]tailcfg.NodeID, 0, currentPeers.Len()) + for _, peer := range currentPeers.All() { + currentPeerIDs = append(currentPeerIDs, peer.ID().NodeID()) } - case change.NodeNewOrUpdate: - // If the node is the one being updated, we send a self update that preserves peer information - // to ensure the node sees changes to its own properties (e.g., hostname/DNS name changes) - // without losing its view of peer status during rapid reconnection cycles - if c.IsSelfUpdate(nodeID) { - responseType = "self" - mapResp, err = mapper.selfMapResponse(nodeID, version) - } else { - responseType = "change" - mapResp, err = mapper.peerChangeResponse(nodeID, version, c.NodeID) - } - - case change.NodeRemove: - responseType = "remove" - mapResp, err = mapper.peerRemovedResponse(nodeID, c.NodeID) - - case change.NodeKeyExpiry: - // If the node is the one whose key is expiring, we send a "full" self update - // as nodes will ignore patch updates about themselves (?). - if c.IsSelfUpdate(nodeID) { - responseType = "self" - mapResp, err = mapper.selfMapResponse(nodeID, version) - // mapResp, err = mapper.fullMapResponse(nodeID, version) - } else { - responseType = "patch" - mapResp, err = mapper.peerChangedPatchResponse(nodeID, []*tailcfg.PeerChange{ - { - NodeID: c.NodeID.NodeID(), - KeyExpiry: c.NodeExpiry, - }, - }) - } - - case change.NodeEndpoint, change.NodeDERP: - // Endpoint or DERP changes can be sent as lightweight patches. - // Query the NodeStore for the current peer state to construct the PeerChange. - // Even if only endpoint or only DERP changed, we include both in the patch - // since they're often updated together and it's minimal overhead. - responseType = "patch" - - peer, found := mapper.state.GetNodeByID(c.NodeID) - if !found { - return nil, fmt.Errorf("%w: %d", errNodeNotFoundInNodeStore, c.NodeID) - } - - peerChange := &tailcfg.PeerChange{ - NodeID: c.NodeID.NodeID(), - Endpoints: peer.Endpoints().AsSlice(), - DERPRegion: 0, // Will be set below if available - } - - // Extract DERP region from Hostinfo if available - if hi := peer.AsStruct().Hostinfo; hi != nil && hi.NetInfo != nil { - peerChange.DERPRegion = hi.NetInfo.PreferredDERP - } - - mapResp, err = mapper.peerChangedPatchResponse(nodeID, []*tailcfg.PeerChange{peerChange}) - - default: - // The following will always hit this: - // change.Full, change.Policy - responseType = "full" - mapResp, err = mapper.fullMapResponse(nodeID, version) + removedPeers := nc.computePeerDiff(currentPeerIDs) + mapResp, err = mapper.policyChangeResponse(nodeID, version, removedPeers, currentPeers) + } else { + mapResp, err = mapper.buildFromChange(nodeID, version, &r) } if err != nil { return nil, fmt.Errorf("generating map response for nodeID %d: %w", nodeID, err) } - // TODO(kradalby): Is this necessary? - // Validate the generated map response - only check for nil response - // Note: mapResp.Node can be nil for peer updates, which is valid - if mapResp == nil && c.Change != change.DERP && c.Change != change.NodeRemove { - return nil, fmt.Errorf("generated nil map response for nodeID %d change %s", nodeID, c.Change.String()) - } - return mapResp, nil } -// handleNodeChange generates and sends a [tailcfg.MapResponse] for a given node and [change.ChangeSet]. -func handleNodeChange(nc nodeConnection, mapper *mapper, c change.ChangeSet) error { +// handleNodeChange generates and sends a [tailcfg.MapResponse] for a given node and [change.Change]. +func handleNodeChange(nc nodeConnection, mapper *mapper, r change.Change) error { if nc == nil { return errors.New("nodeConnection is nil") } nodeID := nc.nodeID() - log.Debug().Caller().Uint64("node.id", nodeID.Uint64()).Str("change.type", c.Change.String()).Msg("Node change processing started because change notification received") + log.Debug().Caller().Uint64("node.id", nodeID.Uint64()).Str("reason", r.Reason).Msg("Node change processing started because change notification received") - var data *tailcfg.MapResponse - var err error - data, err = generateMapResponse(nodeID, nc.version(), mapper, c) + data, err := generateMapResponse(nc, mapper, r) if err != nil { return fmt.Errorf("generating map response for node %d: %w", nodeID, err) } if data == nil { - // No data to send is valid for some change types + // No data to send is valid for some response types return nil } @@ -230,6 +148,9 @@ func handleNodeChange(nc nodeConnection, mapper *mapper, c change.ChangeSet) err return fmt.Errorf("sending map response to node %d: %w", nodeID, err) } + // Update peer tracking after successful send + nc.updateSentPeers(data) + return nil } @@ -241,7 +162,7 @@ type workResult struct { // work represents a unit of work to be processed by workers. type work struct { - c change.ChangeSet + r change.Change nodeID types.NodeID resultCh chan<- workResult // optional channel for synchronous operations } diff --git a/hscontrol/mapper/batcher_lockfree.go b/hscontrol/mapper/batcher_lockfree.go index c90cdc32..28e53426 100644 --- a/hscontrol/mapper/batcher_lockfree.go +++ b/hscontrol/mapper/batcher_lockfree.go @@ -33,7 +33,7 @@ type LockFreeBatcher struct { done chan struct{} // Batching state - pendingChanges *xsync.Map[types.NodeID, []change.ChangeSet] + pendingChanges *xsync.Map[types.NodeID, []change.Change] // Metrics totalNodes atomic.Int64 @@ -141,8 +141,8 @@ func (b *LockFreeBatcher) RemoveNode(id types.NodeID, c chan<- *tailcfg.MapRespo } // AddWork queues a change to be processed by the batcher. -func (b *LockFreeBatcher) AddWork(c ...change.ChangeSet) { - b.addWork(c...) +func (b *LockFreeBatcher) AddWork(r ...change.Change) { + b.addWork(r...) } func (b *LockFreeBatcher) Start() { @@ -211,15 +211,19 @@ func (b *LockFreeBatcher) worker(workerID int) { var result workResult if nc, exists := b.nodes.Load(w.nodeID); exists { var err error - result.mapResponse, err = generateMapResponse(nc.nodeID(), nc.version(), b.mapper, w.c) + + result.mapResponse, err = generateMapResponse(nc, b.mapper, w.r) result.err = err if result.err != nil { b.workErrors.Add(1) log.Error().Err(result.err). Int("worker.id", workerID). Uint64("node.id", w.nodeID.Uint64()). - Str("change", w.c.Change.String()). + Str("reason", w.r.Reason). Msg("failed to generate map response for synchronous work") + } else if result.mapResponse != nil { + // Update peer tracking for synchronous responses too + nc.updateSentPeers(result.mapResponse) } } else { result.err = fmt.Errorf("node %d not found", w.nodeID) @@ -247,13 +251,13 @@ func (b *LockFreeBatcher) worker(workerID int) { if nc, exists := b.nodes.Load(w.nodeID); exists { // Apply change to node - this will handle offline nodes gracefully // and queue work for when they reconnect - err := nc.change(w.c) + err := nc.change(w.r) if err != nil { b.workErrors.Add(1) log.Error().Err(err). Int("worker.id", workerID). - Uint64("node.id", w.c.NodeID.Uint64()). - Str("change", w.c.Change.String()). + Uint64("node.id", w.nodeID.Uint64()). + Str("reason", w.r.Reason). Msg("failed to apply change") } } @@ -264,8 +268,8 @@ func (b *LockFreeBatcher) worker(workerID int) { } } -func (b *LockFreeBatcher) addWork(c ...change.ChangeSet) { - b.addToBatch(c...) +func (b *LockFreeBatcher) addWork(r ...change.Change) { + b.addToBatch(r...) } // queueWork safely queues work. @@ -281,38 +285,43 @@ func (b *LockFreeBatcher) queueWork(w work) { } } -// addToBatch adds a change to the pending batch. -func (b *LockFreeBatcher) addToBatch(c ...change.ChangeSet) { - // Short circuit if any of the changes is a full update, which +// addToBatch adds a response to the pending batch. +func (b *LockFreeBatcher) addToBatch(responses ...change.Change) { + // Short circuit if any of the responses is a full update, which // means we can skip sending individual changes. - if change.HasFull(c) { + if change.HasFull(responses) { b.nodes.Range(func(nodeID types.NodeID, _ *multiChannelNodeConn) bool { - b.pendingChanges.Store(nodeID, []change.ChangeSet{{Change: change.Full}}) + b.pendingChanges.Store(nodeID, []change.Change{change.FullUpdate()}) return true }) - return - } - - all, self := change.SplitAllAndSelf(c) - - for _, changeSet := range self { - changes, _ := b.pendingChanges.LoadOrStore(changeSet.NodeID, []change.ChangeSet{}) - changes = append(changes, changeSet) - b.pendingChanges.Store(changeSet.NodeID, changes) return } - b.nodes.Range(func(nodeID types.NodeID, _ *multiChannelNodeConn) bool { - rel := change.RemoveUpdatesForSelf(nodeID, all) + broadcast, targeted := change.SplitTargetedAndBroadcast(responses) - changes, _ := b.pendingChanges.LoadOrStore(nodeID, []change.ChangeSet{}) - changes = append(changes, rel...) - b.pendingChanges.Store(nodeID, changes) + // Handle targeted responses - send only to the specific node + for _, resp := range targeted { + changes, _ := b.pendingChanges.LoadOrStore(resp.TargetNode, []change.Change{}) + changes = append(changes, resp) + b.pendingChanges.Store(resp.TargetNode, changes) + } - return true - }) + // Handle broadcast responses - send to all nodes, filtering as needed + if len(broadcast) > 0 { + b.nodes.Range(func(nodeID types.NodeID, _ *multiChannelNodeConn) bool { + filtered := change.FilterForNode(nodeID, broadcast) + + if len(filtered) > 0 { + changes, _ := b.pendingChanges.LoadOrStore(nodeID, []change.Change{}) + changes = append(changes, filtered...) + b.pendingChanges.Store(nodeID, changes) + } + + return true + }) + } } // processBatchedChanges processes all pending batched changes. @@ -322,14 +331,14 @@ func (b *LockFreeBatcher) processBatchedChanges() { } // Process all pending changes - b.pendingChanges.Range(func(nodeID types.NodeID, changes []change.ChangeSet) bool { - if len(changes) == 0 { + b.pendingChanges.Range(func(nodeID types.NodeID, responses []change.Change) bool { + if len(responses) == 0 { return true } - // Send all batched changes for this node - for _, c := range changes { - b.queueWork(work{c: c, nodeID: nodeID, resultCh: nil}) + // Send all batched responses for this node + for _, r := range responses { + b.queueWork(work{r: r, nodeID: nodeID, resultCh: nil}) } // Clear the pending changes for this node @@ -432,11 +441,11 @@ func (b *LockFreeBatcher) ConnectedMap() *xsync.Map[types.NodeID, bool] { // MapResponseFromChange queues work to generate a map response and waits for the result. // This allows synchronous map generation using the same worker pool. -func (b *LockFreeBatcher) MapResponseFromChange(id types.NodeID, c change.ChangeSet) (*tailcfg.MapResponse, error) { +func (b *LockFreeBatcher) MapResponseFromChange(id types.NodeID, r change.Change) (*tailcfg.MapResponse, error) { resultCh := make(chan workResult, 1) // Queue the work with a result channel using the safe queueing method - b.queueWork(work{c: c, nodeID: id, resultCh: resultCh}) + b.queueWork(work{r: r, nodeID: id, resultCh: resultCh}) // Wait for the result select { @@ -466,6 +475,12 @@ type multiChannelNodeConn struct { connections []*connectionEntry updateCount atomic.Int64 + + // lastSentPeers tracks which peers were last sent to this node. + // This enables computing diffs for policy changes instead of sending + // full peer lists (which clients interpret as "no change" when empty). + // Using xsync.Map for lock-free concurrent access. + lastSentPeers *xsync.Map[tailcfg.NodeID, struct{}] } // generateConnectionID generates a unique connection identifier. @@ -478,8 +493,9 @@ func generateConnectionID() string { // newMultiChannelNodeConn creates a new multi-channel node connection. func newMultiChannelNodeConn(id types.NodeID, mapper *mapper) *multiChannelNodeConn { return &multiChannelNodeConn{ - id: id, - mapper: mapper, + id: id, + mapper: mapper, + lastSentPeers: xsync.NewMap[tailcfg.NodeID, struct{}](), } } @@ -662,9 +678,59 @@ func (mc *multiChannelNodeConn) version() tailcfg.CapabilityVersion { return mc.connections[0].version } +// updateSentPeers updates the tracked peer state based on a sent MapResponse. +// This must be called after successfully sending a response to keep track of +// what the client knows about, enabling accurate diffs for future updates. +func (mc *multiChannelNodeConn) updateSentPeers(resp *tailcfg.MapResponse) { + if resp == nil { + return + } + + // Full peer list replaces tracked state entirely + if resp.Peers != nil { + mc.lastSentPeers.Clear() + + for _, peer := range resp.Peers { + mc.lastSentPeers.Store(peer.ID, struct{}{}) + } + } + + // Incremental additions + for _, peer := range resp.PeersChanged { + mc.lastSentPeers.Store(peer.ID, struct{}{}) + } + + // Incremental removals + for _, id := range resp.PeersRemoved { + mc.lastSentPeers.Delete(id) + } +} + +// computePeerDiff compares the current peer list against what was last sent +// and returns the peers that were removed (in lastSentPeers but not in current). +func (mc *multiChannelNodeConn) computePeerDiff(currentPeers []tailcfg.NodeID) []tailcfg.NodeID { + currentSet := make(map[tailcfg.NodeID]struct{}, len(currentPeers)) + for _, id := range currentPeers { + currentSet[id] = struct{}{} + } + + var removed []tailcfg.NodeID + + // Find removed: in lastSentPeers but not in current + mc.lastSentPeers.Range(func(id tailcfg.NodeID, _ struct{}) bool { + if _, exists := currentSet[id]; !exists { + removed = append(removed, id) + } + + return true + }) + + return removed +} + // change applies a change to all active connections for the node. -func (mc *multiChannelNodeConn) change(c change.ChangeSet) error { - return handleNodeChange(mc, mc.mapper, c) +func (mc *multiChannelNodeConn) change(r change.Change) error { + return handleNodeChange(mc, mc.mapper, r) } // DebugNodeInfo contains debug information about a node's connections. diff --git a/hscontrol/mapper/batcher_test.go b/hscontrol/mapper/batcher_test.go index f43ea5a1..f67cb517 100644 --- a/hscontrol/mapper/batcher_test.go +++ b/hscontrol/mapper/batcher_test.go @@ -59,7 +59,7 @@ func (t *testBatcherWrapper) AddNode(id types.NodeID, c chan<- *tailcfg.MapRespo return fmt.Errorf("%w: %d", errNodeNotFoundAfterAdd, id) } - t.AddWork(change.NodeOnline(node)) + t.AddWork(change.NodeOnlineFor(node)) return nil } @@ -76,7 +76,7 @@ func (t *testBatcherWrapper) RemoveNode(id types.NodeID, c chan<- *tailcfg.MapRe // Do this BEFORE removing from batcher so the change can be processed node, ok := t.state.GetNodeByID(id) if ok { - t.AddWork(change.NodeOffline(node)) + t.AddWork(change.NodeOfflineFor(node)) } // Finally remove from the real batcher @@ -557,9 +557,9 @@ func TestEnhancedTrackingWithBatcher(t *testing.T) { }, time.Second, 10*time.Millisecond, "waiting for node connection") // Generate work and wait for updates to be processed - batcher.AddWork(change.FullSet) - batcher.AddWork(change.PolicySet) - batcher.AddWork(change.DERPSet) + batcher.AddWork(change.FullUpdate()) + batcher.AddWork(change.PolicyChange()) + batcher.AddWork(change.DERPMap()) // Wait for updates to be processed (at least 1 update received) assert.EventuallyWithT(t, func(c *assert.CollectT) { @@ -661,7 +661,7 @@ func TestBatcherScalabilityAllToAll(t *testing.T) { batcher.AddNode(node.n.ID, node.ch, tailcfg.CapabilityVersion(100)) // Issue full update after each join to ensure connectivity - batcher.AddWork(change.FullSet) + batcher.AddWork(change.FullUpdate()) // Yield to scheduler for large node counts to prevent overwhelming the work queue if tc.nodeCount > 100 && i%50 == 49 { @@ -832,7 +832,7 @@ func TestBatcherBasicOperations(t *testing.T) { } // Test work processing with DERP change - batcher.AddWork(change.DERPChange()) + batcher.AddWork(change.DERPMap()) // Wait for update and validate content select { @@ -959,31 +959,31 @@ func drainChannelTimeout(ch <-chan *tailcfg.MapResponse, name string, timeout ti // }{ // { // name: "DERP change", -// changeSet: change.DERPSet, +// changeSet: change.DERPMapResponse(), // expectData: true, // description: "DERP changes should generate map updates", // }, // { // name: "Node key expiry", -// changeSet: change.KeyExpiry(testNodes[1].n.ID), +// changeSet: change.KeyExpiryFor(testNodes[1].n.ID), // expectData: true, // description: "Node key expiry with real node data", // }, // { // name: "Node new registration", -// changeSet: change.NodeAdded(testNodes[1].n.ID), +// changeSet: change.NodeAddedResponse(testNodes[1].n.ID), // expectData: true, // description: "New node registration with real data", // }, // { // name: "Full update", -// changeSet: change.FullSet, +// changeSet: change.FullUpdateResponse(), // expectData: true, // description: "Full updates with real node data", // }, // { // name: "Policy change", -// changeSet: change.PolicySet, +// changeSet: change.PolicyChangeResponse(), // expectData: true, // description: "Policy updates with real node data", // }, @@ -1057,13 +1057,13 @@ func TestBatcherWorkQueueBatching(t *testing.T) { var receivedUpdates []*tailcfg.MapResponse // Add multiple changes rapidly to test batching - batcher.AddWork(change.DERPSet) + batcher.AddWork(change.DERPMap()) // Use a valid expiry time for testing since test nodes don't have expiry set testExpiry := time.Now().Add(24 * time.Hour) - batcher.AddWork(change.KeyExpiry(testNodes[1].n.ID, testExpiry)) - batcher.AddWork(change.DERPSet) + batcher.AddWork(change.KeyExpiryFor(testNodes[1].n.ID, testExpiry)) + batcher.AddWork(change.DERPMap()) batcher.AddWork(change.NodeAdded(testNodes[1].n.ID)) - batcher.AddWork(change.DERPSet) + batcher.AddWork(change.DERPMap()) // Collect updates with timeout updateCount := 0 @@ -1087,8 +1087,8 @@ func TestBatcherWorkQueueBatching(t *testing.T) { t.Logf("Update %d: nil update", updateCount) } case <-timeout: - // Expected: 5 changes should generate 6 updates (no batching in current implementation) - expectedUpdates := 6 + // Expected: 5 explicit changes + 1 initial from AddNode + 1 NodeOnline from wrapper = 7 updates + expectedUpdates := 7 t.Logf("Received %d updates from %d changes (expected %d)", updateCount, 5, expectedUpdates) @@ -1160,7 +1160,7 @@ func XTestBatcherChannelClosingRace(t *testing.T) { // Add real work during connection chaos if i%10 == 0 { - batcher.AddWork(change.DERPSet) + batcher.AddWork(change.DERPMap()) } // Rapid second connection - should replace ch1 @@ -1260,7 +1260,7 @@ func TestBatcherWorkerChannelSafety(t *testing.T) { // Add node and immediately queue real work batcher.AddNode(testNode.n.ID, ch, tailcfg.CapabilityVersion(100)) - batcher.AddWork(change.DERPSet) + batcher.AddWork(change.DERPMap()) // Consumer goroutine to validate data and detect channel issues go func() { @@ -1302,7 +1302,7 @@ func TestBatcherWorkerChannelSafety(t *testing.T) { if i%10 == 0 { // Use a valid expiry time for testing since test nodes don't have expiry set testExpiry := time.Now().Add(24 * time.Hour) - batcher.AddWork(change.KeyExpiry(testNode.n.ID, testExpiry)) + batcher.AddWork(change.KeyExpiryFor(testNode.n.ID, testExpiry)) } // Rapid removal creates race between worker and removal @@ -1510,12 +1510,12 @@ func TestBatcherConcurrentClients(t *testing.T) { // Generate various types of work during racing if i%3 == 0 { // DERP changes - batcher.AddWork(change.DERPSet) + batcher.AddWork(change.DERPMap()) } if i%5 == 0 { // Full updates using real node data - batcher.AddWork(change.FullSet) + batcher.AddWork(change.FullUpdate()) } if i%7 == 0 && len(allNodes) > 0 { @@ -1523,7 +1523,7 @@ func TestBatcherConcurrentClients(t *testing.T) { node := allNodes[i%len(allNodes)] // Use a valid expiry time for testing since test nodes don't have expiry set testExpiry := time.Now().Add(24 * time.Hour) - batcher.AddWork(change.KeyExpiry(node.n.ID, testExpiry)) + batcher.AddWork(change.KeyExpiryFor(node.n.ID, testExpiry)) } // Yield to allow some batching @@ -1778,7 +1778,7 @@ func XTestBatcherScalability(t *testing.T) { } }, 5*time.Second, 50*time.Millisecond, "waiting for nodes to connect") - batcher.AddWork(change.FullSet) + batcher.AddWork(change.FullUpdate()) // Wait for initial update to propagate assert.EventuallyWithT(t, func(c *assert.CollectT) { @@ -1887,7 +1887,7 @@ func XTestBatcherScalability(t *testing.T) { // Add work to create load if index%5 == 0 { - batcher.AddWork(change.FullSet) + batcher.AddWork(change.FullUpdate()) } }( node.n.ID, @@ -1914,11 +1914,11 @@ func XTestBatcherScalability(t *testing.T) { // Generate different types of work to ensure updates are sent switch index % 4 { case 0: - batcher.AddWork(change.FullSet) + batcher.AddWork(change.FullUpdate()) case 1: - batcher.AddWork(change.PolicySet) + batcher.AddWork(change.PolicyChange()) case 2: - batcher.AddWork(change.DERPSet) + batcher.AddWork(change.DERPMap()) default: // Pick a random node and generate a node change if len(testNodes) > 0 { @@ -1927,7 +1927,7 @@ func XTestBatcherScalability(t *testing.T) { change.NodeAdded(testNodes[nodeIdx].n.ID), ) } else { - batcher.AddWork(change.FullSet) + batcher.AddWork(change.FullUpdate()) } } }(i) @@ -2165,7 +2165,7 @@ func TestBatcherFullPeerUpdates(t *testing.T) { // Send a full update - this should generate full peer lists t.Logf("Sending FullSet update...") - batcher.AddWork(change.FullSet) + batcher.AddWork(change.FullUpdate()) // Wait for FullSet work items to be processed t.Logf("Waiting for FullSet to be processed...") @@ -2261,7 +2261,7 @@ func TestBatcherFullPeerUpdates(t *testing.T) { t.Logf("Total updates received across all nodes: %d", totalUpdates) if !foundFullUpdate { - t.Errorf("CRITICAL: No FULL updates received despite sending change.FullSet!") + t.Errorf("CRITICAL: No FULL updates received despite sending change.FullUpdateResponse()!") t.Errorf( "This confirms the bug - FullSet updates are not generating full peer responses", ) @@ -2372,7 +2372,7 @@ func TestBatcherRapidReconnection(t *testing.T) { t.Logf("Phase 5: Testing if nodes can receive updates despite debug status...") // Send a change that should reach all nodes - batcher.AddWork(change.DERPChange()) + batcher.AddWork(change.DERPMap()) receivedCount := 0 timeout := time.After(500 * time.Millisecond) @@ -2508,11 +2508,7 @@ func TestBatcherMultiConnection(t *testing.T) { clearChannel(node2.ch) // Send a change notification from node2 (so node1 should receive it on all connections) - testChangeSet := change.ChangeSet{ - NodeID: node2.n.ID, - Change: change.NodeNewOrUpdate, - SelfUpdateOnly: false, - } + testChangeSet := change.NodeAdded(node2.n.ID) batcher.AddWork(testChangeSet) @@ -2591,11 +2587,7 @@ func TestBatcherMultiConnection(t *testing.T) { clearChannel(node1.ch) clearChannel(thirdChannel) - testChangeSet2 := change.ChangeSet{ - NodeID: node2.n.ID, - Change: change.NodeNewOrUpdate, - SelfUpdateOnly: false, - } + testChangeSet2 := change.NodeAdded(node2.n.ID) batcher.AddWork(testChangeSet2) @@ -2629,7 +2621,11 @@ func TestBatcherMultiConnection(t *testing.T) { remaining1Received, remaining3Received) } - // Verify second channel no longer receives updates (should be closed/removed) + // Drain secondChannel of any messages received before removal + // (the test wrapper sends NodeOffline before removal, which may have reached this channel) + clearChannel(secondChannel) + + // Verify second channel no longer receives new updates after being removed select { case <-secondChannel: t.Errorf("Removed connection still received update - this should not happen") diff --git a/hscontrol/mapper/builder.go b/hscontrol/mapper/builder.go index 69aca6d8..c666ff24 100644 --- a/hscontrol/mapper/builder.go +++ b/hscontrol/mapper/builder.go @@ -29,10 +29,8 @@ type debugType string const ( fullResponseDebug debugType = "full" selfResponseDebug debugType = "self" - patchResponseDebug debugType = "patch" - removeResponseDebug debugType = "remove" changeResponseDebug debugType = "change" - derpResponseDebug debugType = "derp" + policyResponseDebug debugType = "policy" ) // NewMapResponseBuilder creates a new builder with basic fields set. diff --git a/hscontrol/mapper/mapper.go b/hscontrol/mapper/mapper.go index c2951c45..bb2c4d6d 100644 --- a/hscontrol/mapper/mapper.go +++ b/hscontrol/mapper/mapper.go @@ -14,6 +14,7 @@ import ( "github.com/juanfont/headscale/hscontrol/state" "github.com/juanfont/headscale/hscontrol/types" + "github.com/juanfont/headscale/hscontrol/types/change" "github.com/rs/zerolog/log" "tailscale.com/envknob" "tailscale.com/tailcfg" @@ -179,52 +180,108 @@ func (m *mapper) selfMapResponse( return ma, err } -func (m *mapper) derpMapResponse( - nodeID types.NodeID, -) (*tailcfg.MapResponse, error) { - return m.NewMapResponseBuilder(nodeID). - WithDebugType(derpResponseDebug). - WithDERPMap(). - Build() -} - -// PeerChangedPatchResponse creates a patch MapResponse with -// incoming update from a state change. -func (m *mapper) peerChangedPatchResponse( - nodeID types.NodeID, - changed []*tailcfg.PeerChange, -) (*tailcfg.MapResponse, error) { - return m.NewMapResponseBuilder(nodeID). - WithDebugType(patchResponseDebug). - WithPeerChangedPatch(changed). - Build() -} - -// peerChangeResponse returns a MapResponse with changed or added nodes. -func (m *mapper) peerChangeResponse( +// policyChangeResponse creates a MapResponse for policy changes. +// It sends: +// - PeersRemoved for peers that are no longer visible after the policy change +// - PeersChanged for remaining peers (their AllowedIPs may have changed due to policy) +// - Updated PacketFilters +// - Updated SSHPolicy (SSH rules may reference users/groups that changed) +// This avoids the issue where an empty Peers slice is interpreted by Tailscale +// clients as "no change" rather than "no peers". +func (m *mapper) policyChangeResponse( nodeID types.NodeID, capVer tailcfg.CapabilityVersion, - changedNodeID types.NodeID, + removedPeers []tailcfg.NodeID, + currentPeers views.Slice[types.NodeView], ) (*tailcfg.MapResponse, error) { - peers := m.state.ListPeers(nodeID, changedNodeID) - - return m.NewMapResponseBuilder(nodeID). - WithDebugType(changeResponseDebug). + builder := m.NewMapResponseBuilder(nodeID). + WithDebugType(policyResponseDebug). WithCapabilityVersion(capVer). - WithUserProfiles(peers). - WithPeerChanges(peers). - Build() + WithPacketFilters(). + WithSSHPolicy() + + if len(removedPeers) > 0 { + // Convert tailcfg.NodeID to types.NodeID for WithPeersRemoved + removedIDs := make([]types.NodeID, len(removedPeers)) + for i, id := range removedPeers { + removedIDs[i] = types.NodeID(id) //nolint:gosec // NodeID types are equivalent + } + + builder.WithPeersRemoved(removedIDs...) + } + + // Send remaining peers in PeersChanged - their AllowedIPs may have + // changed due to the policy update (e.g., different routes allowed). + if currentPeers.Len() > 0 { + builder.WithPeerChanges(currentPeers) + } + + return builder.Build() } -// peerRemovedResponse creates a MapResponse indicating that a peer has been removed. -func (m *mapper) peerRemovedResponse( +// buildFromChange builds a MapResponse from a change.Change specification. +// This provides fine-grained control over what gets included in the response. +func (m *mapper) buildFromChange( nodeID types.NodeID, - removedNodeID types.NodeID, + capVer tailcfg.CapabilityVersion, + resp *change.Change, ) (*tailcfg.MapResponse, error) { - return m.NewMapResponseBuilder(nodeID). - WithDebugType(removeResponseDebug). - WithPeersRemoved(removedNodeID). - Build() + if resp.IsEmpty() { + return nil, nil //nolint:nilnil // Empty response means nothing to send, not an error + } + + // If this is a self-update (the changed node is the receiving node), + // send a self-update response to ensure the node sees its own changes. + if resp.OriginNode != 0 && resp.OriginNode == nodeID { + return m.selfMapResponse(nodeID, capVer) + } + + builder := m.NewMapResponseBuilder(nodeID). + WithCapabilityVersion(capVer). + WithDebugType(changeResponseDebug) + + if resp.IncludeSelf { + builder.WithSelfNode() + } + + if resp.IncludeDERPMap { + builder.WithDERPMap() + } + + if resp.IncludeDNS { + builder.WithDNSConfig() + } + + if resp.IncludeDomain { + builder.WithDomain() + } + + if resp.IncludePolicy { + builder.WithPacketFilters() + builder.WithSSHPolicy() + } + + if resp.SendAllPeers { + peers := m.state.ListPeers(nodeID) + builder.WithUserProfiles(peers) + builder.WithPeers(peers) + } else { + if len(resp.PeersChanged) > 0 { + peers := m.state.ListPeers(nodeID, resp.PeersChanged...) + builder.WithUserProfiles(peers) + builder.WithPeerChanges(peers) + } + + if len(resp.PeersRemoved) > 0 { + builder.WithPeersRemoved(resp.PeersRemoved...) + } + } + + if len(resp.PeerPatches) > 0 { + builder.WithPeerChangedPatch(resp.PeerPatches) + } + + return builder.Build() } func writeDebugMapResponse( diff --git a/hscontrol/oidc.go b/hscontrol/oidc.go index 371c307d..5059d07c 100644 --- a/hscontrol/oidc.go +++ b/hscontrol/oidc.go @@ -473,14 +473,16 @@ func (a *AuthProviderOIDC) getRegistrationIDFromState(state string) *types.Regis func (a *AuthProviderOIDC) createOrUpdateUserFromClaim( claims *types.OIDCClaims, -) (*types.User, change.ChangeSet, error) { - var user *types.User - var err error - var newUser bool - var c change.ChangeSet +) (*types.User, change.Change, error) { + var ( + user *types.User + err error + newUser bool + c change.Change + ) user, err = a.h.state.GetUserByOIDCIdentifier(claims.Identifier()) if err != nil && !errors.Is(err, db.ErrUserNotFound) { - return nil, change.EmptySet, fmt.Errorf("creating or updating user: %w", err) + return nil, change.Change{}, fmt.Errorf("creating or updating user: %w", err) } // if the user is still not found, create a new empty user. @@ -496,7 +498,7 @@ func (a *AuthProviderOIDC) createOrUpdateUserFromClaim( if newUser { user, c, err = a.h.state.CreateUser(*user) if err != nil { - return nil, change.EmptySet, fmt.Errorf("creating user: %w", err) + return nil, change.Change{}, fmt.Errorf("creating user: %w", err) } } else { _, c, err = a.h.state.UpdateUser(types.UserID(user.ID), func(u *types.User) error { @@ -504,7 +506,7 @@ func (a *AuthProviderOIDC) createOrUpdateUserFromClaim( return nil }) if err != nil { - return nil, change.EmptySet, fmt.Errorf("updating user: %w", err) + return nil, change.Change{}, fmt.Errorf("updating user: %w", err) } } @@ -545,7 +547,7 @@ func (a *AuthProviderOIDC) handleRegistration( // Send both changes. Empty changes are ignored by Change(). a.h.Change(nodeChange, routesChange) - return !nodeChange.Empty(), nil + return !nodeChange.IsEmpty(), nil } func renderOIDCCallbackTemplate( diff --git a/hscontrol/state/state.go b/hscontrol/state/state.go index 9dbe8374..43ce155d 100644 --- a/hscontrol/state/state.go +++ b/hscontrol/state/state.go @@ -57,6 +57,15 @@ var ErrUnsupportedPolicyMode = errors.New("unsupported policy mode") // ErrNodeNotFound is returned when a node cannot be found by its ID. var ErrNodeNotFound = errors.New("node not found") +// ErrInvalidNodeView is returned when an invalid node view is provided. +var ErrInvalidNodeView = errors.New("invalid node view provided") + +// ErrNodeNotInNodeStore is returned when a node no longer exists in the NodeStore. +var ErrNodeNotInNodeStore = errors.New("node no longer exists in NodeStore") + +// ErrNodeNameNotUnique is returned when a node name is not unique. +var ErrNodeNameNotUnique = errors.New("node name is not unique") + // State manages Headscale's core state, coordinating between database, policy management, // IP allocation, and DERP routing. All methods are thread-safe. type State struct { @@ -243,7 +252,7 @@ func (s *State) DERPMap() tailcfg.DERPMapView { // ReloadPolicy reloads the access control policy and triggers auto-approval if changed. // Returns true if the policy changed. -func (s *State) ReloadPolicy() ([]change.ChangeSet, error) { +func (s *State) ReloadPolicy() ([]change.Change, error) { pol, err := policyBytes(s.db, s.cfg) if err != nil { return nil, fmt.Errorf("loading policy: %w", err) @@ -260,7 +269,7 @@ func (s *State) ReloadPolicy() ([]change.ChangeSet, error) { // propagate correctly when switching between policy types. s.nodeStore.RebuildPeerMaps() - cs := []change.ChangeSet{change.PolicyChange()} + cs := []change.Change{change.PolicyChange()} // Always call autoApproveNodes during policy reload, regardless of whether // the policy content has changed. This ensures that routes are re-evaluated @@ -289,16 +298,16 @@ func (s *State) ReloadPolicy() ([]change.ChangeSet, error) { // CreateUser creates a new user and updates the policy manager. // Returns the created user, change set, and any error. -func (s *State) CreateUser(user types.User) (*types.User, change.ChangeSet, error) { +func (s *State) CreateUser(user types.User) (*types.User, change.Change, error) { if err := s.db.DB.Save(&user).Error; err != nil { - return nil, change.EmptySet, fmt.Errorf("creating user: %w", err) + return nil, change.Change{}, fmt.Errorf("creating user: %w", err) } // Check if policy manager needs updating c, err := s.updatePolicyManagerUsers() if err != nil { // Log the error but don't fail the user creation - return &user, change.EmptySet, fmt.Errorf("failed to update policy manager after user creation: %w", err) + return &user, change.Change{}, fmt.Errorf("failed to update policy manager after user creation: %w", err) } // Even if the policy manager doesn't detect a filter change, SSH policies @@ -306,7 +315,7 @@ func (s *State) CreateUser(user types.User) (*types.User, change.ChangeSet, erro // nodes, we should send a policy change to ensure they get updated SSH policies. // TODO(kradalby): detect this, or rebuild all SSH policies so we can determine // this upstream. - if c.Empty() { + if c.IsEmpty() { c = change.PolicyChange() } @@ -317,7 +326,7 @@ func (s *State) CreateUser(user types.User) (*types.User, change.ChangeSet, erro // UpdateUser modifies an existing user using the provided update function within a transaction. // Returns the updated user, change set, and any error. -func (s *State) UpdateUser(userID types.UserID, updateFn func(*types.User) error) (*types.User, change.ChangeSet, error) { +func (s *State) UpdateUser(userID types.UserID, updateFn func(*types.User) error) (*types.User, change.Change, error) { user, err := hsdb.Write(s.db.DB, func(tx *gorm.DB) (*types.User, error) { user, err := hsdb.GetUserByID(tx, userID) if err != nil { @@ -337,13 +346,13 @@ func (s *State) UpdateUser(userID types.UserID, updateFn func(*types.User) error return user, nil }) if err != nil { - return nil, change.EmptySet, err + return nil, change.Change{}, err } // Check if policy manager needs updating c, err := s.updatePolicyManagerUsers() if err != nil { - return user, change.EmptySet, fmt.Errorf("failed to update policy manager after user update: %w", err) + return user, change.Change{}, fmt.Errorf("failed to update policy manager after user update: %w", err) } // TODO(kradalby): We might want to update nodestore with the user data @@ -358,7 +367,7 @@ func (s *State) DeleteUser(userID types.UserID) error { } // RenameUser changes a user's name. The new name must be unique. -func (s *State) RenameUser(userID types.UserID, newName string) (*types.User, change.ChangeSet, error) { +func (s *State) RenameUser(userID types.UserID, newName string) (*types.User, change.Change, error) { return s.UpdateUser(userID, func(user *types.User) error { user.Name = newName return nil @@ -395,9 +404,9 @@ func (s *State) ListAllUsers() ([]types.User, error) { // NodeStore and the database. It verifies the node still exists in NodeStore to prevent // race conditions where a node might be deleted between UpdateNode returning and // persistNodeToDB being called. -func (s *State) persistNodeToDB(node types.NodeView) (types.NodeView, change.ChangeSet, error) { +func (s *State) persistNodeToDB(node types.NodeView) (types.NodeView, change.Change, error) { if !node.Valid() { - return types.NodeView{}, change.EmptySet, fmt.Errorf("invalid node view provided") + return types.NodeView{}, change.Change{}, ErrInvalidNodeView } // Verify the node still exists in NodeStore before persisting to database. @@ -411,7 +420,8 @@ func (s *State) persistNodeToDB(node types.NodeView) (types.NodeView, change.Cha Str("node.name", node.Hostname()). Bool("is_ephemeral", node.IsEphemeral()). Msg("Node no longer exists in NodeStore, skipping database persist to prevent race condition") - return types.NodeView{}, change.EmptySet, fmt.Errorf("node %d no longer exists in NodeStore, skipping database persist", node.ID()) + + return types.NodeView{}, change.Change{}, fmt.Errorf("%w: %d", ErrNodeNotInNodeStore, node.ID()) } nodePtr := node.AsStruct() @@ -421,23 +431,23 @@ func (s *State) persistNodeToDB(node types.NodeView) (types.NodeView, change.Cha // See: https://github.com/juanfont/headscale/issues/2862 err := s.db.DB.Omit("expiry").Updates(nodePtr).Error if err != nil { - return types.NodeView{}, change.EmptySet, fmt.Errorf("saving node: %w", err) + return types.NodeView{}, change.Change{}, fmt.Errorf("saving node: %w", err) } // Check if policy manager needs updating c, err := s.updatePolicyManagerNodes() if err != nil { - return nodePtr.View(), change.EmptySet, fmt.Errorf("failed to update policy manager after node save: %w", err) + return nodePtr.View(), change.Change{}, fmt.Errorf("failed to update policy manager after node save: %w", err) } - if c.Empty() { + if c.IsEmpty() { c = change.NodeAdded(node.ID()) } return node, c, nil } -func (s *State) SaveNode(node types.NodeView) (types.NodeView, change.ChangeSet, error) { +func (s *State) SaveNode(node types.NodeView) (types.NodeView, change.Change, error) { // Update NodeStore first nodePtr := node.AsStruct() @@ -449,12 +459,12 @@ func (s *State) SaveNode(node types.NodeView) (types.NodeView, change.ChangeSet, // DeleteNode permanently removes a node and cleans up associated resources. // Returns whether policies changed and any error. This operation is irreversible. -func (s *State) DeleteNode(node types.NodeView) (change.ChangeSet, error) { +func (s *State) DeleteNode(node types.NodeView) (change.Change, error) { s.nodeStore.DeleteNode(node.ID()) err := s.db.DeleteNode(node.AsStruct()) if err != nil { - return change.EmptySet, err + return change.Change{}, err } s.ipAlloc.FreeIPs(node.IPs()) @@ -464,10 +474,10 @@ func (s *State) DeleteNode(node types.NodeView) (change.ChangeSet, error) { // Check if policy manager needs updating after node deletion policyChange, err := s.updatePolicyManagerNodes() if err != nil { - return change.EmptySet, fmt.Errorf("failed to update policy manager after node deletion: %w", err) + return change.Change{}, fmt.Errorf("failed to update policy manager after node deletion: %w", err) } - if !policyChange.Empty() { + if !policyChange.IsEmpty() { c = policyChange } @@ -475,7 +485,7 @@ func (s *State) DeleteNode(node types.NodeView) (change.ChangeSet, error) { } // Connect marks a node as connected and updates its primary routes in the state. -func (s *State) Connect(id types.NodeID) []change.ChangeSet { +func (s *State) Connect(id types.NodeID) []change.Change { // CRITICAL FIX: Update the online status in NodeStore BEFORE creating change notification // This ensures that when the NodeCameOnline change is distributed and processed by other nodes, // the NodeStore already reflects the correct online status for full map generation. @@ -488,7 +498,7 @@ func (s *State) Connect(id types.NodeID) []change.ChangeSet { return nil } - c := []change.ChangeSet{change.NodeOnline(node)} + c := []change.Change{change.NodeOnlineFor(node)} log.Info().Uint64("node.id", id.Uint64()).Str("node.name", node.Hostname()).Msg("Node connected") @@ -505,7 +515,7 @@ func (s *State) Connect(id types.NodeID) []change.ChangeSet { } // Disconnect marks a node as disconnected and updates its primary routes in the state. -func (s *State) Disconnect(id types.NodeID) ([]change.ChangeSet, error) { +func (s *State) Disconnect(id types.NodeID) ([]change.Change, error) { now := time.Now() node, ok := s.nodeStore.UpdateNode(id, func(n *types.Node) { @@ -527,14 +537,15 @@ func (s *State) Disconnect(id types.NodeID) ([]change.ChangeSet, error) { // Log error but don't fail the disconnection - NodeStore is already updated // and we need to send change notifications to peers log.Error().Err(err).Uint64("node.id", id.Uint64()).Str("node.name", node.Hostname()).Msg("Failed to update last seen in database") - c = change.EmptySet + + c = change.Change{} } // The node is disconnecting so make sure that none of the routes it // announced are served to any nodes. routeChange := s.primaryRoutes.SetRoutes(id) - cs := []change.ChangeSet{change.NodeOffline(node), c} + cs := []change.Change{change.NodeOfflineFor(node), c} // If we have a policy change or route change, return that as it's more comprehensive // Otherwise, return the NodeOffline change to ensure nodes are notified @@ -637,7 +648,7 @@ func (s *State) ListEphemeralNodes() views.Slice[types.NodeView] { } // SetNodeExpiry updates the expiration time for a node. -func (s *State) SetNodeExpiry(nodeID types.NodeID, expiry time.Time) (types.NodeView, change.ChangeSet, error) { +func (s *State) SetNodeExpiry(nodeID types.NodeID, expiry time.Time) (types.NodeView, change.Change, error) { // Update NodeStore before database to ensure consistency. The NodeStore update is // blocking and will be the source of truth for the batcher. The database update must // make the exact same change. If the database update fails, the NodeStore change will @@ -649,7 +660,7 @@ func (s *State) SetNodeExpiry(nodeID types.NodeID, expiry time.Time) (types.Node }) if !ok { - return types.NodeView{}, change.EmptySet, fmt.Errorf("node not found in NodeStore: %d", nodeID) + return types.NodeView{}, change.Change{}, fmt.Errorf("%w: %d", ErrNodeNotInNodeStore, nodeID) } return s.persistNodeToDB(n) @@ -658,16 +669,16 @@ func (s *State) SetNodeExpiry(nodeID types.NodeID, expiry time.Time) (types.Node // SetNodeTags assigns tags to a node, making it a "tagged node". // Once a node is tagged, it cannot be un-tagged (only tags can be changed). // The UserID is preserved as "created by" information. -func (s *State) SetNodeTags(nodeID types.NodeID, tags []string) (types.NodeView, change.ChangeSet, error) { +func (s *State) SetNodeTags(nodeID types.NodeID, tags []string) (types.NodeView, change.Change, error) { // CANNOT REMOVE ALL TAGS if len(tags) == 0 { - return types.NodeView{}, change.EmptySet, types.ErrCannotRemoveAllTags + return types.NodeView{}, change.Change{}, types.ErrCannotRemoveAllTags } // Get node for validation existingNode, exists := s.nodeStore.GetNode(nodeID) if !exists { - return types.NodeView{}, change.EmptySet, fmt.Errorf("%w: %d", ErrNodeNotFound, nodeID) + return types.NodeView{}, change.Change{}, fmt.Errorf("%w: %d", ErrNodeNotFound, nodeID) } // Validate tags: must have correct format and exist in policy @@ -685,7 +696,7 @@ func (s *State) SetNodeTags(nodeID types.NodeID, tags []string) (types.NodeView, } if len(invalidTags) > 0 { - return types.NodeView{}, change.EmptySet, fmt.Errorf("%w %v are invalid or not permitted", ErrRequestedTagsInvalidOrNotPermitted, invalidTags) + return types.NodeView{}, change.Change{}, fmt.Errorf("%w %v are invalid or not permitted", ErrRequestedTagsInvalidOrNotPermitted, invalidTags) } slices.Sort(validatedTags) @@ -703,14 +714,14 @@ func (s *State) SetNodeTags(nodeID types.NodeID, tags []string) (types.NodeView, }) if !ok { - return types.NodeView{}, change.EmptySet, fmt.Errorf("node not found in NodeStore: %d", nodeID) + return types.NodeView{}, change.Change{}, fmt.Errorf("%w: %d", ErrNodeNotInNodeStore, nodeID) } return s.persistNodeToDB(n) } // SetApprovedRoutes sets the network routes that a node is approved to advertise. -func (s *State) SetApprovedRoutes(nodeID types.NodeID, routes []netip.Prefix) (types.NodeView, change.ChangeSet, error) { +func (s *State) SetApprovedRoutes(nodeID types.NodeID, routes []netip.Prefix) (types.NodeView, change.Change, error) { // TODO(kradalby): In principle we should call the AutoApprove logic here // because even if the CLI removes an auto-approved route, it will be added // back automatically. @@ -719,13 +730,13 @@ func (s *State) SetApprovedRoutes(nodeID types.NodeID, routes []netip.Prefix) (t }) if !ok { - return types.NodeView{}, change.EmptySet, fmt.Errorf("node not found in NodeStore: %d", nodeID) + return types.NodeView{}, change.Change{}, fmt.Errorf("%w: %d", ErrNodeNotInNodeStore, nodeID) } // Persist the node changes to the database nodeView, c, err := s.persistNodeToDB(n) if err != nil { - return types.NodeView{}, change.EmptySet, err + return types.NodeView{}, change.Change{}, err } // Update primary routes table based on SubnetRoutes (intersection of announced and approved). @@ -743,9 +754,9 @@ func (s *State) SetApprovedRoutes(nodeID types.NodeID, routes []netip.Prefix) (t } // RenameNode changes the display name of a node. -func (s *State) RenameNode(nodeID types.NodeID, newName string) (types.NodeView, change.ChangeSet, error) { +func (s *State) RenameNode(nodeID types.NodeID, newName string) (types.NodeView, change.Change, error) { if err := util.ValidateHostname(newName); err != nil { - return types.NodeView{}, change.EmptySet, fmt.Errorf("renaming node: %w", err) + return types.NodeView{}, change.Change{}, fmt.Errorf("renaming node: %w", err) } // Check name uniqueness against NodeStore @@ -753,7 +764,7 @@ func (s *State) RenameNode(nodeID types.NodeID, newName string) (types.NodeView, for i := 0; i < allNodes.Len(); i++ { node := allNodes.At(i) if node.ID() != nodeID && node.AsStruct().GivenName == newName { - return types.NodeView{}, change.EmptySet, fmt.Errorf("name is not unique: %s", newName) + return types.NodeView{}, change.Change{}, fmt.Errorf("%w: %s", ErrNodeNameNotUnique, newName) } } @@ -765,7 +776,7 @@ func (s *State) RenameNode(nodeID types.NodeID, newName string) (types.NodeView, }) if !ok { - return types.NodeView{}, change.EmptySet, fmt.Errorf("node not found in NodeStore: %d", nodeID) + return types.NodeView{}, change.Change{}, fmt.Errorf("%w: %d", ErrNodeNotInNodeStore, nodeID) } return s.persistNodeToDB(n) @@ -810,12 +821,12 @@ func (s *State) BackfillNodeIPs() ([]string, error) { // ExpireExpiredNodes finds and processes expired nodes since the last check. // Returns next check time, state update with expired nodes, and whether any were found. -func (s *State) ExpireExpiredNodes(lastCheck time.Time) (time.Time, []change.ChangeSet, bool) { +func (s *State) ExpireExpiredNodes(lastCheck time.Time) (time.Time, []change.Change, bool) { // Why capture start time: We need to ensure we don't miss nodes that expire // while this function is running by using a consistent timestamp for the next check started := time.Now() - var updates []change.ChangeSet + var updates []change.Change for _, node := range s.nodeStore.ListNodes().All() { if !node.Valid() { @@ -825,7 +836,7 @@ func (s *State) ExpireExpiredNodes(lastCheck time.Time) (time.Time, []change.Cha // Why check After(lastCheck): We only want to notify about nodes that // expired since the last check to avoid duplicate notifications if node.IsExpired() && node.Expiry().Valid() && node.Expiry().Get().After(lastCheck) { - updates = append(updates, change.KeyExpiry(node.ID(), node.Expiry().Get())) + updates = append(updates, change.KeyExpiryFor(node.ID(), node.Expiry().Get())) } } @@ -868,7 +879,7 @@ func (s *State) SetPolicy(pol []byte) (bool, error) { // AutoApproveRoutes checks if a node's routes should be auto-approved. // AutoApproveRoutes checks if any routes should be auto-approved for a node and updates them. -func (s *State) AutoApproveRoutes(nv types.NodeView) (change.ChangeSet, error) { +func (s *State) AutoApproveRoutes(nv types.NodeView) (change.Change, error) { approved, changed := policy.ApproveRoutesWithPolicy(s.polMan, nv, nv.ApprovedRoutes().AsSlice(), nv.AnnouncedRoutes()) if changed { log.Debug(). @@ -889,7 +900,7 @@ func (s *State) AutoApproveRoutes(nv types.NodeView) (change.ChangeSet, error) { Err(err). Msg("Failed to persist auto-approved routes") - return change.EmptySet, err + return change.Change{}, err } log.Info().Uint64("node.id", nv.ID().Uint64()).Str("node.name", nv.Hostname()).Strs("routes.approved", util.PrefixesToString(approved)).Msg("Routes approved") @@ -897,7 +908,7 @@ func (s *State) AutoApproveRoutes(nv types.NodeView) (change.ChangeSet, error) { return c, nil } - return change.EmptySet, nil + return change.Change{}, nil } // GetPolicy retrieves the current policy from the database. @@ -911,14 +922,14 @@ func (s *State) SetPolicyInDB(data string) (*types.Policy, error) { } // SetNodeRoutes sets the primary routes for a node. -func (s *State) SetNodeRoutes(nodeID types.NodeID, routes ...netip.Prefix) change.ChangeSet { +func (s *State) SetNodeRoutes(nodeID types.NodeID, routes ...netip.Prefix) change.Change { if s.primaryRoutes.SetRoutes(nodeID, routes...) { // Route changes affect packet filters for all nodes, so trigger a policy change // to ensure filters are regenerated across the entire network return change.PolicyChange() } - return change.EmptySet + return change.Change{} } // GetNodePrimaryRoutes returns the primary routes for a node. @@ -1232,17 +1243,17 @@ func (s *State) HandleNodeFromAuthPath( userID types.UserID, expiry *time.Time, registrationMethod string, -) (types.NodeView, change.ChangeSet, error) { +) (types.NodeView, change.Change, error) { // Get the registration entry from cache regEntry, ok := s.GetRegistrationCacheEntry(registrationID) if !ok { - return types.NodeView{}, change.EmptySet, hsdb.ErrNodeNotFoundRegistrationCache + return types.NodeView{}, change.Change{}, hsdb.ErrNodeNotFoundRegistrationCache } // Get the user user, err := s.db.GetUserByID(userID) if err != nil { - return types.NodeView{}, change.EmptySet, fmt.Errorf("failed to find user: %w", err) + return types.NodeView{}, change.Change{}, fmt.Errorf("failed to find user: %w", err) } // Ensure we have a valid hostname from the registration cache entry @@ -1306,7 +1317,7 @@ func (s *State) HandleNodeFromAuthPath( }) if !ok { - return types.NodeView{}, change.EmptySet, fmt.Errorf("node not found in NodeStore: %d", existingNodeSameUser.ID()) + return types.NodeView{}, change.Change{}, fmt.Errorf("%w: %d", ErrNodeNotInNodeStore, existingNodeSameUser.ID()) } _, err = hsdb.Write(s.db.DB, func(tx *gorm.DB) (*types.Node, error) { @@ -1318,7 +1329,7 @@ func (s *State) HandleNodeFromAuthPath( return nil, nil }) if err != nil { - return types.NodeView{}, change.EmptySet, err + return types.NodeView{}, change.Change{}, err } log.Trace(). @@ -1376,7 +1387,7 @@ func (s *State) HandleNodeFromAuthPath( ExistingNodeForNetinfo: cmp.Or(existingNodeAnyUser, types.NodeView{}), }) if err != nil { - return types.NodeView{}, change.EmptySet, err + return types.NodeView{}, change.Change{}, err } } @@ -1397,8 +1408,8 @@ func (s *State) HandleNodeFromAuthPath( return finalNode, change.NodeAdded(finalNode.ID()), fmt.Errorf("failed to update policy manager nodes: %w", err) } - var c change.ChangeSet - if !usersChange.Empty() || !nodesChange.Empty() { + var c change.Change + if !usersChange.IsEmpty() || !nodesChange.IsEmpty() { c = change.PolicyChange() } else { c = change.NodeAdded(finalNode.ID()) @@ -1411,10 +1422,10 @@ func (s *State) HandleNodeFromAuthPath( func (s *State) HandleNodeFromPreAuthKey( regReq tailcfg.RegisterRequest, machineKey key.MachinePublic, -) (types.NodeView, change.ChangeSet, error) { +) (types.NodeView, change.Change, error) { pak, err := s.GetPreAuthKey(regReq.Auth.AuthKey) if err != nil { - return types.NodeView{}, change.EmptySet, err + return types.NodeView{}, change.Change{}, err } // Check if node exists with same machine key before validating the key. @@ -1461,7 +1472,7 @@ func (s *State) HandleNodeFromPreAuthKey( // New node or NodeKey rotation: require valid auth key. err = pak.Validate() if err != nil { - return types.NodeView{}, change.EmptySet, err + return types.NodeView{}, change.Change{}, err } } @@ -1535,7 +1546,7 @@ func (s *State) HandleNodeFromPreAuthKey( }) if !ok { - return types.NodeView{}, change.EmptySet, fmt.Errorf("node not found in NodeStore: %d", existingNodeSameUser.ID()) + return types.NodeView{}, change.Change{}, fmt.Errorf("%w: %d", ErrNodeNotInNodeStore, existingNodeSameUser.ID()) } _, err = hsdb.Write(s.db.DB, func(tx *gorm.DB) (*types.Node, error) { @@ -1555,7 +1566,7 @@ func (s *State) HandleNodeFromPreAuthKey( return nil, nil }) if err != nil { - return types.NodeView{}, change.EmptySet, fmt.Errorf("writing node to database: %w", err) + return types.NodeView{}, change.Change{}, fmt.Errorf("writing node to database: %w", err) } log.Trace(). @@ -1607,7 +1618,7 @@ func (s *State) HandleNodeFromPreAuthKey( ExistingNodeForNetinfo: cmp.Or(existingNodeAnyUser, types.NodeView{}), }) if err != nil { - return types.NodeView{}, change.EmptySet, fmt.Errorf("creating new node: %w", err) + return types.NodeView{}, change.Change{}, fmt.Errorf("creating new node: %w", err) } } @@ -1622,8 +1633,8 @@ func (s *State) HandleNodeFromPreAuthKey( return finalNode, change.NodeAdded(finalNode.ID()), fmt.Errorf("failed to update policy manager nodes: %w", err) } - var c change.ChangeSet - if !usersChange.Empty() || !nodesChange.Empty() { + var c change.Change + if !usersChange.IsEmpty() || !nodesChange.IsEmpty() { c = change.PolicyChange() } else { c = change.NodeAdded(finalNode.ID()) @@ -1638,17 +1649,17 @@ func (s *State) HandleNodeFromPreAuthKey( // have the list already available so it could go much quicker. Alternatively // the policy manager could have a remove or add list for users. // updatePolicyManagerUsers refreshes the policy manager with current user data. -func (s *State) updatePolicyManagerUsers() (change.ChangeSet, error) { +func (s *State) updatePolicyManagerUsers() (change.Change, error) { users, err := s.ListAllUsers() if err != nil { - return change.EmptySet, fmt.Errorf("listing users for policy update: %w", err) + return change.Change{}, fmt.Errorf("listing users for policy update: %w", err) } log.Debug().Caller().Int("user.count", len(users)).Msg("Policy manager user update initiated because user list modification detected") changed, err := s.polMan.SetUsers(users) if err != nil { - return change.EmptySet, fmt.Errorf("updating policy manager users: %w", err) + return change.Change{}, fmt.Errorf("updating policy manager users: %w", err) } log.Debug().Caller().Bool("policy.changed", changed).Msg("Policy manager user update completed because SetUsers operation finished") @@ -1657,7 +1668,7 @@ func (s *State) updatePolicyManagerUsers() (change.ChangeSet, error) { return change.PolicyChange(), nil } - return change.EmptySet, nil + return change.Change{}, nil } // updatePolicyManagerNodes updates the policy manager with current nodes. @@ -1666,19 +1677,22 @@ func (s *State) updatePolicyManagerUsers() (change.ChangeSet, error) { // have the list already available so it could go much quicker. Alternatively // the policy manager could have a remove or add list for nodes. // updatePolicyManagerNodes refreshes the policy manager with current node data. -func (s *State) updatePolicyManagerNodes() (change.ChangeSet, error) { +func (s *State) updatePolicyManagerNodes() (change.Change, error) { nodes := s.ListNodes() changed, err := s.polMan.SetNodes(nodes) if err != nil { - return change.EmptySet, fmt.Errorf("updating policy manager nodes: %w", err) + return change.Change{}, fmt.Errorf("updating policy manager nodes: %w", err) } if changed { + // Rebuild peer maps because policy-affecting node changes (tags, user, IPs) + // affect ACL visibility. Without this, cached peer relationships use stale data. + s.nodeStore.RebuildPeerMaps() return change.PolicyChange(), nil } - return change.EmptySet, nil + return change.Change{}, nil } // PingDB checks if the database connection is healthy. @@ -1692,14 +1706,16 @@ func (s *State) PingDB(ctx context.Context) error { // TODO(kradalby): This is kind of messy, maybe this is another +1 // for an event bus. See example comments here. // autoApproveNodes automatically approves nodes based on policy rules. -func (s *State) autoApproveNodes() ([]change.ChangeSet, error) { +func (s *State) autoApproveNodes() ([]change.Change, error) { nodes := s.ListNodes() // Approve routes concurrently, this should make it likely // that the writes end in the same batch in the nodestore write. - var errg errgroup.Group - var cs []change.ChangeSet - var mu sync.Mutex + var ( + errg errgroup.Group + cs []change.Change + mu sync.Mutex + ) for _, nv := range nodes.All() { errg.Go(func() error { approved, changed := policy.ApproveRoutesWithPolicy(s.polMan, nv, nv.ApprovedRoutes().AsSlice(), nv.AnnouncedRoutes()) @@ -1740,7 +1756,7 @@ func (s *State) autoApproveNodes() ([]change.ChangeSet, error) { // - node.PeerChangeFromMapRequest // - node.ApplyPeerChange // - logTracePeerChange in poll.go. -func (s *State) UpdateNodeFromMapRequest(id types.NodeID, req tailcfg.MapRequest) (change.ChangeSet, error) { +func (s *State) UpdateNodeFromMapRequest(id types.NodeID, req tailcfg.MapRequest) (change.Change, error) { log.Trace(). Caller(). Uint64("node.id", id.Uint64()). @@ -1853,7 +1869,7 @@ func (s *State) UpdateNodeFromMapRequest(id types.NodeID, req tailcfg.MapRequest }) if !ok { - return change.EmptySet, fmt.Errorf("node not found in NodeStore: %d", id) + return change.Change{}, fmt.Errorf("%w: %d", ErrNodeNotInNodeStore, id) } if routeChange { @@ -1865,80 +1881,67 @@ func (s *State) UpdateNodeFromMapRequest(id types.NodeID, req tailcfg.MapRequest // SetApprovedRoutes will update both database and PrimaryRoutes table _, c, err := s.SetApprovedRoutes(id, autoApprovedRoutes) if err != nil { - return change.EmptySet, fmt.Errorf("persisting auto-approved routes: %w", err) + return change.Change{}, fmt.Errorf("persisting auto-approved routes: %w", err) } // If SetApprovedRoutes resulted in a policy change, return it - if !c.Empty() { + if !c.IsEmpty() { return c, nil } } // Continue with the rest of the processing using the updated node - nodeRouteChange := change.EmptySet - - // Handle route changes after NodeStore update - // We need to update node routes if either: - // 1. The approved routes changed (routeChange is true), OR - // 2. The announced routes changed (even if approved routes stayed the same) - // This is because SubnetRoutes is the intersection of announced AND approved routes. - needsRouteUpdate := false - var routesChangedButNotApproved bool - if hostinfoChanged && needsRouteApproval && !routeChange { - if hi := req.Hostinfo; hi != nil { - routesChangedButNotApproved = true - } - } - - if routesChangedButNotApproved { - needsRouteUpdate = true - log.Debug(). - Caller(). - Uint64("node.id", id.Uint64()). - Msg("updating routes because announced routes changed but approved routes did not") - } - - if needsRouteUpdate { - // SetNodeRoutes sets the active/distributed routes, so we must use AllApprovedRoutes() - // which returns only the intersection of announced AND approved routes. - // Using AnnouncedRoutes() would bypass the security model and auto-approve everything. - log.Debug(). - Caller(). - Uint64("node.id", id.Uint64()). - Strs("announcedRoutes", util.PrefixesToString(updatedNode.AnnouncedRoutes())). - Strs("approvedRoutes", util.PrefixesToString(updatedNode.ApprovedRoutes().AsSlice())). - Strs("allApprovedRoutes", util.PrefixesToString(updatedNode.AllApprovedRoutes())). - Msg("updating node routes for distribution") - nodeRouteChange = s.SetNodeRoutes(id, updatedNode.AllApprovedRoutes()...) - } + // Handle route changes after NodeStore update. + // Update routes if announced routes changed (even if approved routes stayed the same) + // because SubnetRoutes is the intersection of announced AND approved routes. + nodeRouteChange := s.maybeUpdateNodeRoutes(id, updatedNode, hostinfoChanged, needsRouteApproval, routeChange, req.Hostinfo) _, policyChange, err := s.persistNodeToDB(updatedNode) if err != nil { - return change.EmptySet, fmt.Errorf("saving to database: %w", err) + return change.Change{}, fmt.Errorf("saving to database: %w", err) } if policyChange.IsFull() { return policyChange, nil } - if !nodeRouteChange.Empty() { + + if !nodeRouteChange.IsEmpty() { return nodeRouteChange, nil } // Determine the most specific change type based on what actually changed. // This allows us to send lightweight patch updates instead of full map responses. + return buildMapRequestChangeResponse(id, updatedNode, hostinfoChanged, endpointChanged, derpChanged) +} + +// buildMapRequestChangeResponse determines the appropriate response type for a MapRequest update. +// Hostinfo changes require a full update, while endpoint/DERP changes can use lightweight patches. +func buildMapRequestChangeResponse( + id types.NodeID, + node types.NodeView, + hostinfoChanged, endpointChanged, derpChanged bool, +) (change.Change, error) { // Hostinfo changes require NodeAdded (full update) as they may affect many fields. if hostinfoChanged { return change.NodeAdded(id), nil } // Return specific change types for endpoint and/or DERP updates. - // The batcher will query NodeStore for current state and include both in PeerChange if both changed. - // Prioritize endpoint changes as they're more common and important for connectivity. - if endpointChanged { - return change.EndpointUpdate(id), nil - } + if endpointChanged || derpChanged { + patch := &tailcfg.PeerChange{NodeID: id.NodeID()} - if derpChanged { - return change.DERPUpdate(id), nil + if endpointChanged { + patch.Endpoints = node.Endpoints().AsSlice() + } + + if derpChanged { + if hi := node.Hostinfo(); hi.Valid() { + if ni := hi.NetInfo(); ni.Valid() { + patch.DERPRegion = ni.PreferredDERP() + } + } + } + + return change.EndpointOrDERPUpdate(id, patch), nil } return change.NodeAdded(id), nil @@ -1983,3 +1986,34 @@ func peerChangeEmpty(peerChange tailcfg.PeerChange) bool { peerChange.LastSeen == nil && peerChange.KeyExpiry == nil } + +// maybeUpdateNodeRoutes updates node routes if announced routes changed but approved routes didn't. +// This is needed because SubnetRoutes is the intersection of announced AND approved routes. +func (s *State) maybeUpdateNodeRoutes( + id types.NodeID, + node types.NodeView, + hostinfoChanged, needsRouteApproval, routeChange bool, + hostinfo *tailcfg.Hostinfo, +) change.Change { + // Only update if announced routes changed without approval change + if !hostinfoChanged || !needsRouteApproval || routeChange || hostinfo == nil { + return change.Change{} + } + + log.Debug(). + Caller(). + Uint64("node.id", id.Uint64()). + Msg("updating routes because announced routes changed but approved routes did not") + + // SetNodeRoutes sets the active/distributed routes using AllApprovedRoutes() + // which returns only the intersection of announced AND approved routes. + log.Debug(). + Caller(). + Uint64("node.id", id.Uint64()). + Strs("announcedRoutes", util.PrefixesToString(node.AnnouncedRoutes())). + Strs("approvedRoutes", util.PrefixesToString(node.ApprovedRoutes().AsSlice())). + Strs("allApprovedRoutes", util.PrefixesToString(node.AllApprovedRoutes())). + Msg("updating node routes for distribution") + + return s.SetNodeRoutes(id, node.AllApprovedRoutes()...) +} diff --git a/hscontrol/types/change/change.go b/hscontrol/types/change/change.go index 4c02c1f5..307ef690 100644 --- a/hscontrol/types/change/change.go +++ b/hscontrol/types/change/change.go @@ -1,241 +1,445 @@ -//go:generate go tool stringer -type=Change package change import ( - "errors" + "slices" "time" "github.com/juanfont/headscale/hscontrol/types" + "tailscale.com/tailcfg" ) -type ( - NodeID = types.NodeID - UserID = types.UserID -) +// Change declares what should be included in a MapResponse. +// The mapper uses this to build the response without guessing. +type Change struct { + // Reason is a human-readable description for logging/debugging. + Reason string -type Change int + // TargetNode, if set, means this response should only be sent to this node. + TargetNode types.NodeID -const ( - ChangeUnknown Change = 0 + // OriginNode is the node that triggered this change. + // Used for self-update detection and filtering. + OriginNode types.NodeID - // Deprecated: Use specific change instead - // Full is a legacy change to ensure places where we - // have not yet determined the specific update, can send. - Full Change = 9 + // Content flags - what to include in the MapResponse. + IncludeSelf bool + IncludeDERPMap bool + IncludeDNS bool + IncludeDomain bool + IncludePolicy bool // PacketFilters and SSHPolicy - always sent together - // Server changes. - Policy Change = 11 - DERP Change = 12 - ExtraRecords Change = 13 + // Peer changes. + PeersChanged []types.NodeID + PeersRemoved []types.NodeID + PeerPatches []*tailcfg.PeerChange + SendAllPeers bool - // Node changes. - NodeCameOnline Change = 21 - NodeWentOffline Change = 22 - NodeRemove Change = 23 - NodeKeyExpiry Change = 24 - NodeNewOrUpdate Change = 25 - NodeEndpoint Change = 26 - NodeDERP Change = 27 + // RequiresRuntimePeerComputation indicates that peer visibility + // must be computed at runtime per-node. Used for policy changes + // where each node may have different peer visibility. + RequiresRuntimePeerComputation bool +} - // User changes. - UserNewOrUpdate Change = 51 - UserRemove Change = 52 -) +// boolFieldNames returns all boolean field names for exhaustive testing. +// When adding a new boolean field to Change, add it here. +// Tests use reflection to verify this matches the struct. +func (r Change) boolFieldNames() []string { + return []string{ + "IncludeSelf", + "IncludeDERPMap", + "IncludeDNS", + "IncludeDomain", + "IncludePolicy", + "SendAllPeers", + "RequiresRuntimePeerComputation", + } +} -// AlsoSelf reports whether this change should also be sent to the node itself. -func (c Change) AlsoSelf() bool { - switch c { - case NodeRemove, NodeKeyExpiry, NodeNewOrUpdate: - return true +func (r Change) Merge(other Change) Change { + merged := r + + merged.IncludeSelf = r.IncludeSelf || other.IncludeSelf + merged.IncludeDERPMap = r.IncludeDERPMap || other.IncludeDERPMap + merged.IncludeDNS = r.IncludeDNS || other.IncludeDNS + merged.IncludeDomain = r.IncludeDomain || other.IncludeDomain + merged.IncludePolicy = r.IncludePolicy || other.IncludePolicy + merged.SendAllPeers = r.SendAllPeers || other.SendAllPeers + merged.RequiresRuntimePeerComputation = r.RequiresRuntimePeerComputation || other.RequiresRuntimePeerComputation + + merged.PeersChanged = uniqueNodeIDs(append(r.PeersChanged, other.PeersChanged...)) + merged.PeersRemoved = uniqueNodeIDs(append(r.PeersRemoved, other.PeersRemoved...)) + merged.PeerPatches = append(r.PeerPatches, other.PeerPatches...) + + if r.Reason != "" && other.Reason != "" && r.Reason != other.Reason { + merged.Reason = r.Reason + "; " + other.Reason + } else if other.Reason != "" { + merged.Reason = other.Reason } - return false + return merged } -type ChangeSet struct { - Change Change - - // SelfUpdateOnly indicates that this change should only be sent - // to the node itself, and not to other nodes. - // This is used for changes that are not relevant to other nodes. - // NodeID must be set if this is true. - SelfUpdateOnly bool - - // NodeID if set, is the ID of the node that is being changed. - // It must be set if this is a node change. - NodeID types.NodeID - - // UserID if set, is the ID of the user that is being changed. - // It must be set if this is a user change. - UserID types.UserID - - // IsSubnetRouter indicates whether the node is a subnet router. - IsSubnetRouter bool - - // NodeExpiry is set if the change is NodeKeyExpiry. - NodeExpiry *time.Time -} - -func (c *ChangeSet) Validate() error { - if c.Change >= NodeCameOnline || c.Change <= NodeNewOrUpdate { - if c.NodeID == 0 { - return errors.New("ChangeSet.NodeID must be set for node updates") - } +func (r Change) IsEmpty() bool { + if r.IncludeSelf || r.IncludeDERPMap || r.IncludeDNS || + r.IncludeDomain || r.IncludePolicy || r.SendAllPeers { + return false } - if c.Change >= UserNewOrUpdate || c.Change <= UserRemove { - if c.UserID == 0 { - return errors.New("ChangeSet.UserID must be set for user updates") - } + if r.RequiresRuntimePeerComputation { + return false } - return nil + return len(r.PeersChanged) == 0 && + len(r.PeersRemoved) == 0 && + len(r.PeerPatches) == 0 } -// Empty reports whether the ChangeSet is empty, meaning it does not -// represent any change. -func (c ChangeSet) Empty() bool { - return c.Change == ChangeUnknown && c.NodeID == 0 && c.UserID == 0 +func (r Change) IsSelfOnly() bool { + if r.TargetNode == 0 || !r.IncludeSelf { + return false + } + + if r.SendAllPeers || len(r.PeersChanged) > 0 || len(r.PeersRemoved) > 0 || len(r.PeerPatches) > 0 { + return false + } + + return true } -// IsFull reports whether the ChangeSet represents a full update. -func (c ChangeSet) IsFull() bool { - return c.Change == Full || c.Change == Policy +// IsTargetedToNode returns true if this response should only be sent to TargetNode. +func (r Change) IsTargetedToNode() bool { + return r.TargetNode != 0 } -func HasFull(cs []ChangeSet) bool { - for _, c := range cs { - if c.IsFull() { +// IsFull reports whether this is a full update response. +func (r Change) IsFull() bool { + return r.SendAllPeers && r.IncludeSelf && r.IncludeDERPMap && + r.IncludeDNS && r.IncludeDomain && r.IncludePolicy +} + +// Type returns a categorized type string for metrics. +// This provides a bounded set of values suitable for Prometheus labels, +// unlike Reason which is free-form text for logging. +func (r Change) Type() string { + if r.IsFull() { + return "full" + } + + if r.IsSelfOnly() { + return "self" + } + + if r.RequiresRuntimePeerComputation { + return "policy" + } + + if len(r.PeerPatches) > 0 && len(r.PeersChanged) == 0 && len(r.PeersRemoved) == 0 && !r.SendAllPeers { + return "patch" + } + + if len(r.PeersChanged) > 0 || len(r.PeersRemoved) > 0 || r.SendAllPeers { + return "peers" + } + + if r.IncludeDERPMap || r.IncludeDNS || r.IncludeDomain || r.IncludePolicy { + return "config" + } + + return "unknown" +} + +// ShouldSendToNode determines if this response should be sent to nodeID. +// It handles self-only targeting and filtering out self-updates for non-origin nodes. +func (r Change) ShouldSendToNode(nodeID types.NodeID) bool { + // If targeted to a specific node, only send to that node + if r.TargetNode != 0 { + return r.TargetNode == nodeID + } + + return true +} + +// HasFull returns true if any response in the slice is a full update. +func HasFull(rs []Change) bool { + for _, r := range rs { + if r.IsFull() { return true } } + return false } -func SplitAllAndSelf(cs []ChangeSet) (all []ChangeSet, self []ChangeSet) { - for _, c := range cs { - if c.SelfUpdateOnly { - self = append(self, c) +// SplitTargetedAndBroadcast separates responses into targeted (to specific node) and broadcast. +func SplitTargetedAndBroadcast(rs []Change) ([]Change, []Change) { + var broadcast, targeted []Change + + for _, r := range rs { + if r.IsTargetedToNode() { + targeted = append(targeted, r) } else { - all = append(all, c) + broadcast = append(broadcast, r) } } - return all, self + + return broadcast, targeted } -func RemoveUpdatesForSelf(id types.NodeID, cs []ChangeSet) (ret []ChangeSet) { - for _, c := range cs { - if c.NodeID != id || c.Change.AlsoSelf() { - ret = append(ret, c) +// FilterForNode returns responses that should be sent to the given node. +func FilterForNode(nodeID types.NodeID, rs []Change) []Change { + var result []Change + + for _, r := range rs { + if r.ShouldSendToNode(nodeID) { + result = append(result, r) } } - return ret + + return result } -// IsSelfUpdate reports whether this ChangeSet represents an update to the given node itself. -func (c ChangeSet) IsSelfUpdate(nodeID types.NodeID) bool { - return c.NodeID == nodeID -} - -func (c ChangeSet) AlsoSelf() bool { - // If NodeID is 0, it means this ChangeSet is not related to a specific node, - // so we consider it as a change that should be sent to all nodes. - if c.NodeID == 0 { - return true +func uniqueNodeIDs(ids []types.NodeID) []types.NodeID { + if len(ids) == 0 { + return nil } - return c.Change.AlsoSelf() || c.SelfUpdateOnly + + slices.Sort(ids) + + return slices.Compact(ids) } -var ( - EmptySet = ChangeSet{Change: ChangeUnknown} - FullSet = ChangeSet{Change: Full} - DERPSet = ChangeSet{Change: DERP} - PolicySet = ChangeSet{Change: Policy} - ExtraRecordsSet = ChangeSet{Change: ExtraRecords} -) +// Constructor functions -func FullSelf(id types.NodeID) ChangeSet { - return ChangeSet{ - Change: Full, - SelfUpdateOnly: true, - NodeID: id, +func FullUpdate() Change { + return Change{ + Reason: "full update", + IncludeSelf: true, + IncludeDERPMap: true, + IncludeDNS: true, + IncludeDomain: true, + IncludePolicy: true, + SendAllPeers: true, } } -func NodeAdded(id types.NodeID) ChangeSet { - return ChangeSet{ - Change: NodeNewOrUpdate, - NodeID: id, +// FullSelf returns a full update targeted at a specific node. +func FullSelf(nodeID types.NodeID) Change { + return Change{ + Reason: "full self update", + TargetNode: nodeID, + IncludeSelf: true, + IncludeDERPMap: true, + IncludeDNS: true, + IncludeDomain: true, + IncludePolicy: true, + SendAllPeers: true, } } -func NodeRemoved(id types.NodeID) ChangeSet { - return ChangeSet{ - Change: NodeRemove, - NodeID: id, +func SelfUpdate(nodeID types.NodeID) Change { + return Change{ + Reason: "self update", + TargetNode: nodeID, + IncludeSelf: true, } } -func NodeOnline(node types.NodeView) ChangeSet { - return ChangeSet{ - Change: NodeCameOnline, - NodeID: node.ID(), - IsSubnetRouter: node.IsSubnetRouter(), +func PolicyOnly() Change { + return Change{ + Reason: "policy update", + IncludePolicy: true, } } -func NodeOffline(node types.NodeView) ChangeSet { - return ChangeSet{ - Change: NodeWentOffline, - NodeID: node.ID(), - IsSubnetRouter: node.IsSubnetRouter(), +func PolicyAndPeers(changedPeers ...types.NodeID) Change { + return Change{ + Reason: "policy and peers update", + IncludePolicy: true, + PeersChanged: changedPeers, } } -func KeyExpiry(id types.NodeID, expiry time.Time) ChangeSet { - return ChangeSet{ - Change: NodeKeyExpiry, - NodeID: id, - NodeExpiry: &expiry, +func VisibilityChange(reason string, added, removed []types.NodeID) Change { + return Change{ + Reason: reason, + IncludePolicy: true, + PeersChanged: added, + PeersRemoved: removed, } } -func EndpointUpdate(id types.NodeID) ChangeSet { - return ChangeSet{ - Change: NodeEndpoint, - NodeID: id, +func PeersChanged(reason string, peerIDs ...types.NodeID) Change { + return Change{ + Reason: reason, + PeersChanged: peerIDs, } } -func DERPUpdate(id types.NodeID) ChangeSet { - return ChangeSet{ - Change: NodeDERP, - NodeID: id, +func PeersRemoved(peerIDs ...types.NodeID) Change { + return Change{ + Reason: "peers removed", + PeersRemoved: peerIDs, } } -func UserAdded(id types.UserID) ChangeSet { - return ChangeSet{ - Change: UserNewOrUpdate, - UserID: id, +func PeerPatched(reason string, patches ...*tailcfg.PeerChange) Change { + return Change{ + Reason: reason, + PeerPatches: patches, } } -func UserRemoved(id types.UserID) ChangeSet { - return ChangeSet{ - Change: UserRemove, - UserID: id, +func DERPMap() Change { + return Change{ + Reason: "DERP map update", + IncludeDERPMap: true, } } -func PolicyChange() ChangeSet { - return ChangeSet{ - Change: Policy, +// PolicyChange creates a response for policy changes. +// Policy changes require runtime peer visibility computation. +func PolicyChange() Change { + return Change{ + Reason: "policy change", + IncludePolicy: true, + RequiresRuntimePeerComputation: true, } } -func DERPChange() ChangeSet { - return ChangeSet{ - Change: DERP, +// DNSConfig creates a response for DNS configuration updates. +func DNSConfig() Change { + return Change{ + Reason: "DNS config update", + IncludeDNS: true, } } + +// NodeOnline creates a patch response for a node coming online. +func NodeOnline(nodeID types.NodeID) Change { + return Change{ + Reason: "node online", + PeerPatches: []*tailcfg.PeerChange{ + { + NodeID: nodeID.NodeID(), + Online: ptrTo(true), + }, + }, + } +} + +// NodeOffline creates a patch response for a node going offline. +func NodeOffline(nodeID types.NodeID) Change { + return Change{ + Reason: "node offline", + PeerPatches: []*tailcfg.PeerChange{ + { + NodeID: nodeID.NodeID(), + Online: ptrTo(false), + }, + }, + } +} + +// KeyExpiry creates a patch response for a node's key expiry change. +func KeyExpiry(nodeID types.NodeID, expiry *time.Time) Change { + return Change{ + Reason: "key expiry", + PeerPatches: []*tailcfg.PeerChange{ + { + NodeID: nodeID.NodeID(), + KeyExpiry: expiry, + }, + }, + } +} + +// ptrTo returns a pointer to the given value. +func ptrTo[T any](v T) *T { + return &v +} + +// High-level change constructors + +// NodeAdded returns a Change for when a node is added or updated. +// The OriginNode field enables self-update detection by the mapper. +func NodeAdded(id types.NodeID) Change { + c := PeersChanged("node added", id) + c.OriginNode = id + + return c +} + +// NodeRemoved returns a Change for when a node is removed. +func NodeRemoved(id types.NodeID) Change { + return PeersRemoved(id) +} + +// NodeOnlineFor returns a Change for when a node comes online. +// If the node is a subnet router, a full update is sent instead of a patch. +func NodeOnlineFor(node types.NodeView) Change { + if node.IsSubnetRouter() { + c := FullUpdate() + c.Reason = "subnet router online" + + return c + } + + return NodeOnline(node.ID()) +} + +// NodeOfflineFor returns a Change for when a node goes offline. +// If the node is a subnet router, a full update is sent instead of a patch. +func NodeOfflineFor(node types.NodeView) Change { + if node.IsSubnetRouter() { + c := FullUpdate() + c.Reason = "subnet router offline" + + return c + } + + return NodeOffline(node.ID()) +} + +// KeyExpiryFor returns a Change for when a node's key expiry changes. +// The OriginNode field enables self-update detection by the mapper. +func KeyExpiryFor(id types.NodeID, expiry time.Time) Change { + c := KeyExpiry(id, &expiry) + c.OriginNode = id + + return c +} + +// EndpointOrDERPUpdate returns a Change for when a node's endpoints or DERP region changes. +// The OriginNode field enables self-update detection by the mapper. +func EndpointOrDERPUpdate(id types.NodeID, patch *tailcfg.PeerChange) Change { + c := PeerPatched("endpoint/DERP update", patch) + c.OriginNode = id + + return c +} + +// UserAdded returns a Change for when a user is added or updated. +// A full update is sent to refresh user profiles on all nodes. +func UserAdded() Change { + c := FullUpdate() + c.Reason = "user added" + + return c +} + +// UserRemoved returns a Change for when a user is removed. +// A full update is sent to refresh user profiles on all nodes. +func UserRemoved() Change { + c := FullUpdate() + c.Reason = "user removed" + + return c +} + +// ExtraRecords returns a Change for when DNS extra records change. +func ExtraRecords() Change { + c := DNSConfig() + c.Reason = "extra records update" + + return c +} diff --git a/hscontrol/types/change/change_string.go b/hscontrol/types/change/change_string.go deleted file mode 100644 index fd6059d5..00000000 --- a/hscontrol/types/change/change_string.go +++ /dev/null @@ -1,59 +0,0 @@ -// Code generated by "stringer -type=Change"; DO NOT EDIT. - -package change - -import "strconv" - -func _() { - // An "invalid array index" compiler error signifies that the constant values have changed. - // Re-run the stringer command to generate them again. - var x [1]struct{} - _ = x[ChangeUnknown-0] - _ = x[Full-9] - _ = x[Policy-11] - _ = x[DERP-12] - _ = x[ExtraRecords-13] - _ = x[NodeCameOnline-21] - _ = x[NodeWentOffline-22] - _ = x[NodeRemove-23] - _ = x[NodeKeyExpiry-24] - _ = x[NodeNewOrUpdate-25] - _ = x[NodeEndpoint-26] - _ = x[NodeDERP-27] - _ = x[UserNewOrUpdate-51] - _ = x[UserRemove-52] -} - -const ( - _Change_name_0 = "ChangeUnknown" - _Change_name_1 = "Full" - _Change_name_2 = "PolicyDERPExtraRecords" - _Change_name_3 = "NodeCameOnlineNodeWentOfflineNodeRemoveNodeKeyExpiryNodeNewOrUpdateNodeEndpointNodeDERP" - _Change_name_4 = "UserNewOrUpdateUserRemove" -) - -var ( - _Change_index_2 = [...]uint8{0, 6, 10, 22} - _Change_index_3 = [...]uint8{0, 14, 29, 39, 52, 67, 79, 87} - _Change_index_4 = [...]uint8{0, 15, 25} -) - -func (i Change) String() string { - switch { - case i == 0: - return _Change_name_0 - case i == 9: - return _Change_name_1 - case 11 <= i && i <= 13: - i -= 11 - return _Change_name_2[_Change_index_2[i]:_Change_index_2[i+1]] - case 21 <= i && i <= 27: - i -= 21 - return _Change_name_3[_Change_index_3[i]:_Change_index_3[i+1]] - case 51 <= i && i <= 52: - i -= 51 - return _Change_name_4[_Change_index_4[i]:_Change_index_4[i+1]] - default: - return "Change(" + strconv.FormatInt(int64(i), 10) + ")" - } -} diff --git a/hscontrol/types/change/change_test.go b/hscontrol/types/change/change_test.go new file mode 100644 index 00000000..30330584 --- /dev/null +++ b/hscontrol/types/change/change_test.go @@ -0,0 +1,449 @@ +package change + +import ( + "reflect" + "testing" + + "github.com/juanfont/headscale/hscontrol/types" + "github.com/stretchr/testify/assert" + "tailscale.com/tailcfg" +) + +func TestChange_FieldSync(t *testing.T) { + r := Change{} + fieldNames := r.boolFieldNames() + + typ := reflect.TypeFor[Change]() + boolCount := 0 + + for i := range typ.NumField() { + if typ.Field(i).Type.Kind() == reflect.Bool { + boolCount++ + } + } + + if len(fieldNames) != boolCount { + t.Fatalf("boolFieldNames() returns %d fields but struct has %d bool fields; "+ + "update boolFieldNames() when adding new bool fields", len(fieldNames), boolCount) + } +} + +func TestChange_IsEmpty(t *testing.T) { + tests := []struct { + name string + response Change + want bool + }{ + { + name: "zero value is empty", + response: Change{}, + want: true, + }, + { + name: "only reason is still empty", + response: Change{Reason: "test"}, + want: true, + }, + { + name: "IncludeSelf not empty", + response: Change{IncludeSelf: true}, + want: false, + }, + { + name: "IncludeDERPMap not empty", + response: Change{IncludeDERPMap: true}, + want: false, + }, + { + name: "IncludeDNS not empty", + response: Change{IncludeDNS: true}, + want: false, + }, + { + name: "IncludeDomain not empty", + response: Change{IncludeDomain: true}, + want: false, + }, + { + name: "IncludePolicy not empty", + response: Change{IncludePolicy: true}, + want: false, + }, + { + name: "SendAllPeers not empty", + response: Change{SendAllPeers: true}, + want: false, + }, + { + name: "PeersChanged not empty", + response: Change{PeersChanged: []types.NodeID{1}}, + want: false, + }, + { + name: "PeersRemoved not empty", + response: Change{PeersRemoved: []types.NodeID{1}}, + want: false, + }, + { + name: "PeerPatches not empty", + response: Change{PeerPatches: []*tailcfg.PeerChange{{}}}, + want: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := tt.response.IsEmpty() + assert.Equal(t, tt.want, got) + }) + } +} + +func TestChange_IsSelfOnly(t *testing.T) { + tests := []struct { + name string + response Change + want bool + }{ + { + name: "empty is not self only", + response: Change{}, + want: false, + }, + { + name: "IncludeSelf without TargetNode is not self only", + response: Change{IncludeSelf: true}, + want: false, + }, + { + name: "TargetNode without IncludeSelf is not self only", + response: Change{TargetNode: 1}, + want: false, + }, + { + name: "TargetNode with IncludeSelf is self only", + response: Change{TargetNode: 1, IncludeSelf: true}, + want: true, + }, + { + name: "self only with SendAllPeers is not self only", + response: Change{TargetNode: 1, IncludeSelf: true, SendAllPeers: true}, + want: false, + }, + { + name: "self only with PeersChanged is not self only", + response: Change{TargetNode: 1, IncludeSelf: true, PeersChanged: []types.NodeID{2}}, + want: false, + }, + { + name: "self only with PeersRemoved is not self only", + response: Change{TargetNode: 1, IncludeSelf: true, PeersRemoved: []types.NodeID{2}}, + want: false, + }, + { + name: "self only with PeerPatches is not self only", + response: Change{TargetNode: 1, IncludeSelf: true, PeerPatches: []*tailcfg.PeerChange{{}}}, + want: false, + }, + { + name: "self only with other include flags is still self only", + response: Change{ + TargetNode: 1, + IncludeSelf: true, + IncludePolicy: true, + IncludeDNS: true, + }, + want: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := tt.response.IsSelfOnly() + assert.Equal(t, tt.want, got) + }) + } +} + +func TestChange_Merge(t *testing.T) { + tests := []struct { + name string + r1 Change + r2 Change + want Change + }{ + { + name: "empty merge", + r1: Change{}, + r2: Change{}, + want: Change{}, + }, + { + name: "bool fields OR together", + r1: Change{IncludeSelf: true, IncludePolicy: true}, + r2: Change{IncludeDERPMap: true, IncludePolicy: true}, + want: Change{IncludeSelf: true, IncludeDERPMap: true, IncludePolicy: true}, + }, + { + name: "all bool fields merge", + r1: Change{IncludeSelf: true, IncludeDNS: true, IncludePolicy: true}, + r2: Change{IncludeDERPMap: true, IncludeDomain: true, SendAllPeers: true}, + want: Change{ + IncludeSelf: true, + IncludeDERPMap: true, + IncludeDNS: true, + IncludeDomain: true, + IncludePolicy: true, + SendAllPeers: true, + }, + }, + { + name: "peers deduplicated and sorted", + r1: Change{PeersChanged: []types.NodeID{3, 1}}, + r2: Change{PeersChanged: []types.NodeID{2, 1}}, + want: Change{PeersChanged: []types.NodeID{1, 2, 3}}, + }, + { + name: "peers removed deduplicated", + r1: Change{PeersRemoved: []types.NodeID{1, 2}}, + r2: Change{PeersRemoved: []types.NodeID{2, 3}}, + want: Change{PeersRemoved: []types.NodeID{1, 2, 3}}, + }, + { + name: "peer patches concatenated", + r1: Change{PeerPatches: []*tailcfg.PeerChange{{NodeID: 1}}}, + r2: Change{PeerPatches: []*tailcfg.PeerChange{{NodeID: 2}}}, + want: Change{PeerPatches: []*tailcfg.PeerChange{{NodeID: 1}, {NodeID: 2}}}, + }, + { + name: "reasons combined when different", + r1: Change{Reason: "route change"}, + r2: Change{Reason: "tag change"}, + want: Change{Reason: "route change; tag change"}, + }, + { + name: "same reason not duplicated", + r1: Change{Reason: "policy"}, + r2: Change{Reason: "policy"}, + want: Change{Reason: "policy"}, + }, + { + name: "empty reason takes other", + r1: Change{}, + r2: Change{Reason: "update"}, + want: Change{Reason: "update"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := tt.r1.Merge(tt.r2) + assert.Equal(t, tt.want, got) + }) + } +} + +func TestChange_Constructors(t *testing.T) { + tests := []struct { + name string + constructor func() Change + wantReason string + want Change + }{ + { + name: "FullUpdateResponse", + constructor: FullUpdate, + wantReason: "full update", + want: Change{ + Reason: "full update", + IncludeSelf: true, + IncludeDERPMap: true, + IncludeDNS: true, + IncludeDomain: true, + IncludePolicy: true, + SendAllPeers: true, + }, + }, + { + name: "PolicyOnlyResponse", + constructor: PolicyOnly, + wantReason: "policy update", + want: Change{ + Reason: "policy update", + IncludePolicy: true, + }, + }, + { + name: "DERPMapResponse", + constructor: DERPMap, + wantReason: "DERP map update", + want: Change{ + Reason: "DERP map update", + IncludeDERPMap: true, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + r := tt.constructor() + assert.Equal(t, tt.wantReason, r.Reason) + assert.Equal(t, tt.want, r) + }) + } +} + +func TestSelfUpdate(t *testing.T) { + r := SelfUpdate(42) + assert.Equal(t, "self update", r.Reason) + assert.Equal(t, types.NodeID(42), r.TargetNode) + assert.True(t, r.IncludeSelf) + assert.True(t, r.IsSelfOnly()) +} + +func TestPolicyAndPeers(t *testing.T) { + r := PolicyAndPeers(1, 2, 3) + assert.Equal(t, "policy and peers update", r.Reason) + assert.True(t, r.IncludePolicy) + assert.Equal(t, []types.NodeID{1, 2, 3}, r.PeersChanged) +} + +func TestVisibilityChange(t *testing.T) { + r := VisibilityChange("tag change", []types.NodeID{1}, []types.NodeID{2, 3}) + assert.Equal(t, "tag change", r.Reason) + assert.True(t, r.IncludePolicy) + assert.Equal(t, []types.NodeID{1}, r.PeersChanged) + assert.Equal(t, []types.NodeID{2, 3}, r.PeersRemoved) +} + +func TestPeersChanged(t *testing.T) { + r := PeersChanged("routes approved", 1, 2) + assert.Equal(t, "routes approved", r.Reason) + assert.Equal(t, []types.NodeID{1, 2}, r.PeersChanged) + assert.False(t, r.IncludePolicy) +} + +func TestPeersRemoved(t *testing.T) { + r := PeersRemoved(1, 2, 3) + assert.Equal(t, "peers removed", r.Reason) + assert.Equal(t, []types.NodeID{1, 2, 3}, r.PeersRemoved) +} + +func TestPeerPatched(t *testing.T) { + patch := &tailcfg.PeerChange{NodeID: 1} + r := PeerPatched("endpoint change", patch) + assert.Equal(t, "endpoint change", r.Reason) + assert.Equal(t, []*tailcfg.PeerChange{patch}, r.PeerPatches) +} + +func TestChange_Type(t *testing.T) { + tests := []struct { + name string + response Change + want string + }{ + { + name: "full update", + response: FullUpdate(), + want: "full", + }, + { + name: "self only", + response: SelfUpdate(1), + want: "self", + }, + { + name: "policy with runtime computation", + response: PolicyChange(), + want: "policy", + }, + { + name: "patch only", + response: PeerPatched("test", &tailcfg.PeerChange{NodeID: 1}), + want: "patch", + }, + { + name: "peers changed", + response: PeersChanged("test", 1, 2), + want: "peers", + }, + { + name: "peers removed", + response: PeersRemoved(1, 2), + want: "peers", + }, + { + name: "config - DERP map", + response: DERPMap(), + want: "config", + }, + { + name: "config - DNS", + response: DNSConfig(), + want: "config", + }, + { + name: "config - policy only (no runtime)", + response: PolicyOnly(), + want: "config", + }, + { + name: "empty is unknown", + response: Change{}, + want: "unknown", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := tt.response.Type() + assert.Equal(t, tt.want, got) + }) + } +} + +func TestUniqueNodeIDs(t *testing.T) { + tests := []struct { + name string + input []types.NodeID + want []types.NodeID + }{ + { + name: "nil input", + input: nil, + want: nil, + }, + { + name: "empty input", + input: []types.NodeID{}, + want: nil, + }, + { + name: "single element", + input: []types.NodeID{1}, + want: []types.NodeID{1}, + }, + { + name: "no duplicates", + input: []types.NodeID{1, 2, 3}, + want: []types.NodeID{1, 2, 3}, + }, + { + name: "with duplicates", + input: []types.NodeID{3, 1, 2, 1, 3}, + want: []types.NodeID{1, 2, 3}, + }, + { + name: "all same", + input: []types.NodeID{5, 5, 5, 5}, + want: []types.NodeID{5}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := uniqueNodeIDs(tt.input) + assert.Equal(t, tt.want, got) + }) + } +}