package mapper import ( "errors" "fmt" "time" "github.com/juanfont/headscale/hscontrol/state" "github.com/juanfont/headscale/hscontrol/types" "github.com/juanfont/headscale/hscontrol/types/change" "github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus/promauto" "github.com/puzpuzpuz/xsync/v4" "github.com/rs/zerolog/log" "tailscale.com/tailcfg" ) var mapResponseGenerated = promauto.NewCounterVec(prometheus.CounterOpts{ Namespace: "headscale", Name: "mapresponse_generated_total", Help: "total count of mapresponses generated by response type", }, []string{"response_type"}) type batcherFunc func(cfg *types.Config, state *state.State) Batcher // Batcher defines the common interface for all batcher implementations. type Batcher interface { Start() Close() AddNode(id types.NodeID, c chan<- *tailcfg.MapResponse, version tailcfg.CapabilityVersion) error RemoveNode(id types.NodeID, c chan<- *tailcfg.MapResponse) bool IsConnected(id types.NodeID) bool ConnectedMap() *xsync.Map[types.NodeID, bool] AddWork(r ...change.Change) MapResponseFromChange(id types.NodeID, r change.Change) (*tailcfg.MapResponse, error) DebugMapResponses() (map[types.NodeID][]tailcfg.MapResponse, error) } func NewBatcher(batchTime time.Duration, workers int, mapper *mapper) *LockFreeBatcher { return &LockFreeBatcher{ mapper: mapper, workers: workers, tick: time.NewTicker(batchTime), // The size of this channel is arbitrary chosen, the sizing should be revisited. workCh: make(chan work, workers*200), nodes: xsync.NewMap[types.NodeID, *multiChannelNodeConn](), connected: xsync.NewMap[types.NodeID, *time.Time](), pendingChanges: xsync.NewMap[types.NodeID, []change.Change](), } } // NewBatcherAndMapper creates a Batcher implementation. func NewBatcherAndMapper(cfg *types.Config, state *state.State) Batcher { m := newMapper(cfg, state) b := NewBatcher(cfg.Tuning.BatchChangeDelay, cfg.Tuning.BatcherWorkers, m) m.batcher = b return b } // nodeConnection interface for different connection implementations. type nodeConnection interface { nodeID() types.NodeID version() tailcfg.CapabilityVersion send(data *tailcfg.MapResponse) error // computePeerDiff returns peers that were previously sent but are no longer in the current list. computePeerDiff(currentPeers []tailcfg.NodeID) (removed []tailcfg.NodeID) // updateSentPeers updates the tracking of which peers have been sent to this node. updateSentPeers(resp *tailcfg.MapResponse) } // generateMapResponse generates a [tailcfg.MapResponse] for the given NodeID based on the provided [change.Change]. func generateMapResponse(nc nodeConnection, mapper *mapper, r change.Change) (*tailcfg.MapResponse, error) { nodeID := nc.nodeID() version := nc.version() if r.IsEmpty() { return nil, nil //nolint:nilnil // Empty response means nothing to send } if nodeID == 0 { return nil, fmt.Errorf("invalid nodeID: %d", nodeID) } if mapper == nil { return nil, fmt.Errorf("mapper is nil for nodeID %d", nodeID) } // Handle self-only responses if r.IsSelfOnly() && r.TargetNode != nodeID { return nil, nil //nolint:nilnil // No response needed for other nodes when self-only } var ( mapResp *tailcfg.MapResponse err error ) // Track metric using categorized type, not free-form reason mapResponseGenerated.WithLabelValues(r.Type()).Inc() // Check if this requires runtime peer visibility computation (e.g., policy changes) if r.RequiresRuntimePeerComputation { currentPeers := mapper.state.ListPeers(nodeID) currentPeerIDs := make([]tailcfg.NodeID, 0, currentPeers.Len()) for _, peer := range currentPeers.All() { currentPeerIDs = append(currentPeerIDs, peer.ID().NodeID()) } removedPeers := nc.computePeerDiff(currentPeerIDs) mapResp, err = mapper.policyChangeResponse(nodeID, version, removedPeers, currentPeers) } else { mapResp, err = mapper.buildFromChange(nodeID, version, &r) } if err != nil { return nil, fmt.Errorf("generating map response for nodeID %d: %w", nodeID, err) } return mapResp, nil } // handleNodeChange generates and sends a [tailcfg.MapResponse] for a given node and [change.Change]. func handleNodeChange(nc nodeConnection, mapper *mapper, r change.Change) error { if nc == nil { return errors.New("nodeConnection is nil") } nodeID := nc.nodeID() log.Debug().Caller().Uint64("node.id", nodeID.Uint64()).Str("reason", r.Reason).Msg("Node change processing started because change notification received") data, err := generateMapResponse(nc, mapper, r) if err != nil { return fmt.Errorf("generating map response for node %d: %w", nodeID, err) } if data == nil { // No data to send is valid for some response types return nil } // Send the map response err = nc.send(data) if err != nil { return fmt.Errorf("sending map response to node %d: %w", nodeID, err) } // Update peer tracking after successful send nc.updateSentPeers(data) return nil } // workResult represents the result of processing a change. type workResult struct { mapResponse *tailcfg.MapResponse err error } // work represents a unit of work to be processed by workers. type work struct { c change.Change nodeID types.NodeID resultCh chan<- workResult // optional channel for synchronous operations }