From 486a55f0a9bffc45eb34350f24fea5a76be5169a Mon Sep 17 00:00:00 2001 From: Fran Bull Date: Wed, 16 Apr 2025 10:21:50 -0700 Subject: [PATCH] cmd/natc: add optional consensus backend Enable nat connector to be run on a cluster of machines for high availability. Updates #14667 Signed-off-by: Fran Bull --- cmd/natc/ippool/consensusippool.go | 434 ++++++++++++++++++++ cmd/natc/ippool/consensusippool_test.go | 383 +++++++++++++++++ cmd/natc/ippool/consensusippoolserialize.go | 164 ++++++++ cmd/natc/ippool/ippool.go | 21 +- cmd/natc/ippool/ippool_test.go | 7 +- cmd/natc/natc.go | 28 +- cmd/natc/natc_test.go | 2 +- 7 files changed, 1029 insertions(+), 10 deletions(-) create mode 100644 cmd/natc/ippool/consensusippool.go create mode 100644 cmd/natc/ippool/consensusippool_test.go create mode 100644 cmd/natc/ippool/consensusippoolserialize.go diff --git a/cmd/natc/ippool/consensusippool.go b/cmd/natc/ippool/consensusippool.go new file mode 100644 index 000000000..4783209b2 --- /dev/null +++ b/cmd/natc/ippool/consensusippool.go @@ -0,0 +1,434 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package ippool + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "log" + "net/netip" + "time" + + "github.com/hashicorp/raft" + "go4.org/netipx" + "tailscale.com/syncs" + "tailscale.com/tailcfg" + "tailscale.com/tsconsensus" + "tailscale.com/tsnet" + "tailscale.com/util/mak" +) + +// ConsensusIPPool implements an [IPPool] that is distributed among members of a cluster for high availability. +// Writes are directed to a leader among the cluster and are slower than reads, reads are performed locally +// using information replicated from the leader. +// The cluster maintains consistency, reads can be stale and writes can be unavailable if sufficient cluster +// peers are unavailable. +type ConsensusIPPool struct { + IPSet *netipx.IPSet + perPeerMap *syncs.Map[tailcfg.NodeID, *consensusPerPeerState] + consensus commandExecutor + unusedAddressLifetime time.Duration +} + +func NewConsensusIPPool(ipSet *netipx.IPSet) *ConsensusIPPool { + return &ConsensusIPPool{ + unusedAddressLifetime: 48 * time.Hour, // TODO (fran) is this appropriate? should it be configurable? + IPSet: ipSet, + perPeerMap: &syncs.Map[tailcfg.NodeID, *consensusPerPeerState]{}, + } +} + +// IPForDomain looks up or creates an IP address allocation for the tailcfg.NodeID and domain pair. +// If no address association is found, one is allocated from the range of free addresses for this tailcfg.NodeID. +// If no more address are available, an error is returned. +func (ipp *ConsensusIPPool) IPForDomain(nid tailcfg.NodeID, domain string) (netip.Addr, error) { + now := time.Now() + // Check local state; local state may be stale. If we have an IP for this domain, and we are not + // close to the expiry time for the domain, it's safe to return what we have. + ps, psFound := ipp.perPeerMap.Load(nid) + if psFound { + if addr, addrFound := ps.domainToAddr[domain]; addrFound { + if ww, wwFound := ps.addrToDomain.Load(addr); wwFound { + if !isCloseToExpiry(ww.LastUsed, now, ipp.unusedAddressLifetime) { + ipp.fireAndForgetMarkLastUsed(nid, addr, ww, now) + return addr, nil + } + } + } + } + + // go via consensus + args := checkoutAddrArgs{ + NodeID: nid, + Domain: domain, + ReuseDeadline: now.Add(-1 * ipp.unusedAddressLifetime), + UpdatedAt: now, + } + bs, err := json.Marshal(args) + if err != nil { + return netip.Addr{}, err + } + c := tsconsensus.Command{ + Name: "checkoutAddr", + Args: bs, + } + result, err := ipp.consensus.ExecuteCommand(c) + if err != nil { + log.Printf("IPForDomain: raft error executing command: %v", err) + return netip.Addr{}, err + } + if result.Err != nil { + log.Printf("IPForDomain: error returned from state machine: %v", err) + return netip.Addr{}, result.Err + } + var addr netip.Addr + err = json.Unmarshal(result.Result, &addr) + return addr, err +} + +// DomainForIP looks up the domain associated with a tailcfg.NodeID and netip.Addr pair. +// If there is no association, the result is empty and ok is false. +func (ipp *ConsensusIPPool) DomainForIP(from tailcfg.NodeID, addr netip.Addr, updatedAt time.Time) (string, bool) { + // Look in local state, to save a consensus round trip; local state may be stale. + // + // The only time we expect ordering of commands to matter to clients is on first + // connection to a domain. In that case it may be that although we don't find the + // domain in our local state, it is in fact in the state of the state machine (ie + // the client did a DNS lookup, and we responded with an IP and _should_ know that + // domain when the TCP connection for that IP arrives.) + // + // So it's ok to return local state, unless local state doesn't recognize the domain, + // in which case we should check the consensus state machine to know for sure. + var domain string + ww, ok := ipp.domainLookup(from, addr) + if ok { + domain = ww.Domain + } else { + d, err := ipp.readDomainForIP(from, addr) + if err != nil { + log.Printf("error reading domain from consensus: %v", err) + return "", false + } + domain = d + } + if domain == "" { + log.Printf("did not find domain for node: %v, addr: %s", from, addr) + return "", false + } + ipp.fireAndForgetMarkLastUsed(from, addr, ww, updatedAt) + return domain, true +} + +func (ipp *ConsensusIPPool) fireAndForgetMarkLastUsed(from tailcfg.NodeID, addr netip.Addr, ww whereWhen, updatedAt time.Time) { + window := 5 * time.Minute + if updatedAt.Sub(ww.LastUsed).Abs() < window { + return + } + go func() { + err := ipp.markLastUsed(from, addr, ww.Domain, updatedAt) + if err != nil { + log.Printf("error marking last used: %v", err) + } + }() +} + +func (ipp *ConsensusIPPool) domainLookup(from tailcfg.NodeID, addr netip.Addr) (whereWhen, bool) { + ps, ok := ipp.perPeerMap.Load(from) + if !ok { + log.Printf("domainLookup: peer state absent for: %d", from) + return whereWhen{}, false + } + ww, ok := ps.addrToDomain.Load(addr) + if !ok { + log.Printf("domainLookup: peer state doesn't recognize addr: %s", addr) + return whereWhen{}, false + } + return ww, true +} + +// StartConsensus is part of the IPPool interface. It starts the raft background routines that handle consensus. +func (ipp *ConsensusIPPool) StartConsensus(ctx context.Context, ts *tsnet.Server, clusterTag string) error { + cfg := tsconsensus.DefaultConfig() + cfg.ServeDebugMonitor = true + cns, err := tsconsensus.Start(ctx, ts, ipp, clusterTag, cfg) + if err != nil { + return err + } + ipp.consensus = cns + return nil +} + +type whereWhen struct { + Domain string + LastUsed time.Time +} + +type consensusPerPeerState struct { + domainToAddr map[string]netip.Addr + addrToDomain *syncs.Map[netip.Addr, whereWhen] +} + +// StopConsensus is part of the IPPool interface. It stops the raft background routines that handle consensus. +func (ipp *ConsensusIPPool) StopConsensus(ctx context.Context) error { + return (ipp.consensus).(*tsconsensus.Consensus).Stop(ctx) +} + +// unusedIPV4 finds the next unused or expired IP address in the pool. +// IP addresses in the pool should be reused if they haven't been used for some period of time. +// reuseDeadline is the time before which addresses are considered to be expired. +// So if addresses are being reused after they haven't been used for 24 hours say, reuseDeadline +// would be 24 hours ago. +func (ps *consensusPerPeerState) unusedIPV4(ipset *netipx.IPSet, reuseDeadline time.Time) (netip.Addr, bool, string, error) { + // If we want to have a random IP choice behavior we could make that work with the state machine by doing something like + // passing the randomly chosen IP into the state machine call (so replaying logs would still be deterministic). + for _, r := range ipset.Ranges() { + ip := r.From() + toIP := r.To() + if !ip.IsValid() || !toIP.IsValid() { + continue + } + for toIP.Compare(ip) != -1 { + ww, ok := ps.addrToDomain.Load(ip) + if !ok { + return ip, false, "", nil + } + if ww.LastUsed.Before(reuseDeadline) { + return ip, true, ww.Domain, nil + } + ip = ip.Next() + } + } + return netip.Addr{}, false, "", errors.New("ip pool exhausted") +} + +// isCloseToExpiry returns true if the lastUsed and now times are more than +// half the lifetime apart +func isCloseToExpiry(lastUsed, now time.Time, lifetime time.Duration) bool { + return now.Sub(lastUsed).Abs() > (lifetime / 2) +} + +type readDomainForIPArgs struct { + NodeID tailcfg.NodeID + Addr netip.Addr +} + +// executeReadDomainForIP parses a readDomainForIP log entry and applies it. +func (ipp *ConsensusIPPool) executeReadDomainForIP(bs []byte) tsconsensus.CommandResult { + var args readDomainForIPArgs + err := json.Unmarshal(bs, &args) + if err != nil { + return tsconsensus.CommandResult{Err: err} + } + return ipp.applyReadDomainForIP(args.NodeID, args.Addr) +} + +func (ipp *ConsensusIPPool) applyReadDomainForIP(from tailcfg.NodeID, addr netip.Addr) tsconsensus.CommandResult { + domain := func() string { + ps, ok := ipp.perPeerMap.Load(from) + if !ok { + return "" + } + ww, ok := ps.addrToDomain.Load(addr) + if !ok { + return "" + } + return ww.Domain + }() + resultBs, err := json.Marshal(domain) + return tsconsensus.CommandResult{Result: resultBs, Err: err} +} + +// readDomainForIP executes a readDomainForIP command on the leader with raft. +func (ipp *ConsensusIPPool) readDomainForIP(nid tailcfg.NodeID, addr netip.Addr) (string, error) { + args := readDomainForIPArgs{ + NodeID: nid, + Addr: addr, + } + bs, err := json.Marshal(args) + if err != nil { + return "", err + } + c := tsconsensus.Command{ + Name: "readDomainForIP", + Args: bs, + } + result, err := ipp.consensus.ExecuteCommand(c) + if err != nil { + log.Printf("readDomainForIP: raft error executing command: %v", err) + return "", err + } + if result.Err != nil { + log.Printf("readDomainForIP: error returned from state machine: %v", err) + return "", result.Err + } + var domain string + err = json.Unmarshal(result.Result, &domain) + return domain, err +} + +type markLastUsedArgs struct { + NodeID tailcfg.NodeID + Addr netip.Addr + Domain string + UpdatedAt time.Time +} + +// executeMarkLastUsed parses a markLastUsed log entry and applies it. +func (ipp *ConsensusIPPool) executeMarkLastUsed(bs []byte) tsconsensus.CommandResult { + var args markLastUsedArgs + err := json.Unmarshal(bs, &args) + if err != nil { + return tsconsensus.CommandResult{Err: err} + } + err = ipp.applyMarkLastUsed(args.NodeID, args.Addr, args.Domain, args.UpdatedAt) + if err != nil { + return tsconsensus.CommandResult{Err: err} + } + return tsconsensus.CommandResult{} +} + +// applyMarkLastUsed applies the arguments from the log entry to the state. It updates an entry in the AddrToDomain +// map with a new LastUsed timestamp. +// applyMarkLastUsed is not safe for concurrent access. It's only called from raft which will +// not call it concurrently. +func (ipp *ConsensusIPPool) applyMarkLastUsed(from tailcfg.NodeID, addr netip.Addr, domain string, updatedAt time.Time) error { + ps, ok := ipp.perPeerMap.Load(from) + if !ok { + // There's nothing to mark. But this is unexpected, because we mark last used after we do things with peer state. + log.Printf("applyMarkLastUsed: could not find peer state, nodeID: %s", from) + return nil + } + ww, ok := ps.addrToDomain.Load(addr) + if !ok { + // The peer state didn't have an entry for the IP address (possibly it expired), so there's nothing to mark. + return nil + } + if ww.Domain != domain { + // The IP address expired and was reused for a new domain. Don't mark. + return nil + } + if ww.LastUsed.After(updatedAt) { + // This has been marked more recently. Don't mark. + return nil + } + ww.LastUsed = updatedAt + ps.addrToDomain.Store(addr, ww) + return nil +} + +// markLastUsed executes a markLastUsed command on the leader with raft. +func (ipp *ConsensusIPPool) markLastUsed(nid tailcfg.NodeID, addr netip.Addr, domain string, lastUsed time.Time) error { + args := markLastUsedArgs{ + NodeID: nid, + Addr: addr, + Domain: domain, + UpdatedAt: lastUsed, + } + bs, err := json.Marshal(args) + if err != nil { + return err + } + c := tsconsensus.Command{ + Name: "markLastUsed", + Args: bs, + } + result, err := ipp.consensus.ExecuteCommand(c) + if err != nil { + log.Printf("markLastUsed: raft error executing command: %v", err) + return err + } + if result.Err != nil { + log.Printf("markLastUsed: error returned from state machine: %v", err) + return result.Err + } + return nil +} + +type checkoutAddrArgs struct { + NodeID tailcfg.NodeID + Domain string + ReuseDeadline time.Time + UpdatedAt time.Time +} + +// executeCheckoutAddr parses a checkoutAddr raft log entry and applies it. +func (ipp *ConsensusIPPool) executeCheckoutAddr(bs []byte) tsconsensus.CommandResult { + var args checkoutAddrArgs + err := json.Unmarshal(bs, &args) + if err != nil { + return tsconsensus.CommandResult{Err: err} + } + addr, err := ipp.applyCheckoutAddr(args.NodeID, args.Domain, args.ReuseDeadline, args.UpdatedAt) + if err != nil { + return tsconsensus.CommandResult{Err: err} + } + resultBs, err := json.Marshal(addr) + if err != nil { + return tsconsensus.CommandResult{Err: err} + } + return tsconsensus.CommandResult{Result: resultBs} +} + +// applyCheckoutAddr finds the IP address for a nid+domain +// Each nid can use all of the addresses in the pool. +// updatedAt is the current time, the time at which we are wanting to get a new IP address. +// reuseDeadline is the time before which addresses are considered to be expired. +// So if addresses are being reused after they haven't been used for 24 hours say updatedAt would be now +// and reuseDeadline would be 24 hours ago. +// It is not safe for concurrent access (it's only called from raft, which will not call concurrently +// so that's fine). +func (ipp *ConsensusIPPool) applyCheckoutAddr(nid tailcfg.NodeID, domain string, reuseDeadline, updatedAt time.Time) (netip.Addr, error) { + ps, ok := ipp.perPeerMap.Load(nid) + if !ok { + ps = &consensusPerPeerState{ + addrToDomain: &syncs.Map[netip.Addr, whereWhen]{}, + } + ipp.perPeerMap.Store(nid, ps) + } + if existing, ok := ps.domainToAddr[domain]; ok { + ww, ok := ps.addrToDomain.Load(existing) + if ok { + ww.LastUsed = updatedAt + ps.addrToDomain.Store(existing, ww) + return existing, nil + } + log.Printf("applyCheckoutAddr: data out of sync, allocating new IP") + } + addr, wasInUse, previousDomain, err := ps.unusedIPV4(ipp.IPSet, reuseDeadline) + if err != nil { + return netip.Addr{}, err + } + mak.Set(&ps.domainToAddr, domain, addr) + if wasInUse { + delete(ps.domainToAddr, previousDomain) + } + ps.addrToDomain.Store(addr, whereWhen{Domain: domain, LastUsed: updatedAt}) + return addr, nil +} + +// Apply is part of the raft.FSM interface. It takes an incoming log entry and applies it to the state. +func (ipp *ConsensusIPPool) Apply(l *raft.Log) any { + var c tsconsensus.Command + if err := json.Unmarshal(l.Data, &c); err != nil { + panic(fmt.Sprintf("failed to unmarshal command: %s", err.Error())) + } + switch c.Name { + case "checkoutAddr": + return ipp.executeCheckoutAddr(c.Args) + case "markLastUsed": + return ipp.executeMarkLastUsed(c.Args) + case "readDomainForIP": + return ipp.executeReadDomainForIP(c.Args) + default: + panic(fmt.Sprintf("unrecognized command: %s", c.Name)) + } +} + +// commandExecutor is an interface covering the routing parts of consensus +// used to allow a fake in the tests +type commandExecutor interface { + ExecuteCommand(tsconsensus.Command) (tsconsensus.CommandResult, error) +} diff --git a/cmd/natc/ippool/consensusippool_test.go b/cmd/natc/ippool/consensusippool_test.go new file mode 100644 index 000000000..242cdffaf --- /dev/null +++ b/cmd/natc/ippool/consensusippool_test.go @@ -0,0 +1,383 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package ippool + +import ( + "bytes" + "encoding/json" + "fmt" + "io" + "net/netip" + "testing" + "time" + + "github.com/hashicorp/raft" + "go4.org/netipx" + "tailscale.com/tailcfg" + "tailscale.com/tsconsensus" + "tailscale.com/util/must" +) + +func makeSetFromPrefix(pfx netip.Prefix) *netipx.IPSet { + var ipsb netipx.IPSetBuilder + ipsb.AddPrefix(pfx) + return must.Get(ipsb.IPSet()) +} + +type FakeConsensus struct { + ipp *ConsensusIPPool +} + +func (c *FakeConsensus) ExecuteCommand(cmd tsconsensus.Command) (tsconsensus.CommandResult, error) { + b, err := json.Marshal(cmd) + if err != nil { + return tsconsensus.CommandResult{}, err + } + result := c.ipp.Apply(&raft.Log{Data: b}) + return result.(tsconsensus.CommandResult), nil +} + +func makePool(pfx netip.Prefix) *ConsensusIPPool { + ipp := NewConsensusIPPool(makeSetFromPrefix(pfx)) + ipp.consensus = &FakeConsensus{ipp: ipp} + return ipp +} + +func TestConsensusIPForDomain(t *testing.T) { + pfx := netip.MustParsePrefix("100.64.0.0/16") + ipp := makePool(pfx) + from := tailcfg.NodeID(1) + + a, err := ipp.IPForDomain(from, "example.com") + if err != nil { + t.Fatal(err) + } + if !pfx.Contains(a) { + t.Fatalf("expected %v to be in the prefix %v", a, pfx) + } + + b, err := ipp.IPForDomain(from, "a.example.com") + if err != nil { + t.Fatal(err) + } + if !pfx.Contains(b) { + t.Fatalf("expected %v to be in the prefix %v", b, pfx) + } + if b == a { + t.Fatalf("same address issued twice %v, %v", a, b) + } + + c, err := ipp.IPForDomain(from, "example.com") + if err != nil { + t.Fatal(err) + } + if c != a { + t.Fatalf("expected %v to be remembered as the addr for example.com, but got %v", a, c) + } +} + +func TestConsensusPoolExhaustion(t *testing.T) { + ipp := makePool(netip.MustParsePrefix("100.64.0.0/31")) + from := tailcfg.NodeID(1) + + subdomains := []string{"a", "b", "c"} + for i, sd := range subdomains { + _, err := ipp.IPForDomain(from, fmt.Sprintf("%s.example.com", sd)) + if i < 2 && err != nil { + t.Fatal(err) + } + expected := "ip pool exhausted" + if i == 2 && err.Error() != expected { + t.Fatalf("expected error to be '%s', got '%s'", expected, err.Error()) + } + } +} + +func TestConsensusPoolExpiry(t *testing.T) { + ipp := makePool(netip.MustParsePrefix("100.64.0.0/31")) + firstIP := netip.MustParseAddr("100.64.0.0") + secondIP := netip.MustParseAddr("100.64.0.1") + timeOfUse := time.Now() + beforeTimeOfUse := timeOfUse.Add(-1 * time.Hour) + afterTimeOfUse := timeOfUse.Add(1 * time.Hour) + from := tailcfg.NodeID(1) + + // the pool is unused, we get an address, and it's marked as being used at timeOfUse + aAddr, err := ipp.applyCheckoutAddr(from, "a.example.com", time.Time{}, timeOfUse) + if err != nil { + t.Fatal(err) + } + if aAddr.Compare(firstIP) != 0 { + t.Fatalf("expected %s, got %s", firstIP, aAddr) + } + ww, ok := ipp.domainLookup(from, firstIP) + if !ok { + t.Fatal("expected wherewhen to be found") + } + if ww.Domain != "a.example.com" { + t.Fatalf("expected aAddr to look up to a.example.com, got: %s", ww.Domain) + } + + // the time before which we will reuse addresses is prior to timeOfUse, so no reuse + bAddr, err := ipp.applyCheckoutAddr(from, "b.example.com", beforeTimeOfUse, timeOfUse) + if err != nil { + t.Fatal(err) + } + if bAddr.Compare(secondIP) != 0 { + t.Fatalf("expected %s, got %s", secondIP, bAddr) + } + + // the time before which we will reuse addresses is after timeOfUse, so reuse addresses that were marked as used at timeOfUse. + cAddr, err := ipp.applyCheckoutAddr(from, "c.example.com", afterTimeOfUse, timeOfUse) + if err != nil { + t.Fatal(err) + } + if cAddr.Compare(firstIP) != 0 { + t.Fatalf("expected %s, got %s", firstIP, cAddr) + } + ww, ok = ipp.domainLookup(from, firstIP) + if !ok { + t.Fatal("expected wherewhen to be found") + } + if ww.Domain != "c.example.com" { + t.Fatalf("expected firstIP to look up to c.example.com, got: %s", ww.Domain) + } + + // the addr remains associated with c.example.com + cAddrAgain, err := ipp.applyCheckoutAddr(from, "c.example.com", afterTimeOfUse, timeOfUse) + if err != nil { + t.Fatal(err) + } + if cAddrAgain.Compare(cAddr) != 0 { + t.Fatalf("expected cAddrAgain to be cAddr, but they are different. cAddrAgain=%s cAddr=%s", cAddrAgain, cAddr) + } + ww, ok = ipp.domainLookup(from, firstIP) + if !ok { + t.Fatal("expected wherewhen to be found") + } + if ww.Domain != "c.example.com" { + t.Fatalf("expected firstIP to look up to c.example.com, got: %s", ww.Domain) + } +} + +func TestConsensusPoolApplyMarkLastUsed(t *testing.T) { + ipp := makePool(netip.MustParsePrefix("100.64.0.0/31")) + firstIP := netip.MustParseAddr("100.64.0.0") + time1 := time.Now() + time2 := time1.Add(1 * time.Hour) + from := tailcfg.NodeID(1) + domain := "example.com" + + aAddr, err := ipp.applyCheckoutAddr(from, domain, time.Time{}, time1) + if err != nil { + t.Fatal(err) + } + if aAddr.Compare(firstIP) != 0 { + t.Fatalf("expected %s, got %s", firstIP, aAddr) + } + // example.com LastUsed is now time1 + ww, ok := ipp.domainLookup(from, firstIP) + if !ok { + t.Fatal("expected wherewhen to be found") + } + if ww.LastUsed != time1 { + t.Fatalf("expected %s, got %s", time1, ww.LastUsed) + } + if ww.Domain != domain { + t.Fatalf("expected %s, got %s", domain, ww.Domain) + } + + err = ipp.applyMarkLastUsed(from, firstIP, domain, time2) + if err != nil { + t.Fatal(err) + } + + // example.com LastUsed is now time2 + ww, ok = ipp.domainLookup(from, firstIP) + if !ok { + t.Fatal("expected wherewhen to be found") + } + if ww.LastUsed != time2 { + t.Fatalf("expected %s, got %s", time2, ww.LastUsed) + } + if ww.Domain != domain { + t.Fatalf("expected %s, got %s", domain, ww.Domain) + } +} + +func TestConsensusDomainForIP(t *testing.T) { + ipp := makePool(netip.MustParsePrefix("100.64.0.0/16")) + from := tailcfg.NodeID(1) + domain := "example.com" + now := time.Now() + + d, ok := ipp.DomainForIP(from, netip.MustParseAddr("100.64.0.1"), now) + if d != "" { + t.Fatalf("expected an empty string if the addr is not found but got %s", d) + } + if ok { + t.Fatalf("expected domain to not be found for IP, as it has never been looked up") + } + a, err := ipp.IPForDomain(from, domain) + if err != nil { + t.Fatal(err) + } + d2, ok := ipp.DomainForIP(from, a, now) + if d2 != domain { + t.Fatalf("expected %s but got %s", domain, d2) + } + if !ok { + t.Fatalf("expected domain to be found for IP that was handed out for it") + } +} + +func TestConsensusReadDomainForIP(t *testing.T) { + ipp := makePool(netip.MustParsePrefix("100.64.0.0/16")) + from := tailcfg.NodeID(1) + domain := "example.com" + + d, err := ipp.readDomainForIP(from, netip.MustParseAddr("100.64.0.1")) + if err != nil { + t.Fatal(err) + } + if d != "" { + t.Fatalf("expected an empty string if the addr is not found but got %s", d) + } + a, err := ipp.IPForDomain(from, domain) + if err != nil { + t.Fatal(err) + } + d2, err := ipp.readDomainForIP(from, a) + if err != nil { + t.Fatal(err) + } + if d2 != domain { + t.Fatalf("expected %s but got %s", domain, d2) + } +} + +func TestConsensusSnapshot(t *testing.T) { + pfx := netip.MustParsePrefix("100.64.0.0/16") + ipp := makePool(pfx) + domain := "example.com" + expectedAddr := netip.MustParseAddr("100.64.0.0") + expectedFrom := expectedAddr + expectedTo := netip.MustParseAddr("100.64.255.255") + from := tailcfg.NodeID(1) + + // pool allocates first addr for from + if _, err := ipp.IPForDomain(from, domain); err != nil { + t.Fatal(err) + } + // take a snapshot + fsmSnap, err := ipp.Snapshot() + if err != nil { + t.Fatal(err) + } + snap := fsmSnap.(fsmSnapshot) + + // verify snapshot state matches the state we know ipp will have + // ipset matches ipp.IPSet + if len(snap.IPSet.Ranges) != 1 { + t.Fatalf("expected 1, got %d", len(snap.IPSet.Ranges)) + } + if snap.IPSet.Ranges[0].From != expectedFrom { + t.Fatalf("want %s, got %s", expectedFrom, snap.IPSet.Ranges[0].From) + } + if snap.IPSet.Ranges[0].To != expectedTo { + t.Fatalf("want %s, got %s", expectedTo, snap.IPSet.Ranges[0].To) + } + + // perPeerMap has one entry, for from + if len(snap.PerPeerMap) != 1 { + t.Fatalf("expected 1, got %d", len(snap.PerPeerMap)) + } + ps := snap.PerPeerMap[from] + + // the one peer state has allocated one address, the first in the prefix + if len(ps.DomainToAddr) != 1 { + t.Fatalf("expected 1, got %d", len(ps.DomainToAddr)) + } + addr := ps.DomainToAddr[domain] + if addr != expectedAddr { + t.Fatalf("want %s, got %s", expectedAddr.String(), addr.String()) + } + if len(ps.AddrToDomain) != 1 { + t.Fatalf("expected 1, got %d", len(ps.AddrToDomain)) + } + ww := ps.AddrToDomain[addr] + if ww.Domain != domain { + t.Fatalf("want %s, got %s", domain, ww.Domain) + } +} + +func TestConsensusRestore(t *testing.T) { + pfx := netip.MustParsePrefix("100.64.0.0/16") + ipp := makePool(pfx) + domain := "example.com" + expectedAddr := netip.MustParseAddr("100.64.0.0") + from := tailcfg.NodeID(1) + + if _, err := ipp.IPForDomain(from, domain); err != nil { + t.Fatal(err) + } + // take the snapshot after only 1 addr allocated + fsmSnap, err := ipp.Snapshot() + if err != nil { + t.Fatal(err) + } + snap := fsmSnap.(fsmSnapshot) + + if _, err := ipp.IPForDomain(from, "b.example.com"); err != nil { + t.Fatal(err) + } + if _, err := ipp.IPForDomain(from, "c.example.com"); err != nil { + t.Fatal(err) + } + if _, err := ipp.IPForDomain(from, "d.example.com"); err != nil { + t.Fatal(err) + } + // ipp now has 4 entries in domainToAddr + ps, _ := ipp.perPeerMap.Load(from) + if len(ps.domainToAddr) != 4 { + t.Fatalf("want 4, got %d", len(ps.domainToAddr)) + } + + // restore the snapshot + bs, err := json.Marshal(snap) + if err != nil { + t.Fatal(err) + } + err = ipp.Restore(io.NopCloser(bytes.NewBuffer(bs))) + if err != nil { + t.Fatal(err) + } + + // everything should be as it was when the snapshot was taken + if ipp.perPeerMap.Len() != 1 { + t.Fatalf("want 1, got %d", ipp.perPeerMap.Len()) + } + psAfter, _ := ipp.perPeerMap.Load(from) + if len(psAfter.domainToAddr) != 1 { + t.Fatalf("want 1, got %d", len(psAfter.domainToAddr)) + } + if psAfter.domainToAddr[domain] != expectedAddr { + t.Fatalf("want %s, got %s", expectedAddr, psAfter.domainToAddr[domain]) + } + ww, _ := psAfter.addrToDomain.Load(expectedAddr) + if ww.Domain != domain { + t.Fatalf("want %s, got %s", domain, ww.Domain) + } +} + +func TestConsensusIsCloseToExpiry(t *testing.T) { + a := time.Now() + b := a.Add(5 * time.Second) + if !isCloseToExpiry(a, b, 8*time.Second) { + t.Fatal("times are not within half the lifetime, expected true") + } + if isCloseToExpiry(a, b, 12*time.Second) { + t.Fatal("times are within half the lifetime, expected false") + } +} diff --git a/cmd/natc/ippool/consensusippoolserialize.go b/cmd/natc/ippool/consensusippoolserialize.go new file mode 100644 index 000000000..97dc02f2c --- /dev/null +++ b/cmd/natc/ippool/consensusippoolserialize.go @@ -0,0 +1,164 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package ippool + +import ( + "encoding/json" + "io" + "log" + "maps" + "net/netip" + + "github.com/hashicorp/raft" + "go4.org/netipx" + "tailscale.com/syncs" + "tailscale.com/tailcfg" +) + +// Snapshot and Restore enable the raft lib to do log compaction. +// https://pkg.go.dev/github.com/hashicorp/raft#FSM + +// Snapshot is part of the raft.FSM interface. +// According to the docs it: +// - should return quickly +// - will not be called concurrently with Apply +// - the snapshot returned will have Persist called on it concurrently with Apply +// (so it should not contain pointers to the original data that's being mutated) +func (ipp *ConsensusIPPool) Snapshot() (raft.FSMSnapshot, error) { + // everything is safe for concurrent reads and this is not called concurrently with Apply which is + // the only thing that writes, so we do not need to lock + return ipp.getPersistable(), nil +} + +type persistableIPSet struct { + Ranges []persistableIPRange +} + +func getPersistableIPSet(i *netipx.IPSet) persistableIPSet { + rs := []persistableIPRange{} + for _, r := range i.Ranges() { + rs = append(rs, getPersistableIPRange(r)) + } + return persistableIPSet{Ranges: rs} +} + +func (mips *persistableIPSet) toIPSet() (*netipx.IPSet, error) { + b := netipx.IPSetBuilder{} + for _, r := range mips.Ranges { + b.AddRange(r.toIPRange()) + } + return b.IPSet() +} + +type persistableIPRange struct { + From netip.Addr + To netip.Addr +} + +func getPersistableIPRange(r netipx.IPRange) persistableIPRange { + return persistableIPRange{ + From: r.From(), + To: r.To(), + } +} + +func (mipr *persistableIPRange) toIPRange() netipx.IPRange { + return netipx.IPRangeFrom(mipr.From, mipr.To) +} + +// Restore is part of the raft.FSM interface. +// According to the docs it: +// - will not be called concurrently with any other command +// - the FSM must discard all previous state before restoring +func (ipp *ConsensusIPPool) Restore(rc io.ReadCloser) error { + var snap fsmSnapshot + if err := json.NewDecoder(rc).Decode(&snap); err != nil { + return err + } + ipset, ppm, err := snap.getData() + if err != nil { + return err + } + ipp.IPSet = ipset + ipp.perPeerMap = ppm + return nil +} + +type fsmSnapshot struct { + IPSet persistableIPSet + PerPeerMap map[tailcfg.NodeID]persistablePPS +} + +// Persist is part of the raft.FSMSnapshot interface +// According to the docs Persist may be called concurrently with Apply +func (f fsmSnapshot) Persist(sink raft.SnapshotSink) error { + if err := json.NewEncoder(sink).Encode(f); err != nil { + log.Printf("Error encoding snapshot as JSON: %v", err) + return sink.Cancel() + } + return sink.Close() +} + +// Release is part of the raft.FSMSnapshot interface +func (f fsmSnapshot) Release() {} + +// getPersistable returns an object that: +// - contains all the data in ConsensusIPPool +// - doesn't share any pointers with it +// - can be marshalled to JSON +// +// part of the raft snapshotting, getPersistable will be called during Snapshot +// and the results used during persist (concurrently with Apply) +func (ipp *ConsensusIPPool) getPersistable() fsmSnapshot { + ppm := map[tailcfg.NodeID]persistablePPS{} + for k, v := range ipp.perPeerMap.All() { + ppm[k] = v.getPersistable() + } + return fsmSnapshot{ + IPSet: getPersistableIPSet(ipp.IPSet), + PerPeerMap: ppm, + } +} + +func (f fsmSnapshot) getData() (*netipx.IPSet, *syncs.Map[tailcfg.NodeID, *consensusPerPeerState], error) { + ppm := syncs.Map[tailcfg.NodeID, *consensusPerPeerState]{} + for k, v := range f.PerPeerMap { + ppm.Store(k, v.toPerPeerState()) + } + ipset, err := f.IPSet.toIPSet() + if err != nil { + return nil, nil, err + } + return ipset, &ppm, nil +} + +// getPersistable returns an object that: +// - contains all the data in consensusPerPeerState +// - doesn't share any pointers with it +// - can be marshalled to JSON +// +// part of the raft snapshotting, getPersistable will be called during Snapshot +// and the results used during persist (concurrently with Apply) +func (ps *consensusPerPeerState) getPersistable() persistablePPS { + return persistablePPS{ + AddrToDomain: maps.Collect(ps.addrToDomain.All()), + DomainToAddr: maps.Clone(ps.domainToAddr), + } +} + +type persistablePPS struct { + DomainToAddr map[string]netip.Addr + AddrToDomain map[netip.Addr]whereWhen +} + +func (p persistablePPS) toPerPeerState() *consensusPerPeerState { + atd := &syncs.Map[netip.Addr, whereWhen]{} + for k, v := range p.AddrToDomain { + atd.Store(k, v) + } + return &consensusPerPeerState{ + domainToAddr: p.DomainToAddr, + addrToDomain: atd, + } +} diff --git a/cmd/natc/ippool/ippool.go b/cmd/natc/ippool/ippool.go index 3a46a6e7a..5a2dcbec9 100644 --- a/cmd/natc/ippool/ippool.go +++ b/cmd/natc/ippool/ippool.go @@ -10,6 +10,7 @@ import ( "math/big" "net/netip" "sync" + "time" "github.com/gaissmai/bart" "go4.org/netipx" @@ -21,12 +22,26 @@ import ( var ErrNoIPsAvailable = errors.New("no IPs available") -type IPPool struct { +// IPPool allocates IPv4 addresses from a pool to DNS domains, on a per tailcfg.NodeID basis. +// For each tailcfg.NodeID, IPv4 addresses are associated with at most one DNS domain. +// Addresses may be reused across other tailcfg.NodeID's for the same or other domains. +type IPPool interface { + // DomainForIP looks up the domain associated with a tailcfg.NodeID and netip.Addr pair. + // If there is no association, the result is empty and ok is false. + DomainForIP(tailcfg.NodeID, netip.Addr, time.Time) (string, bool) + + // IPForDomain looks up or creates an IP address allocation for the tailcfg.NodeID and domain pair. + // If no address association is found, one is allocated from the range of free addresses for this tailcfg.NodeID. + // If no more address are available, an error is returned. + IPForDomain(tailcfg.NodeID, string) (netip.Addr, error) +} + +type SingleMachineIPPool struct { perPeerMap syncs.Map[tailcfg.NodeID, *perPeerState] IPSet *netipx.IPSet } -func (ipp *IPPool) DomainForIP(from tailcfg.NodeID, addr netip.Addr) (string, bool) { +func (ipp *SingleMachineIPPool) DomainForIP(from tailcfg.NodeID, addr netip.Addr, _ time.Time) (string, bool) { ps, ok := ipp.perPeerMap.Load(from) if !ok { log.Printf("handleTCPFlow: no perPeerState for %v", from) @@ -40,7 +55,7 @@ func (ipp *IPPool) DomainForIP(from tailcfg.NodeID, addr netip.Addr) (string, bo return domain, ok } -func (ipp *IPPool) IPForDomain(from tailcfg.NodeID, domain string) (netip.Addr, error) { +func (ipp *SingleMachineIPPool) IPForDomain(from tailcfg.NodeID, domain string) (netip.Addr, error) { npps := &perPeerState{ ipset: ipp.IPSet, } diff --git a/cmd/natc/ippool/ippool_test.go b/cmd/natc/ippool/ippool_test.go index 2919d7757..8d474f86a 100644 --- a/cmd/natc/ippool/ippool_test.go +++ b/cmd/natc/ippool/ippool_test.go @@ -8,6 +8,7 @@ import ( "fmt" "net/netip" "testing" + "time" "go4.org/netipx" "tailscale.com/tailcfg" @@ -19,7 +20,7 @@ func TestIPPoolExhaustion(t *testing.T) { var ipsb netipx.IPSetBuilder ipsb.AddPrefix(smallPrefix) addrPool := must.Get(ipsb.IPSet()) - pool := IPPool{IPSet: addrPool} + pool := SingleMachineIPPool{IPSet: addrPool} assignedIPs := make(map[netip.Addr]string) @@ -68,7 +69,7 @@ func TestIPPool(t *testing.T) { var ipsb netipx.IPSetBuilder ipsb.AddPrefix(netip.MustParsePrefix("100.64.1.0/24")) addrPool := must.Get(ipsb.IPSet()) - pool := IPPool{ + pool := SingleMachineIPPool{ IPSet: addrPool, } from := tailcfg.NodeID(12345) @@ -89,7 +90,7 @@ func TestIPPool(t *testing.T) { t.Errorf("IPv4 address %s not in range %s", addr, addrPool) } - domain, ok := pool.DomainForIP(from, addr) + domain, ok := pool.DomainForIP(from, addr, time.Now()) if !ok { t.Errorf("domainForIP(%s) not found", addr) } else if domain != "example.com" { diff --git a/cmd/natc/natc.go b/cmd/natc/natc.go index b327f55bd..2dcdc551f 100644 --- a/cmd/natc/natc.go +++ b/cmd/natc/natc.go @@ -57,6 +57,8 @@ func main() { printULA = fs.Bool("print-ula", false, "print the ULA prefix and exit") ignoreDstPfxStr = fs.String("ignore-destinations", "", "comma-separated list of prefixes to ignore") wgPort = fs.Uint("wg-port", 0, "udp port for wireguard and peer to peer traffic") + clusterTag = fs.String("cluster-tag", "", "optionally run in a consensus cluster with other nodes with this tag") + server = fs.String("login-server", ipn.DefaultControlURL, "the base URL of control server") ) ff.Parse(fs, os.Args[1:], ff.WithEnvVarPrefix("TS_NATC")) @@ -94,6 +96,7 @@ func main() { ts := &tsnet.Server{ Hostname: *hostname, } + ts.ControlURL = *server if *wgPort != 0 { if *wgPort >= 1<<16 { log.Fatalf("wg-port must be in the range [0, 65535]") @@ -148,12 +151,31 @@ func main() { routes, dnsAddr, addrPool := calculateAddresses(prefixes) v6ULA := ula(uint16(*siteID)) + + var ipp ippool.IPPool + if *clusterTag != "" { + cipp := ippool.NewConsensusIPPool(addrPool) + err = cipp.StartConsensus(ctx, ts, *clusterTag) + if err != nil { + log.Fatalf("StartConsensus: %v", err) + } + defer func() { + err := cipp.StopConsensus(ctx) + if err != nil { + log.Printf("Error stopping consensus: %v", err) + } + }() + ipp = cipp + } else { + ipp = &ippool.SingleMachineIPPool{IPSet: addrPool} + } + c := &connector{ ts: ts, whois: lc, v6ULA: v6ULA, ignoreDsts: ignoreDstTable, - ipPool: &ippool.IPPool{IPSet: addrPool}, + ipPool: ipp, routes: routes, dnsAddr: dnsAddr, resolver: net.DefaultResolver, @@ -209,7 +231,7 @@ type connector struct { ignoreDsts *bart.Table[bool] // ipPool contains the per-peer IPv4 address assignments. - ipPool *ippool.IPPool + ipPool ippool.IPPool // resolver is used to lookup IP addresses for DNS queries. resolver lookupNetIPer @@ -453,7 +475,7 @@ func (c *connector) handleTCPFlow(src, dst netip.AddrPort) (handler func(net.Con if dstAddr.Is6() { dstAddr = v4ForV6(dstAddr) } - domain, ok := c.ipPool.DomainForIP(who.Node.ID, dstAddr) + domain, ok := c.ipPool.DomainForIP(who.Node.ID, dstAddr, time.Now()) if !ok { return nil, false } diff --git a/cmd/natc/natc_test.go b/cmd/natc/natc_test.go index 0320db8a4..78dec86fd 100644 --- a/cmd/natc/natc_test.go +++ b/cmd/natc/natc_test.go @@ -270,7 +270,7 @@ func TestDNSResponse(t *testing.T) { ignoreDsts: &bart.Table[bool]{}, routes: routes, v6ULA: v6ULA, - ipPool: &ippool.IPPool{IPSet: addrPool}, + ipPool: &ippool.SingleMachineIPPool{IPSet: addrPool}, dnsAddr: dnsAddr, } c.ignoreDsts.Insert(netip.MustParsePrefix("8.8.4.4/32"), true)