diff --git a/utils.go b/utils.go index c1a39bb8..004bf306 100644 --- a/utils.go +++ b/utils.go @@ -12,7 +12,6 @@ import ( "encoding/json" "fmt" "net" - "sort" "strings" "github.com/rs/zerolog/log" @@ -190,7 +189,7 @@ func (h *Headscale) getAvailableIP(ipPrefix netaddr.IPPrefix) (*netaddr.IP, erro } } -func (h *Headscale) getUsedIPs() (netaddr.IPSet, error) { +func (h *Headscale) getUsedIPs() (*netaddr.IPSet, error) { // FIXME: This really deserves a better data model, // but this was quick to get running and it should be enough // to begin experimenting with a dual stack tailnet. @@ -206,7 +205,7 @@ func (h *Headscale) getUsedIPs() (netaddr.IPSet, error) { var machineAddresses MachineAddresses err := machineAddresses.Scan(slice) if err != nil { - return netaddr.IPSet{}, fmt.Errorf( + return &netaddr.IPSet{}, fmt.Errorf( "failed to read ip from database: %w", err, ) @@ -221,7 +220,15 @@ func (h *Headscale) getUsedIPs() (netaddr.IPSet, error) { Interface("addresses", ips). Msg("Parsed ip addresses that has been allocated from databases") - return netaddr.IPSet{}, nil + ipSet, err := ips.IPSet() + if err != nil { + return &netaddr.IPSet{}, fmt.Errorf( + "failed to build IP Set: %w", + err, + ) + } + + return ipSet, nil } func containsString(ss []string, s string) bool { diff --git a/utils_test.go b/utils_test.go index feb44d5a..896040c8 100644 --- a/utils_test.go +++ b/utils_test.go @@ -48,9 +48,12 @@ func (s *Suite) TestGetUsedIps(c *check.C) { c.Assert(err, check.IsNil) expected := netaddr.MustParseIP("10.27.0.1") + expectedIPSetBuilder := netaddr.IPSetBuilder{} + expectedIPSetBuilder.Add(expected) + expectedIPSet, _ := expectedIPSetBuilder.IPSet() - c.Assert(len(usedIps), check.Equals, 1) - c.Assert(usedIps[0], check.Equals, expected) + c.Assert(usedIps.Equal(expectedIPSet), check.Equals, true) + c.Assert(usedIps.Contains(expected), check.Equals, true) machine1, err := app.GetMachineByID(0) c.Assert(err, check.IsNil) @@ -64,6 +67,8 @@ func (s *Suite) TestGetMultiIp(c *check.C) { c.Assert(err, check.IsNil) for index := 1; index <= 350; index++ { + app.ipAllocationMutex.Lock() + ips, err := app.getAvailableIPs() c.Assert(err, check.IsNil) @@ -86,17 +91,30 @@ func (s *Suite) TestGetMultiIp(c *check.C) { IPAddresses: ips, } app.db.Save(&machine) + + app.ipAllocationMutex.Unlock() } usedIps, err := app.getUsedIPs() - c.Assert(err, check.IsNil) - c.Assert(len(usedIps), check.Equals, 350) + expected0 := netaddr.MustParseIP("10.27.0.1") + expected9 := netaddr.MustParseIP("10.27.0.10") + expected300 := netaddr.MustParseIP("10.27.0.45") - c.Assert(usedIps[0], check.Equals, netaddr.MustParseIP("10.27.0.1")) - c.Assert(usedIps[9], check.Equals, netaddr.MustParseIP("10.27.0.10")) - c.Assert(usedIps[300], check.Equals, netaddr.MustParseIP("10.27.1.45")) + notExpectedIPSetBuilder := netaddr.IPSetBuilder{} + notExpectedIPSetBuilder.Add(expected0) + notExpectedIPSetBuilder.Add(expected9) + notExpectedIPSetBuilder.Add(expected300) + notExpectedIPSet, err := notExpectedIPSetBuilder.IPSet() + c.Assert(err, check.IsNil) + + // We actually expect it to be a lot larger + c.Assert(usedIps.Equal(notExpectedIPSet), check.Equals, false) + + c.Assert(usedIps.Contains(expected0), check.Equals, true) + c.Assert(usedIps.Contains(expected9), check.Equals, true) + c.Assert(usedIps.Contains(expected300), check.Equals, true) // Check that we can read back the IPs machine1, err := app.GetMachineByID(1)