util/ctxlock: make ctxlock.Context generic

Updates #12614
Updates #15824

Signed-off-by: Nick Khyl <nickk@tailscale.com>
This commit is contained in:
Nick Khyl 2025-05-01 10:39:52 -05:00
parent f605a99e0b
commit 968e921deb
No known key found for this signature in database
5 changed files with 88 additions and 77 deletions

View File

@ -17,60 +17,62 @@ import (
)
var (
noneCtx = context.Background()
noneUnchecked = unchecked{noneCtx, nil}
noneCtx = context.Background()
)
type lockerKey struct{ *sync.Mutex }
type lockerKey[T any] struct{ key T }
func lockerKeyOf(mu *sync.Mutex) lockerKey {
return lockerKey{mu}
func lockerKeyOf[T sync.Locker](mu T) lockerKey[T] {
return lockerKey[T]{key: mu}
}
// checked is an implementation of [Context] that performs runtime checks
// to ensure that the context is used correctly.
type checked struct {
type checked[T sync.Locker] struct {
context.Context // nil after [checked.Unlock] is called
mu *sync.Mutex // nil if the context does not track a mutex lock state
parent *checked // nil if the context owns the lock
mu T // nil if the context does not track a mutex lock state
parent *checked[T] // nil if the context owns the lock
}
func noneChecked() *checked {
return &checked{noneCtx, nil, nil}
func noneChecked[T sync.Locker]() *checked[T] {
var zero T
return &checked[T]{noneCtx, zero, nil}
}
func wrapChecked(parent context.Context) *checked {
return &checked{parent, nil, nil}
func wrapChecked[T sync.Locker](parent context.Context) *checked[T] {
var zero T
return &checked[T]{parent, zero, nil}
}
func lockChecked(parent *checked, mu *sync.Mutex) *checked {
func lockChecked[T, P sync.Locker](parent *checked[P], mu T) *checked[T] {
checkLockArgs(parent, mu)
if parentLockCtx, ok := parent.Value(lockerKeyOf(mu)).(*checked); ok {
if parentLockCtx, ok := parent.Value(lockerKeyOf(mu)).(*checked[T]); ok {
if appearsUnlocked(mu) {
// The parent still owns the lock, but the mutex is unlocked.
panic("mu is already unlocked")
}
return &checked{parent, mu, parentLockCtx}
return &checked[T]{parent, mu, parentLockCtx}
}
mu.Lock()
return &checked{parent, mu, nil}
return &checked[T]{parent, mu, nil}
}
func (c *checked) Value(key any) any {
func (c *checked[T]) Value(key any) any {
if c.Context == nil {
panic("use of context after unlock")
}
if key == lockerKeyOf(c.mu) {
if key == any(lockerKeyOf(c.mu)) {
return c
}
return c.Context.Value(key)
}
func (c *checked) Unlock() {
func (c *checked[T]) Unlock() {
var zero T
switch {
case c.Context == nil:
panic("already unlocked")
case c.mu == nil:
case any(c.mu) == any(zero):
// No-op; the context does not track a mutex lock state,
// such as when it was created with [noneChecked] or [wrapChecked].
case appearsUnlocked(c.mu):
@ -88,45 +90,54 @@ func (c *checked) Unlock() {
func checkLockArgs[T interface {
context.Context
comparable
}](parent T, mu *sync.Mutex) {
}, L sync.Locker](parent T, mu L) {
var zero T
var nilLocker L
if parent == zero {
panic("nil parent context")
}
if mu == nil {
if any(mu) == any(nilLocker) {
panic("nil locker")
}
}
// unchecked is an implementation of [Context] that trades runtime checks for performance.
type unchecked struct {
context.Context // always non-nil
mu *sync.Mutex // non-nil if locked by this context
type unchecked[T sync.Locker] struct {
context.Context // always non-nil
mu T // non-nil if locked by this context
}
func wrapUnchecked(parent context.Context) unchecked {
return unchecked{parent, nil}
func noneUnchecked[T sync.Locker]() unchecked[T] {
var zero T
return unchecked[T]{noneCtx, zero}
}
func lockUnchecked(parent unchecked, mu *sync.Mutex) unchecked {
checkLockArgs(parent, mu) // this is cheap, so we do it even in the unchecked case
func wrapUnchecked[T sync.Locker](parent context.Context) unchecked[T] {
var zero T
return unchecked[T]{parent, zero}
}
func lockUnchecked[T, P sync.Locker](parent unchecked[P], mu T) unchecked[T] {
checkLockArgs(parent.Context, mu) // this is cheap, so we do it even in the unchecked case
if parent.Value(lockerKeyOf(mu)) == nil {
mu.Lock()
} else {
mu = nil // already locked by a parent/ancestor
var zero T
mu = zero // already locked by a parent/ancestor
}
return unchecked{parent.Context, mu}
return unchecked[T]{parent.Context, mu}
}
func (c unchecked) Value(key any) any {
if key == lockerKeyOf(c.mu) {
func (c unchecked[T]) Value(key any) any {
if any(key) == any(lockerKeyOf(c.mu)) {
return key
}
return c.Context.Value(key)
}
func (c unchecked) Unlock() {
if c.mu != nil {
func (c unchecked[T]) Unlock() {
var zero T
if any(c.mu) != any(zero) {
c.mu.Unlock()
}
}

View File

@ -17,24 +17,24 @@ import (
// Calling [Context.Unlock] on a [Context] unlocks the mutex locked by the context, if any.
// It is a runtime error to call [Context.Unlock] more than once,
// or use a [Context] after calling [Context.Unlock].
type Context struct {
*checked
type Context[T sync.Locker] struct {
*checked[T]
}
// None returns a [Context] that carries no mutex lock state and an empty [context.Context].
//
// It is typically used by top-level callers that do not have a parent context to pass in,
// and is a shorthand for [Context]([context.Background]).
func None() Context {
return Context{noneChecked()}
func None[T sync.Locker]() Context[T] {
return Context[T]{noneChecked[T]()}
}
// Wrap returns a derived [Context] that wraps the provided [context.Context].
//
// It is typically used by callers that already have a [context.Context],
// which may or may not be a [Context] tracking a mutex lock state.
func Wrap(parent context.Context) Context {
return Context{wrapChecked(parent)}
func Wrap[T sync.Locker](parent context.Context) Context[T] {
return Context[T]{wrapChecked[T](parent)}
}
// Lock returns a derived [Context] that wraps the provided [context.Context]
@ -43,6 +43,6 @@ func Wrap(parent context.Context) Context {
// It locks the mutex unless it is already held by the parent or an ancestor [Context].
// It is a runtime error to pass a nil mutex or to unlock the parent context
// before the returned one.
func Lock(parent Context, mu *sync.Mutex) Context {
return Context{lockChecked(parent.checked, mu)}
func Lock[T, P sync.Locker](parent Context[P], mu T) Context[T] {
return Context[T]{lockChecked(parent.checked, mu)}
}

View File

@ -23,20 +23,20 @@ type impl[T ctx] struct {
}
var (
exportedImpl = impl[Context]{
None: None,
Wrap: Wrap,
Lock: Lock,
exportedImpl = impl[Context[*sync.Mutex]]{
None: None[*sync.Mutex],
Wrap: Wrap[*sync.Mutex],
Lock: Lock[*sync.Mutex, *sync.Mutex],
}
checkedImpl = impl[*checked]{
None: noneChecked,
Wrap: wrapChecked,
Lock: lockChecked,
checkedImpl = impl[*checked[*sync.Mutex]]{
None: noneChecked[*sync.Mutex],
Wrap: wrapChecked[*sync.Mutex],
Lock: lockChecked[*sync.Mutex, *sync.Mutex],
}
uncheckedImpl = impl[unchecked]{
None: func() unchecked { return noneUnchecked },
Wrap: wrapUnchecked,
Lock: lockUnchecked,
uncheckedImpl = impl[unchecked[*sync.Mutex]]{
None: noneUnchecked[*sync.Mutex],
Wrap: wrapUnchecked[*sync.Mutex],
Lock: lockUnchecked[*sync.Mutex, *sync.Mutex],
}
)
@ -207,7 +207,7 @@ func TestUnlockParentFirst_Checked(t *testing.T) {
func TestUnlockTwice_Checked(t *testing.T) {
impl := checkedImpl
doTest := func(t *testing.T, ctx *checked) {
doTest := func(t *testing.T, ctx *checked[*sync.Mutex]) {
ctx.Unlock() // unlocks mu
wantPanic(t, ctx.Unlock) // panics since mu is already unlocked
}

View File

@ -13,18 +13,18 @@ import (
"sync"
)
type Context struct {
unchecked
type Context[T sync.Locker] struct {
unchecked[T]
}
func None() Context {
return Context{noneUnchecked}
func None[T sync.Locker]() Context[T] {
return Context[T]{noneUnchecked[T]()}
}
func Wrap(parent context.Context) Context {
return Context{wrapUnchecked(parent)}
func Wrap[T sync.Locker](parent context.Context) Context[T] {
return Context[T]{wrapUnchecked[T](parent)}
}
func Lock(parent Context, mu *sync.Mutex) Context {
return Context{lockUnchecked(parent.unchecked, mu)}
func Lock[T, P sync.Locker](parent Context[P], mu T) Context[T] {
return Context[T]{lockUnchecked(parent.unchecked, mu)}
}

View File

@ -16,31 +16,31 @@ type Resource struct {
foo, bar string
}
func (r *Resource) GetFoo(ctx ctxlock.Context) string {
func (r *Resource) GetFoo(ctx ctxlock.Context[*sync.Mutex]) string {
defer ctxlock.Lock(ctx, &r.mu).Unlock() // Lock the mutex if not already held.
syncs.AssertLocked(&r.mu) // Panics if mu is still unlocked.
return r.foo
}
func (r *Resource) SetFoo(ctx ctxlock.Context, foo string) {
func (r *Resource) SetFoo(ctx ctxlock.Context[*sync.Mutex], foo string) {
defer ctxlock.Lock(ctx, &r.mu).Unlock()
syncs.AssertLocked(&r.mu)
r.foo = foo
}
func (r *Resource) GetBar(ctx ctxlock.Context) string {
func (r *Resource) GetBar(ctx ctxlock.Context[*sync.Mutex]) string {
defer ctxlock.Lock(ctx, &r.mu).Unlock()
syncs.AssertLocked(&r.mu)
return r.bar
}
func (r *Resource) SetBar(ctx ctxlock.Context, bar string) {
func (r *Resource) SetBar(ctx ctxlock.Context[*sync.Mutex], bar string) {
defer ctxlock.Lock(ctx, &r.mu).Unlock()
syncs.AssertLocked(&r.mu)
r.bar = bar
}
func (r *Resource) WithLock(ctx ctxlock.Context, f func(ctx ctxlock.Context)) {
func (r *Resource) WithLock(ctx ctxlock.Context[*sync.Mutex], f func(ctx ctxlock.Context[*sync.Mutex])) {
// Lock the mutex if not already held, and get a new context.
ctx = ctxlock.Lock(ctx, &r.mu)
defer ctx.Unlock()
@ -50,27 +50,27 @@ func (r *Resource) WithLock(ctx ctxlock.Context, f func(ctx ctxlock.Context)) {
func ExampleContext() {
var r Resource
r.SetFoo(ctxlock.None(), "foo")
r.SetBar(ctxlock.None(), "bar")
r.WithLock(ctxlock.None(), func(ctx ctxlock.Context) {
r.SetFoo(ctxlock.None[*sync.Mutex](), "foo")
r.SetBar(ctxlock.None[*sync.Mutex](), "bar")
r.WithLock(ctxlock.None[*sync.Mutex](), func(ctx ctxlock.Context[*sync.Mutex]) {
// This callback is invoked with the Resource's lock held,
// and the ctx tracks carries the lock state. This means we can safely call
// other methods on the Resource using ctx without causing a deadlock.
r.SetFoo(ctx, r.GetFoo(ctx)+r.GetBar(ctx))
})
fmt.Println(r.GetFoo(ctxlock.None()))
fmt.Println(r.GetFoo(ctxlock.None[*sync.Mutex]()))
// Output: foobar
}
func ExampleContext_twoResources() {
var r1, r2 Resource
r1.SetFoo(ctxlock.None(), "foo")
r2.SetBar(ctxlock.None(), "bar")
r1.WithLock(ctxlock.None(), func(ctx ctxlock.Context) {
r1.SetFoo(ctxlock.None[*sync.Mutex](), "foo")
r2.SetBar(ctxlock.None[*sync.Mutex](), "bar")
r1.WithLock(ctxlock.None[*sync.Mutex](), func(ctx ctxlock.Context[*sync.Mutex]) {
// Here, r1's lock is held, but r2's lock is not.
// So r2 will be locked when we call r2.SetBar(ctx).
r1.SetFoo(ctx, r1.GetFoo(ctx)+r2.GetBar(ctx))
})
fmt.Println(r1.GetFoo(ctxlock.None()))
fmt.Println(r1.GetFoo(ctxlock.None[*sync.Mutex]()))
// Output: foobar
}