ipn/ipnlocal: use atomic instead of mutex for captive context

Signed-off-by: Andrew Dunham <andrew@du.nham.ca>
Change-Id: If1d593e2f3c300790a69e7a74bf6b4b2f4bfe5a4
This commit is contained in:
Andrew Dunham 2024-07-23 16:43:29 -04:00
parent d1ba34d7cd
commit ee14ff1e32

View File

@ -346,15 +346,11 @@ type LocalBackend struct {
// refreshAutoExitNode indicates if the exit node should be recomputed when the next netcheck report is available. // refreshAutoExitNode indicates if the exit node should be recomputed when the next netcheck report is available.
refreshAutoExitNode bool refreshAutoExitNode bool
// captiveCtx and captiveCancel are used to control captive portal // captiveCtx contain a [context.Context] used to allow cancelation
// detection. They are protected by 'mu' and can be changed during the // when sending to the needsCaptiveDetection channel, along with the
// lifetime of a LocalBackend. // [context.CancelFunc] for that context. It can be changed during the
// // lifetime of the backend, and will always be non-nil.
// captiveCtx will always be non-nil, though it might a canceled captiveCtx syncs.AtomicValue[contextAndCancel]
// context. captiveCancel is non-nil if checkCaptivePortalLoop is
// running, and is set to nil after being canceled.
captiveCtx context.Context
captiveCancel context.CancelFunc
// needsCaptiveDetection is a channel that is used to signal either // needsCaptiveDetection is a channel that is used to signal either
// that captive portal detection is required (sending true) or that the // that captive portal detection is required (sending true) or that the
// backend is healthy and captive portal detection is not required // backend is healthy and captive portal detection is not required
@ -362,6 +358,11 @@ type LocalBackend struct {
needsCaptiveDetection chan bool needsCaptiveDetection chan bool
} }
type contextAndCancel struct {
ctx context.Context
cancel context.CancelFunc
}
// HealthTracker returns the health tracker for the backend. // HealthTracker returns the health tracker for the backend.
func (b *LocalBackend) HealthTracker() *health.Tracker { func (b *LocalBackend) HealthTracker() *health.Tracker {
return b.health return b.health
@ -415,7 +416,8 @@ func NewLocalBackend(logf logger.Logf, logID logid.PublicID, sys *tsd.System, lo
clock := tstime.StdClock{} clock := tstime.StdClock{}
// Until we transition to a Running state, use a canceled context for // Until we transition to a Running state, use a canceled context for
// our captive portal detection. // the context that we use when sending to the needsCaptiveDetection
// channel.
captiveCtx, captiveCancel := context.WithCancel(ctx) captiveCtx, captiveCancel := context.WithCancel(ctx)
captiveCancel() captiveCancel()
@ -440,12 +442,15 @@ func NewLocalBackend(logf logger.Logf, logID logid.PublicID, sys *tsd.System, lo
clock: clock, clock: clock,
selfUpdateProgress: make([]ipnstate.UpdateProgress, 0), selfUpdateProgress: make([]ipnstate.UpdateProgress, 0),
lastSelfUpdateState: ipnstate.UpdateFinished, lastSelfUpdateState: ipnstate.UpdateFinished,
captiveCtx: captiveCtx,
captiveCancel: nil, // so that we start checkCaptivePortalLoop when Running
needsCaptiveDetection: make(chan bool), needsCaptiveDetection: make(chan bool),
} }
b.captiveCtx.Store(contextAndCancel{captiveCtx, captiveCancel})
mConn.SetNetInfoCallback(b.setNetInfo) mConn.SetNetInfoCallback(b.setNetInfo)
// Start our captive portal detection loop; this does nothing until
// triggered by needsCaptiveDetection.
go b.checkCaptivePortalLoop(ctx)
if sys.InitialConfig != nil { if sys.InitialConfig != nil {
if err := b.setConfigLocked(sys.InitialConfig); err != nil { if err := b.setConfigLocked(sys.InitialConfig); err != nil {
return nil, err return nil, err
@ -756,28 +761,27 @@ func (b *LocalBackend) onHealthChange(w *health.Warnable, us *health.UnhealthySt
} }
} }
// captiveCtx can be changed, and is protected with 'mu'; grab that // captiveCtx can be changed; grab that before we start our select,
// before we start our select, below. // below.
// ctx := b.captiveCtx.Load().ctx
// It is guaranteed to be non-nil.
b.mu.Lock()
ctx := b.captiveCtx
b.mu.Unlock()
if isConnectivityImpacted {
b.logf("health: connectivity impacted; triggering captive portal detection")
sendCaptiveDetectionNeeded := func(val bool) {
if ctx.Err() != nil {
return
}
// Ensure that we select on captiveCtx so that we can time out // Ensure that we select on captiveCtx so that we can time out
// triggering captive portal detection if the backend is shutdown. // triggering captive portal detection if the backend is shutdown.
select { select {
case b.needsCaptiveDetection <- true: case b.needsCaptiveDetection <- val:
case <-ctx.Done(): case <-ctx.Done():
} }
}
if isConnectivityImpacted {
b.logf("health: connectivity impacted; triggering captive portal detection")
sendCaptiveDetectionNeeded(true)
} else { } else {
select { sendCaptiveDetectionNeeded(false)
case b.needsCaptiveDetection <- false:
case <-ctx.Done():
}
} }
} }
@ -791,10 +795,8 @@ func (b *LocalBackend) Shutdown() {
} }
b.shutdownCalled = true b.shutdownCalled = true
if b.captiveCancel != nil { b.logf("canceling captive portal context")
b.logf("canceling captive portal context") b.captiveCtx.Load().cancel()
b.captiveCancel()
}
if b.loginFlags&controlclient.LoginEphemeral != 0 { if b.loginFlags&controlclient.LoginEphemeral != 0 {
b.mu.Unlock() b.mu.Unlock()
@ -2174,16 +2176,33 @@ func (b *LocalBackend) updateFilterLocked(netMap *netmap.NetworkMap, prefs ipn.P
func (b *LocalBackend) checkCaptivePortalLoop(ctx context.Context) { func (b *LocalBackend) checkCaptivePortalLoop(ctx context.Context) {
var tmr *time.Timer var tmr *time.Timer
maybeStartTimer := func() {
// If there's an existing timer, nothing to do; just continue
// waiting for it to expire. Otherwise, create a new timer.
if tmr == nil {
tmr = time.NewTimer(captivePortalDetectionInterval)
}
}
maybeStopTimer := func() {
if tmr == nil {
return
}
if !tmr.Stop() {
<-tmr.C
}
tmr = nil
}
for { for {
// First, see if we have a signal on our "healthy" channel, which // First, see if we have a signal on our start/stop channel,
// takes priority over an existing timer. // which takes priority over an existing timer.
select { select {
case needsCaptiveDetection := <-b.needsCaptiveDetection: case needsCaptiveDetection := <-b.needsCaptiveDetection:
if !needsCaptiveDetection && tmr != nil { if needsCaptiveDetection {
if !tmr.Stop() { maybeStartTimer()
<-tmr.C } else {
} maybeStopTimer()
tmr = nil
} }
default: default:
} }
@ -2195,9 +2214,7 @@ func (b *LocalBackend) checkCaptivePortalLoop(ctx context.Context) {
select { select {
case <-ctx.Done(): case <-ctx.Done():
// All done; stop the timer and then exit. // All done; stop the timer and then exit.
if tmr != nil && !tmr.Stop() { maybeStopTimer()
<-tmr.C
}
return return
case <-timerChan: case <-timerChan:
// Kick off captive portal check // Kick off captive portal check
@ -2206,18 +2223,10 @@ func (b *LocalBackend) checkCaptivePortalLoop(ctx context.Context) {
tmr = nil tmr = nil
case needsCaptiveDetection := <-b.needsCaptiveDetection: case needsCaptiveDetection := <-b.needsCaptiveDetection:
if needsCaptiveDetection { if needsCaptiveDetection {
// If there's an existing timer, nothing to do; just maybeStartTimer()
// continue waiting for it to expire. Otherwise, create
// a new timer.
if tmr == nil {
tmr = time.NewTimer(captivePortalDetectionInterval)
}
} else { } else {
// Healthy; cancel any existing timer // Healthy; cancel any existing timer
if tmr != nil && !tmr.Stop() { maybeStopTimer()
<-tmr.C
}
tmr = nil
} }
} }
} }
@ -4656,26 +4665,18 @@ func (b *LocalBackend) enterStateLockedOnEntry(newState ipn.State, unlock unlock
b.authURL = "" b.authURL = ""
b.authURLTime = time.Time{} b.authURLTime = time.Time{}
// Start a captive portal detection loop if none has been // Create a new context for sending to the captive portal
// started. Create a new context if none is present, since it // detection loop, and cancel the old one.
// can be shut down if we transition away from Running. captiveCtx, captiveCancel := context.WithCancel(b.ctx)
if b.captiveCancel == nil { oldCC := b.captiveCtx.Swap(contextAndCancel{captiveCtx, captiveCancel})
b.captiveCtx, b.captiveCancel = context.WithCancel(b.ctx) oldCC.cancel()
go b.checkCaptivePortalLoop(b.captiveCtx)
}
} else if oldState == ipn.Running { } else if oldState == ipn.Running {
// Transitioning away from running. // Transitioning away from running.
b.closePeerAPIListenersLocked() b.closePeerAPIListenersLocked()
// Stop any existing captive portal detection loop. // Cancel our captiveCtx to unblock anything trying to send to
if b.captiveCancel != nil { // the captive portal detection loop.
b.captiveCancel() b.captiveCtx.Load().cancel()
b.captiveCancel = nil
// NOTE: don't set captiveCtx to nil here, to ensure
// that we always have a (canceled) context to wait on
// in onHealthChange.
}
} }
b.pauseOrResumeControlClientLocked() b.pauseOrResumeControlClientLocked()