From 95034e15a79888b1a4afe48e6812fd47ea138fd5 Mon Sep 17 00:00:00 2001 From: James Tucker Date: Tue, 25 Mar 2025 12:59:07 -0700 Subject: [PATCH] cmd/natc: fix ip allocation runtime Avoid the unbounded runtime during random allocation, if random allocation fails after a first pass at random through the provided ranges, pick the next free address by walking through the allocated set. The new ipx utilities provide a bitset based allocation pool, good for small to moderate ranges of IPv4 addresses as used in natc. Updates #15367 Signed-off-by: James Tucker --- cmd/natc/ipx.go | 130 ++++++++++++++++++++++++++++++++++++ cmd/natc/ipx_test.go | 150 ++++++++++++++++++++++++++++++++++++++++++ cmd/natc/natc.go | 109 +++++++++++------------------- cmd/natc/natc_test.go | 33 +++------- 4 files changed, 325 insertions(+), 97 deletions(-) create mode 100644 cmd/natc/ipx.go create mode 100644 cmd/natc/ipx_test.go diff --git a/cmd/natc/ipx.go b/cmd/natc/ipx.go new file mode 100644 index 000000000..06bf7be79 --- /dev/null +++ b/cmd/natc/ipx.go @@ -0,0 +1,130 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package main + +import ( + "math/big" + "math/bits" + "math/rand/v2" + "net/netip" + + "go4.org/netipx" +) + +func addrLessOrEqual(a, b netip.Addr) bool { + if a.Less(b) { + return true + } + if a == b { + return true + } + return false +} + +// indexOfAddr returns the index of addr in ipset, or -1 if not found. +func indexOfAddr(addr netip.Addr, ipset *netipx.IPSet) int { + var base int // offset of the current range + for _, r := range ipset.Ranges() { + if addr.Less(r.From()) { + return -1 + } + numFrom := v4ToNum(r.From()) + if addrLessOrEqual(addr, r.To()) { + numInRange := int(v4ToNum(addr) - numFrom) + return base + numInRange + } + numTo := v4ToNum(r.To()) + base += int(numTo-numFrom) + 1 + } + return -1 +} + +// addrAtIndex returns the address at the given index in ipset, or an empty +// address if index is out of range. +func addrAtIndex(index int, ipset *netipx.IPSet) netip.Addr { + if index < 0 { + return netip.Addr{} + } + var base int // offset of the current range + for _, r := range ipset.Ranges() { + numFrom := v4ToNum(r.From()) + numTo := v4ToNum(r.To()) + if index <= base+int(numTo-numFrom) { + return numToV4(uint32(int(numFrom) + index - base)) + } + base += int(numTo-numFrom) + 1 + } + return netip.Addr{} +} + +// TODO(golang/go#9455): once we have uint128 we can easily implement for all addrs. + +// v4ToNum returns a uint32 representation of the IPv4 address. If addr is not +// an IPv4 address, this function will panic. +func v4ToNum(addr netip.Addr) uint32 { + addr = addr.Unmap() + if !addr.Is4() { + panic("only IPv4 addresses are supported by v4ToNum") + } + b := addr.As4() + var o uint32 + o = o<<8 | uint32(b[0]) + o = o<<8 | uint32(b[1]) + o = o<<8 | uint32(b[2]) + o = o<<8 | uint32(b[3]) + return o +} + +func numToV4(i uint32) netip.Addr { + var addr [4]byte + addr[0] = byte((i >> 24) & 0xff) + addr[1] = byte((i >> 16) & 0xff) + addr[2] = byte((i >> 8) & 0xff) + addr[3] = byte(i & 0xff) + return netip.AddrFrom4(addr) +} + +// allocAddr returns an address in ipset that is not already marked allocated in allocated. +func allocAddr(ipset *netipx.IPSet, allocated *big.Int) netip.Addr { + // first try to allocate a random IP from each range, if we land on one. + var base uint32 // index offset of the current range + for _, r := range ipset.Ranges() { + numFrom := v4ToNum(r.From()) + numTo := v4ToNum(r.To()) + randInRange := rand.N(numTo - numFrom) + randIndex := base + randInRange + if allocated.Bit(int(randIndex)) == 0 { + allocated.SetBit(allocated, int(randIndex), 1) + return numToV4(numFrom + randInRange) + } + base += numTo - numFrom + 1 + } + + // fall back to seeking a free bit in the allocated set + index := -1 + for i, word := range allocated.Bits() { + zbi := leastZeroBit(uint(word)) + if zbi == -1 { + continue + } + index = i*bits.UintSize + zbi + allocated.SetBit(allocated, index, 1) + break + } + if index == -1 { + return netip.Addr{} + } + return addrAtIndex(index, ipset) +} + +// leastZeroBit returns the index of the least significant zero bit in the given uint, or -1 +// if all bits are set. +func leastZeroBit(n uint) int { + notN := ^n + rightmostBit := notN & -notN + if rightmostBit == 0 { + return -1 + } + return bits.TrailingZeros(rightmostBit) +} diff --git a/cmd/natc/ipx_test.go b/cmd/natc/ipx_test.go new file mode 100644 index 000000000..b60a5d981 --- /dev/null +++ b/cmd/natc/ipx_test.go @@ -0,0 +1,150 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package main + +import ( + "math" + "math/big" + "net/netip" + "testing" + + "go4.org/netipx" + "tailscale.com/util/must" +) + +func TestV4ToNum(t *testing.T) { + cases := []struct { + addr netip.Addr + num uint32 + }{ + {netip.MustParseAddr("0.0.0.0"), 0}, + {netip.MustParseAddr("255.255.255.255"), 0xffffffff}, + {netip.MustParseAddr("8.8.8.8"), 0x08080808}, + {netip.MustParseAddr("192.168.0.1"), 0xc0a80001}, + {netip.MustParseAddr("10.0.0.1"), 0x0a000001}, + {netip.MustParseAddr("172.16.0.1"), 0xac100001}, + {netip.MustParseAddr("100.64.0.1"), 0x64400001}, + } + + for _, tc := range cases { + num := v4ToNum(tc.addr) + if num != tc.num { + t.Errorf("addrNum(%v) = %d, want %d", tc.addr, num, tc.num) + } + if numToV4(num) != tc.addr { + t.Errorf("numToV4(%d) = %v, want %v", num, numToV4(num), tc.addr) + } + } + + func() { + defer func() { + if r := recover(); r == nil { + t.Fatal("expected panic") + } + }() + + v4ToNum(netip.MustParseAddr("::1")) + }() +} + +func TestAddrIndex(t *testing.T) { + builder := netipx.IPSetBuilder{} + builder.AddRange(netipx.MustParseIPRange("10.0.0.1-10.0.0.5")) + builder.AddRange(netipx.MustParseIPRange("192.168.0.1-192.168.0.10")) + ipset := must.Get(builder.IPSet()) + + indexCases := []struct { + addr netip.Addr + index int + }{ + {netip.MustParseAddr("10.0.0.1"), 0}, + {netip.MustParseAddr("10.0.0.2"), 1}, + {netip.MustParseAddr("10.0.0.3"), 2}, + {netip.MustParseAddr("10.0.0.4"), 3}, + {netip.MustParseAddr("10.0.0.5"), 4}, + {netip.MustParseAddr("192.168.0.1"), 5}, + {netip.MustParseAddr("192.168.0.5"), 9}, + {netip.MustParseAddr("192.168.0.10"), 14}, + {netip.MustParseAddr("172.16.0.1"), -1}, // Not in set + } + + for _, tc := range indexCases { + index := indexOfAddr(tc.addr, ipset) + if index != tc.index { + t.Errorf("indexOfAddr(%v) = %d, want %d", tc.addr, index, tc.index) + } + if tc.index == -1 { + continue + } + addr := addrAtIndex(tc.index, ipset) + if addr != tc.addr { + t.Errorf("addrAtIndex(%d) = %v, want %v", tc.index, addr, tc.addr) + } + } +} + +func TestAllocAddr(t *testing.T) { + builder := netipx.IPSetBuilder{} + builder.AddRange(netipx.MustParseIPRange("10.0.0.1-10.0.0.5")) + builder.AddRange(netipx.MustParseIPRange("192.168.0.1-192.168.0.10")) + ipset := must.Get(builder.IPSet()) + + allocated := new(big.Int) + for range 15 { + addr := allocAddr(ipset, allocated) + if !addr.IsValid() { + t.Errorf("allocAddr() = invalid, want valid") + } + if !ipset.Contains(addr) { + t.Errorf("allocAddr() = %v, not in set", addr) + } + } + addr := allocAddr(ipset, allocated) + if addr.IsValid() { + t.Errorf("allocAddr() = %v, want invalid", addr) + } + wantAddr := netip.MustParseAddr("10.0.0.2") + allocated.SetBit(allocated, indexOfAddr(wantAddr, ipset), 0) + addr = allocAddr(ipset, allocated) + if addr != wantAddr { + t.Errorf("allocAddr() = %v, want %v", addr, wantAddr) + } +} + +func TestLeastZeroBit(t *testing.T) { + cases := []struct { + num uint + want int + }{ + {math.MaxUint, -1}, + {0, 0}, + {0b01, 1}, + {0b11, 2}, + {0b111, 3}, + {math.MaxUint, -1}, + {math.MaxUint - 1, 0}, + } + if math.MaxUint == math.MaxUint64 { + cases = append(cases, []struct { + num uint + want int + }{ + {math.MaxUint >> 1, 63}, + }...) + } else { + cases = append(cases, []struct { + num uint + want int + }{ + {math.MaxUint >> 1, 31}, + }...) + } + + for _, tc := range cases { + got := leastZeroBit(tc.num) + if got != tc.want { + t.Errorf("leastZeroBit(%b) = %d, want %d", tc.num, got, tc.want) + } + } +} diff --git a/cmd/natc/natc.go b/cmd/natc/natc.go index 31d6a5d26..a8168ce6d 100644 --- a/cmd/natc/natc.go +++ b/cmd/natc/natc.go @@ -8,13 +8,12 @@ package main import ( "context" - "encoding/binary" "errors" "expvar" "flag" "fmt" "log" - "math/rand/v2" + "math/big" "net" "net/http" "net/netip" @@ -26,6 +25,7 @@ import ( "github.com/gaissmai/bart" "github.com/inetaf/tcpproxy" "github.com/peterbourgon/ff/v3" + "go4.org/netipx" "golang.org/x/net/dns/dnsmessage" "tailscale.com/client/local" "tailscale.com/envknob" @@ -38,6 +38,7 @@ import ( "tailscale.com/tsweb" "tailscale.com/util/dnsname" "tailscale.com/util/mak" + "tailscale.com/util/must" "tailscale.com/wgengine/netstack" ) @@ -94,24 +95,6 @@ func main() { } ignoreDstTable.Insert(pfx, true) } - var ( - v4Prefixes []netip.Prefix - numV4DNSAddrs int - ) - for _, s := range strings.Split(*v4PfxStr, ",") { - p := netip.MustParsePrefix(strings.TrimSpace(s)) - if p.Masked() != p { - log.Fatalf("v4 prefix %v is not a masked prefix", p) - } - v4Prefixes = append(v4Prefixes, p) - numIPs := 1 << (32 - p.Bits()) - numV4DNSAddrs += numIPs - } - if len(v4Prefixes) == 0 { - log.Fatalf("no v4 prefixes specified") - } - dnsAddr := v4Prefixes[0].Addr() - numV4DNSAddrs -= 1 // Subtract the dnsAddr allocated above. ts := &tsnet.Server{ Hostname: *hostname, } @@ -159,17 +142,34 @@ func main() { } c := &connector{ - ts: ts, - lc: lc, - dnsAddr: dnsAddr, - v4Ranges: v4Prefixes, - numV4DNSAddrs: numV4DNSAddrs, - v6ULA: ula(uint16(*siteID)), - ignoreDsts: ignoreDstTable, + ts: ts, + lc: lc, + v6ULA: ula(uint16(*siteID)), + ignoreDsts: ignoreDstTable, } + var prefixes []netip.Prefix + for _, s := range strings.Split(*v4PfxStr, ",") { + p := netip.MustParsePrefix(strings.TrimSpace(s)) + if p.Masked() != p { + log.Fatalf("v4 prefix %v is not a masked prefix", p) + } + prefixes = append(prefixes, p) + } + c.setPrefixes(prefixes) c.run(ctx) } +func (c *connector) setPrefixes(prefixes []netip.Prefix) { + var ipsb netipx.IPSetBuilder + for _, p := range prefixes { + ipsb.AddPrefix(p) + } + c.routes = must.Get(ipsb.IPSet()) + c.dnsAddr = c.routes.Ranges()[0].From() + ipsb.Remove(c.dnsAddr) + c.ipset = must.Get(ipsb.IPSet()) +} + type connector struct { // ts is the tsnet.Server used to host the connector. ts *tsnet.Server @@ -181,13 +181,13 @@ type connector struct { // prevent the app connector from assigning it to a domain. dnsAddr netip.Addr - // v4Ranges is the list of IPv4 ranges to advertise and assign addresses from. + // ipset is the set of IPv4 ranges to advertise and assign addresses from. // These are masked prefixes. - v4Ranges []netip.Prefix + ipset *netipx.IPSet - // numV4DNSAddrs is the total size of the IPv4 ranges in addresses, minus the - // dnsAddr allocation. - numV4DNSAddrs int + // routes is the set of IPv4 ranges advertised to the tailnet, or ipset with + // the dnsAddr removed. + routes *netipx.IPSet // v6ULA is the ULA prefix used by the app connector to assign IPv6 addresses. v6ULA netip.Prefix @@ -225,7 +225,7 @@ func (c *connector) run(ctx context.Context) { if _, err := c.lc.EditPrefs(ctx, &ipn.MaskedPrefs{ AdvertiseRoutesSet: true, Prefs: ipn.Prefs{ - AdvertiseRoutes: append(c.v4Ranges, c.v6ULA), + AdvertiseRoutes: append(c.routes.Prefixes(), c.v6ULA), }, }); err != nil { log.Fatalf("failed to advertise routes: %v", err) @@ -512,9 +512,9 @@ type perPeerState struct { c *connector mu sync.Mutex + addrInUse *big.Int domainToAddr map[string][]netip.Addr addrToDomain *bart.Table[string] - numV4Allocs int } // domainForIP returns the domain name assigned to the given IP address and @@ -550,46 +550,12 @@ func (ps *perPeerState) ipForDomain(domain string) ([]netip.Addr, error) { return addrs, nil } -// isIPUsedLocked reports whether the given IP address is already assigned to a -// domain. -// ps.mu must be held. -func (ps *perPeerState) isIPUsedLocked(ip netip.Addr) bool { - _, ok := ps.addrToDomain.Lookup(ip) - return ok -} - // unusedIPv4Locked returns an unused IPv4 address from the available ranges. func (ps *perPeerState) unusedIPv4Locked() netip.Addr { - // All addresses have been allocated. - if ps.numV4Allocs >= ps.c.numV4DNSAddrs { - return netip.Addr{} + if ps.addrInUse == nil { + ps.addrInUse = big.NewInt(0) } - - // TODO: skip ranges that have been exhausted - // TODO: implement a much more efficient algorithm for finding unused IPs, - // this is fairly crazy. - for { - for _, r := range ps.c.v4Ranges { - ip := randV4(r) - if !r.Contains(ip) { - panic("error: randV4 returned invalid address") - } - if !ps.isIPUsedLocked(ip) && ip != ps.c.dnsAddr { - return ip - } - } - } -} - -// randV4 returns a random IPv4 address within the given prefix. -func randV4(maskedPfx netip.Prefix) netip.Addr { - bits := 32 - maskedPfx.Bits() - randBits := rand.Uint32N(1 << uint(bits)) - - ip4 := maskedPfx.Addr().As4() - pn := binary.BigEndian.Uint32(ip4[:]) - binary.BigEndian.PutUint32(ip4[:], randBits|pn) - return netip.AddrFrom4(ip4) + return allocAddr(ps.c.ipset, ps.addrInUse) } // assignAddrsLocked assigns a pair of unique IP addresses for the given domain @@ -604,7 +570,6 @@ func (ps *perPeerState) assignAddrsLocked(domain string) []netip.Addr { if !v4.IsValid() { return nil } - ps.numV4Allocs++ as16 := ps.c.v6ULA.Addr().As16() as4 := v4.As4() copy(as16[12:], as4[:]) diff --git a/cmd/natc/natc_test.go b/cmd/natc/natc_test.go index e42fa7e89..ddd2d1894 100644 --- a/cmd/natc/natc_test.go +++ b/cmd/natc/natc_test.go @@ -43,17 +43,6 @@ func TestULA(t *testing.T) { } } -func TestRandV4(t *testing.T) { - pfx := netip.MustParsePrefix("100.64.1.0/24") - - for i := 0; i < 512; i++ { - ip := randV4(pfx) - if !pfx.Contains(ip) { - t.Errorf("randV4(%s) = %s; not contained in prefix", pfx, ip) - } - } -} - func TestDNSResponse(t *testing.T) { tests := []struct { name string @@ -227,11 +216,9 @@ func TestDNSResponse(t *testing.T) { func TestPerPeerState(t *testing.T) { c := &connector{ - v4Ranges: []netip.Prefix{netip.MustParsePrefix("100.64.1.0/24")}, - v6ULA: netip.MustParsePrefix("fd7a:115c:a1e0:a99c:0001::/80"), - dnsAddr: netip.MustParseAddr("100.64.1.0"), - numV4DNSAddrs: (1<<(32-24) - 1), + v6ULA: netip.MustParsePrefix("fd7a:115c:a1e0:a99c:0001::/80"), } + c.setPrefixes([]netip.Prefix{netip.MustParsePrefix("100.64.1.0/24")}) ps := &perPeerState{c: c} @@ -255,8 +242,8 @@ func TestPerPeerState(t *testing.T) { t.Errorf("Second address is not IPv6: %s", v6) } - if !c.v4Ranges[0].Contains(v4) { - t.Errorf("IPv4 address %s not in range %s", v4, c.v4Ranges[0]) + if !c.ipset.Contains(v4) { + t.Errorf("IPv4 address %s not in range %s", v4, c.ipset) } domain, ok := ps.domainForIP(v4) @@ -331,11 +318,9 @@ func TestIgnoreDestination(t *testing.T) { func TestConnectorGenerateDNSResponse(t *testing.T) { c := &connector{ - v4Ranges: []netip.Prefix{netip.MustParsePrefix("100.64.1.0/24")}, - v6ULA: netip.MustParsePrefix("fd7a:115c:a1e0:a99c:0001::/80"), - dnsAddr: netip.MustParseAddr("100.64.1.0"), - numV4DNSAddrs: (1<<(32-24) - 1), + v6ULA: netip.MustParsePrefix("fd7a:115c:a1e0:a99c:0001::/80"), } + c.setPrefixes([]netip.Prefix{netip.MustParsePrefix("100.64.1.0/24")}) req := &dnsmessage.Message{ Header: dnsmessage.Header{ID: 1234}, @@ -371,11 +356,9 @@ func TestConnectorGenerateDNSResponse(t *testing.T) { func TestIPPoolExhaustion(t *testing.T) { smallPrefix := netip.MustParsePrefix("100.64.1.0/30") // Only 4 IPs: .0, .1, .2, .3 c := &connector{ - v6ULA: netip.MustParsePrefix("fd7a:115c:a1e0:a99c:0001::/80"), - v4Ranges: []netip.Prefix{smallPrefix}, - dnsAddr: netip.MustParseAddr("100.64.1.0"), - numV4DNSAddrs: 3, + v6ULA: netip.MustParsePrefix("fd7a:115c:a1e0:a99c:0001::/80"), } + c.setPrefixes([]netip.Prefix{smallPrefix}) ps := &perPeerState{c: c}