diff --git a/ipn/ipnlocal/local.go b/ipn/ipnlocal/local.go index 7b7893bc1..daedb1e19 100644 --- a/ipn/ipnlocal/local.go +++ b/ipn/ipnlocal/local.go @@ -168,6 +168,17 @@ type watchSession struct { var metricCaptivePortalDetected = clientmetric.NewCounter("captiveportal_detected") +var ( + // errShutdown indicates that the [LocalBackend.Shutdown] was called. + errShutdown = errors.New("shutting down") + + // errNodeContextChanged indicates that [LocalBackend] has switched + // to a different [localNodeContext], usually due to a profile change. + // It is used as a context cancellation cause for the old context + // and can be returned when an operation is performed on it. + errNodeContextChanged = errors.New("profile changed") +) + // LocalBackend is the glue between the major pieces of the Tailscale // network software: the cloud control plane (via controlclient), the // network data plane (via wgengine), and the user-facing UIs and CLIs @@ -180,11 +191,11 @@ var metricCaptivePortalDetected = clientmetric.NewCounter("captiveportal_detecte // state machine generates events back out to zero or more components. type LocalBackend struct { // Elements that are thread-safe or constant after construction. - ctx context.Context // canceled by [LocalBackend.Shutdown] - ctxCancel context.CancelFunc // cancels ctx - logf logger.Logf // general logging - keyLogf logger.Logf // for printing list of peers on change - statsLogf logger.Logf // for printing peers stats on change + ctx context.Context // canceled by [LocalBackend.Shutdown] + ctxCancel context.CancelCauseFunc // cancels ctx + logf logger.Logf // general logging + keyLogf logger.Logf // for printing list of peers on change + statsLogf logger.Logf // for printing peers stats on change sys *tsd.System health *health.Tracker // always non-nil metrics metrics @@ -463,7 +474,7 @@ func NewLocalBackend(logf logger.Logf, logID logid.PublicID, sys *tsd.System, lo envknob.LogCurrent(logf) - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancelCause(context.Background()) clock := tstime.StdClock{} // Until we transition to a Running state, use a canceled context for @@ -503,7 +514,10 @@ func NewLocalBackend(logf logger.Logf, logID logid.PublicID, sys *tsd.System, lo captiveCancel: nil, // so that we start checkCaptivePortalLoop when Running needsCaptiveDetection: make(chan bool), } - b.currentNodeAtomic.Store(newNodeBackend()) + nb := newNodeBackend(ctx) + b.currentNodeAtomic.Store(nb) + nb.ready() + mConn.SetNetInfoCallback(b.setNetInfo) if sys.InitialConfig != nil { @@ -586,8 +600,10 @@ func (b *LocalBackend) currentNode() *nodeBackend { return v } // Auto-init one in tests for LocalBackend created without the NewLocalBackend constructor... - v := newNodeBackend() - b.currentNodeAtomic.CompareAndSwap(nil, v) + v := newNodeBackend(cmp.Or(b.ctx, context.Background())) + if b.currentNodeAtomic.CompareAndSwap(nil, v) { + v.ready() + } return b.currentNodeAtomic.Load() } @@ -1089,8 +1105,9 @@ func (b *LocalBackend) Shutdown() { if cc != nil { cc.Shutdown() } + b.ctxCancel(errShutdown) + b.currentNode().shutdown(errShutdown) extHost.Shutdown() - b.ctxCancel() b.e.Close() <-b.e.Done() b.awaitNoGoroutinesInTest() @@ -6992,7 +7009,11 @@ func (b *LocalBackend) resetForProfileChangeLockedOnEntry(unlock unlockOnce) err // down, so no need to do any work. return nil } - b.currentNodeAtomic.Store(newNodeBackend()) + newNode := newNodeBackend(b.ctx) + if oldNode := b.currentNodeAtomic.Swap(newNode); oldNode != nil { + oldNode.shutdown(errNodeContextChanged) + } + defer newNode.ready() b.setNetMapLocked(nil) // Reset netmap. b.updateFilterLocked(ipn.PrefsView{}) // Reset the NetworkMap in the engine diff --git a/ipn/ipnlocal/node_backend.go b/ipn/ipnlocal/node_backend.go index fb77f38eb..361d10bb6 100644 --- a/ipn/ipnlocal/node_backend.go +++ b/ipn/ipnlocal/node_backend.go @@ -5,6 +5,7 @@ package ipnlocal import ( "cmp" + "context" "net/netip" "slices" "sync" @@ -39,7 +40,7 @@ import ( // Two pointers to different [nodeBackend] instances represent different local nodes. // However, there's currently a bug where a new [nodeBackend] might not be created // during an implicit node switch (see tailscale/corp#28014). - +// // In the future, we might want to include at least the following in this struct (in addition to the current fields). // However, not everything should be exported or otherwise made available to the outside world (e.g. [ipnext] extensions, // peer API handlers, etc.). @@ -61,6 +62,9 @@ import ( // Even if they're tied to the local node, instead of moving them here, we should extract the entire feature // into a separate package and have it install proper hooks. type nodeBackend struct { + ctx context.Context // canceled by [nodeBackend.shutdown] + ctxCancel context.CancelCauseFunc // cancels ctx + // filterAtomic is a stateful packet filter. Immutable once created, but can be // replaced with a new one. filterAtomic atomic.Pointer[filter.Filter] @@ -68,6 +72,9 @@ type nodeBackend struct { // TODO(nickkhyl): maybe use sync.RWMutex? mu sync.Mutex // protects the following fields + shutdownOnce sync.Once // guards calling [nodeBackend.shutdown] + readyCh chan struct{} // closed by [nodeBackend.ready]; nil after shutdown + // NetMap is the most recently set full netmap from the controlclient. // It can't be mutated in place once set. Because it can't be mutated in place, // delta updates from the control server don't apply to it. Instead, use @@ -88,12 +95,24 @@ type nodeBackend struct { nodeByAddr map[netip.Addr]tailcfg.NodeID } -func newNodeBackend() *nodeBackend { - cn := &nodeBackend{} +func newNodeBackend(ctx context.Context) *nodeBackend { + ctx, ctxCancel := context.WithCancelCause(ctx) + nb := &nodeBackend{ + ctx: ctx, + ctxCancel: ctxCancel, + readyCh: make(chan struct{}), + } // Default filter blocks everything and logs nothing. noneFilter := filter.NewAllowNone(logger.Discard, &netipx.IPSet{}) - cn.filterAtomic.Store(noneFilter) - return cn + nb.filterAtomic.Store(noneFilter) + return nb +} + +// Context returns a context that is canceled when the [nodeBackend] shuts down, +// either because [LocalBackend] is switching to a different [nodeBackend] +// or is shutting down itself. +func (nb *nodeBackend) Context() context.Context { + return nb.ctx } func (nb *nodeBackend) Self() tailcfg.NodeView { @@ -475,6 +494,59 @@ func (nb *nodeBackend) exitNodeCanProxyDNS(exitNodeID tailcfg.StableNodeID) (doh return exitNodeCanProxyDNS(nb.netMap, nb.peers, exitNodeID) } +// ready signals that [LocalBackend] has completed the switch to this [nodeBackend] +// and any pending calls to [nodeBackend.Wait] must be unblocked. +func (nb *nodeBackend) ready() { + nb.mu.Lock() + defer nb.mu.Unlock() + if nb.readyCh != nil { + close(nb.readyCh) + } +} + +// Wait blocks until [LocalBackend] completes the switch to this [nodeBackend] +// and calls [nodeBackend.ready]. It returns an error if the provided context +// is canceled or if the [nodeBackend] shuts down or is already shut down. +// +// It must not be called with the [LocalBackend]'s internal mutex held as [LocalBackend] +// may need to acquire it to complete the switch. +// +// TODO(nickkhyl): Relax this restriction once [LocalBackend]'s state machine +// runs in its own goroutine, or if we decide that waiting for the state machine +// restart to finish isn't necessary for [LocalBackend] to consider the switch complete. +// We mostly need this because of [LocalBackend.Start] acquiring b.mu and the fact that +// methods like [LocalBackend.SwitchProfile] must report any errors returned by it. +// Perhaps we could report those errors asynchronously as [health.Warnable]s? +func (nb *nodeBackend) Wait(ctx context.Context) error { + nb.mu.Lock() + readyCh := nb.readyCh + nb.mu.Unlock() + + select { + case <-ctx.Done(): + return ctx.Err() + case <-nb.ctx.Done(): + return context.Cause(nb.ctx) + case <-readyCh: + return nil + } +} + +// shutdown shuts down the [nodeBackend] and cancels its context +// with the provided cause. +func (nb *nodeBackend) shutdown(cause error) { + nb.shutdownOnce.Do(func() { + nb.doShutdown(cause) + }) +} + +func (nb *nodeBackend) doShutdown(cause error) { + nb.mu.Lock() + defer nb.mu.Unlock() + nb.ctxCancel(cause) + nb.readyCh = nil +} + // dnsConfigForNetmap returns a *dns.Config for the given netmap, // prefs, client OS version, and cloud hosting environment. // diff --git a/ipn/ipnlocal/node_backend_test.go b/ipn/ipnlocal/node_backend_test.go new file mode 100644 index 000000000..a82b60a9a --- /dev/null +++ b/ipn/ipnlocal/node_backend_test.go @@ -0,0 +1,121 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package ipnlocal + +import ( + "context" + "errors" + "testing" + "time" +) + +func TestNodeBackendReadiness(t *testing.T) { + nb := newNodeBackend(t.Context()) + + // The node backend is not ready until [nodeBackend.ready] is called, + // and [nodeBackend.Wait] should fail with [context.DeadlineExceeded]. + ctx, cancelCtx := context.WithTimeout(context.Background(), 100*time.Millisecond) + defer cancelCtx() + if err := nb.Wait(ctx); err != ctx.Err() { + t.Fatalf("Wait: got %v; want %v", err, ctx.Err()) + } + + // Start a goroutine to wait for the node backend to become ready. + waitDone := make(chan struct{}) + go func() { + if err := nb.Wait(context.Background()); err != nil { + t.Errorf("Wait: got %v; want nil", err) + } + close(waitDone) + }() + + // Call [nodeBackend.ready] to indicate that the node backend is now ready. + go nb.ready() + + // Once the backend is called, [nodeBackend.Wait] should return immediately without error. + if err := nb.Wait(context.Background()); err != nil { + t.Fatalf("Wait: got %v; want nil", err) + } + // And any pending waiters should also be unblocked. + <-waitDone +} + +func TestNodeBackendShutdown(t *testing.T) { + nb := newNodeBackend(t.Context()) + + shutdownCause := errors.New("test shutdown") + + // Start a goroutine to wait for the node backend to become ready. + // This test expects it to block until the node backend shuts down + // and then return the specified shutdown cause. + waitDone := make(chan struct{}) + go func() { + if err := nb.Wait(context.Background()); err != shutdownCause { + t.Errorf("Wait: got %v; want %v", err, shutdownCause) + } + close(waitDone) + }() + + // Call [nodeBackend.shutdown] to indicate that the node backend is shutting down. + nb.shutdown(shutdownCause) + + // Calling it again is fine, but should not change the shutdown cause. + nb.shutdown(errors.New("test shutdown again")) + + // After shutdown, [nodeBackend.Wait] should return with the specified shutdown cause. + if err := nb.Wait(context.Background()); err != shutdownCause { + t.Fatalf("Wait: got %v; want %v", err, shutdownCause) + } + // The context associated with the node backend should also be cancelled + // and its cancellation cause should match the shutdown cause. + if err := nb.Context().Err(); !errors.Is(err, context.Canceled) { + t.Fatalf("Context.Err: got %v; want %v", err, context.Canceled) + } + if cause := context.Cause(nb.Context()); cause != shutdownCause { + t.Fatalf("Cause: got %v; want %v", cause, shutdownCause) + } + // And any pending waiters should also be unblocked. + <-waitDone +} + +func TestNodeBackendReadyAfterShutdown(t *testing.T) { + nb := newNodeBackend(t.Context()) + + shutdownCause := errors.New("test shutdown") + nb.shutdown(shutdownCause) + nb.ready() // Calling ready after shutdown is a no-op, but should not panic, etc. + if err := nb.Wait(context.Background()); err != shutdownCause { + t.Fatalf("Wait: got %v; want %v", err, shutdownCause) + } +} + +func TestNodeBackendParentContextCancellation(t *testing.T) { + ctx, cancelCtx := context.WithCancel(context.Background()) + nb := newNodeBackend(ctx) + + cancelCtx() + + // Cancelling the parent context should cause [nodeBackend.Wait] + // to return with [context.Canceled]. + if err := nb.Wait(context.Background()); !errors.Is(err, context.Canceled) { + t.Fatalf("Wait: got %v; want %v", err, context.Canceled) + } + + // And the node backend's context should also be cancelled. + if err := nb.Context().Err(); !errors.Is(err, context.Canceled) { + t.Fatalf("Context.Err: got %v; want %v", err, context.Canceled) + } +} + +func TestNodeBackendConcurrentReadyAndShutdown(t *testing.T) { + nb := newNodeBackend(t.Context()) + + // Calling [nodeBackend.ready] and [nodeBackend.shutdown] concurrently + // should not cause issues, and [nodeBackend.Wait] should unblock, + // but the result of [nodeBackend.Wait] is intentionally undefined. + go nb.ready() + go nb.shutdown(errors.New("test shutdown")) + + nb.Wait(context.Background()) +}