util/ctxlock: enforce mutex lock ordering defined by its rank

Updates #12614

Signed-off-by: Nick Khyl <nickk@tailscale.com>
This commit is contained in:
Nick Khyl 2025-05-04 23:02:29 -05:00
parent 64e5da8024
commit e744ea41c9
No known key found for this signature in database
11 changed files with 1034 additions and 531 deletions

134
util/ctxlock/doc.go Normal file
View File

@ -0,0 +1,134 @@
// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause
// Package ctxlock provides a [Mutex] type and allows to define lock ordering
// and reentrancy rules for mutexes using a [Rank]. It then enforces these
// rules at runtime using a [State] hierarchy.
//
// The package has two implementations: checked and unchecked.
//
// Both implementations support reentrancy and lock ordering,
// but the checked implementation performs additional runtime checks
// and ensures that:
// - a parent [LockHandle] is not unlocked before its child,
// - a [LockHandle] is only unlocked once, and
// - a [State] is not used after being unlocked.
//
// The unchecked implementation skips these checks for improved performance,
// and is enabled in builds with the ts_omit_ctxlock_checks build tag.
//
// Example:
//
// type Resource struct {
// mu Mutex[Reentrant]
// value int
// }
//
// func (r *Resource) GetValue(ctx State) int {
// lock := Lock(ctx, &r.mu)
// defer lock.Unlock()
// return r.value
// }
//
// func (r *Resource) SetValue(ctx State, v int) {
// lock := Lock(ctx, &r.mu)
// defer lock.Unlock()
// r.value = v
// }
//
// func (r *Resource) Foo(ctx State, cb func(State) int) int {
// lock := Lock(ctx, &r.mu)
// defer lock.Unlock()
// return cb(lock.State())
// }
//
// func main() {
// r := Resource{}
// r.SetValue(State{}, 42)
// v := r.Foo(State{}, func(ctx State) int {
// return r.GetValue(ctx)
// })
// fmt.Println(v) // prints 42
// }
package ctxlock
import "context"
// IsChecked indicates whether the checked implementation is used.
const IsChecked = useCheckedImpl
// A Mutex is a potentially reentrant mutual exclusion lock
// with a lock hierarchy and reentrancy rules defined by its [Rank].
// The zero value of a Mutex is valid and represents an unlocked mutex.
//
// The lock state of zero or more mutexes held by a given call chain
// is carried by a [State].
//
// A mutex can be locked using [Lock]. The returned [LockHandle] becomes
// the mutex's owner if the mutex wasn't already held by an ancestor [State].
// It can be used to unlock the mutex or access the lock state hierarchy.
//
// It is a runtime error to lock a mutex if its rank's CheckLockAfter
// reports a conflict with any mutex already held along the call chain.
type Mutex[R Rank] struct {
mutex[R, lockState]
}
// ReentrantMutex is a reentrant [Mutex] with no defined lock hierarchy.
type ReentrantMutex = Mutex[Reentrant]
// State is a [context.Context] that carries the lock state of zero or more mutexes.
//
// Its zero value is valid and represents an unlocked state and an empty context.
type State struct {
stateImpl
}
// None returns a zero [State].
func None() State {
return State{}
}
// FromContext returns a [State] that carries the same lock state
// as the given [context.Context].
//
// It's typically used when [context.Context] already handles
// cancellation or deadlines and can be extended to locking as well.
func FromContext(ctx context.Context) State {
return State{fromContext(ctx)}
}
// Lock locks the specified mutex and becomes its owner, unless it is
// already held by the parent or its ancestor. It returns a [LockHandle]
// that can be used to unlock the mutex or access the modified lock [State].
//
// The parent can be either a [State] or a [context.Context].
// A zero State is a valid parent.
//
// It is a runtime error to pass a nil mutex or to unlock the parent's
// [LockHandle] before the returned one.
func Lock[T context.Context, R Rank](parent T, mu *Mutex[R]) LockHandle {
//return LockHandle{lock(parent, &mu.mutex)}
if parent, ok := any(parent).(State); ok {
return LockHandle{lock(parent.stateImpl, &mu.mutex)}
}
return LockHandle{lock(fromContext(parent), &mu.mutex)}
}
// LockHandle allows releasing a mutex acquired with [Lock]
// and provides access to the lock state hierarchy.
type LockHandle struct {
state stateImpl
}
// State returns the current lock state.
func (h LockHandle) State() State {
return State{h.state}
}
// Unlock releases the mutex owned by the handle, if any.
// It is a runtime error to call Unlock more than once on the same handle,
// or to unlock a [LockHandle] while its associated [State] is still in use.
func (h LockHandle) Unlock() {
h.state.unlock()
}

View File

@ -6,75 +6,130 @@ package ctxlock_test
import (
"context"
"fmt"
"sync"
"strings"
"testing"
"tailscale.com/util/ctxlock"
)
type Resource struct {
mu sync.Mutex
foo, bar string
func ExampleMutex_reentrant() {
var mu ctxlock.ReentrantMutex // shorthand for ctxlock.Mutex[ctxlock.Reentrant]
// The mutex is reentrant, so foo can be called with or without holding the mu.
// If mu is not already held, it will be locked on entry and unlocked on exit.
// The [ctxlock.State] parameter carries the current lock state.
foo := func(ctx ctxlock.State, msg string) {
lock := ctxlock.Lock(ctx, &mu)
defer lock.Unlock()
fmt.Println(msg)
}
// Calling foo without holding the lock.
foo(ctxlock.None(), "no lock")
// Locking the mutex and calling foo again.
lock := ctxlock.Lock(ctxlock.None(), &mu)
foo(lock.State(), "with lock")
defer lock.Unlock()
// Output:
// no lock
// with lock
}
func (r *Resource) GetFoo(ctx ctxlock.State) string {
// Lock the mutex if not already held.
defer ctxlock.Lock(ctx, &r.mu).Unlock()
return r.foo
func ExampleMutex_nonReentrant() {
var mu ctxlock.Mutex[ctxlock.NonReentrant]
// The mutex is non-reentrant, so foo must only be called without holding the mu.
// If mu is already held, it will panic attempting to lock it again.
foo := func(ctx ctxlock.State, msg string) {
defer func() {
if r := recover(); r != nil {
fmt.Println("panic:", trimPanicMessage(r))
}
}()
lock := ctxlock.Lock(ctx, &mu)
defer lock.Unlock()
fmt.Println(msg)
}
// Calling foo without holding the lock.
foo(ctxlock.None(), "no lock")
// Locking the mutex and calling foo again.
// This will panic because the mutex is non-reentrant.
lock := ctxlock.Lock(ctxlock.None(), &mu)
foo(lock.State(), "with lock")
defer lock.Unlock()
// Output:
// no lock
// panic: non-reentrant mutex already locked
}
func (r *Resource) SetFoo(ctx ctxlock.State, foo string) {
// You can do it this way, if you prefer
// or if you need to pass the state to another function.
ctx = ctxlock.Lock(ctx, &r.mu)
defer ctx.Unlock()
r.foo = foo
func ExampleRank() {
var mu1 ctxlock.Mutex[rank1] // cannot be locked after mu2 or mu3
var mu2 ctxlock.Mutex[rank2] // cannot be locked after mu3
var mu3 ctxlock.Mutex[rank3]
lock := ctxlock.Lock(ctxlock.None(), &mu1)
defer lock.Unlock()
fmt.Println("locked mu1")
lock = ctxlock.Lock(lock.State(), &mu2)
defer lock.Unlock()
fmt.Println("locked mu2")
lock = ctxlock.Lock(lock.State(), &mu3)
defer lock.Unlock()
fmt.Println("locked mu3")
// Output:
// locked mu1
// locked mu2
// locked mu3
}
func (r *Resource) GetBar(ctx ctxlock.State) string {
defer ctxlock.Lock(ctx, &r.mu).Unlock()
return r.bar
func ExampleRank_lockOrderViolation() {
var mu1 ctxlock.Mutex[rank1] // cannot be locked after mu2 or mu3
var mu2 ctxlock.Mutex[rank2] // cannot be locked after mu3
var mu3 ctxlock.Mutex[rank3]
defer func() {
if r := recover(); r != nil {
fmt.Println("panic:", trimPanicMessage(r))
}
}()
// While we can lock mu2 first...
lock := ctxlock.Lock(ctxlock.None(), &mu2)
defer lock.Unlock()
fmt.Println("locked mu2")
// ...and then mu3...
lock = ctxlock.Lock(lock.State(), &mu3)
defer lock.Unlock()
fmt.Println("locked mu3")
// It is a lock order violation to lock mu1
// after either mu2 or mu3.
lock = ctxlock.Lock(lock.State(), &mu1)
defer lock.Unlock()
fmt.Println("locked mu1")
// Output:
// locked mu2
// locked mu3
// panic: cannot lock ctxlock_test.rank1 after ctxlock_test.rank3
}
func (r *Resource) SetBar(ctx ctxlock.State, bar string) {
defer ctxlock.Lock(ctx, &r.mu).Unlock()
r.bar = bar
}
func (r *Resource) WithLock(ctx ctxlock.State, f func(ctx ctxlock.State)) {
// Lock the mutex if not already held, and get a new state.
ctx = ctxlock.Lock(ctx, &r.mu)
defer ctx.Unlock()
f(ctx) // Call the callback with the new lock state.
}
func (r *Resource) HandleRequest(ctx context.Context, foo, bar string, f func(ls ctxlock.State) string) string {
// Same, but with a standard [context.Context] instead of [ctxlock.State].
// [ctxlock.Lock] is generic and works with both without allocating.
// The ctx can be used for cancellation, etc.
mu := ctxlock.Lock(ctx, &r.mu)
defer mu.Unlock()
r.foo = foo
r.bar = bar
return f(mu)
}
func (r *Resource) HandleIntRequest(ctx context.Context, foo, bar string, f func(ls ctxlock.State) int) int {
// Same, but returns an int instead of a string,
// and must not allocate with the unchecked implementation.
mu := ctxlock.Lock(ctx, &r.mu)
defer mu.Unlock()
r.foo = foo
r.bar = bar
return f(mu)
}
func ExampleState() {
func ExampleState_resource() {
var r Resource
r.SetFoo(ctxlock.None(), "foo")
r.SetBar(ctxlock.None(), "bar")
r.WithLock(ctxlock.None(), func(ctx ctxlock.State) {
// This callback is invoked with r's lock held,
// This callback is invoked with r's mutex held,
// and ctx carries the lock state. This means we can safely call
// other methods on r using ctx without causing a deadlock.
r.SetFoo(ctx, r.GetFoo(ctx)+r.GetBar(ctx))
@ -88,7 +143,7 @@ func ExampleState_twoResources() {
r1.SetFoo(ctxlock.None(), "foo")
r2.SetBar(ctxlock.None(), "bar")
r1.WithLock(ctxlock.None(), func(ctx ctxlock.State) {
// Here, r1's lock is held, but r2's lock is not.
// Here, r1's mutex is held, but r2's mutex is not.
// So r2 will be locked when we call r2.GetBar(ctx).
r1.SetFoo(ctx, r1.GetFoo(ctx)+r2.GetBar(ctx))
})
@ -96,29 +151,27 @@ func ExampleState_twoResources() {
// Output: foobar
}
func ExampleState_stdContext() {
func ExampleState_withStdContext() {
var r Resource
ctx := context.Background()
result := r.HandleRequest(ctx, "foo", "bar", func(ctx ctxlock.State) string {
// The r's lock is held, and ctx carries the lock state.
// The r's mutex is held, and ctx carries the lock state.
return r.GetFoo(ctx) + r.GetBar(ctx)
})
fmt.Println(result)
// Output: foobar
}
func TestAllocFree(t *testing.T) {
if ctxlock.Checked {
t.Skip("Exported implementation is not alloc-free (use --tags=ts_omit_ctxlock_checks)")
func TestEndToEndAllocFree(t *testing.T) {
if ctxlock.IsChecked {
t.Skip("Exported implementation is not alloc-free (use --tags=ts_omit_ctxlock_checks).")
}
var r Resource
ctx := context.Background()
const runs = 1000
if allocs := testing.AllocsPerRun(runs, func() {
res := r.HandleIntRequest(ctx, "foo", "bar", func(ctx ctxlock.State) int {
// The r's lock is held, and ctx carries the lock state.
const N = 1000
if allocs := testing.AllocsPerRun(N, func() {
res := r.HandleIntRequest(context.Background(), "foo", "bar", func(ctx ctxlock.State) int {
// The r's mutex is held, and ctx carries the lock state.
return len(r.GetFoo(ctx) + r.GetBar(ctx))
})
if res != 6 {
@ -128,3 +181,102 @@ func TestAllocFree(t *testing.T) {
t.Errorf("expected 0 allocs, got %f", allocs)
}
}
type (
rank1 struct{}
rank2 struct{}
rank3 struct{}
)
// CheckLockAfter implements [ctxlock.Rank].
func (r rank1) CheckLockAfter(r2 ctxlock.Rank) error {
switch r2.(type) {
case rank2, rank3:
return fmt.Errorf("cannot lock %T after %T", r, r2)
default:
return nil
}
}
// CheckLockAfter implements [ctxlock.Rank].
func (r rank2) CheckLockAfter(r2 ctxlock.Rank) error {
switch r2.(type) {
case rank2, rank3:
return fmt.Errorf("cannot lock %T after %T", r, r2)
default:
return nil
}
}
// CheckLockAfter implements [ctxlock.Rank].
func (a rank3) CheckLockAfter(b ctxlock.Rank) error {
return nil
}
type Resource struct {
mu ctxlock.ReentrantMutex
foo, bar string
}
func (r *Resource) GetFoo(ctx ctxlock.State) string {
// Lock the mutex if not already held,
// and unlock it when the function returns.
defer ctxlock.Lock(ctx, &r.mu).Unlock()
return r.foo
}
func (r *Resource) SetFoo(ctx ctxlock.State, foo string) {
// You can do it this way, if you prefer.
mu := ctxlock.Lock(ctx, &r.mu)
defer mu.Unlock()
r.foo = foo
}
func (r *Resource) GetBar(ctx ctxlock.State) string {
mu := ctxlock.Lock(ctx, &r.mu)
defer mu.Unlock()
return r.bar
}
func (r *Resource) SetBar(ctx ctxlock.State, bar string) {
mu := ctxlock.Lock(ctx, &r.mu)
defer mu.Unlock()
r.bar = bar
}
func (r *Resource) WithLock(ctx ctxlock.State, f func(ctx ctxlock.State)) {
mu := ctxlock.Lock(ctx, &r.mu)
defer mu.Unlock()
// Call the callback with the new lock state.
f(mu.State())
}
func (r *Resource) HandleRequest(ctx context.Context, foo, bar string, f func(ls ctxlock.State) string) string {
// Same, but with a standard [context.Context] instead of [ctxlock.State].
// [ctxlock.Lock] is generic and works with both without allocating.
// The ctx can be used for cancellation, etc.
mu := ctxlock.Lock(ctx, &r.mu)
defer mu.Unlock()
r.foo = foo
r.bar = bar
return f(mu.State())
}
func (r *Resource) HandleIntRequest(ctx context.Context, foo, bar string, f func(ls ctxlock.State) int) int {
// Same, but returns an int instead of a string.
// It must not allocate with the checked implementation.
mu := ctxlock.Lock(ctx, &r.mu)
defer mu.Unlock()
r.foo = foo
r.bar = bar
return f(mu.State())
}
func trimPanicMessage(r any) string {
msg := fmt.Sprintf("%v", r)
msg = strings.TrimSpace(msg)
if i := strings.IndexByte(msg, '\n'); i >= 0 {
return msg[:i]
}
return msg
}

45
util/ctxlock/mutex.go Normal file
View File

@ -0,0 +1,45 @@
// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause
package ctxlock
import (
"sync"
)
// mutex is a wrapper around [sync.Mutex] that associates a [Rank] with the mutex
// and provides storage for an arbitrary value (of type S) to be used by the state
// that owns the lock while it is held. It's exported as [Mutex] in the package API.
type mutex[R Rank, S any] struct {
// r is the rank of the mutex, used to check lock order.
r R
// m is the underlying mutex that provides the locking mechanism.
m sync.Mutex
// lockState is a memory region used by the state that owns the lock while it is held.
// It serves as pre-allocated lockState to avoid (in the [unchecked] case)
// or reduce (in the [checked] case) memory allocations.
lockState S
}
func (m *mutex[R, S]) rank() Rank {
return m.r
}
func (m *mutex[R, S]) lock() {
m.m.Lock()
}
func (m *mutex[R, S]) state() any {
return &m.lockState
}
func (m *mutex[R, S]) unlock() {
m.m.Unlock()
}
// mutexHandle is a subset of the [mutex] methods that are used once the mutex is locked.
type mutexHandle interface {
rank() Rank
state() any
unlock()
}

120
util/ctxlock/mutex_test.go Normal file
View File

@ -0,0 +1,120 @@
// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause
package ctxlock
import (
"context"
"fmt"
"testing"
)
func BenchmarkReentrantMutex(b *testing.B) {
b.ReportAllocs()
// Does not allocate with --tags=ts_omit_ctxlock_checks.
b.Run("ctxlock.State", func(b *testing.B) {
var mu ReentrantMutex
for b.Loop() {
reentrantMutexLockUnlock(&mu, None)
}
})
b.Run("context.Context", func(b *testing.B) {
var mu ReentrantMutex
for b.Loop() {
reentrantMutexLockUnlock(&mu, context.Background)
}
})
}
func TestReentrantMutexAllocFree(t *testing.T) {
if IsChecked {
t.Skip("Exported implementation is not alloc-free (use --tags=ts_omit_ctxlock_checks).")
}
const N = 1000
t.Run("ctxlock.State", func(t *testing.T) {
var mu ReentrantMutex
if allocs := testing.AllocsPerRun(N, func() {
reentrantMutexLockUnlock(&mu, None)
}); allocs != 0 {
t.Errorf("expected 0 allocs, got %f", allocs)
}
})
t.Run("context.Context", func(t *testing.T) {
var mu ReentrantMutex
if allocs := testing.AllocsPerRun(N, func() {
reentrantMutexLockUnlock(&mu, context.Background)
}); allocs != 0 {
t.Errorf("expected 0 allocs, got %f", allocs)
}
})
}
func reentrantMutexLockUnlock[T context.Context](mu *ReentrantMutex, rootState func() T) {
parent := Lock(rootState(), mu)
func(state State) {
child := Lock(state, mu)
child.Unlock()
}(parent.State())
parent.Unlock()
}
func TestMutexRank(t *testing.T) {
var m1 mutex1
var m2 mutex2
var m3 mutex3
// Locking m1, m2, and m3 in order is valid.
lock := Lock(None(), &m1)
defer lock.Unlock()
lock = Lock(lock.State(), &m2)
defer lock.Unlock()
lock = Lock(lock.State(), &m3)
defer lock.Unlock()
}
func TestMutexLockOrderViolation(t *testing.T) {
var m1 mutex1
var m2 mutex2
var m3 mutex3
// Locking m2 m3, and then m1 is invalid.
lock := Lock(None(), &m2)
defer lock.Unlock()
lock = Lock(lock.State(), &m3)
defer lock.Unlock()
wantPanic(t, "cannot lock ctxlock.testRank1 after ctxlock.testRank3", func() {
lock := Lock(lock.State(), &m1)
defer lock.Unlock()
})
}
type (
testRank1 struct{}
testRank2 struct{}
testRank3 struct{}
mutex1 = Mutex[testRank1]
mutex2 = Mutex[testRank2]
mutex3 = Mutex[testRank3]
)
func (r testRank1) CheckLockAfter(r2 Rank) error {
switch r2.(type) {
case testRank2, testRank3:
return fmt.Errorf("cannot lock %T after %T", r, r2)
default:
return nil
}
}
func (r testRank2) CheckLockAfter(r2 Rank) error {
switch r2.(type) {
case testRank2, testRank3:
return fmt.Errorf("cannot lock %T after %T", r, r2)
default:
return nil
}
}
func (a testRank3) CheckLockAfter(b Rank) error {
return nil
}

58
util/ctxlock/rank.go Normal file
View File

@ -0,0 +1,58 @@
// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause
package ctxlock
// A Rank defines the locking rules for a [Mutex].
//
// Typically, a distinct [Rank] type is defined for each mutex
// that requires specific locking order.
//
// Example:
//
// type (
// fooRank struct{} // fooRank must not be locked after barRank
// barRank struct{}
// )
//
// func (r fooRank) CheckLockAfter(r2 Rank) error {
// switch r2.(type) {
// case barRank:
// return fmt.Errorf("cannot lock %T after %T", r, r2)
// default:
// return nil
// }
// }
//
// func (r barRank) CheckLockAfter(r2 Rank) error {
// return nil // barRank can be locked anytime
// }
//
// type Foo struct {
// mu Mutex[fooRank]
// }
//
// type Bar struct {
// mu Mutex[barRank]
// }
type Rank interface {
// CheckLockAfter returns an error if locking the receiver
// after the given rank would violate lock ordering or reentrancy rules.
CheckLockAfter(Rank) error
}
// Reentrant is a [Rank] that does not enforce any locking order and allows reentrancy.
//
// It is used by a pre-defined [ReentrantMutex] type.
type Reentrant struct {
noRank
}
// NonReentrant is a [Rank] that does not enforce any locking order, but disallows reentrancy.
type NonReentrant struct {
noRank
}
type noRank struct{}
func (noRank) CheckLockAfter(Rank) error { return nil }

View File

@ -1,228 +0,0 @@
// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause
// Package ctxlock provides a [context.Context] implementation that carries mutex lock state
// and enables reentrant locking. It offers two implementations: checked and unchecked.
// The checked implementation performs runtime validation to ensure that:
// - a parent context is not unlocked before its child,
// - a context is only unlocked once, and
// - a context is not used after being unlocked.
// The unchecked implementation skips these checks for improved performance.
// It defaults to the checked implementation unless the ts_omit_ctxlock_checks build tag is set.
package ctxlock
// This file contains both the [checked] and [unchecked] implementations of [State].
import (
"context"
"fmt"
"reflect"
"sync"
"time"
)
type ctxKey struct{ *sync.Mutex }
func ctxKeyOf(mu *sync.Mutex) ctxKey {
return ctxKey{mu}
}
// checked is an implementation of [State] that performs runtime checks
// to ensure the correct order of locking and unlocking.
//
// Its zero value and a nil pointer are valid and carry no lock state
// and an empty [context.Context].
type checked struct {
context.Context // nil means an empty context
// mu is the mutex tracked by this state,
// or nil if it wasn't created with [Lock].
mu *sync.Mutex
// parent is an ancestor State associated with the same mutex.
// It may or may not own the lock (the lock could be held by a further ancestor).
// The parent is nil if this State is the root of the hierarchy,
// meaning it owns the lock.
parent *checked
// unlocked is whether [checked.Unlock] was called on this state.
unlocked bool
}
func fromContextChecked(ctx context.Context) *checked {
return &checked{ctx, nil, nil, false}
}
func lockChecked(parent *checked, mu *sync.Mutex) *checked {
panicIfNil(mu)
if parentState, ok := parent.Value(ctxKeyOf(mu)).(*checked); ok {
if appearsUnlocked(mu) {
// The parent is already unlocked, but the mutex is not.
panic(fmt.Sprintf("%T is spuriously unlocked", mu))
}
return &checked{parent, mu, parentState, false}
}
mu.Lock()
return &checked{parent, mu, nil, false}
}
func (c *checked) Deadline() (deadline time.Time, ok bool) {
c.panicIfUnlocked()
if c == nil || c.Context == nil {
return time.Time{}, false
}
return c.Context.Deadline()
}
func (c *checked) Done() <-chan struct{} {
c.panicIfUnlocked()
if c == nil || c.Context == nil {
return nil
}
return c.Context.Done()
}
func (c *checked) Err() error {
c.panicIfUnlocked()
if c == nil || c.Context == nil {
return nil
}
return c.Context.Err()
}
func (c *checked) Value(key any) any {
c.panicIfUnlocked()
if c == nil {
// No-op; zero state.
return nil
}
if key, ok := key.(ctxKey); ok && key.Mutex == c.mu {
// This is the mutex tracked by this state.
return c
}
if c.Context != nil {
// Forward the call to the parent context,
// which may or may not be a [checked] state.
return c.Context.Value(key)
}
return nil
}
func (c *checked) Unlock() {
switch {
case c == nil:
// No-op; zero state.
return
case c.unlocked:
panic("already unlocked")
case c.mu == nil:
// No-op; the state does not track a mutex lock state,
// meaning it was not created with [Lock].
case c.parent == nil:
// The state own the mutex's lock; we must unlock it.
// This triggers a fatal error if the mutex is already unlocked.
c.mu.Unlock()
case c.parent.unlocked:
// The parent state is already unlocked.
// The mutex may or may not be locked;
// something else may have already locked it.
panic("parent already unlocked")
case appearsUnlocked(c.mu):
// The mutex itself is unlocked,
// even though the parent state is still locked.
// It may be unlocked by an ancestor state
// or by something else entirely.
panic("mutex is not locked")
default:
// No-op; a parent or ancestor will handle unlocking.
}
c.unlocked = true // mark this state as unlocked
}
func (c *checked) panicIfUnlocked() {
if c != nil && c.unlocked {
panic("use after unlock")
}
}
func panicIfNil[T comparable](v T) {
if reflect.ValueOf(v).IsNil() {
panic(fmt.Sprintf("nil %T", v))
}
}
// unchecked is an implementation of [State] that trades runtime checks for performance.
//
// Its zero value carries no mutex lock state and an empty [context.Context].
type unchecked struct {
context.Context // nil means an empty context
mu *sync.Mutex // non-nil if owned by this state
}
func fromContextUnchecked(ctx context.Context) unchecked {
return unchecked{ctx, nil}
}
func lockUnchecked(parent unchecked, mu *sync.Mutex) unchecked {
if parent.Value(ctxKeyOf(mu)) == nil {
// There's no ancestor state associated with this mutex,
// so we can lock it.
mu.Lock()
} else {
// The mutex is already locked by a parent/ancestor state.
mu = nil
}
return unchecked{parent.Context, mu}
}
func (c unchecked) Deadline() (deadline time.Time, ok bool) {
if c.Context == nil {
return time.Time{}, false
}
return c.Context.Deadline()
}
func (c unchecked) Done() <-chan struct{} {
if c.Context == nil {
return nil
}
return c.Context.Done()
}
func (c unchecked) Err() error {
if c.Context == nil {
return nil
}
return c.Context.Err()
}
func (c unchecked) Value(key any) any {
if key, ok := key.(ctxKey); ok && key.Mutex == c.mu {
return key
}
if c.Context == nil {
return nil
}
return c.Context.Value(key)
}
func (c unchecked) Unlock() {
if c.mu != nil {
c.mu.Unlock()
}
}
type tryLocker interface {
TryLock() bool
Unlock()
}
// appearsUnlocked reports whether m is unlocked.
// It may return a false negative if m does not have a TryLock method.
func appearsUnlocked[T sync.Locker](m T) bool {
if m, ok := any(m).(tryLocker); ok && m.TryLock() {
m.Unlock()
return true
}
return false
}

View File

@ -1,53 +1,229 @@
// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause
// This file exports default, unoptimized implementation of the [State] that includes runtime checks.
// It is used unless the build tag ts_omit_ctxlock_checks is set.
//go:build !ts_omit_ctxlock_checks
package ctxlock
import (
"context"
"sync"
"errors"
"fmt"
"runtime"
"time"
)
// Checked indicates whether runtime checks are enabled for this package.
const Checked = true
// State carries the lock state of zero or more mutexes and an optional [context.Context].
// Its zero value is valid and represents an unlocked state and an empty context.
// checked is an implementation of [State] with additional runtime checks.
//
// Calling [Lock] returns a derived State with the specified mutex locked. The State is considered
// the owner of the lock if it wasn't already acquired by a parent State. Calling [State.Unlock]
// releases the lock owned by the state. It is a runtime error to call Unlock more than once,
// to use the State after it has been unlocked, or to unlock a parent State before its child.
type State struct {
*checked
// Its zero value and a nil pointer are valid and carry no lock state
// and an empty [context.Context].
type checked struct {
context.Context // nil means an empty context
// mu is the mutex locked (or re-locked) by this state,
// or nil if it wasn't created with [Lock].
mu mutexHandle
// parent is the next state in the hierarchy associated with the same mutex.
// It may or may not own the lock (the lock could be held by a further ancestor).
//
// The parent is nil if this state owns the lock, or if it's a zero state.
parent *checked
// unlocked is whether [checked.Unlock] was called on this state.
unlocked bool
// lockedBy are the program counters of function invocations
// that locked the mutex, or nil if mu is not owned by this state.
lockedBy *lockCallers
}
// None returns a [State] that carries no lock state and an empty [context.Context].
func None() State {
return State{}
type (
lockCallers [5]uintptr
checkedMutex[R Rank] = mutex[R, lockCallers]
)
func fromContextChecked(ctx context.Context) *checked {
return &checked{Context: ctx}
}
// FromContext returns a [State] that carries the same lock state as the provided [context.Context].
//
// It is typically used by methods that already accept a [context.Context] for cancellation or deadline
// management, and would like to use it for locking as well.
func FromContext(ctx context.Context) State {
return State{fromContextChecked(ctx)}
}
// Lock acquires the specified mutex and becomes its owner, unless it is already held by a parent.
// The parent can be either a [State] or a [context.Context]. A zero [State] is a valid parent.
// It returns a new [State] that augments the parent with the additional lock state.
//
// It is a runtime error to pass a nil mutex or to unlock the parent state before the returned one.
func Lock[T context.Context](parent T, mu *sync.Mutex) State {
if parent, ok := any(parent).(State); ok {
return State{lockChecked(parent.checked, mu)}
func lockChecked[R Rank](parent *checked, mu *checkedMutex[R]) *checked {
if mu == nil {
panic("nil mutex")
}
return State{lockChecked(fromContextChecked(parent), mu)}
if parentState, ok := parent.isAlreadyLocked(mu); ok {
return &checked{parent, mu, parentState, false, nil}
}
mu.lock()
runtime.Callers(4, mu.lockState[:])
return &checked{parent, mu, nil, false, nil}
}
func (c *checked) isAlreadyLocked(m mutexHandle) (parent *checked, ok bool) {
switch val := c.Value(m).(type) {
case nil:
// No ancestor state associated with this mutex,
// and locking it does not violate the lock ordering.
return nil, false
case error:
// There's a lock ordering or reentrancy violation.
panic(val)
case *checked:
// The mutex is reentrant and is already held by a parent
// or ancestor state.
return val, true
default:
panic("unreachable")
}
}
func (c *checked) unlock() {
switch {
case c == nil:
// No-op; zero state.
return
case c.unlocked:
panic("already unlocked")
case c.mu == nil:
// No-op; the state does not track a mutex lock state,
// meaning it was not created with [Lock].
case c.parent == nil:
// The state own the mutex's lock; we must unlock it.
// This triggers a fatal error if the mutex is already unlocked.
c.mu.unlock()
case c.parent.unlocked:
// The parent state is already unlocked.
// The mutex may or may not be locked;
// something else may have already locked it.
panic("parent already unlocked")
default:
// No-op; a parent or ancestor will handle unlocking.
}
c.unlocked = true // mark this state as unlocked
}
func (c *checked) panicIfUnlocked() {
if c != nil && c.unlocked {
panic("use after unlock")
}
}
// Deadline implements [context.Context].
func (c *checked) Deadline() (deadline time.Time, ok bool) {
c.panicIfUnlocked()
if c == nil || c.Context == nil {
return time.Time{}, false
}
return c.Context.Deadline()
}
// Done implements [context.Context].
func (c *checked) Done() <-chan struct{} {
c.panicIfUnlocked()
if c == nil || c.Context == nil {
return nil
}
return c.Context.Done()
}
// Err implements [context.Context].
func (c *checked) Err() error {
c.panicIfUnlocked()
if c == nil || c.Context == nil {
return nil
}
return c.Context.Err()
}
// Value implements [context.Context].
func (c *checked) Value(key any) any {
c.panicIfUnlocked()
if c == nil {
// No-op; zero state.
return nil
}
if mu, ok := key.(mutexHandle); ok {
// Checks whether mu can be acquired after c.mu.
if res, done := checkLockOrder(mu, c.mu, c); done {
// We have a definite answer.
switch res := res.(type) {
case error:
// There's a lock ordering or reentrancy violation.
// Enrich the error with the call stack when the other mutex was locked.
if lockedBy, ok := c.mu.state().(*lockCallers); ok {
return LockOrderError{res, *lockedBy}
}
default:
// A reentrant mutex is already locked by a parent or ancestor state.
return res
}
}
}
if c.Context != nil {
// Forward the call to the parent context,
// which may or may not be a [checked] state.
return c.Context.Value(key)
}
return nil
}
var errAlreadyLocked = errors.New("non-reentrant mutex already locked")
// checkLockOrder determines whether m1 can be acquired after m2.
// It returns an error and true if there's a lock ordering or reentrancy violation,
// or the provided alreadyLocked value and true if m1 and m2 are the same and reentrancy is allowed,
// or nil and false if the caller should continue checking against the next locked mutex.
func checkLockOrder[T any](m1, m2 mutexHandle, alreadyLocked T) (res any, done bool) {
if m2 == nil {
// Nothing to check; continue search.
return nil, false
}
r1, r2 := m1.rank(), m2.rank()
if err := r1.CheckLockAfter(r2); err != nil {
// There's a lock ordering (or reentrancy) violation.
return err, true
}
if m1 != m2 {
// There's no lock ordering violation,
// but the mutex being locked is not the same as the one
// already locked. We need to continue checking.
return nil, false
}
if _, ok := r1.(NonReentrant); ok {
// Special handling for the [NonReentrant] rank.
//
// For user-defined ranks, reentrancy rules are enforced
// by the rank implementation itself, since each mutex
// is expected to have a distinct rank, and the rank
// can define its own rules. However, the predefined
// [NonReentrant] rank is shared by multiple mutexes.
return errAlreadyLocked, true
}
// The locking mutex is the same as the one already locked,
// and the rank allows reentrancy. We found a match.
return alreadyLocked, true
}
// LockOrderError represents a violation of mutex lock ordering.
//
// This error is not returned directly; it is used in panics to indicate a programming error
// when lock acquisition violates the expected order.
type LockOrderError struct {
error
violatedBy lockCallers // the call stack when the other mutex was locked
}
func (e LockOrderError) Error() string {
return fmt.Sprintf("%s\n\nConflicting lock held at:\n%s", e.error, e.violatedBy)
}
func (c lockCallers) String() string {
var output string
frames := runtime.CallersFrames(c[:])
for {
frame, more := frames.Next()
output += fmt.Sprintf("%s\n\t%s:%d\n", frame.Function, frame.File, frame.Line)
if !more {
break
}
}
return output
}

View File

@ -5,62 +5,58 @@ package ctxlock
import (
"context"
"fmt"
"strings"
"sync"
"testing"
"tailscale.com/util/ctxkey"
)
type state interface {
type stateType interface {
*checked | unchecked
context.Context
Unlock()
unlock()
}
type impl[T state] struct {
type lockStateType interface{ lockCallers | unchecked }
type impl[T stateType, S lockStateType] struct {
None func() T
FromContext func(context.Context) T
Lock func(T, *sync.Mutex) T
LockCtx func(context.Context, *sync.Mutex) T
Lock func(T, *mutex[Reentrant, S]) T
LockCtx func(context.Context, *mutex[Reentrant, S]) T
}
var (
exportedImpl = impl[State]{
None: None,
FromContext: FromContext,
Lock: Lock[State],
LockCtx: Lock[context.Context],
}
checkedImpl = impl[*checked]{
checkedImpl = impl[*checked, lockCallers]{
None: func() *checked { return nil },
FromContext: fromContextChecked,
Lock: lockChecked,
LockCtx: func(ctx context.Context, mu *sync.Mutex) *checked {
Lock: lockChecked[Reentrant],
LockCtx: func(ctx context.Context, mu *checkedMutex[Reentrant]) *checked {
return lockChecked(fromContextChecked(ctx), mu)
},
}
uncheckedImpl = impl[unchecked]{
uncheckedImpl = impl[unchecked, unchecked]{
None: func() unchecked { return unchecked{} },
FromContext: fromContextUnchecked,
Lock: lockUnchecked,
LockCtx: func(ctx context.Context, mu *sync.Mutex) unchecked {
Lock: lockUnchecked[Reentrant],
LockCtx: func(ctx context.Context, mu *mutex[Reentrant, unchecked]) unchecked {
return lockUnchecked(fromContextUnchecked(ctx), mu)
},
}
)
// BenchmarkLockUnlock benchmarks the performance of locking and unlocking a mutex.
func BenchmarkLockUnlock(b *testing.B) {
var mu sync.Mutex
b.Run("Exported", func(b *testing.B) {
benchmarkLockUnlock(b, exportedImpl)
})
// BenchmarkStateLockUnlock benchmarks the performance of locking and unlocking a mutex.
func BenchmarkStateLockUnlock(b *testing.B) {
b.Run("Checked", func(b *testing.B) {
benchmarkLockUnlock(b, checkedImpl)
benchmarkStateLockUnlock(b, checkedImpl)
})
b.Run("Unchecked", func(b *testing.B) {
benchmarkLockUnlock(b, uncheckedImpl)
benchmarkStateLockUnlock(b, uncheckedImpl)
})
b.Run("Reference", func(b *testing.B) {
var mu sync.Mutex
for b.Loop() {
mu.Lock()
mu.Unlock()
@ -68,21 +64,16 @@ func BenchmarkLockUnlock(b *testing.B) {
})
}
func benchmarkLockUnlock[T state](b *testing.B, impl impl[T]) {
var mu sync.Mutex
func benchmarkStateLockUnlock[T stateType, S lockStateType](b *testing.B, impl impl[T, S]) {
var mu mutex[Reentrant, S]
for b.Loop() {
ctx := impl.Lock(impl.None(), &mu)
ctx.Unlock()
state := impl.Lock(impl.None(), &mu)
state.unlock()
}
}
// BenchmarkReentrance benchmarks the performance of reentrant locking and unlocking.
func BenchmarkReentrance(b *testing.B) {
var mu sync.Mutex
b.Run("Exported", func(b *testing.B) {
benchmarkReentrance(b, exportedImpl)
})
b.Run("Checked", func(b *testing.B) {
benchmarkReentrance(b, checkedImpl)
})
@ -90,6 +81,7 @@ func BenchmarkReentrance(b *testing.B) {
benchmarkReentrance(b, uncheckedImpl)
})
b.Run("Reference", func(b *testing.B) {
var mu sync.Mutex
for b.Loop() {
mu.Lock()
func(mu *sync.Mutex) {
@ -102,102 +94,68 @@ func BenchmarkReentrance(b *testing.B) {
})
}
func benchmarkReentrance[T state](b *testing.B, impl impl[T]) {
var mu sync.Mutex
func benchmarkReentrance[T stateType, S lockStateType](b *testing.B, impl impl[T, S]) {
var mu mutex[Reentrant, S]
for b.Loop() {
parent := impl.Lock(impl.None(), &mu)
func(ctx T) {
child := impl.Lock(ctx, &mu)
child.Unlock()
child.unlock()
}(parent)
parent.Unlock()
parent.unlock()
}
}
// BenchmarkGenericLock benchmarks the performance of the generic [Lock] function
// that works with both [State] and [context.Context].
func BenchmarkGenericLock(b *testing.B) {
// Does not allocate with --tags=ts_omit_ctxlock_checks.
b.Run("State", func(b *testing.B) {
var mu sync.Mutex
var ctx State
for b.Loop() {
parent := Lock(ctx, &mu)
func(ctx State) {
child := Lock(ctx, &mu)
child.Unlock()
}(parent)
parent.Unlock()
}
})
b.Run("StdContext", func(b *testing.B) {
var mu sync.Mutex
ctx := context.Background()
for b.Loop() {
parent := Lock(ctx, &mu)
func(ctx State) {
child := Lock(ctx, &mu)
child.Unlock()
}(parent)
parent.Unlock()
}
})
}
// TestUncheckedAllocFree tests that the exported implementation of [State] does not allocate memory
// when the ts_omit_ctxlock_checks build tag is set.
func TestUncheckedAllocFree(t *testing.T) {
if Checked {
if IsChecked {
t.Skip("Exported implementation is not alloc-free (use --tags=ts_omit_ctxlock_checks)")
}
t.Run("Simple/WithState", func(t *testing.T) {
var mu sync.Mutex
var mu ReentrantMutex
mustNotAllocate(t, func() {
ctx := Lock(None(), &mu)
ctx.Unlock()
mu := Lock(None(), &mu)
mu.Unlock()
})
})
t.Run("Simple/WithContext", func(t *testing.T) {
var mu sync.Mutex
var mu ReentrantMutex
ctx := context.Background()
mustNotAllocate(t, func() {
ctx := Lock(ctx, &mu)
ctx.Unlock()
mu := Lock(ctx, &mu)
mu.Unlock()
})
})
t.Run("Reentrant/WithState", func(t *testing.T) {
var mu sync.Mutex
var mu ReentrantMutex
mustNotAllocate(t, func() {
parent := Lock(None(), &mu)
func(ctx State) {
child := Lock(parent, &mu)
func(state State) {
child := Lock(state, &mu)
child.Unlock()
}(parent)
}(parent.State())
parent.Unlock()
})
})
t.Run("Reentrant/WithContext", func(t *testing.T) {
var mu sync.Mutex
var mu ReentrantMutex
ctx := context.Background()
mustNotAllocate(t, func() {
parent := Lock(ctx, &mu)
func(ctx State) {
child := Lock(ctx, &mu)
func(state State) {
child := Lock(state, &mu)
child.Unlock()
}(parent)
}(parent.State())
parent.Unlock()
})
})
}
func TestHappyPath(t *testing.T) {
t.Run("Exported", func(t *testing.T) {
testHappyPath(t, exportedImpl)
})
t.Run("Checked", func(t *testing.T) {
testHappyPath(t, checkedImpl)
})
@ -207,32 +165,33 @@ func TestHappyPath(t *testing.T) {
})
}
func testHappyPath[T state](t *testing.T, impl impl[T]) {
var mu sync.Mutex
func testHappyPath[T stateType, S lockStateType](t *testing.T, impl impl[T, S]) {
var mu mutex[Reentrant, S]
parent := impl.Lock(impl.None(), &mu)
wantLocked(t, &mu) // mu is locked by parent
child := impl.Lock(parent, &mu)
wantLocked(t, &mu) // mu is still locked by parent
var mu2 sync.Mutex
var mu2 mutex[Reentrant, S]
ls2 := impl.Lock(child, &mu2)
wantLocked(t, &mu2) // mu2 is locked by ls2
ls2.Unlock() // unlocks mu2
wantLocked(t, &mu2) // mu2 is locked by ls2
grandchild := impl.Lock(ls2, &mu)
grandchild.unlock() // no-op; mu is owned by parent
wantLocked(t, &mu) // mu is still locked by parent
ls2.unlock() // unlocks mu2
wantUnlocked(t, &mu2) // mu2 is now unlocked
child.Unlock() // noop
child.unlock() // noop
wantLocked(t, &mu) // mu is still locked by parent
parent.Unlock() // unlocks mu
parent.unlock() // unlocks mu
wantUnlocked(t, &mu) // mu is now unlocked
}
func TestContextWrapping(t *testing.T) {
t.Run("Exported", func(t *testing.T) {
testContextWrapping(t, exportedImpl)
})
t.Run("Checked", func(t *testing.T) {
testContextWrapping(t, checkedImpl)
})
@ -242,13 +201,13 @@ func TestContextWrapping(t *testing.T) {
})
}
func testContextWrapping[T state](t *testing.T, impl impl[T]) {
func testContextWrapping[T stateType, S lockStateType](t *testing.T, impl impl[T, S]) {
// Create a [context.Context] with a value set in it.
wantValue := "value"
key := ctxkey.New("key", "")
ctxWithValue := key.WithValue(context.Background(), wantValue)
var mu sync.Mutex
var mu mutex[Reentrant, S]
parent := impl.LockCtx(ctxWithValue, &mu)
wantLocked(t, &mu) // mu is locked by parent
@ -268,103 +227,72 @@ func testContextWrapping[T state](t *testing.T, impl impl[T]) {
}
// ... and the lock state.
child.Unlock() // no-op; mu is owned by parent
child.unlock() // no-op; mu is owned by parent
wantLocked(t, &mu) // mu is still locked by parent
parentDup.Unlock() // no-op; mu is owned by parent
parentDup.unlock() // no-op; mu is owned by parent
wantLocked(t, &mu) // mu is still locked by parent
parent.Unlock() // unlocks mu
parent.unlock() // unlocks mu
wantUnlocked(t, &mu) // mu is now unlocked
}
func TestNilMutex(t *testing.T) {
impl := checkedImpl
wantPanic(t, "nil *sync.Mutex", func() { impl.Lock(impl.None(), nil) })
wantPanic(t, "nil mutex", func() { impl.Lock(impl.None(), nil) })
}
func TestUseUnlockedParent_Checked(t *testing.T) {
impl := checkedImpl
var mu sync.Mutex
var mu checkedMutex[Reentrant]
parent := impl.Lock(impl.None(), &mu)
parent.Unlock() // unlocks mu
parent.unlock() // unlocks mu
wantUnlocked(t, &mu) // mu is now unlocked
wantPanic(t, "use after unlock", func() { impl.Lock(parent, &mu) })
}
func TestUseUnlockedMutex_Checked(t *testing.T) {
impl := checkedImpl
var mu sync.Mutex
parent := impl.Lock(impl.None(), &mu)
mu.Unlock() // unlock mu directly without unlocking parent
wantPanic(t, "*sync.Mutex is spuriously unlocked", func() { impl.Lock(parent, &mu) })
}
func TestUnlockParentFirst_Checked(t *testing.T) {
impl := checkedImpl
var mu sync.Mutex
var mu checkedMutex[Reentrant]
parent := impl.Lock(impl.FromContext(context.Background()), &mu)
child := impl.Lock(parent, &mu)
parent.Unlock() // unlocks mu
parent.unlock() // unlocks mu
wantUnlocked(t, &mu) // mu is now unlocked
wantPanic(t, "parent already unlocked", child.Unlock)
wantPanic(t, "parent already unlocked", child.unlock)
}
func TestUnlockTwice_Checked(t *testing.T) {
impl := checkedImpl
unlockTwice := func(t *testing.T, ctx *checked) {
ctx.Unlock() // unlocks mu
wantPanic(t, "already unlocked", ctx.Unlock)
ctx.unlock() // unlocks mu
wantPanic(t, "already unlocked", ctx.unlock)
}
t.Run("Wrapped", func(t *testing.T) {
unlockTwice(t, impl.FromContext(context.Background()))
})
t.Run("Locked", func(t *testing.T) {
var mu sync.Mutex
var mu checkedMutex[Reentrant]
ctx := impl.Lock(impl.None(), &mu)
unlockTwice(t, ctx)
})
t.Run("Locked/WithReloc", func(t *testing.T) {
var mu sync.Mutex
ctx := impl.Lock(impl.None(), &mu)
ctx.Unlock() // unlocks mu
mu.Lock() // re-locks mu, but not by the state
wantPanic(t, "already unlocked", ctx.Unlock)
})
t.Run("Child", func(t *testing.T) {
var mu sync.Mutex
var mu checkedMutex[Reentrant]
parent := impl.Lock(impl.None(), &mu)
defer parent.Unlock()
defer parent.unlock()
child := impl.Lock(parent, &mu)
unlockTwice(t, child)
})
t.Run("Child/WithReloc", func(t *testing.T) {
var mu sync.Mutex
parent := impl.Lock(impl.None(), &mu)
child := impl.Lock(parent, &mu)
parent.Unlock()
mu.Lock() // re-locks mu, but not the parent state
wantPanic(t, "parent already unlocked", child.Unlock)
})
t.Run("Child/WithManualUnlock", func(t *testing.T) {
var mu sync.Mutex
parent := impl.Lock(impl.None(), &mu)
child := impl.Lock(parent, &mu)
mu.Unlock() // unlocks mu, but not the parent state
wantPanic(t, "mutex is not locked", child.Unlock)
})
t.Run("Grandchild", func(t *testing.T) {
var mu sync.Mutex
var mu checkedMutex[Reentrant]
parent := impl.Lock(impl.None(), &mu)
defer parent.Unlock()
defer parent.unlock()
child := impl.Lock(parent, &mu)
defer child.Unlock()
defer child.unlock()
grandchild := impl.Lock(child, &mu)
unlockTwice(t, grandchild)
})
@ -373,76 +301,70 @@ func TestUnlockTwice_Checked(t *testing.T) {
func TestUseUnlocked_Checked(t *testing.T) {
impl := checkedImpl
var mu sync.Mutex
var mu checkedMutex[Reentrant]
state := lockChecked(impl.None(), &mu)
state.Unlock()
state.unlock()
// All of these should panic since the state is already unlocked.
wantPanic(t, "", func() { state.Deadline() })
wantPanic(t, "", func() { state.Done() })
wantPanic(t, "", func() { state.Err() })
wantPanic(t, "", func() { state.Unlock() })
wantPanic(t, "", func() { state.Value("key") })
wantPanic(t, "*", func() { state.Deadline() })
wantPanic(t, "*", func() { state.Done() })
wantPanic(t, "*", func() { state.Err() })
wantPanic(t, "*", func() { state.unlock() })
wantPanic(t, "*", func() { state.Value("key") })
}
func TestUseZeroState(t *testing.T) {
t.Run("Exported", func(t *testing.T) {
testUseEmptyState(t, exportedImpl.None, exportedImpl)
})
t.Run("Checked", func(t *testing.T) {
testUseEmptyState(t, checkedImpl.None, checkedImpl)
testUseEmptyState(t, checkedImpl.None)
})
t.Run("Unchecked", func(t *testing.T) {
testUseEmptyState(t, uncheckedImpl.None, uncheckedImpl)
testUseEmptyState(t, uncheckedImpl.None)
})
}
func TestUseWrappedBackground(t *testing.T) {
t.Run("Exported", func(t *testing.T) {
testUseEmptyState(t, getWrappedBackground(t, exportedImpl), exportedImpl)
})
t.Run("Checked", func(t *testing.T) {
testUseEmptyState(t, getWrappedBackground(t, checkedImpl), checkedImpl)
testUseEmptyState(t, getWrappedBackground(t, checkedImpl))
})
t.Run("Unchecked", func(t *testing.T) {
testUseEmptyState(t, getWrappedBackground(t, uncheckedImpl), uncheckedImpl)
testUseEmptyState(t, getWrappedBackground(t, uncheckedImpl))
})
}
func getWrappedBackground[T state](t *testing.T, impl impl[T]) func() T {
func getWrappedBackground[T stateType, S lockStateType](t *testing.T, impl impl[T, S]) func() T {
t.Helper()
return func() T {
return impl.FromContext(context.Background())
}
}
func testUseEmptyState[T state](t *testing.T, getCtx func() T, impl impl[T]) {
// Using aan empty [State] must not panic or deadlock.
func testUseEmptyState[T stateType](t *testing.T, getState func() T) {
// Using an empty [State] must not panic or deadlock.
// It should also behave like [context.Background].
for range 2 {
ctx := getCtx()
if gotDone := ctx.Done(); gotDone != nil {
state := getState()
if gotDone := state.Done(); gotDone != nil {
t.Errorf("ctx.Done() = %v; want nil", gotDone)
}
if gotDeadline, ok := ctx.Deadline(); ok {
if gotDeadline, ok := state.Deadline(); ok {
t.Errorf("ctx.Deadline() = %v; want !ok", gotDeadline)
}
if gotErr := ctx.Err(); gotErr != nil {
if gotErr := state.Err(); gotErr != nil {
t.Errorf("ctx.Err() = %v; want nil", gotErr)
}
if gotValue := ctx.Value("test-key"); gotValue != nil {
if gotValue := state.Value("test-key"); gotValue != nil {
t.Errorf("ctx.Value(test-key) = %v; want nil", gotValue)
}
ctx.Unlock()
state.unlock()
}
}
func wantPanic(t *testing.T, wantMsg string, fn func()) {
t.Helper()
defer func() {
if r := recover(); wantMsg != "" {
if gotMsg, ok := r.(string); !ok || gotMsg != wantMsg {
t.Errorf("panic: %v; want %q", r, wantMsg)
if r := recover(); wantMsg != "*" {
if gotMsg := trimPanicMessage(r); gotMsg != wantMsg {
t.Errorf("panic: got %q; want %q", r, wantMsg)
}
}
}()
@ -450,19 +372,26 @@ func wantPanic(t *testing.T, wantMsg string, fn func()) {
t.Fatal("failed to panic")
}
func wantLocked(t *testing.T, m *sync.Mutex) {
if m.TryLock() {
m.Unlock()
func (m *mutex[R, S]) isLockedForTest() bool {
if m.m.TryLock() {
m.m.Unlock()
return false
}
return true
}
func wantLocked[R Rank, S lockStateType](t *testing.T, m *mutex[R, S]) {
t.Helper()
if !m.isLockedForTest() {
t.Fatal("mutex is not locked")
}
}
func wantUnlocked(t *testing.T, m *sync.Mutex) {
func wantUnlocked[R Rank, S lockStateType](t *testing.T, m *mutex[R, S]) {
t.Helper()
if !m.TryLock() {
if m.isLockedForTest() {
t.Fatal("mutex is locked")
}
m.Unlock()
}
func mustNotAllocate(t *testing.T, steps func()) {
@ -472,3 +401,12 @@ func mustNotAllocate(t *testing.T, steps func()) {
t.Errorf("expected 0 allocs, got %f", allocs)
}
}
func trimPanicMessage(r any) string {
msg := fmt.Sprintf("%v", r)
msg = strings.TrimSpace(msg)
if i := strings.IndexByte(msg, '\n'); i >= 0 {
return msg[:i]
}
return msg
}

View File

@ -1,35 +1,103 @@
// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause
// This file exports optimized implementation of the [State] that omits runtime checks.
// It is used when the build tag ts_omit_ctxlock_checks is set.
//go:build ts_omit_ctxlock_checks
package ctxlock
import (
"context"
"sync"
"time"
)
const Checked = false
type State struct {
unchecked
// unchecked is an implementation of [State] that trades additional runtime checks
// for performance.
//
// Its zero value carries no mutex lock state and an empty [context.Context].
type unchecked struct {
context.Context // nil means an empty context
mu mutexHandle // non-nil if owned by this state
}
func None() State {
return State{}
type (
alreadyLocked struct{}
uncheckedMutex[R Rank] = mutex[R, unchecked]
)
func fromContextUnchecked(ctx context.Context) unchecked {
return unchecked{ctx, nil}
}
func FromContext(parent context.Context) State {
return State{fromContextUnchecked(parent)}
}
func lockUnchecked[R Rank](parent unchecked, mu *uncheckedMutex[R]) unchecked {
if !parent.isAlreadyLocked(mu) {
mu.lock()
// Locking a mutex creates a new state that must be accessible from any derived state.
// Normally, this state would be heap-allocated, but we want to avoid allocating new memory
// on every lock. Instead, we use a storage region within the mutex itself.
mu.lockState = unchecked{parent.Context, mu}
return unchecked{&mu.lockState, mu}
func Lock[T context.Context](parent T, mu *sync.Mutex) State {
if parent, ok := any(parent).(State); ok {
return State{lockUnchecked(parent.unchecked, mu)}
}
return State{lockUnchecked(fromContextUnchecked(parent), mu)}
// The mutex is already locked by a parent or ancestor state.
return unchecked{parent.Context, nil}
}
func (c unchecked) isAlreadyLocked(m mutexHandle) bool {
switch val := c.Value(m).(type) {
case nil:
// No ancestor state associated with this mutex,
// and locking it does not violate the lock ordering.
return false
case error:
// There's a lock ordering or reentrancy violation.
panic(val)
case alreadyLocked:
// The mutex is reentrant and is already held by a parent
// or ancestor state.
return true
default:
panic("unreachable")
}
}
func (c unchecked) unlock() {
if c.mu != nil {
c.mu.unlock()
}
}
// Deadline implements [context.Context].
func (c unchecked) Deadline() (deadline time.Time, ok bool) {
if c.Context == nil {
return time.Time{}, false
}
return c.Context.Deadline()
}
// Done implements [context.Context].
func (c unchecked) Done() <-chan struct{} {
if c.Context == nil {
return nil
}
return c.Context.Done()
}
// Err implements [context.Context].
func (c unchecked) Err() error {
if c.Context == nil {
return nil
}
return c.Context.Err()
}
// Err implements [context.Context].
func (c unchecked) Value(key any) any {
if mu, ok := key.(mutexHandle); ok {
if res, done := checkLockOrder(mu, c.mu, alreadyLocked{}); done {
// We have a definite answer.
return res
}
}
if c.Context == nil {
return nil
}
return c.Context.Value(key)
}

View File

@ -0,0 +1,20 @@
// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause
//go:build !ts_omit_ctxlock_checks
package ctxlock
const useCheckedImpl = true
type (
stateImpl = *checked
lockState = lockCallers
_ = lockState
)
var fromContext = fromContextChecked
func lock[R Rank](parent stateImpl, mu *checkedMutex[R]) stateImpl {
return lockChecked(parent, mu)
}

View File

@ -0,0 +1,20 @@
// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause
//go:build ts_omit_ctxlock_checks
package ctxlock
const useCheckedImpl = false
type (
stateImpl = unchecked
lockState = unchecked
_ = lockState
)
var fromContext = fromContextUnchecked
func lock[R Rank](parent stateImpl, mu *uncheckedMutex[R]) stateImpl {
return lockUnchecked(parent, mu)
}