diff --git a/cmd/natc/natc.go b/cmd/natc/natc.go index 73ba116ff..31d6a5d26 100644 --- a/cmd/natc/natc.go +++ b/cmd/natc/natc.go @@ -94,18 +94,24 @@ func main() { } ignoreDstTable.Insert(pfx, true) } - var v4Prefixes []netip.Prefix + var ( + v4Prefixes []netip.Prefix + numV4DNSAddrs int + ) for _, s := range strings.Split(*v4PfxStr, ",") { p := netip.MustParsePrefix(strings.TrimSpace(s)) if p.Masked() != p { log.Fatalf("v4 prefix %v is not a masked prefix", p) } v4Prefixes = append(v4Prefixes, p) + numIPs := 1 << (32 - p.Bits()) + numV4DNSAddrs += numIPs } if len(v4Prefixes) == 0 { log.Fatalf("no v4 prefixes specified") } dnsAddr := v4Prefixes[0].Addr() + numV4DNSAddrs -= 1 // Subtract the dnsAddr allocated above. ts := &tsnet.Server{ Hostname: *hostname, } @@ -153,12 +159,13 @@ func main() { } c := &connector{ - ts: ts, - lc: lc, - dnsAddr: dnsAddr, - v4Ranges: v4Prefixes, - v6ULA: ula(uint16(*siteID)), - ignoreDsts: ignoreDstTable, + ts: ts, + lc: lc, + dnsAddr: dnsAddr, + v4Ranges: v4Prefixes, + numV4DNSAddrs: numV4DNSAddrs, + v6ULA: ula(uint16(*siteID)), + ignoreDsts: ignoreDstTable, } c.run(ctx) } @@ -177,6 +184,11 @@ type connector struct { // v4Ranges is the list of IPv4 ranges to advertise and assign addresses from. // These are masked prefixes. v4Ranges []netip.Prefix + + // numV4DNSAddrs is the total size of the IPv4 ranges in addresses, minus the + // dnsAddr allocation. + numV4DNSAddrs int + // v6ULA is the ULA prefix used by the app connector to assign IPv6 addresses. v6ULA netip.Prefix @@ -502,6 +514,7 @@ type perPeerState struct { mu sync.Mutex domainToAddr map[string][]netip.Addr addrToDomain *bart.Table[string] + numV4Allocs int } // domainForIP returns the domain name assigned to the given IP address and @@ -547,17 +560,25 @@ func (ps *perPeerState) isIPUsedLocked(ip netip.Addr) bool { // unusedIPv4Locked returns an unused IPv4 address from the available ranges. func (ps *perPeerState) unusedIPv4Locked() netip.Addr { + // All addresses have been allocated. + if ps.numV4Allocs >= ps.c.numV4DNSAddrs { + return netip.Addr{} + } + // TODO: skip ranges that have been exhausted - for _, r := range ps.c.v4Ranges { - ip := randV4(r) - for r.Contains(ip) { + // TODO: implement a much more efficient algorithm for finding unused IPs, + // this is fairly crazy. + for { + for _, r := range ps.c.v4Ranges { + ip := randV4(r) + if !r.Contains(ip) { + panic("error: randV4 returned invalid address") + } if !ps.isIPUsedLocked(ip) && ip != ps.c.dnsAddr { return ip } - ip = ip.Next() } } - return netip.Addr{} } // randV4 returns a random IPv4 address within the given prefix. @@ -583,6 +604,7 @@ func (ps *perPeerState) assignAddrsLocked(domain string) []netip.Addr { if !v4.IsValid() { return nil } + ps.numV4Allocs++ as16 := ps.c.v6ULA.Addr().As16() as4 := v4.As4() copy(as16[12:], as4[:]) diff --git a/cmd/natc/natc_test.go b/cmd/natc/natc_test.go index 1b6d7af7c..e42fa7e89 100644 --- a/cmd/natc/natc_test.go +++ b/cmd/natc/natc_test.go @@ -4,6 +4,8 @@ package main import ( + "errors" + "fmt" "net/netip" "slices" "testing" @@ -225,9 +227,10 @@ func TestDNSResponse(t *testing.T) { func TestPerPeerState(t *testing.T) { c := &connector{ - v4Ranges: []netip.Prefix{netip.MustParsePrefix("100.64.1.0/24")}, - v6ULA: netip.MustParsePrefix("fd7a:115c:a1e0:a99c:0001::/80"), - dnsAddr: netip.MustParseAddr("100.64.1.1"), + v4Ranges: []netip.Prefix{netip.MustParsePrefix("100.64.1.0/24")}, + v6ULA: netip.MustParsePrefix("fd7a:115c:a1e0:a99c:0001::/80"), + dnsAddr: netip.MustParseAddr("100.64.1.0"), + numV4DNSAddrs: (1<<(32-24) - 1), } ps := &perPeerState{c: c} @@ -328,9 +331,10 @@ func TestIgnoreDestination(t *testing.T) { func TestConnectorGenerateDNSResponse(t *testing.T) { c := &connector{ - v4Ranges: []netip.Prefix{netip.MustParsePrefix("100.64.1.0/24")}, - v6ULA: netip.MustParsePrefix("fd7a:115c:a1e0:a99c:0001::/80"), - dnsAddr: netip.MustParseAddr("100.64.1.1"), + v4Ranges: []netip.Prefix{netip.MustParsePrefix("100.64.1.0/24")}, + v6ULA: netip.MustParsePrefix("fd7a:115c:a1e0:a99c:0001::/80"), + dnsAddr: netip.MustParseAddr("100.64.1.0"), + numV4DNSAddrs: (1<<(32-24) - 1), } req := &dnsmessage.Message{ @@ -363,3 +367,63 @@ func TestConnectorGenerateDNSResponse(t *testing.T) { t.Errorf("generateDNSResponse() responses differ between calls") } } + +func TestIPPoolExhaustion(t *testing.T) { + smallPrefix := netip.MustParsePrefix("100.64.1.0/30") // Only 4 IPs: .0, .1, .2, .3 + c := &connector{ + v6ULA: netip.MustParsePrefix("fd7a:115c:a1e0:a99c:0001::/80"), + v4Ranges: []netip.Prefix{smallPrefix}, + dnsAddr: netip.MustParseAddr("100.64.1.0"), + numV4DNSAddrs: 3, + } + + ps := &perPeerState{c: c} + + assignedIPs := make(map[netip.Addr]string) + + domains := []string{"a.example.com", "b.example.com", "c.example.com", "d.example.com"} + + var errs []error + + for i := 0; i < 5; i++ { + for _, domain := range domains { + addrs, err := ps.ipForDomain(domain) + if err != nil { + errs = append(errs, fmt.Errorf("failed to get IP for domain %q: %w", domain, err)) + continue + } + + for _, addr := range addrs { + if d, ok := assignedIPs[addr]; ok { + if d != domain { + t.Errorf("IP %s reused for domain %q, previously assigned to %q", addr, domain, d) + } + } else { + assignedIPs[addr] = domain + } + } + } + } + + for addr, domain := range assignedIPs { + if addr.Is4() && !smallPrefix.Contains(addr) { + t.Errorf("IP %s for domain %q not in expected range %s", addr, domain, smallPrefix) + } + if addr.Is6() && !c.v6ULA.Contains(addr) { + t.Errorf("IP %s for domain %q not in expected range %s", addr, domain, c.v6ULA) + } + if addr == c.dnsAddr { + t.Errorf("IP %s for domain %q is the reserved DNS address", addr, domain) + } + } + + // expect one error for each iteration with the 4th domain + if len(errs) != 5 { + t.Errorf("Expected 5 errors, got %d: %v", len(errs), errs) + } + for _, err := range errs { + if !errors.Is(err, ErrNoIPsAvailable) { + t.Errorf("generateDNSResponse() error = %v, want ErrNoIPsAvailable", err) + } + } +}