wgengine/filter: use netaddr types in public API.

We still use the packet.* alloc-free types in the data path, but
the compilation from netaddr to packet happens within the filter
package.

Signed-off-by: David Anderson <danderson@tailscale.com>
This commit is contained in:
David Anderson
2020-11-09 20:12:21 -08:00
committed by Dave Anderson
parent 7988f75b87
commit b3634f020d
7 changed files with 386 additions and 215 deletions

View File

@@ -6,11 +6,15 @@ package tstun
import (
"bytes"
"fmt"
"strconv"
"strings"
"sync/atomic"
"testing"
"unsafe"
"github.com/tailscale/wireguard-go/tun/tuntest"
"inet.af/netaddr"
"tailscale.com/net/packet"
"tailscale.com/types/logger"
"tailscale.com/wgengine/filter"
@@ -29,35 +33,76 @@ func udp(src, dst packet.IP4, sport, dport uint16) []byte {
return packet.Generate(header, []byte("udp_payload"))
}
func filterNet(ip, mask packet.IP4) filter.Net {
return filter.Net{IP: ip, Mask: mask}
func nets(nets ...string) (ret []netaddr.IPPrefix) {
for _, s := range nets {
if i := strings.IndexByte(s, '/'); i == -1 {
ip, err := netaddr.ParseIP(s)
if err != nil {
panic(err)
}
bits := uint8(32)
if ip.Is6() {
bits = 128
}
ret = append(ret, netaddr.IPPrefix{IP: ip, Bits: bits})
} else {
pfx, err := netaddr.ParseIPPrefix(s)
if err != nil {
panic(err)
}
ret = append(ret, pfx)
}
}
return ret
}
func nets(ips []packet.IP4) []filter.Net {
out := make([]filter.Net, 0, len(ips))
for _, ip := range ips {
out = append(out, filterNet(ip, filter.Netmask(32)))
func ports(s string) filter.PortRange {
if s == "*" {
return filter.PortRangeAny
}
return out
var fs, ls string
i := strings.IndexByte(s, '-')
if i == -1 {
fs = s
ls = fs
} else {
fs = s[:i]
ls = s[i+1:]
}
first, err := strconv.ParseInt(fs, 10, 16)
if err != nil {
panic(fmt.Sprintf("invalid NetPortRange %q", s))
}
last, err := strconv.ParseInt(ls, 10, 16)
if err != nil {
panic(fmt.Sprintf("invalid NetPortRange %q", s))
}
return filter.PortRange{First: uint16(first), Last: uint16(last)}
}
func ippr(ip packet.IP4, start, end uint16) []filter.NetPortRange {
return []filter.NetPortRange{
filter.NetPortRange{
Net: filterNet(ip, filter.Netmask(32)),
Ports: filter.PortRange{First: start, Last: end},
},
func netports(netPorts ...string) (ret []filter.NetPortRange) {
for _, s := range netPorts {
i := strings.LastIndexByte(s, ':')
if i == -1 {
panic(fmt.Sprintf("invalid NetPortRange %q", s))
}
npr := filter.NetPortRange{
Net: nets(s[:i])[0],
Ports: ports(s[i+1:]),
}
ret = append(ret, npr)
}
return ret
}
func setfilter(logf logger.Logf, tun *TUN) {
matches := filter.Matches{
{Srcs: nets([]packet.IP4{0x05060708}), Dsts: ippr(0x01020304, 89, 90)},
{Srcs: nets([]packet.IP4{0x01020304}), Dsts: ippr(0x05060708, 98, 98)},
}
localNets := []filter.Net{
filterNet(packet.IP4(0x01020304), filter.Netmask(16)),
{Srcs: nets("5.6.7.8"), Dsts: netports("1.2.3.4:89-90")},
{Srcs: nets("1.2.3.4"), Dsts: netports("5.6.7.8:98")},
}
localNets := nets("1.2.0.0/16")
tun.SetFilter(filter.New(matches, localNets, nil, logf))
}