diff --git a/README.md b/README.md index 5f90c831..16448534 100644 --- a/README.md +++ b/README.md @@ -113,9 +113,10 @@ Headscale's configuration file is named `config.json` or `config.yaml`. Headscal ``` "server_url": "http://192.168.1.12:8080", "listen_addr": "0.0.0.0:8080", + "ip_prefix": "100.64.0.0/10" ``` -`server_url` is the external URL via which Headscale is reachable. `listen_addr` is the IP address and port the Headscale program should listen on. +`server_url` is the external URL via which Headscale is reachable. `listen_addr` is the IP address and port the Headscale program should listen on. `ip_prefix` is the IP prefix (range) in which IP addresses for nodes will be allocated. ``` "private_key_path": "private.key", diff --git a/api.go b/api.go index 088c337f..97ec4d41 100644 --- a/api.go +++ b/api.go @@ -445,6 +445,7 @@ func (h *Headscale) handleAuthKey(c *gin.Context, db *gorm.DB, idKey wgkey.Key, log.Println(err) return } + log.Printf("Assigning %s to %s", ip, m.Name) m.AuthKeyID = uint(pak.ID) m.IPAddress = ip.String() diff --git a/app.go b/app.go index fa9b011b..e3978212 100644 --- a/app.go +++ b/app.go @@ -13,6 +13,7 @@ import ( "github.com/gin-gonic/gin" "golang.org/x/crypto/acme/autocert" "gorm.io/gorm" + "inet.af/netaddr" "tailscale.com/tailcfg" "tailscale.com/types/wgkey" ) @@ -24,6 +25,7 @@ type Config struct { PrivateKeyPath string DerpMap *tailcfg.DERPMap EphemeralNodeInactivityTimeout time.Duration + IPPrefix netaddr.IPPrefix DBtype string DBpath string diff --git a/app_test.go b/app_test.go index ad633334..5e53f1cc 100644 --- a/app_test.go +++ b/app_test.go @@ -6,6 +6,7 @@ import ( "testing" "gopkg.in/check.v1" + "inet.af/netaddr" ) func Test(t *testing.T) { @@ -36,7 +37,9 @@ func (s *Suite) ResetDB(c *check.C) { if err != nil { c.Fatal(err) } - cfg := Config{} + cfg := Config{ + IPPrefix: netaddr.MustParseIPPrefix("10.27.0.0/23"), + } h = Headscale{ cfg: cfg, diff --git a/cli_test.go b/cli_test.go index 9616b4a2..528a115e 100644 --- a/cli_test.go +++ b/cli_test.go @@ -15,6 +15,7 @@ func (s *Suite) TestRegisterMachine(c *check.C) { DiscoKey: "faa", Name: "testmachine", NamespaceID: n.ID, + IPAddress: "10.0.0.1", } h.db.Save(&m) diff --git a/cmd/headscale/cli/utils.go b/cmd/headscale/cli/utils.go index 5e47d157..1c259c74 100644 --- a/cmd/headscale/cli/utils.go +++ b/cmd/headscale/cli/utils.go @@ -14,6 +14,7 @@ import ( "github.com/juanfont/headscale" "github.com/spf13/viper" "gopkg.in/yaml.v2" + "inet.af/netaddr" "tailscale.com/tailcfg" ) @@ -36,6 +37,8 @@ func LoadConfig(path string) error { viper.SetDefault("tls_letsencrypt_cache_dir", "/var/www/.cache") viper.SetDefault("tls_letsencrypt_challenge_type", "HTTP-01") + viper.SetDefault("ip_prefix", "100.64.0.0/10") + err := viper.ReadInConfig() if err != nil { return fmt.Errorf("Fatal error reading config file: %s \n", err) @@ -97,6 +100,7 @@ func getHeadscaleApp() (*headscale.Headscale, error) { Addr: viper.GetString("listen_addr"), PrivateKeyPath: absPath(viper.GetString("private_key_path")), DerpMap: derpMap, + IPPrefix: netaddr.MustParseIPPrefix(viper.GetString("ip_prefix")), EphemeralNodeInactivityTimeout: viper.GetDuration("ephemeral_node_inactivity_timeout"), diff --git a/utils.go b/utils.go index f21063b0..03dc673f 100644 --- a/utils.go +++ b/utils.go @@ -7,18 +7,12 @@ package headscale import ( "crypto/rand" - "encoding/binary" "encoding/json" - "errors" "fmt" "io" - "net" - "time" - - mathrand "math/rand" "golang.org/x/crypto/nacl/box" - "gorm.io/gorm" + "inet.af/netaddr" "tailscale.com/types/wgkey" ) @@ -77,47 +71,71 @@ func encodeMsg(b []byte, pubKey *wgkey.Key, privKey *wgkey.Private) ([]byte, err return msg, nil } -func (h *Headscale) getAvailableIP() (*net.IP, error) { - i := 0 +func (h *Headscale) getAvailableIP() (*netaddr.IP, error) { + ipPrefix := h.cfg.IPPrefix + + usedIps, err := h.getUsedIPs() + if err != nil { + return nil, err + } + + // Get the first IP in our prefix + ip := ipPrefix.IP() + for { - ip, err := getRandomIP() - if err != nil { - return nil, err + if !ipPrefix.Contains(ip) { + return nil, fmt.Errorf("could not find any suitable IP in %s", ipPrefix) } - m := Machine{} - if result := h.db.First(&m, "ip_address = ?", ip.String()); errors.Is(result.Error, gorm.ErrRecordNotFound) { - return ip, nil + + // Some OS (including Linux) does not like when IPs ends with 0 or 255, which + // is typically called network or broadcast. Lets avoid them and continue + // to look when we get one of those traditionally reserved IPs. + ipRaw := ip.As4() + if ipRaw[3] == 0 || ipRaw[3] == 255 { + ip = ip.Next() + continue } - i++ - if i == 100 { // really random number - break + + if ip.IsZero() && + ip.IsLoopback() { + + ip = ip.Next() + continue } + + if !containsIPs(usedIps, ip) { + return &ip, nil + } + + ip = ip.Next() } - return nil, errors.New("Could not find an available IP address in 100.64.0.0/10") } -func getRandomIP() (*net.IP, error) { - mathrand.Seed(time.Now().Unix()) - ipo, ipnet, err := net.ParseCIDR("100.64.0.0/10") - if err == nil { - ip := ipo.To4() - // fmt.Println("In Randomize IPAddr: IP ", ip, " IPNET: ", ipnet) - // fmt.Println("Final address is ", ip) - // fmt.Println("Broadcast address is ", ipb) - // fmt.Println("Network address is ", ipn) - r := mathrand.Uint32() - ipRaw := make([]byte, 4) - binary.LittleEndian.PutUint32(ipRaw, r) - // ipRaw[3] = 254 - // fmt.Println("ipRaw is ", ipRaw) - for i, v := range ipRaw { - // fmt.Println("IP Before: ", ip[i], " v is ", v, " Mask is: ", ipnet.Mask[i]) - ip[i] = ip[i] + (v &^ ipnet.Mask[i]) - // fmt.Println("IP After: ", ip[i]) +func (h *Headscale) getUsedIPs() ([]netaddr.IP, error) { + var addresses []string + h.db.Model(&Machine{}).Pluck("ip_address", &addresses) + + ips := make([]netaddr.IP, len(addresses)) + for index, addr := range addresses { + if addr != "" { + ip, err := netaddr.ParseIP(addr) + if err != nil { + return nil, fmt.Errorf("failed to parse ip from database, %w", err) + } + + ips[index] = ip } - // fmt.Println("FINAL IP: ", ip.String()) - return &ip, nil } - return nil, err + return ips, nil +} + +func containsIPs(ips []netaddr.IP, ip netaddr.IP) bool { + for _, v := range ips { + if v == ip { + return true + } + } + + return false } diff --git a/utils_test.go b/utils_test.go new file mode 100644 index 00000000..f50cd117 --- /dev/null +++ b/utils_test.go @@ -0,0 +1,155 @@ +package headscale + +import ( + "gopkg.in/check.v1" + "inet.af/netaddr" +) + +func (s *Suite) TestGetAvailableIp(c *check.C) { + ip, err := h.getAvailableIP() + + c.Assert(err, check.IsNil) + + expected := netaddr.MustParseIP("10.27.0.1") + + c.Assert(ip.String(), check.Equals, expected.String()) +} + +func (s *Suite) TestGetUsedIps(c *check.C) { + ip, err := h.getAvailableIP() + c.Assert(err, check.IsNil) + + n, err := h.CreateNamespace("test_ip") + c.Assert(err, check.IsNil) + + pak, err := h.CreatePreAuthKey(n.Name, false, false, nil) + c.Assert(err, check.IsNil) + + _, err = h.GetMachine("test", "testmachine") + c.Assert(err, check.NotNil) + + m := Machine{ + ID: 0, + MachineKey: "foo", + NodeKey: "bar", + DiscoKey: "faa", + Name: "testmachine", + NamespaceID: n.ID, + Registered: true, + RegisterMethod: "authKey", + AuthKeyID: uint(pak.ID), + IPAddress: ip.String(), + } + h.db.Save(&m) + + ips, err := h.getUsedIPs() + + c.Assert(err, check.IsNil) + + expected := netaddr.MustParseIP("10.27.0.1") + + c.Assert(ips[0], check.Equals, expected) + + m1, err := h.GetMachineByID(0) + c.Assert(err, check.IsNil) + + c.Assert(m1.IPAddress, check.Equals, expected.String()) +} + +func (s *Suite) TestGetMultiIp(c *check.C) { + n, err := h.CreateNamespace("test-ip-multi") + c.Assert(err, check.IsNil) + + for i := 1; i <= 350; i++ { + ip, err := h.getAvailableIP() + c.Assert(err, check.IsNil) + + pak, err := h.CreatePreAuthKey(n.Name, false, false, nil) + c.Assert(err, check.IsNil) + + _, err = h.GetMachine("test", "testmachine") + c.Assert(err, check.NotNil) + + m := Machine{ + ID: uint64(i), + MachineKey: "foo", + NodeKey: "bar", + DiscoKey: "faa", + Name: "testmachine", + NamespaceID: n.ID, + Registered: true, + RegisterMethod: "authKey", + AuthKeyID: uint(pak.ID), + IPAddress: ip.String(), + } + h.db.Save(&m) + } + + ips, err := h.getUsedIPs() + + c.Assert(err, check.IsNil) + + c.Assert(len(ips), check.Equals, 350) + + c.Assert(ips[0], check.Equals, netaddr.MustParseIP("10.27.0.1")) + c.Assert(ips[9], check.Equals, netaddr.MustParseIP("10.27.0.10")) + c.Assert(ips[300], check.Equals, netaddr.MustParseIP("10.27.1.47")) + + // Check that we can read back the IPs + m1, err := h.GetMachineByID(1) + c.Assert(err, check.IsNil) + c.Assert(m1.IPAddress, check.Equals, netaddr.MustParseIP("10.27.0.1").String()) + + m50, err := h.GetMachineByID(50) + c.Assert(err, check.IsNil) + c.Assert(m50.IPAddress, check.Equals, netaddr.MustParseIP("10.27.0.50").String()) + + expectedNextIP := netaddr.MustParseIP("10.27.1.97") + nextIP, err := h.getAvailableIP() + c.Assert(err, check.IsNil) + + c.Assert(nextIP.String(), check.Equals, expectedNextIP.String()) + + // If we call get Available again, we should receive + // the same IP, as it has not been reserved. + nextIP2, err := h.getAvailableIP() + c.Assert(err, check.IsNil) + + c.Assert(nextIP2.String(), check.Equals, expectedNextIP.String()) +} + +func (s *Suite) TestGetAvailableIpMachineWithoutIP(c *check.C) { + ip, err := h.getAvailableIP() + c.Assert(err, check.IsNil) + + expected := netaddr.MustParseIP("10.27.0.1") + + c.Assert(ip.String(), check.Equals, expected.String()) + + n, err := h.CreateNamespace("test_ip") + c.Assert(err, check.IsNil) + + pak, err := h.CreatePreAuthKey(n.Name, false, false, nil) + c.Assert(err, check.IsNil) + + _, err = h.GetMachine("test", "testmachine") + c.Assert(err, check.NotNil) + + m := Machine{ + ID: 0, + MachineKey: "foo", + NodeKey: "bar", + DiscoKey: "faa", + Name: "testmachine", + NamespaceID: n.ID, + Registered: true, + RegisterMethod: "authKey", + AuthKeyID: uint(pak.ID), + } + h.db.Save(&m) + + ip2, err := h.getAvailableIP() + c.Assert(err, check.IsNil) + + c.Assert(ip2.String(), check.Equals, expected.String()) +}