From 0de5e7b94f0bb89bcaed108f656d3ed50da85d02 Mon Sep 17 00:00:00 2001 From: Joe Tsai Date: Tue, 22 Jul 2025 09:22:17 -1000 Subject: [PATCH] util/set: add IntSet (#16602) IntSet is a set optimized for integers. Updates tailscale/corp#29809 Signed-off-by: Joe Tsai --- util/set/intset.go | 172 +++++++++++++++++++++++++++++++++++++++ util/set/intset_test.go | 174 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 346 insertions(+) create mode 100644 util/set/intset.go create mode 100644 util/set/intset_test.go diff --git a/util/set/intset.go b/util/set/intset.go new file mode 100644 index 000000000..b747d3bff --- /dev/null +++ b/util/set/intset.go @@ -0,0 +1,172 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package set + +import ( + "iter" + "maps" + "math/bits" + "math/rand/v2" + + "golang.org/x/exp/constraints" + "tailscale.com/util/mak" +) + +// IntSet is a set optimized for integer values close to zero +// or set of integers that are close in value. +type IntSet[T constraints.Integer] struct { + // bits is a [bitSet] for numbers less than [bits.UintSize]. + bits bitSet + + // extra is a mapping of [bitSet] for numbers not in bits, + // where the key is a number modulo [bits.UintSize]. + extra map[uint64]bitSet + + // extraLen is the count of numbers in extra since len(extra) + // does not reflect that each bitSet may have multiple numbers. + extraLen int +} + +// Values returns an iterator over the elements of the set. +// The iterator will yield the elements in no particular order. +func (s IntSet[T]) Values() iter.Seq[T] { + return func(yield func(T) bool) { + if s.bits != 0 { + for i := range s.bits.values() { + if !yield(decodeZigZag[T](i)) { + return + } + } + } + if s.extra != nil { + for hi, bs := range s.extra { + for lo := range bs.values() { + if !yield(decodeZigZag[T](hi*bits.UintSize + lo)) { + return + } + } + } + } + } +} + +// Contains reports whether e is in the set. +func (s IntSet[T]) Contains(e T) bool { + if v := encodeZigZag(e); v < bits.UintSize { + return s.bits.contains(v) + } else { + hi, lo := v/uint64(bits.UintSize), v%uint64(bits.UintSize) + return s.extra[hi].contains(lo) + } +} + +// Add adds e to the set. +// +// When storing a IntSet in a map as a value type, +// it is important to re-assign the map entry after calling Add or Delete, +// as the IntSet's representation may change. +func (s *IntSet[T]) Add(e T) { + if v := encodeZigZag(e); v < bits.UintSize { + s.bits.add(v) + } else { + hi, lo := v/uint64(bits.UintSize), v%uint64(bits.UintSize) + if bs := s.extra[hi]; !bs.contains(lo) { + bs.add(lo) + mak.Set(&s.extra, hi, bs) + s.extra[hi] = bs + s.extraLen++ + } + } +} + +// AddSeq adds the values from seq to the set. +func (s *IntSet[T]) AddSeq(seq iter.Seq[T]) { + for e := range seq { + s.Add(e) + } +} + +// Len reports the number of elements in the set. +func (s IntSet[T]) Len() int { + return s.bits.len() + s.extraLen +} + +// Delete removes e from the set. +// +// When storing a IntSet in a map as a value type, +// it is important to re-assign the map entry after calling Add or Delete, +// as the IntSet's representation may change. +func (s *IntSet[T]) Delete(e T) { + if v := encodeZigZag(e); v < bits.UintSize { + s.bits.delete(v) + } else { + hi, lo := v/uint64(bits.UintSize), v%uint64(bits.UintSize) + if bs := s.extra[hi]; bs.contains(lo) { + bs.delete(lo) + mak.Set(&s.extra, hi, bs) + s.extra[hi] = bs + s.extraLen-- + } + } +} + +// Clone returns a copy of s that doesn't alias the original. +func (s IntSet[T]) Clone() IntSet[T] { + return IntSet[T]{ + bits: s.bits, + extra: maps.Clone(s.extra), + extraLen: s.extraLen, + } +} + +type bitSet uint + +func (s bitSet) values() iter.Seq[uint64] { + return func(yield func(uint64) bool) { + // Hyrum-proofing: randomly iterate in forwards or reverse. + if rand.Uint64()%2 == 0 { + for i := 0; i < bits.UintSize; i++ { + if s.contains(uint64(i)) && !yield(uint64(i)) { + return + } + } + } else { + for i := bits.UintSize; i >= 0; i-- { + if s.contains(uint64(i)) && !yield(uint64(i)) { + return + } + } + } + } +} +func (s bitSet) len() int { return bits.OnesCount(uint(s)) } +func (s bitSet) contains(i uint64) bool { return s&(1< 0 } +func (s *bitSet) add(i uint64) { *s |= 1 << i } +func (s *bitSet) delete(i uint64) { *s &= ^(1 << i) } + +// encodeZigZag encodes an integer as an unsigned integer ensuring that +// negative integers near zero still have a near zero positive value. +// For unsigned integers, it returns the value verbatim. +func encodeZigZag[T constraints.Integer](v T) uint64 { + var zero T + if ^zero >= 0 { // must be constraints.Unsigned + return uint64(v) + } else { // must be constraints.Signed + // See [google.golang.org/protobuf/encoding/protowire.EncodeZigZag] + return uint64(int64(v)<<1) ^ uint64(int64(v)>>63) + } +} + +// decodeZigZag decodes an unsigned integer as an integer ensuring that +// negative integers near zero still have a near zero positive value. +// For unsigned integers, it returns the value verbatim. +func decodeZigZag[T constraints.Integer](v uint64) T { + var zero T + if ^zero >= 0 { // must be constraints.Unsigned + return T(v) + } else { // must be constraints.Signed + // See [google.golang.org/protobuf/encoding/protowire.DecodeZigZag] + return T(int64(v>>1) ^ int64(v)<<63>>63) + } +} diff --git a/util/set/intset_test.go b/util/set/intset_test.go new file mode 100644 index 000000000..9523fe88d --- /dev/null +++ b/util/set/intset_test.go @@ -0,0 +1,174 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package set + +import ( + "maps" + "math" + "slices" + "testing" + + "golang.org/x/exp/constraints" +) + +func TestIntSet(t *testing.T) { + t.Run("Int64", func(t *testing.T) { + ss := make(Set[int64]) + var si IntSet[int64] + intValues(t, ss, si) + deleteInt(t, ss, &si, -5) + deleteInt(t, ss, &si, 2) + deleteInt(t, ss, &si, 75) + intValues(t, ss, si) + addInt(t, ss, &si, 2) + addInt(t, ss, &si, 75) + addInt(t, ss, &si, 75) + addInt(t, ss, &si, -3) + addInt(t, ss, &si, -3) + addInt(t, ss, &si, -3) + addInt(t, ss, &si, math.MinInt64) + addInt(t, ss, &si, 8) + intValues(t, ss, si) + addInt(t, ss, &si, 77) + addInt(t, ss, &si, 76) + addInt(t, ss, &si, 76) + addInt(t, ss, &si, 76) + intValues(t, ss, si) + addInt(t, ss, &si, -5) + addInt(t, ss, &si, 7) + addInt(t, ss, &si, -83) + addInt(t, ss, &si, math.MaxInt64) + intValues(t, ss, si) + deleteInt(t, ss, &si, -5) + deleteInt(t, ss, &si, 2) + deleteInt(t, ss, &si, 75) + intValues(t, ss, si) + deleteInt(t, ss, &si, math.MinInt64) + deleteInt(t, ss, &si, math.MaxInt64) + intValues(t, ss, si) + }) + + t.Run("Uint64", func(t *testing.T) { + ss := make(Set[uint64]) + var si IntSet[uint64] + intValues(t, ss, si) + deleteInt(t, ss, &si, 5) + deleteInt(t, ss, &si, 2) + deleteInt(t, ss, &si, 75) + intValues(t, ss, si) + addInt(t, ss, &si, 2) + addInt(t, ss, &si, 75) + addInt(t, ss, &si, 75) + addInt(t, ss, &si, 3) + addInt(t, ss, &si, 3) + addInt(t, ss, &si, 8) + intValues(t, ss, si) + addInt(t, ss, &si, 77) + addInt(t, ss, &si, 76) + addInt(t, ss, &si, 76) + addInt(t, ss, &si, 76) + intValues(t, ss, si) + addInt(t, ss, &si, 5) + addInt(t, ss, &si, 7) + addInt(t, ss, &si, 83) + addInt(t, ss, &si, math.MaxInt64) + intValues(t, ss, si) + deleteInt(t, ss, &si, 5) + deleteInt(t, ss, &si, 2) + deleteInt(t, ss, &si, 75) + intValues(t, ss, si) + deleteInt(t, ss, &si, math.MaxInt64) + intValues(t, ss, si) + }) +} + +func intValues[T constraints.Integer](t testing.TB, ss Set[T], si IntSet[T]) { + got := slices.Collect(maps.Keys(ss)) + slices.Sort(got) + want := slices.Collect(si.Values()) + slices.Sort(want) + if !slices.Equal(got, want) { + t.Fatalf("Values mismatch:\n\tgot %v\n\twant %v", got, want) + } + if got, want := si.Len(), ss.Len(); got != want { + t.Fatalf("Len() = %v, want %v", got, want) + } +} + +func addInt[T constraints.Integer](t testing.TB, ss Set[T], si *IntSet[T], v T) { + t.Helper() + if got, want := si.Contains(v), ss.Contains(v); got != want { + t.Fatalf("Contains(%v) = %v, want %v", v, got, want) + } + ss.Add(v) + si.Add(v) + if !si.Contains(v) { + t.Fatalf("Contains(%v) = false, want true", v) + } + if got, want := si.Len(), ss.Len(); got != want { + t.Fatalf("Len() = %v, want %v", got, want) + } +} + +func deleteInt[T constraints.Integer](t testing.TB, ss Set[T], si *IntSet[T], v T) { + t.Helper() + if got, want := si.Contains(v), ss.Contains(v); got != want { + t.Fatalf("Contains(%v) = %v, want %v", v, got, want) + } + ss.Delete(v) + si.Delete(v) + if si.Contains(v) { + t.Fatalf("Contains(%v) = true, want false", v) + } + if got, want := si.Len(), ss.Len(); got != want { + t.Fatalf("Len() = %v, want %v", got, want) + } +} + +func TestZigZag(t *testing.T) { + t.Run("Int64", func(t *testing.T) { + for _, tt := range []struct { + decoded int64 + encoded uint64 + }{ + {math.MinInt64, math.MaxUint64}, + {-2, 3}, + {-1, 1}, + {0, 0}, + {1, 2}, + {2, 4}, + {math.MaxInt64, math.MaxUint64 - 1}, + } { + encoded := encodeZigZag(tt.decoded) + if encoded != tt.encoded { + t.Errorf("encodeZigZag(%v) = %v, want %v", tt.decoded, encoded, tt.encoded) + } + decoded := decodeZigZag[int64](tt.encoded) + if decoded != tt.decoded { + t.Errorf("decodeZigZag(%v) = %v, want %v", tt.encoded, decoded, tt.decoded) + } + } + }) + t.Run("Uint64", func(t *testing.T) { + for _, tt := range []struct { + decoded uint64 + encoded uint64 + }{ + {0, 0}, + {1, 1}, + {2, 2}, + {math.MaxInt64, math.MaxInt64}, + {math.MaxUint64, math.MaxUint64}, + } { + encoded := encodeZigZag(tt.decoded) + if encoded != tt.encoded { + t.Errorf("encodeZigZag(%v) = %v, want %v", tt.decoded, encoded, tt.encoded) + } + decoded := decodeZigZag[uint64](tt.encoded) + if decoded != tt.decoded { + t.Errorf("decodeZigZag(%v) = %v, want %v", tt.encoded, decoded, tt.decoded) + } + } + }) +}