ipn/ipnlocal, util/goroutines: track goroutines for tests, shutdown

Updates #14520
Updates #14517 (in that I pulled this out of there)

Change-Id: Ibc28162816e083fcadf550586c06805c76e378fc
Signed-off-by: Brad Fitzpatrick <bradfitz@tailscale.com>
This commit is contained in:
Brad Fitzpatrick 2025-01-03 14:24:32 -08:00 committed by Brad Fitzpatrick
parent b90707665e
commit 07aae18bca
3 changed files with 118 additions and 14 deletions

View File

@ -96,6 +96,7 @@ import (
"tailscale.com/types/views"
"tailscale.com/util/deephash"
"tailscale.com/util/dnsname"
"tailscale.com/util/goroutines"
"tailscale.com/util/httpm"
"tailscale.com/util/mak"
"tailscale.com/util/multierr"
@ -178,7 +179,7 @@ type watchSession struct {
// state machine generates events back out to zero or more components.
type LocalBackend struct {
// Elements that are thread-safe or constant after construction.
ctx context.Context // canceled by Close
ctx context.Context // canceled by [LocalBackend.Shutdown]
ctxCancel context.CancelFunc // cancels ctx
logf logger.Logf // general logging
keyLogf logger.Logf // for printing list of peers on change
@ -231,6 +232,10 @@ type LocalBackend struct {
shouldInterceptTCPPortAtomic syncs.AtomicValue[func(uint16) bool]
numClientStatusCalls atomic.Uint32
// goTracker accounts for all goroutines started by LocalBacked, primarily
// for testing and graceful shutdown purposes.
goTracker goroutines.Tracker
// The mutex protects the following elements.
mu sync.Mutex
conf *conffile.Config // latest parsed config, or nil if not in declarative mode
@ -866,7 +871,7 @@ func (b *LocalBackend) linkChange(delta *netmon.ChangeDelta) {
// TODO(raggi,tailscale/corp#22574): authReconfig should be refactored such that we can call the
// necessary operations here and avoid the need for asynchronous behavior that is racy and hard
// to test here, and do less extra work in these conditions.
go b.authReconfig()
b.goTracker.Go(b.authReconfig)
}
}
@ -879,7 +884,7 @@ func (b *LocalBackend) linkChange(delta *netmon.ChangeDelta) {
want := b.netMap.GetAddresses().Len()
if len(b.peerAPIListeners) < want {
b.logf("linkChange: peerAPIListeners too low; trying again")
go b.initPeerAPIListener()
b.goTracker.Go(b.initPeerAPIListener)
}
}
}
@ -1004,6 +1009,33 @@ func (b *LocalBackend) Shutdown() {
b.ctxCancel()
b.e.Close()
<-b.e.Done()
b.awaitNoGoroutinesInTest()
}
func (b *LocalBackend) awaitNoGoroutinesInTest() {
if !testenv.InTest() {
return
}
ctx, cancel := context.WithTimeout(context.Background(), 8*time.Second)
defer cancel()
ch := make(chan bool, 1)
defer b.goTracker.AddDoneCallback(func() { ch <- true })()
for {
n := b.goTracker.RunningGoroutines()
if n == 0 {
return
}
select {
case <-ctx.Done():
// TODO(bradfitz): pass down some TB-like failer interface from
// tests, without depending on testing from here?
// But this is fine in tests too:
panic(fmt.Sprintf("timeout waiting for %d goroutines to stop", n))
case <-ch:
}
}
}
func stripKeysFromPrefs(p ipn.PrefsView) ipn.PrefsView {
@ -2152,7 +2184,7 @@ func (b *LocalBackend) Start(opts ipn.Options) error {
if b.portpoll != nil {
b.portpollOnce.Do(func() {
go b.readPoller()
b.goTracker.Go(b.readPoller)
})
}
@ -2366,7 +2398,7 @@ func (b *LocalBackend) updateFilterLocked(netMap *netmap.NetworkMap, prefs ipn.P
b.e.SetJailedFilter(filter.NewShieldsUpFilter(localNets, logNets, oldJailedFilter, b.logf))
if b.sshServer != nil {
go b.sshServer.OnPolicyChange()
b.goTracker.Go(b.sshServer.OnPolicyChange)
}
}
@ -2843,7 +2875,7 @@ func (b *LocalBackend) WatchNotificationsAs(ctx context.Context, actor ipnauth.A
// request every 2 seconds.
// TODO(bradfitz): plumb this further and only send a Notify on change.
if mask&ipn.NotifyWatchEngineUpdates != 0 {
go b.pollRequestEngineStatus(ctx)
b.goTracker.Go(func() { b.pollRequestEngineStatus(ctx) })
}
// TODO(marwan-at-work): streaming background logs?
@ -3850,7 +3882,7 @@ func (b *LocalBackend) editPrefsLockedOnEntry(mp *ipn.MaskedPrefs, unlock unlock
if mp.EggSet {
mp.EggSet = false
b.egg = true
go b.doSetHostinfoFilterServices()
b.goTracker.Go(b.doSetHostinfoFilterServices)
}
p0 := b.pm.CurrentPrefs()
p1 := b.pm.CurrentPrefs().AsStruct()
@ -3943,7 +3975,7 @@ func (b *LocalBackend) setPrefsLockedOnEntry(newp *ipn.Prefs, unlock unlockOnce)
if oldp.ShouldSSHBeRunning() && !newp.ShouldSSHBeRunning() {
if b.sshServer != nil {
go b.sshServer.Shutdown()
b.goTracker.Go(b.sshServer.Shutdown)
b.sshServer = nil
}
}
@ -4285,8 +4317,14 @@ func (b *LocalBackend) authReconfig() {
dcfg := dnsConfigForNetmap(nm, b.peers, prefs, b.keyExpired, b.logf, version.OS())
// If the current node is an app connector, ensure the app connector machine is started
b.reconfigAppConnectorLocked(nm, prefs)
closing := b.shutdownCalled
b.mu.Unlock()
if closing {
b.logf("[v1] authReconfig: skipping because in shutdown")
return
}
if blocked {
b.logf("[v1] authReconfig: blocked, skipping.")
return
@ -4751,7 +4789,7 @@ func (b *LocalBackend) initPeerAPIListener() {
b.peerAPIListeners = append(b.peerAPIListeners, pln)
}
go b.doSetHostinfoFilterServices()
b.goTracker.Go(b.doSetHostinfoFilterServices)
}
// magicDNSRootDomains returns the subset of nm.DNS.Domains that are the search domains for MagicDNS.
@ -5020,7 +5058,7 @@ func (b *LocalBackend) enterStateLockedOnEntry(newState ipn.State, unlock unlock
// can be shut down if we transition away from Running.
if b.captiveCancel == nil {
b.captiveCtx, b.captiveCancel = context.WithCancel(b.ctx)
go b.checkCaptivePortalLoop(b.captiveCtx)
b.goTracker.Go(func() { b.checkCaptivePortalLoop(b.captiveCtx) })
}
} else if oldState == ipn.Running {
// Transitioning away from running.
@ -5272,7 +5310,7 @@ func (b *LocalBackend) requestEngineStatusAndWait() {
b.statusLock.Lock()
defer b.statusLock.Unlock()
go b.e.RequestStatus()
b.goTracker.Go(b.e.RequestStatus)
b.logf("requestEngineStatusAndWait: waiting...")
b.statusChanged.Wait() // temporarily releases lock while waiting
b.logf("requestEngineStatusAndWait: got status update.")
@ -5383,7 +5421,7 @@ func (b *LocalBackend) setWebClientAtomicBoolLocked(nm *netmap.NetworkMap) {
shouldRun := !nm.HasCap(tailcfg.NodeAttrDisableWebClient)
wasRunning := b.webClientAtomicBool.Swap(shouldRun)
if wasRunning && !shouldRun {
go b.webClientShutdown() // stop web client
b.goTracker.Go(b.webClientShutdown) // stop web client
}
}
@ -5900,7 +5938,7 @@ func (b *LocalBackend) setTCPPortsInterceptedFromNetmapAndPrefsLocked(prefs ipn.
if wire := b.wantIngressLocked(); b.hostinfo != nil && b.hostinfo.WireIngress != wire {
b.logf("Hostinfo.WireIngress changed to %v", wire)
b.hostinfo.WireIngress = wire
go b.doSetHostinfoFilterServices()
b.goTracker.Go(b.doSetHostinfoFilterServices)
}
b.setTCPPortsIntercepted(handlePorts)

View File

@ -1,7 +1,7 @@
// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause
// The goroutines package contains utilities for getting active goroutines.
// The goroutines package contains utilities for tracking and getting active goroutines.
package goroutines
import (

View File

@ -0,0 +1,66 @@
// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause
package goroutines
import (
"sync"
"sync/atomic"
"tailscale.com/util/set"
)
// Tracker tracks a set of goroutines.
type Tracker struct {
started atomic.Int64 // counter
running atomic.Int64 // gauge
mu sync.Mutex
onDone set.HandleSet[func()]
}
func (t *Tracker) Go(f func()) {
t.started.Add(1)
t.running.Add(1)
go t.goAndDecr(f)
}
func (t *Tracker) goAndDecr(f func()) {
defer t.decr()
f()
}
func (t *Tracker) decr() {
t.running.Add(-1)
t.mu.Lock()
defer t.mu.Unlock()
for _, f := range t.onDone {
go f()
}
}
// AddDoneCallback adds a callback to be called in a new goroutine
// whenever a goroutine managed by t (excluding ones from this method)
// finishes. It returns a function to remove the callback.
func (t *Tracker) AddDoneCallback(f func()) (remove func()) {
t.mu.Lock()
defer t.mu.Unlock()
if t.onDone == nil {
t.onDone = set.HandleSet[func()]{}
}
h := t.onDone.Add(f)
return func() {
t.mu.Lock()
defer t.mu.Unlock()
delete(t.onDone, h)
}
}
func (t *Tracker) RunningGoroutines() int64 {
return t.running.Load()
}
func (t *Tracker) StartedGoroutines() int64 {
return t.started.Load()
}