mirror of
				https://github.com/juanfont/headscale.git
				synced 2025-11-04 02:05:46 +00:00 
			
		
		
		
	Add support for multiple IP prefixes
This commit is contained in:
		
							
								
								
									
										6
									
								
								acls.go
									
									
									
									
									
								
							
							
						
						
									
										6
									
								
								acls.go
									
									
									
									
									
								
							@@ -185,7 +185,7 @@ func (h *Headscale) expandAlias(alias string) ([]string, error) {
 | 
			
		||||
				return nil, errInvalidNamespace
 | 
			
		||||
			}
 | 
			
		||||
			for _, node := range nodes {
 | 
			
		||||
				ips = append(ips, node.IPAddress)
 | 
			
		||||
				ips = append(ips, node.IPAddresses.ToStringSlice()...)
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
@@ -219,7 +219,7 @@ func (h *Headscale) expandAlias(alias string) ([]string, error) {
 | 
			
		||||
				// FIXME: Check TagOwners allows this
 | 
			
		||||
				for _, t := range hostinfo.RequestTags {
 | 
			
		||||
					if alias[4:] == t {
 | 
			
		||||
						ips = append(ips, machine.IPAddress)
 | 
			
		||||
						ips = append(ips, machine.IPAddresses.ToStringSlice()...)
 | 
			
		||||
 | 
			
		||||
						break
 | 
			
		||||
					}
 | 
			
		||||
@@ -238,7 +238,7 @@ func (h *Headscale) expandAlias(alias string) ([]string, error) {
 | 
			
		||||
		}
 | 
			
		||||
		ips := []string{}
 | 
			
		||||
		for _, n := range nodes {
 | 
			
		||||
			ips = append(ips, n.IPAddress)
 | 
			
		||||
			ips = append(ips, n.IPAddresses.ToStringSlice()...)
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		return ips, nil
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										50
									
								
								acls_test.go
									
									
									
									
									
								
							
							
						
						
									
										50
									
								
								acls_test.go
									
									
									
									
									
								
							@@ -61,9 +61,9 @@ func (s *Suite) TestPortRange(c *check.C) {
 | 
			
		||||
	c.Assert(rules, check.NotNil)
 | 
			
		||||
 | 
			
		||||
	c.Assert(rules, check.HasLen, 1)
 | 
			
		||||
	c.Assert((rules)[0].DstPorts, check.HasLen, 1)
 | 
			
		||||
	c.Assert((rules)[0].DstPorts[0].Ports.First, check.Equals, uint16(5400))
 | 
			
		||||
	c.Assert((rules)[0].DstPorts[0].Ports.Last, check.Equals, uint16(5500))
 | 
			
		||||
	c.Assert(rules[0].DstPorts, check.HasLen, 1)
 | 
			
		||||
	c.Assert(rules[0].DstPorts[0].Ports.First, check.Equals, uint16(5400))
 | 
			
		||||
	c.Assert(rules[0].DstPorts[0].Ports.Last, check.Equals, uint16(5500))
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s *Suite) TestPortWildcard(c *check.C) {
 | 
			
		||||
@@ -75,11 +75,11 @@ func (s *Suite) TestPortWildcard(c *check.C) {
 | 
			
		||||
	c.Assert(rules, check.NotNil)
 | 
			
		||||
 | 
			
		||||
	c.Assert(rules, check.HasLen, 1)
 | 
			
		||||
	c.Assert((rules)[0].DstPorts, check.HasLen, 1)
 | 
			
		||||
	c.Assert((rules)[0].DstPorts[0].Ports.First, check.Equals, uint16(0))
 | 
			
		||||
	c.Assert((rules)[0].DstPorts[0].Ports.Last, check.Equals, uint16(65535))
 | 
			
		||||
	c.Assert((rules)[0].SrcIPs, check.HasLen, 1)
 | 
			
		||||
	c.Assert((rules)[0].SrcIPs[0], check.Equals, "*")
 | 
			
		||||
	c.Assert(rules[0].DstPorts, check.HasLen, 1)
 | 
			
		||||
	c.Assert(rules[0].DstPorts[0].Ports.First, check.Equals, uint16(0))
 | 
			
		||||
	c.Assert(rules[0].DstPorts[0].Ports.Last, check.Equals, uint16(65535))
 | 
			
		||||
	c.Assert(rules[0].SrcIPs, check.HasLen, 1)
 | 
			
		||||
	c.Assert(rules[0].SrcIPs[0], check.Equals, "*")
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s *Suite) TestPortNamespace(c *check.C) {
 | 
			
		||||
@@ -91,7 +91,7 @@ func (s *Suite) TestPortNamespace(c *check.C) {
 | 
			
		||||
 | 
			
		||||
	_, err = app.GetMachine("testnamespace", "testmachine")
 | 
			
		||||
	c.Assert(err, check.NotNil)
 | 
			
		||||
	ip, _ := app.getAvailableIP()
 | 
			
		||||
	ips, _ := app.getAvailableIPs()
 | 
			
		||||
	machine := Machine{
 | 
			
		||||
		ID:             0,
 | 
			
		||||
		MachineKey:     "foo",
 | 
			
		||||
@@ -101,7 +101,7 @@ func (s *Suite) TestPortNamespace(c *check.C) {
 | 
			
		||||
		NamespaceID:    namespace.ID,
 | 
			
		||||
		Registered:     true,
 | 
			
		||||
		RegisterMethod: RegisterMethodAuthKey,
 | 
			
		||||
		IPAddress:      ip.String(),
 | 
			
		||||
		IPAddresses:    ips,
 | 
			
		||||
		AuthKeyID:      uint(pak.ID),
 | 
			
		||||
	}
 | 
			
		||||
	app.db.Save(&machine)
 | 
			
		||||
@@ -116,12 +116,13 @@ func (s *Suite) TestPortNamespace(c *check.C) {
 | 
			
		||||
	c.Assert(rules, check.NotNil)
 | 
			
		||||
 | 
			
		||||
	c.Assert(rules, check.HasLen, 1)
 | 
			
		||||
	c.Assert((rules)[0].DstPorts, check.HasLen, 1)
 | 
			
		||||
	c.Assert((rules)[0].DstPorts[0].Ports.First, check.Equals, uint16(0))
 | 
			
		||||
	c.Assert((rules)[0].DstPorts[0].Ports.Last, check.Equals, uint16(65535))
 | 
			
		||||
	c.Assert((rules)[0].SrcIPs, check.HasLen, 1)
 | 
			
		||||
	c.Assert((rules)[0].SrcIPs[0], check.Not(check.Equals), "not an ip")
 | 
			
		||||
	c.Assert((rules)[0].SrcIPs[0], check.Equals, ip.String())
 | 
			
		||||
	c.Assert(rules[0].DstPorts, check.HasLen, 1)
 | 
			
		||||
	c.Assert(rules[0].DstPorts[0].Ports.First, check.Equals, uint16(0))
 | 
			
		||||
	c.Assert(rules[0].DstPorts[0].Ports.Last, check.Equals, uint16(65535))
 | 
			
		||||
	c.Assert(rules[0].SrcIPs, check.HasLen, 1)
 | 
			
		||||
	c.Assert(rules[0].SrcIPs[0], check.Not(check.Equals), "not an ip")
 | 
			
		||||
	c.Assert(len(ips), check.Equals, 1)
 | 
			
		||||
	c.Assert(rules[0].SrcIPs[0], check.Equals, ips[0].String())
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s *Suite) TestPortGroup(c *check.C) {
 | 
			
		||||
@@ -133,7 +134,7 @@ func (s *Suite) TestPortGroup(c *check.C) {
 | 
			
		||||
 | 
			
		||||
	_, err = app.GetMachine("testnamespace", "testmachine")
 | 
			
		||||
	c.Assert(err, check.NotNil)
 | 
			
		||||
	ip, _ := app.getAvailableIP()
 | 
			
		||||
	ips, _ := app.getAvailableIPs()
 | 
			
		||||
	machine := Machine{
 | 
			
		||||
		ID:             0,
 | 
			
		||||
		MachineKey:     "foo",
 | 
			
		||||
@@ -143,7 +144,7 @@ func (s *Suite) TestPortGroup(c *check.C) {
 | 
			
		||||
		NamespaceID:    namespace.ID,
 | 
			
		||||
		Registered:     true,
 | 
			
		||||
		RegisterMethod: RegisterMethodAuthKey,
 | 
			
		||||
		IPAddress:      ip.String(),
 | 
			
		||||
		IPAddresses:    ips,
 | 
			
		||||
		AuthKeyID:      uint(pak.ID),
 | 
			
		||||
	}
 | 
			
		||||
	app.db.Save(&machine)
 | 
			
		||||
@@ -156,10 +157,11 @@ func (s *Suite) TestPortGroup(c *check.C) {
 | 
			
		||||
	c.Assert(rules, check.NotNil)
 | 
			
		||||
 | 
			
		||||
	c.Assert(rules, check.HasLen, 1)
 | 
			
		||||
	c.Assert((rules)[0].DstPorts, check.HasLen, 1)
 | 
			
		||||
	c.Assert((rules)[0].DstPorts[0].Ports.First, check.Equals, uint16(0))
 | 
			
		||||
	c.Assert((rules)[0].DstPorts[0].Ports.Last, check.Equals, uint16(65535))
 | 
			
		||||
	c.Assert((rules)[0].SrcIPs, check.HasLen, 1)
 | 
			
		||||
	c.Assert((rules)[0].SrcIPs[0], check.Not(check.Equals), "not an ip")
 | 
			
		||||
	c.Assert((rules)[0].SrcIPs[0], check.Equals, ip.String())
 | 
			
		||||
	c.Assert(rules[0].DstPorts, check.HasLen, 1)
 | 
			
		||||
	c.Assert(rules[0].DstPorts[0].Ports.First, check.Equals, uint16(0))
 | 
			
		||||
	c.Assert(rules[0].DstPorts[0].Ports.Last, check.Equals, uint16(65535))
 | 
			
		||||
	c.Assert(rules[0].SrcIPs, check.HasLen, 1)
 | 
			
		||||
	c.Assert(rules[0].SrcIPs[0], check.Not(check.Equals), "not an ip")
 | 
			
		||||
	c.Assert(len(ips), check.Equals, 1)
 | 
			
		||||
	c.Assert(rules[0].SrcIPs[0], check.Equals, ips[0].String())
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										15
									
								
								api.go
									
									
									
									
									
								
							
							
						
						
									
										15
									
								
								api.go
									
									
									
									
									
								
							@@ -497,6 +497,7 @@ func (h *Headscale) handleMachineRegistrationNew(
 | 
			
		||||
	ctx.Data(http.StatusOK, "application/json; charset=utf-8", respBody)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// TODO: check if any locks are needed around IP allocation.
 | 
			
		||||
func (h *Headscale) handleAuthKey(
 | 
			
		||||
	ctx *gin.Context,
 | 
			
		||||
	machineKey key.MachinePublic,
 | 
			
		||||
@@ -554,14 +555,14 @@ func (h *Headscale) handleAuthKey(
 | 
			
		||||
		log.Debug().
 | 
			
		||||
			Str("func", "handleAuthKey").
 | 
			
		||||
			Str("machine", machine.Name).
 | 
			
		||||
			Msg("Authentication key was valid, proceeding to acquire an IP address")
 | 
			
		||||
		ip, err := h.getAvailableIP()
 | 
			
		||||
			Msg("Authentication key was valid, proceeding to acquire IP addresses")
 | 
			
		||||
		ips, err := h.getAvailableIPs()
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			log.Error().
 | 
			
		||||
				Caller().
 | 
			
		||||
				Str("func", "handleAuthKey").
 | 
			
		||||
				Str("machine", machine.Name).
 | 
			
		||||
				Msg("Failed to find an available IP")
 | 
			
		||||
				Msg("Failed to find an available IP address")
 | 
			
		||||
			machineRegistrations.WithLabelValues("new", "authkey", "error", machine.Namespace.Name).
 | 
			
		||||
				Inc()
 | 
			
		||||
 | 
			
		||||
@@ -570,12 +571,12 @@ func (h *Headscale) handleAuthKey(
 | 
			
		||||
		log.Info().
 | 
			
		||||
			Str("func", "handleAuthKey").
 | 
			
		||||
			Str("machine", machine.Name).
 | 
			
		||||
			Str("ip", ip.String()).
 | 
			
		||||
			Msgf("Assigning %s to %s", ip, machine.Name)
 | 
			
		||||
			Str("ips", strings.Join(ips.ToStringSlice(), ",")).
 | 
			
		||||
			Msgf("Assigning %s to %s", strings.Join(ips.ToStringSlice(), ","), machine.Name)
 | 
			
		||||
 | 
			
		||||
		machine.Expiry = ®isterRequest.Expiry
 | 
			
		||||
		machine.AuthKeyID = uint(pak.ID)
 | 
			
		||||
		machine.IPAddress = ip.String()
 | 
			
		||||
		machine.IPAddresses = ips
 | 
			
		||||
		machine.NamespaceID = pak.NamespaceID
 | 
			
		||||
 | 
			
		||||
		machine.NodeKey = NodePublicKeyStripPrefix(registerRequest.NodeKey)
 | 
			
		||||
@@ -610,6 +611,6 @@ func (h *Headscale) handleAuthKey(
 | 
			
		||||
	log.Info().
 | 
			
		||||
		Str("func", "handleAuthKey").
 | 
			
		||||
		Str("machine", machine.Name).
 | 
			
		||||
		Str("ip", machine.IPAddress).
 | 
			
		||||
		Str("ips", strings.Join(machine.IPAddresses.ToStringSlice(), ", ")).
 | 
			
		||||
		Msg("Successfully authenticated via AuthKey")
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										6
									
								
								app.go
									
									
									
									
									
								
							
							
						
						
									
										6
									
								
								app.go
									
									
									
									
									
								
							@@ -68,7 +68,7 @@ type Config struct {
 | 
			
		||||
	ServerURL                      string
 | 
			
		||||
	Addr                           string
 | 
			
		||||
	EphemeralNodeInactivityTimeout time.Duration
 | 
			
		||||
	IPPrefix                       netaddr.IPPrefix
 | 
			
		||||
	IPPrefixes                     []netaddr.IPPrefix
 | 
			
		||||
	PrivateKeyPath                 string
 | 
			
		||||
	BaseDomain                     string
 | 
			
		||||
 | 
			
		||||
@@ -197,9 +197,7 @@ func NewHeadscale(cfg Config) (*Headscale, error) {
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if app.cfg.DNSConfig != nil && app.cfg.DNSConfig.Proxied { // if MagicDNS
 | 
			
		||||
		magicDNSDomains := generateMagicDNSRootDomains(
 | 
			
		||||
			app.cfg.IPPrefix,
 | 
			
		||||
		)
 | 
			
		||||
		magicDNSDomains := generateMagicDNSRootDomains(app.cfg.IPPrefixes)
 | 
			
		||||
		// we might have routes already from Split DNS
 | 
			
		||||
		if app.cfg.DNSConfig.Routes == nil {
 | 
			
		||||
			app.cfg.DNSConfig.Routes = make(map[string][]dnstype.Resolver)
 | 
			
		||||
 
 | 
			
		||||
@@ -41,7 +41,9 @@ func (s *Suite) ResetDB(c *check.C) {
 | 
			
		||||
		c.Fatal(err)
 | 
			
		||||
	}
 | 
			
		||||
	cfg := Config{
 | 
			
		||||
		IPPrefix: netaddr.MustParseIPPrefix("10.27.0.0/23"),
 | 
			
		||||
		IPPrefixes: []netaddr.IPPrefix{
 | 
			
		||||
			netaddr.MustParseIPPrefix("10.27.0.0/23"),
 | 
			
		||||
		},
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	app = Headscale{
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										10
									
								
								cli_test.go
									
									
									
									
									
								
							
							
						
						
									
										10
									
								
								cli_test.go
									
									
									
									
									
								
							@@ -4,6 +4,7 @@ import (
 | 
			
		||||
	"time"
 | 
			
		||||
 | 
			
		||||
	"gopkg.in/check.v1"
 | 
			
		||||
	"inet.af/netaddr"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func (s *Suite) TestRegisterMachine(c *check.C) {
 | 
			
		||||
@@ -19,16 +20,17 @@ func (s *Suite) TestRegisterMachine(c *check.C) {
 | 
			
		||||
		DiscoKey:    "faa",
 | 
			
		||||
		Name:        "testmachine",
 | 
			
		||||
		NamespaceID: namespace.ID,
 | 
			
		||||
		IPAddress:   "10.0.0.1",
 | 
			
		||||
		IPAddresses: []netaddr.IP{netaddr.MustParseIP("10.0.0.1")},
 | 
			
		||||
		Expiry:      &now,
 | 
			
		||||
	}
 | 
			
		||||
	app.db.Save(&machine)
 | 
			
		||||
	err = app.db.Save(&machine).Error
 | 
			
		||||
	c.Assert(err, check.IsNil)
 | 
			
		||||
 | 
			
		||||
	_, err = app.GetMachine("test", "testmachine")
 | 
			
		||||
	_, err = app.GetMachine(namespace.Name, machine.Name)
 | 
			
		||||
	c.Assert(err, check.IsNil)
 | 
			
		||||
 | 
			
		||||
	machineAfterRegistering, err := app.RegisterMachine(
 | 
			
		||||
		"8ce002a935f8c394e55e78fbbb410576575ff8ec5cfa2e627e4b807f1be15b0e",
 | 
			
		||||
		machine.MachineKey,
 | 
			
		||||
		namespace.Name,
 | 
			
		||||
	)
 | 
			
		||||
	c.Assert(err, check.IsNil)
 | 
			
		||||
 
 | 
			
		||||
@@ -4,6 +4,7 @@ import (
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"log"
 | 
			
		||||
	"strconv"
 | 
			
		||||
	"strings"
 | 
			
		||||
	"time"
 | 
			
		||||
 | 
			
		||||
	survey "github.com/AlecAivazis/survey/v2"
 | 
			
		||||
@@ -459,7 +460,7 @@ func nodesToPtables(
 | 
			
		||||
			"Name",
 | 
			
		||||
			"NodeKey",
 | 
			
		||||
			"Namespace",
 | 
			
		||||
			"IP address",
 | 
			
		||||
			"IP addresses",
 | 
			
		||||
			"Ephemeral",
 | 
			
		||||
			"Last seen",
 | 
			
		||||
			"Online",
 | 
			
		||||
@@ -523,7 +524,7 @@ func nodesToPtables(
 | 
			
		||||
				machine.Name,
 | 
			
		||||
				nodeKey.ShortString(),
 | 
			
		||||
				namespace,
 | 
			
		||||
				machine.IpAddress,
 | 
			
		||||
				strings.Join(machine.IpAddresses, ", "),
 | 
			
		||||
				strconv.FormatBool(ephemeral),
 | 
			
		||||
				lastSeenTime,
 | 
			
		||||
				online,
 | 
			
		||||
 
 | 
			
		||||
@@ -41,7 +41,7 @@ func LoadConfig(path string) error {
 | 
			
		||||
	viper.SetDefault("tls_letsencrypt_cache_dir", "/var/www/.cache")
 | 
			
		||||
	viper.SetDefault("tls_letsencrypt_challenge_type", "HTTP-01")
 | 
			
		||||
 | 
			
		||||
	viper.SetDefault("ip_prefix", "100.64.0.0/10")
 | 
			
		||||
	viper.SetDefault("ip_prefixes", []string{"100.64.0.0/10"})
 | 
			
		||||
 | 
			
		||||
	viper.SetDefault("log_level", "info")
 | 
			
		||||
 | 
			
		||||
@@ -221,10 +221,20 @@ func getHeadscaleConfig() headscale.Config {
 | 
			
		||||
	dnsConfig, baseDomain := GetDNSConfig()
 | 
			
		||||
	derpConfig := GetDERPConfig()
 | 
			
		||||
 | 
			
		||||
	configuredPrefixes := viper.GetStringSlice("ip_prefixes")
 | 
			
		||||
	prefixes := make([]netaddr.IPPrefix, 0, len(configuredPrefixes))
 | 
			
		||||
	for i, prefixInConfig := range configuredPrefixes {
 | 
			
		||||
		prefix, err := netaddr.ParseIPPrefix(prefixInConfig)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			panic(fmt.Errorf("failed to parse ip_prefixes[%d]: %w", i, err))
 | 
			
		||||
		}
 | 
			
		||||
		prefixes = append(prefixes, prefix)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return headscale.Config{
 | 
			
		||||
		ServerURL:      viper.GetString("server_url"),
 | 
			
		||||
		Addr:           viper.GetString("listen_addr"),
 | 
			
		||||
		IPPrefix:       netaddr.MustParseIPPrefix(viper.GetString("ip_prefix")),
 | 
			
		||||
		IPPrefixes:     prefixes,
 | 
			
		||||
		PrivateKeyPath: absPath(viper.GetString("private_key_path")),
 | 
			
		||||
		BaseDomain:     baseDomain,
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										27
									
								
								dns.go
									
									
									
									
									
								
							
							
						
						
									
										27
									
								
								dns.go
									
									
									
									
									
								
							@@ -34,14 +34,25 @@ const (
 | 
			
		||||
 | 
			
		||||
// From the netmask we can find out the wildcard bits (the bits that are not set in the netmask).
 | 
			
		||||
// This allows us to then calculate the subnets included in the subsequent class block and generate the entries.
 | 
			
		||||
func generateMagicDNSRootDomains(
 | 
			
		||||
	ipPrefix netaddr.IPPrefix,
 | 
			
		||||
) []dnsname.FQDN {
 | 
			
		||||
	// TODO(juanfont): we are not handing out IPv6 addresses yet
 | 
			
		||||
	// and in fact this is Tailscale.com's range (note the fd7a:115c:a1e0: range in the fc00::/7 network)
 | 
			
		||||
	ipv6base := dnsname.FQDN("0.e.1.a.c.5.1.1.a.7.d.f.ip6.arpa.")
 | 
			
		||||
	fqdns := []dnsname.FQDN{ipv6base}
 | 
			
		||||
func generateMagicDNSRootDomains(ipPrefixes []netaddr.IPPrefix) []dnsname.FQDN {
 | 
			
		||||
	fqdns := make([]dnsname.FQDN, 0, len(ipPrefixes))
 | 
			
		||||
	for _, ipPrefix := range ipPrefixes {
 | 
			
		||||
		var generateDnsRoot func(netaddr.IPPrefix) []dnsname.FQDN
 | 
			
		||||
		switch ipPrefix.IP().BitLen() {
 | 
			
		||||
		case 32:
 | 
			
		||||
			generateDnsRoot = generateIPv4DNSRootDomain
 | 
			
		||||
 | 
			
		||||
		default:
 | 
			
		||||
			panic(fmt.Sprintf("unsupported IP version with address length %d", ipPrefix.IP().BitLen()))
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		fqdns = append(fqdns, generateDnsRoot(ipPrefix)...)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return fqdns
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func generateIPv4DNSRootDomain(ipPrefix netaddr.IPPrefix) (fqdns []dnsname.FQDN) {
 | 
			
		||||
	// Conversion to the std lib net.IPnet, a bit easier to operate
 | 
			
		||||
	netRange := ipPrefix.IPNet()
 | 
			
		||||
	maskBits, _ := netRange.Mask.Size()
 | 
			
		||||
@@ -73,7 +84,7 @@ func generateMagicDNSRootDomains(
 | 
			
		||||
		fqdns = append(fqdns, fqdn)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return fqdns
 | 
			
		||||
	return
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func getMapResponseDNSConfig(
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										16
									
								
								dns_test.go
									
									
									
									
									
								
							
							
						
						
									
										16
									
								
								dns_test.go
									
									
									
									
									
								
							@@ -124,7 +124,7 @@ func (s *Suite) TestDNSConfigMapResponseWithMagicDNS(c *check.C) {
 | 
			
		||||
		Namespace:      *namespaceShared1,
 | 
			
		||||
		Registered:     true,
 | 
			
		||||
		RegisterMethod: RegisterMethodAuthKey,
 | 
			
		||||
		IPAddress:      "100.64.0.1",
 | 
			
		||||
		IPAddresses:    []netaddr.IP{netaddr.MustParseIP("100.64.0.1")},
 | 
			
		||||
		AuthKeyID:      uint(preAuthKeyInShared1.ID),
 | 
			
		||||
	}
 | 
			
		||||
	app.db.Save(machineInShared1)
 | 
			
		||||
@@ -142,7 +142,7 @@ func (s *Suite) TestDNSConfigMapResponseWithMagicDNS(c *check.C) {
 | 
			
		||||
		Namespace:      *namespaceShared2,
 | 
			
		||||
		Registered:     true,
 | 
			
		||||
		RegisterMethod: RegisterMethodAuthKey,
 | 
			
		||||
		IPAddress:      "100.64.0.2",
 | 
			
		||||
		IPAddresses:    []netaddr.IP{netaddr.MustParseIP("100.64.0.2")},
 | 
			
		||||
		AuthKeyID:      uint(preAuthKeyInShared2.ID),
 | 
			
		||||
	}
 | 
			
		||||
	app.db.Save(machineInShared2)
 | 
			
		||||
@@ -160,7 +160,7 @@ func (s *Suite) TestDNSConfigMapResponseWithMagicDNS(c *check.C) {
 | 
			
		||||
		Namespace:      *namespaceShared3,
 | 
			
		||||
		Registered:     true,
 | 
			
		||||
		RegisterMethod: RegisterMethodAuthKey,
 | 
			
		||||
		IPAddress:      "100.64.0.3",
 | 
			
		||||
		IPAddresses:    []netaddr.IP{netaddr.MustParseIP("100.64.0.3")},
 | 
			
		||||
		AuthKeyID:      uint(preAuthKeyInShared3.ID),
 | 
			
		||||
	}
 | 
			
		||||
	app.db.Save(machineInShared3)
 | 
			
		||||
@@ -178,7 +178,7 @@ func (s *Suite) TestDNSConfigMapResponseWithMagicDNS(c *check.C) {
 | 
			
		||||
		Namespace:      *namespaceShared1,
 | 
			
		||||
		Registered:     true,
 | 
			
		||||
		RegisterMethod: RegisterMethodAuthKey,
 | 
			
		||||
		IPAddress:      "100.64.0.4",
 | 
			
		||||
		IPAddresses:    []netaddr.IP{netaddr.MustParseIP("100.64.0.4")},
 | 
			
		||||
		AuthKeyID:      uint(PreAuthKey2InShared1.ID),
 | 
			
		||||
	}
 | 
			
		||||
	app.db.Save(machine2InShared1)
 | 
			
		||||
@@ -273,7 +273,7 @@ func (s *Suite) TestDNSConfigMapResponseWithoutMagicDNS(c *check.C) {
 | 
			
		||||
		Namespace:      *namespaceShared1,
 | 
			
		||||
		Registered:     true,
 | 
			
		||||
		RegisterMethod: RegisterMethodAuthKey,
 | 
			
		||||
		IPAddress:      "100.64.0.1",
 | 
			
		||||
		IPAddresses:    []netaddr.IP{netaddr.MustParseIP("100.64.0.1")},
 | 
			
		||||
		AuthKeyID:      uint(preAuthKeyInShared1.ID),
 | 
			
		||||
	}
 | 
			
		||||
	app.db.Save(machineInShared1)
 | 
			
		||||
@@ -291,7 +291,7 @@ func (s *Suite) TestDNSConfigMapResponseWithoutMagicDNS(c *check.C) {
 | 
			
		||||
		Namespace:      *namespaceShared2,
 | 
			
		||||
		Registered:     true,
 | 
			
		||||
		RegisterMethod: RegisterMethodAuthKey,
 | 
			
		||||
		IPAddress:      "100.64.0.2",
 | 
			
		||||
		IPAddresses:    []netaddr.IP{netaddr.MustParseIP("100.64.0.2")},
 | 
			
		||||
		AuthKeyID:      uint(preAuthKeyInShared2.ID),
 | 
			
		||||
	}
 | 
			
		||||
	app.db.Save(machineInShared2)
 | 
			
		||||
@@ -309,7 +309,7 @@ func (s *Suite) TestDNSConfigMapResponseWithoutMagicDNS(c *check.C) {
 | 
			
		||||
		Namespace:      *namespaceShared3,
 | 
			
		||||
		Registered:     true,
 | 
			
		||||
		RegisterMethod: RegisterMethodAuthKey,
 | 
			
		||||
		IPAddress:      "100.64.0.3",
 | 
			
		||||
		IPAddresses:    []netaddr.IP{netaddr.MustParseIP("100.64.0.3")},
 | 
			
		||||
		AuthKeyID:      uint(preAuthKeyInShared3.ID),
 | 
			
		||||
	}
 | 
			
		||||
	app.db.Save(machineInShared3)
 | 
			
		||||
@@ -327,7 +327,7 @@ func (s *Suite) TestDNSConfigMapResponseWithoutMagicDNS(c *check.C) {
 | 
			
		||||
		Namespace:      *namespaceShared1,
 | 
			
		||||
		Registered:     true,
 | 
			
		||||
		RegisterMethod: RegisterMethodAuthKey,
 | 
			
		||||
		IPAddress:      "100.64.0.4",
 | 
			
		||||
		IPAddresses:    []netaddr.IP{netaddr.MustParseIP("100.64.0.4")},
 | 
			
		||||
		AuthKeyID:      uint(preAuthKey2InShared1.ID),
 | 
			
		||||
	}
 | 
			
		||||
	app.db.Save(machine2InShared1)
 | 
			
		||||
 
 | 
			
		||||
@@ -372,70 +372,74 @@ func (s *IntegrationTestSuite) TestListNodes() {
 | 
			
		||||
 | 
			
		||||
func (s *IntegrationTestSuite) TestGetIpAddresses() {
 | 
			
		||||
	for _, scales := range s.namespaces {
 | 
			
		||||
		ipPrefix := netaddr.MustParseIPPrefix("100.64.0.0/10")
 | 
			
		||||
		ips, err := getIPs(scales.tailscales)
 | 
			
		||||
		assert.Nil(s.T(), err)
 | 
			
		||||
 | 
			
		||||
		for hostname := range scales.tailscales {
 | 
			
		||||
			s.T().Run(hostname, func(t *testing.T) {
 | 
			
		||||
				ip, ok := ips[hostname]
 | 
			
		||||
		for hostname, _ := range scales.tailscales {
 | 
			
		||||
			ips := ips[hostname]
 | 
			
		||||
			for _, ip := range ips {
 | 
			
		||||
				s.T().Run(hostname, func(t *testing.T) {
 | 
			
		||||
					assert.NotNil(t, ip)
 | 
			
		||||
 | 
			
		||||
				assert.True(t, ok)
 | 
			
		||||
				assert.NotNil(t, ip)
 | 
			
		||||
					fmt.Printf("IP for %s: %s\n", hostname, ip)
 | 
			
		||||
 | 
			
		||||
				fmt.Printf("IP for %s: %s\n", hostname, ip)
 | 
			
		||||
 | 
			
		||||
				// c.Assert(ip.Valid(), check.IsTrue)
 | 
			
		||||
				assert.True(t, ip.Is4())
 | 
			
		||||
				assert.True(t, ipPrefix.Contains(ip))
 | 
			
		||||
			})
 | 
			
		||||
					// c.Assert(ip.Valid(), check.IsTrue)
 | 
			
		||||
					assert.True(t, ip.Is4() || ip.Is6())
 | 
			
		||||
					switch {
 | 
			
		||||
					case ip.Is4():
 | 
			
		||||
						assert.True(t, IpPrefix4.Contains(ip))
 | 
			
		||||
					case ip.Is6():
 | 
			
		||||
						assert.True(t, IpPrefix6.Contains(ip))
 | 
			
		||||
					}
 | 
			
		||||
				})
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// TODO(kradalby): fix this test
 | 
			
		||||
// We need some way to impot ipnstate.Status from multiple go packages.
 | 
			
		||||
// We need some way to import ipnstate.Status from multiple go packages.
 | 
			
		||||
// Currently it will only work with 1.18.x since that is the last
 | 
			
		||||
// version we have in go.mod
 | 
			
		||||
// func (s *IntegrationTestSuite) TestStatus() {
 | 
			
		||||
// 	for _, scales := range s.namespaces {
 | 
			
		||||
// 		ips, err := getIPs(scales.tailscales)
 | 
			
		||||
// 		assert.Nil(s.T(), err)
 | 
			
		||||
//	for _, scales := range s.namespaces {
 | 
			
		||||
//		ips, err := getIPs(scales.tailscales)
 | 
			
		||||
//		assert.Nil(s.T(), err)
 | 
			
		||||
//
 | 
			
		||||
// 		for hostname, tailscale := range scales.tailscales {
 | 
			
		||||
// 			s.T().Run(hostname, func(t *testing.T) {
 | 
			
		||||
// 				command := []string{"tailscale", "status", "--json"}
 | 
			
		||||
//		for hostname, tailscale := range scales.tailscales {
 | 
			
		||||
//			s.T().Run(hostname, func(t *testing.T) {
 | 
			
		||||
//				command := []string{"tailscale", "status", "--json"}
 | 
			
		||||
//
 | 
			
		||||
// 				fmt.Printf("Getting status for %s\n", hostname)
 | 
			
		||||
// 				result, err := ExecuteCommand(
 | 
			
		||||
// 					&tailscale,
 | 
			
		||||
// 					command,
 | 
			
		||||
// 					[]string{},
 | 
			
		||||
// 				)
 | 
			
		||||
// 				assert.Nil(t, err)
 | 
			
		||||
//				fmt.Printf("Getting status for %s\n", hostname)
 | 
			
		||||
//				result, err := ExecuteCommand(
 | 
			
		||||
//					&tailscale,
 | 
			
		||||
//					command,
 | 
			
		||||
//					[]string{},
 | 
			
		||||
//				)
 | 
			
		||||
//				assert.Nil(t, err)
 | 
			
		||||
//
 | 
			
		||||
// 				var status ipnstate.Status
 | 
			
		||||
// 				err = json.Unmarshal([]byte(result), &status)
 | 
			
		||||
// 				assert.Nil(s.T(), err)
 | 
			
		||||
//				var status ipnstate.Status
 | 
			
		||||
//				err = json.Unmarshal([]byte(result), &status)
 | 
			
		||||
//				assert.Nil(s.T(), err)
 | 
			
		||||
//
 | 
			
		||||
// 				// TODO(kradalby): Replace this check with peer length of SAME namespace
 | 
			
		||||
// 				// Check if we have as many nodes in status
 | 
			
		||||
// 				// as we have IPs/tailscales
 | 
			
		||||
// 				// lines := strings.Split(result, "\n")
 | 
			
		||||
// 				// assert.Equal(t, len(ips), len(lines)-1)
 | 
			
		||||
// 				// assert.Equal(t, len(scales.tailscales), len(lines)-1)
 | 
			
		||||
//				// TODO(kradalby): Replace this check with peer length of SAME namespace
 | 
			
		||||
//				// Check if we have as many nodes in status
 | 
			
		||||
//				// as we have IPs/tailscales
 | 
			
		||||
//				// lines := strings.Split(result, "\n")
 | 
			
		||||
//				// assert.Equal(t, len(ips), len(lines)-1)
 | 
			
		||||
//				// assert.Equal(t, len(scales.tailscales), len(lines)-1)
 | 
			
		||||
//
 | 
			
		||||
// 				peerIps := getIPsfromIPNstate(status)
 | 
			
		||||
//				peerIps := getIPsfromIPNstate(status)
 | 
			
		||||
//
 | 
			
		||||
// 				// Check that all hosts is present in all hosts status
 | 
			
		||||
// 				for ipHostname, ip := range ips {
 | 
			
		||||
// 					if hostname != ipHostname {
 | 
			
		||||
// 						assert.Contains(t, peerIps, ip)
 | 
			
		||||
// 					}
 | 
			
		||||
// 				}
 | 
			
		||||
// 			})
 | 
			
		||||
// 		}
 | 
			
		||||
// 	}
 | 
			
		||||
//				// Check that all hosts is present in all hosts status
 | 
			
		||||
//				for ipHostname, ip := range ips {
 | 
			
		||||
//					if hostname != ipHostname {
 | 
			
		||||
//						assert.Contains(t, peerIps, ip)
 | 
			
		||||
//					}
 | 
			
		||||
//				}
 | 
			
		||||
//			})
 | 
			
		||||
//		}
 | 
			
		||||
//	}
 | 
			
		||||
// }
 | 
			
		||||
 | 
			
		||||
func getIPsfromIPNstate(status ipnstate.Status) []netaddr.IP {
 | 
			
		||||
@@ -448,16 +452,19 @@ func getIPsfromIPNstate(status ipnstate.Status) []netaddr.IP {
 | 
			
		||||
	return ips
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s *IntegrationTestSuite) TestPingAllPeers() {
 | 
			
		||||
func (s *IntegrationTestSuite) TestPingAllPeersByAddress() {
 | 
			
		||||
	for _, scales := range s.namespaces {
 | 
			
		||||
		ips, err := getIPs(scales.tailscales)
 | 
			
		||||
		assert.Nil(s.T(), err)
 | 
			
		||||
 | 
			
		||||
		for hostname, tailscale := range scales.tailscales {
 | 
			
		||||
			for peername, ip := range ips {
 | 
			
		||||
				s.T().Run(fmt.Sprintf("%s-%s", hostname, peername), func(t *testing.T) {
 | 
			
		||||
			for peername, peerIPs := range ips {
 | 
			
		||||
				for i, ip := range peerIPs {
 | 
			
		||||
					// We currently cant ping ourselves, so skip that.
 | 
			
		||||
					if peername != hostname {
 | 
			
		||||
					if peername == hostname {
 | 
			
		||||
						continue
 | 
			
		||||
					}
 | 
			
		||||
					s.T().Run(fmt.Sprintf("%s-%s-%d", hostname, peername, i), func(t *testing.T) {
 | 
			
		||||
						// We are only interested in "direct ping" which means what we
 | 
			
		||||
						// might need a couple of more attempts before reaching the node.
 | 
			
		||||
						command := []string{
 | 
			
		||||
@@ -469,9 +476,8 @@ func (s *IntegrationTestSuite) TestPingAllPeers() {
 | 
			
		||||
						}
 | 
			
		||||
 | 
			
		||||
						fmt.Printf(
 | 
			
		||||
							"Pinging from %s (%s) to %s (%s)\n",
 | 
			
		||||
							"Pinging from %s to %s (%s)\n",
 | 
			
		||||
							hostname,
 | 
			
		||||
							ips[hostname],
 | 
			
		||||
							peername,
 | 
			
		||||
							ip,
 | 
			
		||||
						)
 | 
			
		||||
@@ -483,8 +489,8 @@ func (s *IntegrationTestSuite) TestPingAllPeers() {
 | 
			
		||||
						assert.Nil(t, err)
 | 
			
		||||
						fmt.Printf("Result for %s: %s\n", hostname, result)
 | 
			
		||||
						assert.Contains(t, result, "pong")
 | 
			
		||||
					}
 | 
			
		||||
				})
 | 
			
		||||
					})
 | 
			
		||||
				}
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
@@ -553,17 +559,17 @@ func (s *IntegrationTestSuite) TestSharedNodes() {
 | 
			
		||||
	// TODO(juanfont): We have to find out why do we need to wait
 | 
			
		||||
	time.Sleep(100 * time.Second) // Wait for the nodes to receive updates
 | 
			
		||||
 | 
			
		||||
	mainIps, err := getIPs(main.tailscales)
 | 
			
		||||
	assert.Nil(s.T(), err)
 | 
			
		||||
 | 
			
		||||
	sharedIps, err := getIPs(shared.tailscales)
 | 
			
		||||
	assert.Nil(s.T(), err)
 | 
			
		||||
 | 
			
		||||
	for hostname, tailscale := range main.tailscales {
 | 
			
		||||
		for peername, ip := range sharedIps {
 | 
			
		||||
			s.T().Run(fmt.Sprintf("%s-%s", hostname, peername), func(t *testing.T) {
 | 
			
		||||
		for peername, peerIPs := range sharedIps {
 | 
			
		||||
			for i, ip := range peerIPs {
 | 
			
		||||
				// We currently cant ping ourselves, so skip that.
 | 
			
		||||
				if peername != hostname {
 | 
			
		||||
				if peername == hostname {
 | 
			
		||||
					continue
 | 
			
		||||
				}
 | 
			
		||||
				s.T().Run(fmt.Sprintf("%s-%s-%d", hostname, peername, i), func(t *testing.T) {
 | 
			
		||||
					// We are only interested in "direct ping" which means what we
 | 
			
		||||
					// might need a couple of more attempts before reaching the node.
 | 
			
		||||
					command := []string{
 | 
			
		||||
@@ -575,9 +581,8 @@ func (s *IntegrationTestSuite) TestSharedNodes() {
 | 
			
		||||
					}
 | 
			
		||||
 | 
			
		||||
					fmt.Printf(
 | 
			
		||||
						"Pinging from %s (%s) to %s (%s)\n",
 | 
			
		||||
						"Pinging from %s to %s (%s)\n",
 | 
			
		||||
						hostname,
 | 
			
		||||
						mainIps[hostname],
 | 
			
		||||
						peername,
 | 
			
		||||
						ip,
 | 
			
		||||
					)
 | 
			
		||||
@@ -589,8 +594,8 @@ func (s *IntegrationTestSuite) TestSharedNodes() {
 | 
			
		||||
					assert.Nil(t, err)
 | 
			
		||||
					fmt.Printf("Result for %s: %s\n", hostname, result)
 | 
			
		||||
					assert.Contains(t, result, "pong")
 | 
			
		||||
				}
 | 
			
		||||
			})
 | 
			
		||||
				})
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
@@ -607,7 +612,7 @@ func (s *IntegrationTestSuite) TestTailDrop() {
 | 
			
		||||
			_, err := ExecuteCommand(
 | 
			
		||||
				&tailscale,
 | 
			
		||||
				command,
 | 
			
		||||
				[]string{},
 | 
			
		||||
				[]string{"GOMAXPROCS=32"},
 | 
			
		||||
			)
 | 
			
		||||
			assert.Nil(s.T(), err)
 | 
			
		||||
			for peername, ip := range ips {
 | 
			
		||||
@@ -653,7 +658,7 @@ func (s *IntegrationTestSuite) TestTailDrop() {
 | 
			
		||||
							_, err = ExecuteCommand(
 | 
			
		||||
								&tailscale,
 | 
			
		||||
								command,
 | 
			
		||||
								[]string{"ALL_PROXY=socks5://localhost:1055"},
 | 
			
		||||
								[]string{"ALL_PROXY=socks5://localhost:1055", "GOMAXPROCS=32"},
 | 
			
		||||
							)
 | 
			
		||||
							if err == nil {
 | 
			
		||||
								break
 | 
			
		||||
@@ -684,78 +689,125 @@ func (s *IntegrationTestSuite) TestTailDrop() {
 | 
			
		||||
			)
 | 
			
		||||
			assert.Nil(s.T(), err)
 | 
			
		||||
			for peername, ip := range ips {
 | 
			
		||||
				if peername == hostname {
 | 
			
		||||
					continue
 | 
			
		||||
				}
 | 
			
		||||
				s.T().Run(fmt.Sprintf("%s-%s", hostname, peername), func(t *testing.T) {
 | 
			
		||||
					if peername != hostname {
 | 
			
		||||
						command := []string{
 | 
			
		||||
							"ls",
 | 
			
		||||
							fmt.Sprintf("/tmp/file_from_%s", peername),
 | 
			
		||||
						}
 | 
			
		||||
						fmt.Printf(
 | 
			
		||||
							"Checking file in %s (%s) from %s (%s)\n",
 | 
			
		||||
							hostname,
 | 
			
		||||
							ips[hostname],
 | 
			
		||||
							peername,
 | 
			
		||||
							ip,
 | 
			
		||||
						)
 | 
			
		||||
						result, err := ExecuteCommand(
 | 
			
		||||
							&tailscale,
 | 
			
		||||
							command,
 | 
			
		||||
							[]string{},
 | 
			
		||||
						)
 | 
			
		||||
						assert.Nil(t, err)
 | 
			
		||||
						fmt.Printf("Result for %s: %s\n", peername, result)
 | 
			
		||||
						assert.Equal(
 | 
			
		||||
							t,
 | 
			
		||||
							result,
 | 
			
		||||
							fmt.Sprintf("/tmp/file_from_%s\n", peername),
 | 
			
		||||
						)
 | 
			
		||||
					command := []string{
 | 
			
		||||
						"ls",
 | 
			
		||||
						fmt.Sprintf("/tmp/file_from_%s", peername),
 | 
			
		||||
					}
 | 
			
		||||
					fmt.Printf(
 | 
			
		||||
						"Checking file in %s (%s) from %s (%s)\n",
 | 
			
		||||
						hostname,
 | 
			
		||||
						ips[hostname],
 | 
			
		||||
						peername,
 | 
			
		||||
						ip,
 | 
			
		||||
					)
 | 
			
		||||
					result, err := ExecuteCommand(
 | 
			
		||||
						&tailscale,
 | 
			
		||||
						command,
 | 
			
		||||
						[]string{},
 | 
			
		||||
					)
 | 
			
		||||
					assert.Nil(t, err)
 | 
			
		||||
					fmt.Printf("Result for %s: %s\n", peername, result)
 | 
			
		||||
					assert.Equal(
 | 
			
		||||
						t,
 | 
			
		||||
						fmt.Sprintf("/tmp/file_from_%s\n", peername),
 | 
			
		||||
						result,
 | 
			
		||||
					)
 | 
			
		||||
				})
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s *IntegrationTestSuite) TestMagicDNS() {
 | 
			
		||||
func (s *IntegrationTestSuite) TestPingAllPeersByHostname() {
 | 
			
		||||
	for namespace, scales := range s.namespaces {
 | 
			
		||||
		ips, err := getIPs(scales.tailscales)
 | 
			
		||||
		assert.Nil(s.T(), err)
 | 
			
		||||
		for hostname, tailscale := range scales.tailscales {
 | 
			
		||||
			for peername, ip := range ips {
 | 
			
		||||
			for peername, _ := range ips {
 | 
			
		||||
				if peername == hostname {
 | 
			
		||||
					continue
 | 
			
		||||
				}
 | 
			
		||||
				s.T().Run(fmt.Sprintf("%s-%s", hostname, peername), func(t *testing.T) {
 | 
			
		||||
					if peername != hostname {
 | 
			
		||||
						command := []string{
 | 
			
		||||
							"tailscale", "ping",
 | 
			
		||||
							"--timeout=10s",
 | 
			
		||||
							"--c=20",
 | 
			
		||||
							"--until-direct=true",
 | 
			
		||||
							fmt.Sprintf("%s.%s.headscale.net", peername, namespace),
 | 
			
		||||
						}
 | 
			
		||||
 | 
			
		||||
						fmt.Printf(
 | 
			
		||||
							"Pinging using Hostname (magicdns) from %s (%s) to %s (%s)\n",
 | 
			
		||||
							hostname,
 | 
			
		||||
							ips[hostname],
 | 
			
		||||
							peername,
 | 
			
		||||
							ip,
 | 
			
		||||
						)
 | 
			
		||||
						result, err := ExecuteCommand(
 | 
			
		||||
							&tailscale,
 | 
			
		||||
							command,
 | 
			
		||||
							[]string{},
 | 
			
		||||
						)
 | 
			
		||||
						assert.Nil(t, err)
 | 
			
		||||
						fmt.Printf("Result for %s: %s\n", hostname, result)
 | 
			
		||||
						assert.Contains(t, result, "pong")
 | 
			
		||||
					command := []string{
 | 
			
		||||
						"tailscale", "ping",
 | 
			
		||||
						"--timeout=10s",
 | 
			
		||||
						"--c=20",
 | 
			
		||||
						"--until-direct=true",
 | 
			
		||||
						fmt.Sprintf("%s.%s.headscale.net", peername, namespace),
 | 
			
		||||
					}
 | 
			
		||||
 | 
			
		||||
					fmt.Printf(
 | 
			
		||||
						"Pinging using Hostname from %s to %s\n",
 | 
			
		||||
						hostname,
 | 
			
		||||
						peername,
 | 
			
		||||
					)
 | 
			
		||||
					result, err := ExecuteCommand(
 | 
			
		||||
						&tailscale,
 | 
			
		||||
						command,
 | 
			
		||||
						[]string{},
 | 
			
		||||
					)
 | 
			
		||||
					assert.Nil(t, err)
 | 
			
		||||
					fmt.Printf("Result for %s: %s\n", hostname, result)
 | 
			
		||||
					assert.Contains(t, result, "pong")
 | 
			
		||||
				})
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func getIPs(tailscales map[string]dockertest.Resource) (map[string]netaddr.IP, error) {
 | 
			
		||||
	ips := make(map[string]netaddr.IP)
 | 
			
		||||
// TODO:
 | 
			
		||||
// * With manual testing, MagicDNS does not respond to AAAA queries. Why?
 | 
			
		||||
// * Tailscaled only adds a route to the IPv4 (100.100.100.100) address of the MagicDNS service,
 | 
			
		||||
//   event though there is an IPv6 one (fd7a:115c:a1e0::53) as well.
 | 
			
		||||
func (s *IntegrationTestSuite) TestMagicDNSv4() {
 | 
			
		||||
	for namespace, scales := range s.namespaces {
 | 
			
		||||
		ips, err := getIPs(scales.tailscales)
 | 
			
		||||
		assert.Nil(s.T(), err)
 | 
			
		||||
		for hostname, tailscale := range scales.tailscales {
 | 
			
		||||
			for peername, ips := range ips {
 | 
			
		||||
				if peername == hostname {
 | 
			
		||||
					continue
 | 
			
		||||
				}
 | 
			
		||||
				s.T().Run(fmt.Sprintf("%s-%s-ipv4", hostname, peername), func(t *testing.T) {
 | 
			
		||||
					command := []string{
 | 
			
		||||
						"host", "-4", "-t", "A",
 | 
			
		||||
						fmt.Sprintf("%s.%s.headscale.net", peername, namespace),
 | 
			
		||||
						"100.100.100.100",
 | 
			
		||||
					}
 | 
			
		||||
 | 
			
		||||
					fmt.Printf(
 | 
			
		||||
						"Resolving name %s (IPv4) from %s over IPv4\n",
 | 
			
		||||
						peername,
 | 
			
		||||
						hostname,
 | 
			
		||||
					)
 | 
			
		||||
					result, err := ExecuteCommand(
 | 
			
		||||
						&tailscale,
 | 
			
		||||
						command,
 | 
			
		||||
						[]string{},
 | 
			
		||||
					)
 | 
			
		||||
					assert.Nil(t, err)
 | 
			
		||||
					fmt.Printf("Result for %s: %s\n", hostname, result)
 | 
			
		||||
 | 
			
		||||
					resolved := false
 | 
			
		||||
					for _, ip := range ips {
 | 
			
		||||
						if strings.Contains(result, fmt.Sprintf("has address %s", ip.String())) {
 | 
			
		||||
							resolved = true
 | 
			
		||||
							break
 | 
			
		||||
						}
 | 
			
		||||
					}
 | 
			
		||||
					assert.Equal(t, true, resolved)
 | 
			
		||||
				})
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func getIPs(tailscales map[string]dockertest.Resource) (map[string][]netaddr.IP, error) {
 | 
			
		||||
	ips := make(map[string][]netaddr.IP)
 | 
			
		||||
	for hostname, tailscale := range tailscales {
 | 
			
		||||
		command := []string{"tailscale", "ip"}
 | 
			
		||||
 | 
			
		||||
@@ -768,12 +820,17 @@ func getIPs(tailscales map[string]dockertest.Resource) (map[string]netaddr.IP, e
 | 
			
		||||
			return nil, err
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		ip, err := netaddr.ParseIP(strings.TrimSuffix(result, "\n"))
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return nil, err
 | 
			
		||||
		for _, address := range strings.Split(result, "\n") {
 | 
			
		||||
			address = strings.TrimSuffix(address, "\n")
 | 
			
		||||
			if len(address) < 1 {
 | 
			
		||||
				continue
 | 
			
		||||
			}
 | 
			
		||||
			ip, err := netaddr.ParseIP(address)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				return nil, err
 | 
			
		||||
			}
 | 
			
		||||
			ips[hostname] = append(ips[hostname], ip)
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		ips[hostname] = ip
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return ips, nil
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										81
									
								
								machine.go
									
									
									
									
									
								
							
							
						
						
									
										81
									
								
								machine.go
									
									
									
									
									
								
							@@ -1,6 +1,7 @@
 | 
			
		||||
package headscale
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"database/sql/driver"
 | 
			
		||||
	"encoding/json"
 | 
			
		||||
	"errors"
 | 
			
		||||
	"fmt"
 | 
			
		||||
@@ -23,6 +24,7 @@ const (
 | 
			
		||||
	errMachineNotFound            = Error("machine not found")
 | 
			
		||||
	errMachineAlreadyRegistered   = Error("machine already registered")
 | 
			
		||||
	errMachineRouteIsNotAvailable = Error("route is not available on machine")
 | 
			
		||||
	errMachineAddressesInvalid    = Error("failed to parse machine addresses")
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// Machine is a Headscale client.
 | 
			
		||||
@@ -31,7 +33,7 @@ type Machine struct {
 | 
			
		||||
	MachineKey  string `gorm:"type:varchar(64);unique_index"`
 | 
			
		||||
	NodeKey     string
 | 
			
		||||
	DiscoKey    string
 | 
			
		||||
	IPAddress   string
 | 
			
		||||
	IPAddresses MachineAddresses
 | 
			
		||||
	Name        string
 | 
			
		||||
	NamespaceID uint
 | 
			
		||||
	Namespace   Namespace `gorm:"foreignKey:NamespaceID"`
 | 
			
		||||
@@ -64,6 +66,47 @@ func (machine Machine) isRegistered() bool {
 | 
			
		||||
	return machine.Registered
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type MachineAddresses []netaddr.IP
 | 
			
		||||
 | 
			
		||||
func (ma MachineAddresses) ToStringSlice() []string {
 | 
			
		||||
	strSlice := make([]string, 0, len(ma))
 | 
			
		||||
	for _, addr := range ma {
 | 
			
		||||
		strSlice = append(strSlice, addr.String())
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return strSlice
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (ma *MachineAddresses) Scan(destination interface{}) error {
 | 
			
		||||
	switch value := destination.(type) {
 | 
			
		||||
	case string:
 | 
			
		||||
		addresses := strings.Split(value, ",")
 | 
			
		||||
		*ma = (*ma)[:0]
 | 
			
		||||
		for _, addr := range addresses {
 | 
			
		||||
			if len(addr) < 1 {
 | 
			
		||||
				continue
 | 
			
		||||
			}
 | 
			
		||||
			parsed, err := netaddr.ParseIP(addr)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				return err
 | 
			
		||||
			}
 | 
			
		||||
			*ma = append(*ma, parsed)
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		return nil
 | 
			
		||||
 | 
			
		||||
	default:
 | 
			
		||||
		return fmt.Errorf("%w: unexpected data type %T", errMachineAddressesInvalid, destination)
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Value return json value, implement driver.Valuer interface.
 | 
			
		||||
func (ma MachineAddresses) Value() (driver.Value, error) {
 | 
			
		||||
	addresses := strings.Join(ma.ToStringSlice(), ",")
 | 
			
		||||
 | 
			
		||||
	return addresses, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// isExpired returns whether the machine registration has expired.
 | 
			
		||||
func (machine Machine) isExpired() bool {
 | 
			
		||||
	// If Expiry is not set, the client has not indicated that
 | 
			
		||||
@@ -470,22 +513,12 @@ func (machine Machine) toNode(
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	addrs := []netaddr.IPPrefix{}
 | 
			
		||||
	nodeAddr, err := netaddr.ParseIP(m.IPAddresses)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		log.Trace().
 | 
			
		||||
			Caller().
 | 
			
		||||
			Str("ip", machine.IPAddresses).
 | 
			
		||||
			Msgf("Failed to parse machine IP: %s", machine.IPAddresses)
 | 
			
		||||
		return nil, err
 | 
			
		||||
	for _, machineAddress := range machine.IPAddresses {
 | 
			
		||||
		ip := netaddr.IPPrefixFrom(machineAddress, machineAddress.BitLen())
 | 
			
		||||
		addrs = append(addrs, ip)
 | 
			
		||||
	}
 | 
			
		||||
	ip := netaddr.IPPrefixFrom(nodeAddr, nodeAddr.BitLen())
 | 
			
		||||
	addrs = append(addrs, ip)
 | 
			
		||||
 | 
			
		||||
	allowedIPs := []netaddr.IPPrefix{}
 | 
			
		||||
	allowedIPs = append(
 | 
			
		||||
		allowedIPs,
 | 
			
		||||
		ip,
 | 
			
		||||
	) // we append the node own IP, as it is required by the clients
 | 
			
		||||
	allowedIPs := append([]netaddr.IPPrefix{}, addrs...) // we append the node own IP, as it is required by the clients
 | 
			
		||||
 | 
			
		||||
	if includeRoutes {
 | 
			
		||||
		routesStr := []string{}
 | 
			
		||||
@@ -592,11 +625,11 @@ func (machine *Machine) toProto() *v1.Machine {
 | 
			
		||||
		Id:         machine.ID,
 | 
			
		||||
		MachineKey: machine.MachineKey,
 | 
			
		||||
 | 
			
		||||
		NodeKey:   machine.NodeKey,
 | 
			
		||||
		DiscoKey:  machine.DiscoKey,
 | 
			
		||||
		IpAddress: machine.IPAddress,
 | 
			
		||||
		Name:      machine.Name,
 | 
			
		||||
		Namespace: machine.Namespace.toProto(),
 | 
			
		||||
		NodeKey:     machine.NodeKey,
 | 
			
		||||
		DiscoKey:    machine.DiscoKey,
 | 
			
		||||
		IpAddresses: machine.IPAddresses.ToStringSlice(),
 | 
			
		||||
		Name:        machine.Name,
 | 
			
		||||
		Namespace:   machine.Namespace.toProto(),
 | 
			
		||||
 | 
			
		||||
		Registered: machine.Registered,
 | 
			
		||||
 | 
			
		||||
@@ -695,7 +728,7 @@ func (h *Headscale) RegisterMachine(
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	ip, err := h.getAvailableIP()
 | 
			
		||||
	ips, err := h.getAvailableIPs()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		log.Error().
 | 
			
		||||
			Caller().
 | 
			
		||||
@@ -709,10 +742,10 @@ func (h *Headscale) RegisterMachine(
 | 
			
		||||
	log.Trace().
 | 
			
		||||
		Caller().
 | 
			
		||||
		Str("machine", machine.Name).
 | 
			
		||||
		Str("ip", ip.String()).
 | 
			
		||||
		Str("ip", strings.Join(ips.ToStringSlice(), ",")).
 | 
			
		||||
		Msg("Found IP for host")
 | 
			
		||||
 | 
			
		||||
	machine.IPAddress = ip.String()
 | 
			
		||||
	machine.IPAddresses = ips
 | 
			
		||||
	machine.NamespaceID = namespace.ID
 | 
			
		||||
	machine.Registered = true
 | 
			
		||||
	machine.RegisterMethod = RegisterMethodCLI
 | 
			
		||||
@@ -722,7 +755,7 @@ func (h *Headscale) RegisterMachine(
 | 
			
		||||
	log.Trace().
 | 
			
		||||
		Caller().
 | 
			
		||||
		Str("machine", machine.Name).
 | 
			
		||||
		Str("ip", ip.String()).
 | 
			
		||||
		Str("ip", strings.Join(ips.ToStringSlice(), ",")).
 | 
			
		||||
		Msg("Machine registered with the database")
 | 
			
		||||
 | 
			
		||||
	return machine, nil
 | 
			
		||||
 
 | 
			
		||||
@@ -6,6 +6,7 @@ import (
 | 
			
		||||
	"time"
 | 
			
		||||
 | 
			
		||||
	"gopkg.in/check.v1"
 | 
			
		||||
	"inet.af/netaddr"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func (s *Suite) TestGetMachine(c *check.C) {
 | 
			
		||||
@@ -199,3 +200,22 @@ func (s *Suite) TestExpireMachine(c *check.C) {
 | 
			
		||||
 | 
			
		||||
	c.Assert(machineFromDB.isExpired(), check.Equals, true)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s *Suite) TestSerdeAddressStrignSlice(c *check.C) {
 | 
			
		||||
	input := MachineAddresses([]netaddr.IP{
 | 
			
		||||
		netaddr.MustParseIP("192.0.2.1"),
 | 
			
		||||
		netaddr.MustParseIP("2001:db8::1"),
 | 
			
		||||
	})
 | 
			
		||||
	serialized, err := input.Value()
 | 
			
		||||
	c.Assert(err, check.IsNil)
 | 
			
		||||
	c.Assert(serialized.(string), check.Equals, "192.0.2.1,2001:db8::1")
 | 
			
		||||
 | 
			
		||||
	var deserialized MachineAddresses
 | 
			
		||||
	err = deserialized.Scan(serialized)
 | 
			
		||||
	c.Assert(err, check.IsNil)
 | 
			
		||||
 | 
			
		||||
	c.Assert(len(deserialized), check.Equals, len(input))
 | 
			
		||||
	for i := range deserialized {
 | 
			
		||||
		c.Assert(deserialized[i], check.Equals, input[i])
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -4,6 +4,7 @@ import (
 | 
			
		||||
	"github.com/rs/zerolog/log"
 | 
			
		||||
	"gopkg.in/check.v1"
 | 
			
		||||
	"gorm.io/gorm"
 | 
			
		||||
	"inet.af/netaddr"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func (s *Suite) TestCreateAndDestroyNamespace(c *check.C) {
 | 
			
		||||
@@ -146,7 +147,7 @@ func (s *Suite) TestGetMapResponseUserProfiles(c *check.C) {
 | 
			
		||||
		Namespace:      *namespaceShared1,
 | 
			
		||||
		Registered:     true,
 | 
			
		||||
		RegisterMethod: RegisterMethodAuthKey,
 | 
			
		||||
		IPAddress:      "100.64.0.1",
 | 
			
		||||
		IPAddresses:    []netaddr.IP{netaddr.MustParseIP("100.64.0.1")},
 | 
			
		||||
		AuthKeyID:      uint(preAuthKeyShared1.ID),
 | 
			
		||||
	}
 | 
			
		||||
	app.db.Save(machineInShared1)
 | 
			
		||||
@@ -164,7 +165,7 @@ func (s *Suite) TestGetMapResponseUserProfiles(c *check.C) {
 | 
			
		||||
		Namespace:      *namespaceShared2,
 | 
			
		||||
		Registered:     true,
 | 
			
		||||
		RegisterMethod: RegisterMethodAuthKey,
 | 
			
		||||
		IPAddress:      "100.64.0.2",
 | 
			
		||||
		IPAddresses:    []netaddr.IP{netaddr.MustParseIP("100.64.0.2")},
 | 
			
		||||
		AuthKeyID:      uint(preAuthKeyShared2.ID),
 | 
			
		||||
	}
 | 
			
		||||
	app.db.Save(machineInShared2)
 | 
			
		||||
@@ -182,7 +183,7 @@ func (s *Suite) TestGetMapResponseUserProfiles(c *check.C) {
 | 
			
		||||
		Namespace:      *namespaceShared3,
 | 
			
		||||
		Registered:     true,
 | 
			
		||||
		RegisterMethod: RegisterMethodAuthKey,
 | 
			
		||||
		IPAddress:      "100.64.0.3",
 | 
			
		||||
		IPAddresses:    []netaddr.IP{netaddr.MustParseIP("100.64.0.3")},
 | 
			
		||||
		AuthKeyID:      uint(preAuthKeyShared3.ID),
 | 
			
		||||
	}
 | 
			
		||||
	app.db.Save(machineInShared3)
 | 
			
		||||
@@ -200,7 +201,7 @@ func (s *Suite) TestGetMapResponseUserProfiles(c *check.C) {
 | 
			
		||||
		Namespace:      *namespaceShared1,
 | 
			
		||||
		Registered:     true,
 | 
			
		||||
		RegisterMethod: RegisterMethodAuthKey,
 | 
			
		||||
		IPAddress:      "100.64.0.4",
 | 
			
		||||
		IPAddresses:    []netaddr.IP{netaddr.MustParseIP("100.64.0.4")},
 | 
			
		||||
		AuthKeyID:      uint(preAuthKey2Shared1.ID),
 | 
			
		||||
	}
 | 
			
		||||
	app.db.Save(machine2InShared1)
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										5
									
								
								oidc.go
									
									
									
									
									
								
							
							
						
						
									
										5
									
								
								oidc.go
									
									
									
									
									
								
							@@ -126,6 +126,7 @@ var oidcCallbackTemplate = template.Must(
 | 
			
		||||
	</html>`),
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// TODO: Why is the entire machine registration logic duplicated here?
 | 
			
		||||
// OIDCCallback handles the callback from the OIDC endpoint
 | 
			
		||||
// Retrieves the mkey from the state cache and adds the machine to the users email namespace
 | 
			
		||||
// TODO: A confirmation page for new machines should be added to avoid phishing vulnerabilities
 | 
			
		||||
@@ -316,7 +317,7 @@ func (h *Headscale) OIDCCallback(ctx *gin.Context) {
 | 
			
		||||
				return
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			ip, err := h.getAvailableIP()
 | 
			
		||||
			ips, err := h.getAvailableIPs()
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				log.Error().
 | 
			
		||||
					Caller().
 | 
			
		||||
@@ -330,7 +331,7 @@ func (h *Headscale) OIDCCallback(ctx *gin.Context) {
 | 
			
		||||
				return
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			machine.IPAddress = ip.String()
 | 
			
		||||
			machine.IPAddresses = ips
 | 
			
		||||
			machine.NamespaceID = namespace.ID
 | 
			
		||||
			machine.Registered = true
 | 
			
		||||
			machine.RegisterMethod = RegisterMethodOIDC
 | 
			
		||||
 
 | 
			
		||||
@@ -2,6 +2,7 @@ package headscale
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"gopkg.in/check.v1"
 | 
			
		||||
	"inet.af/netaddr"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func CreateNodeNamespace(
 | 
			
		||||
@@ -26,7 +27,7 @@ func CreateNodeNamespace(
 | 
			
		||||
		NamespaceID:    namespace.ID,
 | 
			
		||||
		Registered:     true,
 | 
			
		||||
		RegisterMethod: RegisterMethodAuthKey,
 | 
			
		||||
		IPAddress:      ip,
 | 
			
		||||
		IPAddresses:    []netaddr.IP{netaddr.MustParseIP("100.64.0.1")},
 | 
			
		||||
		AuthKeyID:      uint(pak1.ID),
 | 
			
		||||
	}
 | 
			
		||||
	app.db.Save(machine)
 | 
			
		||||
@@ -214,7 +215,7 @@ func (s *Suite) TestComplexSharingAcrossNamespaces(c *check.C) {
 | 
			
		||||
		NamespaceID:    namespace1.ID,
 | 
			
		||||
		Registered:     true,
 | 
			
		||||
		RegisterMethod: RegisterMethodAuthKey,
 | 
			
		||||
		IPAddress:      "100.64.0.4",
 | 
			
		||||
		IPAddresses:    []netaddr.IP{netaddr.MustParseIP("100.64.0.4")},
 | 
			
		||||
		AuthKeyID:      uint(pak4.ID),
 | 
			
		||||
	}
 | 
			
		||||
	app.db.Save(machine4)
 | 
			
		||||
@@ -294,7 +295,7 @@ func (s *Suite) TestDeleteSharedMachine(c *check.C) {
 | 
			
		||||
		NamespaceID:    namespace1.ID,
 | 
			
		||||
		Registered:     true,
 | 
			
		||||
		RegisterMethod: RegisterMethodAuthKey,
 | 
			
		||||
		IPAddress:      "100.64.0.4",
 | 
			
		||||
		IPAddresses:    []netaddr.IP{netaddr.MustParseIP("100.64.0.4")},
 | 
			
		||||
		AuthKeyID:      uint(pak4n1.ID),
 | 
			
		||||
	}
 | 
			
		||||
	app.db.Save(machine4)
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										43
									
								
								utils.go
									
									
									
									
									
								
							
							
						
						
									
										43
									
								
								utils.go
									
									
									
									
									
								
							@@ -133,9 +133,24 @@ func encode(
 | 
			
		||||
	return privKey.SealTo(*pubKey, b), nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (h *Headscale) getAvailableIP() (*netaddr.IP, error) {
 | 
			
		||||
	ipPrefix := h.cfg.IPPrefix
 | 
			
		||||
func (h *Headscale) getAvailableIPs() (ips MachineAddresses, err error) {
 | 
			
		||||
	ipPrefixes := h.cfg.IPPrefixes
 | 
			
		||||
	for _, ipPrefix := range ipPrefixes {
 | 
			
		||||
		var ip *netaddr.IP
 | 
			
		||||
		ip, err = h.getAvailableIP(ipPrefix)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
		ips = append(ips, *ip)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// TODO: Is this concurrency safe?
 | 
			
		||||
// What would happen if multiple hosts were to register at the same time?
 | 
			
		||||
// Would we attempt to assign the same addresses to multiple nodes?
 | 
			
		||||
func (h *Headscale) getAvailableIP(ipPrefix netaddr.IPPrefix) (*netaddr.IP, error) {
 | 
			
		||||
	usedIps, err := h.getUsedIPs()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
@@ -143,6 +158,7 @@ func (h *Headscale) getAvailableIP() (*netaddr.IP, error) {
 | 
			
		||||
 | 
			
		||||
	ipPrefixNetworkAddress, ipPrefixBroadcastAddress := func() (netaddr.IP, netaddr.IP) {
 | 
			
		||||
		ipRange := ipPrefix.Range()
 | 
			
		||||
 | 
			
		||||
		return ipRange.From(), ipRange.To()
 | 
			
		||||
	}()
 | 
			
		||||
 | 
			
		||||
@@ -171,19 +187,20 @@ func (h *Headscale) getAvailableIP() (*netaddr.IP, error) {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (h *Headscale) getUsedIPs() ([]netaddr.IP, error) {
 | 
			
		||||
	var addresses []string
 | 
			
		||||
	h.db.Model(&Machine{}).Pluck("ip_address", &addresses)
 | 
			
		||||
	// FIXME: This really deserves a better data model,
 | 
			
		||||
	// but this was quick to get running and it should be enough
 | 
			
		||||
	// to begin experimenting with a dual stack tailnet.
 | 
			
		||||
	var addressesSlices []string
 | 
			
		||||
	h.db.Model(&Machine{}).Pluck("ip_addresses", &addressesSlices)
 | 
			
		||||
 | 
			
		||||
	ips := make([]netaddr.IP, len(addresses))
 | 
			
		||||
	for index, addr := range addresses {
 | 
			
		||||
		if addr != "" {
 | 
			
		||||
			ip, err := netaddr.ParseIP(addr)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				return nil, fmt.Errorf("failed to parse ip from database: %w", err)
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			ips[index] = ip
 | 
			
		||||
	ips := make([]netaddr.IP, 0, len(h.cfg.IPPrefixes)*len(addressesSlices))
 | 
			
		||||
	for _, slice := range addressesSlices {
 | 
			
		||||
		var a MachineAddresses
 | 
			
		||||
		err := a.Scan(slice)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return nil, fmt.Errorf("failed to read ip from database: %w", err)
 | 
			
		||||
		}
 | 
			
		||||
		ips = append(ips, a...)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return ips, nil
 | 
			
		||||
 
 | 
			
		||||
@@ -6,17 +6,18 @@ import (
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func (s *Suite) TestGetAvailableIp(c *check.C) {
 | 
			
		||||
	ip, err := app.getAvailableIP()
 | 
			
		||||
	ips, err := app.getAvailableIPs()
 | 
			
		||||
 | 
			
		||||
	c.Assert(err, check.IsNil)
 | 
			
		||||
 | 
			
		||||
	expected := netaddr.MustParseIP("10.27.0.1")
 | 
			
		||||
 | 
			
		||||
	c.Assert(ip.String(), check.Equals, expected.String())
 | 
			
		||||
	c.Assert(len(ips), check.Equals, 1)
 | 
			
		||||
	c.Assert(ips[0].String(), check.Equals, expected.String())
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s *Suite) TestGetUsedIps(c *check.C) {
 | 
			
		||||
	ip, err := app.getAvailableIP()
 | 
			
		||||
	ips, err := app.getAvailableIPs()
 | 
			
		||||
	c.Assert(err, check.IsNil)
 | 
			
		||||
 | 
			
		||||
	namespace, err := app.CreateNamespace("test_ip")
 | 
			
		||||
@@ -38,22 +39,24 @@ func (s *Suite) TestGetUsedIps(c *check.C) {
 | 
			
		||||
		Registered:     true,
 | 
			
		||||
		RegisterMethod: RegisterMethodAuthKey,
 | 
			
		||||
		AuthKeyID:      uint(pak.ID),
 | 
			
		||||
		IPAddress:      ip.String(),
 | 
			
		||||
		IPAddresses:    ips,
 | 
			
		||||
	}
 | 
			
		||||
	app.db.Save(&machine)
 | 
			
		||||
 | 
			
		||||
	ips, err := app.getUsedIPs()
 | 
			
		||||
	usedIps, err := app.getUsedIPs()
 | 
			
		||||
 | 
			
		||||
	c.Assert(err, check.IsNil)
 | 
			
		||||
 | 
			
		||||
	expected := netaddr.MustParseIP("10.27.0.1")
 | 
			
		||||
 | 
			
		||||
	c.Assert(ips[0], check.Equals, expected)
 | 
			
		||||
	c.Assert(len(usedIps), check.Equals, 1)
 | 
			
		||||
	c.Assert(usedIps[0], check.Equals, expected)
 | 
			
		||||
 | 
			
		||||
	machine1, err := app.GetMachineByID(0)
 | 
			
		||||
	c.Assert(err, check.IsNil)
 | 
			
		||||
 | 
			
		||||
	c.Assert(machine1.IPAddress, check.Equals, expected.String())
 | 
			
		||||
	c.Assert(len(machine1.IPAddresses), check.Equals, 1)
 | 
			
		||||
	c.Assert(machine1.IPAddresses[0], check.Equals, expected)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s *Suite) TestGetMultiIp(c *check.C) {
 | 
			
		||||
@@ -61,7 +64,7 @@ func (s *Suite) TestGetMultiIp(c *check.C) {
 | 
			
		||||
	c.Assert(err, check.IsNil)
 | 
			
		||||
 | 
			
		||||
	for index := 1; index <= 350; index++ {
 | 
			
		||||
		ip, err := app.getAvailableIP()
 | 
			
		||||
		ips, err := app.getAvailableIPs()
 | 
			
		||||
		c.Assert(err, check.IsNil)
 | 
			
		||||
 | 
			
		||||
		pak, err := app.CreatePreAuthKey(namespace.Name, false, false, nil)
 | 
			
		||||
@@ -80,59 +83,64 @@ func (s *Suite) TestGetMultiIp(c *check.C) {
 | 
			
		||||
			Registered:     true,
 | 
			
		||||
			RegisterMethod: RegisterMethodAuthKey,
 | 
			
		||||
			AuthKeyID:      uint(pak.ID),
 | 
			
		||||
			IPAddress:      ip.String(),
 | 
			
		||||
			IPAddresses:    ips,
 | 
			
		||||
		}
 | 
			
		||||
		app.db.Save(&machine)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	ips, err := app.getUsedIPs()
 | 
			
		||||
	usedIps, err := app.getUsedIPs()
 | 
			
		||||
 | 
			
		||||
	c.Assert(err, check.IsNil)
 | 
			
		||||
 | 
			
		||||
	c.Assert(len(ips), check.Equals, 350)
 | 
			
		||||
	c.Assert(len(usedIps), check.Equals, 350)
 | 
			
		||||
 | 
			
		||||
	c.Assert(ips[0], check.Equals, netaddr.MustParseIP("10.27.0.1"))
 | 
			
		||||
	c.Assert(ips[9], check.Equals, netaddr.MustParseIP("10.27.0.10"))
 | 
			
		||||
	c.Assert(ips[300], check.Equals, netaddr.MustParseIP("10.27.1.45"))
 | 
			
		||||
	c.Assert(usedIps[0], check.Equals, netaddr.MustParseIP("10.27.0.1"))
 | 
			
		||||
	c.Assert(usedIps[9], check.Equals, netaddr.MustParseIP("10.27.0.10"))
 | 
			
		||||
	c.Assert(usedIps[300], check.Equals, netaddr.MustParseIP("10.27.1.45"))
 | 
			
		||||
 | 
			
		||||
	// Check that we can read back the IPs
 | 
			
		||||
	machine1, err := app.GetMachineByID(1)
 | 
			
		||||
	c.Assert(err, check.IsNil)
 | 
			
		||||
	c.Assert(len(machine1.IPAddresses), check.Equals, 1)
 | 
			
		||||
	c.Assert(
 | 
			
		||||
		machine1.IPAddress,
 | 
			
		||||
		machine1.IPAddresses[0],
 | 
			
		||||
		check.Equals,
 | 
			
		||||
		netaddr.MustParseIP("10.27.0.1").String(),
 | 
			
		||||
		netaddr.MustParseIP("10.27.0.1"),
 | 
			
		||||
	)
 | 
			
		||||
 | 
			
		||||
	machine50, err := app.GetMachineByID(50)
 | 
			
		||||
	c.Assert(err, check.IsNil)
 | 
			
		||||
	c.Assert(len(machine50.IPAddresses), check.Equals, 1)
 | 
			
		||||
	c.Assert(
 | 
			
		||||
		machine50.IPAddress,
 | 
			
		||||
		machine50.IPAddresses[0],
 | 
			
		||||
		check.Equals,
 | 
			
		||||
		netaddr.MustParseIP("10.27.0.50").String(),
 | 
			
		||||
		netaddr.MustParseIP("10.27.0.50"),
 | 
			
		||||
	)
 | 
			
		||||
 | 
			
		||||
	expectedNextIP := netaddr.MustParseIP("10.27.1.95")
 | 
			
		||||
	nextIP, err := app.getAvailableIP()
 | 
			
		||||
	nextIP, err := app.getAvailableIPs()
 | 
			
		||||
	c.Assert(err, check.IsNil)
 | 
			
		||||
 | 
			
		||||
	c.Assert(nextIP.String(), check.Equals, expectedNextIP.String())
 | 
			
		||||
	c.Assert(len(nextIP), check.Equals, 1)
 | 
			
		||||
	c.Assert(nextIP[0].String(), check.Equals, expectedNextIP.String())
 | 
			
		||||
 | 
			
		||||
	// If we call get Available again, we should receive
 | 
			
		||||
	// the same IP, as it has not been reserved.
 | 
			
		||||
	nextIP2, err := app.getAvailableIP()
 | 
			
		||||
	nextIP2, err := app.getAvailableIPs()
 | 
			
		||||
	c.Assert(err, check.IsNil)
 | 
			
		||||
 | 
			
		||||
	c.Assert(nextIP2.String(), check.Equals, expectedNextIP.String())
 | 
			
		||||
	c.Assert(len(nextIP2), check.Equals, 1)
 | 
			
		||||
	c.Assert(nextIP2[0].String(), check.Equals, expectedNextIP.String())
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s *Suite) TestGetAvailableIpMachineWithoutIP(c *check.C) {
 | 
			
		||||
	ip, err := app.getAvailableIP()
 | 
			
		||||
	ips, err := app.getAvailableIPs()
 | 
			
		||||
	c.Assert(err, check.IsNil)
 | 
			
		||||
 | 
			
		||||
	expected := netaddr.MustParseIP("10.27.0.1")
 | 
			
		||||
 | 
			
		||||
	c.Assert(ip.String(), check.Equals, expected.String())
 | 
			
		||||
	c.Assert(len(ips), check.Equals, 1)
 | 
			
		||||
	c.Assert(ips[0].String(), check.Equals, expected.String())
 | 
			
		||||
 | 
			
		||||
	namespace, err := app.CreateNamespace("test_ip")
 | 
			
		||||
	c.Assert(err, check.IsNil)
 | 
			
		||||
@@ -156,8 +164,9 @@ func (s *Suite) TestGetAvailableIpMachineWithoutIP(c *check.C) {
 | 
			
		||||
	}
 | 
			
		||||
	app.db.Save(&machine)
 | 
			
		||||
 | 
			
		||||
	ip2, err := app.getAvailableIP()
 | 
			
		||||
	ips2, err := app.getAvailableIPs()
 | 
			
		||||
	c.Assert(err, check.IsNil)
 | 
			
		||||
 | 
			
		||||
	c.Assert(ip2.String(), check.Equals, expected.String())
 | 
			
		||||
	c.Assert(len(ips2), check.Equals, 1)
 | 
			
		||||
	c.Assert(ips2[0].String(), check.Equals, expected.String())
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user