diff --git a/control/controlclient/direct.go b/control/controlclient/direct.go index f5d1f0410..8c501dc10 100644 --- a/control/controlclient/direct.go +++ b/control/controlclient/direct.go @@ -42,6 +42,7 @@ "tailscale.com/net/tlsdial" "tailscale.com/net/tsdial" "tailscale.com/net/tshttpproxy" + "tailscale.com/syncs" "tailscale.com/tailcfg" "tailscale.com/tka" "tailscale.com/tstime" @@ -82,6 +83,11 @@ type Direct struct { dialPlan ControlDialPlanner // can be nil + // lastServerAddr is set to the most recent address that we + // successfully connected to. It is used to prioritize this address + // when reconnecting (e.g. when a control server restart happens). + lastServerAddr syncs.AtomicValue[netip.Addr] + mu sync.Mutex // mutex guards the following fields serverLegacyKey key.MachinePublic // original ("legacy") nacl crypto_box-based public key; only used for signRegisterRequest on Windows now serverNoiseKey key.MachinePublic @@ -1428,6 +1434,8 @@ func sleepAsRequested(ctx context.Context, logf logger.Logf, d time.Duration, cl } } +var useLastAddr = envknob.RegisterBool("TS_CONTROLCLIENT_USE_LAST_ADDR") + // getNoiseClient returns the noise client, creating one if one doesn't exist. func (c *Direct) getNoiseClient() (*NoiseClient, error) { c.mu.Lock() @@ -1444,6 +1452,12 @@ func (c *Direct) getNoiseClient() (*NoiseClient, error) { if c.dialPlan != nil { dp = c.dialPlan.Load } + + var lastAddr *syncs.AtomicValue[netip.Addr] + if useLastAddr() { + lastAddr = &c.lastServerAddr + } + nc, err, _ := c.sfGroup.Do(struct{}{}, func() (*NoiseClient, error) { k, err := c.getMachinePrivKey() if err != nil { @@ -1451,18 +1465,20 @@ func (c *Direct) getNoiseClient() (*NoiseClient, error) { } c.logf("[v1] creating new noise client") nc, err := NewNoiseClient(NoiseOpts{ - PrivKey: k, - ServerPubKey: serverNoiseKey, - ServerURL: c.serverURL, - Dialer: c.dialer, - DNSCache: c.dnsCache, - Logf: c.logf, - NetMon: c.netMon, - DialPlan: dp, + PrivKey: k, + ServerPubKey: serverNoiseKey, + ServerURL: c.serverURL, + Dialer: c.dialer, + DNSCache: c.dnsCache, + Logf: c.logf, + NetMon: c.netMon, + DialPlan: dp, + LastServerAddr: lastAddr, }) if err != nil { return nil, err } + c.mu.Lock() defer c.mu.Unlock() c.noiseClient = nc diff --git a/control/controlclient/noise.go b/control/controlclient/noise.go index f3e5f1bde..3191312f5 100644 --- a/control/controlclient/noise.go +++ b/control/controlclient/noise.go @@ -12,6 +12,7 @@ "io" "math" "net/http" + "net/netip" "net/url" "sync" "time" @@ -22,6 +23,7 @@ "tailscale.com/net/dnscache" "tailscale.com/net/netmon" "tailscale.com/net/tsdial" + "tailscale.com/syncs" "tailscale.com/tailcfg" "tailscale.com/tstime" "tailscale.com/types/key" @@ -172,6 +174,8 @@ type NoiseClient struct { // be nil. dialPlan func() *tailcfg.ControlDialPlan + lastServerAddr *syncs.AtomicValue[netip.Addr] // can be nil + logf logger.Logf netMon *netmon.Monitor @@ -207,6 +211,12 @@ type NoiseOpts struct { // DialPlan, if set, is a function that should return an explicit plan // on how to connect to the server. DialPlan func() *tailcfg.ControlDialPlan + // LastServerAddr, if non-nil, contains storage for the last address + // used to (successfully) connect to the control server. It will be + // prioritized when making a connection to the server. + // + // If nil, no last address will be stored or used. + LastServerAddr *syncs.AtomicValue[netip.Addr] } // NewNoiseClient returns a new noiseClient for the provided server and machine key. @@ -237,16 +247,17 @@ func NewNoiseClient(opts NoiseOpts) (*NoiseClient, error) { } np := &NoiseClient{ - serverPubKey: opts.ServerPubKey, - privKey: opts.PrivKey, - host: u.Hostname(), - httpPort: httpPort, - httpsPort: httpsPort, - dialer: opts.Dialer, - dnsCache: opts.DNSCache, - dialPlan: opts.DialPlan, - logf: opts.Logf, - netMon: opts.NetMon, + serverPubKey: opts.ServerPubKey, + privKey: opts.PrivKey, + host: u.Hostname(), + httpPort: httpPort, + httpsPort: httpsPort, + dialer: opts.Dialer, + dnsCache: opts.DNSCache, + dialPlan: opts.DialPlan, + lastServerAddr: opts.LastServerAddr, + logf: opts.Logf, + netMon: opts.NetMon, } // Create the HTTP/2 Transport using a net/http.Transport @@ -334,6 +345,14 @@ func (nc *NoiseClient) getConn(ctx context.Context) (*noiseConn, error) { }) var ce contextErr if err == nil || !errors.As(err, &ce) { + // Store this address as our last-successful address for future + // use if we need to reconnect. + if nc.lastServerAddr != nil { + if addr, err := netip.ParseAddrPort(conn.RemoteAddr().String()); err == nil { + nc.lastServerAddr.Store(addr.Addr()) + } + } + return conn, err } if ctx.Err() == nil { @@ -429,6 +448,16 @@ func (nc *NoiseClient) dial(ctx context.Context) (*noiseConn, error) { // handshake. timeoutSec += 5 + // If we have a last server address, then give ourselves a bit more + // time to try it first. + var lastAddr netip.Addr + if nc.lastServerAddr != nil { + lastAddr = nc.lastServerAddr.Load() + } + if lastAddr.IsValid() { + timeoutSec += 5 + } + // Be extremely defensive and ensure that the timeout is in the range // [5, 60] seconds (e.g. if we accidentally get a negative number). if timeoutSec > 60 { @@ -451,6 +480,7 @@ func (nc *NoiseClient) dial(ctx context.Context) (*noiseConn, error) { Dialer: nc.dialer.SystemDial, DNSCache: nc.dnsCache, DialPlan: dialPlan, + LastServerAddr: lastAddr, Logf: nc.logf, NetMon: nc.netMon, Clock: tstime.StdClock{}, diff --git a/control/controlhttp/client.go b/control/controlhttp/client.go index fb220fd0b..b7675f254 100644 --- a/control/controlhttp/client.go +++ b/control/controlhttp/client.go @@ -95,6 +95,17 @@ func (a *Dialer) httpsFallbackDelay() time.Duration { var _ = envknob.RegisterBool("TS_USE_CONTROL_DIAL_PLAN") // to record at init time whether it's in use func (a *Dialer) dial(ctx context.Context) (*ClientConn, error) { + // If we have a last used address, try that first, but time out fairly + // aggressively in case it's actually down. + if a.LastServerAddr.IsValid() { + lastDialCtx, lastDialCancel := context.WithTimeout(ctx, 5*time.Second) + defer lastDialCancel() + conn, err := a.dialHost(lastDialCtx, a.LastServerAddr) + if err == nil { + return conn, nil + } + } + // If we don't have a dial plan, just fall back to dialing the single // host we know about. useDialPlan := envknob.BoolDefaultTrue("TS_USE_CONTROL_DIAL_PLAN") diff --git a/control/controlhttp/constants.go b/control/controlhttp/constants.go index 72161336e..130ad6be2 100644 --- a/control/controlhttp/constants.go +++ b/control/controlhttp/constants.go @@ -5,6 +5,7 @@ import ( "net/http" + "net/netip" "net/url" "time" @@ -84,6 +85,11 @@ type Dialer struct { // plan before falling back to DNS. DialPlan *tailcfg.ControlDialPlan + // LastServerAddr, if valid, is the address that was last used to + // (successfully) connect to the control server. It will be prioritized + // when making a connection to the server. + LastServerAddr netip.Addr + proxyFunc func(*http.Request) (*url.URL, error) // or nil // For tests only