diff --git a/ipn/auditlog/extension.go b/ipn/auditlog/extension.go index 3b561b2e5..90014b72e 100644 --- a/ipn/auditlog/extension.go +++ b/ipn/auditlog/extension.go @@ -36,8 +36,6 @@ func init() { type extension struct { logf logger.Logf - // cleanup are functions to call on shutdown. - cleanup []func() // store is the log store shared by all loggers. // It is created when the first logger is started. store lazy.SyncValue[LogStore] @@ -66,11 +64,9 @@ func (e *extension) Name() string { // Init implements [ipnext.Extension] by registering callbacks and providers // for the duration of the extension's lifetime. func (e *extension) Init(h ipnext.Host) error { - e.cleanup = []func(){ - h.RegisterControlClientCallback(e.controlClientChanged), - h.Profiles().RegisterProfileStateChangeCallback(e.profileChanged), - h.RegisterAuditLogProvider(e.getCurrentLogger), - } + h.RegisterControlClientCallback(e.controlClientChanged) + h.Profiles().RegisterProfileStateChangeCallback(e.profileChanged) + h.RegisterAuditLogProvider(e.getCurrentLogger) return nil } @@ -190,9 +186,5 @@ func (e *extension) getCurrentLogger() ipnauth.AuditLogFunc { // Shutdown implements [ipnlocal.Extension]. func (e *extension) Shutdown() error { - for _, f := range e.cleanup { - f() - } - e.cleanup = nil return nil } diff --git a/ipn/desktop/extension.go b/ipn/desktop/extension.go index 86ae96f5b..057b4cfe6 100644 --- a/ipn/desktop/extension.go +++ b/ipn/desktop/extension.go @@ -74,13 +74,12 @@ func (e *desktopSessionsExt) Name() string { // Init implements [ipnext.Extension]. func (e *desktopSessionsExt) Init(host ipnext.Host) (err error) { e.host = host - unregisterResolver := host.Profiles().RegisterBackgroundProfileResolver(e.getBackgroundProfile) unregisterSessionCb, err := e.sm.RegisterStateCallback(e.updateDesktopSessionState) if err != nil { - unregisterResolver() return fmt.Errorf("session callback registration failed: %w", err) } - e.cleanup = []func(){unregisterResolver, unregisterSessionCb} + host.Profiles().RegisterBackgroundProfileResolver(e.getBackgroundProfile) + e.cleanup = []func(){unregisterSessionCb} return nil } diff --git a/ipn/ipnext/ipnext.go b/ipn/ipnext/ipnext.go index 4c7e978e5..5c35192e4 100644 --- a/ipn/ipnext/ipnext.go +++ b/ipn/ipnext/ipnext.go @@ -43,6 +43,7 @@ type Extension interface { // provided the extension was initialized. For multiple extensions, // Shutdown is called in the reverse order of Init. // Returned errors are not fatal; they are used for logging. + // After a call to Shutdown, the extension will not be called again. Shutdown() error } @@ -182,9 +183,11 @@ type Host interface { // RegisterAuditLogProvider registers an audit log provider, // which returns a function to be called when an auditable action - // is about to be performed. The returned function unregisters the provider. - // It is a runtime error to register a nil provider. - RegisterAuditLogProvider(AuditLogProvider) (unregister func()) + // is about to be performed. + // + // It is a runtime error to register a nil provider or call after the host + // has been initialized. + RegisterAuditLogProvider(AuditLogProvider) // AuditLogger returns a function that calls all currently registered audit loggers. // The function fails if any logger returns an error, indicating that the action @@ -195,9 +198,11 @@ type Host interface { AuditLogger() ipnauth.AuditLogFunc // RegisterControlClientCallback registers a function to be called every time a new - // control client is created. The returned function unregisters the callback. - // It is a runtime error to register a nil callback. - RegisterControlClientCallback(NewControlClientCallback) (unregister func()) + // control client is created. + // + // It is a runtime error to register a nil provider or call after the host + // has been initialized. + RegisterControlClientCallback(NewControlClientCallback) } // ExtensionServices provides access to the [Host]'s extension management services, @@ -252,23 +257,26 @@ type ProfileServices interface { SwitchToBestProfileAsync(reason string) // RegisterBackgroundProfileResolver registers a function to be used when - // resolving the background profile. The returned function unregisters the resolver. - // It is a runtime error to register a nil resolver. + // resolving the background profile. + // + // It is a runtime error to register a nil provider or call after the host + // has been initialized. // // TODO(nickkhyl): allow specifying some kind of priority/altitude for the resolver. // TODO(nickkhyl): make it a "profile resolver" instead of a "background profile resolver". // The concepts of the "current user", "foreground profile" and "background profile" // only exist on Windows, and we're moving away from them anyway. - RegisterBackgroundProfileResolver(ProfileResolver) (unregister func()) + RegisterBackgroundProfileResolver(ProfileResolver) // RegisterProfileStateChangeCallback registers a function to be called when the current - // [ipn.LoginProfile] or its [ipn.Prefs] change. The returned function unregisters the callback. + // [ipn.LoginProfile] or its [ipn.Prefs] change. // // To get the initial profile or prefs, use [ProfileServices.CurrentProfileState] // or [ProfileServices.CurrentPrefs] from the extension's [Extension.Init]. // - // It is a runtime error to register a nil callback. - RegisterProfileStateChangeCallback(ProfileStateChangeCallback) (unregister func()) + // It is a runtime error to register a nil provider or call after the host + // has been initialized. + RegisterProfileStateChangeCallback(ProfileStateChangeCallback) } // ProfileStore provides read-only access to available login profiles and their preferences. diff --git a/ipn/ipnlocal/extension_host.go b/ipn/ipnlocal/extension_host.go index aa56ad8ef..a7a764ebc 100644 --- a/ipn/ipnlocal/extension_host.go +++ b/ipn/ipnlocal/extension_host.go @@ -7,7 +7,6 @@ import ( "context" "errors" "fmt" - "iter" "maps" "reflect" "slices" @@ -24,8 +23,6 @@ import ( "tailscale.com/tsd" "tailscale.com/types/logger" "tailscale.com/util/execqueue" - "tailscale.com/util/set" - "tailscale.com/util/slicesx" "tailscale.com/util/testenv" ) @@ -78,6 +75,7 @@ type ExtensionHost struct { // initOnce is used to ensure that the extensions are initialized only once, // even if [extensionHost.Init] is called multiple times. initOnce sync.Once + initDone atomic.Bool // shutdownOnce is like initOnce, but for [ExtensionHost.Shutdown]. shutdownOnce sync.Once @@ -87,6 +85,24 @@ type ExtensionHost struct { // doEnqueueBackendOperation adds an asynchronous [LocalBackend] operation to the workQueue. doEnqueueBackendOperation func(func(Backend)) + // profileStateChangeCbs are callbacks that are invoked when the current login profile + // or its [ipn.Prefs] change, after those changes have been made. The current login profile + // may be changed either because of a profile switch, or because the profile information + // was updated by [LocalBackend.SetControlClientStatus], including when the profile + // is first populated and persisted. + profileStateChangeCbs []ipnext.ProfileStateChangeCallback + // backgroundProfileResolvers are registered background profile resolvers. + // They're used to determine the profile to use when no GUI/CLI client is connected. + backgroundProfileResolvers []ipnext.ProfileResolver + // auditLoggers are registered [AuditLogProvider]s. + // Each provider is called to get an [ipnauth.AuditLogFunc] when an auditable action + // is about to be performed. If an audit logger returns an error, the action is denied. + auditLoggers []ipnext.AuditLogProvider + // newControlClientCbs are the functions to be called when a new control client is created. + newControlClientCbs []ipnext.NewControlClientCallback + + shuttingDown atomic.Bool + // mu protects the following fields. // It must not be held when calling [LocalBackend] methods // or when invoking callbacks registered by extensions. @@ -107,22 +123,6 @@ type ExtensionHost struct { // currentPrefs is a read-only view of the current profile's [ipn.Prefs] // with any private keys stripped. It is always Valid. currentPrefs ipn.PrefsView - - // auditLoggers are registered [AuditLogProvider]s. - // Each provider is called to get an [ipnauth.AuditLogFunc] when an auditable action - // is about to be performed. If an audit logger returns an error, the action is denied. - auditLoggers set.HandleSet[ipnext.AuditLogProvider] - // backgroundProfileResolvers are registered background profile resolvers. - // They're used to determine the profile to use when no GUI/CLI client is connected. - backgroundProfileResolvers set.HandleSet[ipnext.ProfileResolver] - // newControlClientCbs are the functions to be called when a new control client is created. - newControlClientCbs set.HandleSet[ipnext.NewControlClientCallback] - // profileStateChangeCbs are callbacks that are invoked when the current login profile - // or its [ipn.Prefs] change, after those changes have been made. The current login profile - // may be changed either because of a profile switch, or because the profile information - // was updated by [LocalBackend.SetControlClientStatus], including when the profile - // is first populated and persisted. - profileStateChangeCbs set.HandleSet[ipnext.ProfileStateChangeCallback] } // Backend is a subset of [LocalBackend] methods that are used by [ExtensionHost]. @@ -160,13 +160,10 @@ func NewExtensionHost(logf logger.Logf, sys *tsd.System, b Backend, overrideExts host.workQueue.Add(func() { f(b) }) } - var numExts int - var exts iter.Seq2[int, *ipnext.Definition] - if overrideExts == nil { - // Use registered extensions. - exts = ipnext.Extensions().All() - numExts = ipnext.Extensions().Len() - } else { + // Use registered extensions. + exts := ipnext.Extensions().All() + numExts := ipnext.Extensions().Len() + if overrideExts != nil { // Use the provided, potentially empty, overrideExts // instead of the registered ones. exts = slices.All(overrideExts) @@ -196,6 +193,8 @@ func (h *ExtensionHost) Init() { } func (h *ExtensionHost) init() { + defer h.initDone.Store(true) + // Initialize the extensions in the order they were registered. h.mu.Lock() h.activeExtensions = make([]ipnext.Extension, 0, len(h.allExtensions)) @@ -343,21 +342,21 @@ func (h *ExtensionHost) Backend() Backend { return h.b } +// addFuncHook appends non-nil fn to hooks. +func addFuncHook[F any](h *ExtensionHost, hooks *[]F, fn F) { + if h.initDone.Load() { + panic("invalid callback register after init") + } + if reflect.ValueOf(fn).IsZero() { + panic("nil function hook") + } + *hooks = append(*hooks, fn) +} + // RegisterProfileStateChangeCallback implements [ipnext.ProfileServices]. -func (h *ExtensionHost) RegisterProfileStateChangeCallback(cb ipnext.ProfileStateChangeCallback) (unregister func()) { - if h == nil { - return func() {} - } - if cb == nil { - panic("nil profile change callback") - } - h.mu.Lock() - defer h.mu.Unlock() - handle := h.profileStateChangeCbs.Add(cb) - return func() { - h.mu.Lock() - defer h.mu.Unlock() - delete(h.profileStateChangeCbs, handle) +func (h *ExtensionHost) RegisterProfileStateChangeCallback(cb ipnext.ProfileStateChangeCallback) { + if h != nil { + addFuncHook(h, &h.profileStateChangeCbs, cb) } } @@ -366,7 +365,7 @@ func (h *ExtensionHost) RegisterProfileStateChangeCallback(cb ipnext.ProfileStat // It strips private keys from the [ipn.Prefs] before preserving // or passing them to the callbacks. func (h *ExtensionHost) NotifyProfileChange(profile ipn.LoginProfileView, prefs ipn.PrefsView, sameNode bool) { - if h == nil { + if !h.active() { return } h.mu.Lock() @@ -378,10 +377,9 @@ func (h *ExtensionHost) NotifyProfileChange(profile ipn.LoginProfileView, prefs // so we can provide them to the extensions later if they ask. h.currentPrefs = prefs h.currentProfile = profile - // Get the callbacks to be invoked. - cbs := slicesx.MapValues(h.profileStateChangeCbs) h.mu.Unlock() - for _, cb := range cbs { + + for _, cb := range h.profileStateChangeCbs { cb(profile, prefs, sameNode) } } @@ -390,7 +388,7 @@ func (h *ExtensionHost) NotifyProfileChange(profile ipn.LoginProfileView, prefs // and updates the current profile and prefs in the host. // It strips private keys from the [ipn.Prefs] before preserving or using them. func (h *ExtensionHost) NotifyProfilePrefsChanged(profile ipn.LoginProfileView, oldPrefs, newPrefs ipn.PrefsView) { - if h == nil { + if !h.active() { return } h.mu.Lock() @@ -403,28 +401,24 @@ func (h *ExtensionHost) NotifyProfilePrefsChanged(profile ipn.LoginProfileView, h.currentPrefs = newPrefs h.currentProfile = profile // Get the callbacks to be invoked. - stateCbs := slicesx.MapValues(h.profileStateChangeCbs) h.mu.Unlock() - for _, cb := range stateCbs { + + for _, cb := range h.profileStateChangeCbs { cb(profile, newPrefs, true) } } // RegisterBackgroundProfileResolver implements [ipnext.ProfileServices]. -func (h *ExtensionHost) RegisterBackgroundProfileResolver(resolver ipnext.ProfileResolver) (unregister func()) { - if h == nil { - return func() {} - } - h.mu.Lock() - defer h.mu.Unlock() - handle := h.backgroundProfileResolvers.Add(resolver) - return func() { - h.mu.Lock() - defer h.mu.Unlock() - delete(h.backgroundProfileResolvers, handle) +func (h *ExtensionHost) RegisterBackgroundProfileResolver(resolver ipnext.ProfileResolver) { + if h != nil { + addFuncHook(h, &h.backgroundProfileResolvers, resolver) } } +func (h *ExtensionHost) active() bool { + return h != nil && !h.shuttingDown.Load() +} + // DetermineBackgroundProfile returns a read-only view of the profile // used when no GUI/CLI client is connected, using background profile // resolvers registered by extensions. @@ -434,7 +428,7 @@ func (h *ExtensionHost) RegisterBackgroundProfileResolver(resolver ipnext.Profil // // As of 2025-02-07, this is only used on Windows. func (h *ExtensionHost) DetermineBackgroundProfile(profiles ipnext.ProfileStore) ipn.LoginProfileView { - if h == nil { + if !h.active() { return ipn.LoginProfileView{} } // TODO(nickkhyl): check if the returned profile is allowed on the device, @@ -443,10 +437,7 @@ func (h *ExtensionHost) DetermineBackgroundProfile(profiles ipnext.ProfileStore) // Attempt to resolve the background profile using the registered // background profile resolvers (e.g., [ipn/desktop.desktopSessionsExt] on Windows). - h.mu.Lock() - resolvers := slicesx.MapValues(h.backgroundProfileResolvers) - h.mu.Unlock() - for _, resolver := range resolvers { + for _, resolver := range h.backgroundProfileResolvers { if profile := resolver(profiles); profile.Valid() { return profile } @@ -458,35 +449,21 @@ func (h *ExtensionHost) DetermineBackgroundProfile(profiles ipnext.ProfileStore) } // RegisterControlClientCallback implements [ipnext.Host]. -func (h *ExtensionHost) RegisterControlClientCallback(cb ipnext.NewControlClientCallback) (unregister func()) { - if h == nil { - return func() {} - } - if cb == nil { - panic("nil control client callback") - } - h.mu.Lock() - defer h.mu.Unlock() - handle := h.newControlClientCbs.Add(cb) - return func() { - h.mu.Lock() - defer h.mu.Unlock() - delete(h.newControlClientCbs, handle) +func (h *ExtensionHost) RegisterControlClientCallback(cb ipnext.NewControlClientCallback) { + if h != nil { + addFuncHook(h, &h.newControlClientCbs, cb) } } // NotifyNewControlClient invokes all registered control client callbacks. // It returns callbacks to be executed when the control client shuts down. func (h *ExtensionHost) NotifyNewControlClient(cc controlclient.Client, profile ipn.LoginProfileView) (ccShutdownCbs []func()) { - if h == nil { + if !h.active() { return nil } - h.mu.Lock() - cbs := slicesx.MapValues(h.newControlClientCbs) - h.mu.Unlock() - if len(cbs) > 0 { - ccShutdownCbs = make([]func(), 0, len(cbs)) - for _, cb := range cbs { + if len(h.newControlClientCbs) > 0 { + ccShutdownCbs = make([]func(), 0, len(h.newControlClientCbs)) + for _, cb := range h.newControlClientCbs { if shutdown := cb(cc, profile); shutdown != nil { ccShutdownCbs = append(ccShutdownCbs, shutdown) } @@ -496,20 +473,9 @@ func (h *ExtensionHost) NotifyNewControlClient(cc controlclient.Client, profile } // RegisterAuditLogProvider implements [ipnext.Host]. -func (h *ExtensionHost) RegisterAuditLogProvider(provider ipnext.AuditLogProvider) (unregister func()) { - if h == nil { - return func() {} - } - if provider == nil { - panic("nil audit log provider") - } - h.mu.Lock() - defer h.mu.Unlock() - handle := h.auditLoggers.Add(provider) - return func() { - h.mu.Lock() - defer h.mu.Unlock() - delete(h.auditLoggers, handle) +func (h *ExtensionHost) RegisterAuditLogProvider(provider ipnext.AuditLogProvider) { + if h != nil { + addFuncHook(h, &h.auditLoggers, provider) } } @@ -523,20 +489,12 @@ func (h *ExtensionHost) RegisterAuditLogProvider(provider ipnext.AuditLogProvide // which typically includes the current profile and the audit loggers registered by extensions. // It must not be persisted outside of the auditable action context. func (h *ExtensionHost) AuditLogger() ipnauth.AuditLogFunc { - if h == nil { + if !h.active() { return func(tailcfg.ClientAuditAction, string) error { return nil } } - - h.mu.Lock() - providers := slicesx.MapValues(h.auditLoggers) - h.mu.Unlock() - - var loggers []ipnauth.AuditLogFunc - if len(providers) > 0 { - loggers = make([]ipnauth.AuditLogFunc, len(providers)) - for i, provider := range providers { - loggers[i] = provider() - } + loggers := make([]ipnauth.AuditLogFunc, 0, len(h.auditLoggers)) + for _, provider := range h.auditLoggers { + loggers = append(loggers, provider()) } return func(action tailcfg.ClientAuditAction, details string) error { // Log auditable actions to the host's log regardless of whether @@ -567,6 +525,7 @@ func (h *ExtensionHost) Shutdown() { } func (h *ExtensionHost) shutdown() { + h.shuttingDown.Store(true) // Prevent any queued but not yet started operations from running, // block new operations from being enqueued, and wait for the // currently executing operation (if any) to finish. diff --git a/ipn/ipnlocal/extension_host_test.go b/ipn/ipnlocal/extension_host_test.go index 4c497dd99..01122073a 100644 --- a/ipn/ipnlocal/extension_host_test.go +++ b/ipn/ipnlocal/extension_host_test.go @@ -576,30 +576,6 @@ func TestExtensionHostProfileStateChangeCallback(t *testing.T) { {Profile: &ipn.LoginProfile{ID: "profile-3"}, SameNode: true}, }, }, - { - // Override the default InitHook used in the test to unregister the callback - // after the first call. - name: "Register/Once", - ext: &testExtension{ - InitHook: func(e *testExtension) error { - var unregister func() - handler := func(profile ipn.LoginProfileView, prefs ipn.PrefsView, sameNode bool) { - makeStateChangeAppender(e)(profile, prefs, sameNode) - unregister() - } - unregister = e.host.Profiles().RegisterProfileStateChangeCallback(handler) - return nil - }, - }, - stateCalls: []stateChange{ - {Profile: &ipn.LoginProfile{ID: "profile-1"}}, - {Profile: &ipn.LoginProfile{ID: "profile-2"}}, - {Profile: &ipn.LoginProfile{ID: "profile-3"}}, - }, - wantChanges: []stateChange{ // only the first call is received by the callback - {Profile: &ipn.LoginProfile{ID: "profile-1"}}, - }, - }, { // Ensure that ipn.Prefs are passed to the callback. name: "CheckPrefs", @@ -770,7 +746,7 @@ func TestExtensionHostProfileStateChangeCallback(t *testing.T) { tt.ext.InitHook = func(e *testExtension) error { // Create and register the callback on init. handler := makeStateChangeAppender(e) - e.Cleanup(e.host.Profiles().RegisterProfileStateChangeCallback(handler)) + e.host.Profiles().RegisterProfileStateChangeCallback(handler) return nil } } @@ -891,14 +867,15 @@ func TestBackgroundProfileResolver(t *testing.T) { } } - h := newExtensionHostForTest[ipnext.Extension](t, &testBackend{}, true) + h := newExtensionHostForTest[ipnext.Extension](t, &testBackend{}, false) // Register the resolvers with the host. // This is typically done by the extensions themselves, // but we do it here for testing purposes. for _, r := range tt.resolvers { - t.Cleanup(h.Profiles().RegisterBackgroundProfileResolver(r)) + h.Profiles().RegisterBackgroundProfileResolver(r) } + h.Init() // Call the resolver to get the profile. gotProfile := h.DetermineBackgroundProfile(pm) @@ -989,7 +966,7 @@ func TestAuditLogProviders(t *testing.T) { } } ext.InitHook = func(e *testExtension) error { - e.Cleanup(e.host.RegisterAuditLogProvider(provider)) + e.host.RegisterAuditLogProvider(provider) return nil } exts = append(exts, ext) @@ -1168,8 +1145,6 @@ type testExtension struct { // It can be accessed by tests using [setTestExtensionState], // [getTestExtensionStateOk] and [getTestExtensionState]. state map[string]any - // cleanup are functions to be called on shutdown. - cleanup []func() } var _ ipnext.Extension = (*testExtension)(nil) @@ -1212,22 +1187,11 @@ func (e *testExtension) InitCalled() bool { return e.initCnt.Load() != 0 } -func (e *testExtension) Cleanup(f func()) { - e.mu.Lock() - e.cleanup = append(e.cleanup, f) - e.mu.Unlock() -} - // Shutdown implements [ipnext.Extension]. func (e *testExtension) Shutdown() (err error) { e.t.Helper() e.mu.Lock() - cleanup := e.cleanup - e.cleanup = nil e.mu.Unlock() - for _, f := range cleanup { - f() - } if e.ShutdownHook != nil { err = e.ShutdownHook(e) }