wgengine/filter: use NewContainsIPFunc for Srcs matches

NewContainsIPFunc returns a contains matcher optimized for its
input. Use that instead of what this did before, always doing a test
over each of a list of netip.Prefixes.

    goos: darwin
    goarch: arm64
    pkg: tailscale.com/wgengine/filter
                        │   before    │                after                │
                        │   sec/op    │   sec/op     vs base                │
    FilterMatch/file1-8   32.60n ± 1%   18.87n ± 1%  -42.12% (p=0.000 n=10)

Updates #12486

Change-Id: I8f902bc064effb431e5b46751115942104ff6531
Signed-off-by: Brad Fitzpatrick <bradfitz@tailscale.com>
This commit is contained in:
Brad Fitzpatrick 2024-06-15 18:20:17 -07:00 committed by Brad Fitzpatrick
parent e2c0d69c9c
commit 21ed31e33a
5 changed files with 55 additions and 42 deletions

View File

@ -16,10 +16,12 @@ import (
"tailscale.com/net/flowtrack" "tailscale.com/net/flowtrack"
"tailscale.com/net/netaddr" "tailscale.com/net/netaddr"
"tailscale.com/net/packet" "tailscale.com/net/packet"
"tailscale.com/net/tsaddr"
"tailscale.com/tailcfg" "tailscale.com/tailcfg"
"tailscale.com/tstime/rate" "tailscale.com/tstime/rate"
"tailscale.com/types/ipproto" "tailscale.com/types/ipproto"
"tailscale.com/types/logger" "tailscale.com/types/logger"
"tailscale.com/types/views"
"tailscale.com/util/mak" "tailscale.com/util/mak"
) )
@ -30,12 +32,12 @@ type Filter struct {
// this node. All packets coming in over tailscale must have a // this node. All packets coming in over tailscale must have a
// destination within local, regardless of the policy filter // destination within local, regardless of the policy filter
// below. // below.
local *netipx.IPSet local func(netip.Addr) bool
// logIPs is the set of IPs that are allowed to appear in flow // logIPs is the set of IPs that are allowed to appear in flow
// logs. If a packet is to or from an IP not in logIPs, it will // logs. If a packet is to or from an IP not in logIPs, it will
// never be logged. // never be logged.
logIPs *netipx.IPSet logIPs func(netip.Addr) bool
// matches4 and matches6 are lists of match->action rules // matches4 and matches6 are lists of match->action rules
// applied to all packets arriving over tailscale // applied to all packets arriving over tailscale
@ -172,7 +174,7 @@ func NewShieldsUpFilter(localNets *netipx.IPSet, logIPs *netipx.IPSet, shareStat
// by matches. If shareStateWith is non-nil, the returned filter // by matches. If shareStateWith is non-nil, the returned filter
// shares state with the previous one, to enable changing rules at // shares state with the previous one, to enable changing rules at
// runtime without breaking existing stateful flows. // runtime without breaking existing stateful flows.
func New(matches []Match, localNets *netipx.IPSet, logIPs *netipx.IPSet, shareStateWith *Filter, logf logger.Logf) *Filter { func New(matches []Match, localNets, logIPs *netipx.IPSet, shareStateWith *Filter, logf logger.Logf) *Filter {
var state *filterState var state *filterState
if shareStateWith != nil { if shareStateWith != nil {
state = shareStateWith.state state = shareStateWith.state
@ -181,14 +183,22 @@ func New(matches []Match, localNets *netipx.IPSet, logIPs *netipx.IPSet, shareSt
lru: &flowtrack.Cache[struct{}]{MaxEntries: lruMax}, lru: &flowtrack.Cache[struct{}]{MaxEntries: lruMax},
} }
} }
containsFunc := func(s *netipx.IPSet) func(netip.Addr) bool {
if s == nil {
return tsaddr.FalseContainsIPFunc()
}
return tsaddr.NewContainsIPFunc(views.SliceOf(s.Prefixes()))
}
f := &Filter{ f := &Filter{
logf: logf, logf: logf,
matches4: matchesFamily(matches, netip.Addr.Is4), matches4: matchesFamily(matches, netip.Addr.Is4),
matches6: matchesFamily(matches, netip.Addr.Is6), matches6: matchesFamily(matches, netip.Addr.Is6),
cap4: capMatchesFunc(matches, netip.Addr.Is4), cap4: capMatchesFunc(matches, netip.Addr.Is4),
cap6: capMatchesFunc(matches, netip.Addr.Is6), cap6: capMatchesFunc(matches, netip.Addr.Is6),
local: localNets, local: containsFunc(localNets),
logIPs: logIPs, logIPs: containsFunc(logIPs),
state: state, state: state,
} }
return f return f
@ -206,12 +216,14 @@ func matchesFamily(ms matches, keep func(netip.Addr) bool) matches {
retm.Srcs = append(retm.Srcs, src) retm.Srcs = append(retm.Srcs, src)
} }
} }
for _, dst := range m.Dsts { for _, dst := range m.Dsts {
if keep(dst.Net.Addr()) { if keep(dst.Net.Addr()) {
retm.Dsts = append(retm.Dsts, dst) retm.Dsts = append(retm.Dsts, dst)
} }
} }
if len(retm.Srcs) > 0 && len(retm.Dsts) > 0 { if len(retm.Srcs) > 0 && len(retm.Dsts) > 0 {
retm.SrcsContains = tsaddr.NewContainsIPFunc(views.SliceOf(retm.Srcs))
ret = append(ret, retm) ret = append(ret, retm)
} }
} }
@ -233,6 +245,7 @@ func capMatchesFunc(ms matches, keep func(netip.Addr) bool) matches {
} }
} }
if len(retm.Srcs) > 0 { if len(retm.Srcs) > 0 {
retm.SrcsContains = tsaddr.NewContainsIPFunc(views.SliceOf(retm.Srcs))
ret = append(ret, retm) ret = append(ret, retm)
} }
} }
@ -268,7 +281,7 @@ func init() {
} }
func (f *Filter) logRateLimit(runflags RunFlags, q *packet.Parsed, dir direction, r Response, why string) { func (f *Filter) logRateLimit(runflags RunFlags, q *packet.Parsed, dir direction, r Response, why string) {
if !f.loggingAllowed(q) { if runflags == 0 || !f.loggingAllowed(q) {
return return
} }
@ -345,7 +358,7 @@ func (f *Filter) CapsWithValues(srcIP, dstIP netip.Addr) tailcfg.PeerCapMap {
} }
var out tailcfg.PeerCapMap var out tailcfg.PeerCapMap
for _, m := range mm { for _, m := range mm {
if !ipInList(srcIP, m.Srcs) { if !m.SrcsContains(srcIP) {
continue continue
} }
for _, cm := range m.Caps { for _, cm := range m.Caps {
@ -418,7 +431,7 @@ func (f *Filter) runIn4(q *packet.Parsed) (r Response, why string) {
// A compromised peer could try to send us packets for // A compromised peer could try to send us packets for
// destinations we didn't explicitly advertise. This check is to // destinations we didn't explicitly advertise. This check is to
// prevent that. // prevent that.
if !f.local.Contains(q.Dst.Addr()) { if !f.local(q.Dst.Addr()) {
return Drop, "destination not allowed" return Drop, "destination not allowed"
} }
@ -478,7 +491,7 @@ func (f *Filter) runIn6(q *packet.Parsed) (r Response, why string) {
// A compromised peer could try to send us packets for // A compromised peer could try to send us packets for
// destinations we didn't explicitly advertise. This check is to // destinations we didn't explicitly advertise. This check is to
// prevent that. // prevent that.
if !f.local.Contains(q.Dst.Addr()) { if !f.local(q.Dst.Addr()) {
return Drop, "destination not allowed" return Drop, "destination not allowed"
} }
@ -604,7 +617,7 @@ func (f *Filter) pre(q *packet.Parsed, rf RunFlags, dir direction) Response {
// loggingAllowed reports whether p can appear in logs at all. // loggingAllowed reports whether p can appear in logs at all.
func (f *Filter) loggingAllowed(p *packet.Parsed) bool { func (f *Filter) loggingAllowed(p *packet.Parsed) bool {
return f.logIPs.Contains(p.Src.Addr()) && f.logIPs.Contains(p.Dst.Addr()) return f.logIPs(p.Src.Addr()) && f.logIPs(p.Dst.Addr())
} }
// omitDropLogging reports whether packet p, which has already been // omitDropLogging reports whether packet p, which has already been

View File

@ -34,10 +34,11 @@ func (src *Match) Clone() *Match {
// A compilation failure here means this code must be regenerated, with the command at the top of this file. // A compilation failure here means this code must be regenerated, with the command at the top of this file.
var _MatchCloneNeedsRegeneration = Match(struct { var _MatchCloneNeedsRegeneration = Match(struct {
IPProto []ipproto.Proto IPProto []ipproto.Proto
Srcs []netip.Prefix Srcs []netip.Prefix
Dsts []NetPortRange SrcsContains func(netip.Addr) bool
Caps []CapMatch Dsts []NetPortRange
Caps []CapMatch
}{}) }{})
// Clone makes a deep copy of CapMatch. // Clone makes a deep copy of CapMatch.

View File

@ -16,6 +16,7 @@ import (
"testing" "testing"
"github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
"go4.org/netipx" "go4.org/netipx"
xmaps "golang.org/x/exp/maps" xmaps "golang.org/x/exp/maps"
"tailscale.com/net/packet" "tailscale.com/net/packet"
@ -25,6 +26,7 @@ import (
"tailscale.com/tstime/rate" "tailscale.com/tstime/rate"
"tailscale.com/types/ipproto" "tailscale.com/types/ipproto"
"tailscale.com/types/logger" "tailscale.com/types/logger"
"tailscale.com/types/views"
"tailscale.com/util/must" "tailscale.com/util/must"
) )
@ -40,9 +42,10 @@ func m(srcs []netip.Prefix, dsts []NetPortRange, protos ...ipproto.Proto) Match
protos = defaultProtos protos = defaultProtos
} }
return Match{ return Match{
IPProto: protos, IPProto: protos,
Srcs: srcs, Srcs: srcs,
Dsts: dsts, SrcsContains: tsaddr.NewContainsIPFunc(views.SliceOf(srcs)),
Dsts: dsts,
} }
} }
@ -436,11 +439,11 @@ func TestLoggingPrivacy(t *testing.T) {
logged = true logged = true
} }
var logB netipx.IPSetBuilder
logB.AddPrefix(netip.MustParsePrefix("100.64.0.0/10"))
logB.AddPrefix(tsaddr.TailscaleULARange())
f := newFilter(logf) f := newFilter(logf)
f.logIPs, _ = logB.IPSet() f.logIPs = tsaddr.NewContainsIPFunc(views.SliceOf([]netip.Prefix{
tsaddr.CGNATRange(),
tsaddr.TailscaleULARange(),
}))
var ( var (
ts4 = netip.AddrPortFrom(tsaddr.CGNATRange().Addr().Next(), 1234) ts4 = netip.AddrPortFrom(tsaddr.CGNATRange().Addr().Next(), 1234)
@ -820,11 +823,12 @@ func TestMatchesFromFilterRules(t *testing.T) {
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
cmpOpts := []cmp.Option{
compareIP := cmp.Comparer(func(a, b netip.Addr) bool { return a == b }) cmp.Comparer(func(a, b netip.Addr) bool { return a == b }),
compareIPPrefix := cmp.Comparer(func(a, b netip.Prefix) bool { return a == b }) cmp.Comparer(func(a, b netip.Prefix) bool { return a == b }),
cmpopts.IgnoreFields(Match{}, ".SrcsContains"),
if diff := cmp.Diff(got, tt.want, compareIP, compareIPPrefix); diff != "" { }
if diff := cmp.Diff(got, tt.want, cmpOpts...); diff != "" {
t.Errorf("wrong (-got+want)\n%s", diff) t.Errorf("wrong (-got+want)\n%s", diff)
} }
}) })

View File

@ -66,10 +66,11 @@ type CapMatch struct {
// Match matches packets from any IP address in Srcs to any ip:port in // Match matches packets from any IP address in Srcs to any ip:port in
// Dsts. // Dsts.
type Match struct { type Match struct {
IPProto []ipproto.Proto // required set (no default value at this layer) IPProto []ipproto.Proto // required set (no default value at this layer)
Srcs []netip.Prefix Srcs []netip.Prefix
Dsts []NetPortRange // optional, if Srcs match SrcsContains func(netip.Addr) bool `json:"-"` // report whether Addr is in Srcs
Caps []CapMatch // optional, if Srcs match Dsts []NetPortRange // optional, if Srcs match
Caps []CapMatch // optional, if Srcs match
} }
func (m Match) String() string { func (m Match) String() string {
@ -104,7 +105,7 @@ func (ms matches) match(q *packet.Parsed) bool {
if !slices.Contains(m.IPProto, q.IPProto) { if !slices.Contains(m.IPProto, q.IPProto) {
continue continue
} }
if !ipInList(q.Src.Addr(), m.Srcs) { if !m.SrcsContains(q.Src.Addr()) {
continue continue
} }
for _, dst := range m.Dsts { for _, dst := range m.Dsts {
@ -122,7 +123,7 @@ func (ms matches) match(q *packet.Parsed) bool {
func (ms matches) matchIPsOnly(q *packet.Parsed) bool { func (ms matches) matchIPsOnly(q *packet.Parsed) bool {
for _, m := range ms { for _, m := range ms {
if !ipInList(q.Src.Addr(), m.Srcs) { if !m.SrcsContains(q.Src.Addr()) {
continue continue
} }
for _, dst := range m.Dsts { for _, dst := range m.Dsts {
@ -142,7 +143,7 @@ func (ms matches) matchProtoAndIPsOnlyIfAllPorts(q *packet.Parsed) bool {
if !slices.Contains(m.IPProto, q.IPProto) { if !slices.Contains(m.IPProto, q.IPProto) {
continue continue
} }
if !ipInList(q.Src.Addr(), m.Srcs) { if !m.SrcsContains(q.Src.Addr()) {
continue continue
} }
for _, dst := range m.Dsts { for _, dst := range m.Dsts {
@ -156,12 +157,3 @@ func (ms matches) matchProtoAndIPsOnlyIfAllPorts(q *packet.Parsed) bool {
} }
return false return false
} }
func ipInList(ip netip.Addr, netlist []netip.Prefix) bool {
for _, net := range netlist {
if net.Contains(ip) {
return true
}
}
return false
}

View File

@ -10,8 +10,10 @@ import (
"go4.org/netipx" "go4.org/netipx"
"tailscale.com/net/netaddr" "tailscale.com/net/netaddr"
"tailscale.com/net/tsaddr"
"tailscale.com/tailcfg" "tailscale.com/tailcfg"
"tailscale.com/types/ipproto" "tailscale.com/types/ipproto"
"tailscale.com/types/views"
) )
var defaultProtos = []ipproto.Proto{ var defaultProtos = []ipproto.Proto{
@ -61,6 +63,7 @@ func MatchesFromFilterRules(pf []tailcfg.FilterRule) ([]Match, error) {
} }
m.Srcs = append(m.Srcs, nets...) m.Srcs = append(m.Srcs, nets...)
} }
m.SrcsContains = tsaddr.NewContainsIPFunc(views.SliceOf(m.Srcs))
for _, d := range r.DstPorts { for _, d := range r.DstPorts {
nets, err := parseIPSet(d.IP, d.Bits) nets, err := parseIPSet(d.IP, d.Bits)