derp/derphttp: fix race in mesh watcher

The derphttp client automatically reconnects upon failure.

RunWatchConnectionLoop called derphttp.Client.WatchConnectionChanges
once, but that wrapper method called the underlying
derp.Client.WatchConnectionChanges exactly once on derphttp.Client's
currently active connection. If there's a failure, we need to re-subscribe
upon all reconnections.

This removes the derphttp.Client.WatchConnectionChanges method, which
was basically impossible to use correctly, and changes it to be a
boolean field on derphttp.Client alongside MeshKey and IsProber. Then
it moves the call to the underlying derp.Client.WatchConnectionChanges
to derphttp's client connection code, so it's resubscribed on any
reconnect.

Some paranoia is then added to make sure people hold the API right,
not calling derphttp.Client.RunWatchConnectionLoop on an
already-started Client without having set the bool to true. (But still
auto-setting it to true if that's the first method that's been called
on that derphttp.Client, as is commonly the case, and prevents
existing code from breaking)

Fixes tailscale/corp#9916
Supercedes tailscale/tailscale#9719

Co-authored-by: Val <valerie@tailscale.com>
Co-authored-by: Irbe Krumina <irbe@tailscale.com>
Co-authored-by: Anton Tolchanov <anton@tailscale.com>
Signed-off-by: Brad Fitzpatrick <brad@danga.com>
This commit is contained in:
Brad Fitzpatrick 2023-10-25 11:59:06 -07:00 committed by Anton Tolchanov
parent df4b730438
commit 3d7fb6c21d
3 changed files with 43 additions and 32 deletions

View File

@ -41,6 +41,7 @@ func startMeshWithHost(s *derp.Server, host string) error {
return err return err
} }
c.MeshKey = s.MeshKey() c.MeshKey = s.MeshKey()
c.WatchConnectionChanges = true
// For meshed peers within a region, connect via VPC addresses. // For meshed peers within a region, connect via VPC addresses.
c.SetURLDialer(func(ctx context.Context, network, addr string) (net.Conn, error) { c.SetURLDialer(func(ctx context.Context, network, addr string) (net.Conn, error) {

View File

@ -56,6 +56,12 @@ type Client struct {
MeshKey string // optional; for trusted clients MeshKey string // optional; for trusted clients
IsProber bool // optional; for probers to optional declare themselves as such IsProber bool // optional; for probers to optional declare themselves as such
// WatchConnectionChanges is whether the client wishes to subscribe to
// notifications about clients connecting & disconnecting.
//
// Only trusted connections (using MeshKey) are allowed to use this.
WatchConnectionChanges bool
// BaseContext, if non-nil, returns the base context to use for dialing a // BaseContext, if non-nil, returns the base context to use for dialing a
// new derp server. If nil, context.Background is used. // new derp server. If nil, context.Background is used.
// In either case, additional timeouts may be added to the base context. // In either case, additional timeouts may be added to the base context.
@ -80,6 +86,7 @@ type Client struct {
addrFamSelAtomic syncs.AtomicValue[AddressFamilySelector] addrFamSelAtomic syncs.AtomicValue[AddressFamilySelector]
mu sync.Mutex mu sync.Mutex
started bool // true upon first connect, never transitions to false
preferred bool preferred bool
canAckPings bool canAckPings bool
closed bool closed bool
@ -142,6 +149,15 @@ func NewClient(privateKey key.NodePrivate, serverURL string, logf logger.Logf) (
return c, nil return c, nil
} }
// isStarted reports whether this client has been used yet.
//
// If if reports false, it may still have its exported fields configured.
func (c *Client) isStarted() bool {
c.mu.Lock()
defer c.mu.Unlock()
return c.started
}
// Connect connects or reconnects to the server, unless already connected. // Connect connects or reconnects to the server, unless already connected.
// It returns nil if there was already a good connection, or if one was made. // It returns nil if there was already a good connection, or if one was made.
func (c *Client) Connect(ctx context.Context) error { func (c *Client) Connect(ctx context.Context) error {
@ -284,6 +300,7 @@ func useWebsockets() bool {
func (c *Client) connect(ctx context.Context, caller string) (client *derp.Client, connGen int, err error) { func (c *Client) connect(ctx context.Context, caller string) (client *derp.Client, connGen int, err error) {
c.mu.Lock() c.mu.Lock()
defer c.mu.Unlock() defer c.mu.Unlock()
c.started = true
if c.closed { if c.closed {
return nil, 0, ErrClientClosed return nil, 0, ErrClientClosed
} }
@ -495,6 +512,13 @@ func (c *Client) connect(ctx context.Context, caller string) (client *derp.Clien
} }
} }
if c.WatchConnectionChanges {
if err := derpClient.WatchConnectionChanges(); err != nil {
go httpConn.Close()
return nil, 0, err
}
}
c.serverPubKey = derpClient.ServerPublicKey() c.serverPubKey = derpClient.ServerPublicKey()
c.client = derpClient c.client = derpClient
c.netConn = tcpConn c.netConn = tcpConn
@ -956,22 +980,6 @@ func (c *Client) NotePreferred(v bool) {
} }
} }
// WatchConnectionChanges sends a request to subscribe to
// notifications about clients connecting & disconnecting.
//
// Only trusted connections (using MeshKey) are allowed to use this.
func (c *Client) WatchConnectionChanges() error {
client, _, err := c.connect(c.newContext(), "derphttp.Client.WatchConnectionChanges")
if err != nil {
return err
}
err = client.WatchConnectionChanges()
if err != nil {
c.closeForReconnect(client)
}
return err
}
// ClosePeer asks the server to close target's TCP connection. // ClosePeer asks the server to close target's TCP connection.
// //
// Only trusted connections (using MeshKey) are allowed to use this. // Only trusted connections (using MeshKey) are allowed to use this.

View File

@ -14,20 +14,30 @@
"tailscale.com/types/logger" "tailscale.com/types/logger"
) )
// RunWatchConnectionLoop loops until ctx is done, sending WatchConnectionChanges and subscribing to // RunWatchConnectionLoop loops until ctx is done, sending
// connection changes. // WatchConnectionChanges and subscribing to connection changes.
// //
// If the server's public key is ignoreServerKey, RunWatchConnectionLoop returns. // If the server's public key is ignoreServerKey, RunWatchConnectionLoop
// returns.
// //
// Otherwise, the add and remove funcs are called as clients come & go. // Otherwise, the add and remove funcs are called as clients come & go.
// //
// infoLogf, if non-nil, is the logger to write periodic status // infoLogf, if non-nil, is the logger to write periodic status updates about
// updates about how many peers are on the server. Error log output is // how many peers are on the server. Error log output is set to the c's logger,
// set to the c's logger, regardless of infoLogf's value. // regardless of infoLogf's value.
// //
// To force RunWatchConnectionLoop to return quickly, its ctx needs to // To force RunWatchConnectionLoop to return quickly, its ctx needs to be
// be closed, and c itself needs to be closed. // closed, and c itself needs to be closed.
//
// It is a fatal error to call this on an already-started Client withoutq having
// initialized Client.WatchConnectionChanges to true.
func (c *Client) RunWatchConnectionLoop(ctx context.Context, ignoreServerKey key.NodePublic, infoLogf logger.Logf, add func(key.NodePublic, netip.AddrPort), remove func(key.NodePublic)) { func (c *Client) RunWatchConnectionLoop(ctx context.Context, ignoreServerKey key.NodePublic, infoLogf logger.Logf, add func(key.NodePublic, netip.AddrPort), remove func(key.NodePublic)) {
if !c.WatchConnectionChanges {
if c.isStarted() {
panic("invalid use of RunWatchConnectionLoop on already-started Client without setting Client.RunWatchConnectionLoop")
}
c.WatchConnectionChanges = true
}
if infoLogf == nil { if infoLogf == nil {
infoLogf = logger.Discard infoLogf = logger.Discard
} }
@ -101,14 +111,6 @@ func (c *Client) RunWatchConnectionLoop(ctx context.Context, ignoreServerKey key
} }
for ctx.Err() == nil { for ctx.Err() == nil {
err := c.WatchConnectionChanges()
if err != nil {
clear()
logf("WatchConnectionChanges: %v", err)
sleep(retryInterval)
continue
}
if c.ServerPublicKey() == ignoreServerKey { if c.ServerPublicKey() == ignoreServerKey {
logf("detected self-connect; ignoring host") logf("detected self-connect; ignoring host")
return return