wgengine/magicsock: improve endpoint selection for WireGuard peers with rx time

If we don't have the ICMP hint available, such as on Android, we can use
the signal of rx traffic to bias toward a particular endpoint.

We don't want to stick to a particular endpoint for a very long time
without any signals, so the sticky time is reduced to 1 second, which is
large enough to avoid excessive packet reordering in the common case,
but should be small enough that either rx provides a strong signal, or
we rotate in a user-interactive schedule to another endpoint, improving
the feel of failover to other endpoints.

Updates #8999

Co-authored-by: Charlotte Brandhorst-Satzkorn <charlotte@tailscale.com>

Signed-off-by: James Tucker <james@tailscale.com>
Signed-off-by: Charlotte Brandhorst-Satzkorn <charlotte@tailscale.com>
This commit is contained in:
James Tucker 2023-08-21 17:09:35 -07:00 committed by James Tucker
parent 5edb39d032
commit e1c7e9b736
4 changed files with 177 additions and 86 deletions

View File

@ -664,7 +664,7 @@ func (c *Conn) processDERPReadResult(dm derpReadResult, b []byte) (n int, ep *en
return 0, nil return 0, nil
} }
ep.noteRecvActivity() ep.noteRecvActivity(ipp)
if stats := c.stats.Load(); stats != nil { if stats := c.stats.Load(); stats != nil {
stats.UpdateRxPhysical(ep.nodeAddr, ipp, dm.n) stats.UpdateRxPhysical(ep.nodeAddr, ipp, dm.n)
} }

View File

@ -223,14 +223,26 @@ func (de *endpoint) initFakeUDPAddr() {
// noteRecvActivity records receive activity on de, and invokes // noteRecvActivity records receive activity on de, and invokes
// Conn.noteRecvActivity no more than once every 10s. // Conn.noteRecvActivity no more than once every 10s.
func (de *endpoint) noteRecvActivity() { func (de *endpoint) noteRecvActivity(ipp netip.AddrPort) {
if de.c.noteRecvActivity == nil {
return
}
now := mono.Now() now := mono.Now()
// TODO(raggi): this probably applies relatively equally well to disco
// managed endpoints, but that would be a less conservative change.
if de.isWireguardOnly {
de.mu.Lock()
de.bestAddr.AddrPort = ipp
de.bestAddrAt = now
de.trustBestAddrUntil = now.Add(5 * time.Second)
de.mu.Unlock()
}
elapsed := now.Sub(de.lastRecv.LoadAtomic()) elapsed := now.Sub(de.lastRecv.LoadAtomic())
if elapsed > 10*time.Second { if elapsed > 10*time.Second {
de.lastRecv.StoreAtomic(now) de.lastRecv.StoreAtomic(now)
if de.c.noteRecvActivity == nil {
return
}
de.c.noteRecvActivity(de.publicKey) de.c.noteRecvActivity(de.publicKey)
} }
} }
@ -292,11 +304,23 @@ func (de *endpoint) addrForSendLocked(now mono.Time) (udpAddr, derpAddr netip.Ad
// //
// de.mu must be held. // de.mu must be held.
func (de *endpoint) addrForWireGuardSendLocked(now mono.Time) (udpAddr netip.AddrPort, shouldPing bool) { func (de *endpoint) addrForWireGuardSendLocked(now mono.Time) (udpAddr netip.AddrPort, shouldPing bool) {
if len(de.endpointState) == 0 {
de.c.logf("magicsock: addrForSendWireguardLocked: [unexpected] no candidates available for endpoint")
return udpAddr, false
}
// lowestLatency is a high duration initially, so we // lowestLatency is a high duration initially, so we
// can be sure we're going to have a duration lower than this // can be sure we're going to have a duration lower than this
// for the first latency retrieved. // for the first latency retrieved.
lowestLatency := time.Hour lowestLatency := time.Hour
var oldestPing mono.Time
for ipp, state := range de.endpointState { for ipp, state := range de.endpointState {
if oldestPing.IsZero() {
oldestPing = state.lastPing
} else if state.lastPing.Before(oldestPing) {
oldestPing = state.lastPing
}
if latency, ok := state.latencyLocked(); ok { if latency, ok := state.latencyLocked(); ok {
if latency < lowestLatency || latency == lowestLatency && ipp.Addr().Is6() { if latency < lowestLatency || latency == lowestLatency && ipp.Addr().Is6() {
// If we have the same latency,IPv6 is prioritized. // If we have the same latency,IPv6 is prioritized.
@ -307,35 +331,25 @@ func (de *endpoint) addrForWireGuardSendLocked(now mono.Time) (udpAddr netip.Add
} }
} }
} }
needPing := len(de.endpointState) > 1 && now.Sub(oldestPing) > wireguardPingInterval
if udpAddr.IsValid() { if !udpAddr.IsValid() {
// Set trustBestAddrUntil to an hour, so we will candidates := xmaps.Keys(de.endpointState)
// continue to use this address for a long period of time.
de.bestAddr.AddrPort = udpAddr // Randomly select an address to use until we retrieve latency information
de.trustBestAddrUntil = now.Add(1 * time.Hour) // and give it a short trustBestAddrUntil time so we avoid flapping between
return udpAddr, false // addresses while waiting on latency information to be populated.
udpAddr = candidates[rand.Intn(len(candidates))]
} }
candidates := xmaps.Keys(de.endpointState)
if len(candidates) == 0 {
de.c.logf("magicsock: addrForSendWireguardLocked: [unexpected] no candidates available for endpoint")
return udpAddr, false
}
// Randomly select an address to use until we retrieve latency information
// and give it a short trustBestAddrUntil time so we avoid flapping between
// addresses while waiting on latency information to be populated.
udpAddr = candidates[rand.Intn(len(candidates))]
de.bestAddr.AddrPort = udpAddr de.bestAddr.AddrPort = udpAddr
if len(candidates) == 1 { // Only extend trustBestAddrUntil by one second to avoid packet
// if we only have one address that we can send data too, // reordering and/or CPU usage from random selection during the first
// we should trust it for a longer period of time. // second. We should receive a response due to a WireGuard handshake in
de.trustBestAddrUntil = now.Add(1 * time.Hour) // less than one second in good cases, in which case this will be then
} else { // extended to 15 seconds.
de.trustBestAddrUntil = now.Add(15 * time.Second) de.trustBestAddrUntil = now.Add(time.Second)
} return udpAddr, needPing
return udpAddr, len(candidates) > 1
} }
// heartbeat is called every heartbeatInterval to keep the best UDP path alive, // heartbeat is called every heartbeatInterval to keep the best UDP path alive,

View File

@ -1188,7 +1188,7 @@ func (c *Conn) receiveIP(b []byte, ipp netip.AddrPort, cache *ippEndpointCache)
cache.gen = de.numStopAndReset() cache.gen = de.numStopAndReset()
ep = de ep = de
} }
ep.noteRecvActivity() ep.noteRecvActivity(ipp)
if stats := c.stats.Load(); stats != nil { if stats := c.stats.Load(); stats != nil {
stats.UpdateRxPhysical(ep.nodeAddr, ipp, len(b)) stats.UpdateRxPhysical(ep.nodeAddr, ipp, len(b))
} }
@ -2605,6 +2605,11 @@ var (
// resetting the counter, as the first pings likely didn't through // resetting the counter, as the first pings likely didn't through
// the firewall) // the firewall)
discoPingInterval = 5 * time.Second discoPingInterval = 5 * time.Second
// wireguardPingInterval is the minimum time between pings to an endpoint.
// Pings are only sent if we have not observed bidirectional traffic with an
// endpoint in at least this duration.
wireguardPingInterval = 5 * time.Second
) )
// indexSentinelDeleted is the temporary value that endpointState.index takes while // indexSentinelDeleted is the temporary value that endpointState.index takes while

View File

@ -1212,11 +1212,11 @@ func Test32bitAlignment(t *testing.T) {
t.Fatalf("endpoint.lastRecv is not 8-byte aligned") t.Fatalf("endpoint.lastRecv is not 8-byte aligned")
} }
de.noteRecvActivity() // verify this doesn't panic on 32-bit de.noteRecvActivity(netip.AddrPort{}) // verify this doesn't panic on 32-bit
if called != 1 { if called != 1 {
t.Fatal("expected call to noteRecvActivity") t.Fatal("expected call to noteRecvActivity")
} }
de.noteRecvActivity() de.noteRecvActivity(netip.AddrPort{})
if called != 1 { if called != 1 {
t.Error("expected no second call to noteRecvActivity") t.Error("expected no second call to noteRecvActivity")
} }
@ -2678,6 +2678,7 @@ func newPingResponder(t *testing.T) *pingResponder {
func TestAddrForSendLockedForWireGuardOnly(t *testing.T) { func TestAddrForSendLockedForWireGuardOnly(t *testing.T) {
testTime := mono.Now() testTime := mono.Now()
secondPingTime := testTime.Add(10 * time.Second)
type endpointDetails struct { type endpointDetails struct {
addrPort netip.AddrPort addrPort netip.AddrPort
@ -2685,16 +2686,79 @@ func TestAddrForSendLockedForWireGuardOnly(t *testing.T) {
} }
wgTests := []struct { wgTests := []struct {
name string name string
noV4 bool sendInitialPing bool
noV6 bool validAddr bool
sendWGPing bool sendFollowUpPing bool
ep []endpointDetails pingTime mono.Time
want netip.AddrPort ep []endpointDetails
want netip.AddrPort
}{ }{
{ {
name: "choose lowest latency for useable IPv4 and IPv6", name: "no endpoints",
sendWGPing: true, sendInitialPing: false,
validAddr: false,
sendFollowUpPing: false,
pingTime: testTime,
ep: []endpointDetails{},
want: netip.AddrPort{},
},
{
name: "singular endpoint does not request ping",
sendInitialPing: false,
validAddr: true,
sendFollowUpPing: false,
pingTime: testTime,
ep: []endpointDetails{
{
addrPort: netip.MustParseAddrPort("1.1.1.1:111"),
latency: 100 * time.Millisecond,
},
},
want: netip.MustParseAddrPort("1.1.1.1:111"),
},
{
name: "ping sent within wireguardPingInterval should not request ping",
sendInitialPing: true,
validAddr: true,
sendFollowUpPing: false,
pingTime: testTime.Add(7 * time.Second),
ep: []endpointDetails{
{
addrPort: netip.MustParseAddrPort("1.1.1.1:111"),
latency: 100 * time.Millisecond,
},
{
addrPort: netip.MustParseAddrPort("[2345:0425:2CA1:0000:0000:0567:5673:23b5]:222"),
latency: 2000 * time.Millisecond,
},
},
want: netip.MustParseAddrPort("1.1.1.1:111"),
},
{
name: "ping sent outside of wireguardPingInterval should request ping",
sendInitialPing: true,
validAddr: true,
sendFollowUpPing: true,
pingTime: testTime.Add(3 * time.Second),
ep: []endpointDetails{
{
addrPort: netip.MustParseAddrPort("1.1.1.1:111"),
latency: 100 * time.Millisecond,
},
{
addrPort: netip.MustParseAddrPort("[2345:0425:2CA1:0000:0000:0567:5673:23b5]:222"),
latency: 150 * time.Millisecond,
},
},
want: netip.MustParseAddrPort("1.1.1.1:111"),
},
{
name: "choose lowest latency for useable IPv4 and IPv6",
sendInitialPing: true,
validAddr: true,
sendFollowUpPing: false,
pingTime: secondPingTime,
ep: []endpointDetails{ ep: []endpointDetails{
{ {
addrPort: netip.MustParseAddrPort("1.1.1.1:111"), addrPort: netip.MustParseAddrPort("1.1.1.1:111"),
@ -2708,8 +2772,11 @@ func TestAddrForSendLockedForWireGuardOnly(t *testing.T) {
want: netip.MustParseAddrPort("[2345:0425:2CA1:0000:0000:0567:5673:23b5]:222"), want: netip.MustParseAddrPort("[2345:0425:2CA1:0000:0000:0567:5673:23b5]:222"),
}, },
{ {
name: "choose IPv6 address when latency is the same for v4 and v6", name: "choose IPv6 address when latency is the same for v4 and v6",
sendWGPing: true, sendInitialPing: true,
validAddr: true,
sendFollowUpPing: false,
pingTime: secondPingTime,
ep: []endpointDetails{ ep: []endpointDetails{
{ {
addrPort: netip.MustParseAddrPort("1.1.1.1:111"), addrPort: netip.MustParseAddrPort("1.1.1.1:111"),
@ -2725,52 +2792,57 @@ func TestAddrForSendLockedForWireGuardOnly(t *testing.T) {
} }
for _, test := range wgTests { for _, test := range wgTests {
endpoint := &endpoint{ t.Run(test.name, func(t *testing.T) {
isWireguardOnly: true, endpoint := &endpoint{
endpointState: map[netip.AddrPort]*endpointState{}, isWireguardOnly: true,
c: &Conn{ endpointState: map[netip.AddrPort]*endpointState{},
noV4: atomic.Bool{}, c: &Conn{
noV6: atomic.Bool{}, logf: t.Logf,
}, noV4: atomic.Bool{},
} noV6: atomic.Bool{},
},
for _, epd := range test.ep {
endpoint.endpointState[epd.addrPort] = &endpointState{}
}
udpAddr, _, shouldPing := endpoint.addrForSendLocked(testTime)
if !udpAddr.IsValid() {
t.Error("udpAddr returned is not valid")
}
if shouldPing != test.sendWGPing {
t.Errorf("addrForSendLocked did not indiciate correct ping state; got %v, want %v", shouldPing, test.sendWGPing)
}
for _, epd := range test.ep {
state, ok := endpoint.endpointState[epd.addrPort]
if !ok {
t.Errorf("addr does not exist in endpoint state map")
} }
latency, ok := state.latencyLocked() for _, epd := range test.ep {
if ok { endpoint.endpointState[epd.addrPort] = &endpointState{}
t.Errorf("latency was set for %v: %v", epd.addrPort, latency) }
udpAddr, _, shouldPing := endpoint.addrForSendLocked(testTime)
if udpAddr.IsValid() != test.validAddr {
t.Errorf("udpAddr validity is incorrect; got %v, want %v", udpAddr.IsValid(), test.validAddr)
}
if shouldPing != test.sendInitialPing {
t.Errorf("addrForSendLocked did not indiciate correct ping state; got %v, want %v", shouldPing, test.sendInitialPing)
} }
state.recentPongs = append(state.recentPongs, pongReply{
latency: epd.latency,
})
state.recentPong = 0
}
udpAddr, _, shouldPing = endpoint.addrForSendLocked(testTime.Add(2 * time.Minute)) // Update the endpointState to simulate a ping having been
if udpAddr != test.want { // sent and a pong received.
t.Errorf("udpAddr returned is not expected: got %v, want %v", udpAddr, test.want) for _, epd := range test.ep {
} state, ok := endpoint.endpointState[epd.addrPort]
if shouldPing { if !ok {
t.Error("addrForSendLocked should not indicate ping is required") t.Errorf("addr does not exist in endpoint state map")
} }
if endpoint.bestAddr.AddrPort != test.want { state.lastPing = test.pingTime
t.Errorf("bestAddr.AddrPort is not as expected: got %v, want %v", endpoint.bestAddr.AddrPort, test.want)
} latency, ok := state.latencyLocked()
if ok {
t.Errorf("latency was set for %v: %v", epd.addrPort, latency)
}
state.recentPongs = append(state.recentPongs, pongReply{
latency: epd.latency,
})
state.recentPong = 0
}
udpAddr, _, shouldPing = endpoint.addrForSendLocked(secondPingTime)
if udpAddr != test.want {
t.Errorf("udpAddr returned is not expected: got %v, want %v", udpAddr, test.want)
}
if shouldPing != test.sendFollowUpPing {
t.Errorf("addrForSendLocked did not indiciate correct ping state; got %v, want %v", shouldPing, test.sendFollowUpPing)
}
if endpoint.bestAddr.AddrPort != test.want {
t.Errorf("bestAddr.AddrPort is not as expected: got %v, want %v", endpoint.bestAddr.AddrPort, test.want)
}
})
} }
} }