diff --git a/feature/relayserver/relayserver.go b/feature/relayserver/relayserver.go index 8e734bec9..f73689245 100644 --- a/feature/relayserver/relayserver.go +++ b/feature/relayserver/relayserver.go @@ -20,7 +20,6 @@ import ( "tailscale.com/ipn/ipnlocal" "tailscale.com/net/udprelay" "tailscale.com/tailcfg" - "tailscale.com/tsd" "tailscale.com/types/key" "tailscale.com/types/logger" "tailscale.com/types/ptr" @@ -40,7 +39,7 @@ func init() { // newExtension is an [ipnext.NewExtensionFn] that creates a new relay server // extension. It is registered with [ipnext.RegisterExtension] if the package is // imported. -func newExtension(logf logger.Logf, _ *tsd.System) (ipnext.Extension, error) { +func newExtension(logf logger.Logf, _ ipnext.SafeBackend) (ipnext.Extension, error) { return &extension{logf: logger.WithPrefix(logf, featureName+": ")}, nil } diff --git a/feature/taildrop/ext.go b/feature/taildrop/ext.go index 5d22cfb9b..b7cfdec72 100644 --- a/feature/taildrop/ext.go +++ b/feature/taildrop/ext.go @@ -7,7 +7,6 @@ import ( "tailscale.com/ipn/ipnext" "tailscale.com/ipn/ipnlocal" "tailscale.com/taildrop" - "tailscale.com/tsd" "tailscale.com/types/logger" ) @@ -15,7 +14,7 @@ func init() { ipnext.RegisterExtension("taildrop", newExtension) } -func newExtension(logf logger.Logf, _ *tsd.System) (ipnext.Extension, error) { +func newExtension(logf logger.Logf, b ipnext.SafeBackend) (ipnext.Extension, error) { return &extension{ logf: logger.WithPrefix(logf, "taildrop: "), }, nil @@ -23,7 +22,7 @@ func newExtension(logf logger.Logf, _ *tsd.System) (ipnext.Extension, error) { type extension struct { logf logger.Logf - lb *ipnlocal.LocalBackend + sb ipnext.SafeBackend mgr *taildrop.Manager } @@ -32,11 +31,6 @@ func (e *extension) Name() string { } func (e *extension) Init(h ipnext.Host) error { - type I interface { - Backend() ipnlocal.Backend - } - e.lb = h.(I).Backend().(*ipnlocal.LocalBackend) - // TODO(bradfitz): move init of taildrop.Manager from ipnlocal/peerapi.go to // here e.mgr = nil @@ -45,7 +39,11 @@ func (e *extension) Init(h ipnext.Host) error { } func (e *extension) Shutdown() error { - if mgr, err := e.lb.TaildropManager(); err == nil { + lb, ok := e.sb.(*ipnlocal.LocalBackend) + if !ok { + return nil + } + if mgr, err := lb.TaildropManager(); err == nil { mgr.Shutdown() } else { e.logf("taildrop: failed to shutdown taildrop manager: %v", err) diff --git a/ipn/auditlog/extension.go b/ipn/auditlog/extension.go index 90014b72e..509ab61a8 100644 --- a/ipn/auditlog/extension.go +++ b/ipn/auditlog/extension.go @@ -16,7 +16,6 @@ import ( "tailscale.com/ipn/ipnauth" "tailscale.com/ipn/ipnext" "tailscale.com/tailcfg" - "tailscale.com/tsd" "tailscale.com/types/lazy" "tailscale.com/types/logger" ) @@ -52,7 +51,7 @@ type extension struct { // newExtension is an [ipnext.NewExtensionFn] that creates a new audit log extension. // It is registered with [ipnext.RegisterExtension] if the package is imported. -func newExtension(logf logger.Logf, _ *tsd.System) (ipnext.Extension, error) { +func newExtension(logf logger.Logf, _ ipnext.SafeBackend) (ipnext.Extension, error) { return &extension{logf: logger.WithPrefix(logf, featureName+": ")}, nil } diff --git a/ipn/desktop/extension.go b/ipn/desktop/extension.go index 057b4cfe6..6c59b1e5a 100644 --- a/ipn/desktop/extension.go +++ b/ipn/desktop/extension.go @@ -17,7 +17,6 @@ import ( "tailscale.com/feature" "tailscale.com/ipn" "tailscale.com/ipn/ipnext" - "tailscale.com/tsd" "tailscale.com/types/logger" "tailscale.com/util/syspolicy" ) @@ -53,7 +52,7 @@ type desktopSessionsExt struct { // newDesktopSessionsExt returns a new [desktopSessionsExt], // or an error if a [SessionManager] cannot be created. // It is registered with [ipnext.RegisterExtension] if the package is imported. -func newDesktopSessionsExt(logf logger.Logf, sys *tsd.System) (ipnext.Extension, error) { +func newDesktopSessionsExt(logf logger.Logf, _ ipnext.SafeBackend) (ipnext.Extension, error) { logf = logger.WithPrefix(logf, featureName+": ") sm, err := NewSessionManager(logf) if err != nil { diff --git a/ipn/ipnext/ipnext.go b/ipn/ipnext/ipnext.go index 5c35192e4..b926ee23a 100644 --- a/ipn/ipnext/ipnext.go +++ b/ipn/ipnext/ipnext.go @@ -13,6 +13,7 @@ import ( "tailscale.com/ipn" "tailscale.com/ipn/ipnauth" "tailscale.com/tsd" + "tailscale.com/tstime" "tailscale.com/types/logger" "tailscale.com/types/views" "tailscale.com/util/mak" @@ -52,7 +53,7 @@ type Extension interface { // If the extension should be skipped at runtime, it must return either [SkipExtension] // or a wrapped [SkipExtension]. Any other error returned is fatal and will prevent // the LocalBackend from starting. -type NewExtensionFn func(logger.Logf, *tsd.System) (Extension, error) +type NewExtensionFn func(logger.Logf, SafeBackend) (Extension, error) // SkipExtension is an error returned by [NewExtensionFn] to indicate that the extension // should be skipped rather than prevent the LocalBackend from starting. @@ -78,8 +79,8 @@ func (d *Definition) Name() string { } // MakeExtension instantiates the extension. -func (d *Definition) MakeExtension(logf logger.Logf, sys *tsd.System) (Extension, error) { - ext, err := d.newFn(logf, sys) +func (d *Definition) MakeExtension(logf logger.Logf, sb SafeBackend) (Extension, error) { + ext, err := d.newFn(logf, sb) if err != nil { return nil, err } @@ -130,7 +131,7 @@ func Extensions() views.Slice[*Definition] { func DefinitionForTest(ext Extension) *Definition { return &Definition{ name: ext.Name(), - newFn: func(logger.Logf, *tsd.System) (Extension, error) { return ext, nil }, + newFn: func(logger.Logf, SafeBackend) (Extension, error) { return ext, nil }, } } @@ -140,7 +141,7 @@ func DefinitionForTest(ext Extension) *Definition { func DefinitionWithErrForTest(name string, err error) *Definition { return &Definition{ name: name, - newFn: func(logger.Logf, *tsd.System) (Extension, error) { return nil, err }, + newFn: func(logger.Logf, SafeBackend) (Extension, error) { return nil, err }, } } @@ -203,6 +204,19 @@ type Host interface { // It is a runtime error to register a nil provider or call after the host // has been initialized. RegisterControlClientCallback(NewControlClientCallback) + + // SendNotifyAsync sends a notification to the IPN bus, + // typically to the GUI client. + SendNotifyAsync(ipn.Notify) +} + +// SafeBackend is a subset of the [ipnlocal.LocalBackend] type's methods that +// are safe to call from extension hooks at any time (even hooks called while +// LocalBackend's internal mutex is held). +type SafeBackend interface { + Sys() *tsd.System + Clock() tstime.Clock + TailscaleVarRoot() string } // ExtensionServices provides access to the [Host]'s extension management services, diff --git a/ipn/ipnlocal/extension_host.go b/ipn/ipnlocal/extension_host.go index a7a764ebc..85da27ab0 100644 --- a/ipn/ipnlocal/extension_host.go +++ b/ipn/ipnlocal/extension_host.go @@ -20,7 +20,6 @@ import ( "tailscale.com/ipn/ipnauth" "tailscale.com/ipn/ipnext" "tailscale.com/tailcfg" - "tailscale.com/tsd" "tailscale.com/types/logger" "tailscale.com/util/execqueue" "tailscale.com/util/testenv" @@ -131,15 +130,32 @@ type Backend interface { // SwitchToBestProfile switches to the best profile for the current state of the system. // The reason indicates why the profile is being switched. SwitchToBestProfile(reason string) + + SendNotify(ipn.Notify) + ipnext.SafeBackend } // NewExtensionHost returns a new [ExtensionHost] which manages registered extensions for the given backend. // The extensions are instantiated, but are not initialized until [ExtensionHost.Init] is called. // It returns an error if instantiating any extension fails. +func NewExtensionHost(logf logger.Logf, b Backend) (*ExtensionHost, error) { + return newExtensionHost(logf, b) +} + +func NewExtensionHostForTest(logf logger.Logf, b Backend, overrideExts ...*ipnext.Definition) (*ExtensionHost, error) { + if !testenv.InTest() { + panic("use outside of test") + } + return newExtensionHost(logf, b, overrideExts...) +} + +// newExtensionHost is the shared implementation of [NewExtensionHost] and +// [NewExtensionHostForTest]. // -// If overrideExts is non-nil, the registered extensions are ignored and the provided extensions are used instead. -// Overriding extensions is primarily used for testing. -func NewExtensionHost(logf logger.Logf, sys *tsd.System, b Backend, overrideExts ...*ipnext.Definition) (_ *ExtensionHost, err error) { +// If overrideExts is non-nil, the registered extensions are ignored and the +// provided extensions are used instead. Overriding extensions is primarily used +// for testing. +func newExtensionHost(logf logger.Logf, b Backend, overrideExts ...*ipnext.Definition) (_ *ExtensionHost, err error) { host := &ExtensionHost{ b: b, logf: logger.WithPrefix(logf, "ipnext: "), @@ -172,7 +188,7 @@ func NewExtensionHost(logf logger.Logf, sys *tsd.System, b Backend, overrideExts host.allExtensions = make([]ipnext.Extension, 0, numExts) for _, d := range exts { - ext, err := d.MakeExtension(logf, sys) + ext, err := d.MakeExtension(logf, b) if errors.Is(err, ipnext.SkipExtension) { // The extension wants to be skipped. host.logf("%q: %v", d.Name(), err) @@ -334,12 +350,14 @@ func (h *ExtensionHost) SwitchToBestProfileAsync(reason string) { }) } -// Backend returns the [Backend] used by the extension host. -func (h *ExtensionHost) Backend() Backend { +// SendNotifyAsync implements [ipnext.Host]. +func (h *ExtensionHost) SendNotifyAsync(n ipn.Notify) { if h == nil { - return nil + return } - return h.b + h.enqueueBackendOperation(func(b Backend) { + b.SendNotify(n) + }) } // addFuncHook appends non-nil fn to hooks. diff --git a/ipn/ipnlocal/extension_host_test.go b/ipn/ipnlocal/extension_host_test.go index 01122073a..31b38196a 100644 --- a/ipn/ipnlocal/extension_host_test.go +++ b/ipn/ipnlocal/extension_host_test.go @@ -27,7 +27,9 @@ import ( "tailscale.com/tailcfg" "tailscale.com/tsd" "tailscale.com/tstest" + "tailscale.com/tstime" "tailscale.com/types/key" + "tailscale.com/types/lazy" "tailscale.com/types/persist" "tailscale.com/util/must" ) @@ -284,7 +286,7 @@ func TestNewExtensionHost(t *testing.T) { t.Run(tt.name, func(t *testing.T) { t.Parallel() logf := tstest.WhileTestRunningLogger(t) - h, err := NewExtensionHost(logf, tsd.NewSystem(), &testBackend{}, tt.defs...) + h, err := NewExtensionHostForTest(logf, &testBackend{}, tt.defs...) if gotErr := err != nil; gotErr != tt.wantErr { t.Errorf("NewExtensionHost: gotErr %v(%v); wantErr %v", gotErr, err, tt.wantErr) } @@ -1095,7 +1097,7 @@ func newExtensionHostForTest[T ipnext.Extension](t *testing.T, b Backend, initia } defs[i] = ipnext.DefinitionForTest(ext) } - h, err := NewExtensionHost(logf, tsd.NewSystem(), b, defs...) + h, err := NewExtensionHostForTest(logf, b, defs...) if err != nil { t.Fatalf("NewExtensionHost: %v", err) } @@ -1320,6 +1322,7 @@ func (q *testExecQueue) Wait(context.Context) error { return nil } // testBackend implements [ipnext.Backend] for testing purposes // by calling the provided hooks when its methods are called. type testBackend struct { + lazySys lazy.SyncValue[*tsd.System] switchToBestProfileHook func(reason string) // mu protects the backend state. @@ -1328,6 +1331,13 @@ type testBackend struct { mu sync.Mutex } +func (b *testBackend) Clock() tstime.Clock { return tstime.StdClock{} } +func (b *testBackend) Sys() *tsd.System { + return b.lazySys.Get(tsd.NewSystem) +} +func (b *testBackend) SendNotify(ipn.Notify) { panic("not implemented") } +func (b *testBackend) TailscaleVarRoot() string { panic("not implemented") } + func (b *testBackend) SwitchToBestProfile(reason string) { b.mu.Lock() defer b.mu.Unlock() diff --git a/ipn/ipnlocal/local.go b/ipn/ipnlocal/local.go index ef5ec267f..d60f05b11 100644 --- a/ipn/ipnlocal/local.go +++ b/ipn/ipnlocal/local.go @@ -525,7 +525,7 @@ func NewLocalBackend(logf logger.Logf, logID logid.PublicID, sys *tsd.System, lo } } - if b.extHost, err = NewExtensionHost(logf, sys, b); err != nil { + if b.extHost, err = NewExtensionHost(logf, b); err != nil { return nil, fmt.Errorf("failed to create extension host: %w", err) } b.pm.SetExtensionHost(b.extHost) @@ -589,6 +589,7 @@ func NewLocalBackend(logf logger.Logf, logID logid.PublicID, sys *tsd.System, lo } func (b *LocalBackend) Clock() tstime.Clock { return b.clock } +func (b *LocalBackend) Sys() *tsd.System { return b.sys } // FindExtensionByName returns an active extension with the given name, // or nil if no such extension exists. @@ -3187,6 +3188,12 @@ func (b *LocalBackend) send(n ipn.Notify) { b.sendTo(n, allClients) } +// SendNotify sends a notification to the IPN bus, +// typically to the GUI client. +func (b *LocalBackend) SendNotify(n ipn.Notify) { + b.send(n) +} + // notificationTarget describes a notification recipient. // A zero value is valid and indicate that the notification // should be broadcast to all active [watchSession]s.