net/netmon: remove usage of direct callbacks from netmon (#17292)

The callback itself is not removed as it is used in other repos, making
it simpler for those to slowly transition to the eventbus.

Updates #15160

Signed-off-by: Claus Lensbøl <claus@tailscale.com>
This commit is contained in:
Claus Lensbøl
2025-10-01 14:59:38 -04:00
committed by GitHub
parent 6f7ce5eb5d
commit ce752b8a88
28 changed files with 217 additions and 48 deletions

View File

@@ -30,6 +30,7 @@ import (
"tailscale.com/types/logger"
"tailscale.com/util/clientmetric"
"tailscale.com/util/dnsname"
"tailscale.com/util/eventbus"
"tailscale.com/util/slicesx"
"tailscale.com/util/syspolicy/policyclient"
)
@@ -600,7 +601,7 @@ func (m *Manager) FlushCaches() error {
// No other state needs to be instantiated before this runs.
//
// health must not be nil
func CleanUp(logf logger.Logf, netMon *netmon.Monitor, health *health.Tracker, interfaceName string) {
func CleanUp(logf logger.Logf, netMon *netmon.Monitor, bus *eventbus.Bus, health *health.Tracker, interfaceName string) {
if !buildfeatures.HasDNS {
return
}
@@ -611,6 +612,7 @@ func CleanUp(logf logger.Logf, netMon *netmon.Monitor, health *health.Tracker, i
}
d := &tsdial.Dialer{Logf: logf}
d.SetNetMon(netMon)
d.SetBus(bus)
dns := NewManager(logf, oscfg, health, d, nil, nil, runtime.GOOS)
if err := dns.Down(); err != nil {
logf("dns down: %v", err)

View File

@@ -90,7 +90,10 @@ func TestDNSOverTCP(t *testing.T) {
SearchDomains: fqdns("coffee.shop"),
},
}
m := NewManager(t.Logf, &f, health.NewTracker(eventbustest.NewBus(t)), tsdial.NewDialer(netmon.NewStatic()), nil, nil, "")
bus := eventbustest.NewBus(t)
dialer := tsdial.NewDialer(netmon.NewStatic())
dialer.SetBus(bus)
m := NewManager(t.Logf, &f, health.NewTracker(bus), dialer, nil, nil, "")
m.resolver.TestOnlySetHook(f.SetResolver)
m.Set(Config{
Hosts: hosts(
@@ -175,7 +178,10 @@ func TestDNSOverTCP_TooLarge(t *testing.T) {
SearchDomains: fqdns("coffee.shop"),
},
}
m := NewManager(log, &f, health.NewTracker(eventbustest.NewBus(t)), tsdial.NewDialer(netmon.NewStatic()), nil, nil, "")
bus := eventbustest.NewBus(t)
dialer := tsdial.NewDialer(netmon.NewStatic())
dialer.SetBus(bus)
m := NewManager(log, &f, health.NewTracker(bus), dialer, nil, nil, "")
m.resolver.TestOnlySetHook(f.SetResolver)
m.Set(Config{
Hosts: hosts("andrew.ts.com.", "1.2.3.4"),

View File

@@ -933,7 +933,10 @@ func TestManager(t *testing.T) {
goos = "linux"
}
knobs := &controlknobs.Knobs{}
m := NewManager(t.Logf, &f, health.NewTracker(eventbustest.NewBus(t)), tsdial.NewDialer(netmon.NewStatic()), nil, knobs, goos)
bus := eventbustest.NewBus(t)
dialer := tsdial.NewDialer(netmon.NewStatic())
dialer.SetBus(bus)
m := NewManager(t.Logf, &f, health.NewTracker(bus), dialer, nil, knobs, goos)
m.resolver.TestOnlySetHook(f.SetResolver)
if err := m.Set(test.in); err != nil {
@@ -1039,7 +1042,10 @@ func TestConfigRecompilation(t *testing.T) {
SearchDomains: fqdns("foo.ts.net"),
}
m := NewManager(t.Logf, f, health.NewTracker(eventbustest.NewBus(t)), tsdial.NewDialer(netmon.NewStatic()), nil, nil, "darwin")
bus := eventbustest.NewBus(t)
dialer := tsdial.NewDialer(netmon.NewStatic())
dialer.SetBus(bus)
m := NewManager(t.Logf, f, health.NewTracker(bus), dialer, nil, nil, "darwin")
var managerConfig *resolver.Config
m.resolver.TestOnlySetHook(func(cfg resolver.Config) {

View File

@@ -122,7 +122,6 @@ func TestResolversWithDelays(t *testing.T) {
}
})
}
}
func TestGetRCode(t *testing.T) {
@@ -454,6 +453,7 @@ func runTestQuery(tb testing.TB, request []byte, modify func(*forwarder), ports
var dialer tsdial.Dialer
dialer.SetNetMon(netMon)
dialer.SetBus(bus)
fwd := newForwarder(logf, netMon, nil, &dialer, health.NewTracker(bus), nil)
if modify != nil {

View File

@@ -353,10 +353,13 @@ func TestRDNSNameToIPv6(t *testing.T) {
}
func newResolver(t testing.TB) *Resolver {
bus := eventbustest.NewBus(t)
dialer := tsdial.NewDialer(netmon.NewStatic())
dialer.SetBus(bus)
return New(t.Logf,
nil, // no link selector
tsdial.NewDialer(netmon.NewStatic()),
health.NewTracker(eventbustest.NewBus(t)),
dialer,
health.NewTracker(bus),
nil, // no control knobs
)
}

View File

@@ -8,6 +8,7 @@ import (
"sync"
"tailscale.com/types/logger"
"tailscale.com/util/eventbus"
)
// LinkChangeLogLimiter returns a new [logger.Logf] that logs each unique
@@ -17,13 +18,12 @@ import (
// done.
func LinkChangeLogLimiter(ctx context.Context, logf logger.Logf, nm *Monitor) logger.Logf {
var formatSeen sync.Map // map[string]bool
unregister := nm.RegisterChangeCallback(func(cd *ChangeDelta) {
nm.b.Monitor(nm.changeDeltaWatcher(nm.b, ctx, func(cd ChangeDelta) {
// If we're in a major change or a time jump, clear the seen map.
if cd.Major || cd.TimeJumped {
formatSeen.Clear()
}
})
context.AfterFunc(ctx, unregister)
}))
return func(format string, args ...any) {
// We only store 'true' in the map, so if it's present then it
@@ -42,3 +42,19 @@ func LinkChangeLogLimiter(ctx context.Context, logf logger.Logf, nm *Monitor) lo
logf(format, args...)
}
}
func (nm *Monitor) changeDeltaWatcher(ec *eventbus.Client, ctx context.Context, fn func(ChangeDelta)) func(*eventbus.Client) {
sub := eventbus.Subscribe[ChangeDelta](ec)
return func(ec *eventbus.Client) {
for {
select {
case <-ctx.Done():
return
case <-sub.Done():
return
case change := <-sub.Events():
fn(change)
}
}
}
}

View File

@@ -11,6 +11,7 @@ import (
"testing/synctest"
"tailscale.com/util/eventbus"
"tailscale.com/util/eventbus/eventbustest"
)
func TestLinkChangeLogLimiter(t *testing.T) { synctest.Test(t, syncTestLinkChangeLogLimiter) }
@@ -61,21 +62,15 @@ func syncTestLinkChangeLogLimiter(t *testing.T) {
// string cache and allow the next log to write to our log buffer.
//
// InjectEvent doesn't work because it's not a major event, so we
// instead reach into the netmon and grab the callback, and then call
// it ourselves.
mon.mu.Lock()
var cb func(*ChangeDelta)
for _, c := range mon.cbs {
cb = c
break
}
mon.mu.Unlock()
cb(&ChangeDelta{Major: true})
// instead inject the event ourselves.
injector := eventbustest.NewInjector(t, bus)
eventbustest.Inject(injector, ChangeDelta{Major: true})
synctest.Wait()
logf("hello %s", "world")
if got := logBuffer.String(); got != "hello world\nother message\nhello world\n" {
t.Errorf("unexpected log buffer contents: %q", got)
want := "hello world\nother message\nhello world\n"
if got := logBuffer.String(); got != want {
t.Errorf("unexpected log buffer contents, got: %q, want, %q", got, want)
}
// Canceling the context we passed to LinkChangeLogLimiter should

View File

@@ -28,6 +28,7 @@ import (
"tailscale.com/types/logger"
"tailscale.com/types/netmap"
"tailscale.com/util/clientmetric"
"tailscale.com/util/eventbus"
"tailscale.com/util/mak"
"tailscale.com/util/testenv"
"tailscale.com/version"
@@ -86,6 +87,8 @@ type Dialer struct {
dnsCache *dnscache.MessageCache // nil until first non-empty SetExitDNSDoH
nextSysConnID int
activeSysConns map[int]net.Conn // active connections not yet closed
eventClient *eventbus.Client
eventBusSubs eventbus.Monitor
}
// sysConn wraps a net.Conn that was created using d.SystemDial.
@@ -158,6 +161,9 @@ func (d *Dialer) SetRoutes(routes, localRoutes []netip.Prefix) {
}
func (d *Dialer) Close() error {
if d.eventClient != nil {
d.eventBusSubs.Close()
}
d.mu.Lock()
defer d.mu.Unlock()
d.closed = true
@@ -186,6 +192,14 @@ func (d *Dialer) SetNetMon(netMon *netmon.Monitor) {
d.netMonUnregister = nil
}
d.netMon = netMon
// Having multiple watchers could lead to problems,
// so remove the eventClient if it exists.
// This should really not happen, but better checking for it than not.
// TODO(cmol): Should this just be a panic?
if d.eventClient != nil {
d.eventBusSubs.Close()
d.eventClient = nil
}
d.netMonUnregister = d.netMon.RegisterChangeCallback(d.linkChanged)
}
@@ -197,6 +211,35 @@ func (d *Dialer) NetMon() *netmon.Monitor {
return d.netMon
}
func (d *Dialer) SetBus(bus *eventbus.Bus) {
d.mu.Lock()
defer d.mu.Unlock()
if d.eventClient != nil {
panic("eventbus has already been set")
}
// Having multiple watchers could lead to problems,
// so unregister the callback if it exists.
if d.netMonUnregister != nil {
d.netMonUnregister()
}
d.eventClient = bus.Client("tsdial.Dialer")
d.eventBusSubs = d.eventClient.Monitor(d.linkChangeWatcher(d.eventClient))
}
func (d *Dialer) linkChangeWatcher(ec *eventbus.Client) func(*eventbus.Client) {
linkChangeSub := eventbus.Subscribe[netmon.ChangeDelta](ec)
return func(ec *eventbus.Client) {
for {
select {
case <-ec.Done():
return
case cd := <-linkChangeSub.Events():
d.linkChanged(&cd)
}
}
}
}
var (
metricLinkChangeConnClosed = clientmetric.NewCounter("tsdial_linkchange_closes")
metricChangeDeltaNoDefaultRoute = clientmetric.NewCounter("tsdial_changedelta_no_default_route")