Rework getAvailableIp

This commit reworks getAvailableIp with a "simpler" version that will
look for the first available IP address in our IP Prefix.

There is a couple of ideas behind this:

* Make the host IPs reasonably predictable and in within similar
  subnets, which should simplify ACLs for subnets
* The code is not random, but deterministic so we can have tests
* The code is a bit more understandable (no bit shift magic)
This commit is contained in:
Kristoffer Dalby 2021-08-02 21:57:45 +01:00
parent 309f868a21
commit b5841c8a8b
4 changed files with 170 additions and 45 deletions

View File

@ -38,7 +38,7 @@ func (s *Suite) ResetDB(c *check.C) {
c.Fatal(err) c.Fatal(err)
} }
cfg := Config{ cfg := Config{
IPPrefix: netaddr.MustParseIPPrefix("127.0.0.1/32"), IPPrefix: netaddr.MustParseIPPrefix("10.27.0.0/23"),
} }
h = Headscale{ h = Headscale{

View File

@ -15,6 +15,7 @@ func (s *Suite) TestRegisterMachine(c *check.C) {
DiscoKey: "faa", DiscoKey: "faa",
Name: "testmachine", Name: "testmachine",
NamespaceID: n.ID, NamespaceID: n.ID,
IPAddress: "10.0.0.1",
} }
h.db.Save(&m) h.db.Save(&m)

103
utils.go
View File

@ -7,18 +7,11 @@ package headscale
import ( import (
"crypto/rand" "crypto/rand"
"encoding/binary"
"encoding/json" "encoding/json"
"errors"
"fmt" "fmt"
"io" "io"
"net"
"time"
mathrand "math/rand"
"golang.org/x/crypto/nacl/box" "golang.org/x/crypto/nacl/box"
"gorm.io/gorm"
"inet.af/netaddr" "inet.af/netaddr"
"tailscale.com/types/wgkey" "tailscale.com/types/wgkey"
) )
@ -78,47 +71,73 @@ func encodeMsg(b []byte, pubKey *wgkey.Key, privKey *wgkey.Private) ([]byte, err
return msg, nil return msg, nil
} }
func (h *Headscale) getAvailableIP() (*net.IP, error) { func (h *Headscale) getAvailableIP() (*netaddr.IP, error) {
i := 0 ipPrefix := h.cfg.IPPrefix
for {
ip, err := getRandomIP(h.cfg.IPPrefix) usedIps, err := h.getUsedIPs()
if err != nil { if err != nil {
return nil, err return nil, err
} }
m := Machine{}
if result := h.db.First(&m, "ip_address = ?", ip.String()); errors.Is(result.Error, gorm.ErrRecordNotFound) {
return ip, nil
}
i++
if i == 100 { // really random number
break
}
}
return nil, errors.New(fmt.Sprintf("Could not find an available IP address in %s", h.cfg.IPPrefix.String()))
}
func getRandomIP(ipPrefix netaddr.IPPrefix) (*net.IP, error) { // for _, ip := range usedIps {
mathrand.Seed(time.Now().Unix()) // nextIP := ip.Next()
ipo, ipnet, err := net.ParseCIDR(ipPrefix.String())
if err == nil { // if !containsIPs(usedIps, nextIP) && ipPrefix.Contains(nextIP) {
ip := ipo.To4() // return &nextIP, nil
// fmt.Println("In Randomize IPAddr: IP ", ip, " IPNET: ", ipnet) // }
// fmt.Println("Final address is ", ip) // }
// fmt.Println("Broadcast address is ", ipb)
// fmt.Println("Network address is ", ipn) // // If there are no IPs in use, we are starting fresh and
r := mathrand.Uint32() // // can issue IPs from the beginning of the prefix.
ipRaw := make([]byte, 4) // ip := ipPrefix.IP()
binary.LittleEndian.PutUint32(ipRaw, r) // return &ip, nil
// ipRaw[3] = 254
// fmt.Println("ipRaw is ", ipRaw) // return nil, fmt.Errorf("failed to find any available IP in %s", ipPrefix)
for i, v := range ipRaw {
// fmt.Println("IP Before: ", ip[i], " v is ", v, " Mask is: ", ipnet.Mask[i]) // Get the first IP in our prefix
ip[i] = ip[i] + (v &^ ipnet.Mask[i]) ip := ipPrefix.IP()
// fmt.Println("IP After: ", ip[i])
for {
if !ipPrefix.Contains(ip) {
return nil, fmt.Errorf("could not find any suitable IP in %s", ipPrefix)
} }
// fmt.Println("FINAL IP: ", ip.String())
if ip.IsZero() &&
ip.IsLoopback() {
continue
}
if !containsIPs(usedIps, ip) {
return &ip, nil return &ip, nil
} }
return nil, err ip = ip.Next()
}
}
func (h *Headscale) getUsedIPs() ([]netaddr.IP, error) {
var addresses []string
h.db.Model(&Machine{}).Pluck("ip_address", &addresses)
ips := make([]netaddr.IP, len(addresses))
for index, addr := range addresses {
ip, err := netaddr.ParseIP(addr)
if err != nil {
return nil, fmt.Errorf("failed to parse ip from database, %w", err)
}
ips[index] = ip
}
return ips, nil
}
func containsIPs(ips []netaddr.IP, ip netaddr.IP) bool {
for _, v := range ips {
if v == ip {
return true
}
}
return false
} }

105
utils_test.go Normal file
View File

@ -0,0 +1,105 @@
package headscale
import (
"gopkg.in/check.v1"
"inet.af/netaddr"
)
func (s *Suite) TestGetAvailableIp(c *check.C) {
ip, err := h.getAvailableIP()
c.Assert(err, check.IsNil)
expected := netaddr.MustParseIP("10.27.0.0")
c.Assert(ip.String(), check.Equals, expected.String())
}
func (s *Suite) TestGetUsedIps(c *check.C) {
ip, err := h.getAvailableIP()
c.Assert(err, check.IsNil)
n, err := h.CreateNamespace("test_ip")
c.Assert(err, check.IsNil)
pak, err := h.CreatePreAuthKey(n.Name, false, false, nil)
c.Assert(err, check.IsNil)
_, err = h.GetMachine("test", "testmachine")
c.Assert(err, check.NotNil)
m := Machine{
ID: 0,
MachineKey: "foo",
NodeKey: "bar",
DiscoKey: "faa",
Name: "testmachine",
NamespaceID: n.ID,
Registered: true,
RegisterMethod: "authKey",
AuthKeyID: uint(pak.ID),
IPAddress: ip.String(),
}
h.db.Save(&m)
ips, err := h.getUsedIPs()
c.Assert(err, check.IsNil)
expected := netaddr.MustParseIP("10.27.0.0")
c.Assert(ips[0], check.Equals, expected)
}
func (s *Suite) TestGetMultiIp(c *check.C) {
n, err := h.CreateNamespace("test-ip-multi")
c.Assert(err, check.IsNil)
for i := 1; i <= 350; i++ {
ip, err := h.getAvailableIP()
c.Assert(err, check.IsNil)
pak, err := h.CreatePreAuthKey(n.Name, false, false, nil)
c.Assert(err, check.IsNil)
_, err = h.GetMachine("test", "testmachine")
c.Assert(err, check.NotNil)
m := Machine{
ID: 0,
MachineKey: "foo",
NodeKey: "bar",
DiscoKey: "faa",
Name: "testmachine",
NamespaceID: n.ID,
Registered: true,
RegisterMethod: "authKey",
AuthKeyID: uint(pak.ID),
IPAddress: ip.String(),
}
h.db.Save(&m)
}
ips, err := h.getUsedIPs()
c.Assert(err, check.IsNil)
c.Assert(len(ips), check.Equals, 350)
c.Assert(ips[0], check.Equals, netaddr.MustParseIP("10.27.0.0"))
c.Assert(ips[9], check.Equals, netaddr.MustParseIP("10.27.0.9"))
c.Assert(ips[300], check.Equals, netaddr.MustParseIP("10.27.1.44"))
expectedNextIP := netaddr.MustParseIP("10.27.1.94")
nextIP, err := h.getAvailableIP()
c.Assert(err, check.IsNil)
c.Assert(nextIP.String(), check.Equals, expectedNextIP.String())
// If we call get Available again, we should receive
// the same IP, as it has not been reserved.
nextIP2, err := h.getAvailableIP()
c.Assert(err, check.IsNil)
c.Assert(nextIP2.String(), check.Equals, expectedNextIP.String())
}