From e548a515cacbe00f349bb30b9a0c74649ba0780a Mon Sep 17 00:00:00 2001 From: Fran Bull Date: Tue, 29 Apr 2025 11:32:40 -0700 Subject: [PATCH] snapshot and restore --- cmd/natc/ippool/consensusippool.go | 12 -- cmd/natc/ippool/consensusippool_test.go | 107 ++++++++++++ cmd/natc/ippool/consensusippoolserialize.go | 174 ++++++++++++++++++++ 3 files changed, 281 insertions(+), 12 deletions(-) create mode 100644 cmd/natc/ippool/consensusippoolserialize.go diff --git a/cmd/natc/ippool/consensusippool.go b/cmd/natc/ippool/consensusippool.go index adaadf46c..0b087b970 100644 --- a/cmd/natc/ippool/consensusippool.go +++ b/cmd/natc/ippool/consensusippool.go @@ -5,7 +5,6 @@ import ( "encoding/json" "errors" "fmt" - "io" "log" "net/netip" "sync" @@ -305,17 +304,6 @@ func (ipp *ConsensusIPPool) Apply(l *raft.Log) interface{} { } } -// TODO(fran) what exactly would we gain by implementing Snapshot and Restore? -// Snapshot is part of the raft.FSM interface. -func (ipp *ConsensusIPPool) Snapshot() (raft.FSMSnapshot, error) { - return nil, nil -} - -// Restore is part of the raft.FSM interface. -func (ipp *ConsensusIPPool) Restore(rc io.ReadCloser) error { - return nil -} - // commandExecutor is an interface covering the routing parts of consensus // used to allow a fake in the tests type commandExecutor interface { diff --git a/cmd/natc/ippool/consensusippool_test.go b/cmd/natc/ippool/consensusippool_test.go index 4beb9855b..fd92536e3 100644 --- a/cmd/natc/ippool/consensusippool_test.go +++ b/cmd/natc/ippool/consensusippool_test.go @@ -1,8 +1,10 @@ package ippool import ( + "bytes" "encoding/json" "fmt" + "io" "net/netip" "testing" "time" @@ -183,3 +185,108 @@ func TestConsensusDomainForIP(t *testing.T) { t.Fatalf("expected domain to be found for IP that was handed out for it") } } + +func TestConsensusSnapshot(t *testing.T) { + pfx := netip.MustParsePrefix("100.64.0.0/16") + ipp := makePool(pfx) + domain := "example.com" + expectedAddr := netip.MustParseAddr("100.64.0.0") + expectedFrom := expectedAddr + expectedTo := netip.MustParseAddr("100.64.255.255") + from := tailcfg.NodeID(1) + + // pool allocates first addr for from + ipp.IPForDomain(from, domain) + // take a snapshot + fsmSnap, err := ipp.Snapshot() + if err != nil { + t.Fatal(err) + } + snap := fsmSnap.(fsmSnapshot) + + // verify snapshot state matches the state we know ipp will have + // ipset matches ipp.IPSet + if len(snap.IPSet.Ranges) != 1 { + t.Fatalf("expected 1, got %d", len(snap.IPSet.Ranges)) + } + if snap.IPSet.Ranges[0].From != expectedFrom { + t.Fatalf("want %s, got %s", expectedFrom, snap.IPSet.Ranges[0].From) + } + if snap.IPSet.Ranges[0].To != expectedTo { + t.Fatalf("want %s, got %s", expectedTo, snap.IPSet.Ranges[0].To) + } + + // perPeerMap has one entry, for from + if len(snap.PerPeerMap) != 1 { + t.Fatalf("expected 1, got %d", len(snap.PerPeerMap)) + } + ps, _ := snap.PerPeerMap[from] + + // the one peer state has allocated one address, the first in the prefix + if len(ps.DomainToAddr) != 1 { + t.Fatalf("expected 1, got %d", len(ps.DomainToAddr)) + } + addr := ps.DomainToAddr[domain] + if addr != expectedAddr { + t.Fatalf("want %s, got %s", expectedAddr.String(), addr.String()) + } + if len(ps.AddrToDomain) != 1 { + t.Fatalf("expected 1, got %d", len(ps.AddrToDomain)) + } + addrPfx, err := addr.Prefix(32) + if err != nil { + t.Fatal(err) + } + ww, _ := ps.AddrToDomain[addrPfx] + if ww.Domain != domain { + t.Fatalf("want %s, got %s", domain, ww.Domain) + } +} + +func TestConsensusRestore(t *testing.T) { + pfx := netip.MustParsePrefix("100.64.0.0/16") + ipp := makePool(pfx) + domain := "example.com" + expectedAddr := netip.MustParseAddr("100.64.0.0") + from := tailcfg.NodeID(1) + + ipp.IPForDomain(from, domain) + // take the snapshot after only 1 addr allocated + fsmSnap, err := ipp.Snapshot() + if err != nil { + t.Fatal(err) + } + snap := fsmSnap.(fsmSnapshot) + + ipp.IPForDomain(from, "b.example.com") + ipp.IPForDomain(from, "c.example.com") + ipp.IPForDomain(from, "d.example.com") + // ipp now has 4 entries in domainToAddr + ps, _ := ipp.perPeerMap.Load(from) + if len(ps.domainToAddr) != 4 { + t.Fatalf("want 4, got %d", len(ps.domainToAddr)) + } + + // restore the snapshot + bs, err := json.Marshal(snap) + if err != nil { + t.Fatal(err) + } + ipp.Restore(io.NopCloser(bytes.NewBuffer(bs))) + + // everything should be as it was when the snapshot was taken + if ipp.perPeerMap.Len() != 1 { + t.Fatalf("want 1, got %d", ipp.perPeerMap.Len()) + } + psAfter, _ := ipp.perPeerMap.Load(from) + if len(psAfter.domainToAddr) != 1 { + t.Fatalf("want 1, got %d", len(psAfter.domainToAddr)) + } + if psAfter.domainToAddr[domain] != expectedAddr { + t.Fatalf("want %s, got %s", expectedAddr, psAfter.domainToAddr[domain]) + } + ww, _ := psAfter.addrToDomain.Lookup(expectedAddr) + if ww.Domain != domain { + t.Fatalf("want %s, got %s", domain, ww.Domain) + } +} diff --git a/cmd/natc/ippool/consensusippoolserialize.go b/cmd/natc/ippool/consensusippoolserialize.go new file mode 100644 index 000000000..fe9eaecc8 --- /dev/null +++ b/cmd/natc/ippool/consensusippoolserialize.go @@ -0,0 +1,174 @@ +package ippool + +import ( + "encoding/json" + "io" + "net/netip" + + "github.com/gaissmai/bart" + "github.com/hashicorp/raft" + "go4.org/netipx" + "tailscale.com/syncs" + "tailscale.com/tailcfg" +) + +// Snapshot and Restore enable the raft lib to do log compaction. +// https://pkg.go.dev/github.com/hashicorp/raft#FSM + +// Snapshot is part of the raft.FSM interface. +// According to the docs it: +// - should return quickly +// - will not be called concurrently with Apply +// - the snapshot returned will have Persist called on it concurrently with Apply +// (so it should not contain pointers to the original data that's being mutated) +func (ipp *ConsensusIPPool) Snapshot() (raft.FSMSnapshot, error) { + // everything is safe for concurrent reads and this is not called concurrently with Apply which is + // the only thing that writes, so we do not need to lock + return ipp.getPersistable(), nil +} + +type persistableIPSet struct { + Ranges []persistableIPRange +} + +func getPersistableIPSet(i *netipx.IPSet) persistableIPSet { + rs := []persistableIPRange{} + for _, r := range i.Ranges() { + rs = append(rs, getPersistableIPRange(r)) + } + return persistableIPSet{Ranges: rs} +} + +func (mips *persistableIPSet) toIPSet() (*netipx.IPSet, error) { + b := netipx.IPSetBuilder{} + for _, r := range mips.Ranges { + b.AddRange(r.toIPRange()) + } + return b.IPSet() +} + +type persistableIPRange struct { + From netip.Addr + To netip.Addr +} + +func getPersistableIPRange(r netipx.IPRange) persistableIPRange { + return persistableIPRange{ + From: r.From(), + To: r.To(), + } +} + +func (mipr *persistableIPRange) toIPRange() netipx.IPRange { + return netipx.IPRangeFrom(mipr.From, mipr.To) +} + +// Restore is part of the raft.FSM interface. +// According to the docs it: +// - will not be called concurrently with any other command +// - the FSM must discard all previous state before restoring +func (ipp *ConsensusIPPool) Restore(rc io.ReadCloser) error { + //IPSet *netipx.IPSet + //perPeerMap syncs.Map[tailcfg.NodeID, *consensusPerPeerState] + var snap fsmSnapshot + if err := json.NewDecoder(rc).Decode(&snap); err != nil { + return err + } + ipset, ppm, err := snap.getData() + if err != nil { + return err + } + ipp.IPSet = ipset + ipp.perPeerMap = ppm + return nil +} + +type fsmSnapshot struct { + IPSet persistableIPSet + PerPeerMap map[tailcfg.NodeID]persistablePPS +} + +// Persist is part of the raft.FSMSnapshot interface +// According to the docs Persist may be called concurrently with Apply +func (f fsmSnapshot) Persist(sink raft.SnapshotSink) error { + b, err := json.Marshal(f) + if err != nil { + sink.Cancel() + return err + } + if _, err := sink.Write(b); err != nil { + sink.Cancel() + return err + } + return sink.Close() +} + +// Release is part of the raft.FSMSnapshot interface +func (f fsmSnapshot) Release() {} + +// getPersistable returns an object that: +// * contains all the data in ConsensusIPPool +// * doesn't share any pointers with it +// * can be marshalled to JSON +// part of the raft snapshotting, getPersistable will be called during Snapshot +// and the results used during persist (concurrently with Apply) +func (ipp *ConsensusIPPool) getPersistable() fsmSnapshot { + ppm := map[tailcfg.NodeID]persistablePPS{} + for k, v := range ipp.perPeerMap.All() { + ppm[k] = v.getPersistable() + } + return fsmSnapshot{ + IPSet: getPersistableIPSet(ipp.IPSet), + PerPeerMap: ppm, + } +} + +// func (s fsmSnapshot)cippDataFromPersisted(d fsmSnapshot) (*netipx.IPSet, syncs.Map[tailcfg.NodeID, *consensusPerPeerState], error) { +func (s fsmSnapshot) getData() (*netipx.IPSet, syncs.Map[tailcfg.NodeID, *consensusPerPeerState], error) { + ppm := syncs.Map[tailcfg.NodeID, *consensusPerPeerState]{} + for k, v := range s.PerPeerMap { + ppm.Store(k, v.toPerPeerState()) + } + ipset, err := s.IPSet.toIPSet() + if err != nil { + return nil, ppm, err + } + return ipset, ppm, nil +} + +// getPersistable returns an object that: +// * contains all the data in consensusPerPeerState +// * doesn't share any pointers with it +// * can be marshalled to JSON +// part of the raft snapshotting, getPersistable will be called during Snapshot +// and the results used during persist (concurrently with Apply) +func (ps *consensusPerPeerState) getPersistable() persistablePPS { + dtaCopy := map[string]netip.Addr{} + for k, v := range ps.domainToAddr { + dtaCopy[k] = v + } + atd := map[netip.Prefix]whereWhen{} + for pfx, ww := range ps.addrToDomain.All() { + atd[pfx] = ww + } + return persistablePPS{ + AddrToDomain: atd, + DomainToAddr: dtaCopy, + } +} + +type persistablePPS struct { + DomainToAddr map[string]netip.Addr + AddrToDomain map[netip.Prefix]whereWhen +} + +func (p persistablePPS) toPerPeerState() *consensusPerPeerState { + atd := &bart.Table[whereWhen]{} + for k, v := range p.AddrToDomain { + atd.Insert(k, v) + } + return &consensusPerPeerState{ + domainToAddr: p.DomainToAddr, + addrToDomain: atd, + } +}