control/controlclient: try reconnecting to last successful addr

If we lose our connection to the control server (e.g. due to a restart,
or a network blip, etc), try reconnecting to the same address first
before going through the whole control dialplan and/or DNS flow.

This ensures that we're a bit "sticky", which makes load balancing
easier by improving the chances that this client will hit a server with
a warm cache. It also reduces the thundering herd of requests that hit
other servers after we restart a single one.

Updates #TODO

Signed-off-by: Andrew Dunham <andrew@du.nham.ca>
Change-Id: I6c3ef0b088468a8888c05cf8e3813056118ec835
This commit is contained in:
Andrew Dunham 2024-04-04 16:35:52 -04:00
parent 853e3e29a0
commit b8f89c93ac
4 changed files with 81 additions and 18 deletions

View File

@ -42,6 +42,7 @@ import (
"tailscale.com/net/tlsdial" "tailscale.com/net/tlsdial"
"tailscale.com/net/tsdial" "tailscale.com/net/tsdial"
"tailscale.com/net/tshttpproxy" "tailscale.com/net/tshttpproxy"
"tailscale.com/syncs"
"tailscale.com/tailcfg" "tailscale.com/tailcfg"
"tailscale.com/tka" "tailscale.com/tka"
"tailscale.com/tstime" "tailscale.com/tstime"
@ -82,6 +83,11 @@ type Direct struct {
dialPlan ControlDialPlanner // can be nil dialPlan ControlDialPlanner // can be nil
// lastServerAddr is set to the most recent address that we
// successfully connected to. It is used to prioritize this address
// when reconnecting (e.g. when a control server restart happens).
lastServerAddr syncs.AtomicValue[netip.Addr]
mu sync.Mutex // mutex guards the following fields mu sync.Mutex // mutex guards the following fields
serverLegacyKey key.MachinePublic // original ("legacy") nacl crypto_box-based public key; only used for signRegisterRequest on Windows now serverLegacyKey key.MachinePublic // original ("legacy") nacl crypto_box-based public key; only used for signRegisterRequest on Windows now
serverNoiseKey key.MachinePublic serverNoiseKey key.MachinePublic
@ -1428,6 +1434,8 @@ func sleepAsRequested(ctx context.Context, logf logger.Logf, d time.Duration, cl
} }
} }
var useLastAddr = envknob.RegisterBool("TS_CONTROLCLIENT_USE_LAST_ADDR")
// getNoiseClient returns the noise client, creating one if one doesn't exist. // getNoiseClient returns the noise client, creating one if one doesn't exist.
func (c *Direct) getNoiseClient() (*NoiseClient, error) { func (c *Direct) getNoiseClient() (*NoiseClient, error) {
c.mu.Lock() c.mu.Lock()
@ -1444,6 +1452,12 @@ func (c *Direct) getNoiseClient() (*NoiseClient, error) {
if c.dialPlan != nil { if c.dialPlan != nil {
dp = c.dialPlan.Load dp = c.dialPlan.Load
} }
var lastAddr *syncs.AtomicValue[netip.Addr]
if useLastAddr() {
lastAddr = &c.lastServerAddr
}
nc, err, _ := c.sfGroup.Do(struct{}{}, func() (*NoiseClient, error) { nc, err, _ := c.sfGroup.Do(struct{}{}, func() (*NoiseClient, error) {
k, err := c.getMachinePrivKey() k, err := c.getMachinePrivKey()
if err != nil { if err != nil {
@ -1451,18 +1465,20 @@ func (c *Direct) getNoiseClient() (*NoiseClient, error) {
} }
c.logf("[v1] creating new noise client") c.logf("[v1] creating new noise client")
nc, err := NewNoiseClient(NoiseOpts{ nc, err := NewNoiseClient(NoiseOpts{
PrivKey: k, PrivKey: k,
ServerPubKey: serverNoiseKey, ServerPubKey: serverNoiseKey,
ServerURL: c.serverURL, ServerURL: c.serverURL,
Dialer: c.dialer, Dialer: c.dialer,
DNSCache: c.dnsCache, DNSCache: c.dnsCache,
Logf: c.logf, Logf: c.logf,
NetMon: c.netMon, NetMon: c.netMon,
DialPlan: dp, DialPlan: dp,
LastServerAddr: lastAddr,
}) })
if err != nil { if err != nil {
return nil, err return nil, err
} }
c.mu.Lock() c.mu.Lock()
defer c.mu.Unlock() defer c.mu.Unlock()
c.noiseClient = nc c.noiseClient = nc

View File

@ -12,6 +12,7 @@ import (
"io" "io"
"math" "math"
"net/http" "net/http"
"net/netip"
"net/url" "net/url"
"sync" "sync"
"time" "time"
@ -22,6 +23,7 @@ import (
"tailscale.com/net/dnscache" "tailscale.com/net/dnscache"
"tailscale.com/net/netmon" "tailscale.com/net/netmon"
"tailscale.com/net/tsdial" "tailscale.com/net/tsdial"
"tailscale.com/syncs"
"tailscale.com/tailcfg" "tailscale.com/tailcfg"
"tailscale.com/tstime" "tailscale.com/tstime"
"tailscale.com/types/key" "tailscale.com/types/key"
@ -172,6 +174,8 @@ type NoiseClient struct {
// be nil. // be nil.
dialPlan func() *tailcfg.ControlDialPlan dialPlan func() *tailcfg.ControlDialPlan
lastServerAddr *syncs.AtomicValue[netip.Addr] // can be nil
logf logger.Logf logf logger.Logf
netMon *netmon.Monitor netMon *netmon.Monitor
@ -207,6 +211,12 @@ type NoiseOpts struct {
// DialPlan, if set, is a function that should return an explicit plan // DialPlan, if set, is a function that should return an explicit plan
// on how to connect to the server. // on how to connect to the server.
DialPlan func() *tailcfg.ControlDialPlan DialPlan func() *tailcfg.ControlDialPlan
// LastServerAddr, if non-nil, contains storage for the last address
// used to (successfully) connect to the control server. It will be
// prioritized when making a connection to the server.
//
// If nil, no last address will be stored or used.
LastServerAddr *syncs.AtomicValue[netip.Addr]
} }
// NewNoiseClient returns a new noiseClient for the provided server and machine key. // NewNoiseClient returns a new noiseClient for the provided server and machine key.
@ -237,16 +247,17 @@ func NewNoiseClient(opts NoiseOpts) (*NoiseClient, error) {
} }
np := &NoiseClient{ np := &NoiseClient{
serverPubKey: opts.ServerPubKey, serverPubKey: opts.ServerPubKey,
privKey: opts.PrivKey, privKey: opts.PrivKey,
host: u.Hostname(), host: u.Hostname(),
httpPort: httpPort, httpPort: httpPort,
httpsPort: httpsPort, httpsPort: httpsPort,
dialer: opts.Dialer, dialer: opts.Dialer,
dnsCache: opts.DNSCache, dnsCache: opts.DNSCache,
dialPlan: opts.DialPlan, dialPlan: opts.DialPlan,
logf: opts.Logf, lastServerAddr: opts.LastServerAddr,
netMon: opts.NetMon, logf: opts.Logf,
netMon: opts.NetMon,
} }
// Create the HTTP/2 Transport using a net/http.Transport // Create the HTTP/2 Transport using a net/http.Transport
@ -334,6 +345,14 @@ func (nc *NoiseClient) getConn(ctx context.Context) (*noiseConn, error) {
}) })
var ce contextErr var ce contextErr
if err == nil || !errors.As(err, &ce) { if err == nil || !errors.As(err, &ce) {
// Store this address as our last-successful address for future
// use if we need to reconnect.
if nc.lastServerAddr != nil {
if addr, err := netip.ParseAddrPort(conn.RemoteAddr().String()); err == nil {
nc.lastServerAddr.Store(addr.Addr())
}
}
return conn, err return conn, err
} }
if ctx.Err() == nil { if ctx.Err() == nil {
@ -429,6 +448,16 @@ func (nc *NoiseClient) dial(ctx context.Context) (*noiseConn, error) {
// handshake. // handshake.
timeoutSec += 5 timeoutSec += 5
// If we have a last server address, then give ourselves a bit more
// time to try it first.
var lastAddr netip.Addr
if nc.lastServerAddr != nil {
lastAddr = nc.lastServerAddr.Load()
}
if lastAddr.IsValid() {
timeoutSec += 5
}
// Be extremely defensive and ensure that the timeout is in the range // Be extremely defensive and ensure that the timeout is in the range
// [5, 60] seconds (e.g. if we accidentally get a negative number). // [5, 60] seconds (e.g. if we accidentally get a negative number).
if timeoutSec > 60 { if timeoutSec > 60 {
@ -451,6 +480,7 @@ func (nc *NoiseClient) dial(ctx context.Context) (*noiseConn, error) {
Dialer: nc.dialer.SystemDial, Dialer: nc.dialer.SystemDial,
DNSCache: nc.dnsCache, DNSCache: nc.dnsCache,
DialPlan: dialPlan, DialPlan: dialPlan,
LastServerAddr: lastAddr,
Logf: nc.logf, Logf: nc.logf,
NetMon: nc.netMon, NetMon: nc.netMon,
Clock: tstime.StdClock{}, Clock: tstime.StdClock{},

View File

@ -95,6 +95,17 @@ func (a *Dialer) httpsFallbackDelay() time.Duration {
var _ = envknob.RegisterBool("TS_USE_CONTROL_DIAL_PLAN") // to record at init time whether it's in use var _ = envknob.RegisterBool("TS_USE_CONTROL_DIAL_PLAN") // to record at init time whether it's in use
func (a *Dialer) dial(ctx context.Context) (*ClientConn, error) { func (a *Dialer) dial(ctx context.Context) (*ClientConn, error) {
// If we have a last used address, try that first, but time out fairly
// aggressively in case it's actually down.
if a.LastServerAddr.IsValid() {
lastDialCtx, lastDialCancel := context.WithTimeout(ctx, 5*time.Second)
defer lastDialCancel()
conn, err := a.dialHost(lastDialCtx, a.LastServerAddr)
if err == nil {
return conn, nil
}
}
// If we don't have a dial plan, just fall back to dialing the single // If we don't have a dial plan, just fall back to dialing the single
// host we know about. // host we know about.
useDialPlan := envknob.BoolDefaultTrue("TS_USE_CONTROL_DIAL_PLAN") useDialPlan := envknob.BoolDefaultTrue("TS_USE_CONTROL_DIAL_PLAN")

View File

@ -5,6 +5,7 @@ package controlhttp
import ( import (
"net/http" "net/http"
"net/netip"
"net/url" "net/url"
"time" "time"
@ -84,6 +85,11 @@ type Dialer struct {
// plan before falling back to DNS. // plan before falling back to DNS.
DialPlan *tailcfg.ControlDialPlan DialPlan *tailcfg.ControlDialPlan
// LastServerAddr, if valid, is the address that was last used to
// (successfully) connect to the control server. It will be prioritized
// when making a connection to the server.
LastServerAddr netip.Addr
proxyFunc func(*http.Request) (*url.URL, error) // or nil proxyFunc func(*http.Request) (*url.URL, error) // or nil
// For tests only // For tests only