ipn/ipnlocal: fix captive portal loop shutdown

Signed-off-by: Andrew Dunham <andrew@du.nham.ca>
Change-Id: I7cafdbce68463a16260091bcec1741501a070c95
This commit is contained in:
Andrew Dunham 2024-07-23 13:58:21 -04:00
parent 5e77172feb
commit d1ba34d7cd

View File

@ -346,8 +346,20 @@ type LocalBackend struct {
// refreshAutoExitNode indicates if the exit node should be recomputed when the next netcheck report is available. // refreshAutoExitNode indicates if the exit node should be recomputed when the next netcheck report is available.
refreshAutoExitNode bool refreshAutoExitNode bool
// captiveCtx and captiveCancel are used to control captive portal
// detection. They are protected by 'mu' and can be changed during the
// lifetime of a LocalBackend.
//
// captiveCtx will always be non-nil, though it might a canceled
// context. captiveCancel is non-nil if checkCaptivePortalLoop is
// running, and is set to nil after being canceled.
captiveCtx context.Context
captiveCancel context.CancelFunc
// needsCaptiveDetection is a channel that is used to signal either
// that captive portal detection is required (sending true) or that the
// backend is healthy and captive portal detection is not required
// (sending false).
needsCaptiveDetection chan bool needsCaptiveDetection chan bool
captiveStopGoroutine chan struct{}
} }
// HealthTracker returns the health tracker for the backend. // HealthTracker returns the health tracker for the backend.
@ -402,6 +414,11 @@ func NewLocalBackend(logf logger.Logf, logID logid.PublicID, sys *tsd.System, lo
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
clock := tstime.StdClock{} clock := tstime.StdClock{}
// Until we transition to a Running state, use a canceled context for
// our captive portal detection.
captiveCtx, captiveCancel := context.WithCancel(ctx)
captiveCancel()
b := &LocalBackend{ b := &LocalBackend{
ctx: ctx, ctx: ctx,
ctxCancel: cancel, ctxCancel: cancel,
@ -423,8 +440,9 @@ func NewLocalBackend(logf logger.Logf, logID logid.PublicID, sys *tsd.System, lo
clock: clock, clock: clock,
selfUpdateProgress: make([]ipnstate.UpdateProgress, 0), selfUpdateProgress: make([]ipnstate.UpdateProgress, 0),
lastSelfUpdateState: ipnstate.UpdateFinished, lastSelfUpdateState: ipnstate.UpdateFinished,
captiveCtx: captiveCtx,
captiveCancel: nil, // so that we start checkCaptivePortalLoop when Running
needsCaptiveDetection: make(chan bool), needsCaptiveDetection: make(chan bool),
captiveStopGoroutine: make(chan struct{}),
} }
mConn.SetNetInfoCallback(b.setNetInfo) mConn.SetNetInfoCallback(b.setNetInfo)
@ -738,11 +756,28 @@ func (b *LocalBackend) onHealthChange(w *health.Warnable, us *health.UnhealthySt
} }
} }
// captiveCtx can be changed, and is protected with 'mu'; grab that
// before we start our select, below.
//
// It is guaranteed to be non-nil.
b.mu.Lock()
ctx := b.captiveCtx
b.mu.Unlock()
if isConnectivityImpacted { if isConnectivityImpacted {
b.logf("health: connectivity impacted; triggering captive portal detection") b.logf("health: connectivity impacted; triggering captive portal detection")
b.needsCaptiveDetection <- true
// Ensure that we select on captiveCtx so that we can time out
// triggering captive portal detection if the backend is shutdown.
select {
case b.needsCaptiveDetection <- true:
case <-ctx.Done():
}
} else { } else {
b.needsCaptiveDetection <- false select {
case b.needsCaptiveDetection <- false:
case <-ctx.Done():
}
} }
} }
@ -756,6 +791,11 @@ func (b *LocalBackend) Shutdown() {
} }
b.shutdownCalled = true b.shutdownCalled = true
if b.captiveCancel != nil {
b.logf("canceling captive portal context")
b.captiveCancel()
}
if b.loginFlags&controlclient.LoginEphemeral != 0 { if b.loginFlags&controlclient.LoginEphemeral != 0 {
b.mu.Unlock() b.mu.Unlock()
ctx, cancel := context.WithTimeout(b.ctx, 5*time.Second) ctx, cancel := context.WithTimeout(b.ctx, 5*time.Second)
@ -2132,7 +2172,7 @@ func (b *LocalBackend) updateFilterLocked(netMap *netmap.NetworkMap, prefs ipn.P
ImpactsConnectivity: true, ImpactsConnectivity: true,
}) })
func (b *LocalBackend) checkCaptivePortalLoop() { func (b *LocalBackend) checkCaptivePortalLoop(ctx context.Context) {
var tmr *time.Timer var tmr *time.Timer
for { for {
// First, see if we have a signal on our "healthy" channel, which // First, see if we have a signal on our "healthy" channel, which
@ -2140,7 +2180,6 @@ func (b *LocalBackend) checkCaptivePortalLoop() {
select { select {
case needsCaptiveDetection := <-b.needsCaptiveDetection: case needsCaptiveDetection := <-b.needsCaptiveDetection:
if !needsCaptiveDetection && tmr != nil { if !needsCaptiveDetection && tmr != nil {
println("checkCaptivePortalLoop: canceling existing timer (early)")
if !tmr.Stop() { if !tmr.Stop() {
<-tmr.C <-tmr.C
} }
@ -2154,16 +2193,14 @@ func (b *LocalBackend) checkCaptivePortalLoop() {
timerChan = tmr.C timerChan = tmr.C
} }
select { select {
case <-b.captiveStopGoroutine: case <-ctx.Done():
// All done; stop the timer and then exit. // All done; stop the timer and then exit.
if tmr != nil && !tmr.Stop() { if tmr != nil && !tmr.Stop() {
<-tmr.C <-tmr.C
} }
println("checkCaptivePortalLoop: shutting down")
return return
case <-timerChan: case <-timerChan:
// Kick off captive portal check // Kick off captive portal check
println("checkCaptivePortalLoop: will do captive portal check")
b.performCaptiveDetection() b.performCaptiveDetection()
// nil out timer to force recreation // nil out timer to force recreation
tmr = nil tmr = nil
@ -2173,17 +2210,13 @@ func (b *LocalBackend) checkCaptivePortalLoop() {
// continue waiting for it to expire. Otherwise, create // continue waiting for it to expire. Otherwise, create
// a new timer. // a new timer.
if tmr == nil { if tmr == nil {
tmr = time.NewTimer(2 * time.Second) tmr = time.NewTimer(captivePortalDetectionInterval)
println("checkCaptivePortalLoop: started new timer")
} }
} else { } else {
// Healthy; cancel any existing timer // Healthy; cancel any existing timer
if tmr != nil && !tmr.Stop() { if tmr != nil && !tmr.Stop() {
<-tmr.C <-tmr.C
} }
if tmr != nil {
println("checkCaptivePortalLoop: canceling existing timer")
}
tmr = nil tmr = nil
} }
} }
@ -4622,11 +4655,27 @@ func (b *LocalBackend) enterStateLockedOnEntry(newState ipn.State, unlock unlock
if newState == ipn.Running { if newState == ipn.Running {
b.authURL = "" b.authURL = ""
b.authURLTime = time.Time{} b.authURLTime = time.Time{}
go b.checkCaptivePortalLoop()
// Start a captive portal detection loop if none has been
// started. Create a new context if none is present, since it
// can be shut down if we transition away from Running.
if b.captiveCancel == nil {
b.captiveCtx, b.captiveCancel = context.WithCancel(b.ctx)
go b.checkCaptivePortalLoop(b.captiveCtx)
}
} else if oldState == ipn.Running { } else if oldState == ipn.Running {
// Transitioning away from running. // Transitioning away from running.
b.closePeerAPIListenersLocked() b.closePeerAPIListenersLocked()
close(b.captiveStopGoroutine)
// Stop any existing captive portal detection loop.
if b.captiveCancel != nil {
b.captiveCancel()
b.captiveCancel = nil
// NOTE: don't set captiveCtx to nil here, to ensure
// that we always have a (canceled) context to wait on
// in onHealthChange.
}
} }
b.pauseOrResumeControlClientLocked() b.pauseOrResumeControlClientLocked()