diff --git a/tailcfg/tailcfg.go b/tailcfg/tailcfg.go index b3754f2b5..236b875ae 100644 --- a/tailcfg/tailcfg.go +++ b/tailcfg/tailcfg.go @@ -24,6 +24,7 @@ "tailscale.com/types/tkatype" "tailscale.com/util/cmpx" "tailscale.com/util/dnsname" + "tailscale.com/util/slicesx" ) // CapabilityVersion represents the client's capability level. That @@ -1939,10 +1940,10 @@ func (n *Node) Equal(n2 *Node) bool { n.Machine == n2.Machine && n.DiscoKey == n2.DiscoKey && eqPtr(n.Online, n2.Online) && - eqCIDRs(n.Addresses, n2.Addresses) && - eqCIDRs(n.AllowedIPs, n2.AllowedIPs) && - eqCIDRs(n.PrimaryRoutes, n2.PrimaryRoutes) && - eqStrings(n.Endpoints, n2.Endpoints) && + slicesx.EqualSameNil(n.Addresses, n2.Addresses) && + slicesx.EqualSameNil(n.AllowedIPs, n2.AllowedIPs) && + slicesx.EqualSameNil(n.PrimaryRoutes, n2.PrimaryRoutes) && + slicesx.EqualSameNil(n.Endpoints, n2.Endpoints) && n.DERP == n2.DERP && n.Cap == n2.Cap && n.Hostinfo.Equal(n2.Hostinfo) && @@ -1954,7 +1955,7 @@ func (n *Node) Equal(n2 *Node) bool { n.ComputedName == n2.ComputedName && n.computedHostIfDifferent == n2.computedHostIfDifferent && n.ComputedNameWithHost == n2.ComputedNameWithHost && - eqStrings(n.Tags, n2.Tags) && + slicesx.EqualSameNil(n.Tags, n2.Tags) && n.Expired == n2.Expired && eqPtr(n.SelfNodeV4MasqAddrForThisPeer, n2.SelfNodeV4MasqAddrForThisPeer) && eqPtr(n.SelfNodeV6MasqAddrForThisPeer, n2.SelfNodeV6MasqAddrForThisPeer) && @@ -1971,30 +1972,6 @@ func eqPtr[T comparable](a, b *T) bool { return *a == *b } -func eqStrings(a, b []string) bool { - if len(a) != len(b) || ((a == nil) != (b == nil)) { - return false - } - for i, v := range a { - if v != b[i] { - return false - } - } - return true -} - -func eqCIDRs(a, b []netip.Prefix) bool { - if len(a) != len(b) || ((a == nil) != (b == nil)) { - return false - } - for i, v := range a { - if v != b[i] { - return false - } - } - return true -} - func eqTimePtr(a, b *time.Time) bool { return ((a == nil) == (b == nil)) && (a == nil || a.Equal(*b)) } diff --git a/util/slicesx/slicesx.go b/util/slicesx/slicesx.go index ba5be7271..2e4ac9567 100644 --- a/util/slicesx/slicesx.go +++ b/util/slicesx/slicesx.go @@ -57,3 +57,23 @@ func Partition[S ~[]T, T any](s S, cb func(T) bool) (trues, falses S) { } return } + +// EqualSameNil reports whether two slices are equal: the same length, same +// nilness (notably when length zero), and all elements equal. If the lengths +// are different or their nilness differs, Equal returns false. Otherwise, the +// elements are compared in increasing index order, and the comparison stops at +// the first unequal pair. Floating point NaNs are not considered equal. +// +// It is identical to the standard library's slices.Equal but adds the matching +// nilness check. +func EqualSameNil[S ~[]E, E comparable](s1, s2 S) bool { + if len(s1) != len(s2) || (s1 == nil) != (s2 == nil) { + return false + } + for i := range s1 { + if s1[i] != s2[i] { + return false + } + } + return true +} diff --git a/util/slicesx/slicesx_test.go b/util/slicesx/slicesx_test.go index 48efae4fb..0d206b364 100644 --- a/util/slicesx/slicesx_test.go +++ b/util/slicesx/slicesx_test.go @@ -7,6 +7,8 @@ "reflect" "slices" "testing" + + qt "github.com/frankban/quicktest" ) func TestInterleave(t *testing.T) { @@ -84,3 +86,14 @@ func TestPartition(t *testing.T) { t.Errorf("odds: got %v, want %v", odds, wantOdds) } } + +func TestEqualSameNil(t *testing.T) { + c := qt.New(t) + c.Check(EqualSameNil([]string{"a"}, []string{"a"}), qt.Equals, true) + c.Check(EqualSameNil([]string{"a"}, []string{"b"}), qt.Equals, false) + c.Check(EqualSameNil([]string{"a"}, []string{}), qt.Equals, false) + c.Check(EqualSameNil([]string{}, []string{}), qt.Equals, true) + c.Check(EqualSameNil(nil, []string{}), qt.Equals, false) + c.Check(EqualSameNil([]string{}, nil), qt.Equals, false) + c.Check(EqualSameNil[[]string](nil, nil), qt.Equals, true) +} diff --git a/wgengine/router/router_windows.go b/wgengine/router/router_windows.go index 155c29b46..d51f7a7c6 100644 --- a/wgengine/router/router_windows.go +++ b/wgengine/router/router_windows.go @@ -12,6 +12,7 @@ "net/netip" "os" "os/exec" + "slices" "strings" "sync" "syscall" @@ -196,7 +197,7 @@ func (ft *firewallTweaker) doAsyncSet() { ft.mu.Lock() for { // invariant: ft.mu must be locked when beginning this block val := ft.wantLocal - if ft.known && strsEqual(ft.lastLocal, val) && ft.wantKillswitch == ft.lastKillswitch && routesEqual(ft.localRoutes, ft.lastLocalRoutes) { + if ft.known && slices.Equal(ft.lastLocal, val) && ft.wantKillswitch == ft.lastKillswitch && slices.Equal(ft.localRoutes, ft.lastLocalRoutes) { ft.running = false ft.logf("ending netsh goroutine") ft.mu.Unlock() @@ -341,28 +342,3 @@ func (ft *firewallTweaker) doSet(local []string, killswitch bool, clear bool, pr // in via stdin encoded in json. return ft.fwProcEncoder.Encode(allowedRoutes) } - -func routesEqual(a, b []netip.Prefix) bool { - if len(a) != len(b) { - return false - } - // Routes are pre-sorted. - for i := range a { - if a[i] != b[i] { - return false - } - } - return true -} - -func strsEqual(a, b []string) bool { - if len(a) != len(b) { - return false - } - for i := range a { - if a[i] != b[i] { - return false - } - } - return true -}