mirror of
https://github.com/tailscale/tailscale.git
synced 2024-11-26 03:25:35 +00:00
wgengine/filter: use netaddr types in public API.
We still use the packet.* alloc-free types in the data path, but the compilation from netaddr to packet happens within the filter package. Signed-off-by: David Anderson <danderson@tailscale.com>
This commit is contained in:
parent
7988f75b87
commit
b3634f020d
35
ipn/local.go
35
ipn/local.go
@ -546,7 +546,7 @@ func (b *LocalBackend) updateFilter(netMap *controlclient.NetworkMap, prefs *Pre
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
localNets := wgCIDRsToFilter(netMap.Addresses, advRoutes)
|
localNets := wgCIDRsToNetaddr(netMap.Addresses, advRoutes)
|
||||||
|
|
||||||
if shieldsUp {
|
if shieldsUp {
|
||||||
b.logf("netmap packet filter: (shields up)")
|
b.logf("netmap packet filter: (shields up)")
|
||||||
@ -1266,14 +1266,14 @@ func routerConfig(cfg *wgcfg.Config, prefs *Prefs) *router.Config {
|
|||||||
}
|
}
|
||||||
|
|
||||||
rs := &router.Config{
|
rs := &router.Config{
|
||||||
LocalAddrs: wgCIDRToNetaddr(addrs),
|
LocalAddrs: wgCIDRsToNetaddr(addrs),
|
||||||
SubnetRoutes: wgCIDRToNetaddr(prefs.AdvertiseRoutes),
|
SubnetRoutes: wgCIDRsToNetaddr(prefs.AdvertiseRoutes),
|
||||||
SNATSubnetRoutes: !prefs.NoSNAT,
|
SNATSubnetRoutes: !prefs.NoSNAT,
|
||||||
NetfilterMode: prefs.NetfilterMode,
|
NetfilterMode: prefs.NetfilterMode,
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, peer := range cfg.Peers {
|
for _, peer := range cfg.Peers {
|
||||||
rs.Routes = append(rs.Routes, wgCIDRToNetaddr(peer.AllowedIPs)...)
|
rs.Routes = append(rs.Routes, wgCIDRsToNetaddr(peer.AllowedIPs)...)
|
||||||
}
|
}
|
||||||
|
|
||||||
rs.Routes = append(rs.Routes, netaddr.IPPrefix{
|
rs.Routes = append(rs.Routes, netaddr.IPPrefix{
|
||||||
@ -1284,35 +1284,20 @@ func routerConfig(cfg *wgcfg.Config, prefs *Prefs) *router.Config {
|
|||||||
return rs
|
return rs
|
||||||
}
|
}
|
||||||
|
|
||||||
// wgCIDRsToFilter converts lists of wgcfg.CIDR into a single list of
|
func wgCIDRsToNetaddr(cidrLists ...[]wgcfg.CIDR) (ret []netaddr.IPPrefix) {
|
||||||
// filter.Net.
|
|
||||||
func wgCIDRsToFilter(cidrLists ...[]wgcfg.CIDR) (ret []filter.Net) {
|
|
||||||
for _, cidrs := range cidrLists {
|
for _, cidrs := range cidrLists {
|
||||||
for _, cidr := range cidrs {
|
for _, cidr := range cidrs {
|
||||||
if !cidr.IP.Is4() {
|
ncidr, ok := netaddr.FromStdIPNet(cidr.IPNet())
|
||||||
continue
|
if !ok {
|
||||||
|
panic(fmt.Sprintf("conversion of %s from wgcfg to netaddr IPNet failed", cidr))
|
||||||
}
|
}
|
||||||
ret = append(ret, filter.Net{
|
ncidr.IP = ncidr.IP.Unmap()
|
||||||
IP: filter.NewIP(cidr.IP.IP()),
|
ret = append(ret, ncidr)
|
||||||
Mask: filter.Netmask(int(cidr.Mask)),
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return ret
|
return ret
|
||||||
}
|
}
|
||||||
|
|
||||||
func wgCIDRToNetaddr(cidrs []wgcfg.CIDR) (ret []netaddr.IPPrefix) {
|
|
||||||
for _, cidr := range cidrs {
|
|
||||||
ncidr, ok := netaddr.FromStdIPNet(cidr.IPNet())
|
|
||||||
if !ok {
|
|
||||||
panic(fmt.Sprintf("conversion of %s from wgcfg to netaddr IPNet failed", cidr))
|
|
||||||
}
|
|
||||||
ncidr.IP = ncidr.IP.Unmap()
|
|
||||||
ret = append(ret, ncidr)
|
|
||||||
}
|
|
||||||
return ret
|
|
||||||
}
|
|
||||||
|
|
||||||
func applyPrefsToHostinfo(hi *tailcfg.Hostinfo, prefs *Prefs) {
|
func applyPrefsToHostinfo(hi *tailcfg.Hostinfo, prefs *Prefs) {
|
||||||
if h := prefs.Hostname; h != "" {
|
if h := prefs.Hostname; h != "" {
|
||||||
hi.Hostname = h
|
hi.Hostname = h
|
||||||
|
@ -7,12 +7,12 @@
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/golang/groupcache/lru"
|
"github.com/golang/groupcache/lru"
|
||||||
"golang.org/x/time/rate"
|
"golang.org/x/time/rate"
|
||||||
|
"inet.af/netaddr"
|
||||||
"tailscale.com/net/packet"
|
"tailscale.com/net/packet"
|
||||||
"tailscale.com/tailcfg"
|
"tailscale.com/tailcfg"
|
||||||
"tailscale.com/types/logger"
|
"tailscale.com/types/logger"
|
||||||
@ -26,16 +26,18 @@ type filterState struct {
|
|||||||
// Filter is a stateful packet filter.
|
// Filter is a stateful packet filter.
|
||||||
type Filter struct {
|
type Filter struct {
|
||||||
logf logger.Logf
|
logf logger.Logf
|
||||||
// localNets is the list of IP prefixes that we know to be "local"
|
// localNets is the list of IP prefixes that we know to be
|
||||||
// to this node. All packets coming in over tailscale must have a
|
// "local" to this node. All packets coming in over tailscale
|
||||||
// destination within localNets, regardless of the policy filter
|
// must have a destination within localNets, regardless of the
|
||||||
// below. A nil localNets rejects all incoming traffic.
|
// policy filter below. A nil localNets rejects all incoming
|
||||||
localNets []Net
|
// traffic.
|
||||||
// matches is a list of match->action rules applied to all packets
|
local4 []net4
|
||||||
// arriving over tailscale tunnels. Matches are checked in order,
|
// matches4 is a list of match->action rules applied to all
|
||||||
// and processing stops at the first matching rule. The default
|
// packets arriving over tailscale tunnels. Matches are
|
||||||
// policy if no rules match is to drop the packet.
|
// checked in order, and processing stops at the first
|
||||||
matches Matches
|
// matching rule. The default policy if no rules match is to
|
||||||
|
// drop the packet.
|
||||||
|
matches4 matches4
|
||||||
// state is the connection tracking state attached to this
|
// state is the connection tracking state attached to this
|
||||||
// filter. It is used to allow incoming traffic that is a response
|
// filter. It is used to allow incoming traffic that is a response
|
||||||
// to an outbound connection that this node made, even if those
|
// to an outbound connection that this node made, even if those
|
||||||
@ -87,12 +89,12 @@ type tuple struct {
|
|||||||
|
|
||||||
// MatchAllowAll matches all packets.
|
// MatchAllowAll matches all packets.
|
||||||
var MatchAllowAll = Matches{
|
var MatchAllowAll = Matches{
|
||||||
Match{[]NetPortRange{NetPortRangeAny}, []Net{NetAny}},
|
Match{NetPortRangeAny, NetAny},
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewAllowAll returns a packet filter that accepts everything to and
|
// NewAllowAll returns a packet filter that accepts everything to and
|
||||||
// from localNets.
|
// from localNets.
|
||||||
func NewAllowAll(localNets []Net, logf logger.Logf) *Filter {
|
func NewAllowAll(localNets []netaddr.IPPrefix, logf logger.Logf) *Filter {
|
||||||
return New(MatchAllowAll, localNets, nil, logf)
|
return New(MatchAllowAll, localNets, nil, logf)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -106,7 +108,7 @@ func NewAllowNone(logf logger.Logf) *Filter {
|
|||||||
// by matches. If shareStateWith is non-nil, the returned filter
|
// by matches. If shareStateWith is non-nil, the returned filter
|
||||||
// shares state with the previous one, to enable rules to be changed
|
// shares state with the previous one, to enable rules to be changed
|
||||||
// at runtime without breaking existing flows.
|
// at runtime without breaking existing flows.
|
||||||
func New(matches Matches, localNets []Net, shareStateWith *Filter, logf logger.Logf) *Filter {
|
func New(matches Matches, localNets []netaddr.IPPrefix, shareStateWith *Filter, logf logger.Logf) *Filter {
|
||||||
var state *filterState
|
var state *filterState
|
||||||
if shareStateWith != nil {
|
if shareStateWith != nil {
|
||||||
state = shareStateWith.state
|
state = shareStateWith.state
|
||||||
@ -116,10 +118,10 @@ func New(matches Matches, localNets []Net, shareStateWith *Filter, logf logger.L
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
f := &Filter{
|
f := &Filter{
|
||||||
logf: logf,
|
logf: logf,
|
||||||
matches: matches,
|
matches4: newMatches4(matches),
|
||||||
localNets: localNets,
|
local4: nets4FromIPPrefixes(localNets),
|
||||||
state: state,
|
state: state,
|
||||||
}
|
}
|
||||||
return f
|
return f
|
||||||
}
|
}
|
||||||
@ -179,29 +181,32 @@ func MatchesFromFilterRules(pf []tailcfg.FilterRule) (Matches, error) {
|
|||||||
return mm, erracc
|
return mm, erracc
|
||||||
}
|
}
|
||||||
|
|
||||||
func parseIP(host string, defaultBits int) (Net, error) {
|
func parseIP(host string, defaultBits int) (netaddr.IPPrefix, error) {
|
||||||
ip := net.ParseIP(host)
|
if host == "*" {
|
||||||
if ip != nil && ip.IsUnspecified() {
|
// User explicitly requested wildcard dst ip.
|
||||||
// For clarity, reject 0.0.0.0 as an input
|
// TODO: ipv6
|
||||||
return NetNone, fmt.Errorf("ports=%#v: to allow all IP addresses, use *:port, not 0.0.0.0:port", host)
|
return netaddr.IPPrefix{IP: netaddr.IPv4(0, 0, 0, 0), Bits: 0}, nil
|
||||||
} else if ip == nil && host == "*" {
|
|
||||||
// User explicitly requested wildcard dst ip
|
|
||||||
return NetAny, nil
|
|
||||||
} else {
|
|
||||||
if ip != nil {
|
|
||||||
ip = ip.To4()
|
|
||||||
}
|
|
||||||
if ip == nil || len(ip) != 4 {
|
|
||||||
return NetNone, fmt.Errorf("ports=%#v: invalid IPv4 address", host)
|
|
||||||
}
|
|
||||||
if len(ip) == 4 && (defaultBits < 0 || defaultBits > 32) {
|
|
||||||
return NetNone, fmt.Errorf("invalid CIDR size %d for host %q", defaultBits, host)
|
|
||||||
}
|
|
||||||
return Net{
|
|
||||||
IP: NewIP(ip),
|
|
||||||
Mask: Netmask(defaultBits),
|
|
||||||
}, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ip, err := netaddr.ParseIP(host)
|
||||||
|
if err != nil {
|
||||||
|
return netaddr.IPPrefix{}, fmt.Errorf("ports=%#v: invalid IP address", host)
|
||||||
|
}
|
||||||
|
if ip == netaddr.IPv4(0, 0, 0, 0) {
|
||||||
|
// For clarity, reject 0.0.0.0 as an input
|
||||||
|
return netaddr.IPPrefix{}, fmt.Errorf("ports=%#v: to allow all IP addresses, use *:port, not 0.0.0.0:port", host)
|
||||||
|
}
|
||||||
|
if !ip.Is4() {
|
||||||
|
// TODO: ipv6
|
||||||
|
return netaddr.IPPrefix{}, fmt.Errorf("ports=%#v: invalid IPv4 address", host)
|
||||||
|
}
|
||||||
|
if defaultBits < 0 || defaultBits > 32 {
|
||||||
|
return netaddr.IPPrefix{}, fmt.Errorf("invalid CIDR size %d for host %q", defaultBits, host)
|
||||||
|
}
|
||||||
|
return netaddr.IPPrefix{
|
||||||
|
IP: ip,
|
||||||
|
Bits: uint8(defaultBits),
|
||||||
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO(apenwarr): use a bigger bucket for specifically TCP SYN accept logging?
|
// TODO(apenwarr): use a bigger bucket for specifically TCP SYN accept logging?
|
||||||
@ -266,7 +271,7 @@ func (f *Filter) runIn(q *packet.ParsedPacket) (r Response, why string) {
|
|||||||
// A compromised peer could try to send us packets for
|
// A compromised peer could try to send us packets for
|
||||||
// destinations we didn't explicitly advertise. This check is to
|
// destinations we didn't explicitly advertise. This check is to
|
||||||
// prevent that.
|
// prevent that.
|
||||||
if !ipInList(q.DstIP, f.localNets) {
|
if !ip4InList(q.DstIP, f.local4) {
|
||||||
return Drop, "destination not allowed"
|
return Drop, "destination not allowed"
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -284,7 +289,7 @@ func (f *Filter) runIn(q *packet.ParsedPacket) (r Response, why string) {
|
|||||||
// related to an existing ICMP-Echo, TCP, or UDP
|
// related to an existing ICMP-Echo, TCP, or UDP
|
||||||
// session.
|
// session.
|
||||||
return Accept, "icmp response ok"
|
return Accept, "icmp response ok"
|
||||||
} else if matchIPWithoutPorts(f.matches, q) {
|
} else if f.matches4.matchIPsOnly(q) {
|
||||||
// If any port is open to an IP, allow ICMP to it.
|
// If any port is open to an IP, allow ICMP to it.
|
||||||
return Accept, "icmp ok"
|
return Accept, "icmp ok"
|
||||||
}
|
}
|
||||||
@ -300,7 +305,7 @@ func (f *Filter) runIn(q *packet.ParsedPacket) (r Response, why string) {
|
|||||||
if q.IPProto == packet.TCP && !q.IsTCPSyn() {
|
if q.IPProto == packet.TCP && !q.IsTCPSyn() {
|
||||||
return Accept, "tcp non-syn"
|
return Accept, "tcp non-syn"
|
||||||
}
|
}
|
||||||
if matchIPPorts(f.matches, q) {
|
if f.matches4.match(q) {
|
||||||
return Accept, "tcp ok"
|
return Accept, "tcp ok"
|
||||||
}
|
}
|
||||||
case packet.UDP:
|
case packet.UDP:
|
||||||
@ -313,7 +318,7 @@ func (f *Filter) runIn(q *packet.ParsedPacket) (r Response, why string) {
|
|||||||
if ok {
|
if ok {
|
||||||
return Accept, "udp cached"
|
return Accept, "udp cached"
|
||||||
}
|
}
|
||||||
if matchIPPorts(f.matches, q) {
|
if f.matches4.match(q) {
|
||||||
return Accept, "udp ok"
|
return Accept, "udp ok"
|
||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
@ -399,9 +404,9 @@ func (f *Filter) pre(q *packet.ParsedPacket, rf RunFlags, dir direction) Respons
|
|||||||
)
|
)
|
||||||
|
|
||||||
// omitDropLogging reports whether packet p, which has already been
|
// omitDropLogging reports whether packet p, which has already been
|
||||||
// deemded a packet to Drop, should bypass the [rate-limited] logging.
|
// deemed a packet to Drop, should bypass the [rate-limited] logging.
|
||||||
// We don't want to log scary & spammy reject warnings for packets that
|
// We don't want to log scary & spammy reject warnings for packets
|
||||||
// are totally normal, like IPv6 route announcements.
|
// that are totally normal, like IPv6 route announcements.
|
||||||
func omitDropLogging(p *packet.ParsedPacket, dir direction) bool {
|
func omitDropLogging(p *packet.ParsedPacket, dir direction) bool {
|
||||||
b := p.Buffer()
|
b := p.Buffer()
|
||||||
switch dir {
|
switch dir {
|
||||||
|
@ -8,10 +8,13 @@
|
|||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
"encoding/hex"
|
"encoding/hex"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"inet.af/netaddr"
|
||||||
"tailscale.com/net/packet"
|
"tailscale.com/net/packet"
|
||||||
"tailscale.com/types/logger"
|
"tailscale.com/types/logger"
|
||||||
)
|
)
|
||||||
@ -22,43 +25,91 @@
|
|||||||
var UDP = packet.UDP
|
var UDP = packet.UDP
|
||||||
var Fragment = packet.Fragment
|
var Fragment = packet.Fragment
|
||||||
|
|
||||||
func nets(ips []packet.IP4) []Net {
|
func pfx(s string) netaddr.IPPrefix {
|
||||||
out := make([]Net, 0, len(ips))
|
pfx, err := netaddr.ParseIPPrefix(s)
|
||||||
for _, ip := range ips {
|
if err != nil {
|
||||||
out = append(out, Net{ip, Netmask(32)})
|
panic(err)
|
||||||
}
|
}
|
||||||
return out
|
return pfx
|
||||||
}
|
}
|
||||||
|
|
||||||
func ippr(ip packet.IP4, start, end uint16) []NetPortRange {
|
func nets(nets ...string) (ret []netaddr.IPPrefix) {
|
||||||
return []NetPortRange{
|
for _, s := range nets {
|
||||||
NetPortRange{Net{ip, Netmask(32)}, PortRange{start, end}},
|
if i := strings.IndexByte(s, '/'); i == -1 {
|
||||||
|
ip, err := netaddr.ParseIP(s)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
bits := uint8(32)
|
||||||
|
if ip.Is6() {
|
||||||
|
bits = 128
|
||||||
|
}
|
||||||
|
ret = append(ret, netaddr.IPPrefix{IP: ip, Bits: bits})
|
||||||
|
} else {
|
||||||
|
pfx, err := netaddr.ParseIPPrefix(s)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
ret = append(ret, pfx)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
return ret
|
||||||
}
|
}
|
||||||
|
|
||||||
func netpr(ip packet.IP4, bits int, start, end uint16) []NetPortRange {
|
func ports(s string) PortRange {
|
||||||
return []NetPortRange{
|
if s == "*" {
|
||||||
NetPortRange{Net{ip, Netmask(bits)}, PortRange{start, end}},
|
return PortRangeAny
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var fs, ls string
|
||||||
|
i := strings.IndexByte(s, '-')
|
||||||
|
if i == -1 {
|
||||||
|
fs = s
|
||||||
|
ls = fs
|
||||||
|
} else {
|
||||||
|
fs = s[:i]
|
||||||
|
ls = s[i+1:]
|
||||||
|
}
|
||||||
|
first, err := strconv.ParseInt(fs, 10, 16)
|
||||||
|
if err != nil {
|
||||||
|
panic(fmt.Sprintf("invalid NetPortRange %q", s))
|
||||||
|
}
|
||||||
|
last, err := strconv.ParseInt(ls, 10, 16)
|
||||||
|
if err != nil {
|
||||||
|
panic(fmt.Sprintf("invalid NetPortRange %q", s))
|
||||||
|
}
|
||||||
|
return PortRange{uint16(first), uint16(last)}
|
||||||
|
}
|
||||||
|
|
||||||
|
func netports(netPorts ...string) (ret []NetPortRange) {
|
||||||
|
for _, s := range netPorts {
|
||||||
|
i := strings.LastIndexByte(s, ':')
|
||||||
|
if i == -1 {
|
||||||
|
panic(fmt.Sprintf("invalid NetPortRange %q", s))
|
||||||
|
}
|
||||||
|
|
||||||
|
npr := NetPortRange{
|
||||||
|
Net: nets(s[:i])[0],
|
||||||
|
Ports: ports(s[i+1:]),
|
||||||
|
}
|
||||||
|
ret = append(ret, npr)
|
||||||
|
}
|
||||||
|
return ret
|
||||||
}
|
}
|
||||||
|
|
||||||
var matches = Matches{
|
var matches = Matches{
|
||||||
{Srcs: nets([]packet.IP4{0x08010101, 0x08020202}), Dsts: []NetPortRange{
|
{Srcs: nets("8.1.1.1", "8.2.2.2"), Dsts: netports("1.2.3.4:22", "5.6.7.8:23-24")},
|
||||||
NetPortRange{Net{0x01020304, Netmask(32)}, PortRange{22, 22}},
|
{Srcs: nets("8.1.1.1", "8.2.2.2"), Dsts: netports("5.6.7.8:27-28")},
|
||||||
NetPortRange{Net{0x05060708, Netmask(32)}, PortRange{23, 24}},
|
{Srcs: nets("2.2.2.2"), Dsts: netports("8.1.1.1:22")},
|
||||||
}},
|
{Srcs: nets("0.0.0.0/0"), Dsts: netports("100.122.98.50:*")},
|
||||||
{Srcs: nets([]packet.IP4{0x08010101, 0x08020202}), Dsts: ippr(0x05060708, 27, 28)},
|
{Srcs: nets("0.0.0.0/0"), Dsts: netports("0.0.0.0/0:443")},
|
||||||
{Srcs: nets([]packet.IP4{0x02020202}), Dsts: ippr(0x08010101, 22, 22)},
|
{Srcs: nets("153.1.1.1", "153.1.1.2", "153.3.3.3"), Dsts: netports("1.2.3.4:999")},
|
||||||
{Srcs: []Net{NetAny}, Dsts: ippr(0x647a6232, 0, 65535)},
|
|
||||||
{Srcs: []Net{NetAny}, Dsts: netpr(0, 0, 443, 443)},
|
|
||||||
{Srcs: nets([]packet.IP4{0x99010101, 0x99010102, 0x99030303}), Dsts: ippr(0x01020304, 999, 999)},
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func newFilter(logf logger.Logf) *Filter {
|
func newFilter(logf logger.Logf) *Filter {
|
||||||
// Expects traffic to 100.122.98.50, 1.2.3.4, 5.6.7.8,
|
// Expects traffic to 100.122.98.50, 1.2.3.4, 5.6.7.8,
|
||||||
// 102.102.102.102, 119.119.119.119, 8.1.0.0/16
|
// 102.102.102.102, 119.119.119.119, 8.1.0.0/16
|
||||||
localNets := nets([]packet.IP4{0x647a6232, 0x01020304, 0x05060708, 0x66666666, 0x77777777})
|
localNets := nets("100.122.98.50", "1.2.3.4", "5.6.7.8", "102.102.102.102", "119.119.119.119", "8.1.0.0/16")
|
||||||
localNets = append(localNets, Net{packet.IP4(0x08010000), Netmask(16)})
|
|
||||||
|
|
||||||
return New(matches, localNets, nil, logf)
|
return New(matches, localNets, nil, logf)
|
||||||
}
|
}
|
||||||
@ -160,18 +211,19 @@ func TestNoAllocs(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestParseIP(t *testing.T) {
|
func TestParseIP(t *testing.T) {
|
||||||
|
var noaddr netaddr.IPPrefix
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
host string
|
host string
|
||||||
bits int
|
bits int
|
||||||
want Net
|
want netaddr.IPPrefix
|
||||||
wantErr string
|
wantErr string
|
||||||
}{
|
}{
|
||||||
{"8.8.8.8", 24, Net{IP: packet.NewIP4(net.ParseIP("8.8.8.8")), Mask: packet.NewIP4(net.ParseIP("255.255.255.0"))}, ""},
|
{"8.8.8.8", 24, pfx("8.8.8.8/24"), ""},
|
||||||
{"8.8.8.8", 33, Net{}, `invalid CIDR size 33 for host "8.8.8.8"`},
|
{"8.8.8.8", 33, noaddr, `invalid CIDR size 33 for host "8.8.8.8"`},
|
||||||
{"8.8.8.8", -1, Net{}, `invalid CIDR size -1 for host "8.8.8.8"`},
|
{"8.8.8.8", -1, noaddr, `invalid CIDR size -1 for host "8.8.8.8"`},
|
||||||
{"0.0.0.0", 24, Net{}, `ports="0.0.0.0": to allow all IP addresses, use *:port, not 0.0.0.0:port`},
|
{"0.0.0.0", 24, noaddr, `ports="0.0.0.0": to allow all IP addresses, use *:port, not 0.0.0.0:port`},
|
||||||
{"*", 24, NetAny, ""},
|
{"*", 24, pfx("0.0.0.0/0"), ""},
|
||||||
{"fe80::1", 128, NetNone, `ports="fe80::1": invalid IPv4 address`},
|
{"fe80::1", 128, pfx("255.255.255.255/32"), `ports="fe80::1": invalid IPv4 address`},
|
||||||
}
|
}
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
got, err := parseIP(tt.host, tt.bits)
|
got, err := parseIP(tt.host, tt.bits)
|
||||||
@ -215,6 +267,7 @@ func BenchmarkFilter(b *testing.B) {
|
|||||||
|
|
||||||
for _, bench := range benches {
|
for _, bench := range benches {
|
||||||
b.Run(bench.name, func(b *testing.B) {
|
b.Run(bench.name, func(b *testing.B) {
|
||||||
|
b.ReportAllocs()
|
||||||
for i := 0; i < b.N; i++ {
|
for i := 0; i < b.N; i++ {
|
||||||
q := &packet.ParsedPacket{}
|
q := &packet.ParsedPacket{}
|
||||||
q.Decode(bench.packet)
|
q.Decode(bench.packet)
|
||||||
|
@ -6,53 +6,17 @@
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"math/bits"
|
|
||||||
"net"
|
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"tailscale.com/net/packet"
|
"inet.af/netaddr"
|
||||||
)
|
)
|
||||||
|
|
||||||
func NewIP(ip net.IP) packet.IP4 {
|
// PortRange is a range of TCP and UDP ports.
|
||||||
return packet.NewIP4(ip)
|
|
||||||
}
|
|
||||||
|
|
||||||
type Net struct {
|
|
||||||
IP packet.IP4
|
|
||||||
Mask packet.IP4
|
|
||||||
}
|
|
||||||
|
|
||||||
func (n Net) Includes(ip packet.IP4) bool {
|
|
||||||
return (n.IP & n.Mask) == (ip & n.Mask)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (n Net) Bits() int {
|
|
||||||
return 32 - bits.TrailingZeros32(uint32(n.Mask))
|
|
||||||
}
|
|
||||||
|
|
||||||
func (n Net) String() string {
|
|
||||||
b := n.Bits()
|
|
||||||
if b == 32 {
|
|
||||||
return n.IP.String()
|
|
||||||
} else if b == 0 {
|
|
||||||
return "*"
|
|
||||||
} else {
|
|
||||||
return fmt.Sprintf("%s/%d", n.IP, b)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
var NetAny = Net{0, 0}
|
|
||||||
var NetNone = Net{^packet.IP4(0), ^packet.IP4(0)}
|
|
||||||
|
|
||||||
func Netmask(bits int) packet.IP4 {
|
|
||||||
b := ^uint32((1 << (32 - bits)) - 1)
|
|
||||||
return packet.IP4(b)
|
|
||||||
}
|
|
||||||
|
|
||||||
type PortRange struct {
|
type PortRange struct {
|
||||||
First, Last uint16
|
First, Last uint16 // inclusive
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// PortRangeAny represents all TCP and UDP ports.
|
||||||
var PortRangeAny = PortRange{0, 65535}
|
var PortRangeAny = PortRange{0, 65535}
|
||||||
|
|
||||||
func (pr PortRange) String() string {
|
func (pr PortRange) String() string {
|
||||||
@ -65,28 +29,40 @@ func (pr PortRange) String() string {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (pr PortRange) contains(port uint16) bool {
|
||||||
|
return port >= pr.First && port <= pr.Last
|
||||||
|
}
|
||||||
|
|
||||||
|
// NetAny matches all IP addresses.
|
||||||
|
// TODO: add ipv6.
|
||||||
|
var NetAny = []netaddr.IPPrefix{{IP: netaddr.IPv4(0, 0, 0, 0), Bits: 0}}
|
||||||
|
|
||||||
|
// NetPortRange combines an IP address prefix and PortRange.
|
||||||
type NetPortRange struct {
|
type NetPortRange struct {
|
||||||
Net Net
|
Net netaddr.IPPrefix
|
||||||
Ports PortRange
|
Ports PortRange
|
||||||
}
|
}
|
||||||
|
|
||||||
var NetPortRangeAny = NetPortRange{NetAny, PortRangeAny}
|
func (npr NetPortRange) String() string {
|
||||||
|
return fmt.Sprintf("%v:%v", npr.Net, npr.Ports)
|
||||||
func (ipr NetPortRange) String() string {
|
|
||||||
return fmt.Sprintf("%v:%v", ipr.Net, ipr.Ports)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var NetPortRangeAny = []NetPortRange{{Net: NetAny[0], Ports: PortRangeAny}}
|
||||||
|
|
||||||
|
// Match matches packets from any IP address in Srcs to any ip:port in
|
||||||
|
// Dsts.
|
||||||
type Match struct {
|
type Match struct {
|
||||||
Dsts []NetPortRange
|
Dsts []NetPortRange
|
||||||
Srcs []Net
|
Srcs []netaddr.IPPrefix
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Clone returns a deep copy of m.
|
||||||
func (m Match) Clone() (res Match) {
|
func (m Match) Clone() (res Match) {
|
||||||
if m.Dsts != nil {
|
if m.Dsts != nil {
|
||||||
res.Dsts = append([]NetPortRange{}, m.Dsts...)
|
res.Dsts = append([]NetPortRange{}, m.Dsts...)
|
||||||
}
|
}
|
||||||
if m.Srcs != nil {
|
if m.Srcs != nil {
|
||||||
res.Srcs = append([]Net{}, m.Srcs...)
|
res.Srcs = append([]netaddr.IPPrefix{}, m.Srcs...)
|
||||||
}
|
}
|
||||||
return res
|
return res
|
||||||
}
|
}
|
||||||
@ -115,57 +91,13 @@ func (m Match) String() string {
|
|||||||
return fmt.Sprintf("%v=>%v", ss, ds)
|
return fmt.Sprintf("%v=>%v", ss, ds)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Matches is a list of packet matchers.
|
||||||
type Matches []Match
|
type Matches []Match
|
||||||
|
|
||||||
func (m Matches) Clone() (res Matches) {
|
// Clone returns a deep copy of ms.
|
||||||
for _, match := range m {
|
func (ms Matches) Clone() (res Matches) {
|
||||||
|
for _, match := range ms {
|
||||||
res = append(res, match.Clone())
|
res = append(res, match.Clone())
|
||||||
}
|
}
|
||||||
return res
|
return res
|
||||||
}
|
}
|
||||||
|
|
||||||
func ipInList(ip packet.IP4, netlist []Net) bool {
|
|
||||||
for _, net := range netlist {
|
|
||||||
if net.Includes(ip) {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
func matchIPPorts(mm Matches, q *packet.ParsedPacket) bool {
|
|
||||||
for _, acl := range mm {
|
|
||||||
for _, dst := range acl.Dsts {
|
|
||||||
if !dst.Net.Includes(q.DstIP) {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if q.DstPort < dst.Ports.First || q.DstPort > dst.Ports.Last {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if !ipInList(q.SrcIP, acl.Srcs) {
|
|
||||||
// Skip other dests in this acl, since
|
|
||||||
// the src will never match.
|
|
||||||
break
|
|
||||||
}
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
func matchIPWithoutPorts(mm Matches, q *packet.ParsedPacket) bool {
|
|
||||||
for _, acl := range mm {
|
|
||||||
for _, dst := range acl.Dsts {
|
|
||||||
if !dst.Net.Includes(q.DstIP) {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if !ipInList(q.SrcIP, acl.Srcs) {
|
|
||||||
// Skip other dests in this acl, since
|
|
||||||
// the src will never match.
|
|
||||||
break
|
|
||||||
}
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
151
wgengine/filter/match4.go
Normal file
151
wgengine/filter/match4.go
Normal file
@ -0,0 +1,151 @@
|
|||||||
|
// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved.
|
||||||
|
// Use of this source code is governed by a BSD-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
package filter
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"math/bits"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"inet.af/netaddr"
|
||||||
|
"tailscale.com/net/packet"
|
||||||
|
)
|
||||||
|
|
||||||
|
type net4 struct {
|
||||||
|
ip packet.IP4
|
||||||
|
mask packet.IP4
|
||||||
|
}
|
||||||
|
|
||||||
|
func net4FromIPPrefix(pfx netaddr.IPPrefix) net4 {
|
||||||
|
if !pfx.IP.Is4() {
|
||||||
|
panic("net4FromIPPrefix given non-ipv4 prefix")
|
||||||
|
}
|
||||||
|
return net4{
|
||||||
|
ip: packet.IP4FromNetaddr(pfx.IP),
|
||||||
|
mask: netmask4(pfx.Bits),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func nets4FromIPPrefixes(pfxs []netaddr.IPPrefix) (ret []net4) {
|
||||||
|
for _, pfx := range pfxs {
|
||||||
|
if pfx.IP.Is4() {
|
||||||
|
ret = append(ret, net4FromIPPrefix(pfx))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ret
|
||||||
|
}
|
||||||
|
|
||||||
|
func (n net4) Contains(ip packet.IP4) bool {
|
||||||
|
return (n.ip & n.mask) == (ip & n.mask)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (n net4) Bits() int {
|
||||||
|
return 32 - bits.TrailingZeros32(uint32(n.mask))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (n net4) String() string {
|
||||||
|
b := n.Bits()
|
||||||
|
if b == 32 {
|
||||||
|
return n.ip.String()
|
||||||
|
} else if b == 0 {
|
||||||
|
return "*"
|
||||||
|
} else {
|
||||||
|
return fmt.Sprintf("%s/%d", n.ip, b)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type npr4 struct {
|
||||||
|
net net4
|
||||||
|
ports PortRange
|
||||||
|
}
|
||||||
|
|
||||||
|
func (npr npr4) String() string {
|
||||||
|
return fmt.Sprintf("%s:%s", npr.net, npr.ports)
|
||||||
|
}
|
||||||
|
|
||||||
|
type match4 struct {
|
||||||
|
dsts []npr4
|
||||||
|
srcs []net4
|
||||||
|
}
|
||||||
|
|
||||||
|
type matches4 []match4
|
||||||
|
|
||||||
|
func (ms matches4) String() string {
|
||||||
|
var b strings.Builder
|
||||||
|
for _, m := range ms {
|
||||||
|
fmt.Fprintf(&b, "%s => %s\n", m.srcs, m.dsts)
|
||||||
|
}
|
||||||
|
return b.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
func newMatches4(ms Matches) (ret matches4) {
|
||||||
|
for _, m := range ms {
|
||||||
|
var m4 match4
|
||||||
|
for _, src := range m.Srcs {
|
||||||
|
if src.IP.Is4() {
|
||||||
|
m4.srcs = append(m4.srcs, net4FromIPPrefix(src))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for _, dst := range m.Dsts {
|
||||||
|
if dst.Net.IP.Is4() {
|
||||||
|
m4.dsts = append(m4.dsts, npr4{net4FromIPPrefix(dst.Net), dst.Ports})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if len(m4.srcs) > 0 && len(m4.dsts) > 0 {
|
||||||
|
ret = append(ret, m4)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ret
|
||||||
|
}
|
||||||
|
|
||||||
|
// match returns whether q's source IP and destination IP:port match
|
||||||
|
// any of ms.
|
||||||
|
func (ms matches4) match(q *packet.ParsedPacket) bool {
|
||||||
|
for _, m := range ms {
|
||||||
|
if !ip4InList(q.SrcIP, m.srcs) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
for _, dst := range m.dsts {
|
||||||
|
if !dst.net.Contains(q.DstIP) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if !dst.ports.contains(q.DstPort) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// matchIPsOnly returns whether q's source and destination IP match
|
||||||
|
// any of ms.
|
||||||
|
func (ms matches4) matchIPsOnly(q *packet.ParsedPacket) bool {
|
||||||
|
for _, m := range ms {
|
||||||
|
if !ip4InList(q.SrcIP, m.srcs) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
for _, dst := range m.dsts {
|
||||||
|
if dst.net.Contains(q.DstIP) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func netmask4(bits uint8) packet.IP4 {
|
||||||
|
b := ^uint32((1 << (32 - bits)) - 1)
|
||||||
|
return packet.IP4(b)
|
||||||
|
}
|
||||||
|
|
||||||
|
func ip4InList(ip packet.IP4, netlist []net4) bool {
|
||||||
|
for _, net := range netlist {
|
||||||
|
if net.Contains(ip) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
@ -158,7 +158,7 @@ func newMagicStack(t *testing.T, logf logger.Logf, l nettype.PacketListener, der
|
|||||||
|
|
||||||
tun := tuntest.NewChannelTUN()
|
tun := tuntest.NewChannelTUN()
|
||||||
tsTun := tstun.WrapTUN(logf, tun.TUN())
|
tsTun := tstun.WrapTUN(logf, tun.TUN())
|
||||||
tsTun.SetFilter(filter.NewAllowAll([]filter.Net{filter.NetAny}, logf))
|
tsTun.SetFilter(filter.NewAllowAll(filter.NetAny, logf))
|
||||||
|
|
||||||
dev := device.NewDevice(tsTun, &device.DeviceOptions{
|
dev := device.NewDevice(tsTun, &device.DeviceOptions{
|
||||||
Logger: &device.Logger{
|
Logger: &device.Logger{
|
||||||
|
@ -6,11 +6,15 @@
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
|
"fmt"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"testing"
|
"testing"
|
||||||
"unsafe"
|
"unsafe"
|
||||||
|
|
||||||
"github.com/tailscale/wireguard-go/tun/tuntest"
|
"github.com/tailscale/wireguard-go/tun/tuntest"
|
||||||
|
"inet.af/netaddr"
|
||||||
"tailscale.com/net/packet"
|
"tailscale.com/net/packet"
|
||||||
"tailscale.com/types/logger"
|
"tailscale.com/types/logger"
|
||||||
"tailscale.com/wgengine/filter"
|
"tailscale.com/wgengine/filter"
|
||||||
@ -29,35 +33,76 @@ func udp(src, dst packet.IP4, sport, dport uint16) []byte {
|
|||||||
return packet.Generate(header, []byte("udp_payload"))
|
return packet.Generate(header, []byte("udp_payload"))
|
||||||
}
|
}
|
||||||
|
|
||||||
func filterNet(ip, mask packet.IP4) filter.Net {
|
func nets(nets ...string) (ret []netaddr.IPPrefix) {
|
||||||
return filter.Net{IP: ip, Mask: mask}
|
for _, s := range nets {
|
||||||
|
if i := strings.IndexByte(s, '/'); i == -1 {
|
||||||
|
ip, err := netaddr.ParseIP(s)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
bits := uint8(32)
|
||||||
|
if ip.Is6() {
|
||||||
|
bits = 128
|
||||||
|
}
|
||||||
|
ret = append(ret, netaddr.IPPrefix{IP: ip, Bits: bits})
|
||||||
|
} else {
|
||||||
|
pfx, err := netaddr.ParseIPPrefix(s)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
ret = append(ret, pfx)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ret
|
||||||
}
|
}
|
||||||
|
|
||||||
func nets(ips []packet.IP4) []filter.Net {
|
func ports(s string) filter.PortRange {
|
||||||
out := make([]filter.Net, 0, len(ips))
|
if s == "*" {
|
||||||
for _, ip := range ips {
|
return filter.PortRangeAny
|
||||||
out = append(out, filterNet(ip, filter.Netmask(32)))
|
|
||||||
}
|
}
|
||||||
return out
|
|
||||||
|
var fs, ls string
|
||||||
|
i := strings.IndexByte(s, '-')
|
||||||
|
if i == -1 {
|
||||||
|
fs = s
|
||||||
|
ls = fs
|
||||||
|
} else {
|
||||||
|
fs = s[:i]
|
||||||
|
ls = s[i+1:]
|
||||||
|
}
|
||||||
|
first, err := strconv.ParseInt(fs, 10, 16)
|
||||||
|
if err != nil {
|
||||||
|
panic(fmt.Sprintf("invalid NetPortRange %q", s))
|
||||||
|
}
|
||||||
|
last, err := strconv.ParseInt(ls, 10, 16)
|
||||||
|
if err != nil {
|
||||||
|
panic(fmt.Sprintf("invalid NetPortRange %q", s))
|
||||||
|
}
|
||||||
|
return filter.PortRange{First: uint16(first), Last: uint16(last)}
|
||||||
}
|
}
|
||||||
|
|
||||||
func ippr(ip packet.IP4, start, end uint16) []filter.NetPortRange {
|
func netports(netPorts ...string) (ret []filter.NetPortRange) {
|
||||||
return []filter.NetPortRange{
|
for _, s := range netPorts {
|
||||||
filter.NetPortRange{
|
i := strings.LastIndexByte(s, ':')
|
||||||
Net: filterNet(ip, filter.Netmask(32)),
|
if i == -1 {
|
||||||
Ports: filter.PortRange{First: start, Last: end},
|
panic(fmt.Sprintf("invalid NetPortRange %q", s))
|
||||||
},
|
}
|
||||||
|
|
||||||
|
npr := filter.NetPortRange{
|
||||||
|
Net: nets(s[:i])[0],
|
||||||
|
Ports: ports(s[i+1:]),
|
||||||
|
}
|
||||||
|
ret = append(ret, npr)
|
||||||
}
|
}
|
||||||
|
return ret
|
||||||
}
|
}
|
||||||
|
|
||||||
func setfilter(logf logger.Logf, tun *TUN) {
|
func setfilter(logf logger.Logf, tun *TUN) {
|
||||||
matches := filter.Matches{
|
matches := filter.Matches{
|
||||||
{Srcs: nets([]packet.IP4{0x05060708}), Dsts: ippr(0x01020304, 89, 90)},
|
{Srcs: nets("5.6.7.8"), Dsts: netports("1.2.3.4:89-90")},
|
||||||
{Srcs: nets([]packet.IP4{0x01020304}), Dsts: ippr(0x05060708, 98, 98)},
|
{Srcs: nets("1.2.3.4"), Dsts: netports("5.6.7.8:98")},
|
||||||
}
|
|
||||||
localNets := []filter.Net{
|
|
||||||
filterNet(packet.IP4(0x01020304), filter.Netmask(16)),
|
|
||||||
}
|
}
|
||||||
|
localNets := nets("1.2.0.0/16")
|
||||||
tun.SetFilter(filter.New(matches, localNets, nil, logf))
|
tun.SetFilter(filter.New(matches, localNets, nil, logf))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user