wgengine/magicsock: keep advertising endpoints after we stop discovering them

Previously, when updating endpoints we would immediately stop
advertising any endpoint that wasn't discovered during
determineEndpoints. This could result in, for example, a case where we
performed an incremental netcheck, didn't get any of our three STUN
packets back, and then dropped our STUN endpoint from the set of
advertised endpoints... which would result in clients falling back to a
DERP connection until the next call to determineEndpoints.

Instead, let's cache endpoints that we've discovered and continue
reporting them to clients until a timeout expires. In the above case
where we temporarily don't have a discovered STUN endpoint, we would
continue reporting the old value, then re-discover the STUN endpoint
again and continue reporting it as normal, so clients never see a
withdrawal.

Updates tailscale/coral#108

Signed-off-by: Andrew Dunham <andrew@du.nham.ca>
Change-Id: I42de72e7418ab328a6c732bdefc74549708cf8b9
This commit is contained in:
Andrew Dunham 2023-04-14 16:37:10 -04:00
parent 4b49ca4a12
commit 80b138f0df
2 changed files with 212 additions and 0 deletions

View File

@ -64,6 +64,7 @@
"tailscale.com/util/clientmetric" "tailscale.com/util/clientmetric"
"tailscale.com/util/mak" "tailscale.com/util/mak"
"tailscale.com/util/ringbuffer" "tailscale.com/util/ringbuffer"
"tailscale.com/util/set"
"tailscale.com/util/sysresources" "tailscale.com/util/sysresources"
"tailscale.com/util/uniq" "tailscale.com/util/uniq"
"tailscale.com/version" "tailscale.com/version"
@ -419,6 +420,10 @@ type Conn struct {
// when endpoints are refreshed. // when endpoints are refreshed.
onEndpointRefreshed map[*endpoint]func() onEndpointRefreshed map[*endpoint]func()
// endpointTracker tracks the set of cached endpoints that we advertise
// for a period of time before withdrawing them.
endpointTracker endpointTracker
// peerSet is the set of peers that are currently configured in // peerSet is the set of peers that are currently configured in
// WireGuard. These are not used to filter inbound or outbound // WireGuard. These are not used to filter inbound or outbound
// traffic at all, but only to track what state can be cleaned up // traffic at all, but only to track what state can be cleaned up
@ -1196,6 +1201,22 @@ func (c *Conn) determineEndpoints(ctx context.Context) ([]tailcfg.Endpoint, erro
c.ignoreSTUNPackets() c.ignoreSTUNPackets()
// Update our set of endpoints by adding any endpoints that we
// previously found but haven't expired yet. This also updates the
// cache with the set of endpoints discovered in this function.
//
// NOTE: we do this here and not below so that we don't cache local
// endpoints; we know that the local endpoints we discover are all
// possible local endpoints since we determine them by looking at the
// set of addresses on our local interfaces.
//
// TODO(andrew): If we pull in any cached endpoints, we should probably
// do something to ensure we're propagating the removal of those cached
// endpoints if they do actually time out without being rediscovered.
// For now, though, rely on a minor LinkChange event causing this to
// re-run.
eps = c.endpointTracker.update(time.Now(), eps)
if localAddr := c.pconn4.LocalAddr(); localAddr.IP.IsUnspecified() { if localAddr := c.pconn4.LocalAddr(); localAddr.IP.IsUnspecified() {
ips, loopback, err := interfaces.LocalAddresses() ips, loopback, err := interfaces.LocalAddresses()
if err != nil { if err != nil {
@ -4148,6 +4169,11 @@ type pendingCLIPing struct {
// STUN-derived endpoint valid for. UDP NAT mappings typically // STUN-derived endpoint valid for. UDP NAT mappings typically
// expire at 30 seconds, so this is a few seconds shy of that. // expire at 30 seconds, so this is a few seconds shy of that.
endpointsFreshEnoughDuration = 27 * time.Second endpointsFreshEnoughDuration = 27 * time.Second
// endpointTrackerLifetime is how long we continue advertising an
// endpoint after we last see it. This is intentionally chosen to be
// slightly longer than a full netcheck period.
endpointTrackerLifetime = 5*time.Minute + 10*time.Second
) )
// Constants that are variable for testing. // Constants that are variable for testing.
@ -5105,6 +5131,79 @@ func (s derpAddrFamSelector) PreferIPv6() bool {
return false return false
} }
type endpointTrackerEntry struct {
endpoint tailcfg.Endpoint
until time.Time
}
type endpointTracker struct {
mu sync.Mutex
cache map[netip.AddrPort]endpointTrackerEntry
}
func (et *endpointTracker) update(now time.Time, eps []tailcfg.Endpoint) (epsPlusCached []tailcfg.Endpoint) {
epsPlusCached = eps
var inputEps set.Slice[netip.AddrPort]
for _, ep := range eps {
inputEps.Add(ep.Addr)
}
et.mu.Lock()
defer et.mu.Unlock()
// Add entries to the return array that aren't already there.
for k, ep := range et.cache {
// If the endpoint was in the input list, or has expired, skip it.
if inputEps.Contains(k) {
continue
} else if now.After(ep.until) {
continue
}
// We haven't seen this endpoint; add to the return array
epsPlusCached = append(epsPlusCached, ep.endpoint)
}
// Add entries from the original input array into the cache, and/or
// extend the lifetime of entries that are already in the cache.
until := now.Add(endpointTrackerLifetime)
for _, ep := range eps {
et.addLocked(now, ep, until)
}
// Remove everything that has now expired.
et.removeExpiredLocked(now)
return epsPlusCached
}
// add will store the provided endpoint(s) in the cache for a fixed period of
// time, and remove any entries in the cache that have expired.
//
// et.mu must be held.
func (et *endpointTracker) addLocked(now time.Time, ep tailcfg.Endpoint, until time.Time) {
// If we already have an entry for this endpoint, update the timeout on
// it; otherwise, add it.
entry, found := et.cache[ep.Addr]
if found {
entry.until = until
} else {
entry = endpointTrackerEntry{ep, until}
}
mak.Set(&et.cache, ep.Addr, entry)
}
// removeExpired will remove all expired entries from the cache
//
// et.mu must be held
func (et *endpointTracker) removeExpiredLocked(now time.Time) {
for k, ep := range et.cache {
if now.After(ep.until) {
delete(et.cache, k)
}
}
}
var ( var (
metricNumPeers = clientmetric.NewGauge("magicsock_netmap_num_peers") metricNumPeers = clientmetric.NewGauge("magicsock_netmap_num_peers")
metricNumDERPConns = clientmetric.NewGauge("magicsock_num_derp_conns") metricNumDERPConns = clientmetric.NewGauge("magicsock_num_derp_conns")

View File

@ -18,6 +18,7 @@
"net/http/httptest" "net/http/httptest"
"net/netip" "net/netip"
"os" "os"
"reflect"
"runtime" "runtime"
"strconv" "strconv"
"strings" "strings"
@ -31,6 +32,7 @@
"github.com/tailscale/wireguard-go/tun/tuntest" "github.com/tailscale/wireguard-go/tun/tuntest"
"go4.org/mem" "go4.org/mem"
"golang.org/x/exp/maps" "golang.org/x/exp/maps"
"golang.org/x/exp/slices"
"golang.org/x/net/ipv6" "golang.org/x/net/ipv6"
"tailscale.com/cmd/testwrapper/flakytest" "tailscale.com/cmd/testwrapper/flakytest"
"tailscale.com/derp" "tailscale.com/derp"
@ -390,6 +392,7 @@ func TestNewConn(t *testing.T) {
for { for {
select { select {
case ep := <-epCh: case ep := <-epCh:
t.Logf("TestNewConn: got endpoint: %v", ep)
endpoints = append(endpoints, ep) endpoints = append(endpoints, ep)
if strings.HasSuffix(ep, suffix) { if strings.HasSuffix(ep, suffix) {
break collectEndpoints break collectEndpoints
@ -2280,3 +2283,113 @@ func TestIsWireGuardOnlyPeerWithMasquerade(t *testing.T) {
t.Fatal("no packet after 1s") t.Fatal("no packet after 1s")
} }
} }
func TestEndpointTracker(t *testing.T) {
local := tailcfg.Endpoint{
Addr: netip.MustParseAddrPort("192.168.1.1:12345"),
Type: tailcfg.EndpointLocal,
}
stun4_1 := tailcfg.Endpoint{
Addr: netip.MustParseAddrPort("1.2.3.4:12345"),
Type: tailcfg.EndpointSTUN,
}
stun4_2 := tailcfg.Endpoint{
Addr: netip.MustParseAddrPort("5.6.7.8:12345"),
Type: tailcfg.EndpointSTUN,
}
stun6_1 := tailcfg.Endpoint{
Addr: netip.MustParseAddrPort("[2a09:8280:1::1111]:12345"),
Type: tailcfg.EndpointSTUN,
}
stun6_2 := tailcfg.Endpoint{
Addr: netip.MustParseAddrPort("[2a09:8280:1::2222]:12345"),
Type: tailcfg.EndpointSTUN,
}
start := time.Unix(1681503440, 0)
steps := []struct {
name string
now time.Time
eps []tailcfg.Endpoint
want []tailcfg.Endpoint
}{
{
name: "initial endpoints",
now: start,
eps: []tailcfg.Endpoint{local, stun4_1, stun6_1},
want: []tailcfg.Endpoint{local, stun4_1, stun6_1},
},
{
name: "no change",
now: start.Add(1 * time.Minute),
eps: []tailcfg.Endpoint{local, stun4_1, stun6_1},
want: []tailcfg.Endpoint{local, stun4_1, stun6_1},
},
{
name: "missing stun4",
now: start.Add(2 * time.Minute),
eps: []tailcfg.Endpoint{local, stun6_1},
want: []tailcfg.Endpoint{local, stun4_1, stun6_1},
},
{
name: "missing stun6",
now: start.Add(3 * time.Minute),
eps: []tailcfg.Endpoint{local, stun4_1},
want: []tailcfg.Endpoint{local, stun4_1, stun6_1},
},
{
name: "multiple STUN addresses within timeout",
now: start.Add(4 * time.Minute),
eps: []tailcfg.Endpoint{local, stun4_2, stun6_2},
want: []tailcfg.Endpoint{local, stun4_1, stun4_2, stun6_1, stun6_2},
},
{
name: "endpoint extended",
now: start.Add(3*time.Minute + endpointTrackerLifetime - 1),
eps: []tailcfg.Endpoint{local},
want: []tailcfg.Endpoint{
local, stun4_2, stun6_2,
// stun4_1 had its lifetime extended by the
// "missing stun6" test above to that start
// time plus the lifetime, while stun6 should
// have expired a minute sooner. It should thus
// be in this returned list.
stun4_1,
},
},
{
name: "after timeout",
now: start.Add(4*time.Minute + endpointTrackerLifetime + 1),
eps: []tailcfg.Endpoint{local, stun4_2, stun6_2},
want: []tailcfg.Endpoint{local, stun4_2, stun6_2},
},
{
name: "after timeout still caches",
now: start.Add(4*time.Minute + endpointTrackerLifetime + time.Minute),
eps: []tailcfg.Endpoint{local},
want: []tailcfg.Endpoint{local, stun4_2, stun6_2},
},
}
var et endpointTracker
for _, tt := range steps {
t.Logf("STEP: %s", tt.name)
got := et.update(tt.now, tt.eps)
// Sort both arrays for comparison
slices.SortFunc(got, func(a, b tailcfg.Endpoint) bool {
return a.Addr.String() < b.Addr.String()
})
slices.SortFunc(tt.want, func(a, b tailcfg.Endpoint) bool {
return a.Addr.String() < b.Addr.String()
})
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("endpoints mismatch\ngot: %+v\nwant: %+v", got, tt.want)
}
}
}