wgengine/filter: use netaddr types in public API.

We still use the packet.* alloc-free types in the data path, but
the compilation from netaddr to packet happens within the filter
package.

Signed-off-by: David Anderson <danderson@tailscale.com>
This commit is contained in:
David Anderson 2020-11-09 20:12:21 -08:00 committed by Dave Anderson
parent 7988f75b87
commit b3634f020d
7 changed files with 386 additions and 215 deletions

View File

@ -546,7 +546,7 @@ func (b *LocalBackend) updateFilter(netMap *controlclient.NetworkMap, prefs *Pre
return
}
localNets := wgCIDRsToFilter(netMap.Addresses, advRoutes)
localNets := wgCIDRsToNetaddr(netMap.Addresses, advRoutes)
if shieldsUp {
b.logf("netmap packet filter: (shields up)")
@ -1266,14 +1266,14 @@ func routerConfig(cfg *wgcfg.Config, prefs *Prefs) *router.Config {
}
rs := &router.Config{
LocalAddrs: wgCIDRToNetaddr(addrs),
SubnetRoutes: wgCIDRToNetaddr(prefs.AdvertiseRoutes),
LocalAddrs: wgCIDRsToNetaddr(addrs),
SubnetRoutes: wgCIDRsToNetaddr(prefs.AdvertiseRoutes),
SNATSubnetRoutes: !prefs.NoSNAT,
NetfilterMode: prefs.NetfilterMode,
}
for _, peer := range cfg.Peers {
rs.Routes = append(rs.Routes, wgCIDRToNetaddr(peer.AllowedIPs)...)
rs.Routes = append(rs.Routes, wgCIDRsToNetaddr(peer.AllowedIPs)...)
}
rs.Routes = append(rs.Routes, netaddr.IPPrefix{
@ -1284,35 +1284,20 @@ func routerConfig(cfg *wgcfg.Config, prefs *Prefs) *router.Config {
return rs
}
// wgCIDRsToFilter converts lists of wgcfg.CIDR into a single list of
// filter.Net.
func wgCIDRsToFilter(cidrLists ...[]wgcfg.CIDR) (ret []filter.Net) {
func wgCIDRsToNetaddr(cidrLists ...[]wgcfg.CIDR) (ret []netaddr.IPPrefix) {
for _, cidrs := range cidrLists {
for _, cidr := range cidrs {
if !cidr.IP.Is4() {
continue
ncidr, ok := netaddr.FromStdIPNet(cidr.IPNet())
if !ok {
panic(fmt.Sprintf("conversion of %s from wgcfg to netaddr IPNet failed", cidr))
}
ret = append(ret, filter.Net{
IP: filter.NewIP(cidr.IP.IP()),
Mask: filter.Netmask(int(cidr.Mask)),
})
ncidr.IP = ncidr.IP.Unmap()
ret = append(ret, ncidr)
}
}
return ret
}
func wgCIDRToNetaddr(cidrs []wgcfg.CIDR) (ret []netaddr.IPPrefix) {
for _, cidr := range cidrs {
ncidr, ok := netaddr.FromStdIPNet(cidr.IPNet())
if !ok {
panic(fmt.Sprintf("conversion of %s from wgcfg to netaddr IPNet failed", cidr))
}
ncidr.IP = ncidr.IP.Unmap()
ret = append(ret, ncidr)
}
return ret
}
func applyPrefsToHostinfo(hi *tailcfg.Hostinfo, prefs *Prefs) {
if h := prefs.Hostname; h != "" {
hi.Hostname = h

View File

@ -7,12 +7,12 @@
import (
"fmt"
"net"
"sync"
"time"
"github.com/golang/groupcache/lru"
"golang.org/x/time/rate"
"inet.af/netaddr"
"tailscale.com/net/packet"
"tailscale.com/tailcfg"
"tailscale.com/types/logger"
@ -26,16 +26,18 @@ type filterState struct {
// Filter is a stateful packet filter.
type Filter struct {
logf logger.Logf
// localNets is the list of IP prefixes that we know to be "local"
// to this node. All packets coming in over tailscale must have a
// destination within localNets, regardless of the policy filter
// below. A nil localNets rejects all incoming traffic.
localNets []Net
// matches is a list of match->action rules applied to all packets
// arriving over tailscale tunnels. Matches are checked in order,
// and processing stops at the first matching rule. The default
// policy if no rules match is to drop the packet.
matches Matches
// localNets is the list of IP prefixes that we know to be
// "local" to this node. All packets coming in over tailscale
// must have a destination within localNets, regardless of the
// policy filter below. A nil localNets rejects all incoming
// traffic.
local4 []net4
// matches4 is a list of match->action rules applied to all
// packets arriving over tailscale tunnels. Matches are
// checked in order, and processing stops at the first
// matching rule. The default policy if no rules match is to
// drop the packet.
matches4 matches4
// state is the connection tracking state attached to this
// filter. It is used to allow incoming traffic that is a response
// to an outbound connection that this node made, even if those
@ -87,12 +89,12 @@ type tuple struct {
// MatchAllowAll matches all packets.
var MatchAllowAll = Matches{
Match{[]NetPortRange{NetPortRangeAny}, []Net{NetAny}},
Match{NetPortRangeAny, NetAny},
}
// NewAllowAll returns a packet filter that accepts everything to and
// from localNets.
func NewAllowAll(localNets []Net, logf logger.Logf) *Filter {
func NewAllowAll(localNets []netaddr.IPPrefix, logf logger.Logf) *Filter {
return New(MatchAllowAll, localNets, nil, logf)
}
@ -106,7 +108,7 @@ func NewAllowNone(logf logger.Logf) *Filter {
// by matches. If shareStateWith is non-nil, the returned filter
// shares state with the previous one, to enable rules to be changed
// at runtime without breaking existing flows.
func New(matches Matches, localNets []Net, shareStateWith *Filter, logf logger.Logf) *Filter {
func New(matches Matches, localNets []netaddr.IPPrefix, shareStateWith *Filter, logf logger.Logf) *Filter {
var state *filterState
if shareStateWith != nil {
state = shareStateWith.state
@ -116,10 +118,10 @@ func New(matches Matches, localNets []Net, shareStateWith *Filter, logf logger.L
}
}
f := &Filter{
logf: logf,
matches: matches,
localNets: localNets,
state: state,
logf: logf,
matches4: newMatches4(matches),
local4: nets4FromIPPrefixes(localNets),
state: state,
}
return f
}
@ -179,29 +181,32 @@ func MatchesFromFilterRules(pf []tailcfg.FilterRule) (Matches, error) {
return mm, erracc
}
func parseIP(host string, defaultBits int) (Net, error) {
ip := net.ParseIP(host)
if ip != nil && ip.IsUnspecified() {
// For clarity, reject 0.0.0.0 as an input
return NetNone, fmt.Errorf("ports=%#v: to allow all IP addresses, use *:port, not 0.0.0.0:port", host)
} else if ip == nil && host == "*" {
// User explicitly requested wildcard dst ip
return NetAny, nil
} else {
if ip != nil {
ip = ip.To4()
}
if ip == nil || len(ip) != 4 {
return NetNone, fmt.Errorf("ports=%#v: invalid IPv4 address", host)
}
if len(ip) == 4 && (defaultBits < 0 || defaultBits > 32) {
return NetNone, fmt.Errorf("invalid CIDR size %d for host %q", defaultBits, host)
}
return Net{
IP: NewIP(ip),
Mask: Netmask(defaultBits),
}, nil
func parseIP(host string, defaultBits int) (netaddr.IPPrefix, error) {
if host == "*" {
// User explicitly requested wildcard dst ip.
// TODO: ipv6
return netaddr.IPPrefix{IP: netaddr.IPv4(0, 0, 0, 0), Bits: 0}, nil
}
ip, err := netaddr.ParseIP(host)
if err != nil {
return netaddr.IPPrefix{}, fmt.Errorf("ports=%#v: invalid IP address", host)
}
if ip == netaddr.IPv4(0, 0, 0, 0) {
// For clarity, reject 0.0.0.0 as an input
return netaddr.IPPrefix{}, fmt.Errorf("ports=%#v: to allow all IP addresses, use *:port, not 0.0.0.0:port", host)
}
if !ip.Is4() {
// TODO: ipv6
return netaddr.IPPrefix{}, fmt.Errorf("ports=%#v: invalid IPv4 address", host)
}
if defaultBits < 0 || defaultBits > 32 {
return netaddr.IPPrefix{}, fmt.Errorf("invalid CIDR size %d for host %q", defaultBits, host)
}
return netaddr.IPPrefix{
IP: ip,
Bits: uint8(defaultBits),
}, nil
}
// TODO(apenwarr): use a bigger bucket for specifically TCP SYN accept logging?
@ -266,7 +271,7 @@ func (f *Filter) runIn(q *packet.ParsedPacket) (r Response, why string) {
// A compromised peer could try to send us packets for
// destinations we didn't explicitly advertise. This check is to
// prevent that.
if !ipInList(q.DstIP, f.localNets) {
if !ip4InList(q.DstIP, f.local4) {
return Drop, "destination not allowed"
}
@ -284,7 +289,7 @@ func (f *Filter) runIn(q *packet.ParsedPacket) (r Response, why string) {
// related to an existing ICMP-Echo, TCP, or UDP
// session.
return Accept, "icmp response ok"
} else if matchIPWithoutPorts(f.matches, q) {
} else if f.matches4.matchIPsOnly(q) {
// If any port is open to an IP, allow ICMP to it.
return Accept, "icmp ok"
}
@ -300,7 +305,7 @@ func (f *Filter) runIn(q *packet.ParsedPacket) (r Response, why string) {
if q.IPProto == packet.TCP && !q.IsTCPSyn() {
return Accept, "tcp non-syn"
}
if matchIPPorts(f.matches, q) {
if f.matches4.match(q) {
return Accept, "tcp ok"
}
case packet.UDP:
@ -313,7 +318,7 @@ func (f *Filter) runIn(q *packet.ParsedPacket) (r Response, why string) {
if ok {
return Accept, "udp cached"
}
if matchIPPorts(f.matches, q) {
if f.matches4.match(q) {
return Accept, "udp ok"
}
default:
@ -399,9 +404,9 @@ func (f *Filter) pre(q *packet.ParsedPacket, rf RunFlags, dir direction) Respons
)
// omitDropLogging reports whether packet p, which has already been
// deemded a packet to Drop, should bypass the [rate-limited] logging.
// We don't want to log scary & spammy reject warnings for packets that
// are totally normal, like IPv6 route announcements.
// deemed a packet to Drop, should bypass the [rate-limited] logging.
// We don't want to log scary & spammy reject warnings for packets
// that are totally normal, like IPv6 route announcements.
func omitDropLogging(p *packet.ParsedPacket, dir direction) bool {
b := p.Buffer()
switch dir {

View File

@ -8,10 +8,13 @@
"encoding/binary"
"encoding/hex"
"encoding/json"
"fmt"
"net"
"strconv"
"strings"
"testing"
"inet.af/netaddr"
"tailscale.com/net/packet"
"tailscale.com/types/logger"
)
@ -22,43 +25,91 @@
var UDP = packet.UDP
var Fragment = packet.Fragment
func nets(ips []packet.IP4) []Net {
out := make([]Net, 0, len(ips))
for _, ip := range ips {
out = append(out, Net{ip, Netmask(32)})
func pfx(s string) netaddr.IPPrefix {
pfx, err := netaddr.ParseIPPrefix(s)
if err != nil {
panic(err)
}
return out
return pfx
}
func ippr(ip packet.IP4, start, end uint16) []NetPortRange {
return []NetPortRange{
NetPortRange{Net{ip, Netmask(32)}, PortRange{start, end}},
func nets(nets ...string) (ret []netaddr.IPPrefix) {
for _, s := range nets {
if i := strings.IndexByte(s, '/'); i == -1 {
ip, err := netaddr.ParseIP(s)
if err != nil {
panic(err)
}
bits := uint8(32)
if ip.Is6() {
bits = 128
}
ret = append(ret, netaddr.IPPrefix{IP: ip, Bits: bits})
} else {
pfx, err := netaddr.ParseIPPrefix(s)
if err != nil {
panic(err)
}
ret = append(ret, pfx)
}
}
return ret
}
func netpr(ip packet.IP4, bits int, start, end uint16) []NetPortRange {
return []NetPortRange{
NetPortRange{Net{ip, Netmask(bits)}, PortRange{start, end}},
func ports(s string) PortRange {
if s == "*" {
return PortRangeAny
}
var fs, ls string
i := strings.IndexByte(s, '-')
if i == -1 {
fs = s
ls = fs
} else {
fs = s[:i]
ls = s[i+1:]
}
first, err := strconv.ParseInt(fs, 10, 16)
if err != nil {
panic(fmt.Sprintf("invalid NetPortRange %q", s))
}
last, err := strconv.ParseInt(ls, 10, 16)
if err != nil {
panic(fmt.Sprintf("invalid NetPortRange %q", s))
}
return PortRange{uint16(first), uint16(last)}
}
func netports(netPorts ...string) (ret []NetPortRange) {
for _, s := range netPorts {
i := strings.LastIndexByte(s, ':')
if i == -1 {
panic(fmt.Sprintf("invalid NetPortRange %q", s))
}
npr := NetPortRange{
Net: nets(s[:i])[0],
Ports: ports(s[i+1:]),
}
ret = append(ret, npr)
}
return ret
}
var matches = Matches{
{Srcs: nets([]packet.IP4{0x08010101, 0x08020202}), Dsts: []NetPortRange{
NetPortRange{Net{0x01020304, Netmask(32)}, PortRange{22, 22}},
NetPortRange{Net{0x05060708, Netmask(32)}, PortRange{23, 24}},
}},
{Srcs: nets([]packet.IP4{0x08010101, 0x08020202}), Dsts: ippr(0x05060708, 27, 28)},
{Srcs: nets([]packet.IP4{0x02020202}), Dsts: ippr(0x08010101, 22, 22)},
{Srcs: []Net{NetAny}, Dsts: ippr(0x647a6232, 0, 65535)},
{Srcs: []Net{NetAny}, Dsts: netpr(0, 0, 443, 443)},
{Srcs: nets([]packet.IP4{0x99010101, 0x99010102, 0x99030303}), Dsts: ippr(0x01020304, 999, 999)},
{Srcs: nets("8.1.1.1", "8.2.2.2"), Dsts: netports("1.2.3.4:22", "5.6.7.8:23-24")},
{Srcs: nets("8.1.1.1", "8.2.2.2"), Dsts: netports("5.6.7.8:27-28")},
{Srcs: nets("2.2.2.2"), Dsts: netports("8.1.1.1:22")},
{Srcs: nets("0.0.0.0/0"), Dsts: netports("100.122.98.50:*")},
{Srcs: nets("0.0.0.0/0"), Dsts: netports("0.0.0.0/0:443")},
{Srcs: nets("153.1.1.1", "153.1.1.2", "153.3.3.3"), Dsts: netports("1.2.3.4:999")},
}
func newFilter(logf logger.Logf) *Filter {
// Expects traffic to 100.122.98.50, 1.2.3.4, 5.6.7.8,
// 102.102.102.102, 119.119.119.119, 8.1.0.0/16
localNets := nets([]packet.IP4{0x647a6232, 0x01020304, 0x05060708, 0x66666666, 0x77777777})
localNets = append(localNets, Net{packet.IP4(0x08010000), Netmask(16)})
localNets := nets("100.122.98.50", "1.2.3.4", "5.6.7.8", "102.102.102.102", "119.119.119.119", "8.1.0.0/16")
return New(matches, localNets, nil, logf)
}
@ -160,18 +211,19 @@ func TestNoAllocs(t *testing.T) {
}
func TestParseIP(t *testing.T) {
var noaddr netaddr.IPPrefix
tests := []struct {
host string
bits int
want Net
want netaddr.IPPrefix
wantErr string
}{
{"8.8.8.8", 24, Net{IP: packet.NewIP4(net.ParseIP("8.8.8.8")), Mask: packet.NewIP4(net.ParseIP("255.255.255.0"))}, ""},
{"8.8.8.8", 33, Net{}, `invalid CIDR size 33 for host "8.8.8.8"`},
{"8.8.8.8", -1, Net{}, `invalid CIDR size -1 for host "8.8.8.8"`},
{"0.0.0.0", 24, Net{}, `ports="0.0.0.0": to allow all IP addresses, use *:port, not 0.0.0.0:port`},
{"*", 24, NetAny, ""},
{"fe80::1", 128, NetNone, `ports="fe80::1": invalid IPv4 address`},
{"8.8.8.8", 24, pfx("8.8.8.8/24"), ""},
{"8.8.8.8", 33, noaddr, `invalid CIDR size 33 for host "8.8.8.8"`},
{"8.8.8.8", -1, noaddr, `invalid CIDR size -1 for host "8.8.8.8"`},
{"0.0.0.0", 24, noaddr, `ports="0.0.0.0": to allow all IP addresses, use *:port, not 0.0.0.0:port`},
{"*", 24, pfx("0.0.0.0/0"), ""},
{"fe80::1", 128, pfx("255.255.255.255/32"), `ports="fe80::1": invalid IPv4 address`},
}
for _, tt := range tests {
got, err := parseIP(tt.host, tt.bits)
@ -215,6 +267,7 @@ func BenchmarkFilter(b *testing.B) {
for _, bench := range benches {
b.Run(bench.name, func(b *testing.B) {
b.ReportAllocs()
for i := 0; i < b.N; i++ {
q := &packet.ParsedPacket{}
q.Decode(bench.packet)

View File

@ -6,53 +6,17 @@
import (
"fmt"
"math/bits"
"net"
"strings"
"tailscale.com/net/packet"
"inet.af/netaddr"
)
func NewIP(ip net.IP) packet.IP4 {
return packet.NewIP4(ip)
}
type Net struct {
IP packet.IP4
Mask packet.IP4
}
func (n Net) Includes(ip packet.IP4) bool {
return (n.IP & n.Mask) == (ip & n.Mask)
}
func (n Net) Bits() int {
return 32 - bits.TrailingZeros32(uint32(n.Mask))
}
func (n Net) String() string {
b := n.Bits()
if b == 32 {
return n.IP.String()
} else if b == 0 {
return "*"
} else {
return fmt.Sprintf("%s/%d", n.IP, b)
}
}
var NetAny = Net{0, 0}
var NetNone = Net{^packet.IP4(0), ^packet.IP4(0)}
func Netmask(bits int) packet.IP4 {
b := ^uint32((1 << (32 - bits)) - 1)
return packet.IP4(b)
}
// PortRange is a range of TCP and UDP ports.
type PortRange struct {
First, Last uint16
First, Last uint16 // inclusive
}
// PortRangeAny represents all TCP and UDP ports.
var PortRangeAny = PortRange{0, 65535}
func (pr PortRange) String() string {
@ -65,28 +29,40 @@ func (pr PortRange) String() string {
}
}
func (pr PortRange) contains(port uint16) bool {
return port >= pr.First && port <= pr.Last
}
// NetAny matches all IP addresses.
// TODO: add ipv6.
var NetAny = []netaddr.IPPrefix{{IP: netaddr.IPv4(0, 0, 0, 0), Bits: 0}}
// NetPortRange combines an IP address prefix and PortRange.
type NetPortRange struct {
Net Net
Net netaddr.IPPrefix
Ports PortRange
}
var NetPortRangeAny = NetPortRange{NetAny, PortRangeAny}
func (ipr NetPortRange) String() string {
return fmt.Sprintf("%v:%v", ipr.Net, ipr.Ports)
func (npr NetPortRange) String() string {
return fmt.Sprintf("%v:%v", npr.Net, npr.Ports)
}
var NetPortRangeAny = []NetPortRange{{Net: NetAny[0], Ports: PortRangeAny}}
// Match matches packets from any IP address in Srcs to any ip:port in
// Dsts.
type Match struct {
Dsts []NetPortRange
Srcs []Net
Srcs []netaddr.IPPrefix
}
// Clone returns a deep copy of m.
func (m Match) Clone() (res Match) {
if m.Dsts != nil {
res.Dsts = append([]NetPortRange{}, m.Dsts...)
}
if m.Srcs != nil {
res.Srcs = append([]Net{}, m.Srcs...)
res.Srcs = append([]netaddr.IPPrefix{}, m.Srcs...)
}
return res
}
@ -115,57 +91,13 @@ func (m Match) String() string {
return fmt.Sprintf("%v=>%v", ss, ds)
}
// Matches is a list of packet matchers.
type Matches []Match
func (m Matches) Clone() (res Matches) {
for _, match := range m {
// Clone returns a deep copy of ms.
func (ms Matches) Clone() (res Matches) {
for _, match := range ms {
res = append(res, match.Clone())
}
return res
}
func ipInList(ip packet.IP4, netlist []Net) bool {
for _, net := range netlist {
if net.Includes(ip) {
return true
}
}
return false
}
func matchIPPorts(mm Matches, q *packet.ParsedPacket) bool {
for _, acl := range mm {
for _, dst := range acl.Dsts {
if !dst.Net.Includes(q.DstIP) {
continue
}
if q.DstPort < dst.Ports.First || q.DstPort > dst.Ports.Last {
continue
}
if !ipInList(q.SrcIP, acl.Srcs) {
// Skip other dests in this acl, since
// the src will never match.
break
}
return true
}
}
return false
}
func matchIPWithoutPorts(mm Matches, q *packet.ParsedPacket) bool {
for _, acl := range mm {
for _, dst := range acl.Dsts {
if !dst.Net.Includes(q.DstIP) {
continue
}
if !ipInList(q.SrcIP, acl.Srcs) {
// Skip other dests in this acl, since
// the src will never match.
break
}
return true
}
}
return false
}

151
wgengine/filter/match4.go Normal file
View File

@ -0,0 +1,151 @@
// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package filter
import (
"fmt"
"math/bits"
"strings"
"inet.af/netaddr"
"tailscale.com/net/packet"
)
type net4 struct {
ip packet.IP4
mask packet.IP4
}
func net4FromIPPrefix(pfx netaddr.IPPrefix) net4 {
if !pfx.IP.Is4() {
panic("net4FromIPPrefix given non-ipv4 prefix")
}
return net4{
ip: packet.IP4FromNetaddr(pfx.IP),
mask: netmask4(pfx.Bits),
}
}
func nets4FromIPPrefixes(pfxs []netaddr.IPPrefix) (ret []net4) {
for _, pfx := range pfxs {
if pfx.IP.Is4() {
ret = append(ret, net4FromIPPrefix(pfx))
}
}
return ret
}
func (n net4) Contains(ip packet.IP4) bool {
return (n.ip & n.mask) == (ip & n.mask)
}
func (n net4) Bits() int {
return 32 - bits.TrailingZeros32(uint32(n.mask))
}
func (n net4) String() string {
b := n.Bits()
if b == 32 {
return n.ip.String()
} else if b == 0 {
return "*"
} else {
return fmt.Sprintf("%s/%d", n.ip, b)
}
}
type npr4 struct {
net net4
ports PortRange
}
func (npr npr4) String() string {
return fmt.Sprintf("%s:%s", npr.net, npr.ports)
}
type match4 struct {
dsts []npr4
srcs []net4
}
type matches4 []match4
func (ms matches4) String() string {
var b strings.Builder
for _, m := range ms {
fmt.Fprintf(&b, "%s => %s\n", m.srcs, m.dsts)
}
return b.String()
}
func newMatches4(ms Matches) (ret matches4) {
for _, m := range ms {
var m4 match4
for _, src := range m.Srcs {
if src.IP.Is4() {
m4.srcs = append(m4.srcs, net4FromIPPrefix(src))
}
}
for _, dst := range m.Dsts {
if dst.Net.IP.Is4() {
m4.dsts = append(m4.dsts, npr4{net4FromIPPrefix(dst.Net), dst.Ports})
}
}
if len(m4.srcs) > 0 && len(m4.dsts) > 0 {
ret = append(ret, m4)
}
}
return ret
}
// match returns whether q's source IP and destination IP:port match
// any of ms.
func (ms matches4) match(q *packet.ParsedPacket) bool {
for _, m := range ms {
if !ip4InList(q.SrcIP, m.srcs) {
continue
}
for _, dst := range m.dsts {
if !dst.net.Contains(q.DstIP) {
continue
}
if !dst.ports.contains(q.DstPort) {
continue
}
return true
}
}
return false
}
// matchIPsOnly returns whether q's source and destination IP match
// any of ms.
func (ms matches4) matchIPsOnly(q *packet.ParsedPacket) bool {
for _, m := range ms {
if !ip4InList(q.SrcIP, m.srcs) {
continue
}
for _, dst := range m.dsts {
if dst.net.Contains(q.DstIP) {
return true
}
}
}
return false
}
func netmask4(bits uint8) packet.IP4 {
b := ^uint32((1 << (32 - bits)) - 1)
return packet.IP4(b)
}
func ip4InList(ip packet.IP4, netlist []net4) bool {
for _, net := range netlist {
if net.Contains(ip) {
return true
}
}
return false
}

View File

@ -158,7 +158,7 @@ func newMagicStack(t *testing.T, logf logger.Logf, l nettype.PacketListener, der
tun := tuntest.NewChannelTUN()
tsTun := tstun.WrapTUN(logf, tun.TUN())
tsTun.SetFilter(filter.NewAllowAll([]filter.Net{filter.NetAny}, logf))
tsTun.SetFilter(filter.NewAllowAll(filter.NetAny, logf))
dev := device.NewDevice(tsTun, &device.DeviceOptions{
Logger: &device.Logger{

View File

@ -6,11 +6,15 @@
import (
"bytes"
"fmt"
"strconv"
"strings"
"sync/atomic"
"testing"
"unsafe"
"github.com/tailscale/wireguard-go/tun/tuntest"
"inet.af/netaddr"
"tailscale.com/net/packet"
"tailscale.com/types/logger"
"tailscale.com/wgengine/filter"
@ -29,35 +33,76 @@ func udp(src, dst packet.IP4, sport, dport uint16) []byte {
return packet.Generate(header, []byte("udp_payload"))
}
func filterNet(ip, mask packet.IP4) filter.Net {
return filter.Net{IP: ip, Mask: mask}
func nets(nets ...string) (ret []netaddr.IPPrefix) {
for _, s := range nets {
if i := strings.IndexByte(s, '/'); i == -1 {
ip, err := netaddr.ParseIP(s)
if err != nil {
panic(err)
}
bits := uint8(32)
if ip.Is6() {
bits = 128
}
ret = append(ret, netaddr.IPPrefix{IP: ip, Bits: bits})
} else {
pfx, err := netaddr.ParseIPPrefix(s)
if err != nil {
panic(err)
}
ret = append(ret, pfx)
}
}
return ret
}
func nets(ips []packet.IP4) []filter.Net {
out := make([]filter.Net, 0, len(ips))
for _, ip := range ips {
out = append(out, filterNet(ip, filter.Netmask(32)))
func ports(s string) filter.PortRange {
if s == "*" {
return filter.PortRangeAny
}
return out
var fs, ls string
i := strings.IndexByte(s, '-')
if i == -1 {
fs = s
ls = fs
} else {
fs = s[:i]
ls = s[i+1:]
}
first, err := strconv.ParseInt(fs, 10, 16)
if err != nil {
panic(fmt.Sprintf("invalid NetPortRange %q", s))
}
last, err := strconv.ParseInt(ls, 10, 16)
if err != nil {
panic(fmt.Sprintf("invalid NetPortRange %q", s))
}
return filter.PortRange{First: uint16(first), Last: uint16(last)}
}
func ippr(ip packet.IP4, start, end uint16) []filter.NetPortRange {
return []filter.NetPortRange{
filter.NetPortRange{
Net: filterNet(ip, filter.Netmask(32)),
Ports: filter.PortRange{First: start, Last: end},
},
func netports(netPorts ...string) (ret []filter.NetPortRange) {
for _, s := range netPorts {
i := strings.LastIndexByte(s, ':')
if i == -1 {
panic(fmt.Sprintf("invalid NetPortRange %q", s))
}
npr := filter.NetPortRange{
Net: nets(s[:i])[0],
Ports: ports(s[i+1:]),
}
ret = append(ret, npr)
}
return ret
}
func setfilter(logf logger.Logf, tun *TUN) {
matches := filter.Matches{
{Srcs: nets([]packet.IP4{0x05060708}), Dsts: ippr(0x01020304, 89, 90)},
{Srcs: nets([]packet.IP4{0x01020304}), Dsts: ippr(0x05060708, 98, 98)},
}
localNets := []filter.Net{
filterNet(packet.IP4(0x01020304), filter.Netmask(16)),
{Srcs: nets("5.6.7.8"), Dsts: netports("1.2.3.4:89-90")},
{Srcs: nets("1.2.3.4"), Dsts: netports("5.6.7.8:98")},
}
localNets := nets("1.2.0.0/16")
tun.SetFilter(filter.New(matches, localNets, nil, logf))
}