syncs: fix AtomicValue.CompareAndSwap (#16137)

Fix CompareAndSwap in the edge-case where
the underlying sync.AtomicValue is uninitialized
(i.e., Store was never called) and
the oldV is the zero value,
then perform CompareAndSwap with any(nil).

Also, document that T must be comparable.
This is a pre-existing restriction.

Fixes #16135

Signed-off-by: Joe Tsai <joetsai@digital-static.net>
This commit is contained in:
Joe Tsai 2025-05-30 08:06:16 -10:00 committed by GitHub
parent 11e83f9da5
commit 84aa7ff3bb
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 25 additions and 2 deletions

View File

@ -67,12 +67,18 @@ func (v *AtomicValue[T]) Swap(x T) (old T) {
if oldV != nil {
return oldV.(wrappedValue[T]).v
}
return old
return old // zero value of T
}
// CompareAndSwap executes the compare-and-swap operation for the Value.
// It panics if T is not comparable.
func (v *AtomicValue[T]) CompareAndSwap(oldV, newV T) (swapped bool) {
return v.v.CompareAndSwap(wrappedValue[T]{oldV}, wrappedValue[T]{newV})
var zero T
return v.v.CompareAndSwap(wrappedValue[T]{oldV}, wrappedValue[T]{newV}) ||
// In the edge-case where [atomic.Value.Store] is uninitialized
// and trying to compare with the zero value of T,
// then compare-and-swap with the nil any value.
(any(oldV) == any(zero) && v.v.CompareAndSwap(any(nil), wrappedValue[T]{newV}))
}
// MutexValue is a value protected by a mutex.

View File

@ -64,6 +64,23 @@ func TestAtomicValue(t *testing.T) {
t.Fatalf("LoadOk = (%v, %v), want (nil, true)", got, gotOk)
}
}
{
c1, c2, c3 := make(chan struct{}), make(chan struct{}), make(chan struct{})
var v AtomicValue[chan struct{}]
if v.CompareAndSwap(c1, c2) != false {
t.Fatalf("CompareAndSwap = true, want false")
}
if v.CompareAndSwap(nil, c1) != true {
t.Fatalf("CompareAndSwap = false, want true")
}
if v.CompareAndSwap(c2, c3) != false {
t.Fatalf("CompareAndSwap = true, want false")
}
if v.CompareAndSwap(c1, c2) != true {
t.Fatalf("CompareAndSwap = false, want true")
}
}
}
func TestMutexValue(t *testing.T) {