util/topk: prevent duplicate elements

Previously, we would insert duplicate elements into the TopK
datastructure while it wasn't fully populated, which resulted in
incorrect data due to a single "top" entry being split among multiple
elements.

Updates tailscale/corp#25479

Signed-off-by: Andrew Dunham <andrew@du.nham.ca>
Change-Id: Ibf118bd41534dc53e0ce9be68cfcbda267b4de77
This commit is contained in:
Andrew Dunham 2025-01-10 17:43:59 -05:00
parent 2af255790d
commit 943f930bd8
2 changed files with 82 additions and 12 deletions

View File

@ -16,11 +16,12 @@ import (
// TopK is a probabilistic counter of the top K items, using a count-min sketch
// to keep track of item counts and a heap to track the top K of them.
type TopK[T any] struct {
heap minHeap[T]
k int
sf SerializeFunc[T]
cms CountMinSketch
type TopK[T comparable] struct {
heap minHeap[T]
positions map[T]int
k int
sf SerializeFunc[T]
cms CountMinSketch
}
// HashFunc is responsible for providing a []byte serialization of a value,
@ -31,18 +32,19 @@ type SerializeFunc[T any] func([]byte, T) []byte
// New creates a new TopK that stores k values. Parameters for the underlying
// count-min sketch are chosen for a 0.1% error rate and a 0.1% probability of
// error.
func New[T any](k int, sf SerializeFunc[T]) *TopK[T] {
func New[T comparable](k int, sf SerializeFunc[T]) *TopK[T] {
hashes, buckets := PickParams(0.001, 0.001)
return NewWithParams(k, sf, hashes, buckets)
}
// NewWithParams creates a new TopK that stores k values, and additionally
// allows customizing the parameters for the underlying count-min sketch.
func NewWithParams[T any](k int, sf SerializeFunc[T], numHashes, numCols int) *TopK[T] {
func NewWithParams[T comparable](k int, sf SerializeFunc[T], numHashes, numCols int) *TopK[T] {
ret := &TopK[T]{
heap: make(minHeap[T], 0, k),
k: k,
sf: sf,
heap: make(minHeap[T], 0, k),
positions: make(map[T]int, k),
k: k,
sf: sf,
}
ret.cms.init(numHashes, numCols)
return ret
@ -69,21 +71,38 @@ func (tk *TopK[T]) AddN(val T, count uint64) uint64 {
vcount := tk.cms.AddN(ser, count)
// If we don't have a full heap, just push it.
// Check if this item is already in the heap; if so, we can just update
// the count and fix the heap.
if pos, exists := tk.positions[val]; exists {
tk.heap[pos].count = vcount
heap.Fix(&tk.heap, pos)
return vcount
}
// If we don't have a full heap, we add this item to the heap and
// return without checking the heap minimum.
if len(tk.heap) < tk.k {
pos := len(tk.heap)
heap.Push(&tk.heap, mhValue[T]{
count: vcount,
val: val,
})
tk.positions[val] = pos
return vcount
}
// If this item's count surpasses the heap's minimum, update the heap.
// If this item's count surpasses the heap's minimum, replace the
// minimum value with this item.
if vcount > tk.heap[0].count {
// Remove old item from positions map
delete(tk.positions, tk.heap[0].val)
// Update heap
tk.heap[0] = mhValue[T]{
count: vcount,
val: val,
}
tk.positions[val] = 0
heap.Fix(&tk.heap, 0)
}
return vcount

View File

@ -68,6 +68,57 @@ func TestTopK(t *testing.T) {
t.Errorf("top K mismatch\ngot: %v\nwant: %v", got, want)
}
func TestTopKNoDuplicates(t *testing.T) {
// Create a TopK that tracks top 5 elements
topk := New[string](5, func(in []byte, val string) []byte {
return append(in, []byte(val)...)
})
// Add a single element many times
const commonElement = "very-common"
for i := 0; i < 500; i++ {
topk.Add(commonElement)
}
// We should only have a single "top" element here, despite having
// added the same element 500 times.
if n := len(topk.Top()); n != 1 {
t.Errorf("expected only one element, got %d", n)
}
// Add some less frequent elements
for i := 0; i < 5; i++ {
topk.Add(fmt.Sprintf("less-common-%d", i))
}
// Add common element again
for i := 0; i < 500; i++ {
topk.Add(commonElement)
}
// Get the top elements
results := topk.Top()
// Count occurrences of the common element
commonCount := 0
for _, res := range results {
if res == commonElement {
commonCount++
}
}
if commonCount > 1 {
t.Errorf("common element appeared %d times in results, want 1", commonCount)
} else if commonCount == 0 {
t.Error("common element did not appear in results")
}
// We expect that the common element is last (i.e. "top") in the returned list.
if idx := len(results) - 1; results[idx] != commonElement {
t.Errorf("common element not last in results: %q", results[idx])
}
}
func TestPickParams(t *testing.T) {
hashes, buckets := PickParams(
0.001, // 0.1% error rate