diff --git a/util/ctxlock/ctx.go b/util/ctxlock/ctx.go new file mode 100644 index 000000000..fa224322b --- /dev/null +++ b/util/ctxlock/ctx.go @@ -0,0 +1,157 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package ctxlock provides a [context.Context] implementation that carries mutex lock state +// and enables reentrant locking. It offers two implementations: checked and unchecked. +// The checked implementation performs runtime validation to ensure that: +// - a parent context is not unlocked before its child, +// - a context is only unlocked once, and +// - a context is not used after being unlocked. +// The unchecked implementation skips these checks for improved performance. +// It defaults to the checked implementation unless the ts_omit_ctxlock_checks build tag is set. +package ctxlock + +import ( + "context" + "fmt" + "sync" +) + +var ( + noneCtx = context.Background() + noneUnchecked = unchecked{noneCtx, nil} +) + +type ctxKey struct{ *sync.Mutex } + +func ctxKeyOf(mu *sync.Mutex) ctxKey { + return ctxKey{mu} +} + +// checked is an implementation of [Context] that performs runtime checks +// to ensure that the context is used correctly. +type checked struct { + context.Context // nil after [checked.Unlock] is called + mu *sync.Mutex // nil if the context does not track a mutex lock state + parent *checked // nil if the context owns the lock +} + +func noneChecked() *checked { + return &checked{noneCtx, nil, nil} +} + +func wrapChecked(parent context.Context) *checked { + return &checked{parent, nil, nil} +} + +func lockChecked(parent *checked, mu *sync.Mutex) *checked { + checkLockArgs(parent, mu) + if parentLockCtx, ok := parent.Value(ctxKeyOf(mu)).(*checked); ok { + if appearsUnlocked(mu) { + // The parent still owns the lock, but the mutex is unlocked. + panic("mu is spuriously unlocked") + } + return &checked{parent, mu, parentLockCtx} + } + mu.Lock() + return &checked{parent, mu, nil} +} + +func (c *checked) Value(key any) any { + if c.Context == nil { + panic("use of context after unlock") + } + if key == ctxKeyOf(c.mu) { + return c + } + return c.Context.Value(key) +} + +func (c *checked) Unlock() { + switch { + case c.Context == nil: + panic("already unlocked") + case c.mu == nil: + // No-op; the context does not track a mutex lock state, + // such as when it was created with [noneChecked] or [wrapChecked]. + case c.parent == nil: + // We own the lock; let's unlock it. + // This panics if the mutex is already unlocked. + c.mu.Unlock() + case c.parent.Context == nil: + // The parent context is already unlocked. + // The mutex may or may not be locked; + // something else may have already locked it. + panic("parent already unlocked") + case appearsUnlocked(c.mu): + // The mutex itself is unlocked, + // even though the parent context is still locked. + // It may be unlocked by an ancestor context + // or by something else entirely. + panic("mutex is not locked") + default: + // No-op; a parent or ancestor will handle unlocking. + } + c.Context = nil +} + +func checkLockArgs[T interface { + context.Context + comparable +}](parent T, mu *sync.Mutex) { + var zero T + if parent == zero { + panic("nil parent context") + } + if mu == nil { + panic(fmt.Sprintf("nil %T", mu)) + } +} + +// unchecked is an implementation of [Context] that trades runtime checks for performance. +type unchecked struct { + context.Context // always non-nil + mu *sync.Mutex // non-nil if locked by this context +} + +func wrapUnchecked(parent context.Context) unchecked { + return unchecked{parent, nil} +} + +func lockUnchecked(parent unchecked, mu *sync.Mutex) unchecked { + checkLockArgs(parent, mu) // this is cheap, so we do it even in the unchecked case + if parent.Value(ctxKeyOf(mu)) == nil { + mu.Lock() + } else { + mu = nil // already locked by a parent/ancestor + } + return unchecked{parent.Context, mu} +} + +func (c unchecked) Value(key any) any { + if key == ctxKeyOf(c.mu) { + return key + } + return c.Context.Value(key) +} + +func (c unchecked) Unlock() { + if c.mu != nil { + c.mu.Unlock() + } +} + +type tryLocker interface { + TryLock() bool + Unlock() +} + +// appearsUnlocked reports whether m is unlocked. +// It may return a false negative if m does not have a TryLock method. +func appearsUnlocked[T sync.Locker](m T) bool { + if m, ok := any(m).(tryLocker); ok && m.TryLock() { + m.Unlock() + return true + } + return false +} diff --git a/util/ctxlock/ctx_checked.go b/util/ctxlock/ctx_checked.go new file mode 100644 index 000000000..ac6936f4b --- /dev/null +++ b/util/ctxlock/ctx_checked.go @@ -0,0 +1,48 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// This file exports default, unoptimized implementation of the [Context] that includes runtime checks. +// It is used unless the build tag ts_omit_ctxlock_checks is set. + +//go:build !ts_omit_ctxlock_checks + +package ctxlock + +import ( + "context" + "sync" +) + +// Context is a [context.Context] that can carry a [sync.Mutex] lock state. +// Calling [Context.Unlock] on a [Context] unlocks the mutex locked by the context, if any. +// It is a runtime error to call [Context.Unlock] more than once, +// or use a [Context] after calling [Context.Unlock]. +type Context struct { + *checked +} + +// None returns a [Context] that carries no mutex lock state and an empty [context.Context]. +// +// It is typically used by top-level callers that do not have a parent context to pass in, +// and is a shorthand for [Context]([context.Background]). +func None() Context { + return Context{noneChecked()} +} + +// Wrap returns a derived [Context] that wraps the provided [context.Context]. +// +// It is typically used by callers that already have a [context.Context], +// which may or may not be a [Context] tracking a mutex lock state. +func Wrap(parent context.Context) Context { + return Context{wrapChecked(parent)} +} + +// Lock returns a derived [Context] that wraps the provided [context.Context] +// and carries the mutex lock state. +// +// It locks the mutex unless it is already held by the parent or an ancestor [Context]. +// It is a runtime error to pass a nil mutex or to unlock the parent context +// before the returned one. +func Lock(parent Context, mu *sync.Mutex) Context { + return Context{lockChecked(parent.checked, mu)} +} diff --git a/util/ctxlock/ctx_test.go b/util/ctxlock/ctx_test.go new file mode 100644 index 000000000..ff418a139 --- /dev/null +++ b/util/ctxlock/ctx_test.go @@ -0,0 +1,368 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package ctxlock + +import ( + "context" + "sync" + "testing" + + "tailscale.com/util/ctxkey" +) + +type ctx interface { + context.Context + Unlock() +} + +type impl[T ctx] struct { + None func() T + Wrap func(context.Context) T + Lock func(T, *sync.Mutex) T +} + +var ( + exportedImpl = impl[Context]{ + None: None, + Wrap: Wrap, + Lock: Lock, + } + checkedImpl = impl[*checked]{ + None: noneChecked, + Wrap: wrapChecked, + Lock: lockChecked, + } + uncheckedImpl = impl[unchecked]{ + None: func() unchecked { return noneUnchecked }, + Wrap: wrapUnchecked, + Lock: lockUnchecked, + } +) + +func BenchmarkReentrance(b *testing.B) { + var mu sync.Mutex + + b.Run("Exported", func(b *testing.B) { + benchmarkReentrance(b, exportedImpl) + }) + b.Run("Checked", func(b *testing.B) { + benchmarkReentrance(b, checkedImpl) + }) + b.Run("Unchecked", func(b *testing.B) { + benchmarkReentrance(b, uncheckedImpl) + }) + b.Run("Reference", func(b *testing.B) { + for b.Loop() { + mu.Lock() + func(mu *sync.Mutex) { + if mu.TryLock() { + mu.Unlock() + } + }(&mu) + mu.Unlock() + } + }) +} + +func benchmarkReentrance[T ctx](b *testing.B, impl impl[T]) { + var mu sync.Mutex + for b.Loop() { + parent := impl.Lock(impl.None(), &mu) + func(ctx T) { + child := impl.Lock(ctx, &mu) + child.Unlock() + }(parent) + parent.Unlock() + } +} + +func TestHappyPath(t *testing.T) { + t.Run("Exported", func(t *testing.T) { + testHappyPath(t, exportedImpl) + }) + + t.Run("Checked", func(t *testing.T) { + testHappyPath(t, checkedImpl) + }) + + t.Run("Unchecked", func(t *testing.T) { + testHappyPath(t, uncheckedImpl) + }) +} + +func testHappyPath[T ctx](t *testing.T, impl impl[T]) { + var mu sync.Mutex + parent := impl.Lock(impl.None(), &mu) + wantLocked(t, &mu) // mu is locked by parent + + child := impl.Lock(parent, &mu) + wantLocked(t, &mu) // mu is still locked by parent + + var mu2 sync.Mutex + context2 := impl.Lock(child, &mu2) + wantLocked(t, &mu2) // mu2 is locked by context2 + context2.Unlock() // unlocks mu2 + wantUnlocked(t, &mu2) // mu2 is now unlocked + + child.Unlock() // noop + wantLocked(t, &mu) // mu is still locked by parent + + parent.Unlock() // unlocks mu + wantUnlocked(t, &mu) // mu is now unlocked +} + +func TestWrappedLockContext(t *testing.T) { + t.Run("Exported", func(t *testing.T) { + testWrappedLockContext(t, exportedImpl) + }) + + t.Run("Checked", func(t *testing.T) { + testWrappedLockContext(t, checkedImpl) + }) + + t.Run("Unchecked", func(t *testing.T) { + testWrappedLockContext(t, uncheckedImpl) + }) +} + +func testWrappedLockContext[T ctx](t *testing.T, impl impl[T]) { + wantValue := "value" + key := ctxkey.New("key", "") + ctxWithValue := key.WithValue(context.Background(), wantValue) + root := impl.Wrap(ctxWithValue) + + var mu sync.Mutex + parent := impl.Lock(root, &mu) + wantLocked(t, &mu) // mu is locked by parent + + // Wrap the parent context as if it were a regular [context.Context], + // then create a child context from it. + // The child should still recognize the parent as the mutex owner, + // and not panic or deadlock attempting to lock it again. + wrapped := impl.Wrap(parent) + child := impl.Lock(wrapped, &mu) + + // We should be able to access the value set in the root context. + if gotValue := key.Value(child); gotValue != wantValue { + t.Errorf("key.Value() = %s; want %s", gotValue, wantValue) + } + + child.Unlock() // no-op; mu is owned by parent + wantLocked(t, &mu) // mu is still locked by parent + + wrapped.Unlock() // no-op; mu is owned by parent + wantLocked(t, &mu) // mu is still locked by parent + + parent.Unlock() // unlocks mu + wantUnlocked(t, &mu) // mu is now unlocked +} + +func TestNilContextAndMutex(t *testing.T) { + t.Run("Exported", func(t *testing.T) { + testNilContextAndMutex(t, exportedImpl) + }) + + t.Run("Checked", func(t *testing.T) { + testNilContextAndMutex(t, checkedImpl) + }) + + t.Run("Unchecked", func(t *testing.T) { + testNilContextAndMutex(t, uncheckedImpl) + }) +} + +func testNilContextAndMutex[T ctx](t *testing.T, impl impl[T]) { + t.Run("NilContext", func(t *testing.T) { + var zero T + wantPanic(t, "nil parent context", func() { impl.Lock(zero, &sync.Mutex{}) }) + }) + t.Run("NilMutex", func(t *testing.T) { + wantPanic(t, "nil *sync.Mutex", func() { impl.Lock(impl.None(), nil) }) + }) +} + +func TestUseUnlockedParent_Checked(t *testing.T) { + impl := checkedImpl + + var mu sync.Mutex + parent := impl.Lock(impl.None(), &mu) + parent.Unlock() // unlocks mu + wantUnlocked(t, &mu) // mu is now unlocked + wantPanic(t, "use of context after unlock", func() { impl.Lock(parent, &mu) }) +} + +func TestUseUnlockedMutex_Checked(t *testing.T) { + impl := checkedImpl + + var mu sync.Mutex + parent := impl.Lock(impl.None(), &mu) + mu.Unlock() // unlock mu directly without unlocking parent + wantPanic(t, "mu is spuriously unlocked", func() { impl.Lock(parent, &mu) }) +} + +func TestUnlockParentFirst_Checked(t *testing.T) { + impl := checkedImpl + + var mu sync.Mutex + parent := impl.Lock(impl.None(), &mu) + child := impl.Lock(parent, &mu) + + parent.Unlock() // unlocks mu + wantUnlocked(t, &mu) // mu is now unlocked + wantPanic(t, "parent already unlocked", child.Unlock) +} + +func TestUnlockTwice_Checked(t *testing.T) { + impl := checkedImpl + + unlockTwice := func(t *testing.T, ctx *checked) { + ctx.Unlock() // unlocks mu + wantPanic(t, "already unlocked", ctx.Unlock) + } + + t.Run("None", func(t *testing.T) { + unlockTwice(t, impl.None()) + }) + t.Run("Wrapped", func(t *testing.T) { + unlockTwice(t, impl.Wrap(context.Background())) + }) + t.Run("Locked", func(t *testing.T) { + var mu sync.Mutex + ctx := impl.Lock(impl.None(), &mu) + unlockTwice(t, ctx) + }) + t.Run("Locked/WithReloc", func(t *testing.T) { + var mu sync.Mutex + ctx := impl.Lock(impl.None(), &mu) + ctx.Unlock() // unlocks mu + mu.Lock() // re-locks mu, but not by the context + wantPanic(t, "already unlocked", ctx.Unlock) + }) + t.Run("Child", func(t *testing.T) { + var mu sync.Mutex + parent := impl.Lock(impl.None(), &mu) + defer parent.Unlock() + child := impl.Lock(parent, &mu) + unlockTwice(t, child) + }) + t.Run("Child/WithReloc", func(t *testing.T) { + var mu sync.Mutex + parent := impl.Lock(impl.None(), &mu) + child := impl.Lock(parent, &mu) + parent.Unlock() + mu.Lock() // re-locks mu, but not the parent context + wantPanic(t, "parent already unlocked", child.Unlock) + }) + t.Run("Child/WithManualUnlock", func(t *testing.T) { + var mu sync.Mutex + parent := impl.Lock(impl.None(), &mu) + child := impl.Lock(parent, &mu) + mu.Unlock() // unlocks mu, but not the parent context + wantPanic(t, "mutex is not locked", child.Unlock) + }) + t.Run("Grandchild", func(t *testing.T) { + var mu sync.Mutex + parent := impl.Lock(impl.None(), &mu) + defer parent.Unlock() + child := impl.Lock(parent, &mu) + defer child.Unlock() + grandchild := impl.Lock(child, &mu) + unlockTwice(t, grandchild) + }) +} + +func TestUseUnlocked_Checked(t *testing.T) { + impl := checkedImpl + + var mu sync.Mutex + ctx := lockChecked(impl.None(), &mu) + ctx.Unlock() + + // All of these should panic since the context is already unlocked. + wantPanic(t, "", func() { ctx.Deadline() }) + wantPanic(t, "", func() { ctx.Done() }) + wantPanic(t, "", func() { ctx.Err() }) + wantPanic(t, "", func() { ctx.Unlock() }) + wantPanic(t, "", func() { ctx.Value("key") }) +} + +func TestUseNoneContext(t *testing.T) { + t.Run("Exported", func(t *testing.T) { + testUseEmptyContext(t, exportedImpl.None, exportedImpl) + }) + t.Run("Checked", func(t *testing.T) { + testUseEmptyContext(t, checkedImpl.None, checkedImpl) + }) + t.Run("Unchecked", func(t *testing.T) { + testUseEmptyContext(t, uncheckedImpl.None, uncheckedImpl) + }) +} + +func TestUseWrappedBackground(t *testing.T) { + t.Run("Exported", func(t *testing.T) { + testUseEmptyContext(t, getWrappedBackground(t, exportedImpl), exportedImpl) + }) + t.Run("Checked", func(t *testing.T) { + testUseEmptyContext(t, getWrappedBackground(t, checkedImpl), checkedImpl) + }) + t.Run("Unchecked", func(t *testing.T) { + testUseEmptyContext(t, getWrappedBackground(t, uncheckedImpl), uncheckedImpl) + }) +} + +func getWrappedBackground[T ctx](t *testing.T, impl impl[T]) func() T { + t.Helper() + return func() T { + return impl.Wrap(context.Background()) + } +} + +func testUseEmptyContext[T ctx](t *testing.T, getCtx func() T, impl impl[T]) { + // Using a None context must not panic or deadlock. + // It should also behave like [context.Background]. + for range 2 { + ctx := getCtx() + if gotDone := ctx.Done(); gotDone != nil { + t.Errorf("ctx.Done() = %v; want nil", gotDone) + } + if gotDeadline, ok := ctx.Deadline(); ok { + t.Errorf("ctx.Deadline() = %v; want !ok", gotDeadline) + } + if gotErr := ctx.Err(); gotErr != nil { + t.Errorf("ctx.Err() = %v; want nil", gotErr) + } + if gotValue := ctx.Value("test-key"); gotValue != nil { + t.Errorf("ctx.Value(test-key) = %v; want nil", gotValue) + } + ctx.Unlock() + } +} + +func wantPanic(t *testing.T, wantMsg string, fn func()) { + t.Helper() + defer func() { + if r := recover(); wantMsg != "" { + if gotMsg, ok := r.(string); !ok || gotMsg != wantMsg { + t.Errorf("panic: %v; want %q", r, wantMsg) + } + } + }() + fn() + t.Fatal("failed to panic") +} + +func wantLocked(t *testing.T, m *sync.Mutex) { + if m.TryLock() { + m.Unlock() + t.Fatal("mutex is not locked") + } +} + +func wantUnlocked(t *testing.T, m *sync.Mutex) { + t.Helper() + if !m.TryLock() { + t.Fatal("mutex is locked") + } + m.Unlock() +} diff --git a/util/ctxlock/ctx_unchecked.go b/util/ctxlock/ctx_unchecked.go new file mode 100644 index 000000000..fa5c06f77 --- /dev/null +++ b/util/ctxlock/ctx_unchecked.go @@ -0,0 +1,30 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// This file exports optimized implementation of the [Context] that omits runtime checks. +// It is used when the build tag ts_omit_ctxlock_checks is set. + +//go:build ts_omit_ctxlock_checks + +package ctxlock + +import ( + "context" + "sync" +) + +type Context struct { + unchecked +} + +func None() Context { + return Context{noneUnchecked} +} + +func Wrap(parent context.Context) Context { + return Context{wrapUnchecked(parent)} +} + +func Lock(parent Context, mu *sync.Mutex) Context { + return Context{lockUnchecked(parent.unchecked, mu)} +} diff --git a/util/ctxlock/doc_test.go b/util/ctxlock/doc_test.go new file mode 100644 index 000000000..69b0c6ebe --- /dev/null +++ b/util/ctxlock/doc_test.go @@ -0,0 +1,76 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package ctxlock_test + +import ( + "fmt" + "sync" + + "tailscale.com/syncs" + "tailscale.com/util/ctxlock" +) + +type Resource struct { + mu sync.Mutex + foo, bar string +} + +func (r *Resource) GetFoo(ctx ctxlock.Context) string { + defer ctxlock.Lock(ctx, &r.mu).Unlock() // Lock the mutex if not already held. + syncs.AssertLocked(&r.mu) // Panic if mu is still unlocked. + return r.foo +} + +func (r *Resource) SetFoo(ctx ctxlock.Context, foo string) { + defer ctxlock.Lock(ctx, &r.mu).Unlock() + syncs.AssertLocked(&r.mu) + r.foo = foo +} + +func (r *Resource) GetBar(ctx ctxlock.Context) string { + defer ctxlock.Lock(ctx, &r.mu).Unlock() + syncs.AssertLocked(&r.mu) + return r.bar +} + +func (r *Resource) SetBar(ctx ctxlock.Context, bar string) { + defer ctxlock.Lock(ctx, &r.mu).Unlock() + syncs.AssertLocked(&r.mu) + r.bar = bar +} + +func (r *Resource) WithLock(ctx ctxlock.Context, f func(ctx ctxlock.Context)) { + // Lock the mutex if not already held, and get a new context. + ctx = ctxlock.Lock(ctx, &r.mu) + defer ctx.Unlock() + syncs.AssertLocked(&r.mu) + f(ctx) // Call the callback with the new context. +} + +func ExampleContext() { + var r Resource + r.SetFoo(ctxlock.None(), "foo") + r.SetBar(ctxlock.None(), "bar") + r.WithLock(ctxlock.None(), func(ctx ctxlock.Context) { + // This callback is invoked with r's lock held, + // and ctx carries the lock state. This means we can safely call + // other methods on r using ctx without causing a deadlock. + r.SetFoo(ctx, r.GetFoo(ctx)+r.GetBar(ctx)) + }) + fmt.Println(r.GetFoo(ctxlock.None())) + // Output: foobar +} + +func ExampleContext_twoResources() { + var r1, r2 Resource + r1.SetFoo(ctxlock.None(), "foo") + r2.SetBar(ctxlock.None(), "bar") + r1.WithLock(ctxlock.None(), func(ctx ctxlock.Context) { + // Here, r1's lock is held, but r2's lock is not. + // So r2 will be locked when we call r2.GetBar(ctx). + r1.SetFoo(ctx, r1.GetFoo(ctx)+r2.GetBar(ctx)) + }) + fmt.Println(r1.GetFoo(ctxlock.None())) + // Output: foobar +}