diff --git a/acls.go b/acls.go index c6b65a1e..71395791 100644 --- a/acls.go +++ b/acls.go @@ -117,7 +117,16 @@ func (h *Headscale) LoadACLPolicy(path string) error { } func (h *Headscale) UpdateACLRules() error { - rules, err := h.generateACLRules() + machines, err := h.ListMachines() + if err != nil { + return err + } + + if h.aclPolicy == nil { + return errEmptyPolicy + } + + rules, err := generateACLRules(machines, *h.aclPolicy, h.cfg.OIDC.StripEmaildomain) if err != nil { return err } @@ -141,26 +150,17 @@ func (h *Headscale) UpdateACLRules() error { return nil } -func (h *Headscale) generateACLRules() ([]tailcfg.FilterRule, error) { +func generateACLRules(machines []Machine, aclPolicy ACLPolicy, stripEmaildomain bool) ([]tailcfg.FilterRule, error) { rules := []tailcfg.FilterRule{} - if h.aclPolicy == nil { - return nil, errEmptyPolicy - } - - machines, err := h.ListMachines() - if err != nil { - return nil, err - } - - for index, acl := range h.aclPolicy.ACLs { + for index, acl := range aclPolicy.ACLs { if acl.Action != "accept" { return nil, errInvalidAction } srcIPs := []string{} for innerIndex, src := range acl.Sources { - srcs, err := h.generateACLPolicySrcIP(machines, *h.aclPolicy, src) + srcs, err := generateACLPolicySrcIP(machines, aclPolicy, src, stripEmaildomain) if err != nil { log.Error(). Msgf("Error parsing ACL %d, Source %d", index, innerIndex) @@ -180,11 +180,12 @@ func (h *Headscale) generateACLRules() ([]tailcfg.FilterRule, error) { destPorts := []tailcfg.NetPortRange{} for innerIndex, dest := range acl.Destinations { - dests, err := h.generateACLPolicyDest( + dests, err := generateACLPolicyDest( machines, - *h.aclPolicy, + aclPolicy, dest, needsWildcard, + stripEmaildomain, ) if err != nil { log.Error(). @@ -310,19 +311,21 @@ func sshCheckAction(duration string) (*tailcfg.SSHAction, error) { }, nil } -func (h *Headscale) generateACLPolicySrcIP( +func generateACLPolicySrcIP( machines []Machine, aclPolicy ACLPolicy, src string, + stripEmaildomain bool, ) ([]string, error) { - return expandAlias(machines, aclPolicy, src, h.cfg.OIDC.StripEmaildomain) + return expandAlias(machines, aclPolicy, src, stripEmaildomain) } -func (h *Headscale) generateACLPolicyDest( +func generateACLPolicyDest( machines []Machine, aclPolicy ACLPolicy, dest string, needsWildcard bool, + stripEmaildomain bool, ) ([]tailcfg.NetPortRange, error) { tokens := strings.Split(dest, ":") if len(tokens) < expectedTokenItems || len(tokens) > 3 { @@ -346,7 +349,7 @@ func (h *Headscale) generateACLPolicyDest( machines, aclPolicy, alias, - h.cfg.OIDC.StripEmaildomain, + stripEmaildomain, ) if err != nil { return nil, err diff --git a/acls_test.go b/acls_test.go index 23e7f917..c11e074f 100644 --- a/acls_test.go +++ b/acls_test.go @@ -54,7 +54,7 @@ func (s *Suite) TestBasicRule(c *check.C) { err := app.LoadACLPolicy("./tests/acls/acl_policy_basic_1.hujson") c.Assert(err, check.IsNil) - rules, err := app.generateACLRules() + rules, err := generateACLRules([]Machine{}, *app.aclPolicy, false) c.Assert(err, check.IsNil) c.Assert(rules, check.NotNil) } @@ -411,7 +411,7 @@ func (s *Suite) TestPortRange(c *check.C) { err := app.LoadACLPolicy("./tests/acls/acl_policy_basic_range.hujson") c.Assert(err, check.IsNil) - rules, err := app.generateACLRules() + rules, err := generateACLRules([]Machine{}, *app.aclPolicy, false) c.Assert(err, check.IsNil) c.Assert(rules, check.NotNil) @@ -425,7 +425,7 @@ func (s *Suite) TestProtocolParsing(c *check.C) { err := app.LoadACLPolicy("./tests/acls/acl_policy_basic_protocols.hujson") c.Assert(err, check.IsNil) - rules, err := app.generateACLRules() + rules, err := generateACLRules([]Machine{}, *app.aclPolicy, false) c.Assert(err, check.IsNil) c.Assert(rules, check.NotNil) @@ -439,7 +439,7 @@ func (s *Suite) TestPortWildcard(c *check.C) { err := app.LoadACLPolicy("./tests/acls/acl_policy_basic_wildcards.hujson") c.Assert(err, check.IsNil) - rules, err := app.generateACLRules() + rules, err := generateACLRules([]Machine{}, *app.aclPolicy, false) c.Assert(err, check.IsNil) c.Assert(rules, check.NotNil) @@ -455,7 +455,7 @@ func (s *Suite) TestPortWildcardYAML(c *check.C) { err := app.LoadACLPolicy("./tests/acls/acl_policy_basic_wildcards.yaml") c.Assert(err, check.IsNil) - rules, err := app.generateACLRules() + rules, err := generateACLRules([]Machine{}, *app.aclPolicy, false) c.Assert(err, check.IsNil) c.Assert(rules, check.NotNil) @@ -495,7 +495,10 @@ func (s *Suite) TestPortNamespace(c *check.C) { ) c.Assert(err, check.IsNil) - rules, err := app.generateACLRules() + machines, err := app.ListMachines() + c.Assert(err, check.IsNil) + + rules, err := generateACLRules(machines, *app.aclPolicy, false) c.Assert(err, check.IsNil) c.Assert(rules, check.NotNil) @@ -535,7 +538,10 @@ func (s *Suite) TestPortGroup(c *check.C) { err = app.LoadACLPolicy("./tests/acls/acl_policy_basic_groups.hujson") c.Assert(err, check.IsNil) - rules, err := app.generateACLRules() + machines, err := app.ListMachines() + c.Assert(err, check.IsNil) + + rules, err := generateACLRules(machines, *app.aclPolicy, false) c.Assert(err, check.IsNil) c.Assert(rules, check.NotNil)