tailscale/syncs/shardvalue_test.go
James Tucker e8f1721147 syncs: add ShardedInt expvar.Var type
ShardedInt provides an int type expvar.Var that supports more efficient
writes at high frequencies (one order of magnigude on an M1 Max, much
more on NUMA systems).

There are two implementations of ShardValue, one that abuses sync.Pool
that will work on current public Go versions, and one that takes a
dependency on a runtime.TailscaleP function exposed in Tailscale's Go
fork. The sync.Pool variant has about 10x the throughput of a single
atomic integer on an M1 Max, and the runtime.TailscaleP variant is about
10x faster than the sync.Pool variant.

Neither variant have perfect distribution, or perfectly always avoid
cross-CPU sharing, as there is no locking or affinity to ensure that the
time of yield is on the same core as the time of core biasing, but in
the average case the distributions are enough to provide substantially
better performance.

See golang/go#18802 for a related upstream proposal.

Updates tailscale/go#109
Updates tailscale/corp#25450

Signed-off-by: James Tucker <james@tailscale.com>
2024-12-19 14:58:28 -08:00

120 lines
2.4 KiB
Go

// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause
package syncs
import (
"math"
"runtime"
"sync"
"sync/atomic"
"testing"
"golang.org/x/sys/cpu"
)
func TestShardValue(t *testing.T) {
type intVal struct {
atomic.Int64
_ cpu.CacheLinePad
}
t.Run("One", func(t *testing.T) {
sv := NewShardValue[intVal]()
sv.One(func(v *intVal) {
v.Store(10)
})
var v int64
for i := range sv.shards {
v += sv.shards[i].Load()
}
if v != 10 {
t.Errorf("got %v, want 10", v)
}
})
t.Run("All", func(t *testing.T) {
sv := NewShardValue[intVal]()
for i := range sv.shards {
sv.shards[i].Store(int64(i))
}
var total int64
sv.All(func(v *intVal) bool {
total += v.Load()
return true
})
// triangle coefficient lower one order due to 0 index
want := int64(len(sv.shards) * (len(sv.shards) - 1) / 2)
if total != want {
t.Errorf("got %v, want %v", total, want)
}
})
t.Run("Len", func(t *testing.T) {
sv := NewShardValue[intVal]()
if got, want := sv.Len(), runtime.NumCPU(); got != want {
t.Errorf("got %v, want %v", got, want)
}
})
t.Run("distribution", func(t *testing.T) {
sv := NewShardValue[intVal]()
goroutines := 1000
iterations := 10000
var wg sync.WaitGroup
wg.Add(goroutines)
for i := 0; i < goroutines; i++ {
go func() {
defer wg.Done()
for i := 0; i < iterations; i++ {
sv.One(func(v *intVal) {
v.Add(1)
})
}
}()
}
wg.Wait()
var (
total int64
distribution []int64
)
t.Logf("distribution:")
sv.All(func(v *intVal) bool {
total += v.Load()
distribution = append(distribution, v.Load())
t.Logf("%d", v.Load())
return true
})
if got, want := total, int64(goroutines*iterations); got != want {
t.Errorf("got %v, want %v", got, want)
}
if got, want := len(distribution), runtime.NumCPU(); got != want {
t.Errorf("got %v, want %v", got, want)
}
mean := total / int64(len(distribution))
for _, v := range distribution {
if v < mean/10 || v > mean*10 {
t.Logf("distribution is very unbalanced: %v", distribution)
}
}
t.Logf("mean: %d", mean)
var standardDev int64
for _, v := range distribution {
standardDev += ((v - mean) * (v - mean))
}
standardDev = int64(math.Sqrt(float64(standardDev / int64(len(distribution)))))
t.Logf("stdev: %d", standardDev)
if standardDev > mean/3 {
t.Logf("standard deviation is too high: %v", standardDev)
}
})
}