// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved. // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. package controlclient import ( "context" "fmt" "sync" "time" "tailscale.com/health" "tailscale.com/logtail/backoff" "tailscale.com/tailcfg" "tailscale.com/types/empty" "tailscale.com/types/logger" "tailscale.com/types/netmap" "tailscale.com/types/persist" "tailscale.com/types/structs" "tailscale.com/types/wgkey" ) type LoginGoal struct { _ structs.Incomparable wantLoggedIn bool // true if we *want* to be logged in token *tailcfg.Oauth2Token // oauth token to use when logging in flags LoginFlags // flags to use when logging in url string // auth url that needs to be visited loggedOutResult chan<- error } func (g *LoginGoal) sendLogoutError(err error) { if g.loggedOutResult == nil { return } select { case g.loggedOutResult <- err: default: } } // Auto connects to a tailcontrol server for a node. // It's a concrete implementation of the Client interface. type Auto struct { direct *Direct // our interface to the server APIs timeNow func() time.Time logf logger.Logf expiry *time.Time closed bool newMapCh chan struct{} // readable when we must restart a map request unregisterHealthWatch func() mu sync.Mutex // mutex guards the following fields statusFunc func(Status) // called to update Client status paused bool // whether we should stop making HTTP requests unpauseWaiters []chan struct{} loggedIn bool // true if currently logged in loginGoal *LoginGoal // non-nil if some login activity is desired synced bool // true if our netmap is up-to-date hostinfo *tailcfg.Hostinfo inPollNetMap bool // true if currently running a PollNetMap inLiteMapUpdate bool // true if a lite (non-streaming) map request is outstanding inSendStatus int // number of sendStatus calls currently in progress state State authCtx context.Context // context used for auth requests mapCtx context.Context // context used for netmap requests authCancel func() // cancel the auth context mapCancel func() // cancel the netmap context quit chan struct{} // when closed, goroutines should all exit authDone chan struct{} // when closed, auth goroutine is done mapDone chan struct{} // when closed, map goroutine is done } // New creates and starts a new Auto. func New(opts Options) (*Auto, error) { c, err := NewNoStart(opts) if c != nil { c.Start() } return c, err } // NewNoStart creates a new Auto, but without calling Start on it. func NewNoStart(opts Options) (*Auto, error) { direct, err := NewDirect(opts) if err != nil { return nil, err } if opts.Logf == nil { opts.Logf = func(fmt string, args ...interface{}) {} } if opts.TimeNow == nil { opts.TimeNow = time.Now } c := &Auto{ direct: direct, timeNow: opts.TimeNow, logf: opts.Logf, newMapCh: make(chan struct{}, 1), quit: make(chan struct{}), authDone: make(chan struct{}), mapDone: make(chan struct{}), } c.authCtx, c.authCancel = context.WithCancel(context.Background()) c.mapCtx, c.mapCancel = context.WithCancel(context.Background()) c.unregisterHealthWatch = health.RegisterWatcher(c.onHealthChange) return c, nil } func (c *Auto) onHealthChange(sys health.Subsystem, err error) { if sys == health.SysOverall { return } c.logf("controlclient: restarting map request for %q health change to new state: %v", sys, err) c.cancelMapSafely() } // SetPaused controls whether HTTP activity should be paused. // // The client can be paused and unpaused repeatedly, unlike Start and Shutdown, which can only be used once. func (c *Auto) SetPaused(paused bool) { c.mu.Lock() defer c.mu.Unlock() if paused == c.paused { return } c.logf("setPaused(%v)", paused) c.paused = paused if paused { // Only cancel the map routine. (The auth routine isn't expensive // so it's fine to keep it running.) c.cancelMapLocked() } else { for _, ch := range c.unpauseWaiters { close(ch) } c.unpauseWaiters = nil } } // Start starts the client's goroutines. // // It should only be called for clients created by NewNoStart. func (c *Auto) Start() { go c.authRoutine() go c.mapRoutine() } // sendNewMapRequest either sends a new OmitPeers, non-streaming map request // (to just send Hostinfo/Netinfo/Endpoints info, while keeping an existing // streaming response open), or start a new streaming one if necessary. // // It should be called whenever there's something new to tell the server. func (c *Auto) sendNewMapRequest() { c.mu.Lock() // If we're not already streaming a netmap, or if we're already stuck // in a lite update, then tear down everything and start a new stream // (which starts by sending a new map request) if !c.inPollNetMap || c.inLiteMapUpdate || !c.loggedIn { c.mu.Unlock() c.cancelMapSafely() return } // Otherwise, send a lite update that doesn't keep a // long-running stream response. defer c.mu.Unlock() c.inLiteMapUpdate = true ctx, cancel := context.WithTimeout(c.mapCtx, 10*time.Second) go func() { defer cancel() t0 := time.Now() err := c.direct.SendLiteMapUpdate(ctx) d := time.Since(t0).Round(time.Millisecond) c.mu.Lock() c.inLiteMapUpdate = false c.mu.Unlock() if err == nil { c.logf("[v1] successful lite map update in %v", d) return } if ctx.Err() == nil { c.logf("lite map update after %v: %v", d, err) } // Fall back to restarting the long-polling map // request (the old heavy way) if the lite update // failed for any reason. c.cancelMapSafely() }() } func (c *Auto) cancelAuth() { c.mu.Lock() if c.authCancel != nil { c.authCancel() } if !c.closed { c.authCtx, c.authCancel = context.WithCancel(context.Background()) } c.mu.Unlock() } func (c *Auto) cancelMapLocked() { if c.mapCancel != nil { c.mapCancel() } if !c.closed { c.mapCtx, c.mapCancel = context.WithCancel(context.Background()) } } func (c *Auto) cancelMapUnsafely() { c.mu.Lock() c.cancelMapLocked() c.mu.Unlock() } func (c *Auto) cancelMapSafely() { c.mu.Lock() defer c.mu.Unlock() c.logf("[v1] cancelMapSafely: synced=%v", c.synced) if c.inPollNetMap { // received at least one netmap since the last // interruption. That means the server has already // fully processed our last request, which might // include UpdateEndpoints(). Interrupt it and try // again. c.cancelMapLocked() } else { // !synced means we either haven't done a netmap // request yet, or it hasn't answered yet. So the // server is in an undefined state. If we send // another netmap request too soon, it might race // with the last one, and if we're very unlucky, // the new request will be applied before the old one, // and the wrong endpoints will get registered. We // have to tell the client to abort politely, only // after it receives a response to its existing netmap // request. select { case c.newMapCh <- struct{}{}: c.logf("[v1] cancelMapSafely: wrote to channel") default: // if channel write failed, then there was already // an outstanding newMapCh request. One is enough, // since it'll always use the latest endpoints. c.logf("[v1] cancelMapSafely: channel was full") } } } func (c *Auto) authRoutine() { defer close(c.authDone) bo := backoff.NewBackoff("authRoutine", c.logf, 30*time.Second) for { c.mu.Lock() goal := c.loginGoal ctx := c.authCtx if goal != nil { c.logf("authRoutine: %s; wantLoggedIn=%v", c.state, goal.wantLoggedIn) } else { c.logf("authRoutine: %s; goal=nil paused=%v", c.state, c.paused) } c.mu.Unlock() select { case <-c.quit: c.logf("[v1] authRoutine: quit") return default: } report := func(err error, msg string) { c.logf("[v1] %s: %v", msg, err) // don't send status updates for context errors, // since context cancelation is always on purpose. if ctx.Err() == nil { c.sendStatus("authRoutine-report", err, "", nil) } } if goal == nil { // Wait for user to Login or Logout. <-ctx.Done() c.logf("[v1] authRoutine: context done.") continue } if !goal.wantLoggedIn { err := c.direct.TryLogout(ctx) goal.sendLogoutError(err) if err != nil { report(err, "TryLogout") bo.BackOff(ctx, err) continue } // success c.mu.Lock() c.loggedIn = false c.loginGoal = nil c.state = StateNotAuthenticated c.synced = false c.mu.Unlock() c.sendStatus("authRoutine-wantout", nil, "", nil) bo.BackOff(ctx, nil) } else { // ie. goal.wantLoggedIn c.mu.Lock() if goal.url != "" { c.state = StateURLVisitRequired } else { c.state = StateAuthenticating } c.mu.Unlock() var url string var err error var f string if goal.url != "" { url, err = c.direct.WaitLoginURL(ctx, goal.url) f = "WaitLoginURL" } else { url, err = c.direct.TryLogin(ctx, goal.token, goal.flags) f = "TryLogin" } if err != nil { report(err, f) bo.BackOff(ctx, err) continue } if url != "" { if goal.url != "" { err = fmt.Errorf("[unexpected] server required a new URL?") report(err, "WaitLoginURL") } c.mu.Lock() c.loginGoal = &LoginGoal{ wantLoggedIn: true, flags: LoginDefault, url: url, } c.state = StateURLVisitRequired c.synced = false c.mu.Unlock() c.sendStatus("authRoutine-url", err, url, nil) bo.BackOff(ctx, err) continue } // success c.mu.Lock() c.loggedIn = true c.loginGoal = nil c.state = StateAuthenticated c.mu.Unlock() c.sendStatus("authRoutine-success", nil, "", nil) c.cancelMapSafely() bo.BackOff(ctx, nil) } } } // Expiry returns the credential expiration time, or the zero time if // the expiration time isn't known. Used in tests only. func (c *Auto) Expiry() *time.Time { c.mu.Lock() defer c.mu.Unlock() return c.expiry } // Direct returns the underlying direct client object. Used in tests // only. func (c *Auto) Direct() *Direct { return c.direct } // unpausedChanLocked returns a new channel that is closed when the // current Auto pause is unpaused. // // c.mu must be held func (c *Auto) unpausedChanLocked() <-chan struct{} { unpaused := make(chan struct{}) c.unpauseWaiters = append(c.unpauseWaiters, unpaused) return unpaused } func (c *Auto) mapRoutine() { defer close(c.mapDone) bo := backoff.NewBackoff("mapRoutine", c.logf, 30*time.Second) for { c.mu.Lock() if c.paused { unpaused := c.unpausedChanLocked() c.mu.Unlock() c.logf("mapRoutine: awaiting unpause") select { case <-unpaused: c.logf("mapRoutine: unpaused") case <-c.quit: c.logf("mapRoutine: quit") return } continue } c.logf("mapRoutine: %s", c.state) loggedIn := c.loggedIn ctx := c.mapCtx c.mu.Unlock() select { case <-c.quit: c.logf("mapRoutine: quit") return default: } report := func(err error, msg string) { c.logf("[v1] %s: %v", msg, err) err = fmt.Errorf("%s: %w", msg, err) // don't send status updates for context errors, // since context cancelation is always on purpose. if ctx.Err() == nil { c.sendStatus("mapRoutine1", err, "", nil) } } if !loggedIn { // Wait for something interesting to happen c.mu.Lock() c.synced = false // c.state is set by authRoutine() c.mu.Unlock() select { case <-ctx.Done(): c.logf("mapRoutine: context done.") case <-c.newMapCh: c.logf("mapRoutine: new map needed while idle.") } } else { // Be sure this is false when we're not inside // PollNetMap, so that cancelMapSafely() can notify // us correctly. c.mu.Lock() c.inPollNetMap = false c.mu.Unlock() health.SetInPollNetMap(false) err := c.direct.PollNetMap(ctx, -1, func(nm *netmap.NetworkMap) { health.SetInPollNetMap(true) c.mu.Lock() select { case <-c.newMapCh: c.logf("[v1] mapRoutine: new map request during PollNetMap. canceling.") c.cancelMapLocked() // Don't emit this netmap; we're // about to request a fresh one. c.mu.Unlock() return default: } c.synced = true c.inPollNetMap = true if c.loggedIn { c.state = StateSynchronized } exp := nm.Expiry c.expiry = &exp stillAuthed := c.loggedIn state := c.state c.mu.Unlock() c.logf("[v1] mapRoutine: netmap received: %s", state) if stillAuthed { c.sendStatus("mapRoutine-got-netmap", nil, "", nm) } }) health.SetInPollNetMap(false) c.mu.Lock() c.synced = false c.inPollNetMap = false if c.state == StateSynchronized { c.state = StateAuthenticated } paused := c.paused c.mu.Unlock() if paused { c.logf("mapRoutine: paused") continue } if err != nil { report(err, "PollNetMap") bo.BackOff(ctx, err) continue } bo.BackOff(ctx, nil) } } } func (c *Auto) AuthCantContinue() bool { if c == nil { return true } c.mu.Lock() defer c.mu.Unlock() return !c.loggedIn && (c.loginGoal == nil || c.loginGoal.url != "") } // SetStatusFunc sets fn as the callback to run on any status change. func (c *Auto) SetStatusFunc(fn func(Status)) { c.mu.Lock() c.statusFunc = fn c.mu.Unlock() } func (c *Auto) SetHostinfo(hi *tailcfg.Hostinfo) { if hi == nil { panic("nil Hostinfo") } if !c.direct.SetHostinfo(hi) { // No changes. Don't log. return } // Send new Hostinfo to server c.sendNewMapRequest() } func (c *Auto) SetNetInfo(ni *tailcfg.NetInfo) { if ni == nil { panic("nil NetInfo") } if !c.direct.SetNetInfo(ni) { return } c.logf("NetInfo: %v", ni) // Send new Hostinfo (which includes NetInfo) to server c.sendNewMapRequest() } func (c *Auto) sendStatus(who string, err error, url string, nm *netmap.NetworkMap) { c.mu.Lock() state := c.state loggedIn := c.loggedIn synced := c.synced statusFunc := c.statusFunc hi := c.hostinfo c.inSendStatus++ c.mu.Unlock() c.logf("[v1] sendStatus: %s: %v", who, state) var p *persist.Persist var loginFin, logoutFin *empty.Message if state == StateAuthenticated { loginFin = new(empty.Message) } if state == StateNotAuthenticated { logoutFin = new(empty.Message) } if nm != nil && loggedIn && synced { pp := c.direct.GetPersist() p = &pp } else { // don't send netmap status, as it's misleading when we're // not logged in. nm = nil } new := Status{ LoginFinished: loginFin, LogoutFinished: logoutFin, URL: url, Persist: p, NetMap: nm, Hostinfo: hi, State: state, Err: err, } if statusFunc != nil { statusFunc(new) } c.mu.Lock() c.inSendStatus-- c.mu.Unlock() } func (c *Auto) Login(t *tailcfg.Oauth2Token, flags LoginFlags) { c.logf("client.Login(%v, %v)", t != nil, flags) c.mu.Lock() c.loginGoal = &LoginGoal{ wantLoggedIn: true, token: t, flags: flags, } c.mu.Unlock() c.cancelAuth() } func (c *Auto) StartLogout() { c.logf("client.StartLogout()") c.mu.Lock() c.loginGoal = &LoginGoal{ wantLoggedIn: false, } c.mu.Unlock() c.cancelAuth() } func (c *Auto) Logout(ctx context.Context) error { c.logf("client.Logout()") errc := make(chan error, 1) c.mu.Lock() c.loginGoal = &LoginGoal{ wantLoggedIn: false, loggedOutResult: errc, } c.mu.Unlock() c.cancelAuth() timer := time.NewTimer(10 * time.Second) defer timer.Stop() select { case err := <-errc: return err case <-ctx.Done(): return ctx.Err() case <-timer.C: return context.DeadlineExceeded } } // UpdateEndpoints sets the client's discovered endpoints and sends // them to the control server if they've changed. // // It does not retain the provided slice. // // The localPort field is unused except for integration tests in // another repo. func (c *Auto) UpdateEndpoints(localPort uint16, endpoints []tailcfg.Endpoint) { changed := c.direct.SetEndpoints(localPort, endpoints) if changed { c.sendNewMapRequest() } } func (c *Auto) Shutdown() { c.logf("client.Shutdown()") c.mu.Lock() inSendStatus := c.inSendStatus closed := c.closed if !closed { c.closed = true c.statusFunc = nil } c.mu.Unlock() c.logf("client.Shutdown: inSendStatus=%v", inSendStatus) if !closed { c.unregisterHealthWatch() close(c.quit) c.cancelAuth() <-c.authDone c.cancelMapUnsafely() <-c.mapDone c.logf("Client.Shutdown done.") } } // NodePublicKey returns the node public key currently in use. This is // used exclusively in tests. func (c *Auto) TestOnlyNodePublicKey() wgkey.Key { priv := c.direct.GetPersist() return priv.PrivateNodeKey.Public().AsWGKey() } func (c *Auto) TestOnlySetAuthKey(authkey string) { c.direct.mu.Lock() defer c.direct.mu.Unlock() c.direct.authKey = authkey } func (c *Auto) TestOnlyTimeNow() time.Time { return c.timeNow() } // SetDNS sends the SetDNSRequest request to the control plane server, // requesting a DNS record be created or updated. func (c *Auto) SetDNS(ctx context.Context, req *tailcfg.SetDNSRequest) error { return c.direct.SetDNS(ctx, req) }