mirror of
https://github.com/tailscale/tailscale.git
synced 2025-07-29 15:23:45 +00:00
util/ctxlock: make ctxlock.Context generic
Updates #12614 Updates #15824 Signed-off-by: Nick Khyl <nickk@tailscale.com>
This commit is contained in:
parent
f605a99e0b
commit
968e921deb
@ -17,60 +17,62 @@ import (
|
||||
)
|
||||
|
||||
var (
|
||||
noneCtx = context.Background()
|
||||
noneUnchecked = unchecked{noneCtx, nil}
|
||||
noneCtx = context.Background()
|
||||
)
|
||||
|
||||
type lockerKey struct{ *sync.Mutex }
|
||||
type lockerKey[T any] struct{ key T }
|
||||
|
||||
func lockerKeyOf(mu *sync.Mutex) lockerKey {
|
||||
return lockerKey{mu}
|
||||
func lockerKeyOf[T sync.Locker](mu T) lockerKey[T] {
|
||||
return lockerKey[T]{key: mu}
|
||||
}
|
||||
|
||||
// checked is an implementation of [Context] that performs runtime checks
|
||||
// to ensure that the context is used correctly.
|
||||
type checked struct {
|
||||
type checked[T sync.Locker] struct {
|
||||
context.Context // nil after [checked.Unlock] is called
|
||||
mu *sync.Mutex // nil if the context does not track a mutex lock state
|
||||
parent *checked // nil if the context owns the lock
|
||||
mu T // nil if the context does not track a mutex lock state
|
||||
parent *checked[T] // nil if the context owns the lock
|
||||
}
|
||||
|
||||
func noneChecked() *checked {
|
||||
return &checked{noneCtx, nil, nil}
|
||||
func noneChecked[T sync.Locker]() *checked[T] {
|
||||
var zero T
|
||||
return &checked[T]{noneCtx, zero, nil}
|
||||
}
|
||||
|
||||
func wrapChecked(parent context.Context) *checked {
|
||||
return &checked{parent, nil, nil}
|
||||
func wrapChecked[T sync.Locker](parent context.Context) *checked[T] {
|
||||
var zero T
|
||||
return &checked[T]{parent, zero, nil}
|
||||
}
|
||||
|
||||
func lockChecked(parent *checked, mu *sync.Mutex) *checked {
|
||||
func lockChecked[T, P sync.Locker](parent *checked[P], mu T) *checked[T] {
|
||||
checkLockArgs(parent, mu)
|
||||
if parentLockCtx, ok := parent.Value(lockerKeyOf(mu)).(*checked); ok {
|
||||
if parentLockCtx, ok := parent.Value(lockerKeyOf(mu)).(*checked[T]); ok {
|
||||
if appearsUnlocked(mu) {
|
||||
// The parent still owns the lock, but the mutex is unlocked.
|
||||
panic("mu is already unlocked")
|
||||
}
|
||||
return &checked{parent, mu, parentLockCtx}
|
||||
return &checked[T]{parent, mu, parentLockCtx}
|
||||
}
|
||||
mu.Lock()
|
||||
return &checked{parent, mu, nil}
|
||||
return &checked[T]{parent, mu, nil}
|
||||
}
|
||||
|
||||
func (c *checked) Value(key any) any {
|
||||
func (c *checked[T]) Value(key any) any {
|
||||
if c.Context == nil {
|
||||
panic("use of context after unlock")
|
||||
}
|
||||
if key == lockerKeyOf(c.mu) {
|
||||
if key == any(lockerKeyOf(c.mu)) {
|
||||
return c
|
||||
}
|
||||
return c.Context.Value(key)
|
||||
}
|
||||
|
||||
func (c *checked) Unlock() {
|
||||
func (c *checked[T]) Unlock() {
|
||||
var zero T
|
||||
switch {
|
||||
case c.Context == nil:
|
||||
panic("already unlocked")
|
||||
case c.mu == nil:
|
||||
case any(c.mu) == any(zero):
|
||||
// No-op; the context does not track a mutex lock state,
|
||||
// such as when it was created with [noneChecked] or [wrapChecked].
|
||||
case appearsUnlocked(c.mu):
|
||||
@ -88,45 +90,54 @@ func (c *checked) Unlock() {
|
||||
func checkLockArgs[T interface {
|
||||
context.Context
|
||||
comparable
|
||||
}](parent T, mu *sync.Mutex) {
|
||||
}, L sync.Locker](parent T, mu L) {
|
||||
var zero T
|
||||
var nilLocker L
|
||||
if parent == zero {
|
||||
panic("nil parent context")
|
||||
}
|
||||
if mu == nil {
|
||||
if any(mu) == any(nilLocker) {
|
||||
panic("nil locker")
|
||||
}
|
||||
}
|
||||
|
||||
// unchecked is an implementation of [Context] that trades runtime checks for performance.
|
||||
type unchecked struct {
|
||||
context.Context // always non-nil
|
||||
mu *sync.Mutex // non-nil if locked by this context
|
||||
type unchecked[T sync.Locker] struct {
|
||||
context.Context // always non-nil
|
||||
mu T // non-nil if locked by this context
|
||||
}
|
||||
|
||||
func wrapUnchecked(parent context.Context) unchecked {
|
||||
return unchecked{parent, nil}
|
||||
func noneUnchecked[T sync.Locker]() unchecked[T] {
|
||||
var zero T
|
||||
return unchecked[T]{noneCtx, zero}
|
||||
}
|
||||
|
||||
func lockUnchecked(parent unchecked, mu *sync.Mutex) unchecked {
|
||||
checkLockArgs(parent, mu) // this is cheap, so we do it even in the unchecked case
|
||||
func wrapUnchecked[T sync.Locker](parent context.Context) unchecked[T] {
|
||||
var zero T
|
||||
return unchecked[T]{parent, zero}
|
||||
}
|
||||
|
||||
func lockUnchecked[T, P sync.Locker](parent unchecked[P], mu T) unchecked[T] {
|
||||
checkLockArgs(parent.Context, mu) // this is cheap, so we do it even in the unchecked case
|
||||
if parent.Value(lockerKeyOf(mu)) == nil {
|
||||
mu.Lock()
|
||||
} else {
|
||||
mu = nil // already locked by a parent/ancestor
|
||||
var zero T
|
||||
mu = zero // already locked by a parent/ancestor
|
||||
}
|
||||
return unchecked{parent.Context, mu}
|
||||
return unchecked[T]{parent.Context, mu}
|
||||
}
|
||||
|
||||
func (c unchecked) Value(key any) any {
|
||||
if key == lockerKeyOf(c.mu) {
|
||||
func (c unchecked[T]) Value(key any) any {
|
||||
if any(key) == any(lockerKeyOf(c.mu)) {
|
||||
return key
|
||||
}
|
||||
return c.Context.Value(key)
|
||||
}
|
||||
|
||||
func (c unchecked) Unlock() {
|
||||
if c.mu != nil {
|
||||
func (c unchecked[T]) Unlock() {
|
||||
var zero T
|
||||
if any(c.mu) != any(zero) {
|
||||
c.mu.Unlock()
|
||||
}
|
||||
}
|
||||
|
@ -17,24 +17,24 @@ import (
|
||||
// Calling [Context.Unlock] on a [Context] unlocks the mutex locked by the context, if any.
|
||||
// It is a runtime error to call [Context.Unlock] more than once,
|
||||
// or use a [Context] after calling [Context.Unlock].
|
||||
type Context struct {
|
||||
*checked
|
||||
type Context[T sync.Locker] struct {
|
||||
*checked[T]
|
||||
}
|
||||
|
||||
// None returns a [Context] that carries no mutex lock state and an empty [context.Context].
|
||||
//
|
||||
// It is typically used by top-level callers that do not have a parent context to pass in,
|
||||
// and is a shorthand for [Context]([context.Background]).
|
||||
func None() Context {
|
||||
return Context{noneChecked()}
|
||||
func None[T sync.Locker]() Context[T] {
|
||||
return Context[T]{noneChecked[T]()}
|
||||
}
|
||||
|
||||
// Wrap returns a derived [Context] that wraps the provided [context.Context].
|
||||
//
|
||||
// It is typically used by callers that already have a [context.Context],
|
||||
// which may or may not be a [Context] tracking a mutex lock state.
|
||||
func Wrap(parent context.Context) Context {
|
||||
return Context{wrapChecked(parent)}
|
||||
func Wrap[T sync.Locker](parent context.Context) Context[T] {
|
||||
return Context[T]{wrapChecked[T](parent)}
|
||||
}
|
||||
|
||||
// Lock returns a derived [Context] that wraps the provided [context.Context]
|
||||
@ -43,6 +43,6 @@ func Wrap(parent context.Context) Context {
|
||||
// It locks the mutex unless it is already held by the parent or an ancestor [Context].
|
||||
// It is a runtime error to pass a nil mutex or to unlock the parent context
|
||||
// before the returned one.
|
||||
func Lock(parent Context, mu *sync.Mutex) Context {
|
||||
return Context{lockChecked(parent.checked, mu)}
|
||||
func Lock[T, P sync.Locker](parent Context[P], mu T) Context[T] {
|
||||
return Context[T]{lockChecked(parent.checked, mu)}
|
||||
}
|
||||
|
@ -23,20 +23,20 @@ type impl[T ctx] struct {
|
||||
}
|
||||
|
||||
var (
|
||||
exportedImpl = impl[Context]{
|
||||
None: None,
|
||||
Wrap: Wrap,
|
||||
Lock: Lock,
|
||||
exportedImpl = impl[Context[*sync.Mutex]]{
|
||||
None: None[*sync.Mutex],
|
||||
Wrap: Wrap[*sync.Mutex],
|
||||
Lock: Lock[*sync.Mutex, *sync.Mutex],
|
||||
}
|
||||
checkedImpl = impl[*checked]{
|
||||
None: noneChecked,
|
||||
Wrap: wrapChecked,
|
||||
Lock: lockChecked,
|
||||
checkedImpl = impl[*checked[*sync.Mutex]]{
|
||||
None: noneChecked[*sync.Mutex],
|
||||
Wrap: wrapChecked[*sync.Mutex],
|
||||
Lock: lockChecked[*sync.Mutex, *sync.Mutex],
|
||||
}
|
||||
uncheckedImpl = impl[unchecked]{
|
||||
None: func() unchecked { return noneUnchecked },
|
||||
Wrap: wrapUnchecked,
|
||||
Lock: lockUnchecked,
|
||||
uncheckedImpl = impl[unchecked[*sync.Mutex]]{
|
||||
None: noneUnchecked[*sync.Mutex],
|
||||
Wrap: wrapUnchecked[*sync.Mutex],
|
||||
Lock: lockUnchecked[*sync.Mutex, *sync.Mutex],
|
||||
}
|
||||
)
|
||||
|
||||
@ -207,7 +207,7 @@ func TestUnlockParentFirst_Checked(t *testing.T) {
|
||||
func TestUnlockTwice_Checked(t *testing.T) {
|
||||
impl := checkedImpl
|
||||
|
||||
doTest := func(t *testing.T, ctx *checked) {
|
||||
doTest := func(t *testing.T, ctx *checked[*sync.Mutex]) {
|
||||
ctx.Unlock() // unlocks mu
|
||||
wantPanic(t, ctx.Unlock) // panics since mu is already unlocked
|
||||
}
|
||||
|
@ -13,18 +13,18 @@ import (
|
||||
"sync"
|
||||
)
|
||||
|
||||
type Context struct {
|
||||
unchecked
|
||||
type Context[T sync.Locker] struct {
|
||||
unchecked[T]
|
||||
}
|
||||
|
||||
func None() Context {
|
||||
return Context{noneUnchecked}
|
||||
func None[T sync.Locker]() Context[T] {
|
||||
return Context[T]{noneUnchecked[T]()}
|
||||
}
|
||||
|
||||
func Wrap(parent context.Context) Context {
|
||||
return Context{wrapUnchecked(parent)}
|
||||
func Wrap[T sync.Locker](parent context.Context) Context[T] {
|
||||
return Context[T]{wrapUnchecked[T](parent)}
|
||||
}
|
||||
|
||||
func Lock(parent Context, mu *sync.Mutex) Context {
|
||||
return Context{lockUnchecked(parent.unchecked, mu)}
|
||||
func Lock[T, P sync.Locker](parent Context[P], mu T) Context[T] {
|
||||
return Context[T]{lockUnchecked(parent.unchecked, mu)}
|
||||
}
|
||||
|
@ -16,31 +16,31 @@ type Resource struct {
|
||||
foo, bar string
|
||||
}
|
||||
|
||||
func (r *Resource) GetFoo(ctx ctxlock.Context) string {
|
||||
func (r *Resource) GetFoo(ctx ctxlock.Context[*sync.Mutex]) string {
|
||||
defer ctxlock.Lock(ctx, &r.mu).Unlock() // Lock the mutex if not already held.
|
||||
syncs.AssertLocked(&r.mu) // Panics if mu is still unlocked.
|
||||
return r.foo
|
||||
}
|
||||
|
||||
func (r *Resource) SetFoo(ctx ctxlock.Context, foo string) {
|
||||
func (r *Resource) SetFoo(ctx ctxlock.Context[*sync.Mutex], foo string) {
|
||||
defer ctxlock.Lock(ctx, &r.mu).Unlock()
|
||||
syncs.AssertLocked(&r.mu)
|
||||
r.foo = foo
|
||||
}
|
||||
|
||||
func (r *Resource) GetBar(ctx ctxlock.Context) string {
|
||||
func (r *Resource) GetBar(ctx ctxlock.Context[*sync.Mutex]) string {
|
||||
defer ctxlock.Lock(ctx, &r.mu).Unlock()
|
||||
syncs.AssertLocked(&r.mu)
|
||||
return r.bar
|
||||
}
|
||||
|
||||
func (r *Resource) SetBar(ctx ctxlock.Context, bar string) {
|
||||
func (r *Resource) SetBar(ctx ctxlock.Context[*sync.Mutex], bar string) {
|
||||
defer ctxlock.Lock(ctx, &r.mu).Unlock()
|
||||
syncs.AssertLocked(&r.mu)
|
||||
r.bar = bar
|
||||
}
|
||||
|
||||
func (r *Resource) WithLock(ctx ctxlock.Context, f func(ctx ctxlock.Context)) {
|
||||
func (r *Resource) WithLock(ctx ctxlock.Context[*sync.Mutex], f func(ctx ctxlock.Context[*sync.Mutex])) {
|
||||
// Lock the mutex if not already held, and get a new context.
|
||||
ctx = ctxlock.Lock(ctx, &r.mu)
|
||||
defer ctx.Unlock()
|
||||
@ -50,27 +50,27 @@ func (r *Resource) WithLock(ctx ctxlock.Context, f func(ctx ctxlock.Context)) {
|
||||
|
||||
func ExampleContext() {
|
||||
var r Resource
|
||||
r.SetFoo(ctxlock.None(), "foo")
|
||||
r.SetBar(ctxlock.None(), "bar")
|
||||
r.WithLock(ctxlock.None(), func(ctx ctxlock.Context) {
|
||||
r.SetFoo(ctxlock.None[*sync.Mutex](), "foo")
|
||||
r.SetBar(ctxlock.None[*sync.Mutex](), "bar")
|
||||
r.WithLock(ctxlock.None[*sync.Mutex](), func(ctx ctxlock.Context[*sync.Mutex]) {
|
||||
// This callback is invoked with the Resource's lock held,
|
||||
// and the ctx tracks carries the lock state. This means we can safely call
|
||||
// other methods on the Resource using ctx without causing a deadlock.
|
||||
r.SetFoo(ctx, r.GetFoo(ctx)+r.GetBar(ctx))
|
||||
})
|
||||
fmt.Println(r.GetFoo(ctxlock.None()))
|
||||
fmt.Println(r.GetFoo(ctxlock.None[*sync.Mutex]()))
|
||||
// Output: foobar
|
||||
}
|
||||
|
||||
func ExampleContext_twoResources() {
|
||||
var r1, r2 Resource
|
||||
r1.SetFoo(ctxlock.None(), "foo")
|
||||
r2.SetBar(ctxlock.None(), "bar")
|
||||
r1.WithLock(ctxlock.None(), func(ctx ctxlock.Context) {
|
||||
r1.SetFoo(ctxlock.None[*sync.Mutex](), "foo")
|
||||
r2.SetBar(ctxlock.None[*sync.Mutex](), "bar")
|
||||
r1.WithLock(ctxlock.None[*sync.Mutex](), func(ctx ctxlock.Context[*sync.Mutex]) {
|
||||
// Here, r1's lock is held, but r2's lock is not.
|
||||
// So r2 will be locked when we call r2.SetBar(ctx).
|
||||
r1.SetFoo(ctx, r1.GetFoo(ctx)+r2.GetBar(ctx))
|
||||
})
|
||||
fmt.Println(r1.GetFoo(ctxlock.None()))
|
||||
fmt.Println(r1.GetFoo(ctxlock.None[*sync.Mutex]()))
|
||||
// Output: foobar
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user