Improve ACLs by adding protocol parsing support

This commit is contained in:
Juan Font Alonso 2022-06-08 17:43:59 +02:00
parent 3e353004b8
commit ab1aac9f3e
3 changed files with 93 additions and 14 deletions

70
acls.go
View File

@ -23,6 +23,7 @@ const (
errInvalidGroup = Error("invalid group")
errInvalidTag = Error("invalid tag")
errInvalidPortFormat = Error("invalid port format")
errWildcardIsNeeded = Error("wildcard as port is required for the procotol")
)
const (
@ -134,9 +135,17 @@ func (h *Headscale) generateACLRules() ([]tailcfg.FilterRule, error) {
srcIPs = append(srcIPs, srcs...)
}
protocols, needsWildcard, err := parseProtocol(acl.Protocol)
if err != nil {
log.Error().
Msgf("Error parsing ACL %d. protocol unknown %s", index, acl.Protocol)
return nil, err
}
destPorts := []tailcfg.NetPortRange{}
for innerIndex, dest := range acl.Destinations {
dests, err := h.generateACLPolicyDest(machines, *h.aclPolicy, dest)
dests, err := h.generateACLPolicyDest(machines, *h.aclPolicy, dest, needsWildcard)
if err != nil {
log.Error().
Msgf("Error parsing ACL %d, Destination %d", index, innerIndex)
@ -149,6 +158,7 @@ func (h *Headscale) generateACLRules() ([]tailcfg.FilterRule, error) {
rules = append(rules, tailcfg.FilterRule{
SrcIPs: srcIPs,
DstPorts: destPorts,
IPProto: protocols,
})
}
@ -167,6 +177,7 @@ func (h *Headscale) generateACLPolicyDest(
machines []Machine,
aclPolicy ACLPolicy,
dest string,
needsWildcard bool,
) ([]tailcfg.NetPortRange, error) {
tokens := strings.Split(dest, ":")
if len(tokens) < expectedTokenItems || len(tokens) > 3 {
@ -195,7 +206,7 @@ func (h *Headscale) generateACLPolicyDest(
if err != nil {
return nil, err
}
ports, err := expandPorts(tokens[len(tokens)-1])
ports, err := expandPorts(tokens[len(tokens)-1], needsWildcard)
if err != nil {
return nil, err
}
@ -214,6 +225,54 @@ func (h *Headscale) generateACLPolicyDest(
return dests, nil
}
// parseProtocol reads the proto field of the ACL and generates a list of
// protocols that will be allowed, following the IANA IP protocol number
// https://www.iana.org/assignments/protocol-numbers/protocol-numbers.xhtml
//
// If the ACL proto field is empty, it allows ICMPv4, ICMPv6, TCP, and UDP,
// as per Tailscale behaviour (see tailcfg.FilterRule).
//
// Also returns a boolean indicating if the protocol
// requires all the destinations to use wildcard as port number (only TCP,
// UDP and SCTP support specifying ports).
func parseProtocol(protocol string) ([]int, bool, error) {
switch protocol {
case "":
return []int{1, 58, 6, 17}, false, nil
case "igmp":
return []int{2}, true, nil
case "ipv4", "ip-in-ip":
return []int{4}, true, nil
case "tcp":
return []int{6}, false, nil
case "egp":
return []int{8}, true, nil
case "igp":
return []int{9}, true, nil
case "udp":
return []int{17}, false, nil
case "gre":
return []int{47}, true, nil
case "esp":
return []int{50}, true, nil
case "ah":
return []int{51}, true, nil
case "sctp":
return []int{132}, false, nil
case "icmp":
return []int{1, 58}, true, nil
default:
protocolNumber, err := strconv.Atoi(protocol)
if err != nil {
return nil, false, err
}
needsWildcard := protocolNumber != 6 && protocolNumber != 17 && protocolNumber != 132
return []int{protocolNumber}, needsWildcard, nil
}
}
// expandalias has an input of either
// - a namespace
// - a group
@ -268,6 +327,7 @@ func expandAlias(
alias,
)
}
return ips, nil
} else {
return ips, err
@ -359,13 +419,17 @@ func excludeCorrectlyTaggedNodes(
return out
}
func expandPorts(portsStr string) (*[]tailcfg.PortRange, error) {
func expandPorts(portsStr string, needsWildcard bool) (*[]tailcfg.PortRange, error) {
if portsStr == "*" {
return &[]tailcfg.PortRange{
{First: portRangeBegin, Last: portRangeEnd},
}, nil
}
if needsWildcard {
return nil, errWildcardIsNeeded
}
ports := []tailcfg.PortRange{}
for _, portStr := range strings.Split(portsStr, ",") {
rang := strings.Split(portStr, "-")

View File

@ -628,7 +628,8 @@ func Test_expandTagOwners(t *testing.T) {
func Test_expandPorts(t *testing.T) {
type args struct {
portsStr string
portsStr string
needsWildcard bool
}
tests := []struct {
name string
@ -638,15 +639,29 @@ func Test_expandPorts(t *testing.T) {
}{
{
name: "wildcard",
args: args{portsStr: "*"},
args: args{portsStr: "*", needsWildcard: true},
want: &[]tailcfg.PortRange{
{First: portRangeBegin, Last: portRangeEnd},
},
wantErr: false,
},
{
name: "needs wildcard but does not require it",
args: args{portsStr: "*", needsWildcard: false},
want: &[]tailcfg.PortRange{
{First: portRangeBegin, Last: portRangeEnd},
},
wantErr: false,
},
{
name: "needs wildcard but gets port",
args: args{portsStr: "80,443", needsWildcard: true},
want: nil,
wantErr: true,
},
{
name: "two Destinations",
args: args{portsStr: "80,443"},
args: args{portsStr: "80,443", needsWildcard: false},
want: &[]tailcfg.PortRange{
{First: 80, Last: 80},
{First: 443, Last: 443},
@ -655,7 +670,7 @@ func Test_expandPorts(t *testing.T) {
},
{
name: "a range and a port",
args: args{portsStr: "80-1024,443"},
args: args{portsStr: "80-1024,443", needsWildcard: false},
want: &[]tailcfg.PortRange{
{First: 80, Last: 1024},
{First: 443, Last: 443},
@ -664,38 +679,38 @@ func Test_expandPorts(t *testing.T) {
},
{
name: "out of bounds",
args: args{portsStr: "854038"},
args: args{portsStr: "854038", needsWildcard: false},
want: nil,
wantErr: true,
},
{
name: "wrong port",
args: args{portsStr: "85a38"},
args: args{portsStr: "85a38", needsWildcard: false},
want: nil,
wantErr: true,
},
{
name: "wrong port in first",
args: args{portsStr: "a-80"},
args: args{portsStr: "a-80", needsWildcard: false},
want: nil,
wantErr: true,
},
{
name: "wrong port in last",
args: args{portsStr: "80-85a38"},
args: args{portsStr: "80-85a38", needsWildcard: false},
want: nil,
wantErr: true,
},
{
name: "wrong port format",
args: args{portsStr: "80-85a38-3"},
args: args{portsStr: "80-85a38-3", needsWildcard: false},
want: nil,
wantErr: true,
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
got, err := expandPorts(test.args.portsStr)
got, err := expandPorts(test.args.portsStr, test.args.needsWildcard)
if (err != nil) != test.wantErr {
t.Errorf("expandPorts() error = %v, wantErr %v", err, test.wantErr)

View File

@ -21,7 +21,7 @@ type ACLPolicy struct {
// ACL is a basic rule for the ACL Policy.
type ACL struct {
Action string `json:"action" yaml:"action"`
Protocol string `json:"protocol" yaml:"protocol"`
Protocol string `json:"proto" yaml:"proto"`
Sources []string `json:"src" yaml:"src"`
Destinations []string `json:"dst" yaml:"dst"`
}