wgengine/filter: support FilterRules matching on srcIP node caps [capver 100]

See #12542 for background.

Updates #12542

Change-Id: Ida312f700affc00d17681dc7551ee9672eeb1789
Signed-off-by: Brad Fitzpatrick <bradfitz@tailscale.com>
This commit is contained in:
Brad Fitzpatrick 2024-06-18 13:44:12 -07:00 committed by Maisem Ali
parent 07063bc5c7
commit 5ec01bf3ce
9 changed files with 212 additions and 56 deletions

View File

@ -250,9 +250,9 @@ type LocalBackend struct {
// delta node mutations as they come in (with mu held). The map values can // delta node mutations as they come in (with mu held). The map values can
// be given out to callers, but the map itself must not escape the LocalBackend. // be given out to callers, but the map itself must not escape the LocalBackend.
peers map[tailcfg.NodeID]tailcfg.NodeView peers map[tailcfg.NodeID]tailcfg.NodeView
nodeByAddr map[netip.Addr]tailcfg.NodeID nodeByAddr map[netip.Addr]tailcfg.NodeID // by Node.Addresses only (not subnet routes)
nmExpiryTimer tstime.TimerController // for updating netMap on node expiry; can be nil nmExpiryTimer tstime.TimerController // for updating netMap on node expiry; can be nil
activeLogin string // last logged LoginName from netMap activeLogin string // last logged LoginName from netMap
engineStatus ipn.EngineStatus engineStatus ipn.EngineStatus
endpoints []tailcfg.Endpoint endpoints []tailcfg.Endpoint
blocked bool blocked bool
@ -2021,7 +2021,7 @@ func (b *LocalBackend) updateFilterLocked(netMap *netmap.NetworkMap, prefs ipn.P
b.setFilter(filter.NewShieldsUpFilter(localNets, logNets, oldFilter, b.logf)) b.setFilter(filter.NewShieldsUpFilter(localNets, logNets, oldFilter, b.logf))
} else { } else {
b.logf("[v1] netmap packet filter: %v filters", len(packetFilter)) b.logf("[v1] netmap packet filter: %v filters", len(packetFilter))
b.setFilter(filter.New(packetFilter, localNets, logNets, oldFilter, b.logf)) b.setFilter(filter.New(packetFilter, b.srcIPHasCapForFilter, localNets, logNets, oldFilter, b.logf))
} }
// The filter for a jailed node is the exact same as a ShieldsUp filter. // The filter for a jailed node is the exact same as a ShieldsUp filter.
oldJailedFilter := b.e.GetJailedFilter() oldJailedFilter := b.e.GetJailedFilter()
@ -6839,3 +6839,28 @@ func (b *LocalBackend) startAutoUpdate(logPrefix string) (retErr error) {
}() }()
return nil return nil
} }
// srcIPHasCapForFilter is called by the packet filter when evaluating firewall
// rules that require a source IP to have a certain node capability.
//
// TODO(bradfitz): optimize this later if/when it matters.
func (b *LocalBackend) srcIPHasCapForFilter(srcIP netip.Addr, cap tailcfg.NodeCapability) bool {
if cap == "" {
// Shouldn't happen, but just in case.
// But the empty cap also shouldn't be found in Node.CapMap.
return false
}
b.mu.Lock()
defer b.mu.Unlock()
nodeID, ok := b.nodeByAddr[srcIP]
if !ok {
return false
}
n, ok := b.peers[nodeID]
if !ok {
return false
}
return n.HasCap(cap)
}

View File

@ -168,7 +168,7 @@ func setfilter(logf logger.Logf, tun *Wrapper) {
var sb netipx.IPSetBuilder var sb netipx.IPSetBuilder
sb.AddPrefix(netip.MustParsePrefix("1.2.0.0/16")) sb.AddPrefix(netip.MustParsePrefix("1.2.0.0/16"))
ipSet, _ := sb.IPSet() ipSet, _ := sb.IPSet()
tun.SetFilter(filter.New(matches, ipSet, ipSet, nil, logf)) tun.SetFilter(filter.New(matches, nil, ipSet, ipSet, nil, logf))
} }
func newChannelTUN(logf logger.Logf, secure bool) (*tuntest.ChannelTUN, *Wrapper) { func newChannelTUN(logf logger.Logf, secure bool) (*tuntest.ChannelTUN, *Wrapper) {

View File

@ -140,7 +140,8 @@
// - 97: 2024-06-06: Client understands NodeAttrDisableSplitDNSWhenNoCustomResolvers // - 97: 2024-06-06: Client understands NodeAttrDisableSplitDNSWhenNoCustomResolvers
// - 98: 2024-06-13: iOS/tvOS clients may provide serial number as part of posture information // - 98: 2024-06-13: iOS/tvOS clients may provide serial number as part of posture information
// - 99: 2024-06-14: Client understands NodeAttrDisableLocalDNSOverrideViaNRPT // - 99: 2024-06-14: Client understands NodeAttrDisableLocalDNSOverrideViaNRPT
const CurrentCapabilityVersion CapabilityVersion = 99 // - 100: 2024-06-18: Client supports filtertype.Match.SrcCaps (issue #12542)
const CurrentCapabilityVersion CapabilityVersion = 100
type StableID string type StableID string
@ -1480,6 +1481,7 @@ type FilterRule struct {
// * the string "*" to match everything (both IPv4 & IPv6) // * the string "*" to match everything (both IPv4 & IPv6)
// * a CIDR (e.g. "192.168.0.0/16") // * a CIDR (e.g. "192.168.0.0/16")
// * a range of two IPs, inclusive, separated by hyphen ("2eff::1-2eff::0800") // * a range of two IPs, inclusive, separated by hyphen ("2eff::1-2eff::0800")
// * a string "cap:<capability>" with NodeCapMap cap name
SrcIPs []string SrcIPs []string
// SrcBits is deprecated; it was the old way to specify a CIDR // SrcBits is deprecated; it was the old way to specify a CIDR

View File

@ -42,6 +42,10 @@ type Filter struct {
logIPs4 func(netip.Addr) bool logIPs4 func(netip.Addr) bool
logIPs6 func(netip.Addr) bool logIPs6 func(netip.Addr) bool
// srcIPHasCap optionally specifies a function that reports
// whether a given source IP address has a given capability.
srcIPHasCap CapTestFunc
// matches4 and matches6 are lists of match->action rules // matches4 and matches6 are lists of match->action rules
// applied to all packets arriving over tailscale // applied to all packets arriving over tailscale
// tunnels. Matches are checked in order, and processing stops // tunnels. Matches are checked in order, and processing stops
@ -157,12 +161,12 @@ func NewAllowAllForTest(logf logger.Logf) *Filter {
sb.AddPrefix(any4) sb.AddPrefix(any4)
sb.AddPrefix(any6) sb.AddPrefix(any6)
ipSet, _ := sb.IPSet() ipSet, _ := sb.IPSet()
return New(ms, ipSet, ipSet, nil, logf) return New(ms, nil, ipSet, ipSet, nil, logf)
} }
// NewAllowNone returns a packet filter that rejects everything. // NewAllowNone returns a packet filter that rejects everything.
func NewAllowNone(logf logger.Logf, logIPs *netipx.IPSet) *Filter { func NewAllowNone(logf logger.Logf, logIPs *netipx.IPSet) *Filter {
return New(nil, &netipx.IPSet{}, logIPs, nil, logf) return New(nil, nil, &netipx.IPSet{}, logIPs, nil, logf)
} }
// NewShieldsUpFilter returns a packet filter that rejects incoming connections. // NewShieldsUpFilter returns a packet filter that rejects incoming connections.
@ -174,17 +178,20 @@ func NewShieldsUpFilter(localNets *netipx.IPSet, logIPs *netipx.IPSet, shareStat
if shareStateWith != nil && !shareStateWith.shieldsUp { if shareStateWith != nil && !shareStateWith.shieldsUp {
shareStateWith = nil shareStateWith = nil
} }
f := New(nil, localNets, logIPs, shareStateWith, logf) f := New(nil, nil, localNets, logIPs, shareStateWith, logf)
f.shieldsUp = true f.shieldsUp = true
return f return f
} }
// New creates a new packet filter. The filter enforces that incoming // New creates a new packet filter. The filter enforces that incoming packets
// packets must be destined to an IP in localNets, and must be allowed // must be destined to an IP in localNets, and must be allowed by matches.
// by matches. If shareStateWith is non-nil, the returned filter // The optional capTest func is used to evaluate a Match that uses capabilities.
// shares state with the previous one, to enable changing rules at // If nil, such matches will always fail.
// runtime without breaking existing stateful flows. //
func New(matches []Match, localNets, logIPs *netipx.IPSet, shareStateWith *Filter, logf logger.Logf) *Filter { // If shareStateWith is non-nil, the returned filter shares state with the
// previous one, to enable changing rules at runtime without breaking existing
// stateful flows.
func New(matches []Match, capTest CapTestFunc, localNets, logIPs *netipx.IPSet, shareStateWith *Filter, logf logger.Logf) *Filter {
var state *filterState var state *filterState
if shareStateWith != nil { if shareStateWith != nil {
state = shareStateWith.state state = shareStateWith.state
@ -229,6 +236,7 @@ func matchesFamily(ms matches, keep func(netip.Addr) bool) matches {
for _, m := range ms { for _, m := range ms {
var retm Match var retm Match
retm.IPProto = m.IPProto retm.IPProto = m.IPProto
retm.SrcCaps = m.SrcCaps
for _, src := range m.Srcs { for _, src := range m.Srcs {
if keep(src.Addr()) { if keep(src.Addr()) {
retm.Srcs = append(retm.Srcs, src) retm.Srcs = append(retm.Srcs, src)
@ -240,7 +248,7 @@ func matchesFamily(ms matches, keep func(netip.Addr) bool) matches {
retm.Dsts = append(retm.Dsts, dst) retm.Dsts = append(retm.Dsts, dst)
} }
} }
if len(retm.Srcs) > 0 && len(retm.Dsts) > 0 { if (len(retm.Srcs) > 0 || len(retm.SrcCaps) > 0) && len(retm.Dsts) > 0 {
retm.SrcsContains = ipset.NewContainsIPFunc(views.SliceOf(retm.Srcs)) retm.SrcsContains = ipset.NewContainsIPFunc(views.SliceOf(retm.Srcs))
ret = append(ret, retm) ret = append(ret, retm)
} }
@ -462,7 +470,7 @@ func (f *Filter) runIn4(q *packet.Parsed) (r Response, why string) {
// related to an existing ICMP-Echo, TCP, or UDP // related to an existing ICMP-Echo, TCP, or UDP
// session. // session.
return Accept, "icmp response ok" return Accept, "icmp response ok"
} else if f.matches4.matchIPsOnly(q) { } else if f.matches4.matchIPsOnly(q, f.srcIPHasCap) {
// If any port is open to an IP, allow ICMP to it. // If any port is open to an IP, allow ICMP to it.
return Accept, "icmp ok" return Accept, "icmp ok"
} }
@ -478,7 +486,7 @@ func (f *Filter) runIn4(q *packet.Parsed) (r Response, why string) {
if !q.IsTCPSyn() { if !q.IsTCPSyn() {
return Accept, "tcp non-syn" return Accept, "tcp non-syn"
} }
if f.matches4.match(q) { if f.matches4.match(q, f.srcIPHasCap) {
return Accept, "tcp ok" return Accept, "tcp ok"
} }
case ipproto.UDP, ipproto.SCTP: case ipproto.UDP, ipproto.SCTP:
@ -491,7 +499,7 @@ func (f *Filter) runIn4(q *packet.Parsed) (r Response, why string) {
if ok { if ok {
return Accept, "cached" return Accept, "cached"
} }
if f.matches4.match(q) { if f.matches4.match(q, f.srcIPHasCap) {
return Accept, "ok" return Accept, "ok"
} }
case ipproto.TSMP: case ipproto.TSMP:
@ -522,7 +530,7 @@ func (f *Filter) runIn6(q *packet.Parsed) (r Response, why string) {
// related to an existing ICMP-Echo, TCP, or UDP // related to an existing ICMP-Echo, TCP, or UDP
// session. // session.
return Accept, "icmp response ok" return Accept, "icmp response ok"
} else if f.matches6.matchIPsOnly(q) { } else if f.matches6.matchIPsOnly(q, f.srcIPHasCap) {
// If any port is open to an IP, allow ICMP to it. // If any port is open to an IP, allow ICMP to it.
return Accept, "icmp ok" return Accept, "icmp ok"
} }
@ -538,7 +546,7 @@ func (f *Filter) runIn6(q *packet.Parsed) (r Response, why string) {
if q.IPProto == ipproto.TCP && !q.IsTCPSyn() { if q.IPProto == ipproto.TCP && !q.IsTCPSyn() {
return Accept, "tcp non-syn" return Accept, "tcp non-syn"
} }
if f.matches6.match(q) { if f.matches6.match(q, f.srcIPHasCap) {
return Accept, "tcp ok" return Accept, "tcp ok"
} }
case ipproto.UDP, ipproto.SCTP: case ipproto.UDP, ipproto.SCTP:
@ -551,7 +559,7 @@ func (f *Filter) runIn6(q *packet.Parsed) (r Response, why string) {
if ok { if ok {
return Accept, "cached" return Accept, "cached"
} }
if f.matches6.match(q) { if f.matches6.match(q, f.srcIPHasCap) {
return Accept, "ok" return Accept, "ok"
} }
case ipproto.TSMP: case ipproto.TSMP:

View File

@ -40,14 +40,31 @@
testDeniedProto ipproto.Proto = 127 // CRUDP, appropriately cruddy testDeniedProto ipproto.Proto = 127 // CRUDP, appropriately cruddy
) )
func m(srcs []netip.Prefix, dsts []NetPortRange, protos ...ipproto.Proto) Match { // m returnns a Match with the given srcs and dsts.
if protos == nil { //
// opts can be ipproto.Proto values (if none, defaultProtos is used)
// or tailcfg.NodeCapability values. Other values panic.
func m(srcs []netip.Prefix, dsts []NetPortRange, opts ...any) Match {
var protos []ipproto.Proto
var caps []tailcfg.NodeCapability
for _, o := range opts {
switch o := o.(type) {
case ipproto.Proto:
protos = append(protos, o)
case tailcfg.NodeCapability:
caps = append(caps, o)
default:
panic(fmt.Sprintf("unknown option type %T", o))
}
}
if len(protos) == 0 {
protos = defaultProtos protos = defaultProtos
} }
return Match{ return Match{
IPProto: views.SliceOf(protos), IPProto: views.SliceOf(protos),
Srcs: srcs, Srcs: srcs,
SrcsContains: ipset.NewContainsIPFunc(views.SliceOf(srcs)), SrcsContains: ipset.NewContainsIPFunc(views.SliceOf(srcs)),
SrcCaps: caps,
Dsts: dsts, Dsts: dsts,
} }
} }
@ -65,6 +82,7 @@ func newFilter(logf logger.Logf) *Filter {
m(nets("::/0"), netports("::/0:443")), m(nets("::/0"), netports("::/0:443")),
m(nets("0.0.0.0/0"), netports("0.0.0.0/0:*"), testAllowedProto), m(nets("0.0.0.0/0"), netports("0.0.0.0/0:*"), testAllowedProto),
m(nets("::/0"), netports("::/0:*"), testAllowedProto), m(nets("::/0"), netports("::/0:*"), testAllowedProto),
m(nil, netports("1.2.3.4:22"), tailcfg.NodeCapability("cap-hit-1234-ssh")),
} }
// Expects traffic to 100.122.98.50, 1.2.3.4, 5.6.7.8, // Expects traffic to 100.122.98.50, 1.2.3.4, 5.6.7.8,
@ -79,11 +97,17 @@ func newFilter(logf logger.Logf) *Filter {
localNetsSet, _ := localNets.IPSet() localNetsSet, _ := localNets.IPSet()
logBSet, _ := logB.IPSet() logBSet, _ := logB.IPSet()
return New(matches, localNetsSet, logBSet, nil, logf) return New(matches, nil, localNetsSet, logBSet, nil, logf)
} }
func TestFilter(t *testing.T) { func TestFilter(t *testing.T) {
acl := newFilter(t.Logf) filt := newFilter(t.Logf)
ipWithCap := netip.MustParseAddr("10.0.0.1")
ipWithoutCap := netip.MustParseAddr("10.0.0.2")
filt.srcIPHasCap = func(ip netip.Addr, cap tailcfg.NodeCapability) bool {
return cap == "cap-hit-1234-ssh" && ip == ipWithCap
}
type InOut struct { type InOut struct {
want Response want Response
@ -139,21 +163,27 @@ type InOut struct {
{Accept, parsed(testAllowedProto, "2001::1", "2001::2", 0, 0)}, {Accept, parsed(testAllowedProto, "2001::1", "2001::2", 0, 0)},
{Drop, parsed(testDeniedProto, "1.2.3.4", "5.6.7.8", 0, 0)}, {Drop, parsed(testDeniedProto, "1.2.3.4", "5.6.7.8", 0, 0)},
{Drop, parsed(testDeniedProto, "2001::1", "2001::2", 0, 0)}, {Drop, parsed(testDeniedProto, "2001::1", "2001::2", 0, 0)},
// Test use of a node capability to grant access.
// 10.0.0.1 has the capability; 10.0.0.2 does not (see srcIPHasCap at top of func)
{Accept, parsed(ipproto.TCP, ipWithCap.String(), "1.2.3.4", 30000, 22)},
{Drop, parsed(ipproto.TCP, ipWithoutCap.String(), "1.2.3.4", 30000, 22)},
} }
for i, test := range tests { for i, test := range tests {
aclFunc := acl.runIn4 aclFunc := filt.runIn4
if test.p.IPVersion == 6 { if test.p.IPVersion == 6 {
aclFunc = acl.runIn6 aclFunc = filt.runIn6
} }
if got, why := aclFunc(&test.p); test.want != got { if got, why := aclFunc(&test.p); test.want != got {
t.Errorf("#%d runIn got=%v want=%v why=%q packet:%v", i, got, test.want, why, test.p) t.Errorf("#%d runIn got=%v want=%v why=%q packet:%v", i, got, test.want, why, test.p)
continue
} }
if test.p.IPProto == ipproto.TCP { if test.p.IPProto == ipproto.TCP {
var got Response var got Response
if test.p.IPVersion == 4 { if test.p.IPVersion == 4 {
got = acl.CheckTCP(test.p.Src.Addr(), test.p.Dst.Addr(), test.p.Dst.Port()) got = filt.CheckTCP(test.p.Src.Addr(), test.p.Dst.Addr(), test.p.Dst.Port())
} else { } else {
got = acl.CheckTCP(test.p.Src.Addr(), test.p.Dst.Addr(), test.p.Dst.Port()) got = filt.CheckTCP(test.p.Src.Addr(), test.p.Dst.Addr(), test.p.Dst.Port())
} }
if test.want != got { if test.want != got {
t.Errorf("#%d CheckTCP got=%v want=%v packet:%v", i, got, test.want, test.p) t.Errorf("#%d CheckTCP got=%v want=%v packet:%v", i, got, test.want, test.p)
@ -165,7 +195,7 @@ type InOut struct {
} }
} }
// Update UDP state // Update UDP state
_, _ = acl.runOut(&test.p) _, _ = filt.runOut(&test.p)
} }
} }
@ -264,13 +294,16 @@ func TestParseIPSet(t *testing.T) {
{"*", pfx("0.0.0.0/0", "::/0"), ""}, {"*", pfx("0.0.0.0/0", "::/0"), ""},
} }
for _, tt := range tests { for _, tt := range tests {
got, err := parseIPSet(tt.host) got, gotCap, err := parseIPSet(tt.host)
if err != nil { if err != nil {
if err.Error() == tt.wantErr { if err.Error() == tt.wantErr {
continue continue
} }
t.Errorf("parseIPSet(%q) error: %v; want error %q", tt.host, err, tt.wantErr) t.Errorf("parseIPSet(%q) error: %v; want error %q", tt.host, err, tt.wantErr)
} }
if gotCap != "" {
t.Errorf("parseIPSet(%q) cap: %q; want empty", tt.host, gotCap)
}
compareIP := cmp.Comparer(func(a, b netip.Addr) bool { return a == b }) compareIP := cmp.Comparer(func(a, b netip.Addr) bool { return a == b })
compareIPPrefix := cmp.Comparer(func(a, b netip.Prefix) bool { return a == b }) compareIPPrefix := cmp.Comparer(func(a, b netip.Prefix) bool { return a == b })
if diff := cmp.Diff(got, tt.want, compareIP, compareIPPrefix); diff != "" { if diff := cmp.Diff(got, tt.want, compareIP, compareIPPrefix); diff != "" {
@ -278,6 +311,27 @@ func TestParseIPSet(t *testing.T) {
continue continue
} }
} }
capTests := []struct {
in string
want tailcfg.NodeCapability
}{
{"cap:foo", "foo"},
{"cap:people-in-8.8.8.0/24", "people-in-8.8.8.0/24"}, // test precedence of "/" search
}
for _, tt := range capTests {
pfxes, gotCap, err := parseIPSet(tt.in)
if err != nil {
t.Errorf("parseIPSet(%q) error: %v; want no error", tt.in, err)
continue
}
if gotCap != tt.want {
t.Errorf("parseIPSet(%q) cap: %q; want %q", tt.in, gotCap, tt.want)
}
if len(pfxes) != 0 {
t.Errorf("parseIPSet(%q) pfxes: %v; want empty", tt.in, pfxes)
}
}
} }
func BenchmarkFilter(b *testing.B) { func BenchmarkFilter(b *testing.B) {
@ -904,7 +958,7 @@ func TestPeerCaps(t *testing.T) {
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
filt := New(mm, nil, nil, nil, t.Logf) filt := New(mm, nil, nil, nil, nil, t.Logf)
tests := []struct { tests := []struct {
name string name string
src, dst string // IP src, dst string // IP
@ -1037,7 +1091,7 @@ func benchmarkFile(b *testing.B, file string, opt benchOpt) {
logIPs.AddPrefix(tsaddr.CGNATRange()) logIPs.AddPrefix(tsaddr.CGNATRange())
logIPs.AddPrefix(tsaddr.TailscaleULARange()) logIPs.AddPrefix(tsaddr.TailscaleULARange())
f := New(matches, must.Get(localNets.IPSet()), must.Get(logIPs.IPSet()), nil, logger.Discard) f := New(matches, nil, must.Get(localNets.IPSet()), must.Get(logIPs.IPSet()), nil, logger.Discard)
var srcIP, dstIP netip.Addr var srcIP, dstIP netip.Addr
if opt.v4 { if opt.v4 {
srcIP = netip.MustParseAddr("1.2.3.4") srcIP = netip.MustParseAddr("1.2.3.4")

View File

@ -66,11 +66,28 @@ type CapMatch struct {
// Match matches packets from any IP address in Srcs to any ip:port in // Match matches packets from any IP address in Srcs to any ip:port in
// Dsts. // Dsts.
type Match struct { type Match struct {
IPProto views.Slice[ipproto.Proto] // required set (no default value at this layer) // IPProto is the set of IP protocol numbers for which this match applies.
Srcs []netip.Prefix // It is required. There is no default value at this layer.
// If empty, it doesn't match.
IPProto views.Slice[ipproto.Proto]
// Srcs is the set of source IP prefixes for which this match applies. A
// Match can match by either its source IP address being in Srcs (which
// SrcsContains tests) or if the source IP is of a known peer self address
// that contains a NodeCapability listed in SrcCaps.
Srcs []netip.Prefix
// SrcsContains is an optimized function that reports whether Addr is in
// Srcs, using the best search method for the size and shape of Srcs.
SrcsContains func(netip.Addr) bool `json:"-"` // report whether Addr is in Srcs SrcsContains func(netip.Addr) bool `json:"-"` // report whether Addr is in Srcs
Dsts []NetPortRange // optional, if Srcs match
Caps []CapMatch // optional, if Srcs match // SrcCaps is an alternative way to match packets. If the peer's source IP
// has one of these capabilities, it's also permitted. The peers are only
// looked up by their self address (Node.Addresses) and not by subnet routes
// they advertise.
SrcCaps []tailcfg.NodeCapability
Dsts []NetPortRange // optional, if source matches
Caps []CapMatch // optional, if source match
} }
func (m Match) String() string { func (m Match) String() string {

View File

@ -23,6 +23,7 @@ func (src *Match) Clone() *Match {
*dst = *src *dst = *src
dst.IPProto = src.IPProto dst.IPProto = src.IPProto
dst.Srcs = append(src.Srcs[:0:0], src.Srcs...) dst.Srcs = append(src.Srcs[:0:0], src.Srcs...)
dst.SrcCaps = append(src.SrcCaps[:0:0], src.SrcCaps...)
dst.Dsts = append(src.Dsts[:0:0], src.Dsts...) dst.Dsts = append(src.Dsts[:0:0], src.Dsts...)
if src.Caps != nil { if src.Caps != nil {
dst.Caps = make([]CapMatch, len(src.Caps)) dst.Caps = make([]CapMatch, len(src.Caps))
@ -38,6 +39,7 @@ func (src *Match) Clone() *Match {
IPProto views.Slice[ipproto.Proto] IPProto views.Slice[ipproto.Proto]
Srcs []netip.Prefix Srcs []netip.Prefix
SrcsContains func(netip.Addr) bool SrcsContains func(netip.Addr) bool
SrcCaps []tailcfg.NodeCapability
Dsts []NetPortRange Dsts []NetPortRange
Caps []CapMatch Caps []CapMatch
}{}) }{})

View File

@ -4,19 +4,23 @@
package filter package filter
import ( import (
"net/netip"
"tailscale.com/net/packet" "tailscale.com/net/packet"
"tailscale.com/tailcfg"
"tailscale.com/types/views" "tailscale.com/types/views"
"tailscale.com/wgengine/filter/filtertype" "tailscale.com/wgengine/filter/filtertype"
) )
type matches []filtertype.Match type matches []filtertype.Match
func (ms matches) match(q *packet.Parsed) bool { func (ms matches) match(q *packet.Parsed, hasCap CapTestFunc) bool {
for _, m := range ms { for i := range ms {
m := &ms[i]
if !views.SliceContains(m.IPProto, q.IPProto) { if !views.SliceContains(m.IPProto, q.IPProto) {
continue continue
} }
if !m.SrcsContains(q.Src.Addr()) { if !srcMatches(m, q.Src.Addr(), hasCap) {
continue continue
} }
for _, dst := range m.Dsts { for _, dst := range m.Dsts {
@ -32,9 +36,33 @@ func (ms matches) match(q *packet.Parsed) bool {
return false return false
} }
func (ms matches) matchIPsOnly(q *packet.Parsed) bool { // srcMatches reports whether srcAddr matche the src requirements in m, either
// by Srcs (using SrcsContains), or by the node having a capability listed
// in SrcCaps using the provided hasCap function.
func srcMatches(m *filtertype.Match, srcAddr netip.Addr, hasCap CapTestFunc) bool {
if m.SrcsContains(srcAddr) {
return true
}
if hasCap != nil {
for _, c := range m.SrcCaps {
if hasCap(srcAddr, c) {
return true
}
}
}
return false
}
// CapTestFunc is the function signature of a function that tests whether srcIP
// has a given capability.
//
// It it used in the fast path of evaluating filter rules so should be fast.
type CapTestFunc = func(srcIP netip.Addr, cap tailcfg.NodeCapability) bool
func (ms matches) matchIPsOnly(q *packet.Parsed, hasCap CapTestFunc) bool {
srcAddr := q.Src.Addr()
for _, m := range ms { for _, m := range ms {
if !m.SrcsContains(q.Src.Addr()) { if !m.SrcsContains(srcAddr) {
continue continue
} }
for _, dst := range m.Dsts { for _, dst := range m.Dsts {
@ -43,6 +71,15 @@ func (ms matches) matchIPsOnly(q *packet.Parsed) bool {
} }
} }
} }
if hasCap != nil {
for _, m := range ms {
for _, c := range m.SrcCaps {
if hasCap(srcAddr, c) {
return true
}
}
}
}
return false return false
} }

View File

@ -58,12 +58,15 @@ func MatchesFromFilterRules(pf []tailcfg.FilterRule) ([]Match, error) {
} }
for _, s := range r.SrcIPs { for _, s := range r.SrcIPs {
nets, err := parseIPSet(s) nets, cap, err := parseIPSet(s)
if err != nil && erracc == nil { if err != nil && erracc == nil {
erracc = err erracc = err
continue continue
} }
m.Srcs = append(m.Srcs, nets...) m.Srcs = append(m.Srcs, nets...)
if cap != "" {
m.SrcCaps = append(m.SrcCaps, cap)
}
} }
m.SrcsContains = ipset.NewContainsIPFunc(views.SliceOf(m.Srcs)) m.SrcsContains = ipset.NewContainsIPFunc(views.SliceOf(m.Srcs))
@ -71,11 +74,15 @@ func MatchesFromFilterRules(pf []tailcfg.FilterRule) ([]Match, error) {
if d.Bits != nil { if d.Bits != nil {
return nil, fmt.Errorf("unexpected DstBits; control plane should not send this to this client version") return nil, fmt.Errorf("unexpected DstBits; control plane should not send this to this client version")
} }
nets, err := parseIPSet(d.IP) nets, cap, err := parseIPSet(d.IP)
if err != nil && erracc == nil { if err != nil && erracc == nil {
erracc = err erracc = err
continue continue
} }
if cap != "" {
erracc = fmt.Errorf("unexpected capability %q in DstPorts", cap)
continue
}
for _, net := range nets { for _, net := range nets {
m.Dsts = append(m.Dsts, NetPortRange{ m.Dsts = append(m.Dsts, NetPortRange{
Net: net, Net: net,
@ -120,48 +127,52 @@ func MatchesFromFilterRules(pf []tailcfg.FilterRule) ([]Match, error) {
// - the string "*" to match everything (both IPv4 & IPv6) // - the string "*" to match everything (both IPv4 & IPv6)
// - a CIDR (e.g. "192.168.0.0/16") // - a CIDR (e.g. "192.168.0.0/16")
// - a range of two IPs, inclusive, separated by hyphen ("2eff::1-2eff::0800") // - a range of two IPs, inclusive, separated by hyphen ("2eff::1-2eff::0800")
// - "cap:<peer-node-capability>" to match a peer node capability
// //
// TODO(bradfitz): make this return an IPSet and plumb that all // TODO(bradfitz): make this return an IPSet and plumb that all
// around, and ultimately use a new version of IPSet.ContainsFunc like // around, and ultimately use a new version of IPSet.ContainsFunc like
// Contains16Func that works in [16]byte address, so we we can match // Contains16Func that works in [16]byte address, so we we can match
// at runtime without allocating? // at runtime without allocating?
func parseIPSet(arg string) ([]netip.Prefix, error) { func parseIPSet(arg string) (prefixes []netip.Prefix, peerCap tailcfg.NodeCapability, err error) {
if arg == "*" { if arg == "*" {
// User explicitly requested wildcard. // User explicitly requested wildcard.
return []netip.Prefix{ return []netip.Prefix{
netip.PrefixFrom(zeroIP4, 0), netip.PrefixFrom(zeroIP4, 0),
netip.PrefixFrom(zeroIP6, 0), netip.PrefixFrom(zeroIP6, 0),
}, nil }, "", nil
}
if cap, ok := strings.CutPrefix(arg, "cap:"); ok {
return nil, tailcfg.NodeCapability(cap), nil
} }
if strings.Contains(arg, "/") { if strings.Contains(arg, "/") {
pfx, err := netip.ParsePrefix(arg) pfx, err := netip.ParsePrefix(arg)
if err != nil { if err != nil {
return nil, err return nil, "", err
} }
if pfx != pfx.Masked() { if pfx != pfx.Masked() {
return nil, fmt.Errorf("%v contains non-network bits set", pfx) return nil, "", fmt.Errorf("%v contains non-network bits set", pfx)
} }
return []netip.Prefix{pfx}, nil return []netip.Prefix{pfx}, "", nil
} }
if strings.Count(arg, "-") == 1 { if strings.Count(arg, "-") == 1 {
ip1s, ip2s, _ := strings.Cut(arg, "-") ip1s, ip2s, _ := strings.Cut(arg, "-")
ip1, err := netip.ParseAddr(ip1s) ip1, err := netip.ParseAddr(ip1s)
if err != nil { if err != nil {
return nil, err return nil, "", err
} }
ip2, err := netip.ParseAddr(ip2s) ip2, err := netip.ParseAddr(ip2s)
if err != nil { if err != nil {
return nil, err return nil, "", err
} }
r := netipx.IPRangeFrom(ip1, ip2) r := netipx.IPRangeFrom(ip1, ip2)
if !r.IsValid() { if !r.IsValid() {
return nil, fmt.Errorf("invalid IP range %q", arg) return nil, "", fmt.Errorf("invalid IP range %q", arg)
} }
return r.Prefixes(), nil return r.Prefixes(), "", nil
} }
ip, err := netip.ParseAddr(arg) ip, err := netip.ParseAddr(arg)
if err != nil { if err != nil {
return nil, fmt.Errorf("invalid IP address %q", arg) return nil, "", fmt.Errorf("invalid IP address %q", arg)
} }
return []netip.Prefix{netip.PrefixFrom(ip, ip.BitLen())}, nil return []netip.Prefix{netip.PrefixFrom(ip, ip.BitLen())}, "", nil
} }