diff --git a/syncs/shardedmap.go b/syncs/shardedmap.go new file mode 100644 index 000000000..00ce3aafa --- /dev/null +++ b/syncs/shardedmap.go @@ -0,0 +1,111 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package syncs + +import ( + "sync" + + "golang.org/x/sys/cpu" +) + +// ShardedMap is a synchronized map[K]V, internally sharded by a user-defined +// K-sharding function. +// +// The zero value is not safe for use; use NewShardedMap. +type ShardedMap[K comparable, V any] struct { + shardFunc func(K) int + shards []mapShard[K, V] +} + +type mapShard[K comparable, V any] struct { + mu sync.Mutex + m map[K]V + _ cpu.CacheLinePad // avoid false sharing of neighboring shards' mutexes +} + +// NewShardedMap returns a new ShardedMap with the given number of shards and +// sharding function. +// +// The shard func must return a integer in the range [0, shards) purely +// deterministically based on the provided K. +func NewShardedMap[K comparable, V any](shards int, shard func(K) int) *ShardedMap[K, V] { + m := &ShardedMap[K, V]{ + shardFunc: shard, + shards: make([]mapShard[K, V], shards), + } + for i := range m.shards { + m.shards[i].m = make(map[K]V) + } + return m +} + +func (m *ShardedMap[K, V]) shard(key K) *mapShard[K, V] { + return &m.shards[m.shardFunc(key)] +} + +// GetOk returns m[key] and whether it was present. +func (m *ShardedMap[K, V]) GetOk(key K) (value V, ok bool) { + shard := m.shard(key) + shard.mu.Lock() + defer shard.mu.Unlock() + value, ok = shard.m[key] + return +} + +// Get returns m[key] or the zero value of V if key is not present. +func (m *ShardedMap[K, V]) Get(key K) (value V) { + value, _ = m.GetOk(key) + return +} + +// Set sets m[key] = value. +// +// It reports whether the map grew in size (that is, whether key was not already +// present in m). +func (m *ShardedMap[K, V]) Set(key K, value V) (grew bool) { + shard := m.shard(key) + shard.mu.Lock() + defer shard.mu.Unlock() + s0 := len(shard.m) + shard.m[key] = value + return len(shard.m) > s0 +} + +// Delete removes key from m. +// +// It reports whether the map size shrunk (that is, whether key was present in +// the map). +func (m *ShardedMap[K, V]) Delete(key K) (shrunk bool) { + shard := m.shard(key) + shard.mu.Lock() + defer shard.mu.Unlock() + s0 := len(shard.m) + delete(shard.m, key) + return len(shard.m) < s0 +} + +// Contains reports whether m contains key. +func (m *ShardedMap[K, V]) Contains(key K) bool { + shard := m.shard(key) + shard.mu.Lock() + defer shard.mu.Unlock() + _, ok := shard.m[key] + return ok +} + +// Len returns the number of elements in m. +// +// It does so by locking shards one at a time, so it's not particularly cheap, +// nor does it give a consistent snapshot of the map. It's mostly intended for +// metrics or testing. +func (m *ShardedMap[K, V]) Len() int { + n := 0 + for i := range m.shards { + shard := &m.shards[i] + shard.mu.Lock() + n += len(shard.m) + shard.mu.Unlock() + } + return n +} diff --git a/syncs/shardedmap_test.go b/syncs/shardedmap_test.go new file mode 100644 index 000000000..b09a268d7 --- /dev/null +++ b/syncs/shardedmap_test.go @@ -0,0 +1,44 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package syncs + +import "testing" + +func TestShardedMap(t *testing.T) { + m := NewShardedMap[int, string](16, func(i int) int { return i % 16 }) + + if m.Contains(1) { + t.Errorf("got contains; want !contains") + } + if !m.Set(1, "one") { + t.Errorf("got !set; want set") + } + if m.Set(1, "one") { + t.Errorf("got set; want !set") + } + if !m.Contains(1) { + t.Errorf("got !contains; want contains") + } + if g, w := m.Get(1), "one"; g != w { + t.Errorf("got %q; want %q", g, w) + } + if _, ok := m.GetOk(1); !ok { + t.Errorf("got ok; want !ok") + } + if _, ok := m.GetOk(2); ok { + t.Errorf("got ok; want !ok") + } + if g, w := m.Len(), 1; g != w { + t.Errorf("got Len %v; want %v", g, w) + } + if m.Delete(2) { + t.Errorf("got deleted; want !deleted") + } + if !m.Delete(1) { + t.Errorf("got !deleted; want deleted") + } + if g, w := m.Len(), 0; g != w { + t.Errorf("got Len %v; want %v", g, w) + } +}