mirror of
https://github.com/tailscale/tailscale.git
synced 2025-01-10 10:03:43 +00:00
120 lines
2.4 KiB
Go
120 lines
2.4 KiB
Go
|
// Copyright (c) Tailscale Inc & AUTHORS
|
||
|
// SPDX-License-Identifier: BSD-3-Clause
|
||
|
|
||
|
package syncs
|
||
|
|
||
|
import (
|
||
|
"math"
|
||
|
"runtime"
|
||
|
"sync"
|
||
|
"sync/atomic"
|
||
|
"testing"
|
||
|
|
||
|
"golang.org/x/sys/cpu"
|
||
|
)
|
||
|
|
||
|
func TestShardValue(t *testing.T) {
|
||
|
type intVal struct {
|
||
|
atomic.Int64
|
||
|
_ cpu.CacheLinePad
|
||
|
}
|
||
|
|
||
|
t.Run("One", func(t *testing.T) {
|
||
|
sv := NewShardValue[intVal]()
|
||
|
sv.One(func(v *intVal) {
|
||
|
v.Store(10)
|
||
|
})
|
||
|
|
||
|
var v int64
|
||
|
for i := range sv.shards {
|
||
|
v += sv.shards[i].Load()
|
||
|
}
|
||
|
if v != 10 {
|
||
|
t.Errorf("got %v, want 10", v)
|
||
|
}
|
||
|
})
|
||
|
|
||
|
t.Run("All", func(t *testing.T) {
|
||
|
sv := NewShardValue[intVal]()
|
||
|
for i := range sv.shards {
|
||
|
sv.shards[i].Store(int64(i))
|
||
|
}
|
||
|
|
||
|
var total int64
|
||
|
sv.All(func(v *intVal) bool {
|
||
|
total += v.Load()
|
||
|
return true
|
||
|
})
|
||
|
// triangle coefficient lower one order due to 0 index
|
||
|
want := int64(len(sv.shards) * (len(sv.shards) - 1) / 2)
|
||
|
if total != want {
|
||
|
t.Errorf("got %v, want %v", total, want)
|
||
|
}
|
||
|
})
|
||
|
|
||
|
t.Run("Len", func(t *testing.T) {
|
||
|
sv := NewShardValue[intVal]()
|
||
|
if got, want := sv.Len(), runtime.NumCPU(); got != want {
|
||
|
t.Errorf("got %v, want %v", got, want)
|
||
|
}
|
||
|
})
|
||
|
|
||
|
t.Run("distribution", func(t *testing.T) {
|
||
|
sv := NewShardValue[intVal]()
|
||
|
|
||
|
goroutines := 1000
|
||
|
iterations := 10000
|
||
|
var wg sync.WaitGroup
|
||
|
wg.Add(goroutines)
|
||
|
for i := 0; i < goroutines; i++ {
|
||
|
go func() {
|
||
|
defer wg.Done()
|
||
|
for i := 0; i < iterations; i++ {
|
||
|
sv.One(func(v *intVal) {
|
||
|
v.Add(1)
|
||
|
})
|
||
|
}
|
||
|
}()
|
||
|
}
|
||
|
wg.Wait()
|
||
|
|
||
|
var (
|
||
|
total int64
|
||
|
distribution []int64
|
||
|
)
|
||
|
t.Logf("distribution:")
|
||
|
sv.All(func(v *intVal) bool {
|
||
|
total += v.Load()
|
||
|
distribution = append(distribution, v.Load())
|
||
|
t.Logf("%d", v.Load())
|
||
|
return true
|
||
|
})
|
||
|
|
||
|
if got, want := total, int64(goroutines*iterations); got != want {
|
||
|
t.Errorf("got %v, want %v", got, want)
|
||
|
}
|
||
|
if got, want := len(distribution), runtime.NumCPU(); got != want {
|
||
|
t.Errorf("got %v, want %v", got, want)
|
||
|
}
|
||
|
|
||
|
mean := total / int64(len(distribution))
|
||
|
for _, v := range distribution {
|
||
|
if v < mean/10 || v > mean*10 {
|
||
|
t.Logf("distribution is very unbalanced: %v", distribution)
|
||
|
}
|
||
|
}
|
||
|
t.Logf("mean: %d", mean)
|
||
|
|
||
|
var standardDev int64
|
||
|
for _, v := range distribution {
|
||
|
standardDev += ((v - mean) * (v - mean))
|
||
|
}
|
||
|
standardDev = int64(math.Sqrt(float64(standardDev / int64(len(distribution)))))
|
||
|
t.Logf("stdev: %d", standardDev)
|
||
|
|
||
|
if standardDev > mean/3 {
|
||
|
t.Logf("standard deviation is too high: %v", standardDev)
|
||
|
}
|
||
|
})
|
||
|
}
|