util/ctxlock: make ctxlock.Context generic

Updates #12614
Updates #15824

Signed-off-by: Nick Khyl <nickk@tailscale.com>
This commit is contained in:
Nick Khyl 2025-05-01 10:39:52 -05:00
parent f605a99e0b
commit 968e921deb
No known key found for this signature in database
5 changed files with 88 additions and 77 deletions

View File

@ -17,60 +17,62 @@ import (
) )
var ( var (
noneCtx = context.Background() noneCtx = context.Background()
noneUnchecked = unchecked{noneCtx, nil}
) )
type lockerKey struct{ *sync.Mutex } type lockerKey[T any] struct{ key T }
func lockerKeyOf(mu *sync.Mutex) lockerKey { func lockerKeyOf[T sync.Locker](mu T) lockerKey[T] {
return lockerKey{mu} return lockerKey[T]{key: mu}
} }
// checked is an implementation of [Context] that performs runtime checks // checked is an implementation of [Context] that performs runtime checks
// to ensure that the context is used correctly. // to ensure that the context is used correctly.
type checked struct { type checked[T sync.Locker] struct {
context.Context // nil after [checked.Unlock] is called context.Context // nil after [checked.Unlock] is called
mu *sync.Mutex // nil if the context does not track a mutex lock state mu T // nil if the context does not track a mutex lock state
parent *checked // nil if the context owns the lock parent *checked[T] // nil if the context owns the lock
} }
func noneChecked() *checked { func noneChecked[T sync.Locker]() *checked[T] {
return &checked{noneCtx, nil, nil} var zero T
return &checked[T]{noneCtx, zero, nil}
} }
func wrapChecked(parent context.Context) *checked { func wrapChecked[T sync.Locker](parent context.Context) *checked[T] {
return &checked{parent, nil, nil} var zero T
return &checked[T]{parent, zero, nil}
} }
func lockChecked(parent *checked, mu *sync.Mutex) *checked { func lockChecked[T, P sync.Locker](parent *checked[P], mu T) *checked[T] {
checkLockArgs(parent, mu) checkLockArgs(parent, mu)
if parentLockCtx, ok := parent.Value(lockerKeyOf(mu)).(*checked); ok { if parentLockCtx, ok := parent.Value(lockerKeyOf(mu)).(*checked[T]); ok {
if appearsUnlocked(mu) { if appearsUnlocked(mu) {
// The parent still owns the lock, but the mutex is unlocked. // The parent still owns the lock, but the mutex is unlocked.
panic("mu is already unlocked") panic("mu is already unlocked")
} }
return &checked{parent, mu, parentLockCtx} return &checked[T]{parent, mu, parentLockCtx}
} }
mu.Lock() mu.Lock()
return &checked{parent, mu, nil} return &checked[T]{parent, mu, nil}
} }
func (c *checked) Value(key any) any { func (c *checked[T]) Value(key any) any {
if c.Context == nil { if c.Context == nil {
panic("use of context after unlock") panic("use of context after unlock")
} }
if key == lockerKeyOf(c.mu) { if key == any(lockerKeyOf(c.mu)) {
return c return c
} }
return c.Context.Value(key) return c.Context.Value(key)
} }
func (c *checked) Unlock() { func (c *checked[T]) Unlock() {
var zero T
switch { switch {
case c.Context == nil: case c.Context == nil:
panic("already unlocked") panic("already unlocked")
case c.mu == nil: case any(c.mu) == any(zero):
// No-op; the context does not track a mutex lock state, // No-op; the context does not track a mutex lock state,
// such as when it was created with [noneChecked] or [wrapChecked]. // such as when it was created with [noneChecked] or [wrapChecked].
case appearsUnlocked(c.mu): case appearsUnlocked(c.mu):
@ -88,45 +90,54 @@ func (c *checked) Unlock() {
func checkLockArgs[T interface { func checkLockArgs[T interface {
context.Context context.Context
comparable comparable
}](parent T, mu *sync.Mutex) { }, L sync.Locker](parent T, mu L) {
var zero T var zero T
var nilLocker L
if parent == zero { if parent == zero {
panic("nil parent context") panic("nil parent context")
} }
if mu == nil { if any(mu) == any(nilLocker) {
panic("nil locker") panic("nil locker")
} }
} }
// unchecked is an implementation of [Context] that trades runtime checks for performance. // unchecked is an implementation of [Context] that trades runtime checks for performance.
type unchecked struct { type unchecked[T sync.Locker] struct {
context.Context // always non-nil context.Context // always non-nil
mu *sync.Mutex // non-nil if locked by this context mu T // non-nil if locked by this context
} }
func wrapUnchecked(parent context.Context) unchecked { func noneUnchecked[T sync.Locker]() unchecked[T] {
return unchecked{parent, nil} var zero T
return unchecked[T]{noneCtx, zero}
} }
func lockUnchecked(parent unchecked, mu *sync.Mutex) unchecked { func wrapUnchecked[T sync.Locker](parent context.Context) unchecked[T] {
checkLockArgs(parent, mu) // this is cheap, so we do it even in the unchecked case var zero T
return unchecked[T]{parent, zero}
}
func lockUnchecked[T, P sync.Locker](parent unchecked[P], mu T) unchecked[T] {
checkLockArgs(parent.Context, mu) // this is cheap, so we do it even in the unchecked case
if parent.Value(lockerKeyOf(mu)) == nil { if parent.Value(lockerKeyOf(mu)) == nil {
mu.Lock() mu.Lock()
} else { } else {
mu = nil // already locked by a parent/ancestor var zero T
mu = zero // already locked by a parent/ancestor
} }
return unchecked{parent.Context, mu} return unchecked[T]{parent.Context, mu}
} }
func (c unchecked) Value(key any) any { func (c unchecked[T]) Value(key any) any {
if key == lockerKeyOf(c.mu) { if any(key) == any(lockerKeyOf(c.mu)) {
return key return key
} }
return c.Context.Value(key) return c.Context.Value(key)
} }
func (c unchecked) Unlock() { func (c unchecked[T]) Unlock() {
if c.mu != nil { var zero T
if any(c.mu) != any(zero) {
c.mu.Unlock() c.mu.Unlock()
} }
} }

View File

@ -17,24 +17,24 @@ import (
// Calling [Context.Unlock] on a [Context] unlocks the mutex locked by the context, if any. // Calling [Context.Unlock] on a [Context] unlocks the mutex locked by the context, if any.
// It is a runtime error to call [Context.Unlock] more than once, // It is a runtime error to call [Context.Unlock] more than once,
// or use a [Context] after calling [Context.Unlock]. // or use a [Context] after calling [Context.Unlock].
type Context struct { type Context[T sync.Locker] struct {
*checked *checked[T]
} }
// None returns a [Context] that carries no mutex lock state and an empty [context.Context]. // None returns a [Context] that carries no mutex lock state and an empty [context.Context].
// //
// It is typically used by top-level callers that do not have a parent context to pass in, // It is typically used by top-level callers that do not have a parent context to pass in,
// and is a shorthand for [Context]([context.Background]). // and is a shorthand for [Context]([context.Background]).
func None() Context { func None[T sync.Locker]() Context[T] {
return Context{noneChecked()} return Context[T]{noneChecked[T]()}
} }
// Wrap returns a derived [Context] that wraps the provided [context.Context]. // Wrap returns a derived [Context] that wraps the provided [context.Context].
// //
// It is typically used by callers that already have a [context.Context], // It is typically used by callers that already have a [context.Context],
// which may or may not be a [Context] tracking a mutex lock state. // which may or may not be a [Context] tracking a mutex lock state.
func Wrap(parent context.Context) Context { func Wrap[T sync.Locker](parent context.Context) Context[T] {
return Context{wrapChecked(parent)} return Context[T]{wrapChecked[T](parent)}
} }
// Lock returns a derived [Context] that wraps the provided [context.Context] // Lock returns a derived [Context] that wraps the provided [context.Context]
@ -43,6 +43,6 @@ func Wrap(parent context.Context) Context {
// It locks the mutex unless it is already held by the parent or an ancestor [Context]. // It locks the mutex unless it is already held by the parent or an ancestor [Context].
// It is a runtime error to pass a nil mutex or to unlock the parent context // It is a runtime error to pass a nil mutex or to unlock the parent context
// before the returned one. // before the returned one.
func Lock(parent Context, mu *sync.Mutex) Context { func Lock[T, P sync.Locker](parent Context[P], mu T) Context[T] {
return Context{lockChecked(parent.checked, mu)} return Context[T]{lockChecked(parent.checked, mu)}
} }

View File

@ -23,20 +23,20 @@ type impl[T ctx] struct {
} }
var ( var (
exportedImpl = impl[Context]{ exportedImpl = impl[Context[*sync.Mutex]]{
None: None, None: None[*sync.Mutex],
Wrap: Wrap, Wrap: Wrap[*sync.Mutex],
Lock: Lock, Lock: Lock[*sync.Mutex, *sync.Mutex],
} }
checkedImpl = impl[*checked]{ checkedImpl = impl[*checked[*sync.Mutex]]{
None: noneChecked, None: noneChecked[*sync.Mutex],
Wrap: wrapChecked, Wrap: wrapChecked[*sync.Mutex],
Lock: lockChecked, Lock: lockChecked[*sync.Mutex, *sync.Mutex],
} }
uncheckedImpl = impl[unchecked]{ uncheckedImpl = impl[unchecked[*sync.Mutex]]{
None: func() unchecked { return noneUnchecked }, None: noneUnchecked[*sync.Mutex],
Wrap: wrapUnchecked, Wrap: wrapUnchecked[*sync.Mutex],
Lock: lockUnchecked, Lock: lockUnchecked[*sync.Mutex, *sync.Mutex],
} }
) )
@ -207,7 +207,7 @@ func TestUnlockParentFirst_Checked(t *testing.T) {
func TestUnlockTwice_Checked(t *testing.T) { func TestUnlockTwice_Checked(t *testing.T) {
impl := checkedImpl impl := checkedImpl
doTest := func(t *testing.T, ctx *checked) { doTest := func(t *testing.T, ctx *checked[*sync.Mutex]) {
ctx.Unlock() // unlocks mu ctx.Unlock() // unlocks mu
wantPanic(t, ctx.Unlock) // panics since mu is already unlocked wantPanic(t, ctx.Unlock) // panics since mu is already unlocked
} }

View File

@ -13,18 +13,18 @@ import (
"sync" "sync"
) )
type Context struct { type Context[T sync.Locker] struct {
unchecked unchecked[T]
} }
func None() Context { func None[T sync.Locker]() Context[T] {
return Context{noneUnchecked} return Context[T]{noneUnchecked[T]()}
} }
func Wrap(parent context.Context) Context { func Wrap[T sync.Locker](parent context.Context) Context[T] {
return Context{wrapUnchecked(parent)} return Context[T]{wrapUnchecked[T](parent)}
} }
func Lock(parent Context, mu *sync.Mutex) Context { func Lock[T, P sync.Locker](parent Context[P], mu T) Context[T] {
return Context{lockUnchecked(parent.unchecked, mu)} return Context[T]{lockUnchecked(parent.unchecked, mu)}
} }

View File

@ -16,31 +16,31 @@ type Resource struct {
foo, bar string foo, bar string
} }
func (r *Resource) GetFoo(ctx ctxlock.Context) string { func (r *Resource) GetFoo(ctx ctxlock.Context[*sync.Mutex]) string {
defer ctxlock.Lock(ctx, &r.mu).Unlock() // Lock the mutex if not already held. defer ctxlock.Lock(ctx, &r.mu).Unlock() // Lock the mutex if not already held.
syncs.AssertLocked(&r.mu) // Panics if mu is still unlocked. syncs.AssertLocked(&r.mu) // Panics if mu is still unlocked.
return r.foo return r.foo
} }
func (r *Resource) SetFoo(ctx ctxlock.Context, foo string) { func (r *Resource) SetFoo(ctx ctxlock.Context[*sync.Mutex], foo string) {
defer ctxlock.Lock(ctx, &r.mu).Unlock() defer ctxlock.Lock(ctx, &r.mu).Unlock()
syncs.AssertLocked(&r.mu) syncs.AssertLocked(&r.mu)
r.foo = foo r.foo = foo
} }
func (r *Resource) GetBar(ctx ctxlock.Context) string { func (r *Resource) GetBar(ctx ctxlock.Context[*sync.Mutex]) string {
defer ctxlock.Lock(ctx, &r.mu).Unlock() defer ctxlock.Lock(ctx, &r.mu).Unlock()
syncs.AssertLocked(&r.mu) syncs.AssertLocked(&r.mu)
return r.bar return r.bar
} }
func (r *Resource) SetBar(ctx ctxlock.Context, bar string) { func (r *Resource) SetBar(ctx ctxlock.Context[*sync.Mutex], bar string) {
defer ctxlock.Lock(ctx, &r.mu).Unlock() defer ctxlock.Lock(ctx, &r.mu).Unlock()
syncs.AssertLocked(&r.mu) syncs.AssertLocked(&r.mu)
r.bar = bar r.bar = bar
} }
func (r *Resource) WithLock(ctx ctxlock.Context, f func(ctx ctxlock.Context)) { func (r *Resource) WithLock(ctx ctxlock.Context[*sync.Mutex], f func(ctx ctxlock.Context[*sync.Mutex])) {
// Lock the mutex if not already held, and get a new context. // Lock the mutex if not already held, and get a new context.
ctx = ctxlock.Lock(ctx, &r.mu) ctx = ctxlock.Lock(ctx, &r.mu)
defer ctx.Unlock() defer ctx.Unlock()
@ -50,27 +50,27 @@ func (r *Resource) WithLock(ctx ctxlock.Context, f func(ctx ctxlock.Context)) {
func ExampleContext() { func ExampleContext() {
var r Resource var r Resource
r.SetFoo(ctxlock.None(), "foo") r.SetFoo(ctxlock.None[*sync.Mutex](), "foo")
r.SetBar(ctxlock.None(), "bar") r.SetBar(ctxlock.None[*sync.Mutex](), "bar")
r.WithLock(ctxlock.None(), func(ctx ctxlock.Context) { r.WithLock(ctxlock.None[*sync.Mutex](), func(ctx ctxlock.Context[*sync.Mutex]) {
// This callback is invoked with the Resource's lock held, // This callback is invoked with the Resource's lock held,
// and the ctx tracks carries the lock state. This means we can safely call // and the ctx tracks carries the lock state. This means we can safely call
// other methods on the Resource using ctx without causing a deadlock. // other methods on the Resource using ctx without causing a deadlock.
r.SetFoo(ctx, r.GetFoo(ctx)+r.GetBar(ctx)) r.SetFoo(ctx, r.GetFoo(ctx)+r.GetBar(ctx))
}) })
fmt.Println(r.GetFoo(ctxlock.None())) fmt.Println(r.GetFoo(ctxlock.None[*sync.Mutex]()))
// Output: foobar // Output: foobar
} }
func ExampleContext_twoResources() { func ExampleContext_twoResources() {
var r1, r2 Resource var r1, r2 Resource
r1.SetFoo(ctxlock.None(), "foo") r1.SetFoo(ctxlock.None[*sync.Mutex](), "foo")
r2.SetBar(ctxlock.None(), "bar") r2.SetBar(ctxlock.None[*sync.Mutex](), "bar")
r1.WithLock(ctxlock.None(), func(ctx ctxlock.Context) { r1.WithLock(ctxlock.None[*sync.Mutex](), func(ctx ctxlock.Context[*sync.Mutex]) {
// Here, r1's lock is held, but r2's lock is not. // Here, r1's lock is held, but r2's lock is not.
// So r2 will be locked when we call r2.SetBar(ctx). // So r2 will be locked when we call r2.SetBar(ctx).
r1.SetFoo(ctx, r1.GetFoo(ctx)+r2.GetBar(ctx)) r1.SetFoo(ctx, r1.GetFoo(ctx)+r2.GetBar(ctx))
}) })
fmt.Println(r1.GetFoo(ctxlock.None())) fmt.Println(r1.GetFoo(ctxlock.None[*sync.Mutex]()))
// Output: foobar // Output: foobar
} }