mirror of
https://github.com/tailscale/tailscale.git
synced 2025-01-10 01:53:49 +00:00
188 lines
5.2 KiB
Go
188 lines
5.2 KiB
Go
|
// Copyright (c) Tailscale Inc & AUTHORS
|
||
|
// SPDX-License-Identifier: BSD-3-Clause
|
||
|
|
||
|
package magicsock
|
||
|
|
||
|
import (
|
||
|
"net/netip"
|
||
|
"reflect"
|
||
|
"slices"
|
||
|
"strings"
|
||
|
"testing"
|
||
|
"time"
|
||
|
|
||
|
"tailscale.com/tailcfg"
|
||
|
)
|
||
|
|
||
|
func TestEndpointTracker(t *testing.T) {
|
||
|
local := tailcfg.Endpoint{
|
||
|
Addr: netip.MustParseAddrPort("192.168.1.1:12345"),
|
||
|
Type: tailcfg.EndpointLocal,
|
||
|
}
|
||
|
|
||
|
stun4_1 := tailcfg.Endpoint{
|
||
|
Addr: netip.MustParseAddrPort("1.2.3.4:12345"),
|
||
|
Type: tailcfg.EndpointSTUN,
|
||
|
}
|
||
|
stun4_2 := tailcfg.Endpoint{
|
||
|
Addr: netip.MustParseAddrPort("5.6.7.8:12345"),
|
||
|
Type: tailcfg.EndpointSTUN,
|
||
|
}
|
||
|
|
||
|
stun6_1 := tailcfg.Endpoint{
|
||
|
Addr: netip.MustParseAddrPort("[2a09:8280:1::1111]:12345"),
|
||
|
Type: tailcfg.EndpointSTUN,
|
||
|
}
|
||
|
stun6_2 := tailcfg.Endpoint{
|
||
|
Addr: netip.MustParseAddrPort("[2a09:8280:1::2222]:12345"),
|
||
|
Type: tailcfg.EndpointSTUN,
|
||
|
}
|
||
|
|
||
|
start := time.Unix(1681503440, 0)
|
||
|
|
||
|
steps := []struct {
|
||
|
name string
|
||
|
now time.Time
|
||
|
eps []tailcfg.Endpoint
|
||
|
want []tailcfg.Endpoint
|
||
|
}{
|
||
|
{
|
||
|
name: "initial endpoints",
|
||
|
now: start,
|
||
|
eps: []tailcfg.Endpoint{local, stun4_1, stun6_1},
|
||
|
want: []tailcfg.Endpoint{local, stun4_1, stun6_1},
|
||
|
},
|
||
|
{
|
||
|
name: "no change",
|
||
|
now: start.Add(1 * time.Minute),
|
||
|
eps: []tailcfg.Endpoint{local, stun4_1, stun6_1},
|
||
|
want: []tailcfg.Endpoint{local, stun4_1, stun6_1},
|
||
|
},
|
||
|
{
|
||
|
name: "missing stun4",
|
||
|
now: start.Add(2 * time.Minute),
|
||
|
eps: []tailcfg.Endpoint{local, stun6_1},
|
||
|
want: []tailcfg.Endpoint{local, stun4_1, stun6_1},
|
||
|
},
|
||
|
{
|
||
|
name: "missing stun6",
|
||
|
now: start.Add(3 * time.Minute),
|
||
|
eps: []tailcfg.Endpoint{local, stun4_1},
|
||
|
want: []tailcfg.Endpoint{local, stun4_1, stun6_1},
|
||
|
},
|
||
|
{
|
||
|
name: "multiple STUN addresses within timeout",
|
||
|
now: start.Add(4 * time.Minute),
|
||
|
eps: []tailcfg.Endpoint{local, stun4_2, stun6_2},
|
||
|
want: []tailcfg.Endpoint{local, stun4_1, stun4_2, stun6_1, stun6_2},
|
||
|
},
|
||
|
{
|
||
|
name: "endpoint extended",
|
||
|
now: start.Add(3*time.Minute + endpointTrackerLifetime - 1),
|
||
|
eps: []tailcfg.Endpoint{local},
|
||
|
want: []tailcfg.Endpoint{
|
||
|
local, stun4_2, stun6_2,
|
||
|
// stun4_1 had its lifetime extended by the
|
||
|
// "missing stun6" test above to that start
|
||
|
// time plus the lifetime, while stun6 should
|
||
|
// have expired a minute sooner. It should thus
|
||
|
// be in this returned list.
|
||
|
stun4_1,
|
||
|
},
|
||
|
},
|
||
|
{
|
||
|
name: "after timeout",
|
||
|
now: start.Add(4*time.Minute + endpointTrackerLifetime + 1),
|
||
|
eps: []tailcfg.Endpoint{local, stun4_2, stun6_2},
|
||
|
want: []tailcfg.Endpoint{local, stun4_2, stun6_2},
|
||
|
},
|
||
|
{
|
||
|
name: "after timeout still caches",
|
||
|
now: start.Add(4*time.Minute + endpointTrackerLifetime + time.Minute),
|
||
|
eps: []tailcfg.Endpoint{local},
|
||
|
want: []tailcfg.Endpoint{local, stun4_2, stun6_2},
|
||
|
},
|
||
|
}
|
||
|
|
||
|
var et endpointTracker
|
||
|
for _, tt := range steps {
|
||
|
t.Logf("STEP: %s", tt.name)
|
||
|
|
||
|
got := et.update(tt.now, tt.eps)
|
||
|
|
||
|
// Sort both arrays for comparison
|
||
|
slices.SortFunc(got, func(a, b tailcfg.Endpoint) int {
|
||
|
return strings.Compare(a.Addr.String(), b.Addr.String())
|
||
|
})
|
||
|
slices.SortFunc(tt.want, func(a, b tailcfg.Endpoint) int {
|
||
|
return strings.Compare(a.Addr.String(), b.Addr.String())
|
||
|
})
|
||
|
|
||
|
if !reflect.DeepEqual(got, tt.want) {
|
||
|
t.Errorf("endpoints mismatch\ngot: %+v\nwant: %+v", got, tt.want)
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func TestEndpointTrackerMaxNum(t *testing.T) {
|
||
|
start := time.Unix(1681503440, 0)
|
||
|
|
||
|
var allEndpoints []tailcfg.Endpoint // all created endpoints
|
||
|
mkEp := func(i int) tailcfg.Endpoint {
|
||
|
ep := tailcfg.Endpoint{
|
||
|
Addr: netip.AddrPortFrom(netip.MustParseAddr("1.2.3.4"), uint16(i)),
|
||
|
Type: tailcfg.EndpointSTUN,
|
||
|
}
|
||
|
allEndpoints = append(allEndpoints, ep)
|
||
|
return ep
|
||
|
}
|
||
|
|
||
|
var et endpointTracker
|
||
|
|
||
|
// Add more endpoints to the list than our limit
|
||
|
for i := 0; i <= endpointTrackerMaxPerAddr; i++ {
|
||
|
et.update(start.Add(time.Duration(i)*time.Second), []tailcfg.Endpoint{mkEp(10000 + i)})
|
||
|
}
|
||
|
|
||
|
// Now add two more, slightly later
|
||
|
got := et.update(start.Add(1*time.Minute), []tailcfg.Endpoint{
|
||
|
mkEp(10100),
|
||
|
mkEp(10101),
|
||
|
})
|
||
|
|
||
|
// We expect to get the last N endpoints per our per-Addr limit, since
|
||
|
// all of the endpoints have the same netip.Addr. The first endpoint(s)
|
||
|
// that we added were dropped because we had more than the limit for
|
||
|
// this Addr.
|
||
|
want := allEndpoints[len(allEndpoints)-endpointTrackerMaxPerAddr:]
|
||
|
|
||
|
compareEndpoints := func(got, want []tailcfg.Endpoint) {
|
||
|
t.Helper()
|
||
|
slices.SortFunc(want, func(a, b tailcfg.Endpoint) int {
|
||
|
return strings.Compare(a.Addr.String(), b.Addr.String())
|
||
|
})
|
||
|
slices.SortFunc(got, func(a, b tailcfg.Endpoint) int {
|
||
|
return strings.Compare(a.Addr.String(), b.Addr.String())
|
||
|
})
|
||
|
if !reflect.DeepEqual(got, want) {
|
||
|
t.Errorf("endpoints mismatch\ngot: %+v\nwant: %+v", got, want)
|
||
|
}
|
||
|
}
|
||
|
compareEndpoints(got, want)
|
||
|
|
||
|
// However, if we have more than our limit of endpoints passed in to
|
||
|
// the endpointTracker, we will return all of them (even if they're for
|
||
|
// the same address).
|
||
|
var inputEps []tailcfg.Endpoint
|
||
|
for i := 0; i < endpointTrackerMaxPerAddr+5; i++ {
|
||
|
inputEps = append(inputEps, tailcfg.Endpoint{
|
||
|
Addr: netip.AddrPortFrom(netip.MustParseAddr("1.2.3.4"), 10200+uint16(i)),
|
||
|
Type: tailcfg.EndpointSTUN,
|
||
|
})
|
||
|
}
|
||
|
|
||
|
want = inputEps
|
||
|
got = et.update(start.Add(2*time.Minute), inputEps)
|
||
|
compareEndpoints(got, want)
|
||
|
}
|