wgengine/filter: support subnet mask rules, not just /32 IPs.

This depends on improved support from the control server, to send the
new subnet width (Bits) fields. If these are missing, we fall back to
assuming their value is /32.

Conversely, if the server sends Bits fields to an older client, it will
interpret them as /32 addresses. Since the only rules we allow are
"accept" rules, this will be narrower or equal to the intended rule, so
older clients will simply reject hosts on the wider subnet (fail
closed).

With this change, the internal filter.Matches format has diverged
from the wire format used by controlclient, so move the wire format
into tailcfg and convert it to filter.Matches in controlclient.

Signed-off-by: Avery Pennarun <apenwarr@tailscale.com>
This commit is contained in:
Avery Pennarun 2020-04-30 01:49:17 -04:00
parent d6c34368e8
commit 65fbb9c303
7 changed files with 202 additions and 77 deletions

View File

@ -593,7 +593,7 @@ func (c *Direct) PollNetMap(ctx context.Context, maxPolls int, cb func(*NetworkM
DNS: resp.DNS, DNS: resp.DNS,
DNSDomains: resp.SearchPaths, DNSDomains: resp.SearchPaths,
Hostinfo: resp.Node.Hostinfo, Hostinfo: resp.Node.Hostinfo,
PacketFilter: resp.PacketFilter, PacketFilter: c.parsePacketFilter(resp.PacketFilter),
} }
for _, profile := range resp.UserProfiles { for _, profile := range resp.UserProfiles {
nm.UserProfiles[profile.ID] = profile nm.UserProfiles[profile.ID] = profile

View File

@ -0,0 +1,80 @@
package controlclient
import (
"fmt"
"net"
"tailscale.com/tailcfg"
"tailscale.com/wgengine/filter"
)
func parseIP(host string, defaultBits int) (filter.Net, error) {
ip := net.ParseIP(host)
if ip != nil && ip.IsUnspecified() {
// For clarity, reject 0.0.0.0 as an input
return filter.NetNone, fmt.Errorf("ports=%#v: to allow all IP addresses, use *:port, not 0.0.0.0:port", host)
} else if ip == nil && host == "*" {
// User explicitly requested wildcard dst ip
return filter.NetAny, nil
} else {
if ip != nil {
ip = ip.To4()
}
if ip == nil || len(ip) != 4 {
return filter.NetNone, fmt.Errorf("ports=%#v: invalid IPv4 address", host)
}
return filter.Net{
IP: filter.NewIP(ip),
Mask: filter.Netmask(defaultBits),
}, nil
}
}
// Parse a backward-compatible FilterRule used by control's wire format,
// producing the most current filter.Matches format.
func (c *Direct) parsePacketFilter(pf []tailcfg.FilterRule) filter.Matches {
mm := make([]filter.Match, 0, len(pf))
var erracc error
for _, r := range pf {
m := filter.Match{}
for i, s := range r.SrcIPs {
bits := 32
if len(r.SrcBits) > i {
bits = r.SrcBits[i]
}
net, err := parseIP(s, bits)
if err != nil && erracc == nil {
erracc = err
continue
}
m.Srcs = append(m.Srcs, net)
}
for _, d := range r.DstPorts {
bits := 32
if d.Bits != nil {
bits = *d.Bits
}
net, err := parseIP(d.IP, bits)
if err != nil && erracc == nil {
erracc = err
continue
}
m.Dsts = append(m.Dsts, filter.NetPortRange{
Net: net,
Ports: filter.PortRange{
First: d.Ports.First,
Last: d.Ports.Last,
},
})
}
mm = append(mm, m)
}
if erracc != nil {
c.logf("parsePacketFilter: %s\n", erracc)
}
return mm
}

View File

@ -15,7 +15,6 @@
"github.com/tailscale/wireguard-go/wgcfg" "github.com/tailscale/wireguard-go/wgcfg"
"golang.org/x/oauth2" "golang.org/x/oauth2"
"tailscale.com/types/opt" "tailscale.com/types/opt"
"tailscale.com/wgengine/filter"
) )
type ID int64 type ID int64
@ -404,6 +403,40 @@ type MapRequest struct {
Hostinfo *Hostinfo Hostinfo *Hostinfo
} }
// PortRange represents a range of UDP or TCP port numbers.
type PortRange struct {
First uint16
Last uint16
}
var PortRangeAny = PortRange{0, 65535}
// NetPortRange represents a single subnet:portrange.
type NetPortRange struct {
IP string
Bits *int // backward compatibility: if missing, means "all" bits
Ports PortRange
}
// FilterRule represents one rule in a packet filter.
type FilterRule struct {
SrcIPs []string
SrcBits []int
DstPorts []NetPortRange
}
var FilterAllowAll = []FilterRule{
FilterRule{
SrcIPs: []string{"*"},
SrcBits: nil,
DstPorts: []NetPortRange{NetPortRange{
IP: "*",
Bits: nil,
Ports: PortRange{0, 65535},
}},
},
}
type MapResponse struct { type MapResponse struct {
KeepAlive bool // if set, all other fields are ignored KeepAlive bool // if set, all other fields are ignored
@ -415,7 +448,7 @@ type MapResponse struct {
// ACLs // ACLs
Domain string Domain string
PacketFilter filter.Matches PacketFilter []FilterRule
UserProfiles []UserProfile UserProfiles []UserProfile
Roles []Role Roles []Role
// TODO: Groups []Group // TODO: Groups []Group

View File

@ -71,7 +71,7 @@ type tuple struct {
// MatchAllowAll matches all packets. // MatchAllowAll matches all packets.
var MatchAllowAll = Matches{ var MatchAllowAll = Matches{
Match{[]IPPortRange{IPPortRangeAny}, []IP{IPAny}}, Match{[]NetPortRange{NetPortRangeAny}, []Net{NetAny}},
} }
// NewAllowAll returns a packet filter that accepts everything. // NewAllowAll returns a packet filter that accepts everything.

View File

@ -21,23 +21,37 @@
var UDP = packet.UDP var UDP = packet.UDP
var Fragment = packet.Fragment var Fragment = packet.Fragment
func ippr(ip IP, start, end uint16) []IPPortRange { func nets(ips []IP) []Net {
return []IPPortRange{ out := make([]Net, 0, len(ips))
IPPortRange{ip, PortRange{start, end}}, for _, ip := range ips {
out = append(out, Net{ip, Netmask(32)})
}
return out
}
func ippr(ip IP, start, end uint16) []NetPortRange {
return []NetPortRange{
NetPortRange{Net{ip, Netmask(32)}, PortRange{start, end}},
}
}
func netpr(ip IP, bits int, start, end uint16) []NetPortRange {
return []NetPortRange{
NetPortRange{Net{ip, Netmask(bits)}, PortRange{start, end}},
} }
} }
func TestFilter(t *testing.T) { func TestFilter(t *testing.T) {
mm := Matches{ mm := Matches{
{SrcIPs: []IP{0x08010101, 0x08020202}, DstPorts: []IPPortRange{ {Srcs: nets([]IP{0x08010101, 0x08020202}), Dsts: []NetPortRange{
IPPortRange{0x01020304, PortRange{22, 22}}, NetPortRange{Net{0x01020304, Netmask(32)}, PortRange{22, 22}},
IPPortRange{0x05060708, PortRange{23, 24}}, NetPortRange{Net{0x05060708, Netmask(32)}, PortRange{23, 24}},
}}, }},
{SrcIPs: []IP{0x08010101, 0x08020202}, DstPorts: ippr(0x05060708, 27, 28)}, {Srcs: nets([]IP{0x08010101, 0x08020202}), Dsts: ippr(0x05060708, 27, 28)},
{SrcIPs: []IP{0x02020202}, DstPorts: ippr(0x08010101, 22, 22)}, {Srcs: nets([]IP{0x02020202}), Dsts: ippr(0x08010101, 22, 22)},
{SrcIPs: []IP{0}, DstPorts: ippr(0x647a6232, 0, 65535)}, {Srcs: []Net{NetAny}, Dsts: ippr(0x647a6232, 0, 65535)},
{SrcIPs: []IP{0}, DstPorts: ippr(0, 443, 443)}, {Srcs: []Net{NetAny}, Dsts: netpr(0, 0, 443, 443)},
{SrcIPs: []IP{0x99010101, 0x99010102, 0x99030303}, DstPorts: ippr(0x01020304, 999, 999)}, {Srcs: nets([]IP{0x99010101, 0x99010102, 0x99030303}), Dsts: ippr(0x01020304, 999, 999)},
} }
acl := New(mm, nil) acl := New(mm, nil)

View File

@ -6,6 +6,8 @@
import ( import (
"fmt" "fmt"
"math/bits"
"net"
"strings" "strings"
"tailscale.com/wgengine/packet" "tailscale.com/wgengine/packet"
@ -13,9 +15,42 @@
type IP = packet.IP type IP = packet.IP
const IPAny = IP(0) func NewIP(ip net.IP) IP {
return packet.NewIP(ip)
}
var NewIP = packet.NewIP type Net struct {
IP IP
Mask IP
}
func (n Net) Includes(ip IP) bool {
return (n.IP & n.Mask) == (ip & n.Mask)
}
func (n Net) Bits() int {
return 32 - bits.TrailingZeros32(uint32(n.Mask))
}
func (n Net) String() string {
b := n.Bits()
if b == 32 {
return n.IP.String()
} else if b == 0 {
return "*"
} else {
return fmt.Sprintf("%s/%d", n.IP, b)
}
}
var NetAny = Net{0, 0}
var NetNone = Net{^IP(0), ^IP(0)}
func Netmask(bits int) IP {
var b uint32
b = ^uint32((1 << (32 - bits)) - 1)
return IP(b)
}
type PortRange struct { type PortRange struct {
First, Last uint16 First, Last uint16
@ -33,39 +68,39 @@ func (pr PortRange) String() string {
} }
} }
type IPPortRange struct { type NetPortRange struct {
IP IP Net Net
Ports PortRange Ports PortRange
} }
var IPPortRangeAny = IPPortRange{IPAny, PortRangeAny} var NetPortRangeAny = NetPortRange{NetAny, PortRangeAny}
func (ipr IPPortRange) String() string { func (ipr NetPortRange) String() string {
return fmt.Sprintf("%v:%v", ipr.IP, ipr.Ports) return fmt.Sprintf("%v:%v", ipr.Net, ipr.Ports)
} }
type Match struct { type Match struct {
DstPorts []IPPortRange Dsts []NetPortRange
SrcIPs []IP Srcs []Net
} }
func (m Match) Clone() (res Match) { func (m Match) Clone() (res Match) {
if m.DstPorts != nil { if m.Dsts != nil {
res.DstPorts = append([]IPPortRange{}, m.DstPorts...) res.Dsts = append([]NetPortRange{}, m.Dsts...)
} }
if m.SrcIPs != nil { if m.Srcs != nil {
res.SrcIPs = append([]IP{}, m.SrcIPs...) res.Srcs = append([]Net{}, m.Srcs...)
} }
return res return res
} }
func (m Match) String() string { func (m Match) String() string {
srcs := []string{} srcs := []string{}
for _, srcip := range m.SrcIPs { for _, src := range m.Srcs {
srcs = append(srcs, srcip.String()) srcs = append(srcs, src.String())
} }
dsts := []string{} dsts := []string{}
for _, dst := range m.DstPorts { for _, dst := range m.Dsts {
dsts = append(dsts, dst.String()) dsts = append(dsts, dst.String())
} }
@ -92,9 +127,9 @@ func (m Matches) Clone() (res Matches) {
return res return res
} }
func ipInList(ip IP, iplist []IP) bool { func ipInList(ip IP, netlist []Net) bool {
for _, ipp := range iplist { for _, net := range netlist {
if ipp == IPAny || ipp == ip { if net.Includes(ip) {
return true return true
} }
} }
@ -103,14 +138,14 @@ func ipInList(ip IP, iplist []IP) bool {
func matchIPPorts(mm Matches, q *packet.QDecode) bool { func matchIPPorts(mm Matches, q *packet.QDecode) bool {
for _, acl := range mm { for _, acl := range mm {
for _, dst := range acl.DstPorts { for _, dst := range acl.Dsts {
if dst.IP != IPAny && dst.IP != q.DstIP { if !dst.Net.Includes(q.DstIP) {
continue continue
} }
if q.DstPort < dst.Ports.First || q.DstPort > dst.Ports.Last { if q.DstPort < dst.Ports.First || q.DstPort > dst.Ports.Last {
continue continue
} }
if !ipInList(q.SrcIP, acl.SrcIPs) { if !ipInList(q.SrcIP, acl.Srcs) {
// Skip other dests in this acl, since // Skip other dests in this acl, since
// the src will never match. // the src will never match.
break break
@ -123,11 +158,11 @@ func matchIPPorts(mm Matches, q *packet.QDecode) bool {
func matchIPWithoutPorts(mm Matches, q *packet.QDecode) bool { func matchIPWithoutPorts(mm Matches, q *packet.QDecode) bool {
for _, acl := range mm { for _, acl := range mm {
for _, dst := range acl.DstPorts { for _, dst := range acl.Dsts {
if dst.IP != IPAny && dst.IP != q.DstIP { if !dst.Net.Includes(q.DstIP) {
continue continue
} }
if !ipInList(q.SrcIP, acl.SrcIPs) { if !ipInList(q.SrcIP, acl.Srcs) {
// Skip other dests in this acl, since // Skip other dests in this acl, since
// the src will never match. // the src will never match.
break break

View File

@ -6,7 +6,6 @@
import ( import (
"encoding/binary" "encoding/binary"
"encoding/json"
"fmt" "fmt"
"log" "log"
"net" "net"
@ -43,8 +42,6 @@ func (p IPProto) String() string {
type IP uint32 type IP uint32
const IPAny = IP(0)
func NewIP(b net.IP) IP { func NewIP(b net.IP) IP {
b4 := b.To4() b4 := b.To4()
if b4 == nil { if b4 == nil {
@ -54,45 +51,11 @@ func NewIP(b net.IP) IP {
} }
func (ip IP) String() string { func (ip IP) String() string {
if ip == 0 {
return "*"
}
b := make([]byte, 4) b := make([]byte, 4)
binary.BigEndian.PutUint32(b, uint32(ip)) binary.BigEndian.PutUint32(b, uint32(ip))
return fmt.Sprintf("%d.%d.%d.%d", b[0], b[1], b[2], b[3]) return fmt.Sprintf("%d.%d.%d.%d", b[0], b[1], b[2], b[3])
} }
func (ipp *IP) MarshalJSON() ([]byte, error) {
s := "\"" + (*ipp).String() + "\""
return []byte(s), nil
}
func (ipp *IP) UnmarshalJSON(b []byte) error {
var hostp *string
err := json.Unmarshal(b, &hostp)
if err != nil {
return err
}
host := *hostp
ip := net.ParseIP(host)
if ip != nil && ip.IsUnspecified() {
// For clarity, reject 0.0.0.0 as an input
return fmt.Errorf("ports=%#v: to allow all IP addresses, use *:port, not 0.0.0.0:port", host)
} else if ip == nil && host == "*" {
// User explicitly requested wildcard dst ip
*ipp = IPAny
} else {
if ip != nil {
ip = ip.To4()
}
if ip == nil || len(ip) != 4 {
return fmt.Errorf("ports=%#v: invalid IPv4 address", host)
}
*ipp = NewIP(ip)
}
return nil
}
const ( const (
EchoReply uint8 = 0x00 EchoReply uint8 = 0x00
EchoRequest uint8 = 0x08 EchoRequest uint8 = 0x08