diff --git a/util/set/smallset.go b/util/set/smallset.go index 51cad6a25..1b77419d2 100644 --- a/util/set/smallset.go +++ b/util/set/smallset.go @@ -50,6 +50,15 @@ func (s SmallSet[T]) Contains(e T) bool { return e != zero && s.one == e } +// SoleElement returns the single value in the set, if the set has exactly one +// element. +// +// If the set is empty or has more than one element, ok will be false and e will +// be the zero value of T. +func (s SmallSet[T]) SoleElement() (e T, ok bool) { + return s.one, s.Len() == 1 +} + // Add adds e to the set. // // When storing a SmallSet in a map as a value type, it is important to @@ -61,10 +70,15 @@ func (s *SmallSet[T]) Add(e T) { s.m.Add(e) return } - // Size zero to one non-zero element. - if s.one == zero && e != zero { - s.one = e - return + // Non-zero elements can go into s.one. + if e != zero { + if s.one == zero { + s.one = e // Len 0 to Len 1 + return + } + if s.one == e { + return // dup + } } // Need to make a multi map, either // because we now have two items, or @@ -73,7 +87,7 @@ func (s *SmallSet[T]) Add(e T) { if s.one != zero { s.m.Add(s.one) // move single item to multi } - s.m.Add(e) // add new item + s.m.Add(e) // add new item, possibly zero s.one = zero } diff --git a/util/set/smallset_test.go b/util/set/smallset_test.go index 2635bc893..d6f446df0 100644 --- a/util/set/smallset_test.go +++ b/util/set/smallset_test.go @@ -84,8 +84,43 @@ func TestSmallSet(t *testing.T) { t.Errorf("contains(%v) mismatch after ops %s: normal=%v, small=%v", e, name(), normal.Contains(e), small.Contains(e)) } } + + if err := small.checkInvariants(); err != nil { + t.Errorf("checkInvariants failed after ops %s: %v", name(), err) + } + + if !t.Failed() { + sole, ok := small.SoleElement() + if ok != (small.Len() == 1) { + t.Errorf("SoleElement ok mismatch after ops %s: SoleElement ok=%v, want=%v", name(), ok, !ok) + } + if ok && sole != smallEle[0] { + t.Errorf("SoleElement value mismatch after ops %s: SoleElement=%v, want=%v", name(), sole, smallEle[0]) + t.Errorf("Internals: %+v", small) + } + } } } } } } + +func (s *SmallSet[T]) checkInvariants() error { + var zero T + if s.m != nil && s.one != zero { + return fmt.Errorf("both m and one are non-zero") + } + if s.m != nil { + switch len(s.m) { + case 0: + return fmt.Errorf("m is non-nil but empty") + case 1: + for k := range s.m { + if k != zero { + return fmt.Errorf("m contains exactly 1 non-zero element, %v", k) + } + } + } + } + return nil +}