net/netmon, wgengine/magicsock: simplify LinkChangeLogLimiter signature

Remove the need for the caller to hold on to and call an unregister
function. Both two callers (one real, one test) already have a context
they can use. Use context.AfterFunc instead. There are no observable
side effects from scheduling too late if the goroutine doesn't run sync.

Updates #17148

Change-Id: Ie697dae0e797494fa8ef27fbafa193bfe5ceb307
Signed-off-by: Brad Fitzpatrick <bradfitz@tailscale.com>
This commit is contained in:
Brad Fitzpatrick
2025-09-15 15:49:56 -07:00
committed by Brad Fitzpatrick
parent 5c24f0ed80
commit 8b48f3847d
3 changed files with 25 additions and 18 deletions

View File

@@ -4,6 +4,7 @@
package netmon package netmon
import ( import (
"context"
"sync" "sync"
"tailscale.com/types/logger" "tailscale.com/types/logger"
@@ -12,16 +13,17 @@ import (
// LinkChangeLogLimiter returns a new [logger.Logf] that logs each unique // LinkChangeLogLimiter returns a new [logger.Logf] that logs each unique
// format string to the underlying logger only once per major LinkChange event. // format string to the underlying logger only once per major LinkChange event.
// //
// The returned function should be called when the logger is no longer needed, // The logger stops tracking seen format strings when the provided context is
// to release resources from the Monitor. // done.
func LinkChangeLogLimiter(logf logger.Logf, nm *Monitor) (_ logger.Logf, unregister func()) { func LinkChangeLogLimiter(ctx context.Context, logf logger.Logf, nm *Monitor) logger.Logf {
var formatSeen sync.Map // map[string]bool var formatSeen sync.Map // map[string]bool
unregister = nm.RegisterChangeCallback(func(cd *ChangeDelta) { unregister := nm.RegisterChangeCallback(func(cd *ChangeDelta) {
// If we're in a major change or a time jump, clear the seen map. // If we're in a major change or a time jump, clear the seen map.
if cd.Major || cd.TimeJumped { if cd.Major || cd.TimeJumped {
formatSeen.Clear() formatSeen.Clear()
} }
}) })
context.AfterFunc(ctx, unregister)
return func(format string, args ...any) { return func(format string, args ...any) {
// We only store 'true' in the map, so if it's present then it // We only store 'true' in the map, so if it's present then it
@@ -38,5 +40,5 @@ func LinkChangeLogLimiter(logf logger.Logf, nm *Monitor) (_ logger.Logf, unregis
} }
logf(format, args...) logf(format, args...)
}, unregister }
} }

View File

@@ -5,13 +5,17 @@ package netmon
import ( import (
"bytes" "bytes"
"context"
"fmt" "fmt"
"testing" "testing"
"testing/synctest"
"tailscale.com/util/eventbus" "tailscale.com/util/eventbus"
) )
func TestLinkChangeLogLimiter(t *testing.T) { func TestLinkChangeLogLimiter(t *testing.T) { synctest.Test(t, syncTestLinkChangeLogLimiter) }
func syncTestLinkChangeLogLimiter(t *testing.T) {
bus := eventbus.New() bus := eventbus.New()
defer bus.Close() defer bus.Close()
mon, err := New(bus, t.Logf) mon, err := New(bus, t.Logf)
@@ -30,8 +34,10 @@ func TestLinkChangeLogLimiter(t *testing.T) {
fmt.Fprintf(&logBuffer, format, args...) fmt.Fprintf(&logBuffer, format, args...)
} }
logf, unregister := LinkChangeLogLimiter(logf, mon) ctx, cancel := context.WithCancel(t.Context())
defer unregister() defer cancel()
logf = LinkChangeLogLimiter(ctx, logf, mon)
// Log once, which should write to our log buffer. // Log once, which should write to our log buffer.
logf("hello %s", "world") logf("hello %s", "world")
@@ -72,8 +78,11 @@ func TestLinkChangeLogLimiter(t *testing.T) {
t.Errorf("unexpected log buffer contents: %q", got) t.Errorf("unexpected log buffer contents: %q", got)
} }
// Unregistering the callback should clear our 'cbs' set. // Canceling the context we passed to LinkChangeLogLimiter should
unregister() // unregister the callback from the netmon.
cancel()
synctest.Wait()
mon.mu.Lock() mon.mu.Lock()
if len(mon.cbs) != 0 { if len(mon.cbs) != 0 {
t.Errorf("expected no callbacks, got %v", mon.cbs) t.Errorf("expected no callbacks, got %v", mon.cbs)

View File

@@ -209,10 +209,6 @@ type Conn struct {
// port mappings from NAT devices. // port mappings from NAT devices.
portMapper *portmapper.Client portMapper *portmapper.Client
// portMapperLogfUnregister is the function to call to unregister
// the portmapper log limiter.
portMapperLogfUnregister func()
// derpRecvCh is used by receiveDERP to read DERP messages. // derpRecvCh is used by receiveDERP to read DERP messages.
// It must have buffer size > 0; see issue 3736. // It must have buffer size > 0; see issue 3736.
derpRecvCh chan derpReadResult derpRecvCh chan derpReadResult
@@ -748,10 +744,13 @@ func NewConn(opts Options) (*Conn, error) {
c.subsDoneCh = make(chan struct{}) c.subsDoneCh = make(chan struct{})
go c.consumeEventbusTopics() go c.consumeEventbusTopics()
c.connCtx, c.connCtxCancel = context.WithCancel(context.Background())
c.donec = c.connCtx.Done()
// Don't log the same log messages possibly every few seconds in our // Don't log the same log messages possibly every few seconds in our
// portmapper. // portmapper.
portmapperLogf := logger.WithPrefix(c.logf, "portmapper: ") portmapperLogf := logger.WithPrefix(c.logf, "portmapper: ")
portmapperLogf, c.portMapperLogfUnregister = netmon.LinkChangeLogLimiter(portmapperLogf, opts.NetMon) portmapperLogf = netmon.LinkChangeLogLimiter(c.connCtx, portmapperLogf, opts.NetMon)
portMapOpts := &portmapper.DebugKnobs{ portMapOpts := &portmapper.DebugKnobs{
DisableAll: func() bool { return opts.DisablePortMapper || c.onlyTCP443.Load() }, DisableAll: func() bool { return opts.DisablePortMapper || c.onlyTCP443.Load() },
} }
@@ -772,8 +771,6 @@ func NewConn(opts Options) (*Conn, error) {
return nil, err return nil, err
} }
c.connCtx, c.connCtxCancel = context.WithCancel(context.Background())
c.donec = c.connCtx.Done()
c.netChecker = &netcheck.Client{ c.netChecker = &netcheck.Client{
Logf: logger.WithPrefix(c.logf, "netcheck: "), Logf: logger.WithPrefix(c.logf, "netcheck: "),
NetMon: c.netMon, NetMon: c.netMon,
@@ -3330,7 +3327,6 @@ func (c *Conn) Close() error {
} }
c.stopPeriodicReSTUNTimerLocked() c.stopPeriodicReSTUNTimerLocked()
c.portMapper.Close() c.portMapper.Close()
c.portMapperLogfUnregister()
c.peerMap.forEachEndpoint(func(ep *endpoint) { c.peerMap.forEachEndpoint(func(ep *endpoint) {
ep.stopAndReset() ep.stopAndReset()