diff --git a/acls.go b/acls.go index c5e4864a..3f7ce06c 100644 --- a/acls.go +++ b/acls.go @@ -128,7 +128,7 @@ func (h *Headscale) UpdateACLRules() error { return errEmptyPolicy } - rules, err := generateACLRules(machines, *h.aclPolicy, h.cfg.OIDC.StripEmaildomain) + rules, err := generateFilterRules(machines, *h.aclPolicy, h.cfg.OIDC.StripEmaildomain) if err != nil { return err } @@ -224,24 +224,29 @@ func expandACLPeerAddr(srcIP string) []string { return []string{srcIP} } -func generateACLRules( +// generateFilterRules takes a set of machines and an ACLPolicy and generates a +// set of Tailscale compatible FilterRules used to allow traffic on clients. +func generateFilterRules( machines []Machine, - aclPolicy ACLPolicy, + pol ACLPolicy, stripEmaildomain bool, ) ([]tailcfg.FilterRule, error) { rules := []tailcfg.FilterRule{} - for index, acl := range aclPolicy.ACLs { + for index, acl := range pol.ACLs { if acl.Action != "accept" { return nil, errInvalidAction } srcIPs := []string{} - for innerIndex, src := range acl.Sources { - srcs, err := generateACLPolicySrc(machines, aclPolicy, src, stripEmaildomain) + for srcIndex, src := range acl.Sources { + srcs, err := pol.getIPsFromSource(src, machines, stripEmaildomain) if err != nil { log.Error(). - Msgf("Error parsing ACL %d, Source %d", index, innerIndex) + Interface("src", src). + Int("ACL index", index). + Int("Src index", srcIndex). + Msgf("Error parsing ACL") return nil, err } @@ -257,17 +262,19 @@ func generateACLRules( } destPorts := []tailcfg.NetPortRange{} - for innerIndex, dest := range acl.Destinations { - dests, err := generateACLPolicyDest( - machines, - aclPolicy, + for destIndex, dest := range acl.Destinations { + dests, err := pol.getNetPortRangeFromDestination( dest, + machines, needsWildcard, stripEmaildomain, ) if err != nil { log.Error(). - Msgf("Error parsing ACL %d, Destination %d", index, innerIndex) + Interface("dest", dest). + Int("ACL index", index). + Int("dest index", destIndex). + Msgf("Error parsing ACL") return nil, err } @@ -388,19 +395,21 @@ func sshCheckAction(duration string) (*tailcfg.SSHAction, error) { }, nil } -func generateACLPolicySrc( - machines []Machine, - pol ACLPolicy, +// getIPsFromSource returns a set of Source IPs that would be associated +// with the given src alias. +func (pol *ACLPolicy) getIPsFromSource( src string, + machines []Machine, stripEmaildomain bool, ) ([]string, error) { return pol.expandAlias(machines, src, stripEmaildomain) } -func generateACLPolicyDest( - machines []Machine, - pol ACLPolicy, +// getNetPortRangeFromDestination returns a set of tailcfg.NetPortRange +// which are associated with the dest alias. +func (pol *ACLPolicy) getNetPortRangeFromDestination( dest string, + machines []Machine, needsWildcard bool, stripEmaildomain bool, ) ([]tailcfg.NetPortRange, error) { diff --git a/acls_test.go b/acls_test.go index 5df19082..cd801fba 100644 --- a/acls_test.go +++ b/acls_test.go @@ -54,7 +54,7 @@ func (s *Suite) TestBasicRule(c *check.C) { err := app.LoadACLPolicy("./tests/acls/acl_policy_basic_1.hujson") c.Assert(err, check.IsNil) - rules, err := generateACLRules([]Machine{}, *app.aclPolicy, false) + rules, err := generateFilterRules([]Machine{}, *app.aclPolicy, false) c.Assert(err, check.IsNil) c.Assert(rules, check.NotNil) } @@ -411,7 +411,7 @@ func (s *Suite) TestPortRange(c *check.C) { err := app.LoadACLPolicy("./tests/acls/acl_policy_basic_range.hujson") c.Assert(err, check.IsNil) - rules, err := generateACLRules([]Machine{}, *app.aclPolicy, false) + rules, err := generateFilterRules([]Machine{}, *app.aclPolicy, false) c.Assert(err, check.IsNil) c.Assert(rules, check.NotNil) @@ -425,7 +425,7 @@ func (s *Suite) TestProtocolParsing(c *check.C) { err := app.LoadACLPolicy("./tests/acls/acl_policy_basic_protocols.hujson") c.Assert(err, check.IsNil) - rules, err := generateACLRules([]Machine{}, *app.aclPolicy, false) + rules, err := generateFilterRules([]Machine{}, *app.aclPolicy, false) c.Assert(err, check.IsNil) c.Assert(rules, check.NotNil) @@ -439,7 +439,7 @@ func (s *Suite) TestPortWildcard(c *check.C) { err := app.LoadACLPolicy("./tests/acls/acl_policy_basic_wildcards.hujson") c.Assert(err, check.IsNil) - rules, err := generateACLRules([]Machine{}, *app.aclPolicy, false) + rules, err := generateFilterRules([]Machine{}, *app.aclPolicy, false) c.Assert(err, check.IsNil) c.Assert(rules, check.NotNil) @@ -455,7 +455,7 @@ func (s *Suite) TestPortWildcardYAML(c *check.C) { err := app.LoadACLPolicy("./tests/acls/acl_policy_basic_wildcards.yaml") c.Assert(err, check.IsNil) - rules, err := generateACLRules([]Machine{}, *app.aclPolicy, false) + rules, err := generateFilterRules([]Machine{}, *app.aclPolicy, false) c.Assert(err, check.IsNil) c.Assert(rules, check.NotNil) @@ -498,7 +498,7 @@ func (s *Suite) TestPortUser(c *check.C) { machines, err := app.ListMachines() c.Assert(err, check.IsNil) - rules, err := generateACLRules(machines, *app.aclPolicy, false) + rules, err := generateFilterRules(machines, *app.aclPolicy, false) c.Assert(err, check.IsNil) c.Assert(rules, check.NotNil) @@ -541,7 +541,7 @@ func (s *Suite) TestPortGroup(c *check.C) { machines, err := app.ListMachines() c.Assert(err, check.IsNil) - rules, err := generateACLRules(machines, *app.aclPolicy, false) + rules, err := generateFilterRules(machines, *app.aclPolicy, false) c.Assert(err, check.IsNil) c.Assert(rules, check.NotNil)