From dca4036a207b5f7edeb1d54cce30c7dfe1914499 Mon Sep 17 00:00:00 2001 From: Brad Fitzpatrick Date: Tue, 27 May 2025 13:31:39 -0700 Subject: [PATCH] util/set: add SmallSet Updates tailscale/corp#29093 Change-Id: I0e07e83dee51b4915597a913b0583c99756d90e2 Signed-off-by: Brad Fitzpatrick --- util/set/smallset.go | 134 ++++++++++++++++++++++++++++++++++++++ util/set/smallset_test.go | 91 ++++++++++++++++++++++++++ 2 files changed, 225 insertions(+) create mode 100644 util/set/smallset.go create mode 100644 util/set/smallset_test.go diff --git a/util/set/smallset.go b/util/set/smallset.go new file mode 100644 index 000000000..51cad6a25 --- /dev/null +++ b/util/set/smallset.go @@ -0,0 +1,134 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package set + +import ( + "iter" + "maps" + + "tailscale.com/types/structs" +) + +// SmallSet is a set that is optimized for reducing memory overhead when the +// expected size of the set is 0 or 1 elements. +// +// The zero value of SmallSet is a usable empty set. +// +// When storing a SmallSet in a map as a value type, it is important to re-assign +// the map entry after calling Add or Delete, as the SmallSet's representation +// may change. +// +// Copying a SmallSet by value may alias the previous value. Use the Clone method +// to create a new SmallSet with the same contents. +type SmallSet[T comparable] struct { + _ structs.Incomparable // to prevent == mistakes + one T // if non-zero, then single item in set + m Set[T] // if non-nil, the set of items, which might be size 1 if it's the zero value of T +} + +// Values returns an iterator over the elements of the set. +// The iterator will yield the elements in no particular order. +func (s SmallSet[T]) Values() iter.Seq[T] { + if s.m != nil { + return maps.Keys(s.m) + } + var zero T + return func(yield func(T) bool) { + if s.one != zero { + yield(s.one) + } + } +} + +// Contains reports whether e is in the set. +func (s SmallSet[T]) Contains(e T) bool { + if s.m != nil { + return s.m.Contains(e) + } + var zero T + return e != zero && s.one == e +} + +// Add adds e to the set. +// +// When storing a SmallSet in a map as a value type, it is important to +// re-assign the map entry after calling Add or Delete, as the SmallSet's +// representation may change. +func (s *SmallSet[T]) Add(e T) { + var zero T + if s.m != nil { + s.m.Add(e) + return + } + // Size zero to one non-zero element. + if s.one == zero && e != zero { + s.one = e + return + } + // Need to make a multi map, either + // because we now have two items, or + // because e is the zero value. + s.m = Set[T]{} + if s.one != zero { + s.m.Add(s.one) // move single item to multi + } + s.m.Add(e) // add new item + s.one = zero +} + +// Len reports the number of elements in the set. +func (s SmallSet[T]) Len() int { + var zero T + if s.m != nil { + return s.m.Len() + } + if s.one != zero { + return 1 + } + return 0 +} + +// Delete removes e from the set. +// +// When storing a SmallSet in a map as a value type, it is important to +// re-assign the map entry after calling Add or Delete, as the SmallSet's +// representation may change. +func (s *SmallSet[T]) Delete(e T) { + var zero T + if s.m == nil { + if s.one == e { + s.one = zero + } + return + } + s.m.Delete(e) + + // If the map size drops to zero, that means + // it only contained the zero value of T. + if s.m.Len() == 0 { + s.m = nil + return + } + + // If the map size drops to one element and doesn't + // contain the zero value, we can switch back to the + // single-item representation. + if s.m.Len() == 1 { + for v := range s.m { + if v != zero { + s.one = v + s.m = nil + } + } + } + return +} + +// Clone returns a copy of s that doesn't alias the original. +func (s SmallSet[T]) Clone() SmallSet[T] { + return SmallSet[T]{ + one: s.one, + m: maps.Clone(s.m), // preserves nilness + } +} diff --git a/util/set/smallset_test.go b/util/set/smallset_test.go new file mode 100644 index 000000000..2635bc893 --- /dev/null +++ b/util/set/smallset_test.go @@ -0,0 +1,91 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package set + +import ( + "fmt" + "iter" + "maps" + "reflect" + "slices" + "testing" +) + +func TestSmallSet(t *testing.T) { + t.Parallel() + + wantSize := reflect.TypeFor[int64]().Size() + reflect.TypeFor[map[int]struct{}]().Size() + if wantSize > 16 { + t.Errorf("wantSize should be no more than 16") // it might be smaller on 32-bit systems + } + if size := reflect.TypeFor[SmallSet[int64]]().Size(); size != wantSize { + t.Errorf("SmallSet[int64] size is %d, want %v", size, wantSize) + } + + type op struct { + add bool + v int + } + ops := iter.Seq[op](func(yield func(op) bool) { + for _, add := range []bool{false, true} { + for v := range 4 { + if !yield(op{add: add, v: v}) { + return + } + } + } + }) + type setLike interface { + Add(int) + Delete(int) + } + apply := func(s setLike, o op) { + if o.add { + s.Add(o.v) + } else { + s.Delete(o.v) + } + } + + // For all combinations of 4 operations, + // apply them to both a regular map and SmallSet + // and make sure all the invariants hold. + + for op1 := range ops { + for op2 := range ops { + for op3 := range ops { + for op4 := range ops { + + normal := Set[int]{} + small := &SmallSet[int]{} + for _, op := range []op{op1, op2, op3, op4} { + apply(normal, op) + apply(small, op) + } + + name := func() string { + return fmt.Sprintf("op1=%v, op2=%v, op3=%v, op4=%v", op1, op2, op3, op4) + } + if normal.Len() != small.Len() { + t.Errorf("len mismatch after ops %s: normal=%d, small=%d", name(), normal.Len(), small.Len()) + } + if got := small.Clone().Len(); normal.Len() != got { + t.Errorf("len mismatch after ops %s: normal=%d, clone=%d", name(), normal.Len(), got) + } + + normalEle := slices.Sorted(maps.Keys(normal)) + smallEle := slices.Sorted(small.Values()) + if !slices.Equal(normalEle, smallEle) { + t.Errorf("elements mismatch after ops %s: normal=%v, small=%v", name(), normalEle, smallEle) + } + for e := range 5 { + if normal.Contains(e) != small.Contains(e) { + t.Errorf("contains(%v) mismatch after ops %s: normal=%v, small=%v", e, name(), normal.Contains(e), small.Contains(e)) + } + } + } + } + } + } +}