diff --git a/ipn/ipnlocal/local.go b/ipn/ipnlocal/local.go index 95fe22641..0898d5800 100644 --- a/ipn/ipnlocal/local.go +++ b/ipn/ipnlocal/local.go @@ -167,6 +167,17 @@ type watchSession struct { cancel context.CancelFunc // to shut down the session } +var ( + // errShutdown indicates that the [LocalBackend.Shutdown] was called. + errShutdown = errors.New("shutting down") + + // errNodeContextChanged indicates that [LocalBackend] has switched + // to a different [localNodeContext], usually due to a profile change. + // It is used as a context cancellation cause for the old context + // and can be returned when an operation is performed on it. + errNodeContextChanged = errors.New("profile changed") +) + // LocalBackend is the glue between the major pieces of the Tailscale // network software: the cloud control plane (via controlclient), the // network data plane (via wgengine), and the user-facing UIs and CLIs @@ -179,11 +190,11 @@ type watchSession struct { // state machine generates events back out to zero or more components. type LocalBackend struct { // Elements that are thread-safe or constant after construction. - ctx context.Context // canceled by [LocalBackend.Shutdown] - ctxCancel context.CancelFunc // cancels ctx - logf logger.Logf // general logging - keyLogf logger.Logf // for printing list of peers on change - statsLogf logger.Logf // for printing peers stats on change + ctx context.Context // canceled by [LocalBackend.Shutdown] + ctxCancel context.CancelCauseFunc // cancels ctx + logf logger.Logf // general logging + keyLogf logger.Logf // for printing list of peers on change + statsLogf logger.Logf // for printing peers stats on change sys *tsd.System health *health.Tracker // always non-nil metrics metrics @@ -479,7 +490,7 @@ func NewLocalBackend(logf logger.Logf, logID logid.PublicID, sys *tsd.System, lo envknob.LogCurrent(logf) osshare.SetFileSharingEnabled(false, logf) - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancelCause(context.Background()) clock := tstime.StdClock{} // Until we transition to a Running state, use a canceled context for @@ -519,7 +530,10 @@ func NewLocalBackend(logf logger.Logf, logID logid.PublicID, sys *tsd.System, lo captiveCancel: nil, // so that we start checkCaptivePortalLoop when Running needsCaptiveDetection: make(chan bool), } - b.currentNodeAtomic.Store(newLocalNodeContext()) + cn := newLocalNodeContext(ctx) + b.currentNodeAtomic.Store(cn) + cn.ready() + mConn.SetNetInfoCallback(b.setNetInfo) if sys.InitialConfig != nil { @@ -599,8 +613,10 @@ func (b *LocalBackend) currentNode() *localNodeContext { return v } // Auto-init one in tests for LocalBackend created without the NewLocalBackend constructor... - v := newLocalNodeContext() - b.currentNodeAtomic.CompareAndSwap(nil, v) + v := newLocalNodeContext(cmp.Or(b.ctx, context.Background())) + if b.currentNodeAtomic.CompareAndSwap(nil, v) { + v.ready() + } return b.currentNodeAtomic.Load() } @@ -1099,8 +1115,8 @@ func (b *LocalBackend) Shutdown() { if cc != nil { cc.Shutdown() } + b.ctxCancel(errShutdown) extHost.Shutdown() - b.ctxCancel() b.e.Close() <-b.e.Done() b.awaitNoGoroutinesInTest() @@ -7396,7 +7412,11 @@ func (b *LocalBackend) resetForProfileChangeLockedOnEntry(unlock unlockOnce) err // down, so no need to do any work. return nil } - b.currentNodeAtomic.Store(newLocalNodeContext()) + newNodeCtx := newLocalNodeContext(b.ctx) + if oldNodeCtx := b.currentNodeAtomic.Swap(newNodeCtx); oldNodeCtx != nil { + oldNodeCtx.shutdown(errNodeContextChanged) + } + defer newNodeCtx.ready() b.setNetMapLocked(nil) // Reset netmap. b.updateFilterLocked(ipn.PrefsView{}) // Reset the NetworkMap in the engine diff --git a/ipn/ipnlocal/local_node_context.go b/ipn/ipnlocal/local_node_context.go index 871880893..d81400bc6 100644 --- a/ipn/ipnlocal/local_node_context.go +++ b/ipn/ipnlocal/local_node_context.go @@ -4,6 +4,7 @@ package ipnlocal import ( + "context" "net/netip" "sync" "sync/atomic" @@ -30,7 +31,7 @@ import ( // Two pointers to different [localNodeContext] instances represent different local nodes. // However, there's currently a bug where a new [localNodeContext] might not be created // during an implicit node switch (see tailscale/corp#28014). - +// // In the future, we might want to include at least the following in this struct (in addition to the current fields). // However, not everything should be exported or otherwise made available to the outside world (e.g. [ipnext] extensions, // peer API handlers, etc.). @@ -52,6 +53,9 @@ import ( // Even if they're tied to the local node, instead of moving them here, we should extract the entire feature // into a separate package and have it install proper hooks. type localNodeContext struct { + ctx context.Context // canceled by [localNodeContext.shutdown] + ctxCancel context.CancelCauseFunc // cancels ctx + // filterAtomic is a stateful packet filter. Immutable once created, but can be // replaced with a new one. filterAtomic atomic.Pointer[filter.Filter] @@ -59,6 +63,8 @@ type localNodeContext struct { // TODO(nickkhyl): maybe use sync.RWMutex? mu sync.Mutex // protects the following fields + readyCh chan struct{} // closed by [localNodeContext.ready]; nil after shutdown + // NetMap is the most recently set full netmap from the controlclient. // It can't be mutated in place once set. Because it can't be mutated in place, // delta updates from the control server don't apply to it. Instead, use @@ -79,14 +85,26 @@ type localNodeContext struct { nodeByAddr map[netip.Addr]tailcfg.NodeID } -func newLocalNodeContext() *localNodeContext { - cn := &localNodeContext{} +func newLocalNodeContext(ctx context.Context) *localNodeContext { + ctx, ctxCancel := context.WithCancelCause(ctx) + cn := &localNodeContext{ + ctx: ctx, + ctxCancel: ctxCancel, + readyCh: make(chan struct{}), + } // Default filter blocks everything and logs nothing. noneFilter := filter.NewAllowNone(logger.Discard, &netipx.IPSet{}) cn.filterAtomic.Store(noneFilter) return cn } +// Context returns a context that is canceled when the [localNodeContext] shuts down, +// either because [LocalBackend] is switching to a different [localNodeContext] +// or shutting down itself. +func (c *localNodeContext) Context() context.Context { + return c.ctx +} + func (c *localNodeContext) Self() tailcfg.NodeView { c.mu.Lock() defer c.mu.Unlock() @@ -205,3 +223,50 @@ func (c *localNodeContext) unlockedNodesPermitted(packetFilter []filter.Match) b func (c *localNodeContext) filter() *filter.Filter { return c.filterAtomic.Load() } + +// ready signals that [LocalBackend] has completed the switch to this [localNodeContext] +// and any pending calls to [localNodeContext.wait] must be unblocked. +func (c *localNodeContext) ready() { + c.mu.Lock() + defer c.mu.Unlock() + if c.readyCh != nil { + close(c.readyCh) + } +} + +// Wait blocks until [LocalBackend] completes the switch to this [localNodeContext] +// and calls [localNodeContext.ready]. It returns an error if the provided context +// is canceled or if the [localNodeContext] shuts down or is already shut down. +// +// It must not be called with the [LocalBackend]' internal mutex held as [LocalBackend] +// may need to acquire it to complete the switch. +// +// TODO(nickkhyl): Relax this restriction once [LocalBackend]'s state machine +// runs in its own goroutine, or if we decide that waiting for the state machine +// restart to finish isn't necessary for [LocalBackend] to consider the switch complete. +// We mostly need this because of [LocalBackend.Start] acquiring b.mu and the fact that +// methods like [LocalBackend.SwitchProfile] must report any errors returned by it. +// Perhaps we could report those errors asynchronously as [health.Warnable]s? +func (c *localNodeContext) Wait(ctx context.Context) error { + c.mu.Lock() + readyCh := c.readyCh + c.mu.Unlock() + + select { + case <-ctx.Done(): + return ctx.Err() + case <-c.ctx.Done(): + return context.Cause(ctx) + case <-readyCh: + return nil + } +} + +// shutdown cancels the context with the given cause and shuts down the receiver. +func (c *localNodeContext) shutdown(cause error) { + c.ctxCancel(cause) + + c.mu.Lock() + defer c.mu.Unlock() + c.readyCh = nil +}