From c1218ad3c229dd1314e09e845e67d1353142478e Mon Sep 17 00:00:00 2001
From: Kristoffer Dalby <kristoffer@tailscale.com>
Date: Tue, 6 Jun 2023 11:28:52 +0200
Subject: [PATCH] move reminder of dns funcs to util

Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
---
 hscontrol/app.go           |   2 +-
 hscontrol/dns.go           | 151 -------------------------------------
 hscontrol/dns_test.go      | 109 --------------------------
 hscontrol/util/dns.go      | 140 ++++++++++++++++++++++++++++++++++
 hscontrol/util/dns_test.go | 105 +++++++++++++++++++++++++-
 5 files changed, 245 insertions(+), 262 deletions(-)
 delete mode 100644 hscontrol/dns.go
 delete mode 100644 hscontrol/dns_test.go

diff --git a/hscontrol/app.go b/hscontrol/app.go
index dec14a38..90628136 100644
--- a/hscontrol/app.go
+++ b/hscontrol/app.go
@@ -192,7 +192,7 @@ func NewHeadscale(cfg *types.Config) (*Headscale, error) {
 	}
 
 	if app.cfg.DNSConfig != nil && app.cfg.DNSConfig.Proxied { // if MagicDNS
-		magicDNSDomains := generateMagicDNSRootDomains(app.cfg.IPPrefixes)
+		magicDNSDomains := util.GenerateMagicDNSRootDomains(app.cfg.IPPrefixes)
 		// we might have routes already from Split DNS
 		if app.cfg.DNSConfig.Routes == nil {
 			app.cfg.DNSConfig.Routes = make(map[string][]*dnstype.Resolver)
diff --git a/hscontrol/dns.go b/hscontrol/dns.go
deleted file mode 100644
index dcab04da..00000000
--- a/hscontrol/dns.go
+++ /dev/null
@@ -1,151 +0,0 @@
-package hscontrol
-
-import (
-	"fmt"
-	"net/netip"
-	"strings"
-
-	"go4.org/netipx"
-	"tailscale.com/util/dnsname"
-)
-
-const (
-	ByteSize = 8
-)
-
-const (
-	ipv4AddressLength = 32
-	ipv6AddressLength = 128
-)
-
-// generateMagicDNSRootDomains generates a list of DNS entries to be included in `Routes` in `MapResponse`.
-// This list of reverse DNS entries instructs the OS on what subnets and domains the Tailscale embedded DNS
-// server (listening in 100.100.100.100 udp/53) should be used for.
-//
-// Tailscale.com includes in the list:
-// - the `BaseDomain` of the user
-// - the reverse DNS entry for IPv6 (0.e.1.a.c.5.1.1.a.7.d.f.ip6.arpa., see below more on IPv6)
-// - the reverse DNS entries for the IPv4 subnets covered by the user's `IPPrefix`.
-//   In the public SaaS this is [64-127].100.in-addr.arpa.
-//
-// The main purpose of this function is then generating the list of IPv4 entries. For the 100.64.0.0/10, this
-// is clear, and could be hardcoded. But we are allowing any range as `IPPrefix`, so we need to find out the
-// subnets when we have 172.16.0.0/16 (i.e., [0-255].16.172.in-addr.arpa.), or any other subnet.
-//
-// How IN-ADDR.ARPA domains work is defined in RFC1035 (section 3.5). Tailscale.com seems to adhere to this,
-// and do not make use of RFC2317 ("Classless IN-ADDR.ARPA delegation") - hence generating the entries for the next
-// class block only.
-
-// From the netmask we can find out the wildcard bits (the bits that are not set in the netmask).
-// This allows us to then calculate the subnets included in the subsequent class block and generate the entries.
-func generateMagicDNSRootDomains(ipPrefixes []netip.Prefix) []dnsname.FQDN {
-	fqdns := make([]dnsname.FQDN, 0, len(ipPrefixes))
-	for _, ipPrefix := range ipPrefixes {
-		var generateDNSRoot func(netip.Prefix) []dnsname.FQDN
-		switch ipPrefix.Addr().BitLen() {
-		case ipv4AddressLength:
-			generateDNSRoot = generateIPv4DNSRootDomain
-
-		case ipv6AddressLength:
-			generateDNSRoot = generateIPv6DNSRootDomain
-
-		default:
-			panic(
-				fmt.Sprintf(
-					"unsupported IP version with address length %d",
-					ipPrefix.Addr().BitLen(),
-				),
-			)
-		}
-
-		fqdns = append(fqdns, generateDNSRoot(ipPrefix)...)
-	}
-
-	return fqdns
-}
-
-func generateIPv4DNSRootDomain(ipPrefix netip.Prefix) []dnsname.FQDN {
-	// Conversion to the std lib net.IPnet, a bit easier to operate
-	netRange := netipx.PrefixIPNet(ipPrefix)
-	maskBits, _ := netRange.Mask.Size()
-
-	// lastOctet is the last IP byte covered by the mask
-	lastOctet := maskBits / ByteSize
-
-	// wildcardBits is the number of bits not under the mask in the lastOctet
-	wildcardBits := ByteSize - maskBits%ByteSize
-
-	// min is the value in the lastOctet byte of the IP
-	// max is basically 2^wildcardBits - i.e., the value when all the wildcardBits are set to 1
-	min := uint(netRange.IP[lastOctet])
-	max := (min + 1<<uint(wildcardBits)) - 1
-
-	// here we generate the base domain (e.g., 100.in-addr.arpa., 16.172.in-addr.arpa., etc.)
-	rdnsSlice := []string{}
-	for i := lastOctet - 1; i >= 0; i-- {
-		rdnsSlice = append(rdnsSlice, fmt.Sprintf("%d", netRange.IP[i]))
-	}
-	rdnsSlice = append(rdnsSlice, "in-addr.arpa.")
-	rdnsBase := strings.Join(rdnsSlice, ".")
-
-	fqdns := make([]dnsname.FQDN, 0, max-min+1)
-	for i := min; i <= max; i++ {
-		fqdn, err := dnsname.ToFQDN(fmt.Sprintf("%d.%s", i, rdnsBase))
-		if err != nil {
-			continue
-		}
-		fqdns = append(fqdns, fqdn)
-	}
-
-	return fqdns
-}
-
-func generateIPv6DNSRootDomain(ipPrefix netip.Prefix) []dnsname.FQDN {
-	const nibbleLen = 4
-
-	maskBits, _ := netipx.PrefixIPNet(ipPrefix).Mask.Size()
-	expanded := ipPrefix.Addr().StringExpanded()
-	nibbleStr := strings.Map(func(r rune) rune {
-		if r == ':' {
-			return -1
-		}
-
-		return r
-	}, expanded)
-
-	// TODO?: that does not look the most efficient implementation,
-	// but the inputs are not so long as to cause problems,
-	// and from what I can see, the generateMagicDNSRootDomains
-	// function is called only once over the lifetime of a server process.
-	prefixConstantParts := []string{}
-	for i := 0; i < maskBits/nibbleLen; i++ {
-		prefixConstantParts = append(
-			[]string{string(nibbleStr[i])},
-			prefixConstantParts...)
-	}
-
-	makeDomain := func(variablePrefix ...string) (dnsname.FQDN, error) {
-		prefix := strings.Join(append(variablePrefix, prefixConstantParts...), ".")
-
-		return dnsname.ToFQDN(fmt.Sprintf("%s.ip6.arpa", prefix))
-	}
-
-	var fqdns []dnsname.FQDN
-	if maskBits%4 == 0 {
-		dom, _ := makeDomain()
-		fqdns = append(fqdns, dom)
-	} else {
-		domCount := 1 << (maskBits % nibbleLen)
-		fqdns = make([]dnsname.FQDN, 0, domCount)
-		for i := 0; i < domCount; i++ {
-			varNibble := fmt.Sprintf("%x", i)
-			dom, err := makeDomain(varNibble)
-			if err != nil {
-				continue
-			}
-			fqdns = append(fqdns, dom)
-		}
-	}
-
-	return fqdns
-}
diff --git a/hscontrol/dns_test.go b/hscontrol/dns_test.go
deleted file mode 100644
index aae243c2..00000000
--- a/hscontrol/dns_test.go
+++ /dev/null
@@ -1,109 +0,0 @@
-package hscontrol
-
-import (
-	"net/netip"
-
-	"gopkg.in/check.v1"
-)
-
-func (s *Suite) TestMagicDNSRootDomains100(c *check.C) {
-	prefixes := []netip.Prefix{
-		netip.MustParsePrefix("100.64.0.0/10"),
-	}
-	domains := generateMagicDNSRootDomains(prefixes)
-
-	found := false
-	for _, domain := range domains {
-		if domain == "64.100.in-addr.arpa." {
-			found = true
-
-			break
-		}
-	}
-	c.Assert(found, check.Equals, true)
-
-	found = false
-	for _, domain := range domains {
-		if domain == "100.100.in-addr.arpa." {
-			found = true
-
-			break
-		}
-	}
-	c.Assert(found, check.Equals, true)
-
-	found = false
-	for _, domain := range domains {
-		if domain == "127.100.in-addr.arpa." {
-			found = true
-
-			break
-		}
-	}
-	c.Assert(found, check.Equals, true)
-}
-
-func (s *Suite) TestMagicDNSRootDomains172(c *check.C) {
-	prefixes := []netip.Prefix{
-		netip.MustParsePrefix("172.16.0.0/16"),
-	}
-	domains := generateMagicDNSRootDomains(prefixes)
-
-	found := false
-	for _, domain := range domains {
-		if domain == "0.16.172.in-addr.arpa." {
-			found = true
-
-			break
-		}
-	}
-	c.Assert(found, check.Equals, true)
-
-	found = false
-	for _, domain := range domains {
-		if domain == "255.16.172.in-addr.arpa." {
-			found = true
-
-			break
-		}
-	}
-	c.Assert(found, check.Equals, true)
-}
-
-// Happens when netmask is a multiple of 4 bits (sounds likely).
-func (s *Suite) TestMagicDNSRootDomainsIPv6Single(c *check.C) {
-	prefixes := []netip.Prefix{
-		netip.MustParsePrefix("fd7a:115c:a1e0::/48"),
-	}
-	domains := generateMagicDNSRootDomains(prefixes)
-
-	c.Assert(len(domains), check.Equals, 1)
-	c.Assert(
-		domains[0].WithTrailingDot(),
-		check.Equals,
-		"0.e.1.a.c.5.1.1.a.7.d.f.ip6.arpa.",
-	)
-}
-
-func (s *Suite) TestMagicDNSRootDomainsIPv6SingleMultiple(c *check.C) {
-	prefixes := []netip.Prefix{
-		netip.MustParsePrefix("fd7a:115c:a1e0::/50"),
-	}
-	domains := generateMagicDNSRootDomains(prefixes)
-
-	yieldsRoot := func(dom string) bool {
-		for _, candidate := range domains {
-			if candidate.WithTrailingDot() == dom {
-				return true
-			}
-		}
-
-		return false
-	}
-
-	c.Assert(len(domains), check.Equals, 4)
-	c.Assert(yieldsRoot("0.0.e.1.a.c.5.1.1.a.7.d.f.ip6.arpa."), check.Equals, true)
-	c.Assert(yieldsRoot("1.0.e.1.a.c.5.1.1.a.7.d.f.ip6.arpa."), check.Equals, true)
-	c.Assert(yieldsRoot("2.0.e.1.a.c.5.1.1.a.7.d.f.ip6.arpa."), check.Equals, true)
-	c.Assert(yieldsRoot("3.0.e.1.a.c.5.1.1.a.7.d.f.ip6.arpa."), check.Equals, true)
-}
diff --git a/hscontrol/util/dns.go b/hscontrol/util/dns.go
index 72af8f83..5c666436 100644
--- a/hscontrol/util/dns.go
+++ b/hscontrol/util/dns.go
@@ -3,11 +3,19 @@ package util
 import (
 	"errors"
 	"fmt"
+	"net/netip"
 	"regexp"
 	"strings"
+
+	"go4.org/netipx"
+	"tailscale.com/util/dnsname"
 )
 
 const (
+	ByteSize          = 8
+	ipv4AddressLength = 32
+	ipv6AddressLength = 128
+
 	// value related to RFC 1123 and 952.
 	LabelHostnameLength = 63
 )
@@ -67,3 +75,135 @@ func CheckForFQDNRules(name string) error {
 
 	return nil
 }
+
+// generateMagicDNSRootDomains generates a list of DNS entries to be included in `Routes` in `MapResponse`.
+// This list of reverse DNS entries instructs the OS on what subnets and domains the Tailscale embedded DNS
+// server (listening in 100.100.100.100 udp/53) should be used for.
+//
+// Tailscale.com includes in the list:
+// - the `BaseDomain` of the user
+// - the reverse DNS entry for IPv6 (0.e.1.a.c.5.1.1.a.7.d.f.ip6.arpa., see below more on IPv6)
+// - the reverse DNS entries for the IPv4 subnets covered by the user's `IPPrefix`.
+//   In the public SaaS this is [64-127].100.in-addr.arpa.
+//
+// The main purpose of this function is then generating the list of IPv4 entries. For the 100.64.0.0/10, this
+// is clear, and could be hardcoded. But we are allowing any range as `IPPrefix`, so we need to find out the
+// subnets when we have 172.16.0.0/16 (i.e., [0-255].16.172.in-addr.arpa.), or any other subnet.
+//
+// How IN-ADDR.ARPA domains work is defined in RFC1035 (section 3.5). Tailscale.com seems to adhere to this,
+// and do not make use of RFC2317 ("Classless IN-ADDR.ARPA delegation") - hence generating the entries for the next
+// class block only.
+
+// From the netmask we can find out the wildcard bits (the bits that are not set in the netmask).
+// This allows us to then calculate the subnets included in the subsequent class block and generate the entries.
+func GenerateMagicDNSRootDomains(ipPrefixes []netip.Prefix) []dnsname.FQDN {
+	fqdns := make([]dnsname.FQDN, 0, len(ipPrefixes))
+	for _, ipPrefix := range ipPrefixes {
+		var generateDNSRoot func(netip.Prefix) []dnsname.FQDN
+		switch ipPrefix.Addr().BitLen() {
+		case ipv4AddressLength:
+			generateDNSRoot = generateIPv4DNSRootDomain
+
+		case ipv6AddressLength:
+			generateDNSRoot = generateIPv6DNSRootDomain
+
+		default:
+			panic(
+				fmt.Sprintf(
+					"unsupported IP version with address length %d",
+					ipPrefix.Addr().BitLen(),
+				),
+			)
+		}
+
+		fqdns = append(fqdns, generateDNSRoot(ipPrefix)...)
+	}
+
+	return fqdns
+}
+
+func generateIPv4DNSRootDomain(ipPrefix netip.Prefix) []dnsname.FQDN {
+	// Conversion to the std lib net.IPnet, a bit easier to operate
+	netRange := netipx.PrefixIPNet(ipPrefix)
+	maskBits, _ := netRange.Mask.Size()
+
+	// lastOctet is the last IP byte covered by the mask
+	lastOctet := maskBits / ByteSize
+
+	// wildcardBits is the number of bits not under the mask in the lastOctet
+	wildcardBits := ByteSize - maskBits%ByteSize
+
+	// min is the value in the lastOctet byte of the IP
+	// max is basically 2^wildcardBits - i.e., the value when all the wildcardBits are set to 1
+	min := uint(netRange.IP[lastOctet])
+	max := (min + 1<<uint(wildcardBits)) - 1
+
+	// here we generate the base domain (e.g., 100.in-addr.arpa., 16.172.in-addr.arpa., etc.)
+	rdnsSlice := []string{}
+	for i := lastOctet - 1; i >= 0; i-- {
+		rdnsSlice = append(rdnsSlice, fmt.Sprintf("%d", netRange.IP[i]))
+	}
+	rdnsSlice = append(rdnsSlice, "in-addr.arpa.")
+	rdnsBase := strings.Join(rdnsSlice, ".")
+
+	fqdns := make([]dnsname.FQDN, 0, max-min+1)
+	for i := min; i <= max; i++ {
+		fqdn, err := dnsname.ToFQDN(fmt.Sprintf("%d.%s", i, rdnsBase))
+		if err != nil {
+			continue
+		}
+		fqdns = append(fqdns, fqdn)
+	}
+
+	return fqdns
+}
+
+func generateIPv6DNSRootDomain(ipPrefix netip.Prefix) []dnsname.FQDN {
+	const nibbleLen = 4
+
+	maskBits, _ := netipx.PrefixIPNet(ipPrefix).Mask.Size()
+	expanded := ipPrefix.Addr().StringExpanded()
+	nibbleStr := strings.Map(func(r rune) rune {
+		if r == ':' {
+			return -1
+		}
+
+		return r
+	}, expanded)
+
+	// TODO?: that does not look the most efficient implementation,
+	// but the inputs are not so long as to cause problems,
+	// and from what I can see, the generateMagicDNSRootDomains
+	// function is called only once over the lifetime of a server process.
+	prefixConstantParts := []string{}
+	for i := 0; i < maskBits/nibbleLen; i++ {
+		prefixConstantParts = append(
+			[]string{string(nibbleStr[i])},
+			prefixConstantParts...)
+	}
+
+	makeDomain := func(variablePrefix ...string) (dnsname.FQDN, error) {
+		prefix := strings.Join(append(variablePrefix, prefixConstantParts...), ".")
+
+		return dnsname.ToFQDN(fmt.Sprintf("%s.ip6.arpa", prefix))
+	}
+
+	var fqdns []dnsname.FQDN
+	if maskBits%4 == 0 {
+		dom, _ := makeDomain()
+		fqdns = append(fqdns, dom)
+	} else {
+		domCount := 1 << (maskBits % nibbleLen)
+		fqdns = make([]dnsname.FQDN, 0, domCount)
+		for i := 0; i < domCount; i++ {
+			varNibble := fmt.Sprintf("%x", i)
+			dom, err := makeDomain(varNibble)
+			if err != nil {
+				continue
+			}
+			fqdns = append(fqdns, dom)
+		}
+	}
+
+	return fqdns
+}
diff --git a/hscontrol/util/dns_test.go b/hscontrol/util/dns_test.go
index ab66a130..9d9b08b3 100644
--- a/hscontrol/util/dns_test.go
+++ b/hscontrol/util/dns_test.go
@@ -1,6 +1,11 @@
 package util
 
-import "testing"
+import (
+	"net/netip"
+	"testing"
+
+	"github.com/stretchr/testify/assert"
+)
 
 func TestNormalizeToFQDNRules(t *testing.T) {
 	type args struct {
@@ -141,3 +146,101 @@ func TestCheckForFQDNRules(t *testing.T) {
 		})
 	}
 }
+
+func TestMagicDNSRootDomains100(t *testing.T) {
+	prefixes := []netip.Prefix{
+		netip.MustParsePrefix("100.64.0.0/10"),
+	}
+	domains := GenerateMagicDNSRootDomains(prefixes)
+
+	found := false
+	for _, domain := range domains {
+		if domain == "64.100.in-addr.arpa." {
+			found = true
+
+			break
+		}
+	}
+	assert.True(t, found)
+
+	found = false
+	for _, domain := range domains {
+		if domain == "100.100.in-addr.arpa." {
+			found = true
+
+			break
+		}
+	}
+	assert.True(t, found)
+
+	found = false
+	for _, domain := range domains {
+		if domain == "127.100.in-addr.arpa." {
+			found = true
+
+			break
+		}
+	}
+	assert.True(t, found)
+}
+
+func TestMagicDNSRootDomains172(t *testing.T) {
+	prefixes := []netip.Prefix{
+		netip.MustParsePrefix("172.16.0.0/16"),
+	}
+	domains := GenerateMagicDNSRootDomains(prefixes)
+
+	found := false
+	for _, domain := range domains {
+		if domain == "0.16.172.in-addr.arpa." {
+			found = true
+
+			break
+		}
+	}
+	assert.True(t, found)
+
+	found = false
+	for _, domain := range domains {
+		if domain == "255.16.172.in-addr.arpa." {
+			found = true
+
+			break
+		}
+	}
+	assert.True(t, found)
+}
+
+// Happens when netmask is a multiple of 4 bits (sounds likely).
+func TestMagicDNSRootDomainsIPv6Single(t *testing.T) {
+	prefixes := []netip.Prefix{
+		netip.MustParsePrefix("fd7a:115c:a1e0::/48"),
+	}
+	domains := GenerateMagicDNSRootDomains(prefixes)
+
+	assert.Len(t, domains, 1)
+	assert.Equal(t, "0.e.1.a.c.5.1.1.a.7.d.f.ip6.arpa.", domains[0].WithTrailingDot())
+}
+
+func TestMagicDNSRootDomainsIPv6SingleMultiple(t *testing.T) {
+	prefixes := []netip.Prefix{
+		netip.MustParsePrefix("fd7a:115c:a1e0::/50"),
+	}
+	domains := GenerateMagicDNSRootDomains(prefixes)
+
+	yieldsRoot := func(dom string) bool {
+		for _, candidate := range domains {
+			if candidate.WithTrailingDot() == dom {
+				return true
+			}
+		}
+
+		return false
+	}
+
+	assert.Len(t, domains, 4)
+	assert.True(t, yieldsRoot("0.0.e.1.a.c.5.1.1.a.7.d.f.ip6.arpa."))
+	assert.True(t, yieldsRoot("1.0.e.1.a.c.5.1.1.a.7.d.f.ip6.arpa."))
+	assert.True(t, yieldsRoot("2.0.e.1.a.c.5.1.1.a.7.d.f.ip6.arpa."))
+	assert.True(t, yieldsRoot("3.0.e.1.a.c.5.1.1.a.7.d.f.ip6.arpa."))
+}