mirror of
https://github.com/tailscale/tailscale.git
synced 2024-11-29 04:55:31 +00:00
wgengine/filter: support FilterRules matching on srcIP node caps [capver 100]
See #12542 for background. Updates #12542 Change-Id: Ida312f700affc00d17681dc7551ee9672eeb1789 Signed-off-by: Brad Fitzpatrick <bradfitz@tailscale.com>
This commit is contained in:
parent
07063bc5c7
commit
5ec01bf3ce
@ -250,9 +250,9 @@ type LocalBackend struct {
|
|||||||
// delta node mutations as they come in (with mu held). The map values can
|
// delta node mutations as they come in (with mu held). The map values can
|
||||||
// be given out to callers, but the map itself must not escape the LocalBackend.
|
// be given out to callers, but the map itself must not escape the LocalBackend.
|
||||||
peers map[tailcfg.NodeID]tailcfg.NodeView
|
peers map[tailcfg.NodeID]tailcfg.NodeView
|
||||||
nodeByAddr map[netip.Addr]tailcfg.NodeID
|
nodeByAddr map[netip.Addr]tailcfg.NodeID // by Node.Addresses only (not subnet routes)
|
||||||
nmExpiryTimer tstime.TimerController // for updating netMap on node expiry; can be nil
|
nmExpiryTimer tstime.TimerController // for updating netMap on node expiry; can be nil
|
||||||
activeLogin string // last logged LoginName from netMap
|
activeLogin string // last logged LoginName from netMap
|
||||||
engineStatus ipn.EngineStatus
|
engineStatus ipn.EngineStatus
|
||||||
endpoints []tailcfg.Endpoint
|
endpoints []tailcfg.Endpoint
|
||||||
blocked bool
|
blocked bool
|
||||||
@ -2021,7 +2021,7 @@ func (b *LocalBackend) updateFilterLocked(netMap *netmap.NetworkMap, prefs ipn.P
|
|||||||
b.setFilter(filter.NewShieldsUpFilter(localNets, logNets, oldFilter, b.logf))
|
b.setFilter(filter.NewShieldsUpFilter(localNets, logNets, oldFilter, b.logf))
|
||||||
} else {
|
} else {
|
||||||
b.logf("[v1] netmap packet filter: %v filters", len(packetFilter))
|
b.logf("[v1] netmap packet filter: %v filters", len(packetFilter))
|
||||||
b.setFilter(filter.New(packetFilter, localNets, logNets, oldFilter, b.logf))
|
b.setFilter(filter.New(packetFilter, b.srcIPHasCapForFilter, localNets, logNets, oldFilter, b.logf))
|
||||||
}
|
}
|
||||||
// The filter for a jailed node is the exact same as a ShieldsUp filter.
|
// The filter for a jailed node is the exact same as a ShieldsUp filter.
|
||||||
oldJailedFilter := b.e.GetJailedFilter()
|
oldJailedFilter := b.e.GetJailedFilter()
|
||||||
@ -6839,3 +6839,28 @@ func (b *LocalBackend) startAutoUpdate(logPrefix string) (retErr error) {
|
|||||||
}()
|
}()
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// srcIPHasCapForFilter is called by the packet filter when evaluating firewall
|
||||||
|
// rules that require a source IP to have a certain node capability.
|
||||||
|
//
|
||||||
|
// TODO(bradfitz): optimize this later if/when it matters.
|
||||||
|
func (b *LocalBackend) srcIPHasCapForFilter(srcIP netip.Addr, cap tailcfg.NodeCapability) bool {
|
||||||
|
if cap == "" {
|
||||||
|
// Shouldn't happen, but just in case.
|
||||||
|
// But the empty cap also shouldn't be found in Node.CapMap.
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
b.mu.Lock()
|
||||||
|
defer b.mu.Unlock()
|
||||||
|
|
||||||
|
nodeID, ok := b.nodeByAddr[srcIP]
|
||||||
|
if !ok {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
n, ok := b.peers[nodeID]
|
||||||
|
if !ok {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return n.HasCap(cap)
|
||||||
|
}
|
||||||
|
@ -168,7 +168,7 @@ func setfilter(logf logger.Logf, tun *Wrapper) {
|
|||||||
var sb netipx.IPSetBuilder
|
var sb netipx.IPSetBuilder
|
||||||
sb.AddPrefix(netip.MustParsePrefix("1.2.0.0/16"))
|
sb.AddPrefix(netip.MustParsePrefix("1.2.0.0/16"))
|
||||||
ipSet, _ := sb.IPSet()
|
ipSet, _ := sb.IPSet()
|
||||||
tun.SetFilter(filter.New(matches, ipSet, ipSet, nil, logf))
|
tun.SetFilter(filter.New(matches, nil, ipSet, ipSet, nil, logf))
|
||||||
}
|
}
|
||||||
|
|
||||||
func newChannelTUN(logf logger.Logf, secure bool) (*tuntest.ChannelTUN, *Wrapper) {
|
func newChannelTUN(logf logger.Logf, secure bool) (*tuntest.ChannelTUN, *Wrapper) {
|
||||||
|
@ -140,7 +140,8 @@
|
|||||||
// - 97: 2024-06-06: Client understands NodeAttrDisableSplitDNSWhenNoCustomResolvers
|
// - 97: 2024-06-06: Client understands NodeAttrDisableSplitDNSWhenNoCustomResolvers
|
||||||
// - 98: 2024-06-13: iOS/tvOS clients may provide serial number as part of posture information
|
// - 98: 2024-06-13: iOS/tvOS clients may provide serial number as part of posture information
|
||||||
// - 99: 2024-06-14: Client understands NodeAttrDisableLocalDNSOverrideViaNRPT
|
// - 99: 2024-06-14: Client understands NodeAttrDisableLocalDNSOverrideViaNRPT
|
||||||
const CurrentCapabilityVersion CapabilityVersion = 99
|
// - 100: 2024-06-18: Client supports filtertype.Match.SrcCaps (issue #12542)
|
||||||
|
const CurrentCapabilityVersion CapabilityVersion = 100
|
||||||
|
|
||||||
type StableID string
|
type StableID string
|
||||||
|
|
||||||
@ -1480,6 +1481,7 @@ type FilterRule struct {
|
|||||||
// * the string "*" to match everything (both IPv4 & IPv6)
|
// * the string "*" to match everything (both IPv4 & IPv6)
|
||||||
// * a CIDR (e.g. "192.168.0.0/16")
|
// * a CIDR (e.g. "192.168.0.0/16")
|
||||||
// * a range of two IPs, inclusive, separated by hyphen ("2eff::1-2eff::0800")
|
// * a range of two IPs, inclusive, separated by hyphen ("2eff::1-2eff::0800")
|
||||||
|
// * a string "cap:<capability>" with NodeCapMap cap name
|
||||||
SrcIPs []string
|
SrcIPs []string
|
||||||
|
|
||||||
// SrcBits is deprecated; it was the old way to specify a CIDR
|
// SrcBits is deprecated; it was the old way to specify a CIDR
|
||||||
|
@ -42,6 +42,10 @@ type Filter struct {
|
|||||||
logIPs4 func(netip.Addr) bool
|
logIPs4 func(netip.Addr) bool
|
||||||
logIPs6 func(netip.Addr) bool
|
logIPs6 func(netip.Addr) bool
|
||||||
|
|
||||||
|
// srcIPHasCap optionally specifies a function that reports
|
||||||
|
// whether a given source IP address has a given capability.
|
||||||
|
srcIPHasCap CapTestFunc
|
||||||
|
|
||||||
// matches4 and matches6 are lists of match->action rules
|
// matches4 and matches6 are lists of match->action rules
|
||||||
// applied to all packets arriving over tailscale
|
// applied to all packets arriving over tailscale
|
||||||
// tunnels. Matches are checked in order, and processing stops
|
// tunnels. Matches are checked in order, and processing stops
|
||||||
@ -157,12 +161,12 @@ func NewAllowAllForTest(logf logger.Logf) *Filter {
|
|||||||
sb.AddPrefix(any4)
|
sb.AddPrefix(any4)
|
||||||
sb.AddPrefix(any6)
|
sb.AddPrefix(any6)
|
||||||
ipSet, _ := sb.IPSet()
|
ipSet, _ := sb.IPSet()
|
||||||
return New(ms, ipSet, ipSet, nil, logf)
|
return New(ms, nil, ipSet, ipSet, nil, logf)
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewAllowNone returns a packet filter that rejects everything.
|
// NewAllowNone returns a packet filter that rejects everything.
|
||||||
func NewAllowNone(logf logger.Logf, logIPs *netipx.IPSet) *Filter {
|
func NewAllowNone(logf logger.Logf, logIPs *netipx.IPSet) *Filter {
|
||||||
return New(nil, &netipx.IPSet{}, logIPs, nil, logf)
|
return New(nil, nil, &netipx.IPSet{}, logIPs, nil, logf)
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewShieldsUpFilter returns a packet filter that rejects incoming connections.
|
// NewShieldsUpFilter returns a packet filter that rejects incoming connections.
|
||||||
@ -174,17 +178,20 @@ func NewShieldsUpFilter(localNets *netipx.IPSet, logIPs *netipx.IPSet, shareStat
|
|||||||
if shareStateWith != nil && !shareStateWith.shieldsUp {
|
if shareStateWith != nil && !shareStateWith.shieldsUp {
|
||||||
shareStateWith = nil
|
shareStateWith = nil
|
||||||
}
|
}
|
||||||
f := New(nil, localNets, logIPs, shareStateWith, logf)
|
f := New(nil, nil, localNets, logIPs, shareStateWith, logf)
|
||||||
f.shieldsUp = true
|
f.shieldsUp = true
|
||||||
return f
|
return f
|
||||||
}
|
}
|
||||||
|
|
||||||
// New creates a new packet filter. The filter enforces that incoming
|
// New creates a new packet filter. The filter enforces that incoming packets
|
||||||
// packets must be destined to an IP in localNets, and must be allowed
|
// must be destined to an IP in localNets, and must be allowed by matches.
|
||||||
// by matches. If shareStateWith is non-nil, the returned filter
|
// The optional capTest func is used to evaluate a Match that uses capabilities.
|
||||||
// shares state with the previous one, to enable changing rules at
|
// If nil, such matches will always fail.
|
||||||
// runtime without breaking existing stateful flows.
|
//
|
||||||
func New(matches []Match, localNets, logIPs *netipx.IPSet, shareStateWith *Filter, logf logger.Logf) *Filter {
|
// If shareStateWith is non-nil, the returned filter shares state with the
|
||||||
|
// previous one, to enable changing rules at runtime without breaking existing
|
||||||
|
// stateful flows.
|
||||||
|
func New(matches []Match, capTest CapTestFunc, localNets, logIPs *netipx.IPSet, shareStateWith *Filter, logf logger.Logf) *Filter {
|
||||||
var state *filterState
|
var state *filterState
|
||||||
if shareStateWith != nil {
|
if shareStateWith != nil {
|
||||||
state = shareStateWith.state
|
state = shareStateWith.state
|
||||||
@ -229,6 +236,7 @@ func matchesFamily(ms matches, keep func(netip.Addr) bool) matches {
|
|||||||
for _, m := range ms {
|
for _, m := range ms {
|
||||||
var retm Match
|
var retm Match
|
||||||
retm.IPProto = m.IPProto
|
retm.IPProto = m.IPProto
|
||||||
|
retm.SrcCaps = m.SrcCaps
|
||||||
for _, src := range m.Srcs {
|
for _, src := range m.Srcs {
|
||||||
if keep(src.Addr()) {
|
if keep(src.Addr()) {
|
||||||
retm.Srcs = append(retm.Srcs, src)
|
retm.Srcs = append(retm.Srcs, src)
|
||||||
@ -240,7 +248,7 @@ func matchesFamily(ms matches, keep func(netip.Addr) bool) matches {
|
|||||||
retm.Dsts = append(retm.Dsts, dst)
|
retm.Dsts = append(retm.Dsts, dst)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if len(retm.Srcs) > 0 && len(retm.Dsts) > 0 {
|
if (len(retm.Srcs) > 0 || len(retm.SrcCaps) > 0) && len(retm.Dsts) > 0 {
|
||||||
retm.SrcsContains = ipset.NewContainsIPFunc(views.SliceOf(retm.Srcs))
|
retm.SrcsContains = ipset.NewContainsIPFunc(views.SliceOf(retm.Srcs))
|
||||||
ret = append(ret, retm)
|
ret = append(ret, retm)
|
||||||
}
|
}
|
||||||
@ -462,7 +470,7 @@ func (f *Filter) runIn4(q *packet.Parsed) (r Response, why string) {
|
|||||||
// related to an existing ICMP-Echo, TCP, or UDP
|
// related to an existing ICMP-Echo, TCP, or UDP
|
||||||
// session.
|
// session.
|
||||||
return Accept, "icmp response ok"
|
return Accept, "icmp response ok"
|
||||||
} else if f.matches4.matchIPsOnly(q) {
|
} else if f.matches4.matchIPsOnly(q, f.srcIPHasCap) {
|
||||||
// If any port is open to an IP, allow ICMP to it.
|
// If any port is open to an IP, allow ICMP to it.
|
||||||
return Accept, "icmp ok"
|
return Accept, "icmp ok"
|
||||||
}
|
}
|
||||||
@ -478,7 +486,7 @@ func (f *Filter) runIn4(q *packet.Parsed) (r Response, why string) {
|
|||||||
if !q.IsTCPSyn() {
|
if !q.IsTCPSyn() {
|
||||||
return Accept, "tcp non-syn"
|
return Accept, "tcp non-syn"
|
||||||
}
|
}
|
||||||
if f.matches4.match(q) {
|
if f.matches4.match(q, f.srcIPHasCap) {
|
||||||
return Accept, "tcp ok"
|
return Accept, "tcp ok"
|
||||||
}
|
}
|
||||||
case ipproto.UDP, ipproto.SCTP:
|
case ipproto.UDP, ipproto.SCTP:
|
||||||
@ -491,7 +499,7 @@ func (f *Filter) runIn4(q *packet.Parsed) (r Response, why string) {
|
|||||||
if ok {
|
if ok {
|
||||||
return Accept, "cached"
|
return Accept, "cached"
|
||||||
}
|
}
|
||||||
if f.matches4.match(q) {
|
if f.matches4.match(q, f.srcIPHasCap) {
|
||||||
return Accept, "ok"
|
return Accept, "ok"
|
||||||
}
|
}
|
||||||
case ipproto.TSMP:
|
case ipproto.TSMP:
|
||||||
@ -522,7 +530,7 @@ func (f *Filter) runIn6(q *packet.Parsed) (r Response, why string) {
|
|||||||
// related to an existing ICMP-Echo, TCP, or UDP
|
// related to an existing ICMP-Echo, TCP, or UDP
|
||||||
// session.
|
// session.
|
||||||
return Accept, "icmp response ok"
|
return Accept, "icmp response ok"
|
||||||
} else if f.matches6.matchIPsOnly(q) {
|
} else if f.matches6.matchIPsOnly(q, f.srcIPHasCap) {
|
||||||
// If any port is open to an IP, allow ICMP to it.
|
// If any port is open to an IP, allow ICMP to it.
|
||||||
return Accept, "icmp ok"
|
return Accept, "icmp ok"
|
||||||
}
|
}
|
||||||
@ -538,7 +546,7 @@ func (f *Filter) runIn6(q *packet.Parsed) (r Response, why string) {
|
|||||||
if q.IPProto == ipproto.TCP && !q.IsTCPSyn() {
|
if q.IPProto == ipproto.TCP && !q.IsTCPSyn() {
|
||||||
return Accept, "tcp non-syn"
|
return Accept, "tcp non-syn"
|
||||||
}
|
}
|
||||||
if f.matches6.match(q) {
|
if f.matches6.match(q, f.srcIPHasCap) {
|
||||||
return Accept, "tcp ok"
|
return Accept, "tcp ok"
|
||||||
}
|
}
|
||||||
case ipproto.UDP, ipproto.SCTP:
|
case ipproto.UDP, ipproto.SCTP:
|
||||||
@ -551,7 +559,7 @@ func (f *Filter) runIn6(q *packet.Parsed) (r Response, why string) {
|
|||||||
if ok {
|
if ok {
|
||||||
return Accept, "cached"
|
return Accept, "cached"
|
||||||
}
|
}
|
||||||
if f.matches6.match(q) {
|
if f.matches6.match(q, f.srcIPHasCap) {
|
||||||
return Accept, "ok"
|
return Accept, "ok"
|
||||||
}
|
}
|
||||||
case ipproto.TSMP:
|
case ipproto.TSMP:
|
||||||
|
@ -40,14 +40,31 @@
|
|||||||
testDeniedProto ipproto.Proto = 127 // CRUDP, appropriately cruddy
|
testDeniedProto ipproto.Proto = 127 // CRUDP, appropriately cruddy
|
||||||
)
|
)
|
||||||
|
|
||||||
func m(srcs []netip.Prefix, dsts []NetPortRange, protos ...ipproto.Proto) Match {
|
// m returnns a Match with the given srcs and dsts.
|
||||||
if protos == nil {
|
//
|
||||||
|
// opts can be ipproto.Proto values (if none, defaultProtos is used)
|
||||||
|
// or tailcfg.NodeCapability values. Other values panic.
|
||||||
|
func m(srcs []netip.Prefix, dsts []NetPortRange, opts ...any) Match {
|
||||||
|
var protos []ipproto.Proto
|
||||||
|
var caps []tailcfg.NodeCapability
|
||||||
|
for _, o := range opts {
|
||||||
|
switch o := o.(type) {
|
||||||
|
case ipproto.Proto:
|
||||||
|
protos = append(protos, o)
|
||||||
|
case tailcfg.NodeCapability:
|
||||||
|
caps = append(caps, o)
|
||||||
|
default:
|
||||||
|
panic(fmt.Sprintf("unknown option type %T", o))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if len(protos) == 0 {
|
||||||
protos = defaultProtos
|
protos = defaultProtos
|
||||||
}
|
}
|
||||||
return Match{
|
return Match{
|
||||||
IPProto: views.SliceOf(protos),
|
IPProto: views.SliceOf(protos),
|
||||||
Srcs: srcs,
|
Srcs: srcs,
|
||||||
SrcsContains: ipset.NewContainsIPFunc(views.SliceOf(srcs)),
|
SrcsContains: ipset.NewContainsIPFunc(views.SliceOf(srcs)),
|
||||||
|
SrcCaps: caps,
|
||||||
Dsts: dsts,
|
Dsts: dsts,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -65,6 +82,7 @@ func newFilter(logf logger.Logf) *Filter {
|
|||||||
m(nets("::/0"), netports("::/0:443")),
|
m(nets("::/0"), netports("::/0:443")),
|
||||||
m(nets("0.0.0.0/0"), netports("0.0.0.0/0:*"), testAllowedProto),
|
m(nets("0.0.0.0/0"), netports("0.0.0.0/0:*"), testAllowedProto),
|
||||||
m(nets("::/0"), netports("::/0:*"), testAllowedProto),
|
m(nets("::/0"), netports("::/0:*"), testAllowedProto),
|
||||||
|
m(nil, netports("1.2.3.4:22"), tailcfg.NodeCapability("cap-hit-1234-ssh")),
|
||||||
}
|
}
|
||||||
|
|
||||||
// Expects traffic to 100.122.98.50, 1.2.3.4, 5.6.7.8,
|
// Expects traffic to 100.122.98.50, 1.2.3.4, 5.6.7.8,
|
||||||
@ -79,11 +97,17 @@ func newFilter(logf logger.Logf) *Filter {
|
|||||||
localNetsSet, _ := localNets.IPSet()
|
localNetsSet, _ := localNets.IPSet()
|
||||||
logBSet, _ := logB.IPSet()
|
logBSet, _ := logB.IPSet()
|
||||||
|
|
||||||
return New(matches, localNetsSet, logBSet, nil, logf)
|
return New(matches, nil, localNetsSet, logBSet, nil, logf)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestFilter(t *testing.T) {
|
func TestFilter(t *testing.T) {
|
||||||
acl := newFilter(t.Logf)
|
filt := newFilter(t.Logf)
|
||||||
|
|
||||||
|
ipWithCap := netip.MustParseAddr("10.0.0.1")
|
||||||
|
ipWithoutCap := netip.MustParseAddr("10.0.0.2")
|
||||||
|
filt.srcIPHasCap = func(ip netip.Addr, cap tailcfg.NodeCapability) bool {
|
||||||
|
return cap == "cap-hit-1234-ssh" && ip == ipWithCap
|
||||||
|
}
|
||||||
|
|
||||||
type InOut struct {
|
type InOut struct {
|
||||||
want Response
|
want Response
|
||||||
@ -139,21 +163,27 @@ type InOut struct {
|
|||||||
{Accept, parsed(testAllowedProto, "2001::1", "2001::2", 0, 0)},
|
{Accept, parsed(testAllowedProto, "2001::1", "2001::2", 0, 0)},
|
||||||
{Drop, parsed(testDeniedProto, "1.2.3.4", "5.6.7.8", 0, 0)},
|
{Drop, parsed(testDeniedProto, "1.2.3.4", "5.6.7.8", 0, 0)},
|
||||||
{Drop, parsed(testDeniedProto, "2001::1", "2001::2", 0, 0)},
|
{Drop, parsed(testDeniedProto, "2001::1", "2001::2", 0, 0)},
|
||||||
|
|
||||||
|
// Test use of a node capability to grant access.
|
||||||
|
// 10.0.0.1 has the capability; 10.0.0.2 does not (see srcIPHasCap at top of func)
|
||||||
|
{Accept, parsed(ipproto.TCP, ipWithCap.String(), "1.2.3.4", 30000, 22)},
|
||||||
|
{Drop, parsed(ipproto.TCP, ipWithoutCap.String(), "1.2.3.4", 30000, 22)},
|
||||||
}
|
}
|
||||||
for i, test := range tests {
|
for i, test := range tests {
|
||||||
aclFunc := acl.runIn4
|
aclFunc := filt.runIn4
|
||||||
if test.p.IPVersion == 6 {
|
if test.p.IPVersion == 6 {
|
||||||
aclFunc = acl.runIn6
|
aclFunc = filt.runIn6
|
||||||
}
|
}
|
||||||
if got, why := aclFunc(&test.p); test.want != got {
|
if got, why := aclFunc(&test.p); test.want != got {
|
||||||
t.Errorf("#%d runIn got=%v want=%v why=%q packet:%v", i, got, test.want, why, test.p)
|
t.Errorf("#%d runIn got=%v want=%v why=%q packet:%v", i, got, test.want, why, test.p)
|
||||||
|
continue
|
||||||
}
|
}
|
||||||
if test.p.IPProto == ipproto.TCP {
|
if test.p.IPProto == ipproto.TCP {
|
||||||
var got Response
|
var got Response
|
||||||
if test.p.IPVersion == 4 {
|
if test.p.IPVersion == 4 {
|
||||||
got = acl.CheckTCP(test.p.Src.Addr(), test.p.Dst.Addr(), test.p.Dst.Port())
|
got = filt.CheckTCP(test.p.Src.Addr(), test.p.Dst.Addr(), test.p.Dst.Port())
|
||||||
} else {
|
} else {
|
||||||
got = acl.CheckTCP(test.p.Src.Addr(), test.p.Dst.Addr(), test.p.Dst.Port())
|
got = filt.CheckTCP(test.p.Src.Addr(), test.p.Dst.Addr(), test.p.Dst.Port())
|
||||||
}
|
}
|
||||||
if test.want != got {
|
if test.want != got {
|
||||||
t.Errorf("#%d CheckTCP got=%v want=%v packet:%v", i, got, test.want, test.p)
|
t.Errorf("#%d CheckTCP got=%v want=%v packet:%v", i, got, test.want, test.p)
|
||||||
@ -165,7 +195,7 @@ type InOut struct {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
// Update UDP state
|
// Update UDP state
|
||||||
_, _ = acl.runOut(&test.p)
|
_, _ = filt.runOut(&test.p)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -264,13 +294,16 @@ func TestParseIPSet(t *testing.T) {
|
|||||||
{"*", pfx("0.0.0.0/0", "::/0"), ""},
|
{"*", pfx("0.0.0.0/0", "::/0"), ""},
|
||||||
}
|
}
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
got, err := parseIPSet(tt.host)
|
got, gotCap, err := parseIPSet(tt.host)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if err.Error() == tt.wantErr {
|
if err.Error() == tt.wantErr {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
t.Errorf("parseIPSet(%q) error: %v; want error %q", tt.host, err, tt.wantErr)
|
t.Errorf("parseIPSet(%q) error: %v; want error %q", tt.host, err, tt.wantErr)
|
||||||
}
|
}
|
||||||
|
if gotCap != "" {
|
||||||
|
t.Errorf("parseIPSet(%q) cap: %q; want empty", tt.host, gotCap)
|
||||||
|
}
|
||||||
compareIP := cmp.Comparer(func(a, b netip.Addr) bool { return a == b })
|
compareIP := cmp.Comparer(func(a, b netip.Addr) bool { return a == b })
|
||||||
compareIPPrefix := cmp.Comparer(func(a, b netip.Prefix) bool { return a == b })
|
compareIPPrefix := cmp.Comparer(func(a, b netip.Prefix) bool { return a == b })
|
||||||
if diff := cmp.Diff(got, tt.want, compareIP, compareIPPrefix); diff != "" {
|
if diff := cmp.Diff(got, tt.want, compareIP, compareIPPrefix); diff != "" {
|
||||||
@ -278,6 +311,27 @@ func TestParseIPSet(t *testing.T) {
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
capTests := []struct {
|
||||||
|
in string
|
||||||
|
want tailcfg.NodeCapability
|
||||||
|
}{
|
||||||
|
{"cap:foo", "foo"},
|
||||||
|
{"cap:people-in-8.8.8.0/24", "people-in-8.8.8.0/24"}, // test precedence of "/" search
|
||||||
|
}
|
||||||
|
for _, tt := range capTests {
|
||||||
|
pfxes, gotCap, err := parseIPSet(tt.in)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("parseIPSet(%q) error: %v; want no error", tt.in, err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if gotCap != tt.want {
|
||||||
|
t.Errorf("parseIPSet(%q) cap: %q; want %q", tt.in, gotCap, tt.want)
|
||||||
|
}
|
||||||
|
if len(pfxes) != 0 {
|
||||||
|
t.Errorf("parseIPSet(%q) pfxes: %v; want empty", tt.in, pfxes)
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func BenchmarkFilter(b *testing.B) {
|
func BenchmarkFilter(b *testing.B) {
|
||||||
@ -904,7 +958,7 @@ func TestPeerCaps(t *testing.T) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
filt := New(mm, nil, nil, nil, t.Logf)
|
filt := New(mm, nil, nil, nil, nil, t.Logf)
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
src, dst string // IP
|
src, dst string // IP
|
||||||
@ -1037,7 +1091,7 @@ func benchmarkFile(b *testing.B, file string, opt benchOpt) {
|
|||||||
logIPs.AddPrefix(tsaddr.CGNATRange())
|
logIPs.AddPrefix(tsaddr.CGNATRange())
|
||||||
logIPs.AddPrefix(tsaddr.TailscaleULARange())
|
logIPs.AddPrefix(tsaddr.TailscaleULARange())
|
||||||
|
|
||||||
f := New(matches, must.Get(localNets.IPSet()), must.Get(logIPs.IPSet()), nil, logger.Discard)
|
f := New(matches, nil, must.Get(localNets.IPSet()), must.Get(logIPs.IPSet()), nil, logger.Discard)
|
||||||
var srcIP, dstIP netip.Addr
|
var srcIP, dstIP netip.Addr
|
||||||
if opt.v4 {
|
if opt.v4 {
|
||||||
srcIP = netip.MustParseAddr("1.2.3.4")
|
srcIP = netip.MustParseAddr("1.2.3.4")
|
||||||
|
@ -66,11 +66,28 @@ type CapMatch struct {
|
|||||||
// Match matches packets from any IP address in Srcs to any ip:port in
|
// Match matches packets from any IP address in Srcs to any ip:port in
|
||||||
// Dsts.
|
// Dsts.
|
||||||
type Match struct {
|
type Match struct {
|
||||||
IPProto views.Slice[ipproto.Proto] // required set (no default value at this layer)
|
// IPProto is the set of IP protocol numbers for which this match applies.
|
||||||
Srcs []netip.Prefix
|
// It is required. There is no default value at this layer.
|
||||||
|
// If empty, it doesn't match.
|
||||||
|
IPProto views.Slice[ipproto.Proto]
|
||||||
|
|
||||||
|
// Srcs is the set of source IP prefixes for which this match applies. A
|
||||||
|
// Match can match by either its source IP address being in Srcs (which
|
||||||
|
// SrcsContains tests) or if the source IP is of a known peer self address
|
||||||
|
// that contains a NodeCapability listed in SrcCaps.
|
||||||
|
Srcs []netip.Prefix
|
||||||
|
// SrcsContains is an optimized function that reports whether Addr is in
|
||||||
|
// Srcs, using the best search method for the size and shape of Srcs.
|
||||||
SrcsContains func(netip.Addr) bool `json:"-"` // report whether Addr is in Srcs
|
SrcsContains func(netip.Addr) bool `json:"-"` // report whether Addr is in Srcs
|
||||||
Dsts []NetPortRange // optional, if Srcs match
|
|
||||||
Caps []CapMatch // optional, if Srcs match
|
// SrcCaps is an alternative way to match packets. If the peer's source IP
|
||||||
|
// has one of these capabilities, it's also permitted. The peers are only
|
||||||
|
// looked up by their self address (Node.Addresses) and not by subnet routes
|
||||||
|
// they advertise.
|
||||||
|
SrcCaps []tailcfg.NodeCapability
|
||||||
|
|
||||||
|
Dsts []NetPortRange // optional, if source matches
|
||||||
|
Caps []CapMatch // optional, if source match
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m Match) String() string {
|
func (m Match) String() string {
|
||||||
|
@ -23,6 +23,7 @@ func (src *Match) Clone() *Match {
|
|||||||
*dst = *src
|
*dst = *src
|
||||||
dst.IPProto = src.IPProto
|
dst.IPProto = src.IPProto
|
||||||
dst.Srcs = append(src.Srcs[:0:0], src.Srcs...)
|
dst.Srcs = append(src.Srcs[:0:0], src.Srcs...)
|
||||||
|
dst.SrcCaps = append(src.SrcCaps[:0:0], src.SrcCaps...)
|
||||||
dst.Dsts = append(src.Dsts[:0:0], src.Dsts...)
|
dst.Dsts = append(src.Dsts[:0:0], src.Dsts...)
|
||||||
if src.Caps != nil {
|
if src.Caps != nil {
|
||||||
dst.Caps = make([]CapMatch, len(src.Caps))
|
dst.Caps = make([]CapMatch, len(src.Caps))
|
||||||
@ -38,6 +39,7 @@ func (src *Match) Clone() *Match {
|
|||||||
IPProto views.Slice[ipproto.Proto]
|
IPProto views.Slice[ipproto.Proto]
|
||||||
Srcs []netip.Prefix
|
Srcs []netip.Prefix
|
||||||
SrcsContains func(netip.Addr) bool
|
SrcsContains func(netip.Addr) bool
|
||||||
|
SrcCaps []tailcfg.NodeCapability
|
||||||
Dsts []NetPortRange
|
Dsts []NetPortRange
|
||||||
Caps []CapMatch
|
Caps []CapMatch
|
||||||
}{})
|
}{})
|
||||||
|
@ -4,19 +4,23 @@
|
|||||||
package filter
|
package filter
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"net/netip"
|
||||||
|
|
||||||
"tailscale.com/net/packet"
|
"tailscale.com/net/packet"
|
||||||
|
"tailscale.com/tailcfg"
|
||||||
"tailscale.com/types/views"
|
"tailscale.com/types/views"
|
||||||
"tailscale.com/wgengine/filter/filtertype"
|
"tailscale.com/wgengine/filter/filtertype"
|
||||||
)
|
)
|
||||||
|
|
||||||
type matches []filtertype.Match
|
type matches []filtertype.Match
|
||||||
|
|
||||||
func (ms matches) match(q *packet.Parsed) bool {
|
func (ms matches) match(q *packet.Parsed, hasCap CapTestFunc) bool {
|
||||||
for _, m := range ms {
|
for i := range ms {
|
||||||
|
m := &ms[i]
|
||||||
if !views.SliceContains(m.IPProto, q.IPProto) {
|
if !views.SliceContains(m.IPProto, q.IPProto) {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
if !m.SrcsContains(q.Src.Addr()) {
|
if !srcMatches(m, q.Src.Addr(), hasCap) {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
for _, dst := range m.Dsts {
|
for _, dst := range m.Dsts {
|
||||||
@ -32,9 +36,33 @@ func (ms matches) match(q *packet.Parsed) bool {
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ms matches) matchIPsOnly(q *packet.Parsed) bool {
|
// srcMatches reports whether srcAddr matche the src requirements in m, either
|
||||||
|
// by Srcs (using SrcsContains), or by the node having a capability listed
|
||||||
|
// in SrcCaps using the provided hasCap function.
|
||||||
|
func srcMatches(m *filtertype.Match, srcAddr netip.Addr, hasCap CapTestFunc) bool {
|
||||||
|
if m.SrcsContains(srcAddr) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
if hasCap != nil {
|
||||||
|
for _, c := range m.SrcCaps {
|
||||||
|
if hasCap(srcAddr, c) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// CapTestFunc is the function signature of a function that tests whether srcIP
|
||||||
|
// has a given capability.
|
||||||
|
//
|
||||||
|
// It it used in the fast path of evaluating filter rules so should be fast.
|
||||||
|
type CapTestFunc = func(srcIP netip.Addr, cap tailcfg.NodeCapability) bool
|
||||||
|
|
||||||
|
func (ms matches) matchIPsOnly(q *packet.Parsed, hasCap CapTestFunc) bool {
|
||||||
|
srcAddr := q.Src.Addr()
|
||||||
for _, m := range ms {
|
for _, m := range ms {
|
||||||
if !m.SrcsContains(q.Src.Addr()) {
|
if !m.SrcsContains(srcAddr) {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
for _, dst := range m.Dsts {
|
for _, dst := range m.Dsts {
|
||||||
@ -43,6 +71,15 @@ func (ms matches) matchIPsOnly(q *packet.Parsed) bool {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if hasCap != nil {
|
||||||
|
for _, m := range ms {
|
||||||
|
for _, c := range m.SrcCaps {
|
||||||
|
if hasCap(srcAddr, c) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -58,12 +58,15 @@ func MatchesFromFilterRules(pf []tailcfg.FilterRule) ([]Match, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
for _, s := range r.SrcIPs {
|
for _, s := range r.SrcIPs {
|
||||||
nets, err := parseIPSet(s)
|
nets, cap, err := parseIPSet(s)
|
||||||
if err != nil && erracc == nil {
|
if err != nil && erracc == nil {
|
||||||
erracc = err
|
erracc = err
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
m.Srcs = append(m.Srcs, nets...)
|
m.Srcs = append(m.Srcs, nets...)
|
||||||
|
if cap != "" {
|
||||||
|
m.SrcCaps = append(m.SrcCaps, cap)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
m.SrcsContains = ipset.NewContainsIPFunc(views.SliceOf(m.Srcs))
|
m.SrcsContains = ipset.NewContainsIPFunc(views.SliceOf(m.Srcs))
|
||||||
|
|
||||||
@ -71,11 +74,15 @@ func MatchesFromFilterRules(pf []tailcfg.FilterRule) ([]Match, error) {
|
|||||||
if d.Bits != nil {
|
if d.Bits != nil {
|
||||||
return nil, fmt.Errorf("unexpected DstBits; control plane should not send this to this client version")
|
return nil, fmt.Errorf("unexpected DstBits; control plane should not send this to this client version")
|
||||||
}
|
}
|
||||||
nets, err := parseIPSet(d.IP)
|
nets, cap, err := parseIPSet(d.IP)
|
||||||
if err != nil && erracc == nil {
|
if err != nil && erracc == nil {
|
||||||
erracc = err
|
erracc = err
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
if cap != "" {
|
||||||
|
erracc = fmt.Errorf("unexpected capability %q in DstPorts", cap)
|
||||||
|
continue
|
||||||
|
}
|
||||||
for _, net := range nets {
|
for _, net := range nets {
|
||||||
m.Dsts = append(m.Dsts, NetPortRange{
|
m.Dsts = append(m.Dsts, NetPortRange{
|
||||||
Net: net,
|
Net: net,
|
||||||
@ -120,48 +127,52 @@ func MatchesFromFilterRules(pf []tailcfg.FilterRule) ([]Match, error) {
|
|||||||
// - the string "*" to match everything (both IPv4 & IPv6)
|
// - the string "*" to match everything (both IPv4 & IPv6)
|
||||||
// - a CIDR (e.g. "192.168.0.0/16")
|
// - a CIDR (e.g. "192.168.0.0/16")
|
||||||
// - a range of two IPs, inclusive, separated by hyphen ("2eff::1-2eff::0800")
|
// - a range of two IPs, inclusive, separated by hyphen ("2eff::1-2eff::0800")
|
||||||
|
// - "cap:<peer-node-capability>" to match a peer node capability
|
||||||
//
|
//
|
||||||
// TODO(bradfitz): make this return an IPSet and plumb that all
|
// TODO(bradfitz): make this return an IPSet and plumb that all
|
||||||
// around, and ultimately use a new version of IPSet.ContainsFunc like
|
// around, and ultimately use a new version of IPSet.ContainsFunc like
|
||||||
// Contains16Func that works in [16]byte address, so we we can match
|
// Contains16Func that works in [16]byte address, so we we can match
|
||||||
// at runtime without allocating?
|
// at runtime without allocating?
|
||||||
func parseIPSet(arg string) ([]netip.Prefix, error) {
|
func parseIPSet(arg string) (prefixes []netip.Prefix, peerCap tailcfg.NodeCapability, err error) {
|
||||||
if arg == "*" {
|
if arg == "*" {
|
||||||
// User explicitly requested wildcard.
|
// User explicitly requested wildcard.
|
||||||
return []netip.Prefix{
|
return []netip.Prefix{
|
||||||
netip.PrefixFrom(zeroIP4, 0),
|
netip.PrefixFrom(zeroIP4, 0),
|
||||||
netip.PrefixFrom(zeroIP6, 0),
|
netip.PrefixFrom(zeroIP6, 0),
|
||||||
}, nil
|
}, "", nil
|
||||||
|
}
|
||||||
|
if cap, ok := strings.CutPrefix(arg, "cap:"); ok {
|
||||||
|
return nil, tailcfg.NodeCapability(cap), nil
|
||||||
}
|
}
|
||||||
if strings.Contains(arg, "/") {
|
if strings.Contains(arg, "/") {
|
||||||
pfx, err := netip.ParsePrefix(arg)
|
pfx, err := netip.ParsePrefix(arg)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, "", err
|
||||||
}
|
}
|
||||||
if pfx != pfx.Masked() {
|
if pfx != pfx.Masked() {
|
||||||
return nil, fmt.Errorf("%v contains non-network bits set", pfx)
|
return nil, "", fmt.Errorf("%v contains non-network bits set", pfx)
|
||||||
}
|
}
|
||||||
return []netip.Prefix{pfx}, nil
|
return []netip.Prefix{pfx}, "", nil
|
||||||
}
|
}
|
||||||
if strings.Count(arg, "-") == 1 {
|
if strings.Count(arg, "-") == 1 {
|
||||||
ip1s, ip2s, _ := strings.Cut(arg, "-")
|
ip1s, ip2s, _ := strings.Cut(arg, "-")
|
||||||
ip1, err := netip.ParseAddr(ip1s)
|
ip1, err := netip.ParseAddr(ip1s)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, "", err
|
||||||
}
|
}
|
||||||
ip2, err := netip.ParseAddr(ip2s)
|
ip2, err := netip.ParseAddr(ip2s)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, "", err
|
||||||
}
|
}
|
||||||
r := netipx.IPRangeFrom(ip1, ip2)
|
r := netipx.IPRangeFrom(ip1, ip2)
|
||||||
if !r.IsValid() {
|
if !r.IsValid() {
|
||||||
return nil, fmt.Errorf("invalid IP range %q", arg)
|
return nil, "", fmt.Errorf("invalid IP range %q", arg)
|
||||||
}
|
}
|
||||||
return r.Prefixes(), nil
|
return r.Prefixes(), "", nil
|
||||||
}
|
}
|
||||||
ip, err := netip.ParseAddr(arg)
|
ip, err := netip.ParseAddr(arg)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("invalid IP address %q", arg)
|
return nil, "", fmt.Errorf("invalid IP address %q", arg)
|
||||||
}
|
}
|
||||||
return []netip.Prefix{netip.PrefixFrom(ip, ip.BitLen())}, nil
|
return []netip.Prefix{netip.PrefixFrom(ip, ip.BitLen())}, "", nil
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user