mirror of
https://github.com/tailscale/tailscale.git
synced 2025-07-29 15:23:45 +00:00
util/ctxlock: make ctxlock.Context generic
Updates #12614 Updates #15824 Signed-off-by: Nick Khyl <nickk@tailscale.com>
This commit is contained in:
parent
f605a99e0b
commit
968e921deb
@ -17,60 +17,62 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
noneCtx = context.Background()
|
noneCtx = context.Background()
|
||||||
noneUnchecked = unchecked{noneCtx, nil}
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type lockerKey struct{ *sync.Mutex }
|
type lockerKey[T any] struct{ key T }
|
||||||
|
|
||||||
func lockerKeyOf(mu *sync.Mutex) lockerKey {
|
func lockerKeyOf[T sync.Locker](mu T) lockerKey[T] {
|
||||||
return lockerKey{mu}
|
return lockerKey[T]{key: mu}
|
||||||
}
|
}
|
||||||
|
|
||||||
// checked is an implementation of [Context] that performs runtime checks
|
// checked is an implementation of [Context] that performs runtime checks
|
||||||
// to ensure that the context is used correctly.
|
// to ensure that the context is used correctly.
|
||||||
type checked struct {
|
type checked[T sync.Locker] struct {
|
||||||
context.Context // nil after [checked.Unlock] is called
|
context.Context // nil after [checked.Unlock] is called
|
||||||
mu *sync.Mutex // nil if the context does not track a mutex lock state
|
mu T // nil if the context does not track a mutex lock state
|
||||||
parent *checked // nil if the context owns the lock
|
parent *checked[T] // nil if the context owns the lock
|
||||||
}
|
}
|
||||||
|
|
||||||
func noneChecked() *checked {
|
func noneChecked[T sync.Locker]() *checked[T] {
|
||||||
return &checked{noneCtx, nil, nil}
|
var zero T
|
||||||
|
return &checked[T]{noneCtx, zero, nil}
|
||||||
}
|
}
|
||||||
|
|
||||||
func wrapChecked(parent context.Context) *checked {
|
func wrapChecked[T sync.Locker](parent context.Context) *checked[T] {
|
||||||
return &checked{parent, nil, nil}
|
var zero T
|
||||||
|
return &checked[T]{parent, zero, nil}
|
||||||
}
|
}
|
||||||
|
|
||||||
func lockChecked(parent *checked, mu *sync.Mutex) *checked {
|
func lockChecked[T, P sync.Locker](parent *checked[P], mu T) *checked[T] {
|
||||||
checkLockArgs(parent, mu)
|
checkLockArgs(parent, mu)
|
||||||
if parentLockCtx, ok := parent.Value(lockerKeyOf(mu)).(*checked); ok {
|
if parentLockCtx, ok := parent.Value(lockerKeyOf(mu)).(*checked[T]); ok {
|
||||||
if appearsUnlocked(mu) {
|
if appearsUnlocked(mu) {
|
||||||
// The parent still owns the lock, but the mutex is unlocked.
|
// The parent still owns the lock, but the mutex is unlocked.
|
||||||
panic("mu is already unlocked")
|
panic("mu is already unlocked")
|
||||||
}
|
}
|
||||||
return &checked{parent, mu, parentLockCtx}
|
return &checked[T]{parent, mu, parentLockCtx}
|
||||||
}
|
}
|
||||||
mu.Lock()
|
mu.Lock()
|
||||||
return &checked{parent, mu, nil}
|
return &checked[T]{parent, mu, nil}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *checked) Value(key any) any {
|
func (c *checked[T]) Value(key any) any {
|
||||||
if c.Context == nil {
|
if c.Context == nil {
|
||||||
panic("use of context after unlock")
|
panic("use of context after unlock")
|
||||||
}
|
}
|
||||||
if key == lockerKeyOf(c.mu) {
|
if key == any(lockerKeyOf(c.mu)) {
|
||||||
return c
|
return c
|
||||||
}
|
}
|
||||||
return c.Context.Value(key)
|
return c.Context.Value(key)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *checked) Unlock() {
|
func (c *checked[T]) Unlock() {
|
||||||
|
var zero T
|
||||||
switch {
|
switch {
|
||||||
case c.Context == nil:
|
case c.Context == nil:
|
||||||
panic("already unlocked")
|
panic("already unlocked")
|
||||||
case c.mu == nil:
|
case any(c.mu) == any(zero):
|
||||||
// No-op; the context does not track a mutex lock state,
|
// No-op; the context does not track a mutex lock state,
|
||||||
// such as when it was created with [noneChecked] or [wrapChecked].
|
// such as when it was created with [noneChecked] or [wrapChecked].
|
||||||
case appearsUnlocked(c.mu):
|
case appearsUnlocked(c.mu):
|
||||||
@ -88,45 +90,54 @@ func (c *checked) Unlock() {
|
|||||||
func checkLockArgs[T interface {
|
func checkLockArgs[T interface {
|
||||||
context.Context
|
context.Context
|
||||||
comparable
|
comparable
|
||||||
}](parent T, mu *sync.Mutex) {
|
}, L sync.Locker](parent T, mu L) {
|
||||||
var zero T
|
var zero T
|
||||||
|
var nilLocker L
|
||||||
if parent == zero {
|
if parent == zero {
|
||||||
panic("nil parent context")
|
panic("nil parent context")
|
||||||
}
|
}
|
||||||
if mu == nil {
|
if any(mu) == any(nilLocker) {
|
||||||
panic("nil locker")
|
panic("nil locker")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// unchecked is an implementation of [Context] that trades runtime checks for performance.
|
// unchecked is an implementation of [Context] that trades runtime checks for performance.
|
||||||
type unchecked struct {
|
type unchecked[T sync.Locker] struct {
|
||||||
context.Context // always non-nil
|
context.Context // always non-nil
|
||||||
mu *sync.Mutex // non-nil if locked by this context
|
mu T // non-nil if locked by this context
|
||||||
}
|
}
|
||||||
|
|
||||||
func wrapUnchecked(parent context.Context) unchecked {
|
func noneUnchecked[T sync.Locker]() unchecked[T] {
|
||||||
return unchecked{parent, nil}
|
var zero T
|
||||||
|
return unchecked[T]{noneCtx, zero}
|
||||||
}
|
}
|
||||||
|
|
||||||
func lockUnchecked(parent unchecked, mu *sync.Mutex) unchecked {
|
func wrapUnchecked[T sync.Locker](parent context.Context) unchecked[T] {
|
||||||
checkLockArgs(parent, mu) // this is cheap, so we do it even in the unchecked case
|
var zero T
|
||||||
|
return unchecked[T]{parent, zero}
|
||||||
|
}
|
||||||
|
|
||||||
|
func lockUnchecked[T, P sync.Locker](parent unchecked[P], mu T) unchecked[T] {
|
||||||
|
checkLockArgs(parent.Context, mu) // this is cheap, so we do it even in the unchecked case
|
||||||
if parent.Value(lockerKeyOf(mu)) == nil {
|
if parent.Value(lockerKeyOf(mu)) == nil {
|
||||||
mu.Lock()
|
mu.Lock()
|
||||||
} else {
|
} else {
|
||||||
mu = nil // already locked by a parent/ancestor
|
var zero T
|
||||||
|
mu = zero // already locked by a parent/ancestor
|
||||||
}
|
}
|
||||||
return unchecked{parent.Context, mu}
|
return unchecked[T]{parent.Context, mu}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c unchecked) Value(key any) any {
|
func (c unchecked[T]) Value(key any) any {
|
||||||
if key == lockerKeyOf(c.mu) {
|
if any(key) == any(lockerKeyOf(c.mu)) {
|
||||||
return key
|
return key
|
||||||
}
|
}
|
||||||
return c.Context.Value(key)
|
return c.Context.Value(key)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c unchecked) Unlock() {
|
func (c unchecked[T]) Unlock() {
|
||||||
if c.mu != nil {
|
var zero T
|
||||||
|
if any(c.mu) != any(zero) {
|
||||||
c.mu.Unlock()
|
c.mu.Unlock()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -17,24 +17,24 @@ import (
|
|||||||
// Calling [Context.Unlock] on a [Context] unlocks the mutex locked by the context, if any.
|
// Calling [Context.Unlock] on a [Context] unlocks the mutex locked by the context, if any.
|
||||||
// It is a runtime error to call [Context.Unlock] more than once,
|
// It is a runtime error to call [Context.Unlock] more than once,
|
||||||
// or use a [Context] after calling [Context.Unlock].
|
// or use a [Context] after calling [Context.Unlock].
|
||||||
type Context struct {
|
type Context[T sync.Locker] struct {
|
||||||
*checked
|
*checked[T]
|
||||||
}
|
}
|
||||||
|
|
||||||
// None returns a [Context] that carries no mutex lock state and an empty [context.Context].
|
// None returns a [Context] that carries no mutex lock state and an empty [context.Context].
|
||||||
//
|
//
|
||||||
// It is typically used by top-level callers that do not have a parent context to pass in,
|
// It is typically used by top-level callers that do not have a parent context to pass in,
|
||||||
// and is a shorthand for [Context]([context.Background]).
|
// and is a shorthand for [Context]([context.Background]).
|
||||||
func None() Context {
|
func None[T sync.Locker]() Context[T] {
|
||||||
return Context{noneChecked()}
|
return Context[T]{noneChecked[T]()}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Wrap returns a derived [Context] that wraps the provided [context.Context].
|
// Wrap returns a derived [Context] that wraps the provided [context.Context].
|
||||||
//
|
//
|
||||||
// It is typically used by callers that already have a [context.Context],
|
// It is typically used by callers that already have a [context.Context],
|
||||||
// which may or may not be a [Context] tracking a mutex lock state.
|
// which may or may not be a [Context] tracking a mutex lock state.
|
||||||
func Wrap(parent context.Context) Context {
|
func Wrap[T sync.Locker](parent context.Context) Context[T] {
|
||||||
return Context{wrapChecked(parent)}
|
return Context[T]{wrapChecked[T](parent)}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Lock returns a derived [Context] that wraps the provided [context.Context]
|
// Lock returns a derived [Context] that wraps the provided [context.Context]
|
||||||
@ -43,6 +43,6 @@ func Wrap(parent context.Context) Context {
|
|||||||
// It locks the mutex unless it is already held by the parent or an ancestor [Context].
|
// It locks the mutex unless it is already held by the parent or an ancestor [Context].
|
||||||
// It is a runtime error to pass a nil mutex or to unlock the parent context
|
// It is a runtime error to pass a nil mutex or to unlock the parent context
|
||||||
// before the returned one.
|
// before the returned one.
|
||||||
func Lock(parent Context, mu *sync.Mutex) Context {
|
func Lock[T, P sync.Locker](parent Context[P], mu T) Context[T] {
|
||||||
return Context{lockChecked(parent.checked, mu)}
|
return Context[T]{lockChecked(parent.checked, mu)}
|
||||||
}
|
}
|
||||||
|
@ -23,20 +23,20 @@ type impl[T ctx] struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
var (
|
var (
|
||||||
exportedImpl = impl[Context]{
|
exportedImpl = impl[Context[*sync.Mutex]]{
|
||||||
None: None,
|
None: None[*sync.Mutex],
|
||||||
Wrap: Wrap,
|
Wrap: Wrap[*sync.Mutex],
|
||||||
Lock: Lock,
|
Lock: Lock[*sync.Mutex, *sync.Mutex],
|
||||||
}
|
}
|
||||||
checkedImpl = impl[*checked]{
|
checkedImpl = impl[*checked[*sync.Mutex]]{
|
||||||
None: noneChecked,
|
None: noneChecked[*sync.Mutex],
|
||||||
Wrap: wrapChecked,
|
Wrap: wrapChecked[*sync.Mutex],
|
||||||
Lock: lockChecked,
|
Lock: lockChecked[*sync.Mutex, *sync.Mutex],
|
||||||
}
|
}
|
||||||
uncheckedImpl = impl[unchecked]{
|
uncheckedImpl = impl[unchecked[*sync.Mutex]]{
|
||||||
None: func() unchecked { return noneUnchecked },
|
None: noneUnchecked[*sync.Mutex],
|
||||||
Wrap: wrapUnchecked,
|
Wrap: wrapUnchecked[*sync.Mutex],
|
||||||
Lock: lockUnchecked,
|
Lock: lockUnchecked[*sync.Mutex, *sync.Mutex],
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -207,7 +207,7 @@ func TestUnlockParentFirst_Checked(t *testing.T) {
|
|||||||
func TestUnlockTwice_Checked(t *testing.T) {
|
func TestUnlockTwice_Checked(t *testing.T) {
|
||||||
impl := checkedImpl
|
impl := checkedImpl
|
||||||
|
|
||||||
doTest := func(t *testing.T, ctx *checked) {
|
doTest := func(t *testing.T, ctx *checked[*sync.Mutex]) {
|
||||||
ctx.Unlock() // unlocks mu
|
ctx.Unlock() // unlocks mu
|
||||||
wantPanic(t, ctx.Unlock) // panics since mu is already unlocked
|
wantPanic(t, ctx.Unlock) // panics since mu is already unlocked
|
||||||
}
|
}
|
||||||
|
@ -13,18 +13,18 @@ import (
|
|||||||
"sync"
|
"sync"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Context struct {
|
type Context[T sync.Locker] struct {
|
||||||
unchecked
|
unchecked[T]
|
||||||
}
|
}
|
||||||
|
|
||||||
func None() Context {
|
func None[T sync.Locker]() Context[T] {
|
||||||
return Context{noneUnchecked}
|
return Context[T]{noneUnchecked[T]()}
|
||||||
}
|
}
|
||||||
|
|
||||||
func Wrap(parent context.Context) Context {
|
func Wrap[T sync.Locker](parent context.Context) Context[T] {
|
||||||
return Context{wrapUnchecked(parent)}
|
return Context[T]{wrapUnchecked[T](parent)}
|
||||||
}
|
}
|
||||||
|
|
||||||
func Lock(parent Context, mu *sync.Mutex) Context {
|
func Lock[T, P sync.Locker](parent Context[P], mu T) Context[T] {
|
||||||
return Context{lockUnchecked(parent.unchecked, mu)}
|
return Context[T]{lockUnchecked(parent.unchecked, mu)}
|
||||||
}
|
}
|
||||||
|
@ -16,31 +16,31 @@ type Resource struct {
|
|||||||
foo, bar string
|
foo, bar string
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *Resource) GetFoo(ctx ctxlock.Context) string {
|
func (r *Resource) GetFoo(ctx ctxlock.Context[*sync.Mutex]) string {
|
||||||
defer ctxlock.Lock(ctx, &r.mu).Unlock() // Lock the mutex if not already held.
|
defer ctxlock.Lock(ctx, &r.mu).Unlock() // Lock the mutex if not already held.
|
||||||
syncs.AssertLocked(&r.mu) // Panics if mu is still unlocked.
|
syncs.AssertLocked(&r.mu) // Panics if mu is still unlocked.
|
||||||
return r.foo
|
return r.foo
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *Resource) SetFoo(ctx ctxlock.Context, foo string) {
|
func (r *Resource) SetFoo(ctx ctxlock.Context[*sync.Mutex], foo string) {
|
||||||
defer ctxlock.Lock(ctx, &r.mu).Unlock()
|
defer ctxlock.Lock(ctx, &r.mu).Unlock()
|
||||||
syncs.AssertLocked(&r.mu)
|
syncs.AssertLocked(&r.mu)
|
||||||
r.foo = foo
|
r.foo = foo
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *Resource) GetBar(ctx ctxlock.Context) string {
|
func (r *Resource) GetBar(ctx ctxlock.Context[*sync.Mutex]) string {
|
||||||
defer ctxlock.Lock(ctx, &r.mu).Unlock()
|
defer ctxlock.Lock(ctx, &r.mu).Unlock()
|
||||||
syncs.AssertLocked(&r.mu)
|
syncs.AssertLocked(&r.mu)
|
||||||
return r.bar
|
return r.bar
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *Resource) SetBar(ctx ctxlock.Context, bar string) {
|
func (r *Resource) SetBar(ctx ctxlock.Context[*sync.Mutex], bar string) {
|
||||||
defer ctxlock.Lock(ctx, &r.mu).Unlock()
|
defer ctxlock.Lock(ctx, &r.mu).Unlock()
|
||||||
syncs.AssertLocked(&r.mu)
|
syncs.AssertLocked(&r.mu)
|
||||||
r.bar = bar
|
r.bar = bar
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *Resource) WithLock(ctx ctxlock.Context, f func(ctx ctxlock.Context)) {
|
func (r *Resource) WithLock(ctx ctxlock.Context[*sync.Mutex], f func(ctx ctxlock.Context[*sync.Mutex])) {
|
||||||
// Lock the mutex if not already held, and get a new context.
|
// Lock the mutex if not already held, and get a new context.
|
||||||
ctx = ctxlock.Lock(ctx, &r.mu)
|
ctx = ctxlock.Lock(ctx, &r.mu)
|
||||||
defer ctx.Unlock()
|
defer ctx.Unlock()
|
||||||
@ -50,27 +50,27 @@ func (r *Resource) WithLock(ctx ctxlock.Context, f func(ctx ctxlock.Context)) {
|
|||||||
|
|
||||||
func ExampleContext() {
|
func ExampleContext() {
|
||||||
var r Resource
|
var r Resource
|
||||||
r.SetFoo(ctxlock.None(), "foo")
|
r.SetFoo(ctxlock.None[*sync.Mutex](), "foo")
|
||||||
r.SetBar(ctxlock.None(), "bar")
|
r.SetBar(ctxlock.None[*sync.Mutex](), "bar")
|
||||||
r.WithLock(ctxlock.None(), func(ctx ctxlock.Context) {
|
r.WithLock(ctxlock.None[*sync.Mutex](), func(ctx ctxlock.Context[*sync.Mutex]) {
|
||||||
// This callback is invoked with the Resource's lock held,
|
// This callback is invoked with the Resource's lock held,
|
||||||
// and the ctx tracks carries the lock state. This means we can safely call
|
// and the ctx tracks carries the lock state. This means we can safely call
|
||||||
// other methods on the Resource using ctx without causing a deadlock.
|
// other methods on the Resource using ctx without causing a deadlock.
|
||||||
r.SetFoo(ctx, r.GetFoo(ctx)+r.GetBar(ctx))
|
r.SetFoo(ctx, r.GetFoo(ctx)+r.GetBar(ctx))
|
||||||
})
|
})
|
||||||
fmt.Println(r.GetFoo(ctxlock.None()))
|
fmt.Println(r.GetFoo(ctxlock.None[*sync.Mutex]()))
|
||||||
// Output: foobar
|
// Output: foobar
|
||||||
}
|
}
|
||||||
|
|
||||||
func ExampleContext_twoResources() {
|
func ExampleContext_twoResources() {
|
||||||
var r1, r2 Resource
|
var r1, r2 Resource
|
||||||
r1.SetFoo(ctxlock.None(), "foo")
|
r1.SetFoo(ctxlock.None[*sync.Mutex](), "foo")
|
||||||
r2.SetBar(ctxlock.None(), "bar")
|
r2.SetBar(ctxlock.None[*sync.Mutex](), "bar")
|
||||||
r1.WithLock(ctxlock.None(), func(ctx ctxlock.Context) {
|
r1.WithLock(ctxlock.None[*sync.Mutex](), func(ctx ctxlock.Context[*sync.Mutex]) {
|
||||||
// Here, r1's lock is held, but r2's lock is not.
|
// Here, r1's lock is held, but r2's lock is not.
|
||||||
// So r2 will be locked when we call r2.SetBar(ctx).
|
// So r2 will be locked when we call r2.SetBar(ctx).
|
||||||
r1.SetFoo(ctx, r1.GetFoo(ctx)+r2.GetBar(ctx))
|
r1.SetFoo(ctx, r1.GetFoo(ctx)+r2.GetBar(ctx))
|
||||||
})
|
})
|
||||||
fmt.Println(r1.GetFoo(ctxlock.None()))
|
fmt.Println(r1.GetFoo(ctxlock.None[*sync.Mutex]()))
|
||||||
// Output: foobar
|
// Output: foobar
|
||||||
}
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user