diff --git a/go.toolchain.rev b/go.toolchain.rev index 7be85deb6..e90440d41 100644 --- a/go.toolchain.rev +++ b/go.toolchain.rev @@ -1 +1 @@ -e005697288a8d2fadc87bb7c3e2c74778d08554a +161c3b79ed91039e65eb148f2547dea6b91e2247 diff --git a/syncs/shardedint.go b/syncs/shardedint.go new file mode 100644 index 000000000..28c4168d5 --- /dev/null +++ b/syncs/shardedint.go @@ -0,0 +1,69 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package syncs + +import ( + "encoding/json" + "strconv" + "sync/atomic" + + "golang.org/x/sys/cpu" +) + +// ShardedInt provides a sharded atomic int64 value that optimizes high +// frequency (Mhz range and above) writes in highly parallel workloads. +// The zero value is not safe for use; use [NewShardedInt]. +// ShardedInt implements the expvar.Var interface. +type ShardedInt struct { + sv *ShardValue[intShard] +} + +// NewShardedInt returns a new [ShardedInt]. +func NewShardedInt() *ShardedInt { + return &ShardedInt{ + sv: NewShardValue[intShard](), + } +} + +// Add adds delta to the value. +func (m *ShardedInt) Add(delta int64) { + m.sv.One(func(v *intShard) { + v.Add(delta) + }) +} + +type intShard struct { + atomic.Int64 + _ cpu.CacheLinePad // avoid false sharing of neighboring shards +} + +// Value returns the current value. +func (m *ShardedInt) Value() int64 { + var v int64 + for s := range m.sv.All { + v += s.Load() + } + return v +} + +// GetDistribution returns the current value in each shard. +// This is intended for observability/debugging only. +func (m *ShardedInt) GetDistribution() []int64 { + v := make([]int64, 0, m.sv.Len()) + for s := range m.sv.All { + v = append(v, s.Load()) + } + return v +} + +// String implements the expvar.Var interface +func (m *ShardedInt) String() string { + v, _ := json.Marshal(m.Value()) + return string(v) +} + +// AppendText implements the encoding.TextAppender interface +func (m *ShardedInt) AppendText(b []byte) ([]byte, error) { + return strconv.AppendInt(b, m.Value(), 10), nil +} diff --git a/syncs/shardedint_test.go b/syncs/shardedint_test.go new file mode 100644 index 000000000..d355a1540 --- /dev/null +++ b/syncs/shardedint_test.go @@ -0,0 +1,119 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package syncs + +import ( + "expvar" + "sync" + "testing" + + "tailscale.com/tstest" +) + +var ( + _ expvar.Var = (*ShardedInt)(nil) + // TODO(raggi): future go version: + // _ encoding.TextAppender = (*ShardedInt)(nil) +) + +func BenchmarkShardedInt(b *testing.B) { + b.ReportAllocs() + + b.Run("expvar", func(b *testing.B) { + var m expvar.Int + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + m.Add(1) + } + }) + }) + + b.Run("sharded int", func(b *testing.B) { + m := NewShardedInt() + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + m.Add(1) + } + }) + }) +} + +func TestShardedInt(t *testing.T) { + t.Run("basics", func(t *testing.T) { + m := NewShardedInt() + if got, want := m.Value(), int64(0); got != want { + t.Errorf("got %v, want %v", got, want) + } + m.Add(1) + if got, want := m.Value(), int64(1); got != want { + t.Errorf("got %v, want %v", got, want) + } + m.Add(2) + if got, want := m.Value(), int64(3); got != want { + t.Errorf("got %v, want %v", got, want) + } + m.Add(-1) + if got, want := m.Value(), int64(2); got != want { + t.Errorf("got %v, want %v", got, want) + } + }) + + t.Run("high concurrency", func(t *testing.T) { + m := NewShardedInt() + wg := sync.WaitGroup{} + numWorkers := 1000 + numIncrements := 1000 + wg.Add(numWorkers) + for i := 0; i < numWorkers; i++ { + go func() { + defer wg.Done() + for i := 0; i < numIncrements; i++ { + m.Add(1) + } + }() + } + wg.Wait() + if got, want := m.Value(), int64(numWorkers*numIncrements); got != want { + t.Errorf("got %v, want %v", got, want) + } + for i, shard := range m.GetDistribution() { + t.Logf("shard %d: %d", i, shard) + } + }) + + t.Run("encoding.TextAppender", func(t *testing.T) { + m := NewShardedInt() + m.Add(1) + b := make([]byte, 0, 10) + b, err := m.AppendText(b) + if err != nil { + t.Fatal(err) + } + if got, want := string(b), "1"; got != want { + t.Errorf("got %v, want %v", got, want) + } + }) + + t.Run("allocs", func(t *testing.T) { + m := NewShardedInt() + tstest.MinAllocsPerRun(t, 0, func() { + m.Add(1) + _ = m.Value() + }) + + // TODO(raggi): fix access to expvar's internal append based + // interface, unfortunately it's not currently closed for external + // use, this will alloc when it escapes. + tstest.MinAllocsPerRun(t, 0, func() { + m.Add(1) + _ = m.String() + }) + + b := make([]byte, 0, 10) + tstest.MinAllocsPerRun(t, 0, func() { + m.Add(1) + m.AppendText(b) + }) + }) +} diff --git a/syncs/shardvalue.go b/syncs/shardvalue.go new file mode 100644 index 000000000..b1474477c --- /dev/null +++ b/syncs/shardvalue.go @@ -0,0 +1,36 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package syncs + +// TODO(raggi): this implementation is still imperfect as it will still result +// in cross CPU sharing periodically, we instead really want a per-CPU shard +// key, but the limitations of calling platform code make reaching for even the +// getcpu vdso very painful. See https://github.com/golang/go/issues/18802, and +// hopefully one day we can replace with a primitive that falls out of that +// work. + +// ShardValue contains a value sharded over a set of shards. +// In order to be useful, T should be aligned to cache lines. +// Users must organize that usage in One and All is concurrency safe. +// The zero value is not safe for use; use [NewShardValue]. +type ShardValue[T any] struct { + shards []T + + //lint:ignore U1000 unused under tailscale_go builds. + pool shardValuePool +} + +// Len returns the number of shards. +func (sp *ShardValue[T]) Len() int { + return len(sp.shards) +} + +// All yields a pointer to the value in each shard. +func (sp *ShardValue[T]) All(yield func(*T) bool) { + for i := range sp.shards { + if !yield(&sp.shards[i]) { + return + } + } +} diff --git a/syncs/shardvalue_go.go b/syncs/shardvalue_go.go new file mode 100644 index 000000000..9b9d252a7 --- /dev/null +++ b/syncs/shardvalue_go.go @@ -0,0 +1,36 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !tailscale_go + +package syncs + +import ( + "runtime" + "sync" + "sync/atomic" +) + +type shardValuePool struct { + atomic.Int64 + sync.Pool +} + +// NewShardValue constructs a new ShardValue[T] with a shard per CPU. +func NewShardValue[T any]() *ShardValue[T] { + sp := &ShardValue[T]{ + shards: make([]T, runtime.NumCPU()), + } + sp.pool.New = func() any { + i := sp.pool.Add(1) - 1 + return &sp.shards[i%int64(len(sp.shards))] + } + return sp +} + +// One yields a pointer to a single shard value with best-effort P-locality. +func (sp *ShardValue[T]) One(yield func(*T)) { + v := sp.pool.Get().(*T) + yield(v) + sp.pool.Put(v) +} diff --git a/syncs/shardvalue_tailscale.go b/syncs/shardvalue_tailscale.go new file mode 100644 index 000000000..8ef778ff3 --- /dev/null +++ b/syncs/shardvalue_tailscale.go @@ -0,0 +1,24 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// TODO(raggi): update build tag after toolchain update +//go:build tailscale_go + +package syncs + +import ( + "runtime" +) + +//lint:ignore U1000 unused under tailscale_go builds. +type shardValuePool struct{} + +// NewShardValue constructs a new ShardValue[T] with a shard per CPU. +func NewShardValue[T any]() *ShardValue[T] { + return &ShardValue[T]{shards: make([]T, runtime.NumCPU())} +} + +// One yields a pointer to a single shard value with best-effort P-locality. +func (sp *ShardValue[T]) One(f func(*T)) { + f(&sp.shards[runtime.TailscaleCurrentP()%len(sp.shards)]) +} diff --git a/syncs/shardvalue_test.go b/syncs/shardvalue_test.go new file mode 100644 index 000000000..8f6ac6414 --- /dev/null +++ b/syncs/shardvalue_test.go @@ -0,0 +1,119 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package syncs + +import ( + "math" + "runtime" + "sync" + "sync/atomic" + "testing" + + "golang.org/x/sys/cpu" +) + +func TestShardValue(t *testing.T) { + type intVal struct { + atomic.Int64 + _ cpu.CacheLinePad + } + + t.Run("One", func(t *testing.T) { + sv := NewShardValue[intVal]() + sv.One(func(v *intVal) { + v.Store(10) + }) + + var v int64 + for i := range sv.shards { + v += sv.shards[i].Load() + } + if v != 10 { + t.Errorf("got %v, want 10", v) + } + }) + + t.Run("All", func(t *testing.T) { + sv := NewShardValue[intVal]() + for i := range sv.shards { + sv.shards[i].Store(int64(i)) + } + + var total int64 + sv.All(func(v *intVal) bool { + total += v.Load() + return true + }) + // triangle coefficient lower one order due to 0 index + want := int64(len(sv.shards) * (len(sv.shards) - 1) / 2) + if total != want { + t.Errorf("got %v, want %v", total, want) + } + }) + + t.Run("Len", func(t *testing.T) { + sv := NewShardValue[intVal]() + if got, want := sv.Len(), runtime.NumCPU(); got != want { + t.Errorf("got %v, want %v", got, want) + } + }) + + t.Run("distribution", func(t *testing.T) { + sv := NewShardValue[intVal]() + + goroutines := 1000 + iterations := 10000 + var wg sync.WaitGroup + wg.Add(goroutines) + for i := 0; i < goroutines; i++ { + go func() { + defer wg.Done() + for i := 0; i < iterations; i++ { + sv.One(func(v *intVal) { + v.Add(1) + }) + } + }() + } + wg.Wait() + + var ( + total int64 + distribution []int64 + ) + t.Logf("distribution:") + sv.All(func(v *intVal) bool { + total += v.Load() + distribution = append(distribution, v.Load()) + t.Logf("%d", v.Load()) + return true + }) + + if got, want := total, int64(goroutines*iterations); got != want { + t.Errorf("got %v, want %v", got, want) + } + if got, want := len(distribution), runtime.NumCPU(); got != want { + t.Errorf("got %v, want %v", got, want) + } + + mean := total / int64(len(distribution)) + for _, v := range distribution { + if v < mean/10 || v > mean*10 { + t.Logf("distribution is very unbalanced: %v", distribution) + } + } + t.Logf("mean: %d", mean) + + var standardDev int64 + for _, v := range distribution { + standardDev += ((v - mean) * (v - mean)) + } + standardDev = int64(math.Sqrt(float64(standardDev / int64(len(distribution))))) + t.Logf("stdev: %d", standardDev) + + if standardDev > mean/3 { + t.Logf("standard deviation is too high: %v", standardDev) + } + }) +}