wgengine/filter: let unknown IPProto match if IP okay & match allows all ports

RELNOTE=yes

Change-Id: I96eaf3cf550cee7bb6cdb4ad81fc761e280a1b2a
Signed-off-by: Brad Fitzpatrick <bradfitz@tailscale.com>
(cherry picked from commit 69de3bf7bf)
This commit is contained in:
Brad Fitzpatrick 2021-12-04 11:52:39 -08:00
parent 6c44133d8f
commit 972bcccc36
3 changed files with 96 additions and 18 deletions

View File

@ -382,6 +382,9 @@ func (f *Filter) runIn4(q *packet.Parsed) (r Response, why string) {
case ipproto.TSMP: case ipproto.TSMP:
return Accept, "tsmp ok" return Accept, "tsmp ok"
default: default:
if f.matches4.matchProtoAndIPsOnlyIfAllPorts(q) {
return Accept, "otherproto ok"
}
return Drop, "Unknown proto" return Drop, "Unknown proto"
} }
return Drop, "no rules matched" return Drop, "no rules matched"
@ -439,6 +442,9 @@ func (f *Filter) runIn6(q *packet.Parsed) (r Response, why string) {
case ipproto.TSMP: case ipproto.TSMP:
return Accept, "tsmp ok" return Accept, "tsmp ok"
default: default:
if f.matches6.matchProtoAndIPsOnlyIfAllPorts(q) {
return Accept, "otherproto ok"
}
return Drop, "Unknown proto" return Drop, "Unknown proto"
} }
return Drop, "no rules matched" return Drop, "no rules matched"

View File

@ -23,8 +23,14 @@
"tailscale.com/types/logger" "tailscale.com/types/logger"
) )
func newFilter(logf logger.Logf) *Filter { // testAllowedProto is an IP protocol number we treat as allowed for
m := func(srcs []netaddr.IPPrefix, dsts []NetPortRange, protos ...ipproto.Proto) Match { // these tests.
const (
testAllowedProto ipproto.Proto = 116
testDeniedProto ipproto.Proto = 127 // CRUDP, appropriately cruddy
)
func m(srcs []netaddr.IPPrefix, dsts []NetPortRange, protos ...ipproto.Proto) Match {
if protos == nil { if protos == nil {
protos = defaultProtos protos = defaultProtos
} }
@ -34,6 +40,8 @@ func newFilter(logf logger.Logf) *Filter {
Dsts: dsts, Dsts: dsts,
} }
} }
func newFilter(logf logger.Logf) *Filter {
matches := []Match{ matches := []Match{
m(nets("8.1.1.1", "8.2.2.2"), netports("1.2.3.4:22", "5.6.7.8:23-24")), m(nets("8.1.1.1", "8.2.2.2"), netports("1.2.3.4:22", "5.6.7.8:23-24")),
m(nets("9.1.1.1", "9.2.2.2"), netports("1.2.3.4:22", "5.6.7.8:23-24"), ipproto.SCTP), m(nets("9.1.1.1", "9.2.2.2"), netports("1.2.3.4:22", "5.6.7.8:23-24"), ipproto.SCTP),
@ -44,6 +52,8 @@ func newFilter(logf logger.Logf) *Filter {
m(nets("153.1.1.1", "153.1.1.2", "153.3.3.3"), netports("1.2.3.4:999")), m(nets("153.1.1.1", "153.1.1.2", "153.3.3.3"), netports("1.2.3.4:999")),
m(nets("::1", "::2"), netports("2001::1:22", "2001::2:22")), m(nets("::1", "::2"), netports("2001::1:22", "2001::2:22")),
m(nets("::/0"), netports("::/0:443")), m(nets("::/0"), netports("::/0:443")),
m(nets("0.0.0.0/0"), netports("0.0.0.0/0:*"), testAllowedProto),
m(nets("::/0"), netports("::/0:*"), testAllowedProto),
} }
// Expects traffic to 100.122.98.50, 1.2.3.4, 5.6.7.8, // Expects traffic to 100.122.98.50, 1.2.3.4, 5.6.7.8,
@ -112,6 +122,12 @@ type InOut struct {
{Drop, parsed(ipproto.SCTP, "8.1.1.1", "1.2.3.4", 999, 22)}, {Drop, parsed(ipproto.SCTP, "8.1.1.1", "1.2.3.4", 999, 22)},
// But SCTP is allowed for 9.1.1.1 // But SCTP is allowed for 9.1.1.1
{Accept, parsed(ipproto.SCTP, "9.1.1.1", "1.2.3.4", 999, 22)}, {Accept, parsed(ipproto.SCTP, "9.1.1.1", "1.2.3.4", 999, 22)},
// Unknown protocol is allowed if all its ports are allowed.
{Accept, parsed(testAllowedProto, "1.2.3.4", "5.6.7.8", 0, 0)},
{Accept, parsed(testAllowedProto, "2001::1", "2001::2", 0, 0)},
{Drop, parsed(testDeniedProto, "1.2.3.4", "5.6.7.8", 0, 0)},
{Drop, parsed(testDeniedProto, "2001::1", "2001::2", 0, 0)},
} }
for i, test := range tests { for i, test := range tests {
aclFunc := acl.runIn4 aclFunc := acl.runIn4
@ -534,13 +550,7 @@ func TestLoggingPrivacy(t *testing.T) {
} }
} }
func mustIP(s string) netaddr.IP { var mustIP = netaddr.MustParseIP
ip, err := netaddr.ParseIP(s)
if err != nil {
panic(err)
}
return ip
}
func parsed(proto ipproto.Proto, src, dst string, sport, dport uint16) packet.Parsed { func parsed(proto ipproto.Proto, src, dst string, sport, dport uint16) packet.Parsed {
sip, dip := mustIP(src), mustIP(dst) sip, dip := mustIP(src), mustIP(dst)
@ -689,7 +699,7 @@ func nets(nets ...string) (ret []netaddr.IPPrefix) {
func ports(s string) PortRange { func ports(s string) PortRange {
if s == "*" { if s == "*" {
return PortRange{First: 0, Last: 65535} return allPorts
} }
var fs, ls string var fs, ls string
@ -815,3 +825,40 @@ func TestMatchesFromFilterRules(t *testing.T) {
}) })
} }
} }
func TestMatchesMatchProtoAndIPsOnlyIfAllPorts(t *testing.T) {
tests := []struct {
name string
m Match
p packet.Parsed
want bool
}{
{
name: "all_ports_okay",
m: m(nets("0.0.0.0/0"), netports("0.0.0.0/0:*"), testAllowedProto),
p: parsed(testAllowedProto, "1.2.3.4", "5.6.7.8", 0, 0),
want: true,
},
{
name: "all_ports_match_but_packet_wrong_proto",
m: m(nets("0.0.0.0/0"), netports("0.0.0.0/0:*"), testAllowedProto),
p: parsed(testDeniedProto, "1.2.3.4", "5.6.7.8", 0, 0),
want: false,
},
{
name: "ports_requirements_dont_match_unknown_proto",
m: m(nets("0.0.0.0/0"), netports("0.0.0.0/0:12345"), testAllowedProto),
p: parsed(testAllowedProto, "1.2.3.4", "5.6.7.8", 0, 0),
want: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
matches := matches{tt.m}
got := matches.matchProtoAndIPsOnlyIfAllPorts(&tt.p)
if got != tt.want {
t.Errorf("got = %v; want %v", got, tt.want)
}
})
}
}

View File

@ -20,6 +20,8 @@ type PortRange struct {
First, Last uint16 // inclusive First, Last uint16 // inclusive
} }
var allPorts = PortRange{0, 0xffff}
func (pr PortRange) String() string { func (pr PortRange) String() string {
if pr.First == 0 && pr.Last == 65535 { if pr.First == 0 && pr.Last == 65535 {
return "*" return "*"
@ -115,6 +117,29 @@ func (ms matches) matchIPsOnly(q *packet.Parsed) bool {
return false return false
} }
// matchProtoAndIPsOnlyIfAllPorts reports q matches any Match in ms where the
// Match if for the right IP Protocol and IP address, but ports are
// ignored, as long as the match is for the entire uint16 port range.
func (ms matches) matchProtoAndIPsOnlyIfAllPorts(q *packet.Parsed) bool {
for _, m := range ms {
if !protoInList(q.IPProto, m.IPProto) {
continue
}
if !ipInList(q.Src.IP(), m.Srcs) {
continue
}
for _, dst := range m.Dsts {
if dst.Ports != allPorts {
continue
}
if dst.Net.Contains(q.Dst.IP()) {
return true
}
}
}
return false
}
func ipInList(ip netaddr.IP, netlist []netaddr.IPPrefix) bool { func ipInList(ip netaddr.IP, netlist []netaddr.IPPrefix) bool {
for _, net := range netlist { for _, net := range netlist {
if net.Contains(ip) { if net.Contains(ip) {