avoid race

This commit is contained in:
Fran Bull 2025-02-24 10:38:21 -08:00
parent 3ed0736ae9
commit f0223a9dba

View File

@ -19,6 +19,7 @@ import (
"os"
"path/filepath"
"strings"
"sync"
"testing"
"time"
@ -39,9 +40,12 @@ import (
type fsm struct {
events []map[string]any
count int
mu sync.Mutex
}
func (f *fsm) Apply(l *raft.Log) any {
f.mu.Lock()
defer f.mu.Unlock()
f.count++
f.events = append(f.events, map[string]any{
"type": "Apply",
@ -52,6 +56,12 @@ func (f *fsm) Apply(l *raft.Log) any {
}
}
func (f *fsm) numEvents() int {
f.mu.Lock()
defer f.mu.Unlock()
return len(f.events)
}
func (f *fsm) Snapshot() (raft.FSMSnapshot, error) {
return nil, nil
}
@ -303,7 +313,7 @@ func TestApply(t *testing.T) {
}
fxBothMachinesHaveTheApply := func() bool {
return len(ps[0].sm.events) == 1 && len(ps[1].sm.events) == 1
return ps[0].sm.numEvents() == 1 && ps[1].sm.numEvents() == 1
}
waitFor(t, "the apply event made it into both state machines", fxBothMachinesHaveTheApply, 10, time.Second*1)
}
@ -327,7 +337,7 @@ func assertCommandsWorkOnAnyNode(t *testing.T, participants []*participant) {
expectedEventsLength := i + 1
fxEventsInAll := func() bool {
for _, pOther := range participants {
if len(pOther.sm.events) != expectedEventsLength {
if pOther.sm.numEvents() != expectedEventsLength {
return false
}
}
@ -396,7 +406,7 @@ func TestFollowerFailover(t *testing.T) {
}
fxAllMachinesHaveTheApplies := func() bool {
return len(ps[0].sm.events) == 2 && len(ps[1].sm.events) == 2 && len(smThree.events) == 2
return ps[0].sm.numEvents() == 2 && ps[1].sm.numEvents() == 2 && smThree.numEvents() == 2
}
waitFor(t, "the apply events made it into all state machines", fxAllMachinesHaveTheApplies, 10, time.Second*1)
@ -427,11 +437,11 @@ func TestFollowerFailover(t *testing.T) {
}
defer rThreeAgain.Stop(ctx)
fxThreeGetsCaughtUp := func() bool {
return len(smThreeAgain.events) == 4
return smThreeAgain.numEvents() == 4
}
waitFor(t, "the apply events made it into the third node when it appeared with an empty state machine", fxThreeGetsCaughtUp, 20, time.Second*2)
if len(smThree.events) != 2 {
t.Fatalf("Expected smThree to remain on 2 events: got %d", len(smThree.events))
if smThree.numEvents() != 2 {
t.Fatalf("Expected smThree to remain on 2 events: got %d", smThree.numEvents())
}
}