From ee14ff1e328d5c56fd57e4dcde13a33eb314e81c Mon Sep 17 00:00:00 2001 From: Andrew Dunham Date: Tue, 23 Jul 2024 16:43:29 -0400 Subject: [PATCH] ipn/ipnlocal: use atomic instead of mutex for captive context Signed-off-by: Andrew Dunham Change-Id: If1d593e2f3c300790a69e7a74bf6b4b2f4bfe5a4 --- ipn/ipnlocal/local.go | 135 +++++++++++++++++++++--------------------- 1 file changed, 68 insertions(+), 67 deletions(-) diff --git a/ipn/ipnlocal/local.go b/ipn/ipnlocal/local.go index 48cb59c2f..0a8bace40 100644 --- a/ipn/ipnlocal/local.go +++ b/ipn/ipnlocal/local.go @@ -346,15 +346,11 @@ type LocalBackend struct { // refreshAutoExitNode indicates if the exit node should be recomputed when the next netcheck report is available. refreshAutoExitNode bool - // captiveCtx and captiveCancel are used to control captive portal - // detection. They are protected by 'mu' and can be changed during the - // lifetime of a LocalBackend. - // - // captiveCtx will always be non-nil, though it might a canceled - // context. captiveCancel is non-nil if checkCaptivePortalLoop is - // running, and is set to nil after being canceled. - captiveCtx context.Context - captiveCancel context.CancelFunc + // captiveCtx contain a [context.Context] used to allow cancelation + // when sending to the needsCaptiveDetection channel, along with the + // [context.CancelFunc] for that context. It can be changed during the + // lifetime of the backend, and will always be non-nil. + captiveCtx syncs.AtomicValue[contextAndCancel] // needsCaptiveDetection is a channel that is used to signal either // that captive portal detection is required (sending true) or that the // backend is healthy and captive portal detection is not required @@ -362,6 +358,11 @@ type LocalBackend struct { needsCaptiveDetection chan bool } +type contextAndCancel struct { + ctx context.Context + cancel context.CancelFunc +} + // HealthTracker returns the health tracker for the backend. func (b *LocalBackend) HealthTracker() *health.Tracker { return b.health @@ -415,7 +416,8 @@ func NewLocalBackend(logf logger.Logf, logID logid.PublicID, sys *tsd.System, lo clock := tstime.StdClock{} // Until we transition to a Running state, use a canceled context for - // our captive portal detection. + // the context that we use when sending to the needsCaptiveDetection + // channel. captiveCtx, captiveCancel := context.WithCancel(ctx) captiveCancel() @@ -440,12 +442,15 @@ func NewLocalBackend(logf logger.Logf, logID logid.PublicID, sys *tsd.System, lo clock: clock, selfUpdateProgress: make([]ipnstate.UpdateProgress, 0), lastSelfUpdateState: ipnstate.UpdateFinished, - captiveCtx: captiveCtx, - captiveCancel: nil, // so that we start checkCaptivePortalLoop when Running needsCaptiveDetection: make(chan bool), } + b.captiveCtx.Store(contextAndCancel{captiveCtx, captiveCancel}) mConn.SetNetInfoCallback(b.setNetInfo) + // Start our captive portal detection loop; this does nothing until + // triggered by needsCaptiveDetection. + go b.checkCaptivePortalLoop(ctx) + if sys.InitialConfig != nil { if err := b.setConfigLocked(sys.InitialConfig); err != nil { return nil, err @@ -756,28 +761,27 @@ func (b *LocalBackend) onHealthChange(w *health.Warnable, us *health.UnhealthySt } } - // captiveCtx can be changed, and is protected with 'mu'; grab that - // before we start our select, below. - // - // It is guaranteed to be non-nil. - b.mu.Lock() - ctx := b.captiveCtx - b.mu.Unlock() - - if isConnectivityImpacted { - b.logf("health: connectivity impacted; triggering captive portal detection") + // captiveCtx can be changed; grab that before we start our select, + // below. + ctx := b.captiveCtx.Load().ctx + sendCaptiveDetectionNeeded := func(val bool) { + if ctx.Err() != nil { + return + } // Ensure that we select on captiveCtx so that we can time out // triggering captive portal detection if the backend is shutdown. select { - case b.needsCaptiveDetection <- true: + case b.needsCaptiveDetection <- val: case <-ctx.Done(): } + } + + if isConnectivityImpacted { + b.logf("health: connectivity impacted; triggering captive portal detection") + sendCaptiveDetectionNeeded(true) } else { - select { - case b.needsCaptiveDetection <- false: - case <-ctx.Done(): - } + sendCaptiveDetectionNeeded(false) } } @@ -791,10 +795,8 @@ func (b *LocalBackend) Shutdown() { } b.shutdownCalled = true - if b.captiveCancel != nil { - b.logf("canceling captive portal context") - b.captiveCancel() - } + b.logf("canceling captive portal context") + b.captiveCtx.Load().cancel() if b.loginFlags&controlclient.LoginEphemeral != 0 { b.mu.Unlock() @@ -2174,16 +2176,33 @@ func (b *LocalBackend) updateFilterLocked(netMap *netmap.NetworkMap, prefs ipn.P func (b *LocalBackend) checkCaptivePortalLoop(ctx context.Context) { var tmr *time.Timer + + maybeStartTimer := func() { + // If there's an existing timer, nothing to do; just continue + // waiting for it to expire. Otherwise, create a new timer. + if tmr == nil { + tmr = time.NewTimer(captivePortalDetectionInterval) + } + } + maybeStopTimer := func() { + if tmr == nil { + return + } + if !tmr.Stop() { + <-tmr.C + } + tmr = nil + } + for { - // First, see if we have a signal on our "healthy" channel, which - // takes priority over an existing timer. + // First, see if we have a signal on our start/stop channel, + // which takes priority over an existing timer. select { case needsCaptiveDetection := <-b.needsCaptiveDetection: - if !needsCaptiveDetection && tmr != nil { - if !tmr.Stop() { - <-tmr.C - } - tmr = nil + if needsCaptiveDetection { + maybeStartTimer() + } else { + maybeStopTimer() } default: } @@ -2195,9 +2214,7 @@ func (b *LocalBackend) checkCaptivePortalLoop(ctx context.Context) { select { case <-ctx.Done(): // All done; stop the timer and then exit. - if tmr != nil && !tmr.Stop() { - <-tmr.C - } + maybeStopTimer() return case <-timerChan: // Kick off captive portal check @@ -2206,18 +2223,10 @@ func (b *LocalBackend) checkCaptivePortalLoop(ctx context.Context) { tmr = nil case needsCaptiveDetection := <-b.needsCaptiveDetection: if needsCaptiveDetection { - // If there's an existing timer, nothing to do; just - // continue waiting for it to expire. Otherwise, create - // a new timer. - if tmr == nil { - tmr = time.NewTimer(captivePortalDetectionInterval) - } + maybeStartTimer() } else { // Healthy; cancel any existing timer - if tmr != nil && !tmr.Stop() { - <-tmr.C - } - tmr = nil + maybeStopTimer() } } } @@ -4656,26 +4665,18 @@ func (b *LocalBackend) enterStateLockedOnEntry(newState ipn.State, unlock unlock b.authURL = "" b.authURLTime = time.Time{} - // Start a captive portal detection loop if none has been - // started. Create a new context if none is present, since it - // can be shut down if we transition away from Running. - if b.captiveCancel == nil { - b.captiveCtx, b.captiveCancel = context.WithCancel(b.ctx) - go b.checkCaptivePortalLoop(b.captiveCtx) - } + // Create a new context for sending to the captive portal + // detection loop, and cancel the old one. + captiveCtx, captiveCancel := context.WithCancel(b.ctx) + oldCC := b.captiveCtx.Swap(contextAndCancel{captiveCtx, captiveCancel}) + oldCC.cancel() } else if oldState == ipn.Running { // Transitioning away from running. b.closePeerAPIListenersLocked() - // Stop any existing captive portal detection loop. - if b.captiveCancel != nil { - b.captiveCancel() - b.captiveCancel = nil - - // NOTE: don't set captiveCtx to nil here, to ensure - // that we always have a (canceled) context to wait on - // in onHealthChange. - } + // Cancel our captiveCtx to unblock anything trying to send to + // the captive portal detection loop. + b.captiveCtx.Load().cancel() } b.pauseOrResumeControlClientLocked()