diff --git a/hscontrol/policy/acls.go b/hscontrol/policy/acls.go index d3ad12fd..35580aac 100644 --- a/hscontrol/policy/acls.go +++ b/hscontrol/policy/acls.go @@ -128,7 +128,7 @@ func GenerateFilterRules( return tailcfg.FilterAllowAll, &tailcfg.SSHPolicy{}, nil } - rules, err := policy.generateFilterRules(append(peers, *machine), stripEmailDomain) + rules, err := policy.generateFilterRules(machine, peers, stripEmailDomain) if err != nil { return []tailcfg.FilterRule{}, &tailcfg.SSHPolicy{}, err } @@ -152,10 +152,12 @@ func GenerateFilterRules( // generateFilterRules takes a set of machines and an ACLPolicy and generates a // set of Tailscale compatible FilterRules used to allow traffic on clients. func (pol *ACLPolicy) generateFilterRules( - machines types.Machines, + machine *types.Machine, + peers types.Machines, stripEmailDomain bool, ) ([]tailcfg.FilterRule, error) { rules := []tailcfg.FilterRule{} + machines := append(peers, *machine) for index, acl := range pol.ACLs { if acl.Action != "accept" { diff --git a/hscontrol/policy/acls_test.go b/hscontrol/policy/acls_test.go index d7f5932a..5652e8c6 100644 --- a/hscontrol/policy/acls_test.go +++ b/hscontrol/policy/acls_test.go @@ -199,7 +199,7 @@ func (s *Suite) TestRuleInvalidGeneration(c *check.C) { c.Assert(pol.ACLs, check.HasLen, 6) c.Assert(err, check.IsNil) - rules, err := pol.generateFilterRules(types.Machines{}, false) + rules, err := pol.generateFilterRules(&types.Machine{}, types.Machines{}, false) c.Assert(err, check.NotNil) c.Assert(rules, check.IsNil) } @@ -230,7 +230,7 @@ func (s *Suite) TestBasicRule(c *check.C) { pol, err := LoadACLPolicyFromBytes(acl, "hujson") c.Assert(err, check.IsNil) - rules, err := pol.generateFilterRules(types.Machines{}, false) + rules, err := pol.generateFilterRules(&types.Machine{}, types.Machines{}, false) c.Assert(err, check.IsNil) c.Assert(rules, check.NotNil) } @@ -310,7 +310,7 @@ func (s *Suite) TestPortRange(c *check.C) { c.Assert(err, check.IsNil) c.Assert(pol, check.NotNil) - rules, err := pol.generateFilterRules(types.Machines{}, false) + rules, err := pol.generateFilterRules(&types.Machine{}, types.Machines{}, false) c.Assert(err, check.IsNil) c.Assert(rules, check.NotNil) @@ -366,7 +366,7 @@ func (s *Suite) TestProtocolParsing(c *check.C) { c.Assert(err, check.IsNil) c.Assert(pol, check.NotNil) - rules, err := pol.generateFilterRules(types.Machines{}, false) + rules, err := pol.generateFilterRules(&types.Machine{}, types.Machines{}, false) c.Assert(err, check.IsNil) c.Assert(rules, check.NotNil) @@ -401,7 +401,7 @@ func (s *Suite) TestPortWildcard(c *check.C) { c.Assert(err, check.IsNil) c.Assert(pol, check.NotNil) - rules, err := pol.generateFilterRules(types.Machines{}, false) + rules, err := pol.generateFilterRules(&types.Machine{}, types.Machines{}, false) c.Assert(err, check.IsNil) c.Assert(rules, check.NotNil) @@ -428,7 +428,7 @@ acls: c.Assert(err, check.IsNil) c.Assert(pol, check.NotNil) - rules, err := pol.generateFilterRules(types.Machines{}, false) + rules, err := pol.generateFilterRules(&types.Machine{}, types.Machines{}, false) c.Assert(err, check.IsNil) c.Assert(rules, check.NotNil) @@ -459,7 +459,7 @@ acls: c.Assert(err, check.IsNil) c.Assert(pol, check.NotNil) - rules, err := pol.generateFilterRules(types.Machines{}, false) + rules, err := pol.generateFilterRules(&types.Machine{}, types.Machines{}, false) c.Assert(err, check.IsNil) c.Assert(rules, check.NotNil) @@ -1620,7 +1620,8 @@ func TestACLPolicy_generateFilterRules(t *testing.T) { pol ACLPolicy } type args struct { - machines types.Machines + machine types.Machine + peers types.Machines stripEmailDomain bool } tests := []struct { @@ -1651,7 +1652,8 @@ func TestACLPolicy_generateFilterRules(t *testing.T) { }, }, args: args{ - machines: types.Machines{}, + machine: types.Machine{}, + peers: types.Machines{}, stripEmailDomain: true, }, want: []tailcfg.FilterRule{ @@ -1691,14 +1693,14 @@ func TestACLPolicy_generateFilterRules(t *testing.T) { }, }, args: args{ - machines: types.Machines{ - { - IPAddresses: types.MachineAddresses{ - netip.MustParseAddr("100.64.0.1"), - netip.MustParseAddr("fd7a:115c:a1e0:ab12:4843:2222:6273:2221"), - }, - User: types.User{Name: "mickael"}, + machine: types.Machine{ + IPAddresses: types.MachineAddresses{ + netip.MustParseAddr("100.64.0.1"), + netip.MustParseAddr("fd7a:115c:a1e0:ab12:4843:2222:6273:2221"), }, + User: types.User{Name: "mickael"}, + }, + peers: types.Machines{ { IPAddresses: types.MachineAddresses{ netip.MustParseAddr("100.64.0.2"), @@ -1739,7 +1741,8 @@ func TestACLPolicy_generateFilterRules(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { got, err := tt.field.pol.generateFilterRules( - tt.args.machines, + &tt.args.machine, + tt.args.peers, tt.args.stripEmailDomain, ) if (err != nil) != tt.wantErr {