From 69de3bf7bfddb37b4c0e076c93115f82a51ec407 Mon Sep 17 00:00:00 2001 From: Brad Fitzpatrick Date: Sat, 4 Dec 2021 11:52:39 -0800 Subject: [PATCH] wgengine/filter: let unknown IPProto match if IP okay & match allows all ports RELNOTE=yes Change-Id: I96eaf3cf550cee7bb6cdb4ad81fc761e280a1b2a Signed-off-by: Brad Fitzpatrick --- wgengine/filter/filter.go | 6 +++ wgengine/filter/filter_test.go | 83 ++++++++++++++++++++++++++-------- wgengine/filter/match.go | 25 ++++++++++ 3 files changed, 96 insertions(+), 18 deletions(-) diff --git a/wgengine/filter/filter.go b/wgengine/filter/filter.go index d87066c9e..8cf5dfa50 100644 --- a/wgengine/filter/filter.go +++ b/wgengine/filter/filter.go @@ -384,6 +384,9 @@ func (f *Filter) runIn4(q *packet.Parsed) (r Response, why string) { case ipproto.TSMP: return Accept, "tsmp ok" default: + if f.matches4.matchProtoAndIPsOnlyIfAllPorts(q) { + return Accept, "otherproto ok" + } return Drop, "Unknown proto" } return Drop, "no rules matched" @@ -441,6 +444,9 @@ func (f *Filter) runIn6(q *packet.Parsed) (r Response, why string) { case ipproto.TSMP: return Accept, "tsmp ok" default: + if f.matches6.matchProtoAndIPsOnlyIfAllPorts(q) { + return Accept, "otherproto ok" + } return Drop, "Unknown proto" } return Drop, "no rules matched" diff --git a/wgengine/filter/filter_test.go b/wgengine/filter/filter_test.go index 9b8e7a708..a09927c0a 100644 --- a/wgengine/filter/filter_test.go +++ b/wgengine/filter/filter_test.go @@ -23,17 +23,25 @@ "tailscale.com/types/logger" ) -func newFilter(logf logger.Logf) *Filter { - m := func(srcs []netaddr.IPPrefix, dsts []NetPortRange, protos ...ipproto.Proto) Match { - if protos == nil { - protos = defaultProtos - } - return Match{ - IPProto: protos, - Srcs: srcs, - Dsts: dsts, - } +// testAllowedProto is an IP protocol number we treat as allowed for +// these tests. +const ( + testAllowedProto ipproto.Proto = 116 + testDeniedProto ipproto.Proto = 127 // CRUDP, appropriately cruddy +) + +func m(srcs []netaddr.IPPrefix, dsts []NetPortRange, protos ...ipproto.Proto) Match { + if protos == nil { + protos = defaultProtos } + return Match{ + IPProto: protos, + Srcs: srcs, + Dsts: dsts, + } +} + +func newFilter(logf logger.Logf) *Filter { matches := []Match{ m(nets("8.1.1.1", "8.2.2.2"), netports("1.2.3.4:22", "5.6.7.8:23-24")), m(nets("9.1.1.1", "9.2.2.2"), netports("1.2.3.4:22", "5.6.7.8:23-24"), ipproto.SCTP), @@ -44,6 +52,8 @@ func newFilter(logf logger.Logf) *Filter { m(nets("153.1.1.1", "153.1.1.2", "153.3.3.3"), netports("1.2.3.4:999")), m(nets("::1", "::2"), netports("2001::1:22", "2001::2:22")), m(nets("::/0"), netports("::/0:443")), + m(nets("0.0.0.0/0"), netports("0.0.0.0/0:*"), testAllowedProto), + m(nets("::/0"), netports("::/0:*"), testAllowedProto), } // Expects traffic to 100.122.98.50, 1.2.3.4, 5.6.7.8, @@ -112,6 +122,12 @@ type InOut struct { {Drop, parsed(ipproto.SCTP, "8.1.1.1", "1.2.3.4", 999, 22)}, // But SCTP is allowed for 9.1.1.1 {Accept, parsed(ipproto.SCTP, "9.1.1.1", "1.2.3.4", 999, 22)}, + + // Unknown protocol is allowed if all its ports are allowed. + {Accept, parsed(testAllowedProto, "1.2.3.4", "5.6.7.8", 0, 0)}, + {Accept, parsed(testAllowedProto, "2001::1", "2001::2", 0, 0)}, + {Drop, parsed(testDeniedProto, "1.2.3.4", "5.6.7.8", 0, 0)}, + {Drop, parsed(testDeniedProto, "2001::1", "2001::2", 0, 0)}, } for i, test := range tests { aclFunc := acl.runIn4 @@ -534,13 +550,7 @@ func TestLoggingPrivacy(t *testing.T) { } } -func mustIP(s string) netaddr.IP { - ip, err := netaddr.ParseIP(s) - if err != nil { - panic(err) - } - return ip -} +var mustIP = netaddr.MustParseIP func parsed(proto ipproto.Proto, src, dst string, sport, dport uint16) packet.Parsed { sip, dip := mustIP(src), mustIP(dst) @@ -689,7 +699,7 @@ func nets(nets ...string) (ret []netaddr.IPPrefix) { func ports(s string) PortRange { if s == "*" { - return PortRange{First: 0, Last: 65535} + return allPorts } var fs, ls string @@ -825,3 +835,40 @@ func TestNewAllowAllForTest(t *testing.T) { t.Fatalf("unexpected drop verdict: %v", res) } } + +func TestMatchesMatchProtoAndIPsOnlyIfAllPorts(t *testing.T) { + tests := []struct { + name string + m Match + p packet.Parsed + want bool + }{ + { + name: "all_ports_okay", + m: m(nets("0.0.0.0/0"), netports("0.0.0.0/0:*"), testAllowedProto), + p: parsed(testAllowedProto, "1.2.3.4", "5.6.7.8", 0, 0), + want: true, + }, + { + name: "all_ports_match_but_packet_wrong_proto", + m: m(nets("0.0.0.0/0"), netports("0.0.0.0/0:*"), testAllowedProto), + p: parsed(testDeniedProto, "1.2.3.4", "5.6.7.8", 0, 0), + want: false, + }, + { + name: "ports_requirements_dont_match_unknown_proto", + m: m(nets("0.0.0.0/0"), netports("0.0.0.0/0:12345"), testAllowedProto), + p: parsed(testAllowedProto, "1.2.3.4", "5.6.7.8", 0, 0), + want: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + matches := matches{tt.m} + got := matches.matchProtoAndIPsOnlyIfAllPorts(&tt.p) + if got != tt.want { + t.Errorf("got = %v; want %v", got, tt.want) + } + }) + } +} diff --git a/wgengine/filter/match.go b/wgengine/filter/match.go index 74807dda7..dae60870e 100644 --- a/wgengine/filter/match.go +++ b/wgengine/filter/match.go @@ -20,6 +20,8 @@ type PortRange struct { First, Last uint16 // inclusive } +var allPorts = PortRange{0, 0xffff} + func (pr PortRange) String() string { if pr.First == 0 && pr.Last == 65535 { return "*" @@ -115,6 +117,29 @@ func (ms matches) matchIPsOnly(q *packet.Parsed) bool { return false } +// matchProtoAndIPsOnlyIfAllPorts reports q matches any Match in ms where the +// Match if for the right IP Protocol and IP address, but ports are +// ignored, as long as the match is for the entire uint16 port range. +func (ms matches) matchProtoAndIPsOnlyIfAllPorts(q *packet.Parsed) bool { + for _, m := range ms { + if !protoInList(q.IPProto, m.IPProto) { + continue + } + if !ipInList(q.Src.IP(), m.Srcs) { + continue + } + for _, dst := range m.Dsts { + if dst.Ports != allPorts { + continue + } + if dst.Net.Contains(q.Dst.IP()) { + return true + } + } + } + return false +} + func ipInList(ip netaddr.IP, netlist []netaddr.IPPrefix) bool { for _, net := range netlist { if net.Contains(ip) {