tailscale/util/ctxlock/state_test.go
Nick Khyl e744ea41c9
util/ctxlock: enforce mutex lock ordering defined by its rank
Updates #12614

Signed-off-by: Nick Khyl <nickk@tailscale.com>
2025-05-04 23:15:41 -05:00

413 lines
10 KiB
Go

// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause
package ctxlock
import (
"context"
"fmt"
"strings"
"sync"
"testing"
"tailscale.com/util/ctxkey"
)
type stateType interface {
*checked | unchecked
context.Context
unlock()
}
type lockStateType interface{ lockCallers | unchecked }
type impl[T stateType, S lockStateType] struct {
None func() T
FromContext func(context.Context) T
Lock func(T, *mutex[Reentrant, S]) T
LockCtx func(context.Context, *mutex[Reentrant, S]) T
}
var (
checkedImpl = impl[*checked, lockCallers]{
None: func() *checked { return nil },
FromContext: fromContextChecked,
Lock: lockChecked[Reentrant],
LockCtx: func(ctx context.Context, mu *checkedMutex[Reentrant]) *checked {
return lockChecked(fromContextChecked(ctx), mu)
},
}
uncheckedImpl = impl[unchecked, unchecked]{
None: func() unchecked { return unchecked{} },
FromContext: fromContextUnchecked,
Lock: lockUnchecked[Reentrant],
LockCtx: func(ctx context.Context, mu *mutex[Reentrant, unchecked]) unchecked {
return lockUnchecked(fromContextUnchecked(ctx), mu)
},
}
)
// BenchmarkStateLockUnlock benchmarks the performance of locking and unlocking a mutex.
func BenchmarkStateLockUnlock(b *testing.B) {
b.Run("Checked", func(b *testing.B) {
benchmarkStateLockUnlock(b, checkedImpl)
})
b.Run("Unchecked", func(b *testing.B) {
benchmarkStateLockUnlock(b, uncheckedImpl)
})
b.Run("Reference", func(b *testing.B) {
var mu sync.Mutex
for b.Loop() {
mu.Lock()
mu.Unlock()
}
})
}
func benchmarkStateLockUnlock[T stateType, S lockStateType](b *testing.B, impl impl[T, S]) {
var mu mutex[Reentrant, S]
for b.Loop() {
state := impl.Lock(impl.None(), &mu)
state.unlock()
}
}
// BenchmarkReentrance benchmarks the performance of reentrant locking and unlocking.
func BenchmarkReentrance(b *testing.B) {
b.Run("Checked", func(b *testing.B) {
benchmarkReentrance(b, checkedImpl)
})
b.Run("Unchecked", func(b *testing.B) {
benchmarkReentrance(b, uncheckedImpl)
})
b.Run("Reference", func(b *testing.B) {
var mu sync.Mutex
for b.Loop() {
mu.Lock()
func(mu *sync.Mutex) {
if mu.TryLock() {
mu.Unlock()
}
}(&mu)
mu.Unlock()
}
})
}
func benchmarkReentrance[T stateType, S lockStateType](b *testing.B, impl impl[T, S]) {
var mu mutex[Reentrant, S]
for b.Loop() {
parent := impl.Lock(impl.None(), &mu)
func(ctx T) {
child := impl.Lock(ctx, &mu)
child.unlock()
}(parent)
parent.unlock()
}
}
// TestUncheckedAllocFree tests that the exported implementation of [State] does not allocate memory
// when the ts_omit_ctxlock_checks build tag is set.
func TestUncheckedAllocFree(t *testing.T) {
if IsChecked {
t.Skip("Exported implementation is not alloc-free (use --tags=ts_omit_ctxlock_checks)")
}
t.Run("Simple/WithState", func(t *testing.T) {
var mu ReentrantMutex
mustNotAllocate(t, func() {
mu := Lock(None(), &mu)
mu.Unlock()
})
})
t.Run("Simple/WithContext", func(t *testing.T) {
var mu ReentrantMutex
ctx := context.Background()
mustNotAllocate(t, func() {
mu := Lock(ctx, &mu)
mu.Unlock()
})
})
t.Run("Reentrant/WithState", func(t *testing.T) {
var mu ReentrantMutex
mustNotAllocate(t, func() {
parent := Lock(None(), &mu)
func(state State) {
child := Lock(state, &mu)
child.Unlock()
}(parent.State())
parent.Unlock()
})
})
t.Run("Reentrant/WithContext", func(t *testing.T) {
var mu ReentrantMutex
ctx := context.Background()
mustNotAllocate(t, func() {
parent := Lock(ctx, &mu)
func(state State) {
child := Lock(state, &mu)
child.Unlock()
}(parent.State())
parent.Unlock()
})
})
}
func TestHappyPath(t *testing.T) {
t.Run("Checked", func(t *testing.T) {
testHappyPath(t, checkedImpl)
})
t.Run("Unchecked", func(t *testing.T) {
testHappyPath(t, uncheckedImpl)
})
}
func testHappyPath[T stateType, S lockStateType](t *testing.T, impl impl[T, S]) {
var mu mutex[Reentrant, S]
parent := impl.Lock(impl.None(), &mu)
wantLocked(t, &mu) // mu is locked by parent
child := impl.Lock(parent, &mu)
wantLocked(t, &mu) // mu is still locked by parent
var mu2 mutex[Reentrant, S]
ls2 := impl.Lock(child, &mu2)
wantLocked(t, &mu2) // mu2 is locked by ls2
grandchild := impl.Lock(ls2, &mu)
grandchild.unlock() // no-op; mu is owned by parent
wantLocked(t, &mu) // mu is still locked by parent
ls2.unlock() // unlocks mu2
wantUnlocked(t, &mu2) // mu2 is now unlocked
child.unlock() // noop
wantLocked(t, &mu) // mu is still locked by parent
parent.unlock() // unlocks mu
wantUnlocked(t, &mu) // mu is now unlocked
}
func TestContextWrapping(t *testing.T) {
t.Run("Checked", func(t *testing.T) {
testContextWrapping(t, checkedImpl)
})
t.Run("Unchecked", func(t *testing.T) {
testContextWrapping(t, uncheckedImpl)
})
}
func testContextWrapping[T stateType, S lockStateType](t *testing.T, impl impl[T, S]) {
// Create a [context.Context] with a value set in it.
wantValue := "value"
key := ctxkey.New("key", "")
ctxWithValue := key.WithValue(context.Background(), wantValue)
var mu mutex[Reentrant, S]
parent := impl.LockCtx(ctxWithValue, &mu)
wantLocked(t, &mu) // mu is locked by parent
// Let's assume that we want to call a function that takes a [context.Context].
// [State] is a valid [context.Context], so we can pass it to the function.
ctx := context.Context(parent)
// If / when necessary, we can convert it back to a [State].
// The [State] should carry the same lock state as the parent context.
parentDup := impl.FromContext(ctx)
// We can then create and use a child [State].
child := impl.Lock(parentDup, &mu)
// It still carries all the original context values...
if gotValue := key.Value(child); gotValue != wantValue {
t.Errorf("key.Value() = %s; want %s", gotValue, wantValue)
}
// ... and the lock state.
child.unlock() // no-op; mu is owned by parent
wantLocked(t, &mu) // mu is still locked by parent
parentDup.unlock() // no-op; mu is owned by parent
wantLocked(t, &mu) // mu is still locked by parent
parent.unlock() // unlocks mu
wantUnlocked(t, &mu) // mu is now unlocked
}
func TestNilMutex(t *testing.T) {
impl := checkedImpl
wantPanic(t, "nil mutex", func() { impl.Lock(impl.None(), nil) })
}
func TestUseUnlockedParent_Checked(t *testing.T) {
impl := checkedImpl
var mu checkedMutex[Reentrant]
parent := impl.Lock(impl.None(), &mu)
parent.unlock() // unlocks mu
wantUnlocked(t, &mu) // mu is now unlocked
wantPanic(t, "use after unlock", func() { impl.Lock(parent, &mu) })
}
func TestUnlockParentFirst_Checked(t *testing.T) {
impl := checkedImpl
var mu checkedMutex[Reentrant]
parent := impl.Lock(impl.FromContext(context.Background()), &mu)
child := impl.Lock(parent, &mu)
parent.unlock() // unlocks mu
wantUnlocked(t, &mu) // mu is now unlocked
wantPanic(t, "parent already unlocked", child.unlock)
}
func TestUnlockTwice_Checked(t *testing.T) {
impl := checkedImpl
unlockTwice := func(t *testing.T, ctx *checked) {
ctx.unlock() // unlocks mu
wantPanic(t, "already unlocked", ctx.unlock)
}
t.Run("Wrapped", func(t *testing.T) {
unlockTwice(t, impl.FromContext(context.Background()))
})
t.Run("Locked", func(t *testing.T) {
var mu checkedMutex[Reentrant]
ctx := impl.Lock(impl.None(), &mu)
unlockTwice(t, ctx)
})
t.Run("Child", func(t *testing.T) {
var mu checkedMutex[Reentrant]
parent := impl.Lock(impl.None(), &mu)
defer parent.unlock()
child := impl.Lock(parent, &mu)
unlockTwice(t, child)
})
t.Run("Grandchild", func(t *testing.T) {
var mu checkedMutex[Reentrant]
parent := impl.Lock(impl.None(), &mu)
defer parent.unlock()
child := impl.Lock(parent, &mu)
defer child.unlock()
grandchild := impl.Lock(child, &mu)
unlockTwice(t, grandchild)
})
}
func TestUseUnlocked_Checked(t *testing.T) {
impl := checkedImpl
var mu checkedMutex[Reentrant]
state := lockChecked(impl.None(), &mu)
state.unlock()
// All of these should panic since the state is already unlocked.
wantPanic(t, "*", func() { state.Deadline() })
wantPanic(t, "*", func() { state.Done() })
wantPanic(t, "*", func() { state.Err() })
wantPanic(t, "*", func() { state.unlock() })
wantPanic(t, "*", func() { state.Value("key") })
}
func TestUseZeroState(t *testing.T) {
t.Run("Checked", func(t *testing.T) {
testUseEmptyState(t, checkedImpl.None)
})
t.Run("Unchecked", func(t *testing.T) {
testUseEmptyState(t, uncheckedImpl.None)
})
}
func TestUseWrappedBackground(t *testing.T) {
t.Run("Checked", func(t *testing.T) {
testUseEmptyState(t, getWrappedBackground(t, checkedImpl))
})
t.Run("Unchecked", func(t *testing.T) {
testUseEmptyState(t, getWrappedBackground(t, uncheckedImpl))
})
}
func getWrappedBackground[T stateType, S lockStateType](t *testing.T, impl impl[T, S]) func() T {
t.Helper()
return func() T {
return impl.FromContext(context.Background())
}
}
func testUseEmptyState[T stateType](t *testing.T, getState func() T) {
// Using an empty [State] must not panic or deadlock.
// It should also behave like [context.Background].
for range 2 {
state := getState()
if gotDone := state.Done(); gotDone != nil {
t.Errorf("ctx.Done() = %v; want nil", gotDone)
}
if gotDeadline, ok := state.Deadline(); ok {
t.Errorf("ctx.Deadline() = %v; want !ok", gotDeadline)
}
if gotErr := state.Err(); gotErr != nil {
t.Errorf("ctx.Err() = %v; want nil", gotErr)
}
if gotValue := state.Value("test-key"); gotValue != nil {
t.Errorf("ctx.Value(test-key) = %v; want nil", gotValue)
}
state.unlock()
}
}
func wantPanic(t *testing.T, wantMsg string, fn func()) {
t.Helper()
defer func() {
if r := recover(); wantMsg != "*" {
if gotMsg := trimPanicMessage(r); gotMsg != wantMsg {
t.Errorf("panic: got %q; want %q", r, wantMsg)
}
}
}()
fn()
t.Fatal("failed to panic")
}
func (m *mutex[R, S]) isLockedForTest() bool {
if m.m.TryLock() {
m.m.Unlock()
return false
}
return true
}
func wantLocked[R Rank, S lockStateType](t *testing.T, m *mutex[R, S]) {
t.Helper()
if !m.isLockedForTest() {
t.Fatal("mutex is not locked")
}
}
func wantUnlocked[R Rank, S lockStateType](t *testing.T, m *mutex[R, S]) {
t.Helper()
if m.isLockedForTest() {
t.Fatal("mutex is locked")
}
}
func mustNotAllocate(t *testing.T, steps func()) {
t.Helper()
const runs = 1000
if allocs := testing.AllocsPerRun(runs, steps); allocs != 0 {
t.Errorf("expected 0 allocs, got %f", allocs)
}
}
func trimPanicMessage(r any) string {
msg := fmt.Sprintf("%v", r)
msg = strings.TrimSpace(msg)
if i := strings.IndexByte(msg, '\n'); i >= 0 {
return msg[:i]
}
return msg
}