mirror of
https://github.com/tailscale/tailscale.git
synced 2025-02-18 02:48:40 +00:00
types/lazy: add (*SyncValue[T]).SetForTest method
It is sometimes necessary to change a global lazy.SyncValue for the duration of a test. This PR adds a (*SyncValue[T]).SetForTest method to facilitate that. Updates #12687 Signed-off-by: Nick Khyl <nickk@tailscale.com>
This commit is contained in:
parent
d500a92926
commit
5d09649b0b
@ -154,3 +154,34 @@ func SyncFuncErr[T any](fill func() (T, error)) func() (T, error) {
|
||||
return v, err
|
||||
}
|
||||
}
|
||||
|
||||
// TB is a subset of testing.TB that we use to set up test helpers.
|
||||
// It's defined here to avoid pulling in the testing package.
|
||||
type TB interface {
|
||||
Helper()
|
||||
Cleanup(func())
|
||||
}
|
||||
|
||||
// SetForTest sets z's value and error.
|
||||
// It's used in tests only and reverts z's state back when tb and all its
|
||||
// subtests complete.
|
||||
// It is not safe for concurrent use and must not be called concurrently with
|
||||
// any SyncValue methods, including another call to itself.
|
||||
func (z *SyncValue[T]) SetForTest(tb TB, val T, err error) {
|
||||
tb.Helper()
|
||||
|
||||
z.once.Do(func() {})
|
||||
oldErr, oldVal := z.err.Load(), z.v
|
||||
|
||||
z.v = val
|
||||
if err != nil {
|
||||
z.err.Store(ptr.To(err))
|
||||
} else {
|
||||
z.err.Store(nilErrPtr)
|
||||
}
|
||||
|
||||
tb.Cleanup(func() {
|
||||
z.v = oldVal
|
||||
z.err.Store(oldErr)
|
||||
})
|
||||
}
|
||||
|
@ -8,6 +8,8 @@ import (
|
||||
"fmt"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"tailscale.com/types/opt"
|
||||
)
|
||||
|
||||
func TestSyncValue(t *testing.T) {
|
||||
@ -147,6 +149,196 @@ func TestSyncValueConcurrent(t *testing.T) {
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func TestSyncValueSetForTest(t *testing.T) {
|
||||
testErr := errors.New("boom")
|
||||
tests := []struct {
|
||||
name string
|
||||
initValue opt.Value[int]
|
||||
initErr opt.Value[error]
|
||||
setForTestValue int
|
||||
setForTestErr error
|
||||
getValue int
|
||||
getErr opt.Value[error]
|
||||
wantValue int
|
||||
wantErr error
|
||||
routines int
|
||||
}{
|
||||
{
|
||||
name: "GetOk",
|
||||
setForTestValue: 42,
|
||||
getValue: 8,
|
||||
wantValue: 42,
|
||||
},
|
||||
{
|
||||
name: "GetOk/WithInit",
|
||||
initValue: opt.ValueOf(4),
|
||||
setForTestValue: 42,
|
||||
getValue: 8,
|
||||
wantValue: 42,
|
||||
},
|
||||
{
|
||||
name: "GetOk/WithInitErr",
|
||||
initValue: opt.ValueOf(4),
|
||||
initErr: opt.ValueOf(errors.New("blast")),
|
||||
setForTestValue: 42,
|
||||
getValue: 8,
|
||||
wantValue: 42,
|
||||
},
|
||||
{
|
||||
name: "GetErr",
|
||||
setForTestValue: 42,
|
||||
setForTestErr: testErr,
|
||||
getValue: 8,
|
||||
getErr: opt.ValueOf(errors.New("ka-boom")),
|
||||
wantValue: 42,
|
||||
wantErr: testErr,
|
||||
},
|
||||
{
|
||||
name: "GetErr/NilError",
|
||||
setForTestValue: 42,
|
||||
setForTestErr: nil,
|
||||
getValue: 8,
|
||||
getErr: opt.ValueOf(errors.New("ka-boom")),
|
||||
wantValue: 42,
|
||||
wantErr: nil,
|
||||
},
|
||||
{
|
||||
name: "GetErr/WithInitErr",
|
||||
initValue: opt.ValueOf(4),
|
||||
initErr: opt.ValueOf(errors.New("blast")),
|
||||
setForTestValue: 42,
|
||||
setForTestErr: testErr,
|
||||
getValue: 8,
|
||||
getErr: opt.ValueOf(errors.New("ka-boom")),
|
||||
wantValue: 42,
|
||||
wantErr: testErr,
|
||||
},
|
||||
{
|
||||
name: "Concurrent/GetOk",
|
||||
setForTestValue: 42,
|
||||
getValue: 8,
|
||||
wantValue: 42,
|
||||
routines: 10000,
|
||||
},
|
||||
{
|
||||
name: "Concurrent/GetOk/WithInitErr",
|
||||
initValue: opt.ValueOf(4),
|
||||
initErr: opt.ValueOf(errors.New("blast")),
|
||||
setForTestValue: 42,
|
||||
getValue: 8,
|
||||
wantValue: 42,
|
||||
routines: 10000,
|
||||
},
|
||||
{
|
||||
name: "Concurrent/GetErr",
|
||||
setForTestValue: 42,
|
||||
setForTestErr: testErr,
|
||||
getValue: 8,
|
||||
getErr: opt.ValueOf(errors.New("ka-boom")),
|
||||
wantValue: 42,
|
||||
wantErr: testErr,
|
||||
routines: 10000,
|
||||
},
|
||||
{
|
||||
name: "Concurrent/GetErr/WithInitErr",
|
||||
initValue: opt.ValueOf(4),
|
||||
initErr: opt.ValueOf(errors.New("blast")),
|
||||
setForTestValue: 42,
|
||||
setForTestErr: testErr,
|
||||
getValue: 8,
|
||||
getErr: opt.ValueOf(errors.New("ka-boom")),
|
||||
wantValue: 42,
|
||||
wantErr: testErr,
|
||||
routines: 10000,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
var v SyncValue[int]
|
||||
|
||||
// Initialize the sync value with the specified value and/or error,
|
||||
// if required by the test.
|
||||
if initValue, ok := tt.initValue.GetOk(); ok {
|
||||
var wantInitErr, gotInitErr error
|
||||
var wantInitValue, gotInitValue int
|
||||
wantInitValue = initValue
|
||||
if initErr, ok := tt.initErr.GetOk(); ok {
|
||||
wantInitErr = initErr
|
||||
gotInitValue, gotInitErr = v.GetErr(func() (int, error) { return initValue, initErr })
|
||||
} else {
|
||||
gotInitValue = v.Get(func() int { return initValue })
|
||||
}
|
||||
|
||||
if gotInitErr != wantInitErr {
|
||||
t.Fatalf("InitErr: got %v; want %v", gotInitErr, wantInitErr)
|
||||
}
|
||||
if gotInitValue != wantInitValue {
|
||||
t.Fatalf("InitValue: got %v; want %v", gotInitValue, wantInitValue)
|
||||
}
|
||||
|
||||
// Verify that SetForTest reverted the error and the value during the test cleanup.
|
||||
t.Cleanup(func() {
|
||||
wantCleanupValue, wantCleanupErr := wantInitValue, wantInitErr
|
||||
gotCleanupValue, gotCleanupErr, ok := v.PeekErr()
|
||||
if !ok {
|
||||
t.Fatal("SyncValue is not set after cleanup")
|
||||
}
|
||||
if gotCleanupErr != wantCleanupErr {
|
||||
t.Fatalf("CleanupErr: got %v; want %v", gotCleanupErr, wantCleanupErr)
|
||||
}
|
||||
if gotCleanupValue != wantCleanupValue {
|
||||
t.Fatalf("CleanupValue: got %v; want %v", gotCleanupValue, wantCleanupValue)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// Set the test value and/or error.
|
||||
v.SetForTest(t, tt.setForTestValue, tt.setForTestErr)
|
||||
|
||||
// Verify that the value and/or error have been set.
|
||||
// This will run on either the current goroutine
|
||||
// or concurrently depending on the tt.routines value.
|
||||
checkSyncValue := func() {
|
||||
var gotValue int
|
||||
var gotErr error
|
||||
if getErr, ok := tt.getErr.GetOk(); ok {
|
||||
gotValue, gotErr = v.GetErr(func() (int, error) { return tt.getValue, getErr })
|
||||
} else {
|
||||
gotValue = v.Get(func() int { return tt.getValue })
|
||||
}
|
||||
|
||||
if gotErr != tt.wantErr {
|
||||
t.Errorf("Err: got %v; want %v", gotErr, tt.wantErr)
|
||||
}
|
||||
if gotValue != tt.wantValue {
|
||||
t.Errorf("Value: got %v; want %v", gotValue, tt.wantValue)
|
||||
}
|
||||
}
|
||||
|
||||
switch tt.routines {
|
||||
case 0:
|
||||
checkSyncValue()
|
||||
default:
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(tt.routines)
|
||||
start := make(chan struct{})
|
||||
for range tt.routines {
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
// Every goroutine waits for the go signal, so that more of them
|
||||
// have a chance to race on the initial Get than with sequential
|
||||
// goroutine starts.
|
||||
<-start
|
||||
checkSyncValue()
|
||||
}()
|
||||
}
|
||||
close(start)
|
||||
wg.Wait()
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSyncFunc(t *testing.T) {
|
||||
f := SyncFunc(fortyTwo)
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user