mirror of
https://github.com/tailscale/tailscale.git
synced 2024-11-26 03:25:35 +00:00
wgengine/filter: use netaddr types in public API.
We still use the packet.* alloc-free types in the data path, but the compilation from netaddr to packet happens within the filter package. Signed-off-by: David Anderson <danderson@tailscale.com>
This commit is contained in:
parent
7988f75b87
commit
b3634f020d
35
ipn/local.go
35
ipn/local.go
@ -546,7 +546,7 @@ func (b *LocalBackend) updateFilter(netMap *controlclient.NetworkMap, prefs *Pre
|
||||
return
|
||||
}
|
||||
|
||||
localNets := wgCIDRsToFilter(netMap.Addresses, advRoutes)
|
||||
localNets := wgCIDRsToNetaddr(netMap.Addresses, advRoutes)
|
||||
|
||||
if shieldsUp {
|
||||
b.logf("netmap packet filter: (shields up)")
|
||||
@ -1266,14 +1266,14 @@ func routerConfig(cfg *wgcfg.Config, prefs *Prefs) *router.Config {
|
||||
}
|
||||
|
||||
rs := &router.Config{
|
||||
LocalAddrs: wgCIDRToNetaddr(addrs),
|
||||
SubnetRoutes: wgCIDRToNetaddr(prefs.AdvertiseRoutes),
|
||||
LocalAddrs: wgCIDRsToNetaddr(addrs),
|
||||
SubnetRoutes: wgCIDRsToNetaddr(prefs.AdvertiseRoutes),
|
||||
SNATSubnetRoutes: !prefs.NoSNAT,
|
||||
NetfilterMode: prefs.NetfilterMode,
|
||||
}
|
||||
|
||||
for _, peer := range cfg.Peers {
|
||||
rs.Routes = append(rs.Routes, wgCIDRToNetaddr(peer.AllowedIPs)...)
|
||||
rs.Routes = append(rs.Routes, wgCIDRsToNetaddr(peer.AllowedIPs)...)
|
||||
}
|
||||
|
||||
rs.Routes = append(rs.Routes, netaddr.IPPrefix{
|
||||
@ -1284,35 +1284,20 @@ func routerConfig(cfg *wgcfg.Config, prefs *Prefs) *router.Config {
|
||||
return rs
|
||||
}
|
||||
|
||||
// wgCIDRsToFilter converts lists of wgcfg.CIDR into a single list of
|
||||
// filter.Net.
|
||||
func wgCIDRsToFilter(cidrLists ...[]wgcfg.CIDR) (ret []filter.Net) {
|
||||
func wgCIDRsToNetaddr(cidrLists ...[]wgcfg.CIDR) (ret []netaddr.IPPrefix) {
|
||||
for _, cidrs := range cidrLists {
|
||||
for _, cidr := range cidrs {
|
||||
if !cidr.IP.Is4() {
|
||||
continue
|
||||
ncidr, ok := netaddr.FromStdIPNet(cidr.IPNet())
|
||||
if !ok {
|
||||
panic(fmt.Sprintf("conversion of %s from wgcfg to netaddr IPNet failed", cidr))
|
||||
}
|
||||
ret = append(ret, filter.Net{
|
||||
IP: filter.NewIP(cidr.IP.IP()),
|
||||
Mask: filter.Netmask(int(cidr.Mask)),
|
||||
})
|
||||
ncidr.IP = ncidr.IP.Unmap()
|
||||
ret = append(ret, ncidr)
|
||||
}
|
||||
}
|
||||
return ret
|
||||
}
|
||||
|
||||
func wgCIDRToNetaddr(cidrs []wgcfg.CIDR) (ret []netaddr.IPPrefix) {
|
||||
for _, cidr := range cidrs {
|
||||
ncidr, ok := netaddr.FromStdIPNet(cidr.IPNet())
|
||||
if !ok {
|
||||
panic(fmt.Sprintf("conversion of %s from wgcfg to netaddr IPNet failed", cidr))
|
||||
}
|
||||
ncidr.IP = ncidr.IP.Unmap()
|
||||
ret = append(ret, ncidr)
|
||||
}
|
||||
return ret
|
||||
}
|
||||
|
||||
func applyPrefsToHostinfo(hi *tailcfg.Hostinfo, prefs *Prefs) {
|
||||
if h := prefs.Hostname; h != "" {
|
||||
hi.Hostname = h
|
||||
|
@ -7,12 +7,12 @@
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/golang/groupcache/lru"
|
||||
"golang.org/x/time/rate"
|
||||
"inet.af/netaddr"
|
||||
"tailscale.com/net/packet"
|
||||
"tailscale.com/tailcfg"
|
||||
"tailscale.com/types/logger"
|
||||
@ -26,16 +26,18 @@ type filterState struct {
|
||||
// Filter is a stateful packet filter.
|
||||
type Filter struct {
|
||||
logf logger.Logf
|
||||
// localNets is the list of IP prefixes that we know to be "local"
|
||||
// to this node. All packets coming in over tailscale must have a
|
||||
// destination within localNets, regardless of the policy filter
|
||||
// below. A nil localNets rejects all incoming traffic.
|
||||
localNets []Net
|
||||
// matches is a list of match->action rules applied to all packets
|
||||
// arriving over tailscale tunnels. Matches are checked in order,
|
||||
// and processing stops at the first matching rule. The default
|
||||
// policy if no rules match is to drop the packet.
|
||||
matches Matches
|
||||
// localNets is the list of IP prefixes that we know to be
|
||||
// "local" to this node. All packets coming in over tailscale
|
||||
// must have a destination within localNets, regardless of the
|
||||
// policy filter below. A nil localNets rejects all incoming
|
||||
// traffic.
|
||||
local4 []net4
|
||||
// matches4 is a list of match->action rules applied to all
|
||||
// packets arriving over tailscale tunnels. Matches are
|
||||
// checked in order, and processing stops at the first
|
||||
// matching rule. The default policy if no rules match is to
|
||||
// drop the packet.
|
||||
matches4 matches4
|
||||
// state is the connection tracking state attached to this
|
||||
// filter. It is used to allow incoming traffic that is a response
|
||||
// to an outbound connection that this node made, even if those
|
||||
@ -87,12 +89,12 @@ type tuple struct {
|
||||
|
||||
// MatchAllowAll matches all packets.
|
||||
var MatchAllowAll = Matches{
|
||||
Match{[]NetPortRange{NetPortRangeAny}, []Net{NetAny}},
|
||||
Match{NetPortRangeAny, NetAny},
|
||||
}
|
||||
|
||||
// NewAllowAll returns a packet filter that accepts everything to and
|
||||
// from localNets.
|
||||
func NewAllowAll(localNets []Net, logf logger.Logf) *Filter {
|
||||
func NewAllowAll(localNets []netaddr.IPPrefix, logf logger.Logf) *Filter {
|
||||
return New(MatchAllowAll, localNets, nil, logf)
|
||||
}
|
||||
|
||||
@ -106,7 +108,7 @@ func NewAllowNone(logf logger.Logf) *Filter {
|
||||
// by matches. If shareStateWith is non-nil, the returned filter
|
||||
// shares state with the previous one, to enable rules to be changed
|
||||
// at runtime without breaking existing flows.
|
||||
func New(matches Matches, localNets []Net, shareStateWith *Filter, logf logger.Logf) *Filter {
|
||||
func New(matches Matches, localNets []netaddr.IPPrefix, shareStateWith *Filter, logf logger.Logf) *Filter {
|
||||
var state *filterState
|
||||
if shareStateWith != nil {
|
||||
state = shareStateWith.state
|
||||
@ -116,10 +118,10 @@ func New(matches Matches, localNets []Net, shareStateWith *Filter, logf logger.L
|
||||
}
|
||||
}
|
||||
f := &Filter{
|
||||
logf: logf,
|
||||
matches: matches,
|
||||
localNets: localNets,
|
||||
state: state,
|
||||
logf: logf,
|
||||
matches4: newMatches4(matches),
|
||||
local4: nets4FromIPPrefixes(localNets),
|
||||
state: state,
|
||||
}
|
||||
return f
|
||||
}
|
||||
@ -179,29 +181,32 @@ func MatchesFromFilterRules(pf []tailcfg.FilterRule) (Matches, error) {
|
||||
return mm, erracc
|
||||
}
|
||||
|
||||
func parseIP(host string, defaultBits int) (Net, error) {
|
||||
ip := net.ParseIP(host)
|
||||
if ip != nil && ip.IsUnspecified() {
|
||||
// For clarity, reject 0.0.0.0 as an input
|
||||
return NetNone, fmt.Errorf("ports=%#v: to allow all IP addresses, use *:port, not 0.0.0.0:port", host)
|
||||
} else if ip == nil && host == "*" {
|
||||
// User explicitly requested wildcard dst ip
|
||||
return NetAny, nil
|
||||
} else {
|
||||
if ip != nil {
|
||||
ip = ip.To4()
|
||||
}
|
||||
if ip == nil || len(ip) != 4 {
|
||||
return NetNone, fmt.Errorf("ports=%#v: invalid IPv4 address", host)
|
||||
}
|
||||
if len(ip) == 4 && (defaultBits < 0 || defaultBits > 32) {
|
||||
return NetNone, fmt.Errorf("invalid CIDR size %d for host %q", defaultBits, host)
|
||||
}
|
||||
return Net{
|
||||
IP: NewIP(ip),
|
||||
Mask: Netmask(defaultBits),
|
||||
}, nil
|
||||
func parseIP(host string, defaultBits int) (netaddr.IPPrefix, error) {
|
||||
if host == "*" {
|
||||
// User explicitly requested wildcard dst ip.
|
||||
// TODO: ipv6
|
||||
return netaddr.IPPrefix{IP: netaddr.IPv4(0, 0, 0, 0), Bits: 0}, nil
|
||||
}
|
||||
|
||||
ip, err := netaddr.ParseIP(host)
|
||||
if err != nil {
|
||||
return netaddr.IPPrefix{}, fmt.Errorf("ports=%#v: invalid IP address", host)
|
||||
}
|
||||
if ip == netaddr.IPv4(0, 0, 0, 0) {
|
||||
// For clarity, reject 0.0.0.0 as an input
|
||||
return netaddr.IPPrefix{}, fmt.Errorf("ports=%#v: to allow all IP addresses, use *:port, not 0.0.0.0:port", host)
|
||||
}
|
||||
if !ip.Is4() {
|
||||
// TODO: ipv6
|
||||
return netaddr.IPPrefix{}, fmt.Errorf("ports=%#v: invalid IPv4 address", host)
|
||||
}
|
||||
if defaultBits < 0 || defaultBits > 32 {
|
||||
return netaddr.IPPrefix{}, fmt.Errorf("invalid CIDR size %d for host %q", defaultBits, host)
|
||||
}
|
||||
return netaddr.IPPrefix{
|
||||
IP: ip,
|
||||
Bits: uint8(defaultBits),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// TODO(apenwarr): use a bigger bucket for specifically TCP SYN accept logging?
|
||||
@ -266,7 +271,7 @@ func (f *Filter) runIn(q *packet.ParsedPacket) (r Response, why string) {
|
||||
// A compromised peer could try to send us packets for
|
||||
// destinations we didn't explicitly advertise. This check is to
|
||||
// prevent that.
|
||||
if !ipInList(q.DstIP, f.localNets) {
|
||||
if !ip4InList(q.DstIP, f.local4) {
|
||||
return Drop, "destination not allowed"
|
||||
}
|
||||
|
||||
@ -284,7 +289,7 @@ func (f *Filter) runIn(q *packet.ParsedPacket) (r Response, why string) {
|
||||
// related to an existing ICMP-Echo, TCP, or UDP
|
||||
// session.
|
||||
return Accept, "icmp response ok"
|
||||
} else if matchIPWithoutPorts(f.matches, q) {
|
||||
} else if f.matches4.matchIPsOnly(q) {
|
||||
// If any port is open to an IP, allow ICMP to it.
|
||||
return Accept, "icmp ok"
|
||||
}
|
||||
@ -300,7 +305,7 @@ func (f *Filter) runIn(q *packet.ParsedPacket) (r Response, why string) {
|
||||
if q.IPProto == packet.TCP && !q.IsTCPSyn() {
|
||||
return Accept, "tcp non-syn"
|
||||
}
|
||||
if matchIPPorts(f.matches, q) {
|
||||
if f.matches4.match(q) {
|
||||
return Accept, "tcp ok"
|
||||
}
|
||||
case packet.UDP:
|
||||
@ -313,7 +318,7 @@ func (f *Filter) runIn(q *packet.ParsedPacket) (r Response, why string) {
|
||||
if ok {
|
||||
return Accept, "udp cached"
|
||||
}
|
||||
if matchIPPorts(f.matches, q) {
|
||||
if f.matches4.match(q) {
|
||||
return Accept, "udp ok"
|
||||
}
|
||||
default:
|
||||
@ -399,9 +404,9 @@ func (f *Filter) pre(q *packet.ParsedPacket, rf RunFlags, dir direction) Respons
|
||||
)
|
||||
|
||||
// omitDropLogging reports whether packet p, which has already been
|
||||
// deemded a packet to Drop, should bypass the [rate-limited] logging.
|
||||
// We don't want to log scary & spammy reject warnings for packets that
|
||||
// are totally normal, like IPv6 route announcements.
|
||||
// deemed a packet to Drop, should bypass the [rate-limited] logging.
|
||||
// We don't want to log scary & spammy reject warnings for packets
|
||||
// that are totally normal, like IPv6 route announcements.
|
||||
func omitDropLogging(p *packet.ParsedPacket, dir direction) bool {
|
||||
b := p.Buffer()
|
||||
switch dir {
|
||||
|
@ -8,10 +8,13 @@
|
||||
"encoding/binary"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net"
|
||||
"strconv"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"inet.af/netaddr"
|
||||
"tailscale.com/net/packet"
|
||||
"tailscale.com/types/logger"
|
||||
)
|
||||
@ -22,43 +25,91 @@
|
||||
var UDP = packet.UDP
|
||||
var Fragment = packet.Fragment
|
||||
|
||||
func nets(ips []packet.IP4) []Net {
|
||||
out := make([]Net, 0, len(ips))
|
||||
for _, ip := range ips {
|
||||
out = append(out, Net{ip, Netmask(32)})
|
||||
func pfx(s string) netaddr.IPPrefix {
|
||||
pfx, err := netaddr.ParseIPPrefix(s)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return out
|
||||
return pfx
|
||||
}
|
||||
|
||||
func ippr(ip packet.IP4, start, end uint16) []NetPortRange {
|
||||
return []NetPortRange{
|
||||
NetPortRange{Net{ip, Netmask(32)}, PortRange{start, end}},
|
||||
func nets(nets ...string) (ret []netaddr.IPPrefix) {
|
||||
for _, s := range nets {
|
||||
if i := strings.IndexByte(s, '/'); i == -1 {
|
||||
ip, err := netaddr.ParseIP(s)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
bits := uint8(32)
|
||||
if ip.Is6() {
|
||||
bits = 128
|
||||
}
|
||||
ret = append(ret, netaddr.IPPrefix{IP: ip, Bits: bits})
|
||||
} else {
|
||||
pfx, err := netaddr.ParseIPPrefix(s)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
ret = append(ret, pfx)
|
||||
}
|
||||
}
|
||||
return ret
|
||||
}
|
||||
|
||||
func netpr(ip packet.IP4, bits int, start, end uint16) []NetPortRange {
|
||||
return []NetPortRange{
|
||||
NetPortRange{Net{ip, Netmask(bits)}, PortRange{start, end}},
|
||||
func ports(s string) PortRange {
|
||||
if s == "*" {
|
||||
return PortRangeAny
|
||||
}
|
||||
|
||||
var fs, ls string
|
||||
i := strings.IndexByte(s, '-')
|
||||
if i == -1 {
|
||||
fs = s
|
||||
ls = fs
|
||||
} else {
|
||||
fs = s[:i]
|
||||
ls = s[i+1:]
|
||||
}
|
||||
first, err := strconv.ParseInt(fs, 10, 16)
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("invalid NetPortRange %q", s))
|
||||
}
|
||||
last, err := strconv.ParseInt(ls, 10, 16)
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("invalid NetPortRange %q", s))
|
||||
}
|
||||
return PortRange{uint16(first), uint16(last)}
|
||||
}
|
||||
|
||||
func netports(netPorts ...string) (ret []NetPortRange) {
|
||||
for _, s := range netPorts {
|
||||
i := strings.LastIndexByte(s, ':')
|
||||
if i == -1 {
|
||||
panic(fmt.Sprintf("invalid NetPortRange %q", s))
|
||||
}
|
||||
|
||||
npr := NetPortRange{
|
||||
Net: nets(s[:i])[0],
|
||||
Ports: ports(s[i+1:]),
|
||||
}
|
||||
ret = append(ret, npr)
|
||||
}
|
||||
return ret
|
||||
}
|
||||
|
||||
var matches = Matches{
|
||||
{Srcs: nets([]packet.IP4{0x08010101, 0x08020202}), Dsts: []NetPortRange{
|
||||
NetPortRange{Net{0x01020304, Netmask(32)}, PortRange{22, 22}},
|
||||
NetPortRange{Net{0x05060708, Netmask(32)}, PortRange{23, 24}},
|
||||
}},
|
||||
{Srcs: nets([]packet.IP4{0x08010101, 0x08020202}), Dsts: ippr(0x05060708, 27, 28)},
|
||||
{Srcs: nets([]packet.IP4{0x02020202}), Dsts: ippr(0x08010101, 22, 22)},
|
||||
{Srcs: []Net{NetAny}, Dsts: ippr(0x647a6232, 0, 65535)},
|
||||
{Srcs: []Net{NetAny}, Dsts: netpr(0, 0, 443, 443)},
|
||||
{Srcs: nets([]packet.IP4{0x99010101, 0x99010102, 0x99030303}), Dsts: ippr(0x01020304, 999, 999)},
|
||||
{Srcs: nets("8.1.1.1", "8.2.2.2"), Dsts: netports("1.2.3.4:22", "5.6.7.8:23-24")},
|
||||
{Srcs: nets("8.1.1.1", "8.2.2.2"), Dsts: netports("5.6.7.8:27-28")},
|
||||
{Srcs: nets("2.2.2.2"), Dsts: netports("8.1.1.1:22")},
|
||||
{Srcs: nets("0.0.0.0/0"), Dsts: netports("100.122.98.50:*")},
|
||||
{Srcs: nets("0.0.0.0/0"), Dsts: netports("0.0.0.0/0:443")},
|
||||
{Srcs: nets("153.1.1.1", "153.1.1.2", "153.3.3.3"), Dsts: netports("1.2.3.4:999")},
|
||||
}
|
||||
|
||||
func newFilter(logf logger.Logf) *Filter {
|
||||
// Expects traffic to 100.122.98.50, 1.2.3.4, 5.6.7.8,
|
||||
// 102.102.102.102, 119.119.119.119, 8.1.0.0/16
|
||||
localNets := nets([]packet.IP4{0x647a6232, 0x01020304, 0x05060708, 0x66666666, 0x77777777})
|
||||
localNets = append(localNets, Net{packet.IP4(0x08010000), Netmask(16)})
|
||||
localNets := nets("100.122.98.50", "1.2.3.4", "5.6.7.8", "102.102.102.102", "119.119.119.119", "8.1.0.0/16")
|
||||
|
||||
return New(matches, localNets, nil, logf)
|
||||
}
|
||||
@ -160,18 +211,19 @@ func TestNoAllocs(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestParseIP(t *testing.T) {
|
||||
var noaddr netaddr.IPPrefix
|
||||
tests := []struct {
|
||||
host string
|
||||
bits int
|
||||
want Net
|
||||
want netaddr.IPPrefix
|
||||
wantErr string
|
||||
}{
|
||||
{"8.8.8.8", 24, Net{IP: packet.NewIP4(net.ParseIP("8.8.8.8")), Mask: packet.NewIP4(net.ParseIP("255.255.255.0"))}, ""},
|
||||
{"8.8.8.8", 33, Net{}, `invalid CIDR size 33 for host "8.8.8.8"`},
|
||||
{"8.8.8.8", -1, Net{}, `invalid CIDR size -1 for host "8.8.8.8"`},
|
||||
{"0.0.0.0", 24, Net{}, `ports="0.0.0.0": to allow all IP addresses, use *:port, not 0.0.0.0:port`},
|
||||
{"*", 24, NetAny, ""},
|
||||
{"fe80::1", 128, NetNone, `ports="fe80::1": invalid IPv4 address`},
|
||||
{"8.8.8.8", 24, pfx("8.8.8.8/24"), ""},
|
||||
{"8.8.8.8", 33, noaddr, `invalid CIDR size 33 for host "8.8.8.8"`},
|
||||
{"8.8.8.8", -1, noaddr, `invalid CIDR size -1 for host "8.8.8.8"`},
|
||||
{"0.0.0.0", 24, noaddr, `ports="0.0.0.0": to allow all IP addresses, use *:port, not 0.0.0.0:port`},
|
||||
{"*", 24, pfx("0.0.0.0/0"), ""},
|
||||
{"fe80::1", 128, pfx("255.255.255.255/32"), `ports="fe80::1": invalid IPv4 address`},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
got, err := parseIP(tt.host, tt.bits)
|
||||
@ -215,6 +267,7 @@ func BenchmarkFilter(b *testing.B) {
|
||||
|
||||
for _, bench := range benches {
|
||||
b.Run(bench.name, func(b *testing.B) {
|
||||
b.ReportAllocs()
|
||||
for i := 0; i < b.N; i++ {
|
||||
q := &packet.ParsedPacket{}
|
||||
q.Decode(bench.packet)
|
||||
|
@ -6,53 +6,17 @@
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"math/bits"
|
||||
"net"
|
||||
"strings"
|
||||
|
||||
"tailscale.com/net/packet"
|
||||
"inet.af/netaddr"
|
||||
)
|
||||
|
||||
func NewIP(ip net.IP) packet.IP4 {
|
||||
return packet.NewIP4(ip)
|
||||
}
|
||||
|
||||
type Net struct {
|
||||
IP packet.IP4
|
||||
Mask packet.IP4
|
||||
}
|
||||
|
||||
func (n Net) Includes(ip packet.IP4) bool {
|
||||
return (n.IP & n.Mask) == (ip & n.Mask)
|
||||
}
|
||||
|
||||
func (n Net) Bits() int {
|
||||
return 32 - bits.TrailingZeros32(uint32(n.Mask))
|
||||
}
|
||||
|
||||
func (n Net) String() string {
|
||||
b := n.Bits()
|
||||
if b == 32 {
|
||||
return n.IP.String()
|
||||
} else if b == 0 {
|
||||
return "*"
|
||||
} else {
|
||||
return fmt.Sprintf("%s/%d", n.IP, b)
|
||||
}
|
||||
}
|
||||
|
||||
var NetAny = Net{0, 0}
|
||||
var NetNone = Net{^packet.IP4(0), ^packet.IP4(0)}
|
||||
|
||||
func Netmask(bits int) packet.IP4 {
|
||||
b := ^uint32((1 << (32 - bits)) - 1)
|
||||
return packet.IP4(b)
|
||||
}
|
||||
|
||||
// PortRange is a range of TCP and UDP ports.
|
||||
type PortRange struct {
|
||||
First, Last uint16
|
||||
First, Last uint16 // inclusive
|
||||
}
|
||||
|
||||
// PortRangeAny represents all TCP and UDP ports.
|
||||
var PortRangeAny = PortRange{0, 65535}
|
||||
|
||||
func (pr PortRange) String() string {
|
||||
@ -65,28 +29,40 @@ func (pr PortRange) String() string {
|
||||
}
|
||||
}
|
||||
|
||||
func (pr PortRange) contains(port uint16) bool {
|
||||
return port >= pr.First && port <= pr.Last
|
||||
}
|
||||
|
||||
// NetAny matches all IP addresses.
|
||||
// TODO: add ipv6.
|
||||
var NetAny = []netaddr.IPPrefix{{IP: netaddr.IPv4(0, 0, 0, 0), Bits: 0}}
|
||||
|
||||
// NetPortRange combines an IP address prefix and PortRange.
|
||||
type NetPortRange struct {
|
||||
Net Net
|
||||
Net netaddr.IPPrefix
|
||||
Ports PortRange
|
||||
}
|
||||
|
||||
var NetPortRangeAny = NetPortRange{NetAny, PortRangeAny}
|
||||
|
||||
func (ipr NetPortRange) String() string {
|
||||
return fmt.Sprintf("%v:%v", ipr.Net, ipr.Ports)
|
||||
func (npr NetPortRange) String() string {
|
||||
return fmt.Sprintf("%v:%v", npr.Net, npr.Ports)
|
||||
}
|
||||
|
||||
var NetPortRangeAny = []NetPortRange{{Net: NetAny[0], Ports: PortRangeAny}}
|
||||
|
||||
// Match matches packets from any IP address in Srcs to any ip:port in
|
||||
// Dsts.
|
||||
type Match struct {
|
||||
Dsts []NetPortRange
|
||||
Srcs []Net
|
||||
Srcs []netaddr.IPPrefix
|
||||
}
|
||||
|
||||
// Clone returns a deep copy of m.
|
||||
func (m Match) Clone() (res Match) {
|
||||
if m.Dsts != nil {
|
||||
res.Dsts = append([]NetPortRange{}, m.Dsts...)
|
||||
}
|
||||
if m.Srcs != nil {
|
||||
res.Srcs = append([]Net{}, m.Srcs...)
|
||||
res.Srcs = append([]netaddr.IPPrefix{}, m.Srcs...)
|
||||
}
|
||||
return res
|
||||
}
|
||||
@ -115,57 +91,13 @@ func (m Match) String() string {
|
||||
return fmt.Sprintf("%v=>%v", ss, ds)
|
||||
}
|
||||
|
||||
// Matches is a list of packet matchers.
|
||||
type Matches []Match
|
||||
|
||||
func (m Matches) Clone() (res Matches) {
|
||||
for _, match := range m {
|
||||
// Clone returns a deep copy of ms.
|
||||
func (ms Matches) Clone() (res Matches) {
|
||||
for _, match := range ms {
|
||||
res = append(res, match.Clone())
|
||||
}
|
||||
return res
|
||||
}
|
||||
|
||||
func ipInList(ip packet.IP4, netlist []Net) bool {
|
||||
for _, net := range netlist {
|
||||
if net.Includes(ip) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func matchIPPorts(mm Matches, q *packet.ParsedPacket) bool {
|
||||
for _, acl := range mm {
|
||||
for _, dst := range acl.Dsts {
|
||||
if !dst.Net.Includes(q.DstIP) {
|
||||
continue
|
||||
}
|
||||
if q.DstPort < dst.Ports.First || q.DstPort > dst.Ports.Last {
|
||||
continue
|
||||
}
|
||||
if !ipInList(q.SrcIP, acl.Srcs) {
|
||||
// Skip other dests in this acl, since
|
||||
// the src will never match.
|
||||
break
|
||||
}
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func matchIPWithoutPorts(mm Matches, q *packet.ParsedPacket) bool {
|
||||
for _, acl := range mm {
|
||||
for _, dst := range acl.Dsts {
|
||||
if !dst.Net.Includes(q.DstIP) {
|
||||
continue
|
||||
}
|
||||
if !ipInList(q.SrcIP, acl.Srcs) {
|
||||
// Skip other dests in this acl, since
|
||||
// the src will never match.
|
||||
break
|
||||
}
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
151
wgengine/filter/match4.go
Normal file
151
wgengine/filter/match4.go
Normal file
@ -0,0 +1,151 @@
|
||||
// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package filter
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"math/bits"
|
||||
"strings"
|
||||
|
||||
"inet.af/netaddr"
|
||||
"tailscale.com/net/packet"
|
||||
)
|
||||
|
||||
type net4 struct {
|
||||
ip packet.IP4
|
||||
mask packet.IP4
|
||||
}
|
||||
|
||||
func net4FromIPPrefix(pfx netaddr.IPPrefix) net4 {
|
||||
if !pfx.IP.Is4() {
|
||||
panic("net4FromIPPrefix given non-ipv4 prefix")
|
||||
}
|
||||
return net4{
|
||||
ip: packet.IP4FromNetaddr(pfx.IP),
|
||||
mask: netmask4(pfx.Bits),
|
||||
}
|
||||
}
|
||||
|
||||
func nets4FromIPPrefixes(pfxs []netaddr.IPPrefix) (ret []net4) {
|
||||
for _, pfx := range pfxs {
|
||||
if pfx.IP.Is4() {
|
||||
ret = append(ret, net4FromIPPrefix(pfx))
|
||||
}
|
||||
}
|
||||
return ret
|
||||
}
|
||||
|
||||
func (n net4) Contains(ip packet.IP4) bool {
|
||||
return (n.ip & n.mask) == (ip & n.mask)
|
||||
}
|
||||
|
||||
func (n net4) Bits() int {
|
||||
return 32 - bits.TrailingZeros32(uint32(n.mask))
|
||||
}
|
||||
|
||||
func (n net4) String() string {
|
||||
b := n.Bits()
|
||||
if b == 32 {
|
||||
return n.ip.String()
|
||||
} else if b == 0 {
|
||||
return "*"
|
||||
} else {
|
||||
return fmt.Sprintf("%s/%d", n.ip, b)
|
||||
}
|
||||
}
|
||||
|
||||
type npr4 struct {
|
||||
net net4
|
||||
ports PortRange
|
||||
}
|
||||
|
||||
func (npr npr4) String() string {
|
||||
return fmt.Sprintf("%s:%s", npr.net, npr.ports)
|
||||
}
|
||||
|
||||
type match4 struct {
|
||||
dsts []npr4
|
||||
srcs []net4
|
||||
}
|
||||
|
||||
type matches4 []match4
|
||||
|
||||
func (ms matches4) String() string {
|
||||
var b strings.Builder
|
||||
for _, m := range ms {
|
||||
fmt.Fprintf(&b, "%s => %s\n", m.srcs, m.dsts)
|
||||
}
|
||||
return b.String()
|
||||
}
|
||||
|
||||
func newMatches4(ms Matches) (ret matches4) {
|
||||
for _, m := range ms {
|
||||
var m4 match4
|
||||
for _, src := range m.Srcs {
|
||||
if src.IP.Is4() {
|
||||
m4.srcs = append(m4.srcs, net4FromIPPrefix(src))
|
||||
}
|
||||
}
|
||||
for _, dst := range m.Dsts {
|
||||
if dst.Net.IP.Is4() {
|
||||
m4.dsts = append(m4.dsts, npr4{net4FromIPPrefix(dst.Net), dst.Ports})
|
||||
}
|
||||
}
|
||||
if len(m4.srcs) > 0 && len(m4.dsts) > 0 {
|
||||
ret = append(ret, m4)
|
||||
}
|
||||
}
|
||||
return ret
|
||||
}
|
||||
|
||||
// match returns whether q's source IP and destination IP:port match
|
||||
// any of ms.
|
||||
func (ms matches4) match(q *packet.ParsedPacket) bool {
|
||||
for _, m := range ms {
|
||||
if !ip4InList(q.SrcIP, m.srcs) {
|
||||
continue
|
||||
}
|
||||
for _, dst := range m.dsts {
|
||||
if !dst.net.Contains(q.DstIP) {
|
||||
continue
|
||||
}
|
||||
if !dst.ports.contains(q.DstPort) {
|
||||
continue
|
||||
}
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// matchIPsOnly returns whether q's source and destination IP match
|
||||
// any of ms.
|
||||
func (ms matches4) matchIPsOnly(q *packet.ParsedPacket) bool {
|
||||
for _, m := range ms {
|
||||
if !ip4InList(q.SrcIP, m.srcs) {
|
||||
continue
|
||||
}
|
||||
for _, dst := range m.dsts {
|
||||
if dst.net.Contains(q.DstIP) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func netmask4(bits uint8) packet.IP4 {
|
||||
b := ^uint32((1 << (32 - bits)) - 1)
|
||||
return packet.IP4(b)
|
||||
}
|
||||
|
||||
func ip4InList(ip packet.IP4, netlist []net4) bool {
|
||||
for _, net := range netlist {
|
||||
if net.Contains(ip) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
@ -158,7 +158,7 @@ func newMagicStack(t *testing.T, logf logger.Logf, l nettype.PacketListener, der
|
||||
|
||||
tun := tuntest.NewChannelTUN()
|
||||
tsTun := tstun.WrapTUN(logf, tun.TUN())
|
||||
tsTun.SetFilter(filter.NewAllowAll([]filter.Net{filter.NetAny}, logf))
|
||||
tsTun.SetFilter(filter.NewAllowAll(filter.NetAny, logf))
|
||||
|
||||
dev := device.NewDevice(tsTun, &device.DeviceOptions{
|
||||
Logger: &device.Logger{
|
||||
|
@ -6,11 +6,15 @@
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"unsafe"
|
||||
|
||||
"github.com/tailscale/wireguard-go/tun/tuntest"
|
||||
"inet.af/netaddr"
|
||||
"tailscale.com/net/packet"
|
||||
"tailscale.com/types/logger"
|
||||
"tailscale.com/wgengine/filter"
|
||||
@ -29,35 +33,76 @@ func udp(src, dst packet.IP4, sport, dport uint16) []byte {
|
||||
return packet.Generate(header, []byte("udp_payload"))
|
||||
}
|
||||
|
||||
func filterNet(ip, mask packet.IP4) filter.Net {
|
||||
return filter.Net{IP: ip, Mask: mask}
|
||||
func nets(nets ...string) (ret []netaddr.IPPrefix) {
|
||||
for _, s := range nets {
|
||||
if i := strings.IndexByte(s, '/'); i == -1 {
|
||||
ip, err := netaddr.ParseIP(s)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
bits := uint8(32)
|
||||
if ip.Is6() {
|
||||
bits = 128
|
||||
}
|
||||
ret = append(ret, netaddr.IPPrefix{IP: ip, Bits: bits})
|
||||
} else {
|
||||
pfx, err := netaddr.ParseIPPrefix(s)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
ret = append(ret, pfx)
|
||||
}
|
||||
}
|
||||
return ret
|
||||
}
|
||||
|
||||
func nets(ips []packet.IP4) []filter.Net {
|
||||
out := make([]filter.Net, 0, len(ips))
|
||||
for _, ip := range ips {
|
||||
out = append(out, filterNet(ip, filter.Netmask(32)))
|
||||
func ports(s string) filter.PortRange {
|
||||
if s == "*" {
|
||||
return filter.PortRangeAny
|
||||
}
|
||||
return out
|
||||
|
||||
var fs, ls string
|
||||
i := strings.IndexByte(s, '-')
|
||||
if i == -1 {
|
||||
fs = s
|
||||
ls = fs
|
||||
} else {
|
||||
fs = s[:i]
|
||||
ls = s[i+1:]
|
||||
}
|
||||
first, err := strconv.ParseInt(fs, 10, 16)
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("invalid NetPortRange %q", s))
|
||||
}
|
||||
last, err := strconv.ParseInt(ls, 10, 16)
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("invalid NetPortRange %q", s))
|
||||
}
|
||||
return filter.PortRange{First: uint16(first), Last: uint16(last)}
|
||||
}
|
||||
|
||||
func ippr(ip packet.IP4, start, end uint16) []filter.NetPortRange {
|
||||
return []filter.NetPortRange{
|
||||
filter.NetPortRange{
|
||||
Net: filterNet(ip, filter.Netmask(32)),
|
||||
Ports: filter.PortRange{First: start, Last: end},
|
||||
},
|
||||
func netports(netPorts ...string) (ret []filter.NetPortRange) {
|
||||
for _, s := range netPorts {
|
||||
i := strings.LastIndexByte(s, ':')
|
||||
if i == -1 {
|
||||
panic(fmt.Sprintf("invalid NetPortRange %q", s))
|
||||
}
|
||||
|
||||
npr := filter.NetPortRange{
|
||||
Net: nets(s[:i])[0],
|
||||
Ports: ports(s[i+1:]),
|
||||
}
|
||||
ret = append(ret, npr)
|
||||
}
|
||||
return ret
|
||||
}
|
||||
|
||||
func setfilter(logf logger.Logf, tun *TUN) {
|
||||
matches := filter.Matches{
|
||||
{Srcs: nets([]packet.IP4{0x05060708}), Dsts: ippr(0x01020304, 89, 90)},
|
||||
{Srcs: nets([]packet.IP4{0x01020304}), Dsts: ippr(0x05060708, 98, 98)},
|
||||
}
|
||||
localNets := []filter.Net{
|
||||
filterNet(packet.IP4(0x01020304), filter.Netmask(16)),
|
||||
{Srcs: nets("5.6.7.8"), Dsts: netports("1.2.3.4:89-90")},
|
||||
{Srcs: nets("1.2.3.4"), Dsts: netports("5.6.7.8:98")},
|
||||
}
|
||||
localNets := nets("1.2.0.0/16")
|
||||
tun.SetFilter(filter.New(matches, localNets, nil, logf))
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user