util/singleflight: release empty Group.m with 32+ capacity

Updates tailscale/corp#30942

Signed-off-by: Paul Scott <408401+icio@users.noreply.github.com>
This commit is contained in:
Paul Scott
2025-08-12 13:04:43 +01:00
parent cde65dba16
commit c794771d85
2 changed files with 74 additions and 4 deletions

View File

@@ -81,8 +81,9 @@ type call[V any] struct {
// Group represents a class of work and forms a namespace in
// which units of work can be executed with duplicate suppression.
type Group[K comparable, V any] struct {
mu sync.Mutex // protects m
m map[K]*call[V] // lazily initialized
mu sync.Mutex // protects m
m map[K]*call[V] // lazily initialized
maxLen int // maximum len(m) when doCall completes, used for cleanup
}
// Result holds the results of Do, so they can be passed
@@ -254,7 +255,7 @@ func (g *Group[K, V]) doCall(c *call[V], key K, fn func() (V, error)) {
defer g.mu.Unlock()
c.wg.Done()
if g.m[key] == c {
delete(g.m, key)
g.deleteLocked(key)
}
if e, ok := c.err.(*panicError); ok {
@@ -301,11 +302,29 @@ func (g *Group[K, V]) doCall(c *call[V], key K, fn func() (V, error)) {
}
}
func (g *Group[K, B]) deleteLocked(key K) {
delete(g.m, key)
n := len(g.m)
if n > g.maxLen {
g.maxLen = n + 1
return
}
if n > 0 || g.maxLen < 32 {
return
}
// Release g.m for its memory to be reclaimed.
g.maxLen = 0
g.m = nil
}
// Forget tells the singleflight to forget about a key. Future calls
// to Do for this key will call the function rather than waiting for
// an earlier call to complete.
func (g *Group[K, V]) Forget(key K) {
g.mu.Lock()
delete(g.m, key)
g.deleteLocked(key)
g.mu.Unlock()
}

View File

@@ -21,6 +21,8 @@ import (
"sync/atomic"
"testing"
"time"
"golang.org/x/sync/errgroup"
)
func TestDo(t *testing.T) {
@@ -474,3 +476,52 @@ func assertOKResult[V comparable](t testing.TB, res Result[V], want V) {
t.Fatalf("unexpected value; got %v, want %v", res.Val, want)
}
}
func TestRelease(t *testing.T) {
var sg Group[int, int]
var wg errgroup.Group
var startup sync.WaitGroup
release := make(chan struct{})
// Start 50 singleflight goroutines.
for key := range 50 {
startup.Add(1)
wg.Go(func() error {
keyRet, err, shared := sg.Do(key, func() (int, error) {
startup.Done()
<-release
return key, nil
})
if err != nil {
return fmt.Errorf("Do(%d) return error: %s", key, err)
}
if shared {
return fmt.Errorf("Do(%d) returned shared result, expected unshared", key)
}
if key != keyRet {
return fmt.Errorf("Do(%d) = %d, want %d", key, keyRet, key)
}
return nil
})
}
// Wait for all sg.Do goroutines to be executing their function.
// sg.m will point to all of them.
startup.Wait()
if got, want := len(sg.m), 50; got != want {
t.Fatalf("len(sg.m) = %d, want %d", got, want)
}
// Let the sg.Do goroutines return from their function.
close(release)
err := wg.Wait()
if err != nil {
t.Fatalf("error from worker: %s", err)
}
// Test for cleanup.
if sg.m != nil {
t.Fatal("sg.m != nil, want nil - cleanup didn't happen")
}
}