diff --git a/cmd/natc/ippool/ippool.go b/cmd/natc/ippool/ippool.go new file mode 100644 index 000000000..6f6ad1d83 --- /dev/null +++ b/cmd/natc/ippool/ippool.go @@ -0,0 +1,127 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// ippool implements IP address storage, creation, and retrieval for cmd/natc +package ippool + +import ( + "errors" + "log" + "math/big" + "net/netip" + "sync" + + "github.com/gaissmai/bart" + "go4.org/netipx" + "tailscale.com/syncs" + "tailscale.com/tailcfg" + "tailscale.com/util/dnsname" + "tailscale.com/util/mak" +) + +var ErrNoIPsAvailable = errors.New("no IPs available") + +type IPPool struct { + perPeerMap syncs.Map[tailcfg.NodeID, *perPeerState] + IPSet *netipx.IPSet + V6ULA netip.Prefix +} + +func (ipp *IPPool) DomainForIP(from tailcfg.NodeID, addr netip.Addr) (string, bool) { + ps, ok := ipp.perPeerMap.Load(from) + if !ok { + log.Printf("handleTCPFlow: no perPeerState for %v", from) + return "", false + } + domain, ok := ps.domainForIP(addr) + if !ok { + log.Printf("handleTCPFlow: no domain for IP %v\n", addr) + return "", false + } + return domain, ok +} + +func (ipp *IPPool) IPForDomain(from tailcfg.NodeID, domain string) ([]netip.Addr, error) { + npps := &perPeerState{ + ipset: ipp.IPSet, + v6ULA: ipp.V6ULA, + } + ps, _ := ipp.perPeerMap.LoadOrStore(from, npps) + return ps.ipForDomain(domain) +} + +// perPeerState holds the state for a single peer. +type perPeerState struct { + v6ULA netip.Prefix + ipset *netipx.IPSet + + mu sync.Mutex + addrInUse *big.Int + domainToAddr map[string][]netip.Addr + addrToDomain *bart.Table[string] +} + +// domainForIP returns the domain name assigned to the given IP address and +// whether it was found. +func (ps *perPeerState) domainForIP(ip netip.Addr) (_ string, ok bool) { + ps.mu.Lock() + defer ps.mu.Unlock() + if ps.addrToDomain == nil { + return "", false + } + return ps.addrToDomain.Lookup(ip) +} + +// ipForDomain assigns a pair of unique IP addresses for the given domain and +// returns them. The first address is an IPv4 address and the second is an IPv6 +// address. If the domain already has assigned addresses, it returns them. +func (ps *perPeerState) ipForDomain(domain string) ([]netip.Addr, error) { + fqdn, err := dnsname.ToFQDN(domain) + if err != nil { + return nil, err + } + domain = fqdn.WithoutTrailingDot() + + ps.mu.Lock() + defer ps.mu.Unlock() + if addrs, ok := ps.domainToAddr[domain]; ok { + return addrs, nil + } + addrs := ps.assignAddrsLocked(domain) + if addrs == nil { + return nil, ErrNoIPsAvailable + } + return addrs, nil +} + +// unusedIPv4Locked returns an unused IPv4 address from the available ranges. +func (ps *perPeerState) unusedIPv4Locked() netip.Addr { + if ps.addrInUse == nil { + ps.addrInUse = big.NewInt(0) + } + return allocAddr(ps.ipset, ps.addrInUse) +} + +// assignAddrsLocked assigns a pair of unique IP addresses for the given domain +// and returns them. The first address is an IPv4 address and the second is an +// IPv6 address. It does not check if the domain already has assigned addresses. +// ps.mu must be held. +func (ps *perPeerState) assignAddrsLocked(domain string) []netip.Addr { + if ps.addrToDomain == nil { + ps.addrToDomain = &bart.Table[string]{} + } + v4 := ps.unusedIPv4Locked() + if !v4.IsValid() { + return nil + } + as16 := ps.v6ULA.Addr().As16() + as4 := v4.As4() + copy(as16[12:], as4[:]) + v6 := netip.AddrFrom16(as16) + addrs := []netip.Addr{v4, v6} + mak.Set(&ps.domainToAddr, domain, addrs) + for _, a := range addrs { + ps.addrToDomain.Insert(netip.PrefixFrom(a, a.BitLen()), domain) + } + return addrs +} diff --git a/cmd/natc/ippool/ippool_test.go b/cmd/natc/ippool/ippool_test.go new file mode 100644 index 000000000..84b3b7a02 --- /dev/null +++ b/cmd/natc/ippool/ippool_test.go @@ -0,0 +1,129 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package ippool + +import ( + "errors" + "fmt" + "net/netip" + "slices" + "testing" + + "go4.org/netipx" + "tailscale.com/tailcfg" + "tailscale.com/util/must" +) + +func TestIPPoolExhaustion(t *testing.T) { + smallPrefix := netip.MustParsePrefix("100.64.1.0/30") // Only 4 IPs: .0, .1, .2, .3 + var ipsb netipx.IPSetBuilder + ipsb.AddPrefix(smallPrefix) + addrPool := must.Get(ipsb.IPSet()) + v6ULA := netip.MustParsePrefix("fd7a:115c:a1e0:a99c:0001::/80") + pool := IPPool{V6ULA: v6ULA, IPSet: addrPool} + + assignedIPs := make(map[netip.Addr]string) + + domains := []string{"a.example.com", "b.example.com", "c.example.com", "d.example.com", "e.example.com"} + + var errs []error + + from := tailcfg.NodeID(12345) + + for i := 0; i < 5; i++ { + for _, domain := range domains { + addrs, err := pool.IPForDomain(from, domain) + if err != nil { + errs = append(errs, fmt.Errorf("failed to get IP for domain %q: %w", domain, err)) + continue + } + + for _, addr := range addrs { + if d, ok := assignedIPs[addr]; ok { + if d != domain { + t.Errorf("IP %s reused for domain %q, previously assigned to %q", addr, domain, d) + } + } else { + assignedIPs[addr] = domain + } + } + } + } + + for addr, domain := range assignedIPs { + if addr.Is4() && !smallPrefix.Contains(addr) { + t.Errorf("IP %s for domain %q not in expected range %s", addr, domain, smallPrefix) + } + if addr.Is6() && !v6ULA.Contains(addr) { + t.Errorf("IP %s for domain %q not in expected range %s", addr, domain, v6ULA) + } + } + + // expect one error for each iteration with the 5th domain + if len(errs) != 5 { + t.Errorf("Expected 5 errors, got %d: %v", len(errs), errs) + } + for _, err := range errs { + if !errors.Is(err, ErrNoIPsAvailable) { + t.Errorf("generateDNSResponse() error = %v, want ErrNoIPsAvailable", err) + } + } +} + +func TestIPPool(t *testing.T) { + var ipsb netipx.IPSetBuilder + ipsb.AddPrefix(netip.MustParsePrefix("100.64.1.0/24")) + addrPool := must.Get(ipsb.IPSet()) + pool := IPPool{ + V6ULA: netip.MustParsePrefix("fd7a:115c:a1e0:a99c:0001::/80"), + IPSet: addrPool, + } + from := tailcfg.NodeID(12345) + addrs, err := pool.IPForDomain(from, "example.com") + if err != nil { + t.Fatalf("ipForDomain() error = %v", err) + } + + if len(addrs) != 2 { + t.Fatalf("ipForDomain() returned %d addresses, want 2", len(addrs)) + } + + v4 := addrs[0] + v6 := addrs[1] + + if !v4.Is4() { + t.Errorf("First address is not IPv4: %s", v4) + } + + if !v6.Is6() { + t.Errorf("Second address is not IPv6: %s", v6) + } + + if !addrPool.Contains(v4) { + t.Errorf("IPv4 address %s not in range %s", v4, addrPool) + } + + domain, ok := pool.DomainForIP(from, v4) + if !ok { + t.Errorf("domainForIP(%s) not found", v4) + } else if domain != "example.com" { + t.Errorf("domainForIP(%s) = %s, want %s", v4, domain, "example.com") + } + + domain, ok = pool.DomainForIP(from, v6) + if !ok { + t.Errorf("domainForIP(%s) not found", v6) + } else if domain != "example.com" { + t.Errorf("domainForIP(%s) = %s, want %s", v6, domain, "example.com") + } + + addrs2, err := pool.IPForDomain(from, "example.com") + if err != nil { + t.Fatalf("ipForDomain() second call error = %v", err) + } + + if !slices.Equal(addrs, addrs2) { + t.Errorf("ipForDomain() second call = %v, want %v", addrs2, addrs) + } +} diff --git a/cmd/natc/ipx.go b/cmd/natc/ippool/ipx.go similarity index 99% rename from cmd/natc/ipx.go rename to cmd/natc/ippool/ipx.go index 06bf7be79..8259a56db 100644 --- a/cmd/natc/ipx.go +++ b/cmd/natc/ippool/ipx.go @@ -1,7 +1,7 @@ // Copyright (c) Tailscale Inc & AUTHORS // SPDX-License-Identifier: BSD-3-Clause -package main +package ippool import ( "math/big" diff --git a/cmd/natc/ipx_test.go b/cmd/natc/ippool/ipx_test.go similarity index 99% rename from cmd/natc/ipx_test.go rename to cmd/natc/ippool/ipx_test.go index b60a5d981..2e2b9d3d4 100644 --- a/cmd/natc/ipx_test.go +++ b/cmd/natc/ippool/ipx_test.go @@ -1,7 +1,7 @@ // Copyright (c) Tailscale Inc & AUTHORS // SPDX-License-Identifier: BSD-3-Clause -package main +package ippool import ( "math" diff --git a/cmd/natc/natc.go b/cmd/natc/natc.go index bff9bce87..270524879 100644 --- a/cmd/natc/natc.go +++ b/cmd/natc/natc.go @@ -13,13 +13,11 @@ import ( "flag" "fmt" "log" - "math/big" "net" "net/http" "net/netip" "os" "strings" - "sync" "time" "github.com/gaissmai/bart" @@ -28,22 +26,18 @@ import ( "go4.org/netipx" "golang.org/x/net/dns/dnsmessage" "tailscale.com/client/local" + "tailscale.com/cmd/natc/ippool" "tailscale.com/envknob" "tailscale.com/hostinfo" "tailscale.com/ipn" "tailscale.com/net/netutil" - "tailscale.com/syncs" "tailscale.com/tailcfg" "tailscale.com/tsnet" "tailscale.com/tsweb" - "tailscale.com/util/dnsname" - "tailscale.com/util/mak" "tailscale.com/util/must" "tailscale.com/wgengine/netstack" ) -var ErrNoIPsAvailable = errors.New("no IPs available") - func main() { hostinfo.SetApp("natc") if !envknob.UseWIPCode() { @@ -141,12 +135,6 @@ func main() { log.Fatalf("ts.Up: %v", err) } - c := &connector{ - ts: ts, - lc: lc, - v6ULA: ula(uint16(*siteID)), - ignoreDsts: ignoreDstTable, - } var prefixes []netip.Prefix for _, s := range strings.Split(*v4PfxStr, ",") { p := netip.MustParsePrefix(strings.TrimSpace(s)) @@ -155,19 +143,31 @@ func main() { } prefixes = append(prefixes, p) } - c.setPrefixes(prefixes) + routes, dnsAddr, addrPool := calculateAddresses(prefixes) + + v6ULA := ula(uint16(*siteID)) + c := &connector{ + ts: ts, + lc: lc, + v6ULA: v6ULA, + ignoreDsts: ignoreDstTable, + ipPool: &ippool.IPPool{V6ULA: v6ULA, IPSet: addrPool}, + routes: routes, + dnsAddr: dnsAddr, + } c.run(ctx) } -func (c *connector) setPrefixes(prefixes []netip.Prefix) { +func calculateAddresses(prefixes []netip.Prefix) (*netipx.IPSet, netip.Addr, *netipx.IPSet) { var ipsb netipx.IPSetBuilder for _, p := range prefixes { ipsb.AddPrefix(p) } - c.routes = must.Get(ipsb.IPSet()) - c.dnsAddr = c.routes.Ranges()[0].From() - ipsb.Remove(c.dnsAddr) - c.ipset = must.Get(ipsb.IPSet()) + routesToAdvertise := must.Get(ipsb.IPSet()) + dnsAddr := routesToAdvertise.Ranges()[0].From() + ipsb.Remove(dnsAddr) + addrPool := must.Get(ipsb.IPSet()) + return routesToAdvertise, dnsAddr, addrPool } type connector struct { @@ -181,10 +181,6 @@ type connector struct { // prevent the app connector from assigning it to a domain. dnsAddr netip.Addr - // ipset is the set of IPv4 ranges to advertise and assign addresses from. - // These are masked prefixes. - ipset *netipx.IPSet - // routes is the set of IPv4 ranges advertised to the tailnet, or ipset with // the dnsAddr removed. routes *netipx.IPSet @@ -192,8 +188,6 @@ type connector struct { // v6ULA is the ULA prefix used by the app connector to assign IPv6 addresses. v6ULA netip.Prefix - perPeerMap syncs.Map[tailcfg.NodeID, *perPeerState] - // ignoreDsts is initialized at start up with the contents of --ignore-destinations (if none it is nil) // It is never mutated, only used for lookups. // Users who want to natc a DNS wildcard but not every address record in that domain can supply the @@ -202,6 +196,8 @@ type connector struct { // return a dns response that contains the ip addresses we discovered with the lookup (ie not the // natc behavior, which would return a dummy ip address pointing at natc). ignoreDsts *bart.Table[bool] + + ipPool *ippool.IPPool } // v6ULA is the ULA prefix used by the app connector to assign IPv6 addresses. @@ -359,13 +355,12 @@ var tsMBox = dnsmessage.MustNewName("support.tailscale.com.") // generateDNSResponse generates a DNS response for the given request. The from // argument is the NodeID of the node that sent the request. func (c *connector) generateDNSResponse(req *dnsmessage.Message, from tailcfg.NodeID) ([]byte, error) { - pm, _ := c.perPeerMap.LoadOrStore(from, newPerPeerState(c)) var addrs []netip.Addr if len(req.Questions) > 0 { switch req.Questions[0].Type { case dnsmessage.TypeAAAA, dnsmessage.TypeA: var err error - addrs, err = pm.ipForDomain(req.Questions[0].Name.String()) + addrs, err = c.ipPool.IPForDomain(from, req.Questions[0].Name.String()) if err != nil { return nil, err } @@ -454,16 +449,8 @@ func (c *connector) handleTCPFlow(src, dst netip.AddrPort) (handler func(net.Con log.Printf("HandleTCPFlow: WhoIs failed: %v\n", err) return nil, false } - - from := who.Node.ID - ps, ok := c.perPeerMap.Load(from) + domain, ok := c.ipPool.DomainForIP(who.Node.ID, dst.Addr()) if !ok { - log.Printf("handleTCPFlow: no perPeerState for %v", from) - return nil, false - } - domain, ok := ps.domainForIP(dst.Addr()) - if !ok { - log.Printf("handleTCPFlow: no domain for IP %v\n", dst.Addr()) return nil, false } return func(conn net.Conn) { @@ -506,86 +493,3 @@ func proxyTCPConn(c net.Conn, dest string) { }) p.Start() } - -// perPeerState holds the state for a single peer. -type perPeerState struct { - v6ULA netip.Prefix - ipset *netipx.IPSet - - mu sync.Mutex - addrInUse *big.Int - domainToAddr map[string][]netip.Addr - addrToDomain *bart.Table[string] -} - -func newPerPeerState(c *connector) *perPeerState { - return &perPeerState{ - ipset: c.ipset, - v6ULA: c.v6ULA, - } -} - -// domainForIP returns the domain name assigned to the given IP address and -// whether it was found. -func (ps *perPeerState) domainForIP(ip netip.Addr) (_ string, ok bool) { - ps.mu.Lock() - defer ps.mu.Unlock() - if ps.addrToDomain == nil { - return "", false - } - return ps.addrToDomain.Lookup(ip) -} - -// ipForDomain assigns a pair of unique IP addresses for the given domain and -// returns them. The first address is an IPv4 address and the second is an IPv6 -// address. If the domain already has assigned addresses, it returns them. -func (ps *perPeerState) ipForDomain(domain string) ([]netip.Addr, error) { - fqdn, err := dnsname.ToFQDN(domain) - if err != nil { - return nil, err - } - domain = fqdn.WithoutTrailingDot() - - ps.mu.Lock() - defer ps.mu.Unlock() - if addrs, ok := ps.domainToAddr[domain]; ok { - return addrs, nil - } - addrs := ps.assignAddrsLocked(domain) - if addrs == nil { - return nil, ErrNoIPsAvailable - } - return addrs, nil -} - -// unusedIPv4Locked returns an unused IPv4 address from the available ranges. -func (ps *perPeerState) unusedIPv4Locked() netip.Addr { - if ps.addrInUse == nil { - ps.addrInUse = big.NewInt(0) - } - return allocAddr(ps.ipset, ps.addrInUse) -} - -// assignAddrsLocked assigns a pair of unique IP addresses for the given domain -// and returns them. The first address is an IPv4 address and the second is an -// IPv6 address. It does not check if the domain already has assigned addresses. -// ps.mu must be held. -func (ps *perPeerState) assignAddrsLocked(domain string) []netip.Addr { - if ps.addrToDomain == nil { - ps.addrToDomain = &bart.Table[string]{} - } - v4 := ps.unusedIPv4Locked() - if !v4.IsValid() { - return nil - } - as16 := ps.v6ULA.Addr().As16() - as4 := v4.As4() - copy(as16[12:], as4[:]) - v6 := netip.AddrFrom16(as16) - addrs := []netip.Addr{v4, v6} - mak.Set(&ps.domainToAddr, domain, addrs) - for _, a := range addrs { - ps.addrToDomain.Insert(netip.PrefixFrom(a, a.BitLen()), domain) - } - return addrs -} diff --git a/cmd/natc/natc_test.go b/cmd/natc/natc_test.go index 66e0141b9..09ade0a98 100644 --- a/cmd/natc/natc_test.go +++ b/cmd/natc/natc_test.go @@ -4,15 +4,13 @@ package main import ( - "errors" - "fmt" "net/netip" - "slices" "testing" "github.com/gaissmai/bart" "github.com/google/go-cmp/cmp" "golang.org/x/net/dns/dnsmessage" + "tailscale.com/cmd/natc/ippool" "tailscale.com/tailcfg" ) @@ -214,62 +212,6 @@ func TestDNSResponse(t *testing.T) { } } -func TestPerPeerState(t *testing.T) { - c := &connector{ - v6ULA: netip.MustParsePrefix("fd7a:115c:a1e0:a99c:0001::/80"), - } - c.setPrefixes([]netip.Prefix{netip.MustParsePrefix("100.64.1.0/24")}) - - ps := newPerPeerState(c) - - addrs, err := ps.ipForDomain("example.com") - if err != nil { - t.Fatalf("ipForDomain() error = %v", err) - } - - if len(addrs) != 2 { - t.Fatalf("ipForDomain() returned %d addresses, want 2", len(addrs)) - } - - v4 := addrs[0] - v6 := addrs[1] - - if !v4.Is4() { - t.Errorf("First address is not IPv4: %s", v4) - } - - if !v6.Is6() { - t.Errorf("Second address is not IPv6: %s", v6) - } - - if !c.ipset.Contains(v4) { - t.Errorf("IPv4 address %s not in range %s", v4, c.ipset) - } - - domain, ok := ps.domainForIP(v4) - if !ok { - t.Errorf("domainForIP(%s) not found", v4) - } else if domain != "example.com" { - t.Errorf("domainForIP(%s) = %s, want %s", v4, domain, "example.com") - } - - domain, ok = ps.domainForIP(v6) - if !ok { - t.Errorf("domainForIP(%s) not found", v6) - } else if domain != "example.com" { - t.Errorf("domainForIP(%s) = %s, want %s", v6, domain, "example.com") - } - - addrs2, err := ps.ipForDomain("example.com") - if err != nil { - t.Fatalf("ipForDomain() second call error = %v", err) - } - - if !slices.Equal(addrs, addrs2) { - t.Errorf("ipForDomain() second call = %v, want %v", addrs2, addrs) - } -} - func TestIgnoreDestination(t *testing.T) { ignoreDstTable := &bart.Table[bool]{} ignoreDstTable.Insert(netip.MustParsePrefix("192.168.1.0/24"), true) @@ -317,10 +259,14 @@ func TestIgnoreDestination(t *testing.T) { } func TestConnectorGenerateDNSResponse(t *testing.T) { + v6ULA := netip.MustParsePrefix("fd7a:115c:a1e0:a99c:0001::/80") + routes, dnsAddr, addrPool := calculateAddresses([]netip.Prefix{netip.MustParsePrefix("100.64.1.0/24")}) c := &connector{ - v6ULA: netip.MustParsePrefix("fd7a:115c:a1e0:a99c:0001::/80"), + v6ULA: v6ULA, + ipPool: &ippool.IPPool{V6ULA: v6ULA, IPSet: addrPool}, + routes: routes, + dnsAddr: dnsAddr, } - c.setPrefixes([]netip.Prefix{netip.MustParsePrefix("100.64.1.0/24")}) req := &dnsmessage.Message{ Header: dnsmessage.Header{ID: 1234}, @@ -351,62 +297,13 @@ func TestConnectorGenerateDNSResponse(t *testing.T) { if !cmp.Equal(resp1, resp2) { t.Errorf("generateDNSResponse() responses differ between calls") } -} -func TestIPPoolExhaustion(t *testing.T) { - smallPrefix := netip.MustParsePrefix("100.64.1.0/30") // Only 4 IPs: .0, .1, .2, .3 - c := &connector{ - v6ULA: netip.MustParsePrefix("fd7a:115c:a1e0:a99c:0001::/80"), + var msg dnsmessage.Message + err = msg.Unpack(resp1) + if err != nil { + t.Fatalf("dnsmessage Unpack error = %v", err) } - c.setPrefixes([]netip.Prefix{smallPrefix}) - - ps := newPerPeerState(c) - - assignedIPs := make(map[netip.Addr]string) - - domains := []string{"a.example.com", "b.example.com", "c.example.com", "d.example.com"} - - var errs []error - - for i := 0; i < 5; i++ { - for _, domain := range domains { - addrs, err := ps.ipForDomain(domain) - if err != nil { - errs = append(errs, fmt.Errorf("failed to get IP for domain %q: %w", domain, err)) - continue - } - - for _, addr := range addrs { - if d, ok := assignedIPs[addr]; ok { - if d != domain { - t.Errorf("IP %s reused for domain %q, previously assigned to %q", addr, domain, d) - } - } else { - assignedIPs[addr] = domain - } - } - } - } - - for addr, domain := range assignedIPs { - if addr.Is4() && !smallPrefix.Contains(addr) { - t.Errorf("IP %s for domain %q not in expected range %s", addr, domain, smallPrefix) - } - if addr.Is6() && !c.v6ULA.Contains(addr) { - t.Errorf("IP %s for domain %q not in expected range %s", addr, domain, c.v6ULA) - } - if addr == c.dnsAddr { - t.Errorf("IP %s for domain %q is the reserved DNS address", addr, domain) - } - } - - // expect one error for each iteration with the 4th domain - if len(errs) != 5 { - t.Errorf("Expected 5 errors, got %d: %v", len(errs), errs) - } - for _, err := range errs { - if !errors.Is(err, ErrNoIPsAvailable) { - t.Errorf("generateDNSResponse() error = %v, want ErrNoIPsAvailable", err) - } + if len(msg.Answers) != 1 { + t.Fatalf("expected 1 answer, got: %d", len(msg.Answers)) } }