mirror of
https://github.com/juanfont/headscale.git
synced 2025-12-26 19:46:37 +00:00
change: smarter change notifications
This commit replaces the ChangeSet with a simpler bool based change model that can be directly used in the map builder to build the appropriate map response based on the change that has occured. Previously, we fell back to sending full maps for a lot of changes as that was consider "the safe" thing to do to ensure no updates were missed. This was slightly problematic as a node that already has a list of peers will only do full replacement of the peers if the list is non-empty, meaning that it was not possible to remove all nodes (if for example policy changed). Now we will keep track of last seen nodes, so we can send remove ids, but also we are much smarter on how we send smaller, partial maps when needed. Fixes #2389 Signed-off-by: Kristoffer Dalby <kristoffer@dalby.cc>
This commit is contained in:
@@ -271,7 +271,7 @@ func (h *Headscale) scheduledTasks(ctx context.Context) {
|
||||
return
|
||||
|
||||
case <-expireTicker.C:
|
||||
var expiredNodeChanges []change.ChangeSet
|
||||
var expiredNodeChanges []change.Change
|
||||
var changed bool
|
||||
|
||||
lastExpiryCheck, expiredNodeChanges, changed = h.state.ExpireExpiredNodes(lastExpiryCheck)
|
||||
@@ -305,7 +305,7 @@ func (h *Headscale) scheduledTasks(ctx context.Context) {
|
||||
}
|
||||
h.state.SetDERPMap(derpMap)
|
||||
|
||||
h.Change(change.DERPSet)
|
||||
h.Change(change.DERPMap())
|
||||
|
||||
case records, ok := <-extraRecordsUpdate:
|
||||
if !ok {
|
||||
@@ -313,7 +313,7 @@ func (h *Headscale) scheduledTasks(ctx context.Context) {
|
||||
}
|
||||
h.cfg.TailcfgDNSConfig.ExtraRecords = records
|
||||
|
||||
h.Change(change.ExtraRecordsSet)
|
||||
h.Change(change.ExtraRecords())
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -988,7 +988,7 @@ func readOrCreatePrivateKey(path string) (*key.MachinePrivate, error) {
|
||||
// Change is used to send changes to nodes.
|
||||
// All change should be enqueued here and empty will be automatically
|
||||
// ignored.
|
||||
func (h *Headscale) Change(cs ...change.ChangeSet) {
|
||||
func (h *Headscale) Change(cs ...change.Change) {
|
||||
h.mapBatcher.AddWork(cs...)
|
||||
}
|
||||
|
||||
|
||||
@@ -58,14 +58,9 @@ func (api headscaleV1APIServer) CreateUser(
|
||||
return nil, status.Errorf(codes.Internal, "failed to create user: %s", err)
|
||||
}
|
||||
|
||||
c := change.UserAdded(types.UserID(user.ID))
|
||||
|
||||
// TODO(kradalby): Both of these might be policy changes, find a better way to merge.
|
||||
if !policyChanged.Empty() {
|
||||
c.Change = change.Policy
|
||||
}
|
||||
|
||||
api.h.Change(c)
|
||||
// CreateUser returns a policy change response if the user creation affected policy.
|
||||
// This triggers a full policy re-evaluation for all connected nodes.
|
||||
api.h.Change(policyChanged)
|
||||
|
||||
return &v1.CreateUserResponse{User: user.Proto()}, nil
|
||||
}
|
||||
@@ -109,7 +104,8 @@ func (api headscaleV1APIServer) DeleteUser(
|
||||
return nil, err
|
||||
}
|
||||
|
||||
api.h.Change(change.UserRemoved(types.UserID(user.ID)))
|
||||
// User deletion may affect policy, trigger a full policy re-evaluation.
|
||||
api.h.Change(change.UserRemoved())
|
||||
|
||||
return &v1.DeleteUserResponse{}, nil
|
||||
}
|
||||
|
||||
@@ -13,18 +13,13 @@ import (
|
||||
"github.com/puzpuzpuz/xsync/v4"
|
||||
"github.com/rs/zerolog/log"
|
||||
"tailscale.com/tailcfg"
|
||||
"tailscale.com/types/ptr"
|
||||
)
|
||||
|
||||
var (
|
||||
mapResponseGenerated = promauto.NewCounterVec(prometheus.CounterOpts{
|
||||
Namespace: "headscale",
|
||||
Name: "mapresponse_generated_total",
|
||||
Help: "total count of mapresponses generated by response type and change type",
|
||||
}, []string{"response_type", "change_type"})
|
||||
|
||||
errNodeNotFoundInNodeStore = errors.New("node not found in NodeStore")
|
||||
)
|
||||
var mapResponseGenerated = promauto.NewCounterVec(prometheus.CounterOpts{
|
||||
Namespace: "headscale",
|
||||
Name: "mapresponse_generated_total",
|
||||
Help: "total count of mapresponses generated by response type",
|
||||
}, []string{"response_type"})
|
||||
|
||||
type batcherFunc func(cfg *types.Config, state *state.State) Batcher
|
||||
|
||||
@@ -36,8 +31,8 @@ type Batcher interface {
|
||||
RemoveNode(id types.NodeID, c chan<- *tailcfg.MapResponse) bool
|
||||
IsConnected(id types.NodeID) bool
|
||||
ConnectedMap() *xsync.Map[types.NodeID, bool]
|
||||
AddWork(c ...change.ChangeSet)
|
||||
MapResponseFromChange(id types.NodeID, c change.ChangeSet) (*tailcfg.MapResponse, error)
|
||||
AddWork(r ...change.Change)
|
||||
MapResponseFromChange(id types.NodeID, r change.Change) (*tailcfg.MapResponse, error)
|
||||
DebugMapResponses() (map[types.NodeID][]tailcfg.MapResponse, error)
|
||||
}
|
||||
|
||||
@@ -51,7 +46,7 @@ func NewBatcher(batchTime time.Duration, workers int, mapper *mapper) *LockFreeB
|
||||
workCh: make(chan work, workers*200),
|
||||
nodes: xsync.NewMap[types.NodeID, *multiChannelNodeConn](),
|
||||
connected: xsync.NewMap[types.NodeID, *time.Time](),
|
||||
pendingChanges: xsync.NewMap[types.NodeID, []change.ChangeSet](),
|
||||
pendingChanges: xsync.NewMap[types.NodeID, []change.Change](),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -69,15 +64,21 @@ type nodeConnection interface {
|
||||
nodeID() types.NodeID
|
||||
version() tailcfg.CapabilityVersion
|
||||
send(data *tailcfg.MapResponse) error
|
||||
// computePeerDiff returns peers that were previously sent but are no longer in the current list.
|
||||
computePeerDiff(currentPeers []tailcfg.NodeID) (removed []tailcfg.NodeID)
|
||||
// updateSentPeers updates the tracking of which peers have been sent to this node.
|
||||
updateSentPeers(resp *tailcfg.MapResponse)
|
||||
}
|
||||
|
||||
// generateMapResponse generates a [tailcfg.MapResponse] for the given NodeID that is based on the provided [change.ChangeSet].
|
||||
func generateMapResponse(nodeID types.NodeID, version tailcfg.CapabilityVersion, mapper *mapper, c change.ChangeSet) (*tailcfg.MapResponse, error) {
|
||||
if c.Empty() {
|
||||
return nil, nil
|
||||
// generateMapResponse generates a [tailcfg.MapResponse] for the given NodeID based on the provided [change.Change].
|
||||
func generateMapResponse(nc nodeConnection, mapper *mapper, r change.Change) (*tailcfg.MapResponse, error) {
|
||||
nodeID := nc.nodeID()
|
||||
version := nc.version()
|
||||
|
||||
if r.IsEmpty() {
|
||||
return nil, nil //nolint:nilnil // Empty response means nothing to send
|
||||
}
|
||||
|
||||
// Validate inputs before processing
|
||||
if nodeID == 0 {
|
||||
return nil, fmt.Errorf("invalid nodeID: %d", nodeID)
|
||||
}
|
||||
@@ -86,141 +87,58 @@ func generateMapResponse(nodeID types.NodeID, version tailcfg.CapabilityVersion,
|
||||
return nil, fmt.Errorf("mapper is nil for nodeID %d", nodeID)
|
||||
}
|
||||
|
||||
// Handle self-only responses
|
||||
if r.IsSelfOnly() && r.TargetNode != nodeID {
|
||||
return nil, nil //nolint:nilnil // No response needed for other nodes when self-only
|
||||
}
|
||||
|
||||
var (
|
||||
mapResp *tailcfg.MapResponse
|
||||
err error
|
||||
responseType string
|
||||
mapResp *tailcfg.MapResponse
|
||||
err error
|
||||
)
|
||||
|
||||
// Record metric when function exits
|
||||
defer func() {
|
||||
if err == nil && mapResp != nil && responseType != "" {
|
||||
mapResponseGenerated.WithLabelValues(responseType, c.Change.String()).Inc()
|
||||
}
|
||||
}()
|
||||
// Track metric using categorized type, not free-form reason
|
||||
mapResponseGenerated.WithLabelValues(r.Type()).Inc()
|
||||
|
||||
switch c.Change {
|
||||
case change.DERP:
|
||||
responseType = "derp"
|
||||
mapResp, err = mapper.derpMapResponse(nodeID)
|
||||
// Check if this requires runtime peer visibility computation (e.g., policy changes)
|
||||
if r.RequiresRuntimePeerComputation {
|
||||
currentPeers := mapper.state.ListPeers(nodeID)
|
||||
|
||||
case change.NodeCameOnline, change.NodeWentOffline:
|
||||
if c.IsSubnetRouter {
|
||||
// TODO(kradalby): This can potentially be a peer update of the old and new subnet router.
|
||||
responseType = "full"
|
||||
mapResp, err = mapper.fullMapResponse(nodeID, version)
|
||||
} else {
|
||||
// Trust the change type for online/offline status to avoid race conditions
|
||||
// between NodeStore updates and change processing
|
||||
responseType = string(patchResponseDebug)
|
||||
onlineStatus := c.Change == change.NodeCameOnline
|
||||
|
||||
mapResp, err = mapper.peerChangedPatchResponse(nodeID, []*tailcfg.PeerChange{
|
||||
{
|
||||
NodeID: c.NodeID.NodeID(),
|
||||
Online: ptr.To(onlineStatus),
|
||||
},
|
||||
})
|
||||
currentPeerIDs := make([]tailcfg.NodeID, 0, currentPeers.Len())
|
||||
for _, peer := range currentPeers.All() {
|
||||
currentPeerIDs = append(currentPeerIDs, peer.ID().NodeID())
|
||||
}
|
||||
|
||||
case change.NodeNewOrUpdate:
|
||||
// If the node is the one being updated, we send a self update that preserves peer information
|
||||
// to ensure the node sees changes to its own properties (e.g., hostname/DNS name changes)
|
||||
// without losing its view of peer status during rapid reconnection cycles
|
||||
if c.IsSelfUpdate(nodeID) {
|
||||
responseType = "self"
|
||||
mapResp, err = mapper.selfMapResponse(nodeID, version)
|
||||
} else {
|
||||
responseType = "change"
|
||||
mapResp, err = mapper.peerChangeResponse(nodeID, version, c.NodeID)
|
||||
}
|
||||
|
||||
case change.NodeRemove:
|
||||
responseType = "remove"
|
||||
mapResp, err = mapper.peerRemovedResponse(nodeID, c.NodeID)
|
||||
|
||||
case change.NodeKeyExpiry:
|
||||
// If the node is the one whose key is expiring, we send a "full" self update
|
||||
// as nodes will ignore patch updates about themselves (?).
|
||||
if c.IsSelfUpdate(nodeID) {
|
||||
responseType = "self"
|
||||
mapResp, err = mapper.selfMapResponse(nodeID, version)
|
||||
// mapResp, err = mapper.fullMapResponse(nodeID, version)
|
||||
} else {
|
||||
responseType = "patch"
|
||||
mapResp, err = mapper.peerChangedPatchResponse(nodeID, []*tailcfg.PeerChange{
|
||||
{
|
||||
NodeID: c.NodeID.NodeID(),
|
||||
KeyExpiry: c.NodeExpiry,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
case change.NodeEndpoint, change.NodeDERP:
|
||||
// Endpoint or DERP changes can be sent as lightweight patches.
|
||||
// Query the NodeStore for the current peer state to construct the PeerChange.
|
||||
// Even if only endpoint or only DERP changed, we include both in the patch
|
||||
// since they're often updated together and it's minimal overhead.
|
||||
responseType = "patch"
|
||||
|
||||
peer, found := mapper.state.GetNodeByID(c.NodeID)
|
||||
if !found {
|
||||
return nil, fmt.Errorf("%w: %d", errNodeNotFoundInNodeStore, c.NodeID)
|
||||
}
|
||||
|
||||
peerChange := &tailcfg.PeerChange{
|
||||
NodeID: c.NodeID.NodeID(),
|
||||
Endpoints: peer.Endpoints().AsSlice(),
|
||||
DERPRegion: 0, // Will be set below if available
|
||||
}
|
||||
|
||||
// Extract DERP region from Hostinfo if available
|
||||
if hi := peer.AsStruct().Hostinfo; hi != nil && hi.NetInfo != nil {
|
||||
peerChange.DERPRegion = hi.NetInfo.PreferredDERP
|
||||
}
|
||||
|
||||
mapResp, err = mapper.peerChangedPatchResponse(nodeID, []*tailcfg.PeerChange{peerChange})
|
||||
|
||||
default:
|
||||
// The following will always hit this:
|
||||
// change.Full, change.Policy
|
||||
responseType = "full"
|
||||
mapResp, err = mapper.fullMapResponse(nodeID, version)
|
||||
removedPeers := nc.computePeerDiff(currentPeerIDs)
|
||||
mapResp, err = mapper.policyChangeResponse(nodeID, version, removedPeers, currentPeers)
|
||||
} else {
|
||||
mapResp, err = mapper.buildFromChange(nodeID, version, &r)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("generating map response for nodeID %d: %w", nodeID, err)
|
||||
}
|
||||
|
||||
// TODO(kradalby): Is this necessary?
|
||||
// Validate the generated map response - only check for nil response
|
||||
// Note: mapResp.Node can be nil for peer updates, which is valid
|
||||
if mapResp == nil && c.Change != change.DERP && c.Change != change.NodeRemove {
|
||||
return nil, fmt.Errorf("generated nil map response for nodeID %d change %s", nodeID, c.Change.String())
|
||||
}
|
||||
|
||||
return mapResp, nil
|
||||
}
|
||||
|
||||
// handleNodeChange generates and sends a [tailcfg.MapResponse] for a given node and [change.ChangeSet].
|
||||
func handleNodeChange(nc nodeConnection, mapper *mapper, c change.ChangeSet) error {
|
||||
// handleNodeChange generates and sends a [tailcfg.MapResponse] for a given node and [change.Change].
|
||||
func handleNodeChange(nc nodeConnection, mapper *mapper, r change.Change) error {
|
||||
if nc == nil {
|
||||
return errors.New("nodeConnection is nil")
|
||||
}
|
||||
|
||||
nodeID := nc.nodeID()
|
||||
|
||||
log.Debug().Caller().Uint64("node.id", nodeID.Uint64()).Str("change.type", c.Change.String()).Msg("Node change processing started because change notification received")
|
||||
log.Debug().Caller().Uint64("node.id", nodeID.Uint64()).Str("reason", r.Reason).Msg("Node change processing started because change notification received")
|
||||
|
||||
var data *tailcfg.MapResponse
|
||||
var err error
|
||||
data, err = generateMapResponse(nodeID, nc.version(), mapper, c)
|
||||
data, err := generateMapResponse(nc, mapper, r)
|
||||
if err != nil {
|
||||
return fmt.Errorf("generating map response for node %d: %w", nodeID, err)
|
||||
}
|
||||
|
||||
if data == nil {
|
||||
// No data to send is valid for some change types
|
||||
// No data to send is valid for some response types
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -230,6 +148,9 @@ func handleNodeChange(nc nodeConnection, mapper *mapper, c change.ChangeSet) err
|
||||
return fmt.Errorf("sending map response to node %d: %w", nodeID, err)
|
||||
}
|
||||
|
||||
// Update peer tracking after successful send
|
||||
nc.updateSentPeers(data)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -241,7 +162,7 @@ type workResult struct {
|
||||
|
||||
// work represents a unit of work to be processed by workers.
|
||||
type work struct {
|
||||
c change.ChangeSet
|
||||
r change.Change
|
||||
nodeID types.NodeID
|
||||
resultCh chan<- workResult // optional channel for synchronous operations
|
||||
}
|
||||
|
||||
@@ -33,7 +33,7 @@ type LockFreeBatcher struct {
|
||||
done chan struct{}
|
||||
|
||||
// Batching state
|
||||
pendingChanges *xsync.Map[types.NodeID, []change.ChangeSet]
|
||||
pendingChanges *xsync.Map[types.NodeID, []change.Change]
|
||||
|
||||
// Metrics
|
||||
totalNodes atomic.Int64
|
||||
@@ -141,8 +141,8 @@ func (b *LockFreeBatcher) RemoveNode(id types.NodeID, c chan<- *tailcfg.MapRespo
|
||||
}
|
||||
|
||||
// AddWork queues a change to be processed by the batcher.
|
||||
func (b *LockFreeBatcher) AddWork(c ...change.ChangeSet) {
|
||||
b.addWork(c...)
|
||||
func (b *LockFreeBatcher) AddWork(r ...change.Change) {
|
||||
b.addWork(r...)
|
||||
}
|
||||
|
||||
func (b *LockFreeBatcher) Start() {
|
||||
@@ -211,15 +211,19 @@ func (b *LockFreeBatcher) worker(workerID int) {
|
||||
var result workResult
|
||||
if nc, exists := b.nodes.Load(w.nodeID); exists {
|
||||
var err error
|
||||
result.mapResponse, err = generateMapResponse(nc.nodeID(), nc.version(), b.mapper, w.c)
|
||||
|
||||
result.mapResponse, err = generateMapResponse(nc, b.mapper, w.r)
|
||||
result.err = err
|
||||
if result.err != nil {
|
||||
b.workErrors.Add(1)
|
||||
log.Error().Err(result.err).
|
||||
Int("worker.id", workerID).
|
||||
Uint64("node.id", w.nodeID.Uint64()).
|
||||
Str("change", w.c.Change.String()).
|
||||
Str("reason", w.r.Reason).
|
||||
Msg("failed to generate map response for synchronous work")
|
||||
} else if result.mapResponse != nil {
|
||||
// Update peer tracking for synchronous responses too
|
||||
nc.updateSentPeers(result.mapResponse)
|
||||
}
|
||||
} else {
|
||||
result.err = fmt.Errorf("node %d not found", w.nodeID)
|
||||
@@ -247,13 +251,13 @@ func (b *LockFreeBatcher) worker(workerID int) {
|
||||
if nc, exists := b.nodes.Load(w.nodeID); exists {
|
||||
// Apply change to node - this will handle offline nodes gracefully
|
||||
// and queue work for when they reconnect
|
||||
err := nc.change(w.c)
|
||||
err := nc.change(w.r)
|
||||
if err != nil {
|
||||
b.workErrors.Add(1)
|
||||
log.Error().Err(err).
|
||||
Int("worker.id", workerID).
|
||||
Uint64("node.id", w.c.NodeID.Uint64()).
|
||||
Str("change", w.c.Change.String()).
|
||||
Uint64("node.id", w.nodeID.Uint64()).
|
||||
Str("reason", w.r.Reason).
|
||||
Msg("failed to apply change")
|
||||
}
|
||||
}
|
||||
@@ -264,8 +268,8 @@ func (b *LockFreeBatcher) worker(workerID int) {
|
||||
}
|
||||
}
|
||||
|
||||
func (b *LockFreeBatcher) addWork(c ...change.ChangeSet) {
|
||||
b.addToBatch(c...)
|
||||
func (b *LockFreeBatcher) addWork(r ...change.Change) {
|
||||
b.addToBatch(r...)
|
||||
}
|
||||
|
||||
// queueWork safely queues work.
|
||||
@@ -281,38 +285,43 @@ func (b *LockFreeBatcher) queueWork(w work) {
|
||||
}
|
||||
}
|
||||
|
||||
// addToBatch adds a change to the pending batch.
|
||||
func (b *LockFreeBatcher) addToBatch(c ...change.ChangeSet) {
|
||||
// Short circuit if any of the changes is a full update, which
|
||||
// addToBatch adds a response to the pending batch.
|
||||
func (b *LockFreeBatcher) addToBatch(responses ...change.Change) {
|
||||
// Short circuit if any of the responses is a full update, which
|
||||
// means we can skip sending individual changes.
|
||||
if change.HasFull(c) {
|
||||
if change.HasFull(responses) {
|
||||
b.nodes.Range(func(nodeID types.NodeID, _ *multiChannelNodeConn) bool {
|
||||
b.pendingChanges.Store(nodeID, []change.ChangeSet{{Change: change.Full}})
|
||||
b.pendingChanges.Store(nodeID, []change.Change{change.FullUpdate()})
|
||||
|
||||
return true
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
all, self := change.SplitAllAndSelf(c)
|
||||
|
||||
for _, changeSet := range self {
|
||||
changes, _ := b.pendingChanges.LoadOrStore(changeSet.NodeID, []change.ChangeSet{})
|
||||
changes = append(changes, changeSet)
|
||||
b.pendingChanges.Store(changeSet.NodeID, changes)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
b.nodes.Range(func(nodeID types.NodeID, _ *multiChannelNodeConn) bool {
|
||||
rel := change.RemoveUpdatesForSelf(nodeID, all)
|
||||
broadcast, targeted := change.SplitTargetedAndBroadcast(responses)
|
||||
|
||||
changes, _ := b.pendingChanges.LoadOrStore(nodeID, []change.ChangeSet{})
|
||||
changes = append(changes, rel...)
|
||||
b.pendingChanges.Store(nodeID, changes)
|
||||
// Handle targeted responses - send only to the specific node
|
||||
for _, resp := range targeted {
|
||||
changes, _ := b.pendingChanges.LoadOrStore(resp.TargetNode, []change.Change{})
|
||||
changes = append(changes, resp)
|
||||
b.pendingChanges.Store(resp.TargetNode, changes)
|
||||
}
|
||||
|
||||
return true
|
||||
})
|
||||
// Handle broadcast responses - send to all nodes, filtering as needed
|
||||
if len(broadcast) > 0 {
|
||||
b.nodes.Range(func(nodeID types.NodeID, _ *multiChannelNodeConn) bool {
|
||||
filtered := change.FilterForNode(nodeID, broadcast)
|
||||
|
||||
if len(filtered) > 0 {
|
||||
changes, _ := b.pendingChanges.LoadOrStore(nodeID, []change.Change{})
|
||||
changes = append(changes, filtered...)
|
||||
b.pendingChanges.Store(nodeID, changes)
|
||||
}
|
||||
|
||||
return true
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// processBatchedChanges processes all pending batched changes.
|
||||
@@ -322,14 +331,14 @@ func (b *LockFreeBatcher) processBatchedChanges() {
|
||||
}
|
||||
|
||||
// Process all pending changes
|
||||
b.pendingChanges.Range(func(nodeID types.NodeID, changes []change.ChangeSet) bool {
|
||||
if len(changes) == 0 {
|
||||
b.pendingChanges.Range(func(nodeID types.NodeID, responses []change.Change) bool {
|
||||
if len(responses) == 0 {
|
||||
return true
|
||||
}
|
||||
|
||||
// Send all batched changes for this node
|
||||
for _, c := range changes {
|
||||
b.queueWork(work{c: c, nodeID: nodeID, resultCh: nil})
|
||||
// Send all batched responses for this node
|
||||
for _, r := range responses {
|
||||
b.queueWork(work{r: r, nodeID: nodeID, resultCh: nil})
|
||||
}
|
||||
|
||||
// Clear the pending changes for this node
|
||||
@@ -432,11 +441,11 @@ func (b *LockFreeBatcher) ConnectedMap() *xsync.Map[types.NodeID, bool] {
|
||||
|
||||
// MapResponseFromChange queues work to generate a map response and waits for the result.
|
||||
// This allows synchronous map generation using the same worker pool.
|
||||
func (b *LockFreeBatcher) MapResponseFromChange(id types.NodeID, c change.ChangeSet) (*tailcfg.MapResponse, error) {
|
||||
func (b *LockFreeBatcher) MapResponseFromChange(id types.NodeID, r change.Change) (*tailcfg.MapResponse, error) {
|
||||
resultCh := make(chan workResult, 1)
|
||||
|
||||
// Queue the work with a result channel using the safe queueing method
|
||||
b.queueWork(work{c: c, nodeID: id, resultCh: resultCh})
|
||||
b.queueWork(work{r: r, nodeID: id, resultCh: resultCh})
|
||||
|
||||
// Wait for the result
|
||||
select {
|
||||
@@ -466,6 +475,12 @@ type multiChannelNodeConn struct {
|
||||
connections []*connectionEntry
|
||||
|
||||
updateCount atomic.Int64
|
||||
|
||||
// lastSentPeers tracks which peers were last sent to this node.
|
||||
// This enables computing diffs for policy changes instead of sending
|
||||
// full peer lists (which clients interpret as "no change" when empty).
|
||||
// Using xsync.Map for lock-free concurrent access.
|
||||
lastSentPeers *xsync.Map[tailcfg.NodeID, struct{}]
|
||||
}
|
||||
|
||||
// generateConnectionID generates a unique connection identifier.
|
||||
@@ -478,8 +493,9 @@ func generateConnectionID() string {
|
||||
// newMultiChannelNodeConn creates a new multi-channel node connection.
|
||||
func newMultiChannelNodeConn(id types.NodeID, mapper *mapper) *multiChannelNodeConn {
|
||||
return &multiChannelNodeConn{
|
||||
id: id,
|
||||
mapper: mapper,
|
||||
id: id,
|
||||
mapper: mapper,
|
||||
lastSentPeers: xsync.NewMap[tailcfg.NodeID, struct{}](),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -662,9 +678,59 @@ func (mc *multiChannelNodeConn) version() tailcfg.CapabilityVersion {
|
||||
return mc.connections[0].version
|
||||
}
|
||||
|
||||
// updateSentPeers updates the tracked peer state based on a sent MapResponse.
|
||||
// This must be called after successfully sending a response to keep track of
|
||||
// what the client knows about, enabling accurate diffs for future updates.
|
||||
func (mc *multiChannelNodeConn) updateSentPeers(resp *tailcfg.MapResponse) {
|
||||
if resp == nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Full peer list replaces tracked state entirely
|
||||
if resp.Peers != nil {
|
||||
mc.lastSentPeers.Clear()
|
||||
|
||||
for _, peer := range resp.Peers {
|
||||
mc.lastSentPeers.Store(peer.ID, struct{}{})
|
||||
}
|
||||
}
|
||||
|
||||
// Incremental additions
|
||||
for _, peer := range resp.PeersChanged {
|
||||
mc.lastSentPeers.Store(peer.ID, struct{}{})
|
||||
}
|
||||
|
||||
// Incremental removals
|
||||
for _, id := range resp.PeersRemoved {
|
||||
mc.lastSentPeers.Delete(id)
|
||||
}
|
||||
}
|
||||
|
||||
// computePeerDiff compares the current peer list against what was last sent
|
||||
// and returns the peers that were removed (in lastSentPeers but not in current).
|
||||
func (mc *multiChannelNodeConn) computePeerDiff(currentPeers []tailcfg.NodeID) []tailcfg.NodeID {
|
||||
currentSet := make(map[tailcfg.NodeID]struct{}, len(currentPeers))
|
||||
for _, id := range currentPeers {
|
||||
currentSet[id] = struct{}{}
|
||||
}
|
||||
|
||||
var removed []tailcfg.NodeID
|
||||
|
||||
// Find removed: in lastSentPeers but not in current
|
||||
mc.lastSentPeers.Range(func(id tailcfg.NodeID, _ struct{}) bool {
|
||||
if _, exists := currentSet[id]; !exists {
|
||||
removed = append(removed, id)
|
||||
}
|
||||
|
||||
return true
|
||||
})
|
||||
|
||||
return removed
|
||||
}
|
||||
|
||||
// change applies a change to all active connections for the node.
|
||||
func (mc *multiChannelNodeConn) change(c change.ChangeSet) error {
|
||||
return handleNodeChange(mc, mc.mapper, c)
|
||||
func (mc *multiChannelNodeConn) change(r change.Change) error {
|
||||
return handleNodeChange(mc, mc.mapper, r)
|
||||
}
|
||||
|
||||
// DebugNodeInfo contains debug information about a node's connections.
|
||||
|
||||
@@ -59,7 +59,7 @@ func (t *testBatcherWrapper) AddNode(id types.NodeID, c chan<- *tailcfg.MapRespo
|
||||
return fmt.Errorf("%w: %d", errNodeNotFoundAfterAdd, id)
|
||||
}
|
||||
|
||||
t.AddWork(change.NodeOnline(node))
|
||||
t.AddWork(change.NodeOnlineFor(node))
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -76,7 +76,7 @@ func (t *testBatcherWrapper) RemoveNode(id types.NodeID, c chan<- *tailcfg.MapRe
|
||||
// Do this BEFORE removing from batcher so the change can be processed
|
||||
node, ok := t.state.GetNodeByID(id)
|
||||
if ok {
|
||||
t.AddWork(change.NodeOffline(node))
|
||||
t.AddWork(change.NodeOfflineFor(node))
|
||||
}
|
||||
|
||||
// Finally remove from the real batcher
|
||||
@@ -557,9 +557,9 @@ func TestEnhancedTrackingWithBatcher(t *testing.T) {
|
||||
}, time.Second, 10*time.Millisecond, "waiting for node connection")
|
||||
|
||||
// Generate work and wait for updates to be processed
|
||||
batcher.AddWork(change.FullSet)
|
||||
batcher.AddWork(change.PolicySet)
|
||||
batcher.AddWork(change.DERPSet)
|
||||
batcher.AddWork(change.FullUpdate())
|
||||
batcher.AddWork(change.PolicyChange())
|
||||
batcher.AddWork(change.DERPMap())
|
||||
|
||||
// Wait for updates to be processed (at least 1 update received)
|
||||
assert.EventuallyWithT(t, func(c *assert.CollectT) {
|
||||
@@ -661,7 +661,7 @@ func TestBatcherScalabilityAllToAll(t *testing.T) {
|
||||
batcher.AddNode(node.n.ID, node.ch, tailcfg.CapabilityVersion(100))
|
||||
|
||||
// Issue full update after each join to ensure connectivity
|
||||
batcher.AddWork(change.FullSet)
|
||||
batcher.AddWork(change.FullUpdate())
|
||||
|
||||
// Yield to scheduler for large node counts to prevent overwhelming the work queue
|
||||
if tc.nodeCount > 100 && i%50 == 49 {
|
||||
@@ -832,7 +832,7 @@ func TestBatcherBasicOperations(t *testing.T) {
|
||||
}
|
||||
|
||||
// Test work processing with DERP change
|
||||
batcher.AddWork(change.DERPChange())
|
||||
batcher.AddWork(change.DERPMap())
|
||||
|
||||
// Wait for update and validate content
|
||||
select {
|
||||
@@ -959,31 +959,31 @@ func drainChannelTimeout(ch <-chan *tailcfg.MapResponse, name string, timeout ti
|
||||
// }{
|
||||
// {
|
||||
// name: "DERP change",
|
||||
// changeSet: change.DERPSet,
|
||||
// changeSet: change.DERPMapResponse(),
|
||||
// expectData: true,
|
||||
// description: "DERP changes should generate map updates",
|
||||
// },
|
||||
// {
|
||||
// name: "Node key expiry",
|
||||
// changeSet: change.KeyExpiry(testNodes[1].n.ID),
|
||||
// changeSet: change.KeyExpiryFor(testNodes[1].n.ID),
|
||||
// expectData: true,
|
||||
// description: "Node key expiry with real node data",
|
||||
// },
|
||||
// {
|
||||
// name: "Node new registration",
|
||||
// changeSet: change.NodeAdded(testNodes[1].n.ID),
|
||||
// changeSet: change.NodeAddedResponse(testNodes[1].n.ID),
|
||||
// expectData: true,
|
||||
// description: "New node registration with real data",
|
||||
// },
|
||||
// {
|
||||
// name: "Full update",
|
||||
// changeSet: change.FullSet,
|
||||
// changeSet: change.FullUpdateResponse(),
|
||||
// expectData: true,
|
||||
// description: "Full updates with real node data",
|
||||
// },
|
||||
// {
|
||||
// name: "Policy change",
|
||||
// changeSet: change.PolicySet,
|
||||
// changeSet: change.PolicyChangeResponse(),
|
||||
// expectData: true,
|
||||
// description: "Policy updates with real node data",
|
||||
// },
|
||||
@@ -1057,13 +1057,13 @@ func TestBatcherWorkQueueBatching(t *testing.T) {
|
||||
var receivedUpdates []*tailcfg.MapResponse
|
||||
|
||||
// Add multiple changes rapidly to test batching
|
||||
batcher.AddWork(change.DERPSet)
|
||||
batcher.AddWork(change.DERPMap())
|
||||
// Use a valid expiry time for testing since test nodes don't have expiry set
|
||||
testExpiry := time.Now().Add(24 * time.Hour)
|
||||
batcher.AddWork(change.KeyExpiry(testNodes[1].n.ID, testExpiry))
|
||||
batcher.AddWork(change.DERPSet)
|
||||
batcher.AddWork(change.KeyExpiryFor(testNodes[1].n.ID, testExpiry))
|
||||
batcher.AddWork(change.DERPMap())
|
||||
batcher.AddWork(change.NodeAdded(testNodes[1].n.ID))
|
||||
batcher.AddWork(change.DERPSet)
|
||||
batcher.AddWork(change.DERPMap())
|
||||
|
||||
// Collect updates with timeout
|
||||
updateCount := 0
|
||||
@@ -1087,8 +1087,8 @@ func TestBatcherWorkQueueBatching(t *testing.T) {
|
||||
t.Logf("Update %d: nil update", updateCount)
|
||||
}
|
||||
case <-timeout:
|
||||
// Expected: 5 changes should generate 6 updates (no batching in current implementation)
|
||||
expectedUpdates := 6
|
||||
// Expected: 5 explicit changes + 1 initial from AddNode + 1 NodeOnline from wrapper = 7 updates
|
||||
expectedUpdates := 7
|
||||
t.Logf("Received %d updates from %d changes (expected %d)",
|
||||
updateCount, 5, expectedUpdates)
|
||||
|
||||
@@ -1160,7 +1160,7 @@ func XTestBatcherChannelClosingRace(t *testing.T) {
|
||||
|
||||
// Add real work during connection chaos
|
||||
if i%10 == 0 {
|
||||
batcher.AddWork(change.DERPSet)
|
||||
batcher.AddWork(change.DERPMap())
|
||||
}
|
||||
|
||||
// Rapid second connection - should replace ch1
|
||||
@@ -1260,7 +1260,7 @@ func TestBatcherWorkerChannelSafety(t *testing.T) {
|
||||
|
||||
// Add node and immediately queue real work
|
||||
batcher.AddNode(testNode.n.ID, ch, tailcfg.CapabilityVersion(100))
|
||||
batcher.AddWork(change.DERPSet)
|
||||
batcher.AddWork(change.DERPMap())
|
||||
|
||||
// Consumer goroutine to validate data and detect channel issues
|
||||
go func() {
|
||||
@@ -1302,7 +1302,7 @@ func TestBatcherWorkerChannelSafety(t *testing.T) {
|
||||
if i%10 == 0 {
|
||||
// Use a valid expiry time for testing since test nodes don't have expiry set
|
||||
testExpiry := time.Now().Add(24 * time.Hour)
|
||||
batcher.AddWork(change.KeyExpiry(testNode.n.ID, testExpiry))
|
||||
batcher.AddWork(change.KeyExpiryFor(testNode.n.ID, testExpiry))
|
||||
}
|
||||
|
||||
// Rapid removal creates race between worker and removal
|
||||
@@ -1510,12 +1510,12 @@ func TestBatcherConcurrentClients(t *testing.T) {
|
||||
// Generate various types of work during racing
|
||||
if i%3 == 0 {
|
||||
// DERP changes
|
||||
batcher.AddWork(change.DERPSet)
|
||||
batcher.AddWork(change.DERPMap())
|
||||
}
|
||||
|
||||
if i%5 == 0 {
|
||||
// Full updates using real node data
|
||||
batcher.AddWork(change.FullSet)
|
||||
batcher.AddWork(change.FullUpdate())
|
||||
}
|
||||
|
||||
if i%7 == 0 && len(allNodes) > 0 {
|
||||
@@ -1523,7 +1523,7 @@ func TestBatcherConcurrentClients(t *testing.T) {
|
||||
node := allNodes[i%len(allNodes)]
|
||||
// Use a valid expiry time for testing since test nodes don't have expiry set
|
||||
testExpiry := time.Now().Add(24 * time.Hour)
|
||||
batcher.AddWork(change.KeyExpiry(node.n.ID, testExpiry))
|
||||
batcher.AddWork(change.KeyExpiryFor(node.n.ID, testExpiry))
|
||||
}
|
||||
|
||||
// Yield to allow some batching
|
||||
@@ -1778,7 +1778,7 @@ func XTestBatcherScalability(t *testing.T) {
|
||||
}
|
||||
}, 5*time.Second, 50*time.Millisecond, "waiting for nodes to connect")
|
||||
|
||||
batcher.AddWork(change.FullSet)
|
||||
batcher.AddWork(change.FullUpdate())
|
||||
|
||||
// Wait for initial update to propagate
|
||||
assert.EventuallyWithT(t, func(c *assert.CollectT) {
|
||||
@@ -1887,7 +1887,7 @@ func XTestBatcherScalability(t *testing.T) {
|
||||
|
||||
// Add work to create load
|
||||
if index%5 == 0 {
|
||||
batcher.AddWork(change.FullSet)
|
||||
batcher.AddWork(change.FullUpdate())
|
||||
}
|
||||
}(
|
||||
node.n.ID,
|
||||
@@ -1914,11 +1914,11 @@ func XTestBatcherScalability(t *testing.T) {
|
||||
// Generate different types of work to ensure updates are sent
|
||||
switch index % 4 {
|
||||
case 0:
|
||||
batcher.AddWork(change.FullSet)
|
||||
batcher.AddWork(change.FullUpdate())
|
||||
case 1:
|
||||
batcher.AddWork(change.PolicySet)
|
||||
batcher.AddWork(change.PolicyChange())
|
||||
case 2:
|
||||
batcher.AddWork(change.DERPSet)
|
||||
batcher.AddWork(change.DERPMap())
|
||||
default:
|
||||
// Pick a random node and generate a node change
|
||||
if len(testNodes) > 0 {
|
||||
@@ -1927,7 +1927,7 @@ func XTestBatcherScalability(t *testing.T) {
|
||||
change.NodeAdded(testNodes[nodeIdx].n.ID),
|
||||
)
|
||||
} else {
|
||||
batcher.AddWork(change.FullSet)
|
||||
batcher.AddWork(change.FullUpdate())
|
||||
}
|
||||
}
|
||||
}(i)
|
||||
@@ -2165,7 +2165,7 @@ func TestBatcherFullPeerUpdates(t *testing.T) {
|
||||
|
||||
// Send a full update - this should generate full peer lists
|
||||
t.Logf("Sending FullSet update...")
|
||||
batcher.AddWork(change.FullSet)
|
||||
batcher.AddWork(change.FullUpdate())
|
||||
|
||||
// Wait for FullSet work items to be processed
|
||||
t.Logf("Waiting for FullSet to be processed...")
|
||||
@@ -2261,7 +2261,7 @@ func TestBatcherFullPeerUpdates(t *testing.T) {
|
||||
t.Logf("Total updates received across all nodes: %d", totalUpdates)
|
||||
|
||||
if !foundFullUpdate {
|
||||
t.Errorf("CRITICAL: No FULL updates received despite sending change.FullSet!")
|
||||
t.Errorf("CRITICAL: No FULL updates received despite sending change.FullUpdateResponse()!")
|
||||
t.Errorf(
|
||||
"This confirms the bug - FullSet updates are not generating full peer responses",
|
||||
)
|
||||
@@ -2372,7 +2372,7 @@ func TestBatcherRapidReconnection(t *testing.T) {
|
||||
t.Logf("Phase 5: Testing if nodes can receive updates despite debug status...")
|
||||
|
||||
// Send a change that should reach all nodes
|
||||
batcher.AddWork(change.DERPChange())
|
||||
batcher.AddWork(change.DERPMap())
|
||||
|
||||
receivedCount := 0
|
||||
timeout := time.After(500 * time.Millisecond)
|
||||
@@ -2508,11 +2508,7 @@ func TestBatcherMultiConnection(t *testing.T) {
|
||||
clearChannel(node2.ch)
|
||||
|
||||
// Send a change notification from node2 (so node1 should receive it on all connections)
|
||||
testChangeSet := change.ChangeSet{
|
||||
NodeID: node2.n.ID,
|
||||
Change: change.NodeNewOrUpdate,
|
||||
SelfUpdateOnly: false,
|
||||
}
|
||||
testChangeSet := change.NodeAdded(node2.n.ID)
|
||||
|
||||
batcher.AddWork(testChangeSet)
|
||||
|
||||
@@ -2591,11 +2587,7 @@ func TestBatcherMultiConnection(t *testing.T) {
|
||||
clearChannel(node1.ch)
|
||||
clearChannel(thirdChannel)
|
||||
|
||||
testChangeSet2 := change.ChangeSet{
|
||||
NodeID: node2.n.ID,
|
||||
Change: change.NodeNewOrUpdate,
|
||||
SelfUpdateOnly: false,
|
||||
}
|
||||
testChangeSet2 := change.NodeAdded(node2.n.ID)
|
||||
|
||||
batcher.AddWork(testChangeSet2)
|
||||
|
||||
@@ -2629,7 +2621,11 @@ func TestBatcherMultiConnection(t *testing.T) {
|
||||
remaining1Received, remaining3Received)
|
||||
}
|
||||
|
||||
// Verify second channel no longer receives updates (should be closed/removed)
|
||||
// Drain secondChannel of any messages received before removal
|
||||
// (the test wrapper sends NodeOffline before removal, which may have reached this channel)
|
||||
clearChannel(secondChannel)
|
||||
|
||||
// Verify second channel no longer receives new updates after being removed
|
||||
select {
|
||||
case <-secondChannel:
|
||||
t.Errorf("Removed connection still received update - this should not happen")
|
||||
|
||||
@@ -29,10 +29,8 @@ type debugType string
|
||||
const (
|
||||
fullResponseDebug debugType = "full"
|
||||
selfResponseDebug debugType = "self"
|
||||
patchResponseDebug debugType = "patch"
|
||||
removeResponseDebug debugType = "remove"
|
||||
changeResponseDebug debugType = "change"
|
||||
derpResponseDebug debugType = "derp"
|
||||
policyResponseDebug debugType = "policy"
|
||||
)
|
||||
|
||||
// NewMapResponseBuilder creates a new builder with basic fields set.
|
||||
|
||||
@@ -14,6 +14,7 @@ import (
|
||||
|
||||
"github.com/juanfont/headscale/hscontrol/state"
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"github.com/juanfont/headscale/hscontrol/types/change"
|
||||
"github.com/rs/zerolog/log"
|
||||
"tailscale.com/envknob"
|
||||
"tailscale.com/tailcfg"
|
||||
@@ -179,52 +180,108 @@ func (m *mapper) selfMapResponse(
|
||||
return ma, err
|
||||
}
|
||||
|
||||
func (m *mapper) derpMapResponse(
|
||||
nodeID types.NodeID,
|
||||
) (*tailcfg.MapResponse, error) {
|
||||
return m.NewMapResponseBuilder(nodeID).
|
||||
WithDebugType(derpResponseDebug).
|
||||
WithDERPMap().
|
||||
Build()
|
||||
}
|
||||
|
||||
// PeerChangedPatchResponse creates a patch MapResponse with
|
||||
// incoming update from a state change.
|
||||
func (m *mapper) peerChangedPatchResponse(
|
||||
nodeID types.NodeID,
|
||||
changed []*tailcfg.PeerChange,
|
||||
) (*tailcfg.MapResponse, error) {
|
||||
return m.NewMapResponseBuilder(nodeID).
|
||||
WithDebugType(patchResponseDebug).
|
||||
WithPeerChangedPatch(changed).
|
||||
Build()
|
||||
}
|
||||
|
||||
// peerChangeResponse returns a MapResponse with changed or added nodes.
|
||||
func (m *mapper) peerChangeResponse(
|
||||
// policyChangeResponse creates a MapResponse for policy changes.
|
||||
// It sends:
|
||||
// - PeersRemoved for peers that are no longer visible after the policy change
|
||||
// - PeersChanged for remaining peers (their AllowedIPs may have changed due to policy)
|
||||
// - Updated PacketFilters
|
||||
// - Updated SSHPolicy (SSH rules may reference users/groups that changed)
|
||||
// This avoids the issue where an empty Peers slice is interpreted by Tailscale
|
||||
// clients as "no change" rather than "no peers".
|
||||
func (m *mapper) policyChangeResponse(
|
||||
nodeID types.NodeID,
|
||||
capVer tailcfg.CapabilityVersion,
|
||||
changedNodeID types.NodeID,
|
||||
removedPeers []tailcfg.NodeID,
|
||||
currentPeers views.Slice[types.NodeView],
|
||||
) (*tailcfg.MapResponse, error) {
|
||||
peers := m.state.ListPeers(nodeID, changedNodeID)
|
||||
|
||||
return m.NewMapResponseBuilder(nodeID).
|
||||
WithDebugType(changeResponseDebug).
|
||||
builder := m.NewMapResponseBuilder(nodeID).
|
||||
WithDebugType(policyResponseDebug).
|
||||
WithCapabilityVersion(capVer).
|
||||
WithUserProfiles(peers).
|
||||
WithPeerChanges(peers).
|
||||
Build()
|
||||
WithPacketFilters().
|
||||
WithSSHPolicy()
|
||||
|
||||
if len(removedPeers) > 0 {
|
||||
// Convert tailcfg.NodeID to types.NodeID for WithPeersRemoved
|
||||
removedIDs := make([]types.NodeID, len(removedPeers))
|
||||
for i, id := range removedPeers {
|
||||
removedIDs[i] = types.NodeID(id) //nolint:gosec // NodeID types are equivalent
|
||||
}
|
||||
|
||||
builder.WithPeersRemoved(removedIDs...)
|
||||
}
|
||||
|
||||
// Send remaining peers in PeersChanged - their AllowedIPs may have
|
||||
// changed due to the policy update (e.g., different routes allowed).
|
||||
if currentPeers.Len() > 0 {
|
||||
builder.WithPeerChanges(currentPeers)
|
||||
}
|
||||
|
||||
return builder.Build()
|
||||
}
|
||||
|
||||
// peerRemovedResponse creates a MapResponse indicating that a peer has been removed.
|
||||
func (m *mapper) peerRemovedResponse(
|
||||
// buildFromChange builds a MapResponse from a change.Change specification.
|
||||
// This provides fine-grained control over what gets included in the response.
|
||||
func (m *mapper) buildFromChange(
|
||||
nodeID types.NodeID,
|
||||
removedNodeID types.NodeID,
|
||||
capVer tailcfg.CapabilityVersion,
|
||||
resp *change.Change,
|
||||
) (*tailcfg.MapResponse, error) {
|
||||
return m.NewMapResponseBuilder(nodeID).
|
||||
WithDebugType(removeResponseDebug).
|
||||
WithPeersRemoved(removedNodeID).
|
||||
Build()
|
||||
if resp.IsEmpty() {
|
||||
return nil, nil //nolint:nilnil // Empty response means nothing to send, not an error
|
||||
}
|
||||
|
||||
// If this is a self-update (the changed node is the receiving node),
|
||||
// send a self-update response to ensure the node sees its own changes.
|
||||
if resp.OriginNode != 0 && resp.OriginNode == nodeID {
|
||||
return m.selfMapResponse(nodeID, capVer)
|
||||
}
|
||||
|
||||
builder := m.NewMapResponseBuilder(nodeID).
|
||||
WithCapabilityVersion(capVer).
|
||||
WithDebugType(changeResponseDebug)
|
||||
|
||||
if resp.IncludeSelf {
|
||||
builder.WithSelfNode()
|
||||
}
|
||||
|
||||
if resp.IncludeDERPMap {
|
||||
builder.WithDERPMap()
|
||||
}
|
||||
|
||||
if resp.IncludeDNS {
|
||||
builder.WithDNSConfig()
|
||||
}
|
||||
|
||||
if resp.IncludeDomain {
|
||||
builder.WithDomain()
|
||||
}
|
||||
|
||||
if resp.IncludePolicy {
|
||||
builder.WithPacketFilters()
|
||||
builder.WithSSHPolicy()
|
||||
}
|
||||
|
||||
if resp.SendAllPeers {
|
||||
peers := m.state.ListPeers(nodeID)
|
||||
builder.WithUserProfiles(peers)
|
||||
builder.WithPeers(peers)
|
||||
} else {
|
||||
if len(resp.PeersChanged) > 0 {
|
||||
peers := m.state.ListPeers(nodeID, resp.PeersChanged...)
|
||||
builder.WithUserProfiles(peers)
|
||||
builder.WithPeerChanges(peers)
|
||||
}
|
||||
|
||||
if len(resp.PeersRemoved) > 0 {
|
||||
builder.WithPeersRemoved(resp.PeersRemoved...)
|
||||
}
|
||||
}
|
||||
|
||||
if len(resp.PeerPatches) > 0 {
|
||||
builder.WithPeerChangedPatch(resp.PeerPatches)
|
||||
}
|
||||
|
||||
return builder.Build()
|
||||
}
|
||||
|
||||
func writeDebugMapResponse(
|
||||
|
||||
@@ -473,14 +473,16 @@ func (a *AuthProviderOIDC) getRegistrationIDFromState(state string) *types.Regis
|
||||
|
||||
func (a *AuthProviderOIDC) createOrUpdateUserFromClaim(
|
||||
claims *types.OIDCClaims,
|
||||
) (*types.User, change.ChangeSet, error) {
|
||||
var user *types.User
|
||||
var err error
|
||||
var newUser bool
|
||||
var c change.ChangeSet
|
||||
) (*types.User, change.Change, error) {
|
||||
var (
|
||||
user *types.User
|
||||
err error
|
||||
newUser bool
|
||||
c change.Change
|
||||
)
|
||||
user, err = a.h.state.GetUserByOIDCIdentifier(claims.Identifier())
|
||||
if err != nil && !errors.Is(err, db.ErrUserNotFound) {
|
||||
return nil, change.EmptySet, fmt.Errorf("creating or updating user: %w", err)
|
||||
return nil, change.Change{}, fmt.Errorf("creating or updating user: %w", err)
|
||||
}
|
||||
|
||||
// if the user is still not found, create a new empty user.
|
||||
@@ -496,7 +498,7 @@ func (a *AuthProviderOIDC) createOrUpdateUserFromClaim(
|
||||
if newUser {
|
||||
user, c, err = a.h.state.CreateUser(*user)
|
||||
if err != nil {
|
||||
return nil, change.EmptySet, fmt.Errorf("creating user: %w", err)
|
||||
return nil, change.Change{}, fmt.Errorf("creating user: %w", err)
|
||||
}
|
||||
} else {
|
||||
_, c, err = a.h.state.UpdateUser(types.UserID(user.ID), func(u *types.User) error {
|
||||
@@ -504,7 +506,7 @@ func (a *AuthProviderOIDC) createOrUpdateUserFromClaim(
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, change.EmptySet, fmt.Errorf("updating user: %w", err)
|
||||
return nil, change.Change{}, fmt.Errorf("updating user: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -545,7 +547,7 @@ func (a *AuthProviderOIDC) handleRegistration(
|
||||
// Send both changes. Empty changes are ignored by Change().
|
||||
a.h.Change(nodeChange, routesChange)
|
||||
|
||||
return !nodeChange.Empty(), nil
|
||||
return !nodeChange.IsEmpty(), nil
|
||||
}
|
||||
|
||||
func renderOIDCCallbackTemplate(
|
||||
|
||||
@@ -57,6 +57,15 @@ var ErrUnsupportedPolicyMode = errors.New("unsupported policy mode")
|
||||
// ErrNodeNotFound is returned when a node cannot be found by its ID.
|
||||
var ErrNodeNotFound = errors.New("node not found")
|
||||
|
||||
// ErrInvalidNodeView is returned when an invalid node view is provided.
|
||||
var ErrInvalidNodeView = errors.New("invalid node view provided")
|
||||
|
||||
// ErrNodeNotInNodeStore is returned when a node no longer exists in the NodeStore.
|
||||
var ErrNodeNotInNodeStore = errors.New("node no longer exists in NodeStore")
|
||||
|
||||
// ErrNodeNameNotUnique is returned when a node name is not unique.
|
||||
var ErrNodeNameNotUnique = errors.New("node name is not unique")
|
||||
|
||||
// State manages Headscale's core state, coordinating between database, policy management,
|
||||
// IP allocation, and DERP routing. All methods are thread-safe.
|
||||
type State struct {
|
||||
@@ -243,7 +252,7 @@ func (s *State) DERPMap() tailcfg.DERPMapView {
|
||||
|
||||
// ReloadPolicy reloads the access control policy and triggers auto-approval if changed.
|
||||
// Returns true if the policy changed.
|
||||
func (s *State) ReloadPolicy() ([]change.ChangeSet, error) {
|
||||
func (s *State) ReloadPolicy() ([]change.Change, error) {
|
||||
pol, err := policyBytes(s.db, s.cfg)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("loading policy: %w", err)
|
||||
@@ -260,7 +269,7 @@ func (s *State) ReloadPolicy() ([]change.ChangeSet, error) {
|
||||
// propagate correctly when switching between policy types.
|
||||
s.nodeStore.RebuildPeerMaps()
|
||||
|
||||
cs := []change.ChangeSet{change.PolicyChange()}
|
||||
cs := []change.Change{change.PolicyChange()}
|
||||
|
||||
// Always call autoApproveNodes during policy reload, regardless of whether
|
||||
// the policy content has changed. This ensures that routes are re-evaluated
|
||||
@@ -289,16 +298,16 @@ func (s *State) ReloadPolicy() ([]change.ChangeSet, error) {
|
||||
|
||||
// CreateUser creates a new user and updates the policy manager.
|
||||
// Returns the created user, change set, and any error.
|
||||
func (s *State) CreateUser(user types.User) (*types.User, change.ChangeSet, error) {
|
||||
func (s *State) CreateUser(user types.User) (*types.User, change.Change, error) {
|
||||
if err := s.db.DB.Save(&user).Error; err != nil {
|
||||
return nil, change.EmptySet, fmt.Errorf("creating user: %w", err)
|
||||
return nil, change.Change{}, fmt.Errorf("creating user: %w", err)
|
||||
}
|
||||
|
||||
// Check if policy manager needs updating
|
||||
c, err := s.updatePolicyManagerUsers()
|
||||
if err != nil {
|
||||
// Log the error but don't fail the user creation
|
||||
return &user, change.EmptySet, fmt.Errorf("failed to update policy manager after user creation: %w", err)
|
||||
return &user, change.Change{}, fmt.Errorf("failed to update policy manager after user creation: %w", err)
|
||||
}
|
||||
|
||||
// Even if the policy manager doesn't detect a filter change, SSH policies
|
||||
@@ -306,7 +315,7 @@ func (s *State) CreateUser(user types.User) (*types.User, change.ChangeSet, erro
|
||||
// nodes, we should send a policy change to ensure they get updated SSH policies.
|
||||
// TODO(kradalby): detect this, or rebuild all SSH policies so we can determine
|
||||
// this upstream.
|
||||
if c.Empty() {
|
||||
if c.IsEmpty() {
|
||||
c = change.PolicyChange()
|
||||
}
|
||||
|
||||
@@ -317,7 +326,7 @@ func (s *State) CreateUser(user types.User) (*types.User, change.ChangeSet, erro
|
||||
|
||||
// UpdateUser modifies an existing user using the provided update function within a transaction.
|
||||
// Returns the updated user, change set, and any error.
|
||||
func (s *State) UpdateUser(userID types.UserID, updateFn func(*types.User) error) (*types.User, change.ChangeSet, error) {
|
||||
func (s *State) UpdateUser(userID types.UserID, updateFn func(*types.User) error) (*types.User, change.Change, error) {
|
||||
user, err := hsdb.Write(s.db.DB, func(tx *gorm.DB) (*types.User, error) {
|
||||
user, err := hsdb.GetUserByID(tx, userID)
|
||||
if err != nil {
|
||||
@@ -337,13 +346,13 @@ func (s *State) UpdateUser(userID types.UserID, updateFn func(*types.User) error
|
||||
return user, nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, change.EmptySet, err
|
||||
return nil, change.Change{}, err
|
||||
}
|
||||
|
||||
// Check if policy manager needs updating
|
||||
c, err := s.updatePolicyManagerUsers()
|
||||
if err != nil {
|
||||
return user, change.EmptySet, fmt.Errorf("failed to update policy manager after user update: %w", err)
|
||||
return user, change.Change{}, fmt.Errorf("failed to update policy manager after user update: %w", err)
|
||||
}
|
||||
|
||||
// TODO(kradalby): We might want to update nodestore with the user data
|
||||
@@ -358,7 +367,7 @@ func (s *State) DeleteUser(userID types.UserID) error {
|
||||
}
|
||||
|
||||
// RenameUser changes a user's name. The new name must be unique.
|
||||
func (s *State) RenameUser(userID types.UserID, newName string) (*types.User, change.ChangeSet, error) {
|
||||
func (s *State) RenameUser(userID types.UserID, newName string) (*types.User, change.Change, error) {
|
||||
return s.UpdateUser(userID, func(user *types.User) error {
|
||||
user.Name = newName
|
||||
return nil
|
||||
@@ -395,9 +404,9 @@ func (s *State) ListAllUsers() ([]types.User, error) {
|
||||
// NodeStore and the database. It verifies the node still exists in NodeStore to prevent
|
||||
// race conditions where a node might be deleted between UpdateNode returning and
|
||||
// persistNodeToDB being called.
|
||||
func (s *State) persistNodeToDB(node types.NodeView) (types.NodeView, change.ChangeSet, error) {
|
||||
func (s *State) persistNodeToDB(node types.NodeView) (types.NodeView, change.Change, error) {
|
||||
if !node.Valid() {
|
||||
return types.NodeView{}, change.EmptySet, fmt.Errorf("invalid node view provided")
|
||||
return types.NodeView{}, change.Change{}, ErrInvalidNodeView
|
||||
}
|
||||
|
||||
// Verify the node still exists in NodeStore before persisting to database.
|
||||
@@ -411,7 +420,8 @@ func (s *State) persistNodeToDB(node types.NodeView) (types.NodeView, change.Cha
|
||||
Str("node.name", node.Hostname()).
|
||||
Bool("is_ephemeral", node.IsEphemeral()).
|
||||
Msg("Node no longer exists in NodeStore, skipping database persist to prevent race condition")
|
||||
return types.NodeView{}, change.EmptySet, fmt.Errorf("node %d no longer exists in NodeStore, skipping database persist", node.ID())
|
||||
|
||||
return types.NodeView{}, change.Change{}, fmt.Errorf("%w: %d", ErrNodeNotInNodeStore, node.ID())
|
||||
}
|
||||
|
||||
nodePtr := node.AsStruct()
|
||||
@@ -421,23 +431,23 @@ func (s *State) persistNodeToDB(node types.NodeView) (types.NodeView, change.Cha
|
||||
// See: https://github.com/juanfont/headscale/issues/2862
|
||||
err := s.db.DB.Omit("expiry").Updates(nodePtr).Error
|
||||
if err != nil {
|
||||
return types.NodeView{}, change.EmptySet, fmt.Errorf("saving node: %w", err)
|
||||
return types.NodeView{}, change.Change{}, fmt.Errorf("saving node: %w", err)
|
||||
}
|
||||
|
||||
// Check if policy manager needs updating
|
||||
c, err := s.updatePolicyManagerNodes()
|
||||
if err != nil {
|
||||
return nodePtr.View(), change.EmptySet, fmt.Errorf("failed to update policy manager after node save: %w", err)
|
||||
return nodePtr.View(), change.Change{}, fmt.Errorf("failed to update policy manager after node save: %w", err)
|
||||
}
|
||||
|
||||
if c.Empty() {
|
||||
if c.IsEmpty() {
|
||||
c = change.NodeAdded(node.ID())
|
||||
}
|
||||
|
||||
return node, c, nil
|
||||
}
|
||||
|
||||
func (s *State) SaveNode(node types.NodeView) (types.NodeView, change.ChangeSet, error) {
|
||||
func (s *State) SaveNode(node types.NodeView) (types.NodeView, change.Change, error) {
|
||||
// Update NodeStore first
|
||||
nodePtr := node.AsStruct()
|
||||
|
||||
@@ -449,12 +459,12 @@ func (s *State) SaveNode(node types.NodeView) (types.NodeView, change.ChangeSet,
|
||||
|
||||
// DeleteNode permanently removes a node and cleans up associated resources.
|
||||
// Returns whether policies changed and any error. This operation is irreversible.
|
||||
func (s *State) DeleteNode(node types.NodeView) (change.ChangeSet, error) {
|
||||
func (s *State) DeleteNode(node types.NodeView) (change.Change, error) {
|
||||
s.nodeStore.DeleteNode(node.ID())
|
||||
|
||||
err := s.db.DeleteNode(node.AsStruct())
|
||||
if err != nil {
|
||||
return change.EmptySet, err
|
||||
return change.Change{}, err
|
||||
}
|
||||
|
||||
s.ipAlloc.FreeIPs(node.IPs())
|
||||
@@ -464,10 +474,10 @@ func (s *State) DeleteNode(node types.NodeView) (change.ChangeSet, error) {
|
||||
// Check if policy manager needs updating after node deletion
|
||||
policyChange, err := s.updatePolicyManagerNodes()
|
||||
if err != nil {
|
||||
return change.EmptySet, fmt.Errorf("failed to update policy manager after node deletion: %w", err)
|
||||
return change.Change{}, fmt.Errorf("failed to update policy manager after node deletion: %w", err)
|
||||
}
|
||||
|
||||
if !policyChange.Empty() {
|
||||
if !policyChange.IsEmpty() {
|
||||
c = policyChange
|
||||
}
|
||||
|
||||
@@ -475,7 +485,7 @@ func (s *State) DeleteNode(node types.NodeView) (change.ChangeSet, error) {
|
||||
}
|
||||
|
||||
// Connect marks a node as connected and updates its primary routes in the state.
|
||||
func (s *State) Connect(id types.NodeID) []change.ChangeSet {
|
||||
func (s *State) Connect(id types.NodeID) []change.Change {
|
||||
// CRITICAL FIX: Update the online status in NodeStore BEFORE creating change notification
|
||||
// This ensures that when the NodeCameOnline change is distributed and processed by other nodes,
|
||||
// the NodeStore already reflects the correct online status for full map generation.
|
||||
@@ -488,7 +498,7 @@ func (s *State) Connect(id types.NodeID) []change.ChangeSet {
|
||||
return nil
|
||||
}
|
||||
|
||||
c := []change.ChangeSet{change.NodeOnline(node)}
|
||||
c := []change.Change{change.NodeOnlineFor(node)}
|
||||
|
||||
log.Info().Uint64("node.id", id.Uint64()).Str("node.name", node.Hostname()).Msg("Node connected")
|
||||
|
||||
@@ -505,7 +515,7 @@ func (s *State) Connect(id types.NodeID) []change.ChangeSet {
|
||||
}
|
||||
|
||||
// Disconnect marks a node as disconnected and updates its primary routes in the state.
|
||||
func (s *State) Disconnect(id types.NodeID) ([]change.ChangeSet, error) {
|
||||
func (s *State) Disconnect(id types.NodeID) ([]change.Change, error) {
|
||||
now := time.Now()
|
||||
|
||||
node, ok := s.nodeStore.UpdateNode(id, func(n *types.Node) {
|
||||
@@ -527,14 +537,15 @@ func (s *State) Disconnect(id types.NodeID) ([]change.ChangeSet, error) {
|
||||
// Log error but don't fail the disconnection - NodeStore is already updated
|
||||
// and we need to send change notifications to peers
|
||||
log.Error().Err(err).Uint64("node.id", id.Uint64()).Str("node.name", node.Hostname()).Msg("Failed to update last seen in database")
|
||||
c = change.EmptySet
|
||||
|
||||
c = change.Change{}
|
||||
}
|
||||
|
||||
// The node is disconnecting so make sure that none of the routes it
|
||||
// announced are served to any nodes.
|
||||
routeChange := s.primaryRoutes.SetRoutes(id)
|
||||
|
||||
cs := []change.ChangeSet{change.NodeOffline(node), c}
|
||||
cs := []change.Change{change.NodeOfflineFor(node), c}
|
||||
|
||||
// If we have a policy change or route change, return that as it's more comprehensive
|
||||
// Otherwise, return the NodeOffline change to ensure nodes are notified
|
||||
@@ -637,7 +648,7 @@ func (s *State) ListEphemeralNodes() views.Slice[types.NodeView] {
|
||||
}
|
||||
|
||||
// SetNodeExpiry updates the expiration time for a node.
|
||||
func (s *State) SetNodeExpiry(nodeID types.NodeID, expiry time.Time) (types.NodeView, change.ChangeSet, error) {
|
||||
func (s *State) SetNodeExpiry(nodeID types.NodeID, expiry time.Time) (types.NodeView, change.Change, error) {
|
||||
// Update NodeStore before database to ensure consistency. The NodeStore update is
|
||||
// blocking and will be the source of truth for the batcher. The database update must
|
||||
// make the exact same change. If the database update fails, the NodeStore change will
|
||||
@@ -649,7 +660,7 @@ func (s *State) SetNodeExpiry(nodeID types.NodeID, expiry time.Time) (types.Node
|
||||
})
|
||||
|
||||
if !ok {
|
||||
return types.NodeView{}, change.EmptySet, fmt.Errorf("node not found in NodeStore: %d", nodeID)
|
||||
return types.NodeView{}, change.Change{}, fmt.Errorf("%w: %d", ErrNodeNotInNodeStore, nodeID)
|
||||
}
|
||||
|
||||
return s.persistNodeToDB(n)
|
||||
@@ -658,16 +669,16 @@ func (s *State) SetNodeExpiry(nodeID types.NodeID, expiry time.Time) (types.Node
|
||||
// SetNodeTags assigns tags to a node, making it a "tagged node".
|
||||
// Once a node is tagged, it cannot be un-tagged (only tags can be changed).
|
||||
// The UserID is preserved as "created by" information.
|
||||
func (s *State) SetNodeTags(nodeID types.NodeID, tags []string) (types.NodeView, change.ChangeSet, error) {
|
||||
func (s *State) SetNodeTags(nodeID types.NodeID, tags []string) (types.NodeView, change.Change, error) {
|
||||
// CANNOT REMOVE ALL TAGS
|
||||
if len(tags) == 0 {
|
||||
return types.NodeView{}, change.EmptySet, types.ErrCannotRemoveAllTags
|
||||
return types.NodeView{}, change.Change{}, types.ErrCannotRemoveAllTags
|
||||
}
|
||||
|
||||
// Get node for validation
|
||||
existingNode, exists := s.nodeStore.GetNode(nodeID)
|
||||
if !exists {
|
||||
return types.NodeView{}, change.EmptySet, fmt.Errorf("%w: %d", ErrNodeNotFound, nodeID)
|
||||
return types.NodeView{}, change.Change{}, fmt.Errorf("%w: %d", ErrNodeNotFound, nodeID)
|
||||
}
|
||||
|
||||
// Validate tags: must have correct format and exist in policy
|
||||
@@ -685,7 +696,7 @@ func (s *State) SetNodeTags(nodeID types.NodeID, tags []string) (types.NodeView,
|
||||
}
|
||||
|
||||
if len(invalidTags) > 0 {
|
||||
return types.NodeView{}, change.EmptySet, fmt.Errorf("%w %v are invalid or not permitted", ErrRequestedTagsInvalidOrNotPermitted, invalidTags)
|
||||
return types.NodeView{}, change.Change{}, fmt.Errorf("%w %v are invalid or not permitted", ErrRequestedTagsInvalidOrNotPermitted, invalidTags)
|
||||
}
|
||||
|
||||
slices.Sort(validatedTags)
|
||||
@@ -703,14 +714,14 @@ func (s *State) SetNodeTags(nodeID types.NodeID, tags []string) (types.NodeView,
|
||||
})
|
||||
|
||||
if !ok {
|
||||
return types.NodeView{}, change.EmptySet, fmt.Errorf("node not found in NodeStore: %d", nodeID)
|
||||
return types.NodeView{}, change.Change{}, fmt.Errorf("%w: %d", ErrNodeNotInNodeStore, nodeID)
|
||||
}
|
||||
|
||||
return s.persistNodeToDB(n)
|
||||
}
|
||||
|
||||
// SetApprovedRoutes sets the network routes that a node is approved to advertise.
|
||||
func (s *State) SetApprovedRoutes(nodeID types.NodeID, routes []netip.Prefix) (types.NodeView, change.ChangeSet, error) {
|
||||
func (s *State) SetApprovedRoutes(nodeID types.NodeID, routes []netip.Prefix) (types.NodeView, change.Change, error) {
|
||||
// TODO(kradalby): In principle we should call the AutoApprove logic here
|
||||
// because even if the CLI removes an auto-approved route, it will be added
|
||||
// back automatically.
|
||||
@@ -719,13 +730,13 @@ func (s *State) SetApprovedRoutes(nodeID types.NodeID, routes []netip.Prefix) (t
|
||||
})
|
||||
|
||||
if !ok {
|
||||
return types.NodeView{}, change.EmptySet, fmt.Errorf("node not found in NodeStore: %d", nodeID)
|
||||
return types.NodeView{}, change.Change{}, fmt.Errorf("%w: %d", ErrNodeNotInNodeStore, nodeID)
|
||||
}
|
||||
|
||||
// Persist the node changes to the database
|
||||
nodeView, c, err := s.persistNodeToDB(n)
|
||||
if err != nil {
|
||||
return types.NodeView{}, change.EmptySet, err
|
||||
return types.NodeView{}, change.Change{}, err
|
||||
}
|
||||
|
||||
// Update primary routes table based on SubnetRoutes (intersection of announced and approved).
|
||||
@@ -743,9 +754,9 @@ func (s *State) SetApprovedRoutes(nodeID types.NodeID, routes []netip.Prefix) (t
|
||||
}
|
||||
|
||||
// RenameNode changes the display name of a node.
|
||||
func (s *State) RenameNode(nodeID types.NodeID, newName string) (types.NodeView, change.ChangeSet, error) {
|
||||
func (s *State) RenameNode(nodeID types.NodeID, newName string) (types.NodeView, change.Change, error) {
|
||||
if err := util.ValidateHostname(newName); err != nil {
|
||||
return types.NodeView{}, change.EmptySet, fmt.Errorf("renaming node: %w", err)
|
||||
return types.NodeView{}, change.Change{}, fmt.Errorf("renaming node: %w", err)
|
||||
}
|
||||
|
||||
// Check name uniqueness against NodeStore
|
||||
@@ -753,7 +764,7 @@ func (s *State) RenameNode(nodeID types.NodeID, newName string) (types.NodeView,
|
||||
for i := 0; i < allNodes.Len(); i++ {
|
||||
node := allNodes.At(i)
|
||||
if node.ID() != nodeID && node.AsStruct().GivenName == newName {
|
||||
return types.NodeView{}, change.EmptySet, fmt.Errorf("name is not unique: %s", newName)
|
||||
return types.NodeView{}, change.Change{}, fmt.Errorf("%w: %s", ErrNodeNameNotUnique, newName)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -765,7 +776,7 @@ func (s *State) RenameNode(nodeID types.NodeID, newName string) (types.NodeView,
|
||||
})
|
||||
|
||||
if !ok {
|
||||
return types.NodeView{}, change.EmptySet, fmt.Errorf("node not found in NodeStore: %d", nodeID)
|
||||
return types.NodeView{}, change.Change{}, fmt.Errorf("%w: %d", ErrNodeNotInNodeStore, nodeID)
|
||||
}
|
||||
|
||||
return s.persistNodeToDB(n)
|
||||
@@ -810,12 +821,12 @@ func (s *State) BackfillNodeIPs() ([]string, error) {
|
||||
|
||||
// ExpireExpiredNodes finds and processes expired nodes since the last check.
|
||||
// Returns next check time, state update with expired nodes, and whether any were found.
|
||||
func (s *State) ExpireExpiredNodes(lastCheck time.Time) (time.Time, []change.ChangeSet, bool) {
|
||||
func (s *State) ExpireExpiredNodes(lastCheck time.Time) (time.Time, []change.Change, bool) {
|
||||
// Why capture start time: We need to ensure we don't miss nodes that expire
|
||||
// while this function is running by using a consistent timestamp for the next check
|
||||
started := time.Now()
|
||||
|
||||
var updates []change.ChangeSet
|
||||
var updates []change.Change
|
||||
|
||||
for _, node := range s.nodeStore.ListNodes().All() {
|
||||
if !node.Valid() {
|
||||
@@ -825,7 +836,7 @@ func (s *State) ExpireExpiredNodes(lastCheck time.Time) (time.Time, []change.Cha
|
||||
// Why check After(lastCheck): We only want to notify about nodes that
|
||||
// expired since the last check to avoid duplicate notifications
|
||||
if node.IsExpired() && node.Expiry().Valid() && node.Expiry().Get().After(lastCheck) {
|
||||
updates = append(updates, change.KeyExpiry(node.ID(), node.Expiry().Get()))
|
||||
updates = append(updates, change.KeyExpiryFor(node.ID(), node.Expiry().Get()))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -868,7 +879,7 @@ func (s *State) SetPolicy(pol []byte) (bool, error) {
|
||||
|
||||
// AutoApproveRoutes checks if a node's routes should be auto-approved.
|
||||
// AutoApproveRoutes checks if any routes should be auto-approved for a node and updates them.
|
||||
func (s *State) AutoApproveRoutes(nv types.NodeView) (change.ChangeSet, error) {
|
||||
func (s *State) AutoApproveRoutes(nv types.NodeView) (change.Change, error) {
|
||||
approved, changed := policy.ApproveRoutesWithPolicy(s.polMan, nv, nv.ApprovedRoutes().AsSlice(), nv.AnnouncedRoutes())
|
||||
if changed {
|
||||
log.Debug().
|
||||
@@ -889,7 +900,7 @@ func (s *State) AutoApproveRoutes(nv types.NodeView) (change.ChangeSet, error) {
|
||||
Err(err).
|
||||
Msg("Failed to persist auto-approved routes")
|
||||
|
||||
return change.EmptySet, err
|
||||
return change.Change{}, err
|
||||
}
|
||||
|
||||
log.Info().Uint64("node.id", nv.ID().Uint64()).Str("node.name", nv.Hostname()).Strs("routes.approved", util.PrefixesToString(approved)).Msg("Routes approved")
|
||||
@@ -897,7 +908,7 @@ func (s *State) AutoApproveRoutes(nv types.NodeView) (change.ChangeSet, error) {
|
||||
return c, nil
|
||||
}
|
||||
|
||||
return change.EmptySet, nil
|
||||
return change.Change{}, nil
|
||||
}
|
||||
|
||||
// GetPolicy retrieves the current policy from the database.
|
||||
@@ -911,14 +922,14 @@ func (s *State) SetPolicyInDB(data string) (*types.Policy, error) {
|
||||
}
|
||||
|
||||
// SetNodeRoutes sets the primary routes for a node.
|
||||
func (s *State) SetNodeRoutes(nodeID types.NodeID, routes ...netip.Prefix) change.ChangeSet {
|
||||
func (s *State) SetNodeRoutes(nodeID types.NodeID, routes ...netip.Prefix) change.Change {
|
||||
if s.primaryRoutes.SetRoutes(nodeID, routes...) {
|
||||
// Route changes affect packet filters for all nodes, so trigger a policy change
|
||||
// to ensure filters are regenerated across the entire network
|
||||
return change.PolicyChange()
|
||||
}
|
||||
|
||||
return change.EmptySet
|
||||
return change.Change{}
|
||||
}
|
||||
|
||||
// GetNodePrimaryRoutes returns the primary routes for a node.
|
||||
@@ -1232,17 +1243,17 @@ func (s *State) HandleNodeFromAuthPath(
|
||||
userID types.UserID,
|
||||
expiry *time.Time,
|
||||
registrationMethod string,
|
||||
) (types.NodeView, change.ChangeSet, error) {
|
||||
) (types.NodeView, change.Change, error) {
|
||||
// Get the registration entry from cache
|
||||
regEntry, ok := s.GetRegistrationCacheEntry(registrationID)
|
||||
if !ok {
|
||||
return types.NodeView{}, change.EmptySet, hsdb.ErrNodeNotFoundRegistrationCache
|
||||
return types.NodeView{}, change.Change{}, hsdb.ErrNodeNotFoundRegistrationCache
|
||||
}
|
||||
|
||||
// Get the user
|
||||
user, err := s.db.GetUserByID(userID)
|
||||
if err != nil {
|
||||
return types.NodeView{}, change.EmptySet, fmt.Errorf("failed to find user: %w", err)
|
||||
return types.NodeView{}, change.Change{}, fmt.Errorf("failed to find user: %w", err)
|
||||
}
|
||||
|
||||
// Ensure we have a valid hostname from the registration cache entry
|
||||
@@ -1306,7 +1317,7 @@ func (s *State) HandleNodeFromAuthPath(
|
||||
})
|
||||
|
||||
if !ok {
|
||||
return types.NodeView{}, change.EmptySet, fmt.Errorf("node not found in NodeStore: %d", existingNodeSameUser.ID())
|
||||
return types.NodeView{}, change.Change{}, fmt.Errorf("%w: %d", ErrNodeNotInNodeStore, existingNodeSameUser.ID())
|
||||
}
|
||||
|
||||
_, err = hsdb.Write(s.db.DB, func(tx *gorm.DB) (*types.Node, error) {
|
||||
@@ -1318,7 +1329,7 @@ func (s *State) HandleNodeFromAuthPath(
|
||||
return nil, nil
|
||||
})
|
||||
if err != nil {
|
||||
return types.NodeView{}, change.EmptySet, err
|
||||
return types.NodeView{}, change.Change{}, err
|
||||
}
|
||||
|
||||
log.Trace().
|
||||
@@ -1376,7 +1387,7 @@ func (s *State) HandleNodeFromAuthPath(
|
||||
ExistingNodeForNetinfo: cmp.Or(existingNodeAnyUser, types.NodeView{}),
|
||||
})
|
||||
if err != nil {
|
||||
return types.NodeView{}, change.EmptySet, err
|
||||
return types.NodeView{}, change.Change{}, err
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1397,8 +1408,8 @@ func (s *State) HandleNodeFromAuthPath(
|
||||
return finalNode, change.NodeAdded(finalNode.ID()), fmt.Errorf("failed to update policy manager nodes: %w", err)
|
||||
}
|
||||
|
||||
var c change.ChangeSet
|
||||
if !usersChange.Empty() || !nodesChange.Empty() {
|
||||
var c change.Change
|
||||
if !usersChange.IsEmpty() || !nodesChange.IsEmpty() {
|
||||
c = change.PolicyChange()
|
||||
} else {
|
||||
c = change.NodeAdded(finalNode.ID())
|
||||
@@ -1411,10 +1422,10 @@ func (s *State) HandleNodeFromAuthPath(
|
||||
func (s *State) HandleNodeFromPreAuthKey(
|
||||
regReq tailcfg.RegisterRequest,
|
||||
machineKey key.MachinePublic,
|
||||
) (types.NodeView, change.ChangeSet, error) {
|
||||
) (types.NodeView, change.Change, error) {
|
||||
pak, err := s.GetPreAuthKey(regReq.Auth.AuthKey)
|
||||
if err != nil {
|
||||
return types.NodeView{}, change.EmptySet, err
|
||||
return types.NodeView{}, change.Change{}, err
|
||||
}
|
||||
|
||||
// Check if node exists with same machine key before validating the key.
|
||||
@@ -1461,7 +1472,7 @@ func (s *State) HandleNodeFromPreAuthKey(
|
||||
// New node or NodeKey rotation: require valid auth key.
|
||||
err = pak.Validate()
|
||||
if err != nil {
|
||||
return types.NodeView{}, change.EmptySet, err
|
||||
return types.NodeView{}, change.Change{}, err
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1535,7 +1546,7 @@ func (s *State) HandleNodeFromPreAuthKey(
|
||||
})
|
||||
|
||||
if !ok {
|
||||
return types.NodeView{}, change.EmptySet, fmt.Errorf("node not found in NodeStore: %d", existingNodeSameUser.ID())
|
||||
return types.NodeView{}, change.Change{}, fmt.Errorf("%w: %d", ErrNodeNotInNodeStore, existingNodeSameUser.ID())
|
||||
}
|
||||
|
||||
_, err = hsdb.Write(s.db.DB, func(tx *gorm.DB) (*types.Node, error) {
|
||||
@@ -1555,7 +1566,7 @@ func (s *State) HandleNodeFromPreAuthKey(
|
||||
return nil, nil
|
||||
})
|
||||
if err != nil {
|
||||
return types.NodeView{}, change.EmptySet, fmt.Errorf("writing node to database: %w", err)
|
||||
return types.NodeView{}, change.Change{}, fmt.Errorf("writing node to database: %w", err)
|
||||
}
|
||||
|
||||
log.Trace().
|
||||
@@ -1607,7 +1618,7 @@ func (s *State) HandleNodeFromPreAuthKey(
|
||||
ExistingNodeForNetinfo: cmp.Or(existingNodeAnyUser, types.NodeView{}),
|
||||
})
|
||||
if err != nil {
|
||||
return types.NodeView{}, change.EmptySet, fmt.Errorf("creating new node: %w", err)
|
||||
return types.NodeView{}, change.Change{}, fmt.Errorf("creating new node: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1622,8 +1633,8 @@ func (s *State) HandleNodeFromPreAuthKey(
|
||||
return finalNode, change.NodeAdded(finalNode.ID()), fmt.Errorf("failed to update policy manager nodes: %w", err)
|
||||
}
|
||||
|
||||
var c change.ChangeSet
|
||||
if !usersChange.Empty() || !nodesChange.Empty() {
|
||||
var c change.Change
|
||||
if !usersChange.IsEmpty() || !nodesChange.IsEmpty() {
|
||||
c = change.PolicyChange()
|
||||
} else {
|
||||
c = change.NodeAdded(finalNode.ID())
|
||||
@@ -1638,17 +1649,17 @@ func (s *State) HandleNodeFromPreAuthKey(
|
||||
// have the list already available so it could go much quicker. Alternatively
|
||||
// the policy manager could have a remove or add list for users.
|
||||
// updatePolicyManagerUsers refreshes the policy manager with current user data.
|
||||
func (s *State) updatePolicyManagerUsers() (change.ChangeSet, error) {
|
||||
func (s *State) updatePolicyManagerUsers() (change.Change, error) {
|
||||
users, err := s.ListAllUsers()
|
||||
if err != nil {
|
||||
return change.EmptySet, fmt.Errorf("listing users for policy update: %w", err)
|
||||
return change.Change{}, fmt.Errorf("listing users for policy update: %w", err)
|
||||
}
|
||||
|
||||
log.Debug().Caller().Int("user.count", len(users)).Msg("Policy manager user update initiated because user list modification detected")
|
||||
|
||||
changed, err := s.polMan.SetUsers(users)
|
||||
if err != nil {
|
||||
return change.EmptySet, fmt.Errorf("updating policy manager users: %w", err)
|
||||
return change.Change{}, fmt.Errorf("updating policy manager users: %w", err)
|
||||
}
|
||||
|
||||
log.Debug().Caller().Bool("policy.changed", changed).Msg("Policy manager user update completed because SetUsers operation finished")
|
||||
@@ -1657,7 +1668,7 @@ func (s *State) updatePolicyManagerUsers() (change.ChangeSet, error) {
|
||||
return change.PolicyChange(), nil
|
||||
}
|
||||
|
||||
return change.EmptySet, nil
|
||||
return change.Change{}, nil
|
||||
}
|
||||
|
||||
// updatePolicyManagerNodes updates the policy manager with current nodes.
|
||||
@@ -1666,19 +1677,22 @@ func (s *State) updatePolicyManagerUsers() (change.ChangeSet, error) {
|
||||
// have the list already available so it could go much quicker. Alternatively
|
||||
// the policy manager could have a remove or add list for nodes.
|
||||
// updatePolicyManagerNodes refreshes the policy manager with current node data.
|
||||
func (s *State) updatePolicyManagerNodes() (change.ChangeSet, error) {
|
||||
func (s *State) updatePolicyManagerNodes() (change.Change, error) {
|
||||
nodes := s.ListNodes()
|
||||
|
||||
changed, err := s.polMan.SetNodes(nodes)
|
||||
if err != nil {
|
||||
return change.EmptySet, fmt.Errorf("updating policy manager nodes: %w", err)
|
||||
return change.Change{}, fmt.Errorf("updating policy manager nodes: %w", err)
|
||||
}
|
||||
|
||||
if changed {
|
||||
// Rebuild peer maps because policy-affecting node changes (tags, user, IPs)
|
||||
// affect ACL visibility. Without this, cached peer relationships use stale data.
|
||||
s.nodeStore.RebuildPeerMaps()
|
||||
return change.PolicyChange(), nil
|
||||
}
|
||||
|
||||
return change.EmptySet, nil
|
||||
return change.Change{}, nil
|
||||
}
|
||||
|
||||
// PingDB checks if the database connection is healthy.
|
||||
@@ -1692,14 +1706,16 @@ func (s *State) PingDB(ctx context.Context) error {
|
||||
// TODO(kradalby): This is kind of messy, maybe this is another +1
|
||||
// for an event bus. See example comments here.
|
||||
// autoApproveNodes automatically approves nodes based on policy rules.
|
||||
func (s *State) autoApproveNodes() ([]change.ChangeSet, error) {
|
||||
func (s *State) autoApproveNodes() ([]change.Change, error) {
|
||||
nodes := s.ListNodes()
|
||||
|
||||
// Approve routes concurrently, this should make it likely
|
||||
// that the writes end in the same batch in the nodestore write.
|
||||
var errg errgroup.Group
|
||||
var cs []change.ChangeSet
|
||||
var mu sync.Mutex
|
||||
var (
|
||||
errg errgroup.Group
|
||||
cs []change.Change
|
||||
mu sync.Mutex
|
||||
)
|
||||
for _, nv := range nodes.All() {
|
||||
errg.Go(func() error {
|
||||
approved, changed := policy.ApproveRoutesWithPolicy(s.polMan, nv, nv.ApprovedRoutes().AsSlice(), nv.AnnouncedRoutes())
|
||||
@@ -1740,7 +1756,7 @@ func (s *State) autoApproveNodes() ([]change.ChangeSet, error) {
|
||||
// - node.PeerChangeFromMapRequest
|
||||
// - node.ApplyPeerChange
|
||||
// - logTracePeerChange in poll.go.
|
||||
func (s *State) UpdateNodeFromMapRequest(id types.NodeID, req tailcfg.MapRequest) (change.ChangeSet, error) {
|
||||
func (s *State) UpdateNodeFromMapRequest(id types.NodeID, req tailcfg.MapRequest) (change.Change, error) {
|
||||
log.Trace().
|
||||
Caller().
|
||||
Uint64("node.id", id.Uint64()).
|
||||
@@ -1853,7 +1869,7 @@ func (s *State) UpdateNodeFromMapRequest(id types.NodeID, req tailcfg.MapRequest
|
||||
})
|
||||
|
||||
if !ok {
|
||||
return change.EmptySet, fmt.Errorf("node not found in NodeStore: %d", id)
|
||||
return change.Change{}, fmt.Errorf("%w: %d", ErrNodeNotInNodeStore, id)
|
||||
}
|
||||
|
||||
if routeChange {
|
||||
@@ -1865,80 +1881,67 @@ func (s *State) UpdateNodeFromMapRequest(id types.NodeID, req tailcfg.MapRequest
|
||||
// SetApprovedRoutes will update both database and PrimaryRoutes table
|
||||
_, c, err := s.SetApprovedRoutes(id, autoApprovedRoutes)
|
||||
if err != nil {
|
||||
return change.EmptySet, fmt.Errorf("persisting auto-approved routes: %w", err)
|
||||
return change.Change{}, fmt.Errorf("persisting auto-approved routes: %w", err)
|
||||
}
|
||||
|
||||
// If SetApprovedRoutes resulted in a policy change, return it
|
||||
if !c.Empty() {
|
||||
if !c.IsEmpty() {
|
||||
return c, nil
|
||||
}
|
||||
} // Continue with the rest of the processing using the updated node
|
||||
|
||||
nodeRouteChange := change.EmptySet
|
||||
|
||||
// Handle route changes after NodeStore update
|
||||
// We need to update node routes if either:
|
||||
// 1. The approved routes changed (routeChange is true), OR
|
||||
// 2. The announced routes changed (even if approved routes stayed the same)
|
||||
// This is because SubnetRoutes is the intersection of announced AND approved routes.
|
||||
needsRouteUpdate := false
|
||||
var routesChangedButNotApproved bool
|
||||
if hostinfoChanged && needsRouteApproval && !routeChange {
|
||||
if hi := req.Hostinfo; hi != nil {
|
||||
routesChangedButNotApproved = true
|
||||
}
|
||||
}
|
||||
|
||||
if routesChangedButNotApproved {
|
||||
needsRouteUpdate = true
|
||||
log.Debug().
|
||||
Caller().
|
||||
Uint64("node.id", id.Uint64()).
|
||||
Msg("updating routes because announced routes changed but approved routes did not")
|
||||
}
|
||||
|
||||
if needsRouteUpdate {
|
||||
// SetNodeRoutes sets the active/distributed routes, so we must use AllApprovedRoutes()
|
||||
// which returns only the intersection of announced AND approved routes.
|
||||
// Using AnnouncedRoutes() would bypass the security model and auto-approve everything.
|
||||
log.Debug().
|
||||
Caller().
|
||||
Uint64("node.id", id.Uint64()).
|
||||
Strs("announcedRoutes", util.PrefixesToString(updatedNode.AnnouncedRoutes())).
|
||||
Strs("approvedRoutes", util.PrefixesToString(updatedNode.ApprovedRoutes().AsSlice())).
|
||||
Strs("allApprovedRoutes", util.PrefixesToString(updatedNode.AllApprovedRoutes())).
|
||||
Msg("updating node routes for distribution")
|
||||
nodeRouteChange = s.SetNodeRoutes(id, updatedNode.AllApprovedRoutes()...)
|
||||
}
|
||||
// Handle route changes after NodeStore update.
|
||||
// Update routes if announced routes changed (even if approved routes stayed the same)
|
||||
// because SubnetRoutes is the intersection of announced AND approved routes.
|
||||
nodeRouteChange := s.maybeUpdateNodeRoutes(id, updatedNode, hostinfoChanged, needsRouteApproval, routeChange, req.Hostinfo)
|
||||
|
||||
_, policyChange, err := s.persistNodeToDB(updatedNode)
|
||||
if err != nil {
|
||||
return change.EmptySet, fmt.Errorf("saving to database: %w", err)
|
||||
return change.Change{}, fmt.Errorf("saving to database: %w", err)
|
||||
}
|
||||
|
||||
if policyChange.IsFull() {
|
||||
return policyChange, nil
|
||||
}
|
||||
if !nodeRouteChange.Empty() {
|
||||
|
||||
if !nodeRouteChange.IsEmpty() {
|
||||
return nodeRouteChange, nil
|
||||
}
|
||||
|
||||
// Determine the most specific change type based on what actually changed.
|
||||
// This allows us to send lightweight patch updates instead of full map responses.
|
||||
return buildMapRequestChangeResponse(id, updatedNode, hostinfoChanged, endpointChanged, derpChanged)
|
||||
}
|
||||
|
||||
// buildMapRequestChangeResponse determines the appropriate response type for a MapRequest update.
|
||||
// Hostinfo changes require a full update, while endpoint/DERP changes can use lightweight patches.
|
||||
func buildMapRequestChangeResponse(
|
||||
id types.NodeID,
|
||||
node types.NodeView,
|
||||
hostinfoChanged, endpointChanged, derpChanged bool,
|
||||
) (change.Change, error) {
|
||||
// Hostinfo changes require NodeAdded (full update) as they may affect many fields.
|
||||
if hostinfoChanged {
|
||||
return change.NodeAdded(id), nil
|
||||
}
|
||||
|
||||
// Return specific change types for endpoint and/or DERP updates.
|
||||
// The batcher will query NodeStore for current state and include both in PeerChange if both changed.
|
||||
// Prioritize endpoint changes as they're more common and important for connectivity.
|
||||
if endpointChanged {
|
||||
return change.EndpointUpdate(id), nil
|
||||
}
|
||||
if endpointChanged || derpChanged {
|
||||
patch := &tailcfg.PeerChange{NodeID: id.NodeID()}
|
||||
|
||||
if derpChanged {
|
||||
return change.DERPUpdate(id), nil
|
||||
if endpointChanged {
|
||||
patch.Endpoints = node.Endpoints().AsSlice()
|
||||
}
|
||||
|
||||
if derpChanged {
|
||||
if hi := node.Hostinfo(); hi.Valid() {
|
||||
if ni := hi.NetInfo(); ni.Valid() {
|
||||
patch.DERPRegion = ni.PreferredDERP()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return change.EndpointOrDERPUpdate(id, patch), nil
|
||||
}
|
||||
|
||||
return change.NodeAdded(id), nil
|
||||
@@ -1983,3 +1986,34 @@ func peerChangeEmpty(peerChange tailcfg.PeerChange) bool {
|
||||
peerChange.LastSeen == nil &&
|
||||
peerChange.KeyExpiry == nil
|
||||
}
|
||||
|
||||
// maybeUpdateNodeRoutes updates node routes if announced routes changed but approved routes didn't.
|
||||
// This is needed because SubnetRoutes is the intersection of announced AND approved routes.
|
||||
func (s *State) maybeUpdateNodeRoutes(
|
||||
id types.NodeID,
|
||||
node types.NodeView,
|
||||
hostinfoChanged, needsRouteApproval, routeChange bool,
|
||||
hostinfo *tailcfg.Hostinfo,
|
||||
) change.Change {
|
||||
// Only update if announced routes changed without approval change
|
||||
if !hostinfoChanged || !needsRouteApproval || routeChange || hostinfo == nil {
|
||||
return change.Change{}
|
||||
}
|
||||
|
||||
log.Debug().
|
||||
Caller().
|
||||
Uint64("node.id", id.Uint64()).
|
||||
Msg("updating routes because announced routes changed but approved routes did not")
|
||||
|
||||
// SetNodeRoutes sets the active/distributed routes using AllApprovedRoutes()
|
||||
// which returns only the intersection of announced AND approved routes.
|
||||
log.Debug().
|
||||
Caller().
|
||||
Uint64("node.id", id.Uint64()).
|
||||
Strs("announcedRoutes", util.PrefixesToString(node.AnnouncedRoutes())).
|
||||
Strs("approvedRoutes", util.PrefixesToString(node.ApprovedRoutes().AsSlice())).
|
||||
Strs("allApprovedRoutes", util.PrefixesToString(node.AllApprovedRoutes())).
|
||||
Msg("updating node routes for distribution")
|
||||
|
||||
return s.SetNodeRoutes(id, node.AllApprovedRoutes()...)
|
||||
}
|
||||
|
||||
@@ -1,241 +1,445 @@
|
||||
//go:generate go tool stringer -type=Change
|
||||
package change
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"slices"
|
||||
"time"
|
||||
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"tailscale.com/tailcfg"
|
||||
)
|
||||
|
||||
type (
|
||||
NodeID = types.NodeID
|
||||
UserID = types.UserID
|
||||
)
|
||||
// Change declares what should be included in a MapResponse.
|
||||
// The mapper uses this to build the response without guessing.
|
||||
type Change struct {
|
||||
// Reason is a human-readable description for logging/debugging.
|
||||
Reason string
|
||||
|
||||
type Change int
|
||||
// TargetNode, if set, means this response should only be sent to this node.
|
||||
TargetNode types.NodeID
|
||||
|
||||
const (
|
||||
ChangeUnknown Change = 0
|
||||
// OriginNode is the node that triggered this change.
|
||||
// Used for self-update detection and filtering.
|
||||
OriginNode types.NodeID
|
||||
|
||||
// Deprecated: Use specific change instead
|
||||
// Full is a legacy change to ensure places where we
|
||||
// have not yet determined the specific update, can send.
|
||||
Full Change = 9
|
||||
// Content flags - what to include in the MapResponse.
|
||||
IncludeSelf bool
|
||||
IncludeDERPMap bool
|
||||
IncludeDNS bool
|
||||
IncludeDomain bool
|
||||
IncludePolicy bool // PacketFilters and SSHPolicy - always sent together
|
||||
|
||||
// Server changes.
|
||||
Policy Change = 11
|
||||
DERP Change = 12
|
||||
ExtraRecords Change = 13
|
||||
// Peer changes.
|
||||
PeersChanged []types.NodeID
|
||||
PeersRemoved []types.NodeID
|
||||
PeerPatches []*tailcfg.PeerChange
|
||||
SendAllPeers bool
|
||||
|
||||
// Node changes.
|
||||
NodeCameOnline Change = 21
|
||||
NodeWentOffline Change = 22
|
||||
NodeRemove Change = 23
|
||||
NodeKeyExpiry Change = 24
|
||||
NodeNewOrUpdate Change = 25
|
||||
NodeEndpoint Change = 26
|
||||
NodeDERP Change = 27
|
||||
// RequiresRuntimePeerComputation indicates that peer visibility
|
||||
// must be computed at runtime per-node. Used for policy changes
|
||||
// where each node may have different peer visibility.
|
||||
RequiresRuntimePeerComputation bool
|
||||
}
|
||||
|
||||
// User changes.
|
||||
UserNewOrUpdate Change = 51
|
||||
UserRemove Change = 52
|
||||
)
|
||||
// boolFieldNames returns all boolean field names for exhaustive testing.
|
||||
// When adding a new boolean field to Change, add it here.
|
||||
// Tests use reflection to verify this matches the struct.
|
||||
func (r Change) boolFieldNames() []string {
|
||||
return []string{
|
||||
"IncludeSelf",
|
||||
"IncludeDERPMap",
|
||||
"IncludeDNS",
|
||||
"IncludeDomain",
|
||||
"IncludePolicy",
|
||||
"SendAllPeers",
|
||||
"RequiresRuntimePeerComputation",
|
||||
}
|
||||
}
|
||||
|
||||
// AlsoSelf reports whether this change should also be sent to the node itself.
|
||||
func (c Change) AlsoSelf() bool {
|
||||
switch c {
|
||||
case NodeRemove, NodeKeyExpiry, NodeNewOrUpdate:
|
||||
return true
|
||||
func (r Change) Merge(other Change) Change {
|
||||
merged := r
|
||||
|
||||
merged.IncludeSelf = r.IncludeSelf || other.IncludeSelf
|
||||
merged.IncludeDERPMap = r.IncludeDERPMap || other.IncludeDERPMap
|
||||
merged.IncludeDNS = r.IncludeDNS || other.IncludeDNS
|
||||
merged.IncludeDomain = r.IncludeDomain || other.IncludeDomain
|
||||
merged.IncludePolicy = r.IncludePolicy || other.IncludePolicy
|
||||
merged.SendAllPeers = r.SendAllPeers || other.SendAllPeers
|
||||
merged.RequiresRuntimePeerComputation = r.RequiresRuntimePeerComputation || other.RequiresRuntimePeerComputation
|
||||
|
||||
merged.PeersChanged = uniqueNodeIDs(append(r.PeersChanged, other.PeersChanged...))
|
||||
merged.PeersRemoved = uniqueNodeIDs(append(r.PeersRemoved, other.PeersRemoved...))
|
||||
merged.PeerPatches = append(r.PeerPatches, other.PeerPatches...)
|
||||
|
||||
if r.Reason != "" && other.Reason != "" && r.Reason != other.Reason {
|
||||
merged.Reason = r.Reason + "; " + other.Reason
|
||||
} else if other.Reason != "" {
|
||||
merged.Reason = other.Reason
|
||||
}
|
||||
|
||||
return false
|
||||
return merged
|
||||
}
|
||||
|
||||
type ChangeSet struct {
|
||||
Change Change
|
||||
|
||||
// SelfUpdateOnly indicates that this change should only be sent
|
||||
// to the node itself, and not to other nodes.
|
||||
// This is used for changes that are not relevant to other nodes.
|
||||
// NodeID must be set if this is true.
|
||||
SelfUpdateOnly bool
|
||||
|
||||
// NodeID if set, is the ID of the node that is being changed.
|
||||
// It must be set if this is a node change.
|
||||
NodeID types.NodeID
|
||||
|
||||
// UserID if set, is the ID of the user that is being changed.
|
||||
// It must be set if this is a user change.
|
||||
UserID types.UserID
|
||||
|
||||
// IsSubnetRouter indicates whether the node is a subnet router.
|
||||
IsSubnetRouter bool
|
||||
|
||||
// NodeExpiry is set if the change is NodeKeyExpiry.
|
||||
NodeExpiry *time.Time
|
||||
}
|
||||
|
||||
func (c *ChangeSet) Validate() error {
|
||||
if c.Change >= NodeCameOnline || c.Change <= NodeNewOrUpdate {
|
||||
if c.NodeID == 0 {
|
||||
return errors.New("ChangeSet.NodeID must be set for node updates")
|
||||
}
|
||||
func (r Change) IsEmpty() bool {
|
||||
if r.IncludeSelf || r.IncludeDERPMap || r.IncludeDNS ||
|
||||
r.IncludeDomain || r.IncludePolicy || r.SendAllPeers {
|
||||
return false
|
||||
}
|
||||
|
||||
if c.Change >= UserNewOrUpdate || c.Change <= UserRemove {
|
||||
if c.UserID == 0 {
|
||||
return errors.New("ChangeSet.UserID must be set for user updates")
|
||||
}
|
||||
if r.RequiresRuntimePeerComputation {
|
||||
return false
|
||||
}
|
||||
|
||||
return nil
|
||||
return len(r.PeersChanged) == 0 &&
|
||||
len(r.PeersRemoved) == 0 &&
|
||||
len(r.PeerPatches) == 0
|
||||
}
|
||||
|
||||
// Empty reports whether the ChangeSet is empty, meaning it does not
|
||||
// represent any change.
|
||||
func (c ChangeSet) Empty() bool {
|
||||
return c.Change == ChangeUnknown && c.NodeID == 0 && c.UserID == 0
|
||||
func (r Change) IsSelfOnly() bool {
|
||||
if r.TargetNode == 0 || !r.IncludeSelf {
|
||||
return false
|
||||
}
|
||||
|
||||
if r.SendAllPeers || len(r.PeersChanged) > 0 || len(r.PeersRemoved) > 0 || len(r.PeerPatches) > 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// IsFull reports whether the ChangeSet represents a full update.
|
||||
func (c ChangeSet) IsFull() bool {
|
||||
return c.Change == Full || c.Change == Policy
|
||||
// IsTargetedToNode returns true if this response should only be sent to TargetNode.
|
||||
func (r Change) IsTargetedToNode() bool {
|
||||
return r.TargetNode != 0
|
||||
}
|
||||
|
||||
func HasFull(cs []ChangeSet) bool {
|
||||
for _, c := range cs {
|
||||
if c.IsFull() {
|
||||
// IsFull reports whether this is a full update response.
|
||||
func (r Change) IsFull() bool {
|
||||
return r.SendAllPeers && r.IncludeSelf && r.IncludeDERPMap &&
|
||||
r.IncludeDNS && r.IncludeDomain && r.IncludePolicy
|
||||
}
|
||||
|
||||
// Type returns a categorized type string for metrics.
|
||||
// This provides a bounded set of values suitable for Prometheus labels,
|
||||
// unlike Reason which is free-form text for logging.
|
||||
func (r Change) Type() string {
|
||||
if r.IsFull() {
|
||||
return "full"
|
||||
}
|
||||
|
||||
if r.IsSelfOnly() {
|
||||
return "self"
|
||||
}
|
||||
|
||||
if r.RequiresRuntimePeerComputation {
|
||||
return "policy"
|
||||
}
|
||||
|
||||
if len(r.PeerPatches) > 0 && len(r.PeersChanged) == 0 && len(r.PeersRemoved) == 0 && !r.SendAllPeers {
|
||||
return "patch"
|
||||
}
|
||||
|
||||
if len(r.PeersChanged) > 0 || len(r.PeersRemoved) > 0 || r.SendAllPeers {
|
||||
return "peers"
|
||||
}
|
||||
|
||||
if r.IncludeDERPMap || r.IncludeDNS || r.IncludeDomain || r.IncludePolicy {
|
||||
return "config"
|
||||
}
|
||||
|
||||
return "unknown"
|
||||
}
|
||||
|
||||
// ShouldSendToNode determines if this response should be sent to nodeID.
|
||||
// It handles self-only targeting and filtering out self-updates for non-origin nodes.
|
||||
func (r Change) ShouldSendToNode(nodeID types.NodeID) bool {
|
||||
// If targeted to a specific node, only send to that node
|
||||
if r.TargetNode != 0 {
|
||||
return r.TargetNode == nodeID
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// HasFull returns true if any response in the slice is a full update.
|
||||
func HasFull(rs []Change) bool {
|
||||
for _, r := range rs {
|
||||
if r.IsFull() {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func SplitAllAndSelf(cs []ChangeSet) (all []ChangeSet, self []ChangeSet) {
|
||||
for _, c := range cs {
|
||||
if c.SelfUpdateOnly {
|
||||
self = append(self, c)
|
||||
// SplitTargetedAndBroadcast separates responses into targeted (to specific node) and broadcast.
|
||||
func SplitTargetedAndBroadcast(rs []Change) ([]Change, []Change) {
|
||||
var broadcast, targeted []Change
|
||||
|
||||
for _, r := range rs {
|
||||
if r.IsTargetedToNode() {
|
||||
targeted = append(targeted, r)
|
||||
} else {
|
||||
all = append(all, c)
|
||||
broadcast = append(broadcast, r)
|
||||
}
|
||||
}
|
||||
return all, self
|
||||
|
||||
return broadcast, targeted
|
||||
}
|
||||
|
||||
func RemoveUpdatesForSelf(id types.NodeID, cs []ChangeSet) (ret []ChangeSet) {
|
||||
for _, c := range cs {
|
||||
if c.NodeID != id || c.Change.AlsoSelf() {
|
||||
ret = append(ret, c)
|
||||
// FilterForNode returns responses that should be sent to the given node.
|
||||
func FilterForNode(nodeID types.NodeID, rs []Change) []Change {
|
||||
var result []Change
|
||||
|
||||
for _, r := range rs {
|
||||
if r.ShouldSendToNode(nodeID) {
|
||||
result = append(result, r)
|
||||
}
|
||||
}
|
||||
return ret
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// IsSelfUpdate reports whether this ChangeSet represents an update to the given node itself.
|
||||
func (c ChangeSet) IsSelfUpdate(nodeID types.NodeID) bool {
|
||||
return c.NodeID == nodeID
|
||||
}
|
||||
|
||||
func (c ChangeSet) AlsoSelf() bool {
|
||||
// If NodeID is 0, it means this ChangeSet is not related to a specific node,
|
||||
// so we consider it as a change that should be sent to all nodes.
|
||||
if c.NodeID == 0 {
|
||||
return true
|
||||
func uniqueNodeIDs(ids []types.NodeID) []types.NodeID {
|
||||
if len(ids) == 0 {
|
||||
return nil
|
||||
}
|
||||
return c.Change.AlsoSelf() || c.SelfUpdateOnly
|
||||
|
||||
slices.Sort(ids)
|
||||
|
||||
return slices.Compact(ids)
|
||||
}
|
||||
|
||||
var (
|
||||
EmptySet = ChangeSet{Change: ChangeUnknown}
|
||||
FullSet = ChangeSet{Change: Full}
|
||||
DERPSet = ChangeSet{Change: DERP}
|
||||
PolicySet = ChangeSet{Change: Policy}
|
||||
ExtraRecordsSet = ChangeSet{Change: ExtraRecords}
|
||||
)
|
||||
// Constructor functions
|
||||
|
||||
func FullSelf(id types.NodeID) ChangeSet {
|
||||
return ChangeSet{
|
||||
Change: Full,
|
||||
SelfUpdateOnly: true,
|
||||
NodeID: id,
|
||||
func FullUpdate() Change {
|
||||
return Change{
|
||||
Reason: "full update",
|
||||
IncludeSelf: true,
|
||||
IncludeDERPMap: true,
|
||||
IncludeDNS: true,
|
||||
IncludeDomain: true,
|
||||
IncludePolicy: true,
|
||||
SendAllPeers: true,
|
||||
}
|
||||
}
|
||||
|
||||
func NodeAdded(id types.NodeID) ChangeSet {
|
||||
return ChangeSet{
|
||||
Change: NodeNewOrUpdate,
|
||||
NodeID: id,
|
||||
// FullSelf returns a full update targeted at a specific node.
|
||||
func FullSelf(nodeID types.NodeID) Change {
|
||||
return Change{
|
||||
Reason: "full self update",
|
||||
TargetNode: nodeID,
|
||||
IncludeSelf: true,
|
||||
IncludeDERPMap: true,
|
||||
IncludeDNS: true,
|
||||
IncludeDomain: true,
|
||||
IncludePolicy: true,
|
||||
SendAllPeers: true,
|
||||
}
|
||||
}
|
||||
|
||||
func NodeRemoved(id types.NodeID) ChangeSet {
|
||||
return ChangeSet{
|
||||
Change: NodeRemove,
|
||||
NodeID: id,
|
||||
func SelfUpdate(nodeID types.NodeID) Change {
|
||||
return Change{
|
||||
Reason: "self update",
|
||||
TargetNode: nodeID,
|
||||
IncludeSelf: true,
|
||||
}
|
||||
}
|
||||
|
||||
func NodeOnline(node types.NodeView) ChangeSet {
|
||||
return ChangeSet{
|
||||
Change: NodeCameOnline,
|
||||
NodeID: node.ID(),
|
||||
IsSubnetRouter: node.IsSubnetRouter(),
|
||||
func PolicyOnly() Change {
|
||||
return Change{
|
||||
Reason: "policy update",
|
||||
IncludePolicy: true,
|
||||
}
|
||||
}
|
||||
|
||||
func NodeOffline(node types.NodeView) ChangeSet {
|
||||
return ChangeSet{
|
||||
Change: NodeWentOffline,
|
||||
NodeID: node.ID(),
|
||||
IsSubnetRouter: node.IsSubnetRouter(),
|
||||
func PolicyAndPeers(changedPeers ...types.NodeID) Change {
|
||||
return Change{
|
||||
Reason: "policy and peers update",
|
||||
IncludePolicy: true,
|
||||
PeersChanged: changedPeers,
|
||||
}
|
||||
}
|
||||
|
||||
func KeyExpiry(id types.NodeID, expiry time.Time) ChangeSet {
|
||||
return ChangeSet{
|
||||
Change: NodeKeyExpiry,
|
||||
NodeID: id,
|
||||
NodeExpiry: &expiry,
|
||||
func VisibilityChange(reason string, added, removed []types.NodeID) Change {
|
||||
return Change{
|
||||
Reason: reason,
|
||||
IncludePolicy: true,
|
||||
PeersChanged: added,
|
||||
PeersRemoved: removed,
|
||||
}
|
||||
}
|
||||
|
||||
func EndpointUpdate(id types.NodeID) ChangeSet {
|
||||
return ChangeSet{
|
||||
Change: NodeEndpoint,
|
||||
NodeID: id,
|
||||
func PeersChanged(reason string, peerIDs ...types.NodeID) Change {
|
||||
return Change{
|
||||
Reason: reason,
|
||||
PeersChanged: peerIDs,
|
||||
}
|
||||
}
|
||||
|
||||
func DERPUpdate(id types.NodeID) ChangeSet {
|
||||
return ChangeSet{
|
||||
Change: NodeDERP,
|
||||
NodeID: id,
|
||||
func PeersRemoved(peerIDs ...types.NodeID) Change {
|
||||
return Change{
|
||||
Reason: "peers removed",
|
||||
PeersRemoved: peerIDs,
|
||||
}
|
||||
}
|
||||
|
||||
func UserAdded(id types.UserID) ChangeSet {
|
||||
return ChangeSet{
|
||||
Change: UserNewOrUpdate,
|
||||
UserID: id,
|
||||
func PeerPatched(reason string, patches ...*tailcfg.PeerChange) Change {
|
||||
return Change{
|
||||
Reason: reason,
|
||||
PeerPatches: patches,
|
||||
}
|
||||
}
|
||||
|
||||
func UserRemoved(id types.UserID) ChangeSet {
|
||||
return ChangeSet{
|
||||
Change: UserRemove,
|
||||
UserID: id,
|
||||
func DERPMap() Change {
|
||||
return Change{
|
||||
Reason: "DERP map update",
|
||||
IncludeDERPMap: true,
|
||||
}
|
||||
}
|
||||
|
||||
func PolicyChange() ChangeSet {
|
||||
return ChangeSet{
|
||||
Change: Policy,
|
||||
// PolicyChange creates a response for policy changes.
|
||||
// Policy changes require runtime peer visibility computation.
|
||||
func PolicyChange() Change {
|
||||
return Change{
|
||||
Reason: "policy change",
|
||||
IncludePolicy: true,
|
||||
RequiresRuntimePeerComputation: true,
|
||||
}
|
||||
}
|
||||
|
||||
func DERPChange() ChangeSet {
|
||||
return ChangeSet{
|
||||
Change: DERP,
|
||||
// DNSConfig creates a response for DNS configuration updates.
|
||||
func DNSConfig() Change {
|
||||
return Change{
|
||||
Reason: "DNS config update",
|
||||
IncludeDNS: true,
|
||||
}
|
||||
}
|
||||
|
||||
// NodeOnline creates a patch response for a node coming online.
|
||||
func NodeOnline(nodeID types.NodeID) Change {
|
||||
return Change{
|
||||
Reason: "node online",
|
||||
PeerPatches: []*tailcfg.PeerChange{
|
||||
{
|
||||
NodeID: nodeID.NodeID(),
|
||||
Online: ptrTo(true),
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// NodeOffline creates a patch response for a node going offline.
|
||||
func NodeOffline(nodeID types.NodeID) Change {
|
||||
return Change{
|
||||
Reason: "node offline",
|
||||
PeerPatches: []*tailcfg.PeerChange{
|
||||
{
|
||||
NodeID: nodeID.NodeID(),
|
||||
Online: ptrTo(false),
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// KeyExpiry creates a patch response for a node's key expiry change.
|
||||
func KeyExpiry(nodeID types.NodeID, expiry *time.Time) Change {
|
||||
return Change{
|
||||
Reason: "key expiry",
|
||||
PeerPatches: []*tailcfg.PeerChange{
|
||||
{
|
||||
NodeID: nodeID.NodeID(),
|
||||
KeyExpiry: expiry,
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// ptrTo returns a pointer to the given value.
|
||||
func ptrTo[T any](v T) *T {
|
||||
return &v
|
||||
}
|
||||
|
||||
// High-level change constructors
|
||||
|
||||
// NodeAdded returns a Change for when a node is added or updated.
|
||||
// The OriginNode field enables self-update detection by the mapper.
|
||||
func NodeAdded(id types.NodeID) Change {
|
||||
c := PeersChanged("node added", id)
|
||||
c.OriginNode = id
|
||||
|
||||
return c
|
||||
}
|
||||
|
||||
// NodeRemoved returns a Change for when a node is removed.
|
||||
func NodeRemoved(id types.NodeID) Change {
|
||||
return PeersRemoved(id)
|
||||
}
|
||||
|
||||
// NodeOnlineFor returns a Change for when a node comes online.
|
||||
// If the node is a subnet router, a full update is sent instead of a patch.
|
||||
func NodeOnlineFor(node types.NodeView) Change {
|
||||
if node.IsSubnetRouter() {
|
||||
c := FullUpdate()
|
||||
c.Reason = "subnet router online"
|
||||
|
||||
return c
|
||||
}
|
||||
|
||||
return NodeOnline(node.ID())
|
||||
}
|
||||
|
||||
// NodeOfflineFor returns a Change for when a node goes offline.
|
||||
// If the node is a subnet router, a full update is sent instead of a patch.
|
||||
func NodeOfflineFor(node types.NodeView) Change {
|
||||
if node.IsSubnetRouter() {
|
||||
c := FullUpdate()
|
||||
c.Reason = "subnet router offline"
|
||||
|
||||
return c
|
||||
}
|
||||
|
||||
return NodeOffline(node.ID())
|
||||
}
|
||||
|
||||
// KeyExpiryFor returns a Change for when a node's key expiry changes.
|
||||
// The OriginNode field enables self-update detection by the mapper.
|
||||
func KeyExpiryFor(id types.NodeID, expiry time.Time) Change {
|
||||
c := KeyExpiry(id, &expiry)
|
||||
c.OriginNode = id
|
||||
|
||||
return c
|
||||
}
|
||||
|
||||
// EndpointOrDERPUpdate returns a Change for when a node's endpoints or DERP region changes.
|
||||
// The OriginNode field enables self-update detection by the mapper.
|
||||
func EndpointOrDERPUpdate(id types.NodeID, patch *tailcfg.PeerChange) Change {
|
||||
c := PeerPatched("endpoint/DERP update", patch)
|
||||
c.OriginNode = id
|
||||
|
||||
return c
|
||||
}
|
||||
|
||||
// UserAdded returns a Change for when a user is added or updated.
|
||||
// A full update is sent to refresh user profiles on all nodes.
|
||||
func UserAdded() Change {
|
||||
c := FullUpdate()
|
||||
c.Reason = "user added"
|
||||
|
||||
return c
|
||||
}
|
||||
|
||||
// UserRemoved returns a Change for when a user is removed.
|
||||
// A full update is sent to refresh user profiles on all nodes.
|
||||
func UserRemoved() Change {
|
||||
c := FullUpdate()
|
||||
c.Reason = "user removed"
|
||||
|
||||
return c
|
||||
}
|
||||
|
||||
// ExtraRecords returns a Change for when DNS extra records change.
|
||||
func ExtraRecords() Change {
|
||||
c := DNSConfig()
|
||||
c.Reason = "extra records update"
|
||||
|
||||
return c
|
||||
}
|
||||
|
||||
@@ -1,59 +0,0 @@
|
||||
// Code generated by "stringer -type=Change"; DO NOT EDIT.
|
||||
|
||||
package change
|
||||
|
||||
import "strconv"
|
||||
|
||||
func _() {
|
||||
// An "invalid array index" compiler error signifies that the constant values have changed.
|
||||
// Re-run the stringer command to generate them again.
|
||||
var x [1]struct{}
|
||||
_ = x[ChangeUnknown-0]
|
||||
_ = x[Full-9]
|
||||
_ = x[Policy-11]
|
||||
_ = x[DERP-12]
|
||||
_ = x[ExtraRecords-13]
|
||||
_ = x[NodeCameOnline-21]
|
||||
_ = x[NodeWentOffline-22]
|
||||
_ = x[NodeRemove-23]
|
||||
_ = x[NodeKeyExpiry-24]
|
||||
_ = x[NodeNewOrUpdate-25]
|
||||
_ = x[NodeEndpoint-26]
|
||||
_ = x[NodeDERP-27]
|
||||
_ = x[UserNewOrUpdate-51]
|
||||
_ = x[UserRemove-52]
|
||||
}
|
||||
|
||||
const (
|
||||
_Change_name_0 = "ChangeUnknown"
|
||||
_Change_name_1 = "Full"
|
||||
_Change_name_2 = "PolicyDERPExtraRecords"
|
||||
_Change_name_3 = "NodeCameOnlineNodeWentOfflineNodeRemoveNodeKeyExpiryNodeNewOrUpdateNodeEndpointNodeDERP"
|
||||
_Change_name_4 = "UserNewOrUpdateUserRemove"
|
||||
)
|
||||
|
||||
var (
|
||||
_Change_index_2 = [...]uint8{0, 6, 10, 22}
|
||||
_Change_index_3 = [...]uint8{0, 14, 29, 39, 52, 67, 79, 87}
|
||||
_Change_index_4 = [...]uint8{0, 15, 25}
|
||||
)
|
||||
|
||||
func (i Change) String() string {
|
||||
switch {
|
||||
case i == 0:
|
||||
return _Change_name_0
|
||||
case i == 9:
|
||||
return _Change_name_1
|
||||
case 11 <= i && i <= 13:
|
||||
i -= 11
|
||||
return _Change_name_2[_Change_index_2[i]:_Change_index_2[i+1]]
|
||||
case 21 <= i && i <= 27:
|
||||
i -= 21
|
||||
return _Change_name_3[_Change_index_3[i]:_Change_index_3[i+1]]
|
||||
case 51 <= i && i <= 52:
|
||||
i -= 51
|
||||
return _Change_name_4[_Change_index_4[i]:_Change_index_4[i+1]]
|
||||
default:
|
||||
return "Change(" + strconv.FormatInt(int64(i), 10) + ")"
|
||||
}
|
||||
}
|
||||
449
hscontrol/types/change/change_test.go
Normal file
449
hscontrol/types/change/change_test.go
Normal file
@@ -0,0 +1,449 @@
|
||||
package change
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"tailscale.com/tailcfg"
|
||||
)
|
||||
|
||||
func TestChange_FieldSync(t *testing.T) {
|
||||
r := Change{}
|
||||
fieldNames := r.boolFieldNames()
|
||||
|
||||
typ := reflect.TypeFor[Change]()
|
||||
boolCount := 0
|
||||
|
||||
for i := range typ.NumField() {
|
||||
if typ.Field(i).Type.Kind() == reflect.Bool {
|
||||
boolCount++
|
||||
}
|
||||
}
|
||||
|
||||
if len(fieldNames) != boolCount {
|
||||
t.Fatalf("boolFieldNames() returns %d fields but struct has %d bool fields; "+
|
||||
"update boolFieldNames() when adding new bool fields", len(fieldNames), boolCount)
|
||||
}
|
||||
}
|
||||
|
||||
func TestChange_IsEmpty(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
response Change
|
||||
want bool
|
||||
}{
|
||||
{
|
||||
name: "zero value is empty",
|
||||
response: Change{},
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "only reason is still empty",
|
||||
response: Change{Reason: "test"},
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "IncludeSelf not empty",
|
||||
response: Change{IncludeSelf: true},
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "IncludeDERPMap not empty",
|
||||
response: Change{IncludeDERPMap: true},
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "IncludeDNS not empty",
|
||||
response: Change{IncludeDNS: true},
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "IncludeDomain not empty",
|
||||
response: Change{IncludeDomain: true},
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "IncludePolicy not empty",
|
||||
response: Change{IncludePolicy: true},
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "SendAllPeers not empty",
|
||||
response: Change{SendAllPeers: true},
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "PeersChanged not empty",
|
||||
response: Change{PeersChanged: []types.NodeID{1}},
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "PeersRemoved not empty",
|
||||
response: Change{PeersRemoved: []types.NodeID{1}},
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "PeerPatches not empty",
|
||||
response: Change{PeerPatches: []*tailcfg.PeerChange{{}}},
|
||||
want: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := tt.response.IsEmpty()
|
||||
assert.Equal(t, tt.want, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestChange_IsSelfOnly(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
response Change
|
||||
want bool
|
||||
}{
|
||||
{
|
||||
name: "empty is not self only",
|
||||
response: Change{},
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "IncludeSelf without TargetNode is not self only",
|
||||
response: Change{IncludeSelf: true},
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "TargetNode without IncludeSelf is not self only",
|
||||
response: Change{TargetNode: 1},
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "TargetNode with IncludeSelf is self only",
|
||||
response: Change{TargetNode: 1, IncludeSelf: true},
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "self only with SendAllPeers is not self only",
|
||||
response: Change{TargetNode: 1, IncludeSelf: true, SendAllPeers: true},
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "self only with PeersChanged is not self only",
|
||||
response: Change{TargetNode: 1, IncludeSelf: true, PeersChanged: []types.NodeID{2}},
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "self only with PeersRemoved is not self only",
|
||||
response: Change{TargetNode: 1, IncludeSelf: true, PeersRemoved: []types.NodeID{2}},
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "self only with PeerPatches is not self only",
|
||||
response: Change{TargetNode: 1, IncludeSelf: true, PeerPatches: []*tailcfg.PeerChange{{}}},
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "self only with other include flags is still self only",
|
||||
response: Change{
|
||||
TargetNode: 1,
|
||||
IncludeSelf: true,
|
||||
IncludePolicy: true,
|
||||
IncludeDNS: true,
|
||||
},
|
||||
want: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := tt.response.IsSelfOnly()
|
||||
assert.Equal(t, tt.want, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestChange_Merge(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
r1 Change
|
||||
r2 Change
|
||||
want Change
|
||||
}{
|
||||
{
|
||||
name: "empty merge",
|
||||
r1: Change{},
|
||||
r2: Change{},
|
||||
want: Change{},
|
||||
},
|
||||
{
|
||||
name: "bool fields OR together",
|
||||
r1: Change{IncludeSelf: true, IncludePolicy: true},
|
||||
r2: Change{IncludeDERPMap: true, IncludePolicy: true},
|
||||
want: Change{IncludeSelf: true, IncludeDERPMap: true, IncludePolicy: true},
|
||||
},
|
||||
{
|
||||
name: "all bool fields merge",
|
||||
r1: Change{IncludeSelf: true, IncludeDNS: true, IncludePolicy: true},
|
||||
r2: Change{IncludeDERPMap: true, IncludeDomain: true, SendAllPeers: true},
|
||||
want: Change{
|
||||
IncludeSelf: true,
|
||||
IncludeDERPMap: true,
|
||||
IncludeDNS: true,
|
||||
IncludeDomain: true,
|
||||
IncludePolicy: true,
|
||||
SendAllPeers: true,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "peers deduplicated and sorted",
|
||||
r1: Change{PeersChanged: []types.NodeID{3, 1}},
|
||||
r2: Change{PeersChanged: []types.NodeID{2, 1}},
|
||||
want: Change{PeersChanged: []types.NodeID{1, 2, 3}},
|
||||
},
|
||||
{
|
||||
name: "peers removed deduplicated",
|
||||
r1: Change{PeersRemoved: []types.NodeID{1, 2}},
|
||||
r2: Change{PeersRemoved: []types.NodeID{2, 3}},
|
||||
want: Change{PeersRemoved: []types.NodeID{1, 2, 3}},
|
||||
},
|
||||
{
|
||||
name: "peer patches concatenated",
|
||||
r1: Change{PeerPatches: []*tailcfg.PeerChange{{NodeID: 1}}},
|
||||
r2: Change{PeerPatches: []*tailcfg.PeerChange{{NodeID: 2}}},
|
||||
want: Change{PeerPatches: []*tailcfg.PeerChange{{NodeID: 1}, {NodeID: 2}}},
|
||||
},
|
||||
{
|
||||
name: "reasons combined when different",
|
||||
r1: Change{Reason: "route change"},
|
||||
r2: Change{Reason: "tag change"},
|
||||
want: Change{Reason: "route change; tag change"},
|
||||
},
|
||||
{
|
||||
name: "same reason not duplicated",
|
||||
r1: Change{Reason: "policy"},
|
||||
r2: Change{Reason: "policy"},
|
||||
want: Change{Reason: "policy"},
|
||||
},
|
||||
{
|
||||
name: "empty reason takes other",
|
||||
r1: Change{},
|
||||
r2: Change{Reason: "update"},
|
||||
want: Change{Reason: "update"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := tt.r1.Merge(tt.r2)
|
||||
assert.Equal(t, tt.want, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestChange_Constructors(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
constructor func() Change
|
||||
wantReason string
|
||||
want Change
|
||||
}{
|
||||
{
|
||||
name: "FullUpdateResponse",
|
||||
constructor: FullUpdate,
|
||||
wantReason: "full update",
|
||||
want: Change{
|
||||
Reason: "full update",
|
||||
IncludeSelf: true,
|
||||
IncludeDERPMap: true,
|
||||
IncludeDNS: true,
|
||||
IncludeDomain: true,
|
||||
IncludePolicy: true,
|
||||
SendAllPeers: true,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "PolicyOnlyResponse",
|
||||
constructor: PolicyOnly,
|
||||
wantReason: "policy update",
|
||||
want: Change{
|
||||
Reason: "policy update",
|
||||
IncludePolicy: true,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "DERPMapResponse",
|
||||
constructor: DERPMap,
|
||||
wantReason: "DERP map update",
|
||||
want: Change{
|
||||
Reason: "DERP map update",
|
||||
IncludeDERPMap: true,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
r := tt.constructor()
|
||||
assert.Equal(t, tt.wantReason, r.Reason)
|
||||
assert.Equal(t, tt.want, r)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSelfUpdate(t *testing.T) {
|
||||
r := SelfUpdate(42)
|
||||
assert.Equal(t, "self update", r.Reason)
|
||||
assert.Equal(t, types.NodeID(42), r.TargetNode)
|
||||
assert.True(t, r.IncludeSelf)
|
||||
assert.True(t, r.IsSelfOnly())
|
||||
}
|
||||
|
||||
func TestPolicyAndPeers(t *testing.T) {
|
||||
r := PolicyAndPeers(1, 2, 3)
|
||||
assert.Equal(t, "policy and peers update", r.Reason)
|
||||
assert.True(t, r.IncludePolicy)
|
||||
assert.Equal(t, []types.NodeID{1, 2, 3}, r.PeersChanged)
|
||||
}
|
||||
|
||||
func TestVisibilityChange(t *testing.T) {
|
||||
r := VisibilityChange("tag change", []types.NodeID{1}, []types.NodeID{2, 3})
|
||||
assert.Equal(t, "tag change", r.Reason)
|
||||
assert.True(t, r.IncludePolicy)
|
||||
assert.Equal(t, []types.NodeID{1}, r.PeersChanged)
|
||||
assert.Equal(t, []types.NodeID{2, 3}, r.PeersRemoved)
|
||||
}
|
||||
|
||||
func TestPeersChanged(t *testing.T) {
|
||||
r := PeersChanged("routes approved", 1, 2)
|
||||
assert.Equal(t, "routes approved", r.Reason)
|
||||
assert.Equal(t, []types.NodeID{1, 2}, r.PeersChanged)
|
||||
assert.False(t, r.IncludePolicy)
|
||||
}
|
||||
|
||||
func TestPeersRemoved(t *testing.T) {
|
||||
r := PeersRemoved(1, 2, 3)
|
||||
assert.Equal(t, "peers removed", r.Reason)
|
||||
assert.Equal(t, []types.NodeID{1, 2, 3}, r.PeersRemoved)
|
||||
}
|
||||
|
||||
func TestPeerPatched(t *testing.T) {
|
||||
patch := &tailcfg.PeerChange{NodeID: 1}
|
||||
r := PeerPatched("endpoint change", patch)
|
||||
assert.Equal(t, "endpoint change", r.Reason)
|
||||
assert.Equal(t, []*tailcfg.PeerChange{patch}, r.PeerPatches)
|
||||
}
|
||||
|
||||
func TestChange_Type(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
response Change
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "full update",
|
||||
response: FullUpdate(),
|
||||
want: "full",
|
||||
},
|
||||
{
|
||||
name: "self only",
|
||||
response: SelfUpdate(1),
|
||||
want: "self",
|
||||
},
|
||||
{
|
||||
name: "policy with runtime computation",
|
||||
response: PolicyChange(),
|
||||
want: "policy",
|
||||
},
|
||||
{
|
||||
name: "patch only",
|
||||
response: PeerPatched("test", &tailcfg.PeerChange{NodeID: 1}),
|
||||
want: "patch",
|
||||
},
|
||||
{
|
||||
name: "peers changed",
|
||||
response: PeersChanged("test", 1, 2),
|
||||
want: "peers",
|
||||
},
|
||||
{
|
||||
name: "peers removed",
|
||||
response: PeersRemoved(1, 2),
|
||||
want: "peers",
|
||||
},
|
||||
{
|
||||
name: "config - DERP map",
|
||||
response: DERPMap(),
|
||||
want: "config",
|
||||
},
|
||||
{
|
||||
name: "config - DNS",
|
||||
response: DNSConfig(),
|
||||
want: "config",
|
||||
},
|
||||
{
|
||||
name: "config - policy only (no runtime)",
|
||||
response: PolicyOnly(),
|
||||
want: "config",
|
||||
},
|
||||
{
|
||||
name: "empty is unknown",
|
||||
response: Change{},
|
||||
want: "unknown",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := tt.response.Type()
|
||||
assert.Equal(t, tt.want, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestUniqueNodeIDs(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input []types.NodeID
|
||||
want []types.NodeID
|
||||
}{
|
||||
{
|
||||
name: "nil input",
|
||||
input: nil,
|
||||
want: nil,
|
||||
},
|
||||
{
|
||||
name: "empty input",
|
||||
input: []types.NodeID{},
|
||||
want: nil,
|
||||
},
|
||||
{
|
||||
name: "single element",
|
||||
input: []types.NodeID{1},
|
||||
want: []types.NodeID{1},
|
||||
},
|
||||
{
|
||||
name: "no duplicates",
|
||||
input: []types.NodeID{1, 2, 3},
|
||||
want: []types.NodeID{1, 2, 3},
|
||||
},
|
||||
{
|
||||
name: "with duplicates",
|
||||
input: []types.NodeID{3, 1, 2, 1, 3},
|
||||
want: []types.NodeID{1, 2, 3},
|
||||
},
|
||||
{
|
||||
name: "all same",
|
||||
input: []types.NodeID{5, 5, 5, 5},
|
||||
want: []types.NodeID{5},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := uniqueNodeIDs(tt.input)
|
||||
assert.Equal(t, tt.want, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user