diff --git a/ipn/ipnlocal/local.go b/ipn/ipnlocal/local.go index c4f68e929..4efb4895a 100644 --- a/ipn/ipnlocal/local.go +++ b/ipn/ipnlocal/local.go @@ -96,6 +96,7 @@ "tailscale.com/types/views" "tailscale.com/util/deephash" "tailscale.com/util/dnsname" + "tailscale.com/util/goroutines" "tailscale.com/util/httpm" "tailscale.com/util/mak" "tailscale.com/util/multierr" @@ -178,7 +179,7 @@ type watchSession struct { // state machine generates events back out to zero or more components. type LocalBackend struct { // Elements that are thread-safe or constant after construction. - ctx context.Context // canceled by Close + ctx context.Context // canceled by [LocalBackend.Shutdown] ctxCancel context.CancelFunc // cancels ctx logf logger.Logf // general logging keyLogf logger.Logf // for printing list of peers on change @@ -231,6 +232,10 @@ type LocalBackend struct { shouldInterceptTCPPortAtomic syncs.AtomicValue[func(uint16) bool] numClientStatusCalls atomic.Uint32 + // goTracker accounts for all goroutines started by LocalBacked, primarily + // for testing and graceful shutdown purposes. + goTracker goroutines.Tracker + // The mutex protects the following elements. mu sync.Mutex conf *conffile.Config // latest parsed config, or nil if not in declarative mode @@ -866,7 +871,7 @@ func (b *LocalBackend) linkChange(delta *netmon.ChangeDelta) { // TODO(raggi,tailscale/corp#22574): authReconfig should be refactored such that we can call the // necessary operations here and avoid the need for asynchronous behavior that is racy and hard // to test here, and do less extra work in these conditions. - go b.authReconfig() + b.goTracker.Go(b.authReconfig) } } @@ -879,7 +884,7 @@ func (b *LocalBackend) linkChange(delta *netmon.ChangeDelta) { want := b.netMap.GetAddresses().Len() if len(b.peerAPIListeners) < want { b.logf("linkChange: peerAPIListeners too low; trying again") - go b.initPeerAPIListener() + b.goTracker.Go(b.initPeerAPIListener) } } } @@ -1004,6 +1009,33 @@ func (b *LocalBackend) Shutdown() { b.ctxCancel() b.e.Close() <-b.e.Done() + b.awaitNoGoroutinesInTest() +} + +func (b *LocalBackend) awaitNoGoroutinesInTest() { + if !testenv.InTest() { + return + } + ctx, cancel := context.WithTimeout(context.Background(), 8*time.Second) + defer cancel() + + ch := make(chan bool, 1) + defer b.goTracker.AddDoneCallback(func() { ch <- true })() + + for { + n := b.goTracker.RunningGoroutines() + if n == 0 { + return + } + select { + case <-ctx.Done(): + // TODO(bradfitz): pass down some TB-like failer interface from + // tests, without depending on testing from here? + // But this is fine in tests too: + panic(fmt.Sprintf("timeout waiting for %d goroutines to stop", n)) + case <-ch: + } + } } func stripKeysFromPrefs(p ipn.PrefsView) ipn.PrefsView { @@ -2154,7 +2186,7 @@ func (b *LocalBackend) Start(opts ipn.Options) error { if b.portpoll != nil { b.portpollOnce.Do(func() { - go b.readPoller() + b.goTracker.Go(b.readPoller) }) } @@ -2368,7 +2400,7 @@ func (b *LocalBackend) updateFilterLocked(netMap *netmap.NetworkMap, prefs ipn.P b.e.SetJailedFilter(filter.NewShieldsUpFilter(localNets, logNets, oldJailedFilter, b.logf)) if b.sshServer != nil { - go b.sshServer.OnPolicyChange() + b.goTracker.Go(b.sshServer.OnPolicyChange) } } @@ -2845,7 +2877,7 @@ func (b *LocalBackend) WatchNotificationsAs(ctx context.Context, actor ipnauth.A // request every 2 seconds. // TODO(bradfitz): plumb this further and only send a Notify on change. if mask&ipn.NotifyWatchEngineUpdates != 0 { - go b.pollRequestEngineStatus(ctx) + b.goTracker.Go(func() { b.pollRequestEngineStatus(ctx) }) } // TODO(marwan-at-work): streaming background logs? @@ -3852,7 +3884,7 @@ func (b *LocalBackend) editPrefsLockedOnEntry(mp *ipn.MaskedPrefs, unlock unlock if mp.EggSet { mp.EggSet = false b.egg = true - go b.doSetHostinfoFilterServices() + b.goTracker.Go(b.doSetHostinfoFilterServices) } p0 := b.pm.CurrentPrefs() p1 := b.pm.CurrentPrefs().AsStruct() @@ -3945,7 +3977,7 @@ func (b *LocalBackend) setPrefsLockedOnEntry(newp *ipn.Prefs, unlock unlockOnce) if oldp.ShouldSSHBeRunning() && !newp.ShouldSSHBeRunning() { if b.sshServer != nil { - go b.sshServer.Shutdown() + b.goTracker.Go(b.sshServer.Shutdown) b.sshServer = nil } } @@ -4287,8 +4319,14 @@ func (b *LocalBackend) authReconfig() { dcfg := dnsConfigForNetmap(nm, b.peers, prefs, b.keyExpired, b.logf, version.OS()) // If the current node is an app connector, ensure the app connector machine is started b.reconfigAppConnectorLocked(nm, prefs) + closing := b.shutdownCalled b.mu.Unlock() + if closing { + b.logf("[v1] authReconfig: skipping because in shutdown") + return + } + if blocked { b.logf("[v1] authReconfig: blocked, skipping.") return @@ -4753,7 +4791,7 @@ func (b *LocalBackend) initPeerAPIListener() { b.peerAPIListeners = append(b.peerAPIListeners, pln) } - go b.doSetHostinfoFilterServices() + b.goTracker.Go(b.doSetHostinfoFilterServices) } // magicDNSRootDomains returns the subset of nm.DNS.Domains that are the search domains for MagicDNS. @@ -5022,7 +5060,7 @@ func (b *LocalBackend) enterStateLockedOnEntry(newState ipn.State, unlock unlock // can be shut down if we transition away from Running. if b.captiveCancel == nil { b.captiveCtx, b.captiveCancel = context.WithCancel(b.ctx) - go b.checkCaptivePortalLoop(b.captiveCtx) + b.goTracker.Go(func() { b.checkCaptivePortalLoop(b.captiveCtx) }) } } else if oldState == ipn.Running { // Transitioning away from running. @@ -5274,7 +5312,7 @@ func (b *LocalBackend) requestEngineStatusAndWait() { b.statusLock.Lock() defer b.statusLock.Unlock() - go b.e.RequestStatus() + b.goTracker.Go(b.e.RequestStatus) b.logf("requestEngineStatusAndWait: waiting...") b.statusChanged.Wait() // temporarily releases lock while waiting b.logf("requestEngineStatusAndWait: got status update.") @@ -5385,7 +5423,7 @@ func (b *LocalBackend) setWebClientAtomicBoolLocked(nm *netmap.NetworkMap) { shouldRun := !nm.HasCap(tailcfg.NodeAttrDisableWebClient) wasRunning := b.webClientAtomicBool.Swap(shouldRun) if wasRunning && !shouldRun { - go b.webClientShutdown() // stop web client + b.goTracker.Go(b.webClientShutdown) // stop web client } } @@ -5903,7 +5941,7 @@ func (b *LocalBackend) setTCPPortsInterceptedFromNetmapAndPrefsLocked(prefs ipn. if wire := b.wantIngressLocked(); b.hostinfo != nil && b.hostinfo.WireIngress != wire { b.logf("Hostinfo.WireIngress changed to %v", wire) b.hostinfo.WireIngress = wire - go b.doSetHostinfoFilterServices() + b.goTracker.Go(b.doSetHostinfoFilterServices) } b.setTCPPortsIntercepted(handlePorts) diff --git a/util/goroutines/goroutines.go b/util/goroutines/goroutines.go index 9758b0758..d40cbecb1 100644 --- a/util/goroutines/goroutines.go +++ b/util/goroutines/goroutines.go @@ -1,7 +1,7 @@ // Copyright (c) Tailscale Inc & AUTHORS // SPDX-License-Identifier: BSD-3-Clause -// The goroutines package contains utilities for getting active goroutines. +// The goroutines package contains utilities for tracking and getting active goroutines. package goroutines import ( diff --git a/util/goroutines/tracker.go b/util/goroutines/tracker.go new file mode 100644 index 000000000..044843d33 --- /dev/null +++ b/util/goroutines/tracker.go @@ -0,0 +1,66 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package goroutines + +import ( + "sync" + "sync/atomic" + + "tailscale.com/util/set" +) + +// Tracker tracks a set of goroutines. +type Tracker struct { + started atomic.Int64 // counter + running atomic.Int64 // gauge + + mu sync.Mutex + onDone set.HandleSet[func()] +} + +func (t *Tracker) Go(f func()) { + t.started.Add(1) + t.running.Add(1) + go t.goAndDecr(f) +} + +func (t *Tracker) goAndDecr(f func()) { + defer t.decr() + f() +} + +func (t *Tracker) decr() { + t.running.Add(-1) + + t.mu.Lock() + defer t.mu.Unlock() + for _, f := range t.onDone { + go f() + } +} + +// AddDoneCallback adds a callback to be called in a new goroutine +// whenever a goroutine managed by t (excluding ones from this method) +// finishes. It returns a function to remove the callback. +func (t *Tracker) AddDoneCallback(f func()) (remove func()) { + t.mu.Lock() + defer t.mu.Unlock() + if t.onDone == nil { + t.onDone = set.HandleSet[func()]{} + } + h := t.onDone.Add(f) + return func() { + t.mu.Lock() + defer t.mu.Unlock() + delete(t.onDone, h) + } +} + +func (t *Tracker) RunningGoroutines() int64 { + return t.running.Load() +} + +func (t *Tracker) StartedGoroutines() int64 { + return t.started.Load() +}