snapshot and restore

This commit is contained in:
Fran Bull 2025-04-29 11:32:40 -07:00
parent 8be3387199
commit e548a515ca
3 changed files with 281 additions and 12 deletions

View File

@ -5,7 +5,6 @@ import (
"encoding/json"
"errors"
"fmt"
"io"
"log"
"net/netip"
"sync"
@ -305,17 +304,6 @@ func (ipp *ConsensusIPPool) Apply(l *raft.Log) interface{} {
}
}
// TODO(fran) what exactly would we gain by implementing Snapshot and Restore?
// Snapshot is part of the raft.FSM interface.
func (ipp *ConsensusIPPool) Snapshot() (raft.FSMSnapshot, error) {
return nil, nil
}
// Restore is part of the raft.FSM interface.
func (ipp *ConsensusIPPool) Restore(rc io.ReadCloser) error {
return nil
}
// commandExecutor is an interface covering the routing parts of consensus
// used to allow a fake in the tests
type commandExecutor interface {

View File

@ -1,8 +1,10 @@
package ippool
import (
"bytes"
"encoding/json"
"fmt"
"io"
"net/netip"
"testing"
"time"
@ -183,3 +185,108 @@ func TestConsensusDomainForIP(t *testing.T) {
t.Fatalf("expected domain to be found for IP that was handed out for it")
}
}
func TestConsensusSnapshot(t *testing.T) {
pfx := netip.MustParsePrefix("100.64.0.0/16")
ipp := makePool(pfx)
domain := "example.com"
expectedAddr := netip.MustParseAddr("100.64.0.0")
expectedFrom := expectedAddr
expectedTo := netip.MustParseAddr("100.64.255.255")
from := tailcfg.NodeID(1)
// pool allocates first addr for from
ipp.IPForDomain(from, domain)
// take a snapshot
fsmSnap, err := ipp.Snapshot()
if err != nil {
t.Fatal(err)
}
snap := fsmSnap.(fsmSnapshot)
// verify snapshot state matches the state we know ipp will have
// ipset matches ipp.IPSet
if len(snap.IPSet.Ranges) != 1 {
t.Fatalf("expected 1, got %d", len(snap.IPSet.Ranges))
}
if snap.IPSet.Ranges[0].From != expectedFrom {
t.Fatalf("want %s, got %s", expectedFrom, snap.IPSet.Ranges[0].From)
}
if snap.IPSet.Ranges[0].To != expectedTo {
t.Fatalf("want %s, got %s", expectedTo, snap.IPSet.Ranges[0].To)
}
// perPeerMap has one entry, for from
if len(snap.PerPeerMap) != 1 {
t.Fatalf("expected 1, got %d", len(snap.PerPeerMap))
}
ps, _ := snap.PerPeerMap[from]
// the one peer state has allocated one address, the first in the prefix
if len(ps.DomainToAddr) != 1 {
t.Fatalf("expected 1, got %d", len(ps.DomainToAddr))
}
addr := ps.DomainToAddr[domain]
if addr != expectedAddr {
t.Fatalf("want %s, got %s", expectedAddr.String(), addr.String())
}
if len(ps.AddrToDomain) != 1 {
t.Fatalf("expected 1, got %d", len(ps.AddrToDomain))
}
addrPfx, err := addr.Prefix(32)
if err != nil {
t.Fatal(err)
}
ww, _ := ps.AddrToDomain[addrPfx]
if ww.Domain != domain {
t.Fatalf("want %s, got %s", domain, ww.Domain)
}
}
func TestConsensusRestore(t *testing.T) {
pfx := netip.MustParsePrefix("100.64.0.0/16")
ipp := makePool(pfx)
domain := "example.com"
expectedAddr := netip.MustParseAddr("100.64.0.0")
from := tailcfg.NodeID(1)
ipp.IPForDomain(from, domain)
// take the snapshot after only 1 addr allocated
fsmSnap, err := ipp.Snapshot()
if err != nil {
t.Fatal(err)
}
snap := fsmSnap.(fsmSnapshot)
ipp.IPForDomain(from, "b.example.com")
ipp.IPForDomain(from, "c.example.com")
ipp.IPForDomain(from, "d.example.com")
// ipp now has 4 entries in domainToAddr
ps, _ := ipp.perPeerMap.Load(from)
if len(ps.domainToAddr) != 4 {
t.Fatalf("want 4, got %d", len(ps.domainToAddr))
}
// restore the snapshot
bs, err := json.Marshal(snap)
if err != nil {
t.Fatal(err)
}
ipp.Restore(io.NopCloser(bytes.NewBuffer(bs)))
// everything should be as it was when the snapshot was taken
if ipp.perPeerMap.Len() != 1 {
t.Fatalf("want 1, got %d", ipp.perPeerMap.Len())
}
psAfter, _ := ipp.perPeerMap.Load(from)
if len(psAfter.domainToAddr) != 1 {
t.Fatalf("want 1, got %d", len(psAfter.domainToAddr))
}
if psAfter.domainToAddr[domain] != expectedAddr {
t.Fatalf("want %s, got %s", expectedAddr, psAfter.domainToAddr[domain])
}
ww, _ := psAfter.addrToDomain.Lookup(expectedAddr)
if ww.Domain != domain {
t.Fatalf("want %s, got %s", domain, ww.Domain)
}
}

View File

@ -0,0 +1,174 @@
package ippool
import (
"encoding/json"
"io"
"net/netip"
"github.com/gaissmai/bart"
"github.com/hashicorp/raft"
"go4.org/netipx"
"tailscale.com/syncs"
"tailscale.com/tailcfg"
)
// Snapshot and Restore enable the raft lib to do log compaction.
// https://pkg.go.dev/github.com/hashicorp/raft#FSM
// Snapshot is part of the raft.FSM interface.
// According to the docs it:
// - should return quickly
// - will not be called concurrently with Apply
// - the snapshot returned will have Persist called on it concurrently with Apply
// (so it should not contain pointers to the original data that's being mutated)
func (ipp *ConsensusIPPool) Snapshot() (raft.FSMSnapshot, error) {
// everything is safe for concurrent reads and this is not called concurrently with Apply which is
// the only thing that writes, so we do not need to lock
return ipp.getPersistable(), nil
}
type persistableIPSet struct {
Ranges []persistableIPRange
}
func getPersistableIPSet(i *netipx.IPSet) persistableIPSet {
rs := []persistableIPRange{}
for _, r := range i.Ranges() {
rs = append(rs, getPersistableIPRange(r))
}
return persistableIPSet{Ranges: rs}
}
func (mips *persistableIPSet) toIPSet() (*netipx.IPSet, error) {
b := netipx.IPSetBuilder{}
for _, r := range mips.Ranges {
b.AddRange(r.toIPRange())
}
return b.IPSet()
}
type persistableIPRange struct {
From netip.Addr
To netip.Addr
}
func getPersistableIPRange(r netipx.IPRange) persistableIPRange {
return persistableIPRange{
From: r.From(),
To: r.To(),
}
}
func (mipr *persistableIPRange) toIPRange() netipx.IPRange {
return netipx.IPRangeFrom(mipr.From, mipr.To)
}
// Restore is part of the raft.FSM interface.
// According to the docs it:
// - will not be called concurrently with any other command
// - the FSM must discard all previous state before restoring
func (ipp *ConsensusIPPool) Restore(rc io.ReadCloser) error {
//IPSet *netipx.IPSet
//perPeerMap syncs.Map[tailcfg.NodeID, *consensusPerPeerState]
var snap fsmSnapshot
if err := json.NewDecoder(rc).Decode(&snap); err != nil {
return err
}
ipset, ppm, err := snap.getData()
if err != nil {
return err
}
ipp.IPSet = ipset
ipp.perPeerMap = ppm
return nil
}
type fsmSnapshot struct {
IPSet persistableIPSet
PerPeerMap map[tailcfg.NodeID]persistablePPS
}
// Persist is part of the raft.FSMSnapshot interface
// According to the docs Persist may be called concurrently with Apply
func (f fsmSnapshot) Persist(sink raft.SnapshotSink) error {
b, err := json.Marshal(f)
if err != nil {
sink.Cancel()
return err
}
if _, err := sink.Write(b); err != nil {
sink.Cancel()
return err
}
return sink.Close()
}
// Release is part of the raft.FSMSnapshot interface
func (f fsmSnapshot) Release() {}
// getPersistable returns an object that:
// * contains all the data in ConsensusIPPool
// * doesn't share any pointers with it
// * can be marshalled to JSON
// part of the raft snapshotting, getPersistable will be called during Snapshot
// and the results used during persist (concurrently with Apply)
func (ipp *ConsensusIPPool) getPersistable() fsmSnapshot {
ppm := map[tailcfg.NodeID]persistablePPS{}
for k, v := range ipp.perPeerMap.All() {
ppm[k] = v.getPersistable()
}
return fsmSnapshot{
IPSet: getPersistableIPSet(ipp.IPSet),
PerPeerMap: ppm,
}
}
// func (s fsmSnapshot)cippDataFromPersisted(d fsmSnapshot) (*netipx.IPSet, syncs.Map[tailcfg.NodeID, *consensusPerPeerState], error) {
func (s fsmSnapshot) getData() (*netipx.IPSet, syncs.Map[tailcfg.NodeID, *consensusPerPeerState], error) {
ppm := syncs.Map[tailcfg.NodeID, *consensusPerPeerState]{}
for k, v := range s.PerPeerMap {
ppm.Store(k, v.toPerPeerState())
}
ipset, err := s.IPSet.toIPSet()
if err != nil {
return nil, ppm, err
}
return ipset, ppm, nil
}
// getPersistable returns an object that:
// * contains all the data in consensusPerPeerState
// * doesn't share any pointers with it
// * can be marshalled to JSON
// part of the raft snapshotting, getPersistable will be called during Snapshot
// and the results used during persist (concurrently with Apply)
func (ps *consensusPerPeerState) getPersistable() persistablePPS {
dtaCopy := map[string]netip.Addr{}
for k, v := range ps.domainToAddr {
dtaCopy[k] = v
}
atd := map[netip.Prefix]whereWhen{}
for pfx, ww := range ps.addrToDomain.All() {
atd[pfx] = ww
}
return persistablePPS{
AddrToDomain: atd,
DomainToAddr: dtaCopy,
}
}
type persistablePPS struct {
DomainToAddr map[string]netip.Addr
AddrToDomain map[netip.Prefix]whereWhen
}
func (p persistablePPS) toPerPeerState() *consensusPerPeerState {
atd := &bart.Table[whereWhen]{}
for k, v := range p.AddrToDomain {
atd.Insert(k, v)
}
return &consensusPerPeerState{
domainToAddr: p.DomainToAddr,
addrToDomain: atd,
}
}