From bd93c3067e4adf73fa9eabbb2b9d016618128120 Mon Sep 17 00:00:00 2001 From: Brad Fitzpatrick Date: Tue, 18 Jun 2024 12:05:34 -0700 Subject: [PATCH] wgengine/filter/filtertype: make Match.IPProto a view I noticed we were allocating these every time when they could just share the same memory. Rather than document ownership, just lock it down with a view. I was considering doing all of the fields but decided to just do this one first as test to see how infectious it became. Conclusion: not very. Updates #cleanup (while working towards tailscale/corp#20514) Change-Id: I8ce08519de0c9a53f20292adfbecd970fe362de0 Signed-off-by: Brad Fitzpatrick --- net/tstun/wrap_test.go | 5 +++-- util/deephash/deephash_test.go | 3 ++- wgengine/filter/filter.go | 4 ++-- wgengine/filter/filter_test.go | 14 +++++--------- wgengine/filter/filtertype/filtertype.go | 3 ++- wgengine/filter/filtertype/filtertype_clone.go | 5 +++-- wgengine/filter/match.go | 7 +++---- wgengine/filter/tailcfg.go | 9 ++++++--- 8 files changed, 26 insertions(+), 24 deletions(-) diff --git a/net/tstun/wrap_test.go b/net/tstun/wrap_test.go index 5e3685c62..d6287c652 100644 --- a/net/tstun/wrap_test.go +++ b/net/tstun/wrap_test.go @@ -36,6 +36,7 @@ "tailscale.com/types/logger" "tailscale.com/types/netlogtype" "tailscale.com/types/ptr" + "tailscale.com/types/views" "tailscale.com/util/must" "tailscale.com/wgengine/capture" "tailscale.com/wgengine/filter" @@ -156,10 +157,10 @@ func netports(netPorts ...string) (ret []filter.NetPortRange) { } func setfilter(logf logger.Logf, tun *Wrapper) { - protos := []ipproto.Proto{ + protos := views.SliceOf([]ipproto.Proto{ ipproto.TCP, ipproto.UDP, - } + }) matches := []filter.Match{ {IPProto: protos, Srcs: nets("5.6.7.8"), Dsts: netports("1.2.3.4:89-90")}, {IPProto: protos, Srcs: nets("1.2.3.4"), Dsts: netports("5.6.7.8:98")}, diff --git a/util/deephash/deephash_test.go b/util/deephash/deephash_test.go index 51b0bfa10..7908f9a60 100644 --- a/util/deephash/deephash_test.go +++ b/util/deephash/deephash_test.go @@ -27,6 +27,7 @@ "tailscale.com/types/ipproto" "tailscale.com/types/key" "tailscale.com/types/ptr" + "tailscale.com/types/views" "tailscale.com/util/deephash/testtype" "tailscale.com/util/dnsname" "tailscale.com/util/hashx" @@ -353,7 +354,7 @@ func getVal() *tailscaleTypes { }, }, filter.Match{ - IPProto: []ipproto.Proto{1, 2, 3}, + IPProto: views.SliceOf([]ipproto.Proto{1, 2, 3}), }, } } diff --git a/wgengine/filter/filter.go b/wgengine/filter/filter.go index 01182d4f8..42d8ebc95 100644 --- a/wgengine/filter/filter.go +++ b/wgengine/filter/filter.go @@ -126,7 +126,7 @@ func NewAllowAllForTest(logf logger.Logf) *Filter { any6 := netip.PrefixFrom(netip.AddrFrom16([16]byte{}), 0) ms := []Match{ { - IPProto: []ipproto.Proto{ipproto.TCP, ipproto.UDP, ipproto.ICMPv4}, + IPProto: views.SliceOf([]ipproto.Proto{ipproto.TCP, ipproto.UDP, ipproto.ICMPv4}), Srcs: []netip.Prefix{any4}, Dsts: []NetPortRange{ { @@ -139,7 +139,7 @@ func NewAllowAllForTest(logf logger.Logf) *Filter { }, }, { - IPProto: []ipproto.Proto{ipproto.TCP, ipproto.UDP, ipproto.ICMPv6}, + IPProto: views.SliceOf([]ipproto.Proto{ipproto.TCP, ipproto.UDP, ipproto.ICMPv6}), Srcs: []netip.Prefix{any6}, Dsts: []NetPortRange{ { diff --git a/wgengine/filter/filter_test.go b/wgengine/filter/filter_test.go index 8632613b8..63c34d8dd 100644 --- a/wgengine/filter/filter_test.go +++ b/wgengine/filter/filter_test.go @@ -45,7 +45,7 @@ func m(srcs []netip.Prefix, dsts []NetPortRange, protos ...ipproto.Proto) Match protos = defaultProtos } return Match{ - IPProto: protos, + IPProto: views.SliceOf(protos), Srcs: srcs, SrcsContains: ipset.NewContainsIPFunc(views.SliceOf(srcs)), Dsts: dsts, @@ -767,12 +767,7 @@ func TestMatchesFromFilterRules(t *testing.T) { }, want: []Match{ { - IPProto: []ipproto.Proto{ - ipproto.TCP, - ipproto.UDP, - ipproto.ICMPv4, - ipproto.ICMPv6, - }, + IPProto: defaultProtosView, Dsts: []NetPortRange{ { Net: netip.MustParsePrefix("0.0.0.0/0"), @@ -804,9 +799,9 @@ func TestMatchesFromFilterRules(t *testing.T) { }, want: []Match{ { - IPProto: []ipproto.Proto{ + IPProto: views.SliceOf([]ipproto.Proto{ ipproto.TCP, - }, + }), Dsts: []NetPortRange{ { Net: netip.MustParsePrefix("1.2.0.0/16"), @@ -830,6 +825,7 @@ func TestMatchesFromFilterRules(t *testing.T) { cmpOpts := []cmp.Option{ cmp.Comparer(func(a, b netip.Addr) bool { return a == b }), cmp.Comparer(func(a, b netip.Prefix) bool { return a == b }), + cmp.Comparer(func(a, b views.Slice[ipproto.Proto]) bool { return views.SliceEqual(a, b) }), cmpopts.IgnoreFields(Match{}, ".SrcsContains"), } if diff := cmp.Diff(got, tt.want, cmpOpts...); diff != "" { diff --git a/wgengine/filter/filtertype/filtertype.go b/wgengine/filter/filtertype/filtertype.go index 1090ac718..689a45e7c 100644 --- a/wgengine/filter/filtertype/filtertype.go +++ b/wgengine/filter/filtertype/filtertype.go @@ -11,6 +11,7 @@ "tailscale.com/tailcfg" "tailscale.com/types/ipproto" + "tailscale.com/types/views" ) //go:generate go run tailscale.com/cmd/cloner --type=Match,CapMatch @@ -65,7 +66,7 @@ type CapMatch struct { // Match matches packets from any IP address in Srcs to any ip:port in // Dsts. type Match struct { - IPProto []ipproto.Proto // required set (no default value at this layer) + IPProto views.Slice[ipproto.Proto] // required set (no default value at this layer) Srcs []netip.Prefix SrcsContains func(netip.Addr) bool `json:"-"` // report whether Addr is in Srcs Dsts []NetPortRange // optional, if Srcs match diff --git a/wgengine/filter/filtertype/filtertype_clone.go b/wgengine/filter/filtertype/filtertype_clone.go index 056e2ee09..122f1bbe7 100644 --- a/wgengine/filter/filtertype/filtertype_clone.go +++ b/wgengine/filter/filtertype/filtertype_clone.go @@ -10,6 +10,7 @@ "tailscale.com/tailcfg" "tailscale.com/types/ipproto" + "tailscale.com/types/views" ) // Clone makes a deep copy of Match. @@ -20,7 +21,7 @@ func (src *Match) Clone() *Match { } dst := new(Match) *dst = *src - dst.IPProto = append(src.IPProto[:0:0], src.IPProto...) + dst.IPProto = src.IPProto dst.Srcs = append(src.Srcs[:0:0], src.Srcs...) dst.Dsts = append(src.Dsts[:0:0], src.Dsts...) if src.Caps != nil { @@ -34,7 +35,7 @@ func (src *Match) Clone() *Match { // A compilation failure here means this code must be regenerated, with the command at the top of this file. var _MatchCloneNeedsRegeneration = Match(struct { - IPProto []ipproto.Proto + IPProto views.Slice[ipproto.Proto] Srcs []netip.Prefix SrcsContains func(netip.Addr) bool Dsts []NetPortRange diff --git a/wgengine/filter/match.go b/wgengine/filter/match.go index 70c4a6d02..4d93979ea 100644 --- a/wgengine/filter/match.go +++ b/wgengine/filter/match.go @@ -4,9 +4,8 @@ package filter import ( - "slices" - "tailscale.com/net/packet" + "tailscale.com/types/views" "tailscale.com/wgengine/filter/filtertype" ) @@ -14,7 +13,7 @@ func (ms matches) match(q *packet.Parsed) bool { for _, m := range ms { - if !slices.Contains(m.IPProto, q.IPProto) { + if !views.SliceContains(m.IPProto, q.IPProto) { continue } if !m.SrcsContains(q.Src.Addr()) { @@ -52,7 +51,7 @@ func (ms matches) matchIPsOnly(q *packet.Parsed) bool { // ignored, as long as the match is for the entire uint16 port range. func (ms matches) matchProtoAndIPsOnlyIfAllPorts(q *packet.Parsed) bool { for _, m := range ms { - if !slices.Contains(m.IPProto, q.IPProto) { + if !views.SliceContains(m.IPProto, q.IPProto) { continue } if !m.SrcsContains(q.Src.Addr()) { diff --git a/wgengine/filter/tailcfg.go b/wgengine/filter/tailcfg.go index e52e7c6e1..be9c6f13b 100644 --- a/wgengine/filter/tailcfg.go +++ b/wgengine/filter/tailcfg.go @@ -23,6 +23,8 @@ ipproto.ICMPv6, } +var defaultProtosView = views.SliceOf(defaultProtos) + // MatchesFromFilterRules converts tailcfg FilterRules into Matches. // If an error is returned, the Matches result is still valid, // containing the rules that were successfully converted. @@ -41,14 +43,15 @@ func MatchesFromFilterRules(pf []tailcfg.FilterRule) ([]Match, error) { } if len(r.IPProto) == 0 { - m.IPProto = append([]ipproto.Proto(nil), defaultProtos...) + m.IPProto = defaultProtosView } else { - m.IPProto = make([]ipproto.Proto, 0, len(r.IPProto)) + filtered := make([]ipproto.Proto, 0, len(r.IPProto)) for _, n := range r.IPProto { if n >= 0 && n <= 0xff { - m.IPProto = append(m.IPProto, ipproto.Proto(n)) + filtered = append(filtered, ipproto.Proto(n)) } } + m.IPProto = views.SliceOf(filtered) } for i, s := range r.SrcIPs {