diff --git a/acls.go b/acls.go index 9dd1260d..3d6b1945 100644 --- a/acls.go +++ b/acls.go @@ -20,7 +20,6 @@ const ( errInvalidUserSection = Error("invalid user section") errInvalidGroup = Error("invalid group") errInvalidTag = Error("invalid tag") - errInvalidNamespace = Error("invalid namespace") errInvalidPortFormat = Error("invalid port format") ) @@ -69,6 +68,7 @@ func (h *Headscale) LoadACLPolicy(path string) error { } h.aclPolicy = &policy + return h.UpdateACLRules() } @@ -79,6 +79,7 @@ func (h *Headscale) UpdateACLRules() error { } log.Trace().Interface("ACL", rules).Msg("ACL rules generated") h.aclRules = rules + return nil } @@ -182,7 +183,7 @@ func (h *Headscale) generateACLPolicyDestPorts( // - a namespace // - a group // - a tag -// and transform these in IPAddresses +// and transform these in IPAddresses. func expandAlias(machines []Machine, aclPolicy ACLPolicy, alias string) ([]string, error) { ips := []string{} if alias == "*" { @@ -200,6 +201,7 @@ func expandAlias(machines []Machine, aclPolicy ACLPolicy, alias string) ([]strin ips = append(ips, node.IPAddresses.ToStringSlice()...) } } + return ips, nil } @@ -225,6 +227,7 @@ func expandAlias(machines []Machine, aclPolicy ACLPolicy, alias string) ([]strin } } } + return ips, nil } @@ -276,6 +279,7 @@ func excludeCorrectlyTaggedNodes(aclPolicy ACLPolicy, nodes []Machine, namespace for _, machine := range nodes { if len(machine.HostInfo) == 0 { out = append(out, machine) + continue } hi, err := machine.GetHostInfo() @@ -286,6 +290,7 @@ func excludeCorrectlyTaggedNodes(aclPolicy ACLPolicy, nodes []Machine, namespace for _, t := range hi.RequestTags { if containsString(tags, t) { found = true + break } } @@ -293,6 +298,7 @@ func excludeCorrectlyTaggedNodes(aclPolicy ACLPolicy, nodes []Machine, namespace out = append(out, machine) } } + return out, nil } @@ -346,42 +352,45 @@ func listMachinesInNamespace(machines []Machine, namespace string) []Machine { out = append(out, machine) } } + return out } // expandTagOwners will return a list of namespace. An owner can be either a namespace or a group -// a group cannot be composed of groups +// a group cannot be composed of groups. func expandTagOwners(aclPolicy ACLPolicy, tag string) ([]string, error) { var owners []string ows, ok := aclPolicy.TagOwners[tag] if !ok { return []string{}, fmt.Errorf("%w. %v isn't owned by a TagOwner. Please add one first. https://tailscale.com/kb/1018/acls/#tag-owners", errInvalidTag, tag) } - for _, ow := range ows { - if strings.HasPrefix(ow, "group:") { - gs, err := expandGroup(aclPolicy, ow) + for _, owner := range ows { + if strings.HasPrefix(owner, "group:") { + gs, err := expandGroup(aclPolicy, owner) if err != nil { return []string{}, err } owners = append(owners, gs...) } else { - owners = append(owners, ow) + owners = append(owners, owner) } } + return owners, nil } // expandGroup will return the list of namespace inside the group -// after some validation +// after some validation. func expandGroup(aclPolicy ACLPolicy, group string) ([]string, error) { - gs, ok := aclPolicy.Groups[group] + groups, ok := aclPolicy.Groups[group] if !ok { return []string{}, fmt.Errorf("group %v isn't registered. %w", group, errInvalidGroup) } - for _, g := range gs { + for _, g := range groups { if strings.HasPrefix(g, "group:") { return []string{}, fmt.Errorf("%w. A group cannot be composed of groups. https://tailscale.com/kb/1018/acls/#groups", errInvalidGroup) } } - return gs, nil + + return groups, nil } diff --git a/acls_test.go b/acls_test.go index 8bedb47d..786de0fc 100644 --- a/acls_test.go +++ b/acls_test.go @@ -94,7 +94,7 @@ func (s *Suite) TestInvalidTagOwners(c *check.C) { // this test should validate that we can expand a group in a TagOWner section and // match properly the IP's of the related hosts. The owner is valid and the tag is also valid. -// the tag is matched in the Users section +// the tag is matched in the Users section. func (s *Suite) TestValidExpandTagOwnersInUsers(c *check.C) { namespace, err := app.CreateNamespace("foo") c.Assert(err, check.IsNil) @@ -104,7 +104,7 @@ func (s *Suite) TestValidExpandTagOwnersInUsers(c *check.C) { _, err = app.GetMachine("foo", "testmachine") c.Assert(err, check.NotNil) - b := []byte("{\"OS\":\"centos\",\"Hostname\":\"foo\",\"RequestTags\":[\"tag:test\"]}") + hostInfo := []byte("{\"OS\":\"centos\",\"Hostname\":\"foo\",\"RequestTags\":[\"tag:test\"]}") machine := Machine{ ID: 0, MachineKey: "foo", @@ -116,7 +116,7 @@ func (s *Suite) TestValidExpandTagOwnersInUsers(c *check.C) { Registered: true, RegisterMethod: RegisterMethodAuthKey, AuthKeyID: uint(pak.ID), - HostInfo: datatypes.JSON(b), + HostInfo: datatypes.JSON(hostInfo), } app.db.Save(&machine) @@ -136,7 +136,7 @@ func (s *Suite) TestValidExpandTagOwnersInUsers(c *check.C) { // this test should validate that we can expand a group in a TagOWner section and // match properly the IP's of the related hosts. The owner is valid and the tag is also valid. -// the tag is matched in the Ports section +// the tag is matched in the Ports section. func (s *Suite) TestValidExpandTagOwnersInPorts(c *check.C) { namespace, err := app.CreateNamespace("foo") c.Assert(err, check.IsNil) @@ -146,7 +146,7 @@ func (s *Suite) TestValidExpandTagOwnersInPorts(c *check.C) { _, err = app.GetMachine("foo", "testmachine") c.Assert(err, check.NotNil) - b := []byte("{\"OS\":\"centos\",\"Hostname\":\"foo\",\"RequestTags\":[\"tag:test\"]}") + hostInfo := []byte("{\"OS\":\"centos\",\"Hostname\":\"foo\",\"RequestTags\":[\"tag:test\"]}") machine := Machine{ ID: 1, MachineKey: "foo", @@ -158,7 +158,7 @@ func (s *Suite) TestValidExpandTagOwnersInPorts(c *check.C) { Registered: true, RegisterMethod: RegisterMethodAuthKey, AuthKeyID: uint(pak.ID), - HostInfo: datatypes.JSON(b), + HostInfo: datatypes.JSON(hostInfo), } app.db.Save(&machine) @@ -178,7 +178,7 @@ func (s *Suite) TestValidExpandTagOwnersInPorts(c *check.C) { // need a test with: // tag on a host that isn't owned by a tag owners. So the namespace -// of the host should be valid +// of the host should be valid. func (s *Suite) TestInvalidTagValidNamespace(c *check.C) { namespace, err := app.CreateNamespace("foo") c.Assert(err, check.IsNil) @@ -188,7 +188,7 @@ func (s *Suite) TestInvalidTagValidNamespace(c *check.C) { _, err = app.GetMachine("foo", "testmachine") c.Assert(err, check.NotNil) - b := []byte("{\"OS\":\"centos\",\"Hostname\":\"foo\",\"RequestTags\":[\"tag:foo\"]}") + hostInfo := []byte("{\"OS\":\"centos\",\"Hostname\":\"foo\",\"RequestTags\":[\"tag:foo\"]}") machine := Machine{ ID: 1, MachineKey: "foo", @@ -200,7 +200,7 @@ func (s *Suite) TestInvalidTagValidNamespace(c *check.C) { Registered: true, RegisterMethod: RegisterMethodAuthKey, AuthKeyID: uint(pak.ID), - HostInfo: datatypes.JSON(b), + HostInfo: datatypes.JSON(hostInfo), } app.db.Save(&machine) @@ -229,7 +229,7 @@ func (s *Suite) TestValidTagInvalidNamespace(c *check.C) { _, err = app.GetMachine("foo", "webserver") c.Assert(err, check.NotNil) - b := []byte("{\"OS\":\"centos\",\"Hostname\":\"webserver\",\"RequestTags\":[\"tag:webapp\"]}") + hostInfo := []byte("{\"OS\":\"centos\",\"Hostname\":\"webserver\",\"RequestTags\":[\"tag:webapp\"]}") machine := Machine{ ID: 1, MachineKey: "foo", @@ -241,11 +241,11 @@ func (s *Suite) TestValidTagInvalidNamespace(c *check.C) { Registered: true, RegisterMethod: RegisterMethodAuthKey, AuthKeyID: uint(pak.ID), - HostInfo: datatypes.JSON(b), + HostInfo: datatypes.JSON(hostInfo), } app.db.Save(&machine) _, err = app.GetMachine("foo", "user") - b = []byte("{\"OS\":\"debian\",\"Hostname\":\"user\"}") + hostInfo = []byte("{\"OS\":\"debian\",\"Hostname\":\"user\"}") c.Assert(err, check.NotNil) machine = Machine{ ID: 2, @@ -258,7 +258,7 @@ func (s *Suite) TestValidTagInvalidNamespace(c *check.C) { Registered: true, RegisterMethod: RegisterMethodAuthKey, AuthKeyID: uint(pak.ID), - HostInfo: datatypes.JSON(b), + HostInfo: datatypes.JSON(hostInfo), } app.db.Save(&machine) @@ -430,15 +430,16 @@ func Test_expandGroup(t *testing.T) { wantErr: true, }, } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got, err := expandGroup(tt.args.aclPolicy, tt.args.group) - if (err != nil) != tt.wantErr { - t.Errorf("expandGroup() error = %v, wantErr %v", err, tt.wantErr) + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + got, err := expandGroup(test.args.aclPolicy, test.args.group) + if (err != nil) != test.wantErr { + t.Errorf("expandGroup() error = %v, wantErr %v", err, test.wantErr) + return } - if !reflect.DeepEqual(got, tt.want) { - t.Errorf("expandGroup() = %v, want %v", got, tt.want) + if !reflect.DeepEqual(got, test.want) { + t.Errorf("expandGroup() = %v, want %v", got, test.want) } }) } @@ -514,15 +515,16 @@ func Test_expandTagOwners(t *testing.T) { wantErr: true, }, } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got, err := expandTagOwners(tt.args.aclPolicy, tt.args.tag) - if (err != nil) != tt.wantErr { - t.Errorf("expandTagOwners() error = %v, wantErr %v", err, tt.wantErr) + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + got, err := expandTagOwners(test.args.aclPolicy, test.args.tag) + if (err != nil) != test.wantErr { + t.Errorf("expandTagOwners() error = %v, wantErr %v", err, test.wantErr) + return } - if !reflect.DeepEqual(got, tt.want) { - t.Errorf("expandTagOwners() = %v, want %v", got, tt.want) + if !reflect.DeepEqual(got, test.want) { + t.Errorf("expandTagOwners() = %v, want %v", got, test.want) } }) } @@ -595,15 +597,16 @@ func Test_expandPorts(t *testing.T) { wantErr: true, }, } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got, err := expandPorts(tt.args.portsStr) - if (err != nil) != tt.wantErr { - t.Errorf("expandPorts() error = %v, wantErr %v", err, tt.wantErr) + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + got, err := expandPorts(test.args.portsStr) + if (err != nil) != test.wantErr { + t.Errorf("expandPorts() error = %v, wantErr %v", err, test.wantErr) + return } - if !reflect.DeepEqual(got, tt.want) { - t.Errorf("expandPorts() = %v, want %v", got, tt.want) + if !reflect.DeepEqual(got, test.want) { + t.Errorf("expandPorts() = %v, want %v", got, test.want) } }) } @@ -824,15 +827,16 @@ func Test_expandAlias(t *testing.T) { wantErr: false, }, } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got, err := expandAlias(tt.args.machines, tt.args.aclPolicy, tt.args.alias) - if (err != nil) != tt.wantErr { - t.Errorf("expandAlias() error = %v, wantErr %v", err, tt.wantErr) + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + got, err := expandAlias(test.args.machines, test.args.aclPolicy, test.args.alias) + if (err != nil) != test.wantErr { + t.Errorf("expandAlias() error = %v, wantErr %v", err, test.wantErr) + return } - if !reflect.DeepEqual(got, tt.want) { - t.Errorf("expandAlias() = %v, want %v", got, tt.want) + if !reflect.DeepEqual(got, test.want) { + t.Errorf("expandAlias() = %v, want %v", got, test.want) } }) } @@ -889,15 +893,16 @@ func Test_excludeCorrectlyTaggedNodes(t *testing.T) { wantErr: false, }, } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got, err := excludeCorrectlyTaggedNodes(tt.args.aclPolicy, tt.args.nodes, tt.args.namespace) - if (err != nil) != tt.wantErr { - t.Errorf("excludeCorrectlyTaggedNodes() error = %v, wantErr %v", err, tt.wantErr) + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + got, err := excludeCorrectlyTaggedNodes(test.args.aclPolicy, test.args.nodes, test.args.namespace) + if (err != nil) != test.wantErr { + t.Errorf("excludeCorrectlyTaggedNodes() error = %v, wantErr %v", err, test.wantErr) + return } - if !reflect.DeepEqual(got, tt.want) { - t.Errorf("excludeCorrectlyTaggedNodes() = %v, want %v", got, tt.want) + if !reflect.DeepEqual(got, test.want) { + t.Errorf("excludeCorrectlyTaggedNodes() = %v, want %v", got, test.want) } }) } diff --git a/api.go b/api.go index 020ded01..073be5e2 100644 --- a/api.go +++ b/api.go @@ -261,7 +261,16 @@ func (h *Headscale) getMapResponse( var respBody []byte if req.Compress == "zstd" { - src, _ := json.Marshal(resp) + src, err := json.Marshal(resp) + if err != nil { + log.Error(). + Caller(). + Str("func", "getMapResponse"). + Err(err). + Msg("Failed to marshal response for the client") + + return nil, err + } encoder, _ := zstd.NewWriter(nil) srcCompressed := encoder.EncodeAll(src, nil) @@ -290,7 +299,16 @@ func (h *Headscale) getMapKeepAliveResponse( var respBody []byte var err error if mapRequest.Compress == "zstd" { - src, _ := json.Marshal(mapResponse) + src, err := json.Marshal(mapResponse) + if err != nil { + log.Error(). + Caller(). + Str("func", "getMapKeepAliveResponse"). + Err(err). + Msg("Failed to marshal keepalive response for the client") + + return nil, err + } encoder, _ := zstd.NewWriter(nil) srcCompressed := encoder.EncodeAll(src, nil) respBody = h.privateKey.SealTo(machineKey, srcCompressed) diff --git a/dns.go b/dns.go index 37daa884..085a14e2 100644 --- a/dns.go +++ b/dns.go @@ -165,7 +165,7 @@ func getMapResponseDNSConfig( dnsConfig.Domains, fmt.Sprintf( "%s.%s", - strings.Replace(machine.Namespace.Name, "@", ".", -1), // Replace @ with . for valid domain for machine + strings.ReplaceAll(machine.Namespace.Name, "@", "."), // Replace @ with . for valid domain for machine baseDomain, ), ) @@ -176,7 +176,7 @@ func getMapResponseDNSConfig( namespaceSet.Add(p.Namespace) } for _, namespace := range namespaceSet.List() { - dnsRoute := fmt.Sprintf("%s.%s", namespace.(Namespace).Name, baseDomain) + var dnsRoute string = fmt.Sprintf("%v.%v", namespace.(Namespace).Name, baseDomain) dnsConfig.Routes[dnsRoute] = nil } } else { diff --git a/machine.go b/machine.go index 01fc9276..cb70bf12 100644 --- a/machine.go +++ b/machine.go @@ -138,6 +138,7 @@ func containsAddresses(inputs []string, addrs MachineAddresses) bool { return true } } + return false } @@ -174,20 +175,20 @@ func (h *Headscale) getFilteredByACLPeers(machine *Machine) (Machines, error) { // In order to do this we would need to be able to identify that node A want to talk to node B but that Node B doesn't know // how to talk to node A and then add the peering resource. - for _, m := range machines { + for _, mchn := range machines { for _, rule := range h.aclRules { var dst []string for _, d := range rule.DstPorts { dst = append(dst, d.IP) } - if (containsAddresses(rule.SrcIPs, machine.IPAddresses) && (containsAddresses(dst, m.IPAddresses) || containsString(dst, "*"))) || - (containsAddresses(rule.SrcIPs, m.IPAddresses) && containsAddresses(dst, machine.IPAddresses)) { - mMachines[m.ID] = m + if (containsAddresses(rule.SrcIPs, machine.IPAddresses) && (containsAddresses(dst, mchn.IPAddresses) || containsString(dst, "*"))) || + (containsAddresses(rule.SrcIPs, mchn.IPAddresses) && containsAddresses(dst, machine.IPAddresses)) { + mMachines[mchn.ID] = mchn } } } - var authorizedMachines Machines + authorizedMachines := make([]Machine, 0, len(mMachines)) for _, m := range mMachines { authorizedMachines = append(authorizedMachines, m) } @@ -694,7 +695,7 @@ func (machine Machine) toNode( hostname = fmt.Sprintf( "%s.%s.%s", machine.Name, - strings.Replace(machine.Namespace.Name, "@", ".", -1), // Replace @ with . for valid domain for machine + strings.ReplaceAll(machine.Namespace.Name, "@", "."), // Replace @ with . for valid domain for machine baseDomain, ) } else { diff --git a/machine_test.go b/machine_test.go index 203289ac..df00067b 100644 --- a/machine_test.go +++ b/machine_test.go @@ -161,7 +161,7 @@ func (s *Suite) TestGetACLFilteredPeers(c *check.C) { key *PreAuthKey } - var stor []base + stor := make([]base, 0) for _, name := range []string{"test", "admin"} { namespace, err := app.CreateNamespace(name) @@ -169,7 +169,6 @@ func (s *Suite) TestGetACLFilteredPeers(c *check.C) { pak, err := app.CreatePreAuthKey(namespace.Name, false, false, nil) c.Assert(err, check.IsNil) stor = append(stor, base{namespace, pak}) - } _, err := app.GetMachineByID(0) diff --git a/poll.go b/poll.go index c2a51d11..f00f7484 100644 --- a/poll.go +++ b/poll.go @@ -85,7 +85,10 @@ func (h *Headscale) PollNetMapHandler(ctx *gin.Context) { Str("machine", machine.Name). Msg("Found machine in database") - hostinfo, _ := json.Marshal(req.Hostinfo) + hostinfo, err := json.Marshal(req.Hostinfo) + if err != nil { + return + } machine.Name = req.Hostinfo.Hostname machine.HostInfo = datatypes.JSON(hostinfo) machine.DiscoKey = DiscoPublicKeyStripPrefix(req.DiscoKey) @@ -106,7 +109,17 @@ func (h *Headscale) PollNetMapHandler(ctx *gin.Context) { // The intended use is for clients to discover the DERP map at start-up // before their first real endpoint update. if !req.ReadOnly { - endpoints, _ := json.Marshal(req.Endpoints) + endpoints, err := json.Marshal(req.Endpoints) + if err != nil { + log.Error(). + Caller(). + Str("func", "PollNetMapHandler"). + Err(err). + Msg("Failed to mashal requested endpoints for the client") + ctx.String(http.StatusInternalServerError, ":(") + + return + } machine.Endpoints = datatypes.JSON(endpoints) machine.LastSeen = &now }