wgengine/magicsock: refactor maybeRebindOnError

Remove the platform specificity, it is unnecessary complexity.
Deduplicate repeated code as a result of reduced complexity.
Split out error identification code.
Update call-sites and tests.

Updates #14551
Updates tailscale/corp#25648

Signed-off-by: James Tucker <james@tailscale.com>
This commit is contained in:
James Tucker 2025-01-06 13:10:56 -08:00 committed by James Tucker
parent 6db220b478
commit 2c07f5dfcd
4 changed files with 95 additions and 65 deletions

View File

@ -364,9 +364,9 @@ type Conn struct {
// wireguard state by its public key. If nil, it's not used. // wireguard state by its public key. If nil, it's not used.
getPeerByKey func(key.NodePublic) (_ wgint.Peer, ok bool) getPeerByKey func(key.NodePublic) (_ wgint.Peer, ok bool)
// lastEPERMRebind tracks the last time a rebind was performed // lastErrRebind tracks the last time a rebind was performed after
// after experiencing a syscall.EPERM. // experiencing a write error, and is used to throttle the rate of rebinds.
lastEPERMRebind syncs.AtomicValue[time.Time] lastErrRebind syncs.AtomicValue[time.Time]
// staticEndpoints are user set endpoints that this node should // staticEndpoints are user set endpoints that this node should
// advertise amongst its wireguard endpoints. It is user's // advertise amongst its wireguard endpoints. It is user's
@ -1258,7 +1258,7 @@ func (c *Conn) sendUDPBatch(addr netip.AddrPort, buffs [][]byte) (sent bool, err
c.logf("magicsock: %s", errGSO.Error()) c.logf("magicsock: %s", errGSO.Error())
err = errGSO.RetryErr err = errGSO.RetryErr
} else { } else {
_ = c.maybeRebindOnError(runtime.GOOS, err) c.maybeRebindOnError(err)
} }
} }
return err == nil, err return err == nil, err
@ -1273,7 +1273,7 @@ func (c *Conn) sendUDP(ipp netip.AddrPort, b []byte, isDisco bool) (sent bool, e
sent, err = c.sendUDPStd(ipp, b) sent, err = c.sendUDPStd(ipp, b)
if err != nil { if err != nil {
metricSendUDPError.Add(1) metricSendUDPError.Add(1)
_ = c.maybeRebindOnError(runtime.GOOS, err) c.maybeRebindOnError(err)
} else { } else {
if sent && !isDisco { if sent && !isDisco {
switch { switch {
@ -1289,6 +1289,23 @@ func (c *Conn) sendUDP(ipp netip.AddrPort, b []byte, isDisco bool) (sent bool, e
return return
} }
// maybeRebindOnError performs a rebind and restun if the error is one that is
// known to be healed by a rebind, and the rebind is not throttled.
func (c *Conn) maybeRebindOnError(err error) {
ok, reason := shouldRebind(err)
if !ok {
return
}
if c.lastErrRebind.Load().Before(time.Now().Add(-5 * time.Second)) {
c.logf("magicsock: performing rebind due to %q", reason)
c.Rebind()
go c.ReSTUN(reason)
} else {
c.logf("magicsock: not performing %q rebind due to throttle", reason)
}
}
// sendUDPNetcheck sends b via UDP to addr. It is used exclusively by netcheck. // sendUDPNetcheck sends b via UDP to addr. It is used exclusively by netcheck.
// It returns the number of bytes sent along with any error encountered. It // It returns the number of bytes sent along with any error encountered. It
// returns errors.ErrUnsupported if the client is explicitly configured to only // returns errors.ErrUnsupported if the client is explicitly configured to only

View File

@ -8,42 +8,24 @@ package magicsock
import ( import (
"errors" "errors"
"syscall" "syscall"
"time"
) )
// maybeRebindOnError performs a rebind and restun if the error is defined and // shouldRebind returns if the error is one that is known to be healed by a
// any conditionals are met. // rebind, and if so also returns a resason string for the rebind.
func (c *Conn) maybeRebindOnError(os string, err error) bool { func shouldRebind(err error) (ok bool, reason string) {
switch { switch {
case errors.Is(err, syscall.EPIPE) || errors.Is(err, syscall.ENOTCONN):
// EPIPE/ENOTCONN are common errors when a send fails due to a closed // EPIPE/ENOTCONN are common errors when a send fails due to a closed
// socket. There is some platform and version inconsistency in which // socket. There is some platform and version inconsistency in which
// error is returned, but the meaning is the same. // error is returned, but the meaning is the same.
why := "broken-pipe-rebind" case errors.Is(err, syscall.EPIPE), errors.Is(err, syscall.ENOTCONN):
c.logf("magicsock: performing %q", why) return true, "broken-pipe"
c.Rebind()
go c.ReSTUN(why) // EPERM is typically caused by EDR software, and has been observed to be
return true // transient, it seems that some versions of some EDR lose track of sockets
// at times, and return EPERM, but reconnects will establish appropriate
// rights associated with a new socket.
case errors.Is(err, syscall.EPERM): case errors.Is(err, syscall.EPERM):
why := "operation-not-permitted-rebind" return true, "operation-not-permitted"
switch os {
// We currently will only rebind and restun on a syscall.EPERM if it is experienced
// on a client running darwin.
// TODO(charlotte, raggi): expand os options if required.
case "darwin":
// TODO(charlotte): implement a backoff, so we don't end up in a rebind loop for persistent
// EPERMs.
if c.lastEPERMRebind.Load().Before(time.Now().Add(-5 * time.Second)) {
c.logf("magicsock: performing %q", why)
c.lastEPERMRebind.Store(time.Now())
c.Rebind()
go c.ReSTUN(why)
return true
} }
default: return false, ""
c.logf("magicsock: not performing %q", why)
return false
}
}
return false
} }

View File

@ -5,8 +5,8 @@
package magicsock package magicsock
// maybeRebindOnError performs a rebind and restun if the error is defined and // shouldRebind returns if the error is one that is known to be healed by a
// any conditionals are met. // rebind, and if so also returns a resason string for the rebind.
func (c *Conn) maybeRebindOnError(os string, err error) bool { func shouldRebind(err error) (ok bool, reason string) {
return false return false, ""
} }

View File

@ -3050,37 +3050,68 @@ func TestMaybeSetNearestDERP(t *testing.T) {
} }
} }
func TestShouldRebind(t *testing.T) {
tests := []struct {
err error
ok bool
reason string
}{
{nil, false, ""},
{io.EOF, false, ""},
{io.ErrUnexpectedEOF, false, ""},
{io.ErrShortBuffer, false, ""},
{&net.OpError{Err: syscall.EPERM}, true, "operation-not-permitted"},
{&net.OpError{Err: syscall.EPIPE}, true, "broken-pipe"},
{&net.OpError{Err: syscall.ENOTCONN}, true, "broken-pipe"},
}
for _, tt := range tests {
t.Run(fmt.Sprintf("%s-%v", tt.err, tt.ok), func(t *testing.T) {
if got, reason := shouldRebind(tt.err); got != tt.ok || reason != tt.reason {
t.Errorf("errShouldRebind(%v) = %v, %q; want %v, %q", tt.err, got, reason, tt.ok, tt.reason)
}
})
}
}
func TestMaybeRebindOnError(t *testing.T) { func TestMaybeRebindOnError(t *testing.T) {
tstest.PanicOnLog() tstest.PanicOnLog()
tstest.ResourceCheck(t) tstest.ResourceCheck(t)
err := fmt.Errorf("outer err: %w", syscall.EPERM) var rebindErrs []error
if runtime.GOOS != "plan9" {
rebindErrs = append(rebindErrs,
&net.OpError{Err: syscall.EPERM},
&net.OpError{Err: syscall.EPIPE},
&net.OpError{Err: syscall.ENOTCONN},
)
}
t.Run("darwin-rebind", func(t *testing.T) { for _, rebindErr := range rebindErrs {
t.Run(fmt.Sprintf("rebind-%s", rebindErr), func(t *testing.T) {
conn := newTestConn(t) conn := newTestConn(t)
defer conn.Close() defer conn.Close()
rebound := conn.maybeRebindOnError("darwin", err)
if !rebound {
t.Errorf("darwin should rebind on syscall.EPERM")
}
})
t.Run("linux-not-rebind", func(t *testing.T) { before := metricRebindCalls.Value()
conn := newTestConn(t) conn.maybeRebindOnError(rebindErr)
defer conn.Close() after := metricRebindCalls.Value()
rebound := conn.maybeRebindOnError("linux", err) if before+1 != after {
if rebound { t.Errorf("should rebind on %#v", rebindErr)
t.Errorf("linux should not rebind on syscall.EPERM")
} }
}) })
}
t.Run("no-frequent-rebind", func(t *testing.T) { t.Run("no-frequent-rebind", func(t *testing.T) {
if runtime.GOOS != "plan9" {
err := fmt.Errorf("outer err: %w", syscall.EPERM)
conn := newTestConn(t) conn := newTestConn(t)
defer conn.Close() defer conn.Close()
conn.lastEPERMRebind.Store(time.Now().Add(-1 * time.Second)) conn.lastErrRebind.Store(time.Now().Add(-1 * time.Second))
rebound := conn.maybeRebindOnError("darwin", err) before := metricRebindCalls.Value()
if rebound { conn.maybeRebindOnError(err)
t.Errorf("darwin should not rebind on syscall.EPERM within 5 seconds of last") after := metricRebindCalls.Value()
if before != after {
t.Errorf("should not rebind within 5 seconds of last")
}
} }
}) })
} }