diff --git a/control/controlclient/auto.go b/control/controlclient/auto.go index b0f705c31..fd2a13212 100644 --- a/control/controlclient/auto.go +++ b/control/controlclient/auto.go @@ -25,44 +25,28 @@ ) type LoginGoal struct { - _ structs.Incomparable - wantLoggedIn bool // true if we *want* to be logged in - token *tailcfg.Oauth2Token // oauth token to use when logging in - flags LoginFlags // flags to use when logging in - url string // auth url that needs to be visited - loggedOutResult chan<- error -} - -func (g *LoginGoal) sendLogoutError(err error) { - if g.loggedOutResult == nil { - return - } - select { - case g.loggedOutResult <- err: - default: - } + _ structs.Incomparable + token *tailcfg.Oauth2Token // oauth token to use when logging in + flags LoginFlags // flags to use when logging in + url string // auth url that needs to be visited } var _ Client = (*Auto)(nil) -// waitUnpause waits until the client is unpaused then returns. It only -// returns an error if the client is closed. -func (c *Auto) waitUnpause(routineLogName string) error { +// waitUnpause waits until either the client is unpaused or the Auto client is +// shut down. It reports whether the client should keep running (i.e. it's not +// closed). +func (c *Auto) waitUnpause(routineLogName string) (keepRunning bool) { c.mu.Lock() if !c.paused { - c.mu.Unlock() - return nil + defer c.mu.Unlock() + return !c.closed } unpaused := c.unpausedChanLocked() c.mu.Unlock() + c.logf("%s: awaiting unpause", routineLogName) - select { - case <-unpaused: - c.logf("%s: unpaused", routineLogName) - return nil - case <-c.quit: - return errors.New("quit") - } + return <-unpaused } // updateRoutine is responsible for informing the server of worthy changes to @@ -76,7 +60,7 @@ func (c *Auto) updateRoutine() { var lastUpdateGenInformed updateGen for { - if err := c.waitUnpause("updateRoutine"); err != nil { + if !c.waitUnpause("updateRoutine") { c.logf("updateRoutine: exiting") return } @@ -86,19 +70,11 @@ func (c *Auto) updateRoutine() { needUpdate := gen > 0 && gen != lastUpdateGenInformed && c.loggedIn c.mu.Unlock() - if needUpdate { - select { - case <-c.quit: - c.logf("updateRoutine: exiting") - return - default: - } - } else { + if !needUpdate { // Nothing to do, wait for a signal. select { - case <-c.quit: - c.logf("updateRoutine: exiting") - return + case <-ctx.Done(): + continue case <-c.updateCh: continue } @@ -141,7 +117,6 @@ type Auto struct { logf logger.Logf closed bool updateCh chan struct{} // readable when we should inform the server of a change - newMapCh chan struct{} // readable when we must restart a map request observer Observer // called to update Client status; always non-nil observerQueue execQueue @@ -149,24 +124,25 @@ type Auto struct { mu sync.Mutex // mutex guards the following fields - expiry time.Time + wantLoggedIn bool // whether the user wants to be logged in per last method call + urlToVisit string // the last url we were told to visit + expiry time.Time // lastUpdateGen is the gen of last update we had an update worth sending to // the server. lastUpdateGen updateGen - paused bool // whether we should stop making HTTP requests - unpauseWaiters []chan struct{} - loggedIn bool // true if currently logged in - loginGoal *LoginGoal // non-nil if some login activity is desired - synced bool // true if our netmap is up-to-date - state State + paused bool // whether we should stop making HTTP requests + unpauseWaiters []chan bool // chans that gets sent true (once) on wake, or false on Shutdown + loggedIn bool // true if currently logged in + loginGoal *LoginGoal // non-nil if some login activity is desired + inMapPoll bool // true once we get the first MapResponse in a stream; false when HTTP response ends + state State // TODO(bradfitz): delete this, make it computed by method from other state authCtx context.Context // context used for auth requests mapCtx context.Context // context used for netmap and update requests authCancel func() // cancel authCtx mapCancel func() // cancel mapCtx - quit chan struct{} // when closed, goroutines should all exit authDone chan struct{} // when closed, authRoutine is done mapDone chan struct{} // when closed, mapRoutine is done updateDone chan struct{} // when closed, updateRoutine is done @@ -207,8 +183,6 @@ func NewNoStart(opts Options) (_ *Auto, err error) { clock: opts.Clock, logf: opts.Logf, updateCh: make(chan struct{}, 1), - newMapCh: make(chan struct{}, 1), - quit: make(chan struct{}), authDone: make(chan struct{}), mapDone: make(chan struct{}), updateDone: make(chan struct{}), @@ -237,15 +211,14 @@ func (c *Auto) SetPaused(paused bool) { c.logf("setPaused(%v)", paused) c.paused = paused if paused { - // Only cancel the map routine. (The auth routine isn't expensive - // so it's fine to keep it running.) c.cancelMapCtxLocked() - } else { - for _, ch := range c.unpauseWaiters { - close(ch) - } - c.unpauseWaiters = nil + c.cancelAuthCtxLocked() + return } + for _, ch := range c.unpauseWaiters { + ch <- true + } + c.unpauseWaiters = nil } // Start starts the client's goroutines. @@ -322,20 +295,10 @@ func (c *Auto) cancelMapCtxLocked() { func (c *Auto) restartMap() { c.mu.Lock() c.cancelMapCtxLocked() - synced := c.synced + synced := c.inMapPoll c.mu.Unlock() c.logf("[v1] restartMap: synced=%v", synced) - - select { - case c.newMapCh <- struct{}{}: - c.logf("[v1] restartMap: wrote to channel") - default: - // if channel write failed, then there was already - // an outstanding newMapCh request. One is enough, - // since it'll always use the latest endpoints. - c.logf("[v1] restartMap: channel was full") - } c.updateControl() } @@ -344,23 +307,20 @@ func (c *Auto) authRoutine() { bo := backoff.NewBackoff("authRoutine", c.logf, 30*time.Second) for { + if !c.waitUnpause("authRoutine") { + c.logf("authRoutine: exiting") + return + } c.mu.Lock() goal := c.loginGoal ctx := c.authCtx if goal != nil { - c.logf("[v1] authRoutine: %s; wantLoggedIn=%v", c.state, goal.wantLoggedIn) + c.logf("[v1] authRoutine: %s; wantLoggedIn=%v", c.state, true) } else { c.logf("[v1] authRoutine: %s; goal=nil paused=%v", c.state, c.paused) } c.mu.Unlock() - select { - case <-c.quit: - c.logf("[v1] authRoutine: quit") - return - default: - } - report := func(err error, msg string) { c.logf("[v1] %s: %v", msg, err) // don't send status updates for context errors, @@ -378,88 +338,67 @@ func (c *Auto) authRoutine() { continue } - if !goal.wantLoggedIn { - health.SetAuthRoutineInError(nil) - err := c.direct.TryLogout(ctx) - goal.sendLogoutError(err) - if err != nil { - report(err, "TryLogout") - bo.BackOff(ctx, err) - continue - } - - // success - c.mu.Lock() - c.loggedIn = false - c.loginGoal = nil - c.state = StateNotAuthenticated - c.synced = false - c.mu.Unlock() - - c.sendStatus("authRoutine-wantout", nil, "", nil) - bo.BackOff(ctx, nil) - } else { // ie. goal.wantLoggedIn - c.mu.Lock() - if goal.url != "" { - c.state = StateURLVisitRequired - } else { - c.state = StateAuthenticating - } - c.mu.Unlock() - - var url string - var err error - var f string - if goal.url != "" { - url, err = c.direct.WaitLoginURL(ctx, goal.url) - f = "WaitLoginURL" - } else { - url, err = c.direct.TryLogin(ctx, goal.token, goal.flags) - f = "TryLogin" - } - if err != nil { - health.SetAuthRoutineInError(err) - report(err, f) - bo.BackOff(ctx, err) - continue - } - if url != "" { - // goal.url ought to be empty here. - // However, not all control servers get this right, - // and logging about it here just generates noise. - c.mu.Lock() - c.loginGoal = &LoginGoal{ - wantLoggedIn: true, - flags: LoginDefault, - url: url, - } - c.state = StateURLVisitRequired - c.synced = false - c.mu.Unlock() - - c.sendStatus("authRoutine-url", err, url, nil) - if goal.url == url { - // The server sent us the same URL we already tried, - // backoff to avoid a busy loop. - bo.BackOff(ctx, errors.New("login URL not changing")) - } else { - bo.BackOff(ctx, nil) - } - continue - } - - // success - health.SetAuthRoutineInError(nil) - c.mu.Lock() - c.loggedIn = true - c.loginGoal = nil - c.state = StateAuthenticated - c.mu.Unlock() - - c.sendStatus("authRoutine-success", nil, "", nil) - c.restartMap() - bo.BackOff(ctx, nil) + c.mu.Lock() + c.urlToVisit = goal.url + if goal.url != "" { + c.state = StateURLVisitRequired + } else { + c.state = StateAuthenticating } + c.mu.Unlock() + + var url string + var err error + var f string + if goal.url != "" { + url, err = c.direct.WaitLoginURL(ctx, goal.url) + f = "WaitLoginURL" + } else { + url, err = c.direct.TryLogin(ctx, goal.token, goal.flags) + f = "TryLogin" + } + if err != nil { + health.SetAuthRoutineInError(err) + report(err, f) + bo.BackOff(ctx, err) + continue + } + if url != "" { + // goal.url ought to be empty here. + // However, not all control servers get this right, + // and logging about it here just generates noise. + c.mu.Lock() + c.urlToVisit = url + c.loginGoal = &LoginGoal{ + flags: LoginDefault, + url: url, + } + c.state = StateURLVisitRequired + c.mu.Unlock() + + c.sendStatus("authRoutine-url", err, url, nil) + if goal.url == url { + // The server sent us the same URL we already tried, + // backoff to avoid a busy loop. + bo.BackOff(ctx, errors.New("login URL not changing")) + } else { + bo.BackOff(ctx, nil) + } + continue + } + + // success + health.SetAuthRoutineInError(nil) + c.mu.Lock() + c.urlToVisit = "" + c.loggedIn = true + c.loginGoal = nil + c.state = StateAuthenticated + c.mu.Unlock() + + c.sendStatus("authRoutine-success", nil, "", nil) + c.restartMap() + bo.BackOff(ctx, nil) } } @@ -477,12 +416,12 @@ func (c *Auto) DirectForTest() *Direct { return c.direct } -// unpausedChanLocked returns a new channel that is closed when the -// current Auto pause is unpaused. +// unpausedChanLocked returns a new channel that gets sent +// either a true when unpaused or false on Auto.Shutdown. // // c.mu must be held -func (c *Auto) unpausedChanLocked() <-chan struct{} { - unpaused := make(chan struct{}) +func (c *Auto) unpausedChanLocked() <-chan bool { + unpaused := make(chan bool, 1) c.unpauseWaiters = append(c.unpauseWaiters, unpaused) return unpaused } @@ -498,7 +437,7 @@ func (mrs mapRoutineState) UpdateFullNetmap(nm *netmap.NetworkMap) { c.mu.Lock() ctx := c.mapCtx - c.synced = true + c.inMapPoll = true if c.loggedIn { c.state = StateSynchronized } @@ -524,7 +463,7 @@ func (c *Auto) mapRoutine() { } for { - if err := c.waitUnpause("mapRoutine"); err != nil { + if !c.waitUnpause("mapRoutine") { c.logf("mapRoutine: exiting") return } @@ -535,13 +474,6 @@ func (c *Auto) mapRoutine() { ctx := c.mapCtx c.mu.Unlock() - select { - case <-c.quit: - c.logf("mapRoutine: quit") - return - default: - } - report := func(err error, msg string) { c.logf("[v1] %s: %v", msg, err) err = fmt.Errorf("%s: %w", msg, err) @@ -555,40 +487,33 @@ func (c *Auto) mapRoutine() { if !loggedIn { // Wait for something interesting to happen c.mu.Lock() - c.synced = false - // c.state is set by authRoutine() + c.inMapPoll = false c.mu.Unlock() - select { - case <-ctx.Done(): - c.logf("[v1] mapRoutine: context done.") - case <-c.newMapCh: - c.logf("[v1] mapRoutine: new map needed while idle.") - } - } else { - health.SetOutOfPollNetMap() - - err := c.direct.PollNetMap(ctx, mrs) - - health.SetOutOfPollNetMap() - c.mu.Lock() - c.synced = false - if c.state == StateSynchronized { - c.state = StateAuthenticated - } - paused := c.paused - c.mu.Unlock() - - if paused { - mrs.bo.BackOff(ctx, nil) - c.logf("mapRoutine: paused") - continue - } - - report(err, "PollNetMap") - mrs.bo.BackOff(ctx, err) + <-ctx.Done() + c.logf("[v1] mapRoutine: context done.") continue } + health.SetOutOfPollNetMap() + + err := c.direct.PollNetMap(ctx, mrs) + + health.SetOutOfPollNetMap() + c.mu.Lock() + c.inMapPoll = false + if c.state == StateSynchronized { + c.state = StateAuthenticated + } + paused := c.paused + c.mu.Unlock() + + if paused { + mrs.bo.BackOff(ctx, nil) + c.logf("mapRoutine: paused") + } else { + mrs.bo.BackOff(ctx, err) + report(err, "PollNetMap") + } } } @@ -637,6 +562,7 @@ func (c *Auto) SetTKAHead(headHash string) { c.updateControl() } +// sendStatus can not be called with the c.mu held. func (c *Auto) sendStatus(who string, err error, url string, nm *netmap.NetworkMap) { c.mu.Lock() if c.closed { @@ -645,13 +571,13 @@ func (c *Auto) sendStatus(who string, err error, url string, nm *netmap.NetworkM } state := c.state loggedIn := c.loggedIn - synced := c.synced + inMapPoll := c.inMapPoll c.mu.Unlock() c.logf("[v1] sendStatus: %s: %v", who, state) var p persist.PersistView - if nm != nil && loggedIn && synced { + if nm != nil && loggedIn && inMapPoll { p = c.direct.GetPersist() } else { // don't send netmap status, as it's misleading when we're @@ -677,40 +603,45 @@ func (c *Auto) Login(t *tailcfg.Oauth2Token, flags LoginFlags) { c.logf("client.Login(%v, %v)", t != nil, flags) c.mu.Lock() - c.loginGoal = &LoginGoal{ - wantLoggedIn: true, - token: t, - flags: flags, + defer c.mu.Unlock() + if c.closed { + return } - c.mu.Unlock() - - c.cancelAuthCtx() + c.wantLoggedIn = true + c.loginGoal = &LoginGoal{ + token: t, + flags: flags, + } + c.cancelMapCtxLocked() + c.cancelAuthCtxLocked() } +var ErrClientClosed = errors.New("client closed") + func (c *Auto) Logout(ctx context.Context) error { c.logf("client.Logout()") - - errc := make(chan error, 1) - c.mu.Lock() - c.loginGoal = &LoginGoal{ - wantLoggedIn: false, - loggedOutResult: errc, - } + c.wantLoggedIn = false + c.loginGoal = nil + closed := c.closed c.mu.Unlock() - c.cancelAuthCtx() - c.cancelMapCtx() - timer, timerChannel := c.clock.NewTimer(10 * time.Second) - defer timer.Stop() - select { - case err := <-errc: - return err - case <-ctx.Done(): - return ctx.Err() - case <-timerChannel: - return context.DeadlineExceeded + if closed { + return ErrClientClosed } + + if err := c.direct.TryLogout(ctx); err != nil { + return err + } + c.mu.Lock() + c.loggedIn = false + c.state = StateNotAuthenticated + c.cancelAuthCtxLocked() + c.cancelMapCtxLocked() + c.mu.Unlock() + + c.sendStatus("authRoutine-wantout", nil, "", nil) + return nil } func (c *Auto) SetExpirySooner(ctx context.Context, expiry time.Time) error { @@ -738,14 +669,16 @@ func (c *Auto) Shutdown() { c.closed = true c.cancelAuthCtxLocked() c.cancelMapCtxLocked() - go c.observerQueue.shutdown() + for _, w := range c.unpauseWaiters { + w <- false + } + c.unpauseWaiters = nil } c.mu.Unlock() c.logf("client.Shutdown") if !closed { c.unregisterHealthWatch() - close(c.quit) <-c.authDone <-c.mapDone <-c.updateDone diff --git a/control/controlclient/controlclient_test.go b/control/controlclient/controlclient_test.go index e02b784ef..bb0b598e6 100644 --- a/control/controlclient/controlclient_test.go +++ b/control/controlclient/controlclient_test.go @@ -50,12 +50,7 @@ func TestStatusEqual(t *testing.T) { true, }, { - &Status{state: StateNew}, - &Status{state: StateNew}, - true, - }, - { - &Status{state: StateNew}, + &Status{}, &Status{state: StateAuthenticated}, false, },