diff --git a/cmd/k8s-operator/depaware.txt b/cmd/k8s-operator/depaware.txt index 7fd4c4b21..416265188 100644 --- a/cmd/k8s-operator/depaware.txt +++ b/cmd/k8s-operator/depaware.txt @@ -815,8 +815,8 @@ tailscale.com/cmd/k8s-operator dependencies: (generated by github.com/tailscale/ tailscale.com/internal/noiseconn from tailscale.com/control/controlclient tailscale.com/ipn from tailscale.com/client/local+ tailscale.com/ipn/conffile from tailscale.com/ipn/ipnlocal+ - πŸ’£ tailscale.com/ipn/desktop from tailscale.com/ipn/ipnlocal+ πŸ’£ tailscale.com/ipn/ipnauth from tailscale.com/ipn/ipnlocal+ + tailscale.com/ipn/ipnext from tailscale.com/ipn/ipnlocal tailscale.com/ipn/ipnlocal from tailscale.com/ipn/localapi+ tailscale.com/ipn/ipnstate from tailscale.com/client/local+ tailscale.com/ipn/localapi from tailscale.com/tsnet+ diff --git a/cmd/tailscaled/depaware.txt b/cmd/tailscaled/depaware.txt index 394056295..9cdebbae1 100644 --- a/cmd/tailscaled/depaware.txt +++ b/cmd/tailscaled/depaware.txt @@ -273,8 +273,9 @@ tailscale.com/cmd/tailscaled dependencies: (generated by github.com/tailscale/de tailscale.com/ipn from tailscale.com/client/local+ W tailscale.com/ipn/auditlog from tailscale.com/cmd/tailscaled tailscale.com/ipn/conffile from tailscale.com/cmd/tailscaled+ - πŸ’£ tailscale.com/ipn/desktop from tailscale.com/cmd/tailscaled+ + W πŸ’£ tailscale.com/ipn/desktop from tailscale.com/cmd/tailscaled πŸ’£ tailscale.com/ipn/ipnauth from tailscale.com/ipn/ipnlocal+ + tailscale.com/ipn/ipnext from tailscale.com/ipn/auditlog+ tailscale.com/ipn/ipnlocal from tailscale.com/cmd/tailscaled+ tailscale.com/ipn/ipnserver from tailscale.com/cmd/tailscaled tailscale.com/ipn/ipnstate from tailscale.com/client/local+ diff --git a/cmd/tailscaled/tailscaled_windows.go b/cmd/tailscaled/tailscaled_windows.go index dfe53ef61..54ff2af14 100644 --- a/cmd/tailscaled/tailscaled_windows.go +++ b/cmd/tailscaled/tailscaled_windows.go @@ -45,7 +45,7 @@ import ( "tailscale.com/drive/driveimpl" "tailscale.com/envknob" _ "tailscale.com/ipn/auditlog" - "tailscale.com/ipn/desktop" + _ "tailscale.com/ipn/desktop" "tailscale.com/logpolicy" "tailscale.com/logtail/backoff" "tailscale.com/net/dns" @@ -337,13 +337,6 @@ func beWindowsSubprocess() bool { sys.Set(driveimpl.NewFileSystemForRemote(log.Printf)) - if sessionManager, err := desktop.NewSessionManager(log.Printf); err == nil { - sys.Set(sessionManager) - } else { - // Errors creating the session manager are unexpected, but not fatal. - log.Printf("[unexpected]: error creating a desktop session manager: %v", err) - } - publicLogID, _ := logid.ParsePublicID(logID) err = startIPNServer(ctx, log.Printf, publicLogID, sys) if err != nil { diff --git a/ipn/auditlog/extension.go b/ipn/auditlog/extension.go index 8be7dfb66..6bbe37398 100644 --- a/ipn/auditlog/extension.go +++ b/ipn/auditlog/extension.go @@ -14,19 +14,23 @@ import ( "tailscale.com/feature" "tailscale.com/ipn" "tailscale.com/ipn/ipnauth" - "tailscale.com/ipn/ipnlocal" + "tailscale.com/ipn/ipnext" "tailscale.com/tailcfg" "tailscale.com/tsd" "tailscale.com/types/lazy" "tailscale.com/types/logger" ) +// featureName is the name of the feature implemented by this package. +// It is also the the [extension] name and the log prefix. +const featureName = "auditlog" + func init() { - feature.Register("auditlog") - ipnlocal.RegisterExtension("auditlog", newExtension) + feature.Register(featureName) + ipnext.RegisterExtension(featureName, newExtension) } -// extension is an [ipnlocal.Extension] managing audit logging +// extension is an [ipnext.Extension] managing audit logging // on platforms that import this package. // As of 2025-03-27, that's only Windows and macOS. type extension struct { @@ -48,19 +52,24 @@ type extension struct { logger *Logger } -// newExtension is an [ipnlocal.NewExtensionFn] that creates a new audit log extension. -// It is registered with [ipnlocal.RegisterExtension] if the package is imported. -func newExtension(logf logger.Logf, _ *tsd.System) (ipnlocal.Extension, error) { - return &extension{logf: logger.WithPrefix(logf, "auditlog: ")}, nil +// newExtension is an [ipnext.NewExtensionFn] that creates a new audit log extension. +// It is registered with [ipnext.RegisterExtension] if the package is imported. +func newExtension(logf logger.Logf, _ *tsd.System) (ipnext.Extension, error) { + return &extension{logf: logger.WithPrefix(logf, featureName+": ")}, nil } -// Init implements [ipnlocal.Extension] by registering callbacks and providers +// Name implements [ipnext.Extension]. +func (e *extension) Name() string { + return featureName +} + +// Init implements [ipnext.Extension] by registering callbacks and providers // for the duration of the extension's lifetime. -func (e *extension) Init(lb *ipnlocal.LocalBackend) error { +func (e *extension) Init(h ipnext.Host) error { e.cleanup = []func(){ - lb.RegisterControlClientCallback(e.controlClientChanged), - lb.RegisterProfileChangeCallback(e.profileChanged, false), - lb.RegisterAuditLogProvider(e.getCurrentLogger), + h.RegisterControlClientCallback(e.controlClientChanged), + h.Profiles().RegisterProfileChangeCallback(e.profileChanged), + h.RegisterAuditLogProvider(e.getCurrentLogger), } return nil } @@ -165,8 +174,8 @@ func noCurrentLogger(_ tailcfg.ClientAuditAction, _ string) error { return errNoLogger } -// getCurrentLogger is an [ipnlocal.AuditLogProvider] registered with [ipnlocal.LocalBackend]. -// It is called when [ipnlocal.LocalBackend] needs to audit an action. +// getCurrentLogger is an [ipnext.AuditLogProvider] registered with [ipnext.Host]. +// It is called when [ipnlocal.LocalBackend] or an extension needs to audit an action. // // It returns a function that enqueues the audit log for the current profile, // or [noCurrentLogger] if the logger is unavailable. diff --git a/ipn/ipnlocal/desktop_sessions.go b/ipn/desktop/extension.go similarity index 62% rename from ipn/ipnlocal/desktop_sessions.go rename to ipn/desktop/extension.go index 29cb196c7..86ae96f5b 100644 --- a/ipn/ipnlocal/desktop_sessions.go +++ b/ipn/desktop/extension.go @@ -7,29 +7,32 @@ //go:build windows && !ts_omit_desktop_sessions -package ipnlocal +package desktop import ( "cmp" - "errors" "fmt" "sync" "tailscale.com/feature" "tailscale.com/ipn" - "tailscale.com/ipn/desktop" + "tailscale.com/ipn/ipnext" "tailscale.com/tsd" "tailscale.com/types/logger" "tailscale.com/util/syspolicy" ) +// featureName is the name of the feature implemented by this package. +// It is also the the [desktopSessionsExt] name and the log prefix. +const featureName = "desktop-sessions" + func init() { - feature.Register("desktop-sessions") - RegisterExtension("desktop-sessions", newDesktopSessionsExt) + feature.Register(featureName) + ipnext.RegisterExtension(featureName, newDesktopSessionsExt) } -// desktopSessionsExt implements [Extension]. -var _ Extension = (*desktopSessionsExt)(nil) +// [desktopSessionsExt] implements [ipnext.Extension]. +var _ ipnext.Extension = (*desktopSessionsExt)(nil) // desktopSessionsExt extends [LocalBackend] with desktop session management. // It keeps Tailscale running in the background if Always-On mode is enabled, @@ -37,32 +40,41 @@ var _ Extension = (*desktopSessionsExt)(nil) // locks their screen, or disconnects a remote session. type desktopSessionsExt struct { logf logger.Logf - sm desktop.SessionManager + sm SessionManager - *LocalBackend // or nil, until Init is called - cleanup []func() // cleanup functions to call on shutdown + host ipnext.Host // or nil, until Init is called + cleanup []func() // cleanup functions to call on shutdown // mu protects all following fields. - // When both mu and [LocalBackend.mu] need to be taken, - // [LocalBackend.mu] must be taken before mu. - mu sync.Mutex - id2sess map[desktop.SessionID]*desktop.Session + mu sync.Mutex + sessByID map[SessionID]*Session } // newDesktopSessionsExt returns a new [desktopSessionsExt], -// or an error if [desktop.SessionManager] is not available. -func newDesktopSessionsExt(logf logger.Logf, sys *tsd.System) (Extension, error) { - sm, ok := sys.SessionManager.GetOK() - if !ok { - return nil, errors.New("session manager is not available") +// or an error if a [SessionManager] cannot be created. +// It is registered with [ipnext.RegisterExtension] if the package is imported. +func newDesktopSessionsExt(logf logger.Logf, sys *tsd.System) (ipnext.Extension, error) { + logf = logger.WithPrefix(logf, featureName+": ") + sm, err := NewSessionManager(logf) + if err != nil { + return nil, fmt.Errorf("%w: session manager is not available: %w", ipnext.SkipExtension, err) } - return &desktopSessionsExt{logf: logf, sm: sm, id2sess: make(map[desktop.SessionID]*desktop.Session)}, nil + return &desktopSessionsExt{ + logf: logf, + sm: sm, + sessByID: make(map[SessionID]*Session), + }, nil } -// Init implements [localBackendExtension]. -func (e *desktopSessionsExt) Init(lb *LocalBackend) (err error) { - e.LocalBackend = lb - unregisterResolver := lb.RegisterBackgroundProfileResolver(e.getBackgroundProfile) +// Name implements [ipnext.Extension]. +func (e *desktopSessionsExt) Name() string { + return featureName +} + +// Init implements [ipnext.Extension]. +func (e *desktopSessionsExt) Init(host ipnext.Host) (err error) { + e.host = host + unregisterResolver := host.Profiles().RegisterBackgroundProfileResolver(e.getBackgroundProfile) unregisterSessionCb, err := e.sm.RegisterStateCallback(e.updateDesktopSessionState) if err != nil { unregisterResolver() @@ -72,30 +84,30 @@ func (e *desktopSessionsExt) Init(lb *LocalBackend) (err error) { return nil } -// updateDesktopSessionState is a [desktop.SessionStateCallback] -// invoked by [desktop.SessionManager] once for each existing session +// updateDesktopSessionState is a [SessionStateCallback] +// invoked by [SessionManager] once for each existing session // and whenever the session state changes. It updates the session map // and switches to the best profile if necessary. -func (e *desktopSessionsExt) updateDesktopSessionState(session *desktop.Session) { +func (e *desktopSessionsExt) updateDesktopSessionState(session *Session) { e.mu.Lock() - if session.Status != desktop.ClosedSession { - e.id2sess[session.ID] = session + if session.Status != ClosedSession { + e.sessByID[session.ID] = session } else { - delete(e.id2sess, session.ID) + delete(e.sessByID, session.ID) } e.mu.Unlock() var action string switch session.Status { - case desktop.ForegroundSession: + case ForegroundSession: // The user has either signed in or unlocked their session. // For remote sessions, this may also mean the user has connected. // The distinction isn't important for our purposes, // so let's always say "signed in". action = "signed in to" - case desktop.BackgroundSession: + case BackgroundSession: action = "locked" - case desktop.ClosedSession: + case ClosedSession: action = "signed out from" default: panic("unreachable") @@ -104,10 +116,10 @@ func (e *desktopSessionsExt) updateDesktopSessionState(session *desktop.Session) userIdentifier := cmp.Or(maybeUsername, string(session.User.UserID()), "user") reason := fmt.Sprintf("%s %s session %v", userIdentifier, action, session.ID) - e.SwitchToBestProfile(reason) + e.host.Profiles().SwitchToBestProfileAsync(reason) } -// getBackgroundProfile is a [profileResolver] that works as follows: +// getBackgroundProfile is a [ipnext.ProfileResolver] that works as follows: // // If Always-On mode is disabled, it returns no profile. // @@ -121,9 +133,7 @@ func (e *desktopSessionsExt) updateDesktopSessionState(session *desktop.Session) // disconnects without signing out. // // In all other cases, it returns no profile. -// -// It is called with [LocalBackend.mu] locked. -func (e *desktopSessionsExt) getBackgroundProfile() ipn.LoginProfileView { +func (e *desktopSessionsExt) getBackgroundProfile(profiles ipnext.ProfileStore) ipn.LoginProfileView { e.mu.Lock() defer e.mu.Unlock() @@ -135,16 +145,16 @@ func (e *desktopSessionsExt) getBackgroundProfile() ipn.LoginProfileView { isCurrentProfileOwnerSignedIn := false var foregroundUIDs []ipn.WindowsUserID - for _, s := range e.id2sess { + for _, s := range e.sessByID { switch uid := s.User.UserID(); uid { - case e.pm.CurrentProfile().LocalUserID(): + case profiles.CurrentProfile().LocalUserID(): isCurrentProfileOwnerSignedIn = true - if s.Status == desktop.ForegroundSession { + if s.Status == ForegroundSession { // Keep the current profile if the user has a foreground session. - return e.pm.CurrentProfile() + return profiles.CurrentProfile() } default: - if s.Status == desktop.ForegroundSession { + if s.Status == ForegroundSession { foregroundUIDs = append(foregroundUIDs, uid) } } @@ -154,7 +164,7 @@ func (e *desktopSessionsExt) getBackgroundProfile() ipn.LoginProfileView { // or if the current profile's owner has no foreground session, switch to the default profile // of the first user with a foreground session, if any. for _, uid := range foregroundUIDs { - if profile := e.pm.DefaultUserProfile(uid); profile.ID() != "" { + if profile := profiles.DefaultUserProfile(uid); profile.ID() != "" { return profile } } @@ -163,19 +173,19 @@ func (e *desktopSessionsExt) getBackgroundProfile() ipn.LoginProfileView { // keep the current profile even if the session is not in the foreground, // such as when the screen is locked or a remote session is disconnected. if len(foregroundUIDs) == 0 && isCurrentProfileOwnerSignedIn { - return e.pm.CurrentProfile() + return profiles.CurrentProfile() } // Otherwise, there's no background profile. return ipn.LoginProfileView{} } -// Shutdown implements [localBackendExtension]. +// Shutdown implements [ipnext.Extension]. func (e *desktopSessionsExt) Shutdown() error { for _, f := range e.cleanup { f() } e.cleanup = nil - e.LocalBackend = nil - return nil + e.host = nil + return e.sm.Close() } diff --git a/ipn/ipnext/ipnext.go b/ipn/ipnext/ipnext.go new file mode 100644 index 000000000..af870b53a --- /dev/null +++ b/ipn/ipnext/ipnext.go @@ -0,0 +1,284 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package ipnext defines types and interfaces used for extending the core LocalBackend +// functionality with additional features and services. +package ipnext + +import ( + "errors" + "fmt" + + "tailscale.com/control/controlclient" + "tailscale.com/ipn" + "tailscale.com/ipn/ipnauth" + "tailscale.com/tsd" + "tailscale.com/types/logger" + "tailscale.com/types/views" + "tailscale.com/util/mak" +) + +// Extension augments LocalBackend with additional functionality. +// +// An extension uses the provided [Host] to register callbacks +// and interact with the backend in a controlled, well-defined +// and thread-safe manner. +// +// Extensions are registered using [RegisterExtension]. +// +// They must be safe for concurrent use. +type Extension interface { + // Name is a unique name of the extension. + // It must be the same as the name used to register the extension. + Name() string + + // Init is called to initialize the extension when LocalBackend is initialized. + // If the extension cannot be initialized, it must return an error, + // and its Shutdown method will not be called on the host's shutdown. + // Returned errors are not fatal; they are used for logging. + // A [SkipExtension] error indicates an intentional decision rather than a failure. + Init(Host) error + + // Shutdown is called when LocalBackend is shutting down, + // provided the extension was initialized. For multiple extensions, + // Shutdown is called in the reverse order of Init. + // Returned errors are not fatal; they are used for logging. + Shutdown() error +} + +// NewExtensionFn is a function that instantiates an [Extension]. +// If a registered extension cannot be instantiated, the function must return an error. +// If the extension should be skipped at runtime, it must return either [SkipExtension] +// or a wrapped [SkipExtension]. Any other error returned is fatal and will prevent +// the LocalBackend from starting. +type NewExtensionFn func(logger.Logf, *tsd.System) (Extension, error) + +// SkipExtension is an error returned by [NewExtensionFn] to indicate that the extension +// should be skipped rather than prevent the LocalBackend from starting. +// +// Skipping an extension should be reserved for cases where the extension is not supported +// on the current platform or configuration, or depends on a feature that is not available, +// or otherwise should be disabled permanently rather than temporarily. +// +// Specifically, it must not be returned if the extension is not required right now +// based on user preferences, policy settings, the current tailnet, or other factors +// that may change throughout the LocalBackend's lifetime. +var SkipExtension = errors.New("skipping extension") + +// Definition describes a registered [Extension]. +type Definition struct { + name string // name under which the extension is registered + newFn NewExtensionFn // function that creates a new instance of the extension +} + +// Name returns the name of the extension. +func (d *Definition) Name() string { + return d.name +} + +// MakeExtension instantiates the extension. +func (d *Definition) MakeExtension(logf logger.Logf, sys *tsd.System) (Extension, error) { + ext, err := d.newFn(logf, sys) + if err != nil { + return nil, err + } + if ext.Name() != d.name { + return nil, fmt.Errorf("extension name mismatch: registered %q; actual %q", d.name, ext.Name()) + } + return ext, nil +} + +// extensionsByName is a map of registered extensions, +// where the key is the name of the extension. +var extensionsByName map[string]*Definition + +// extensionsByOrder is a slice of registered extensions, +// in the order they were registered. +var extensionsByOrder []*Definition + +// RegisterExtension registers a function that instantiates an [Extension]. +// The name must be the same as returned by the extension's [Extension.Name]. +// +// It must be called on the main goroutine before LocalBackend is created, +// such as from an init function of the package implementing the extension. +// +// It panics if newExt is nil or if an extension with the same name +// has already been registered. +func RegisterExtension(name string, newExt NewExtensionFn) { + if newExt == nil { + panic(fmt.Sprintf("ipnext: newExt is nil: %q", name)) + } + if _, ok := extensionsByName[name]; ok { + panic(fmt.Sprintf("ipnext: duplicate extensions: %q", name)) + } + ext := &Definition{name, newExt} + mak.Set(&extensionsByName, name, ext) + extensionsByOrder = append(extensionsByOrder, ext) +} + +// Extensions returns a read-only view of the extensions +// registered via [RegisterExtension]. It preserves the order +// in which the extensions were registered. +func Extensions() views.Slice[*Definition] { + return views.SliceOf(extensionsByOrder) +} + +// DefinitionForTest returns a [Definition] for the specified [Extension]. +// It is primarily used for testing where the test code needs to instantiate +// and use an extension without registering it. +func DefinitionForTest(ext Extension) *Definition { + return &Definition{ + name: ext.Name(), + newFn: func(logger.Logf, *tsd.System) (Extension, error) { return ext, nil }, + } +} + +// DefinitionWithErrForTest returns a [Definition] with the specified extension name +// whose [Definition.MakeExtension] method returns the specified error. +// It is used for testing. +func DefinitionWithErrForTest(name string, err error) *Definition { + return &Definition{ + name: name, + newFn: func(logger.Logf, *tsd.System) (Extension, error) { return nil, err }, + } +} + +// Host is the API surface used by [Extension]s to interact with LocalBackend +// in a controlled manner. +// +// Extensions can register callbacks, request information, or perform actions +// via the [Host] interface. +// +// Typically, the host invokes registered callbacks when one of the following occurs: +// - LocalBackend notifies it of an event or state change that may be +// of interest to extensions, such as when switching [ipn.LoginProfile]. +// - LocalBackend needs to consult extensions for information, for example, +// determining the most appropriate profile for the current state of the system. +// - LocalBackend performs an extensible action, such as logging an auditable event, +// and delegates its execution to the extension. +// +// The callbacks are invoked synchronously, and the LocalBackend's state +// remains unchanged while callbacks execute. +// +// In contrast, actions initiated by extensions are generally asynchronous, +// as indicated by the "Async" suffix in their names. +// Performing actions may result in callbacks being invoked as described above. +// +// To prevent conflicts between extensions competing for shared state, +// such as the current profile or prefs, the host must not expose methods +// that directly modify that state. For example, instead of allowing extensions +// to switch profiles at-will, the host's [ProfileServices] provides a method +// to switch to the "best" profile. The host can then consult extensions +// to determine the appropriate profile to use and resolve any conflicts +// in a controlled manner. +// +// A host must be safe for concurrent use. +type Host interface { + // Profiles returns the host's [ProfileServices]. + Profiles() ProfileServices + + // RegisterAuditLogProvider registers an audit log provider, + // which returns a function to be called when an auditable action + // is about to be performed. The returned function unregisters the provider. + // It is a runtime error to register a nil provider. + RegisterAuditLogProvider(AuditLogProvider) (unregister func()) + + // AuditLogger returns a function that calls all currently registered audit loggers. + // The function fails if any logger returns an error, indicating that the action + // cannot be logged and must not be performed. + // + // The returned function captures the current state (e.g., the current profile) at + // the time of the call and must not be persisted. + AuditLogger() ipnauth.AuditLogFunc + + // RegisterControlClientCallback registers a function to be called every time a new + // control client is created. The returned function unregisters the callback. + // It is a runtime error to register a nil callback. + RegisterControlClientCallback(NewControlClientCallback) (unregister func()) +} + +// ProfileServices provides access to the [Host]'s profile management services, +// such as switching profiles and registering profile change callbacks. +type ProfileServices interface { + // SwitchToBestProfileAsync asynchronously selects the best profile to use + // and switches to it, unless it is already the current profile. + // + // If an extension needs to know when a profile switch occurs, + // it must use [ProfileServices.RegisterProfileChangeCallback] + // to register a [ProfileChangeCallback]. + // + // The reason indicates why the profile is being switched, such as due + // to a client connecting or disconnecting or a change in the desktop + // session state. It is used for logging. + SwitchToBestProfileAsync(reason string) + + // RegisterBackgroundProfileResolver registers a function to be used when + // resolving the background profile. The returned function unregisters the resolver. + // It is a runtime error to register a nil resolver. + // + // TODO(nickkhyl): allow specifying some kind of priority/altitude for the resolver. + // TODO(nickkhyl): make it a "profile resolver" instead of a "background profile resolver". + // The concepts of the "current user", "foreground profile" and "background profile" + // only exist on Windows, and we're moving away from them anyway. + RegisterBackgroundProfileResolver(ProfileResolver) (unregister func()) + + // RegisterProfileChangeCallback registers a function to be called when the current + // [ipn.LoginProfile] changes. The returned function unregisters the callback. + // It is a runtime error to register a nil callback. + RegisterProfileChangeCallback(ProfileChangeCallback) (unregister func()) +} + +// ProfileStore provides read-only access to available login profiles and their preferences. +// It is not safe for concurrent use and can only be used from the callback it is passed to. +type ProfileStore interface { + // CurrentUserID returns the current user ID. It is only non-empty on + // Windows where we have a multi-user system. + // + // Deprecated: this method exists for compatibility with the current (as of 2024-08-27) + // permission model and will be removed as we progress on tailscale/corp#18342. + CurrentUserID() ipn.WindowsUserID + + // CurrentProfile returns a read-only [ipn.LoginProfileView] of the current profile. + // The returned view is always valid, but the profile's [ipn.LoginProfileView.ID] + // returns "" if the profile is new and has not been persisted yet. + CurrentProfile() ipn.LoginProfileView + + // CurrentPrefs returns a read-only view of the current prefs. + // The returned view is always valid. + CurrentPrefs() ipn.PrefsView + + // DefaultUserProfile returns a read-only view of the default (last used) profile for the specified user. + // It returns a read-only view of a new, non-persisted profile if the specified user does not have a default profile. + DefaultUserProfile(uid ipn.WindowsUserID) ipn.LoginProfileView +} + +// AuditLogProvider is a function that returns an [ipnauth.AuditLogFunc] for +// logging auditable actions. +type AuditLogProvider func() ipnauth.AuditLogFunc + +// ProfileResolver is a function that returns a read-only view of a login profile. +// An invalid view indicates no profile. A valid profile view with an empty [ipn.ProfileID] +// indicates that the profile is new and has not been persisted yet. +// The provided [ProfileStore] can only be used for the duration of the callback. +type ProfileResolver func(ProfileStore) ipn.LoginProfileView + +// ProfileChangeCallback is a function to be called when the current login profile changes. +// The sameNode parameter indicates whether the profile represents the same node as before, +// such as when only the profile metadata is updated but the node ID remains the same, +// or when a new profile is persisted and assigned an [ipn.ProfileID] for the first time. +// The subscribers can use this information to decide whether to reset their state. +// +// The profile and prefs are always valid, but the profile's [ipn.LoginProfileView.ID] +// returns "" if the profile is new and has not been persisted yet. +type ProfileChangeCallback func(_ ipn.LoginProfileView, _ ipn.PrefsView, sameNode bool) + +// NewControlClientCallback is a function to be called when a new [controlclient.Client] +// is created and before it is first used. The login profile and prefs represent +// the profile for which the cc is created and are always valid; however, the +// profile's [ipn.LoginProfileView.ID] returns "" if the profile is new +// and has not been persisted yet. If the [controlclient.Client] is created +// due to a profile switch, any registered [ProfileChangeCallback]s are called first. +// +// It returns a function to be called when the cc is being shut down, +// or nil if no cleanup is needed. +type NewControlClientCallback func(controlclient.Client, ipn.LoginProfileView, ipn.PrefsView) (cleanup func()) diff --git a/ipn/ipnlocal/extension_host.go b/ipn/ipnlocal/extension_host.go new file mode 100644 index 000000000..4a617ed72 --- /dev/null +++ b/ipn/ipnlocal/extension_host.go @@ -0,0 +1,537 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package ipnlocal + +import ( + "context" + "errors" + "fmt" + "iter" + "maps" + "slices" + "strings" + "sync" + "sync/atomic" + "time" + + "tailscale.com/control/controlclient" + "tailscale.com/ipn" + "tailscale.com/ipn/ipnauth" + "tailscale.com/ipn/ipnext" + "tailscale.com/tailcfg" + "tailscale.com/tsd" + "tailscale.com/types/logger" + "tailscale.com/util/execqueue" + "tailscale.com/util/set" + "tailscale.com/util/testenv" +) + +// ExtensionHost is a bridge between the [LocalBackend] and the registered [ipnext.Extension]s. +// It implements [ipnext.Host] and is safe for concurrent use. +// +// A nil pointer to [ExtensionHost] is a valid, no-op extension host which is primarily used in tests +// that instantiate [LocalBackend] directly without using [NewExtensionHost]. +// +// The [LocalBackend] is not required to hold its mutex when calling the host's methods, +// but it typically does so either to prevent changes to its state (for example, the current profile) +// while callbacks are executing, or because it calls the host's methods as part of a larger operation +// that requires the mutex to be held. +// +// Extensions might invoke the host's methods either from callbacks triggered by the [LocalBackend], +// or in a response to external events. Some methods can be called by both the extensions and the backend. +// +// As a general rule, the host cannot assume anything about the current state of the [LocalBackend]'s +// internal mutex on entry to its methods, and therefore cannot safely call [LocalBackend] methods directly. +// +// The following are typical and supported patterns: +// - LocalBackend notifies the host about an event, such as a change in the current profile. +// The host invokes callbacks registered by Extensions, forwarding the event arguments to them. +// If necessary, the host can also update its own state for future use. +// - LocalBackend requests information from the host, such as the effective [ipnauth.AuditLogFunc] +// or the [ipn.LoginProfile] to use when no GUI/CLI client is connected. Typically, [LocalBackend] +// provides the required context to the host, and the host returns the result to [LocalBackend] +// after forwarding the request to the extensions. +// - Extension invokes the host's method to perform an action, such as switching to the "best" profile +// in response to a change in the device's state. Since the host does not know whether the [LocalBackend]'s +// internal mutex is held, it cannot invoke any methods on the [LocalBackend] directly and must instead +// do so asynchronously, such as by using [ExtensionHost.enqueueBackendOperation]. +// - Extension requests information from the host, such as the effective [ipnauth.AuditLogFunc] +// or the current [ipn.LoginProfile]. Since the host cannot invoke any methods on the [LocalBackend] directly, +// it should maintain its own view of the current state, updating it when the [LocalBackend] notifies it +// about a change or event. +// +// To safeguard against adopting incorrect or risky patterns, the host does not store [LocalBackend] in its fields +// and instead provides [ExtensionHost.enqueueBackendOperation]. Additionally, to make it easier to test extensions +// and to further reduce the risk of accessing unexported methods or fields of [LocalBackend], the host interacts +// with it via the [Backend] interface. +type ExtensionHost struct { + logf logger.Logf // prefixed with "ipnext:" + + // allExtensions holds the extensions in the order they were registered, + // including those that have not yet attempted initialization or have failed to initialize. + allExtensions []ipnext.Extension + + // initOnce is used to ensure that the extensions are initialized only once, + // even if [extensionHost.Init] is called multiple times. + initOnce sync.Once + // shutdownOnce is like initOnce, but for [ExtensionHost.Shutdown]. + shutdownOnce sync.Once + + // workQueue maintains execution order for asynchronous operations requested by extensions. + // It is always an [execqueue.ExecQueue] except in some tests. + workQueue execQueue + // doEnqueueBackendOperation adds an asynchronous [LocalBackend] operation to the workQueue. + doEnqueueBackendOperation func(func(Backend)) + + // mu protects the following fields. + // It must not be held when calling [LocalBackend] methods + // or when invoking callbacks registered by extensions. + mu sync.Mutex + // initialized is whether the host and extensions have been fully initialized. + initialized atomic.Bool + // activeExtensions is a subset of allExtensions that have been initialized and are ready to use. + activeExtensions []ipnext.Extension + // extensionsByName are the activeExtensions indexed by their names. + extensionsByName map[string]ipnext.Extension + // postInitWorkQueue is a queue of functions to be executed + // by the workQueue after all extensions have been initialized. + postInitWorkQueue []func(Backend) + + // auditLoggers are registered [AuditLogProvider]s. + // Each provider is called to get an [ipnauth.AuditLogFunc] when an auditable action + // is about to be performed. If an audit logger returns an error, the action is denied. + auditLoggers set.HandleSet[ipnext.AuditLogProvider] + // backgroundProfileResolvers are registered background profile resolvers. + // They're used to determine the profile to use when no GUI/CLI client is connected. + backgroundProfileResolvers set.HandleSet[ipnext.ProfileResolver] + // newControlClientCbs are the functions to be called when a new control client is created. + newControlClientCbs set.HandleSet[ipnext.NewControlClientCallback] + // profileChangeCbs are the callbacks to be invoked when the current login profile changes, + // either because of a profile switch, or because the profile information was updated + // by [LocalBackend.SetControlClientStatus], including when the profile is first populated + // and persisted. + profileChangeCbs set.HandleSet[ipnext.ProfileChangeCallback] +} + +// Backend is a subset of [LocalBackend] methods that are used by [ExtensionHost]. +// It is primarily used for testing. +type Backend interface { + // SwitchToBestProfile switches to the best profile for the current state of the system. + // The reason indicates why the profile is being switched. + SwitchToBestProfile(reason string) +} + +// NewExtensionHost returns a new [ExtensionHost] which manages registered extensions for the given backend. +// The extensions are instantiated, but are not initialized until [ExtensionHost.Init] is called. +// It returns an error if instantiating any extension fails. +// +// If overrideExts is non-nil, the registered extensions are ignored and the provided extensions are used instead. +// Overriding extensions is primarily used for testing. +func NewExtensionHost(logf logger.Logf, sys *tsd.System, b Backend, overrideExts ...*ipnext.Definition) (_ *ExtensionHost, err error) { + host := &ExtensionHost{ + logf: logger.WithPrefix(logf, "ipnext: "), + workQueue: &execqueue.ExecQueue{}, + } + + // All operations on the backend must be executed asynchronously by the work queue. + // DO NOT retain a direct reference to the backend in the host. + // See the docstring for [ExtensionHost] for more details. + host.doEnqueueBackendOperation = func(f func(Backend)) { + if f == nil { + panic("nil backend operation") + } + host.workQueue.Add(func() { f(b) }) + } + + var numExts int + var exts iter.Seq2[int, *ipnext.Definition] + if overrideExts == nil { + // Use registered extensions. + exts = ipnext.Extensions().All() + numExts = ipnext.Extensions().Len() + } else { + // Use the provided, potentially empty, overrideExts + // instead of the registered ones. + exts = slices.All(overrideExts) + numExts = len(overrideExts) + } + + host.allExtensions = make([]ipnext.Extension, 0, numExts) + for _, d := range exts { + ext, err := d.MakeExtension(logf, sys) + if errors.Is(err, ipnext.SkipExtension) { + // The extension wants to be skipped. + host.logf("%q: %v", d.Name(), err) + continue + } else if err != nil { + return nil, fmt.Errorf("failed to create %q extension: %v", d.Name(), err) + } + host.allExtensions = append(host.allExtensions, ext) + } + return host, nil +} + +// Init initializes the host and the extensions it manages. +func (h *ExtensionHost) Init() { + if h != nil { + h.initOnce.Do(h.init) + } +} + +func (h *ExtensionHost) init() { + // Initialize the extensions in the order they were registered. + h.mu.Lock() + h.activeExtensions = make([]ipnext.Extension, 0, len(h.allExtensions)) + h.extensionsByName = make(map[string]ipnext.Extension, len(h.allExtensions)) + h.mu.Unlock() + for _, ext := range h.allExtensions { + // Do not hold the lock while calling [ipnext.Extension.Init]. + // Extensions call back into the host to register their callbacks, + // and that would cause a deadlock if the h.mu is already held. + if err := ext.Init(h); err != nil { + // As per the [ipnext.Extension] interface, failures to initialize + // an extension are never fatal. The extension is simply skipped. + // + // But we handle [ipnext.SkipExtension] differently for nicer logging + // if the extension wants to be skipped and not actually failing. + if errors.Is(err, ipnext.SkipExtension) { + h.logf("%q: %v", ext.Name(), err) + } else { + h.logf("%q init failed: %v", ext.Name(), err) + } + continue + } + // Update the initialized extensions lists as soon as the extension is initialized. + // We'd like to make them visible to other extensions that are initialized later. + h.mu.Lock() + h.activeExtensions = append(h.activeExtensions, ext) + h.extensionsByName[ext.Name()] = ext + h.mu.Unlock() + } + + // Report active extensions to the log. + // TODO(nickkhyl): update client metrics to include the active/failed/skipped extensions. + h.mu.Lock() + extensionNames := slices.Collect(maps.Keys(h.extensionsByName)) + h.mu.Unlock() + h.logf("active extensions: %v", strings.Join(extensionNames, ", ")) + + // Additional init steps that need to be performed after all extensions have been initialized. + h.mu.Lock() + wq := h.postInitWorkQueue + h.postInitWorkQueue = nil + h.initialized.Store(true) + h.mu.Unlock() + + // Enqueue work that was requested and deferred during initialization. + h.doEnqueueBackendOperation(func(b Backend) { + for _, f := range wq { + f(b) + } + }) + +} + +// Profiles implements [ipnext.Host]. +func (h *ExtensionHost) Profiles() ipnext.ProfileServices { + // Currently, [ExtensionHost] implements [ipnext.ProfileServices] directly. + // We might want to extract it to a separate type in the future. + return h +} + +// SwitchToBestProfileAsync implements [ipnext.ProfileServices]. +func (h *ExtensionHost) SwitchToBestProfileAsync(reason string) { + if h == nil { + return + } + h.enqueueBackendOperation(func(b Backend) { + b.SwitchToBestProfile(reason) + }) +} + +// RegisterProfileChangeCallback implements [ipnext.ProfileServices]. +func (h *ExtensionHost) RegisterProfileChangeCallback(cb ipnext.ProfileChangeCallback) (unregister func()) { + if h == nil { + return func() {} + } + if cb == nil { + panic("nil profile change callback") + } + h.mu.Lock() + defer h.mu.Unlock() + handle := h.profileChangeCbs.Add(cb) + return func() { + h.mu.Lock() + defer h.mu.Unlock() + delete(h.profileChangeCbs, handle) + } +} + +// NotifyProfileChange invokes registered profile change callbacks. +// It strips private keys from the [ipn.Prefs] before passing it to the callbacks. +func (h *ExtensionHost) NotifyProfileChange(profile ipn.LoginProfileView, prefs ipn.PrefsView, sameNode bool) { + if h == nil { + return + } + h.mu.Lock() + cbs := collectValues(h.profileChangeCbs) + h.mu.Unlock() + if cbs != nil { + // Strip private keys from the prefs before passing it to the callbacks. + // Extensions should not need it (unless proven otherwise in the future), + // and this is a good way to ensure that they won't accidentally leak them. + prefs = stripKeysFromPrefs(prefs) + for _, cb := range cbs { + cb(profile, prefs, sameNode) + } + } +} + +// RegisterBackgroundProfileResolver implements [ipnext.ProfileServices]. +func (h *ExtensionHost) RegisterBackgroundProfileResolver(resolver ipnext.ProfileResolver) (unregister func()) { + if h == nil { + return func() {} + } + h.mu.Lock() + defer h.mu.Unlock() + handle := h.backgroundProfileResolvers.Add(resolver) + return func() { + h.mu.Lock() + defer h.mu.Unlock() + delete(h.backgroundProfileResolvers, handle) + } +} + +// DetermineBackgroundProfile returns a read-only view of the profile +// used when no GUI/CLI client is connected, using background profile +// resolvers registered by extensions. +// +// It returns an invalid view if Tailscale should not run in the background +// and instead disconnect until a GUI/CLI client connects. +// +// As of 2025-02-07, this is only used on Windows. +func (h *ExtensionHost) DetermineBackgroundProfile(profiles ipnext.ProfileStore) ipn.LoginProfileView { + if h == nil { + return ipn.LoginProfileView{} + } + // TODO(nickkhyl): check if the returned profile is allowed on the device, + // such as when [syspolicy.Tailnet] policy setting requires a specific Tailnet. + // See tailscale/corp#26249. + + // Attempt to resolve the background profile using the registered + // background profile resolvers (e.g., [ipn/desktop.desktopSessionsExt] on Windows). + h.mu.Lock() + resolvers := collectValues(h.backgroundProfileResolvers) + h.mu.Unlock() + for _, resolver := range resolvers { + if profile := resolver(profiles); profile.Valid() { + return profile + } + } + + // Otherwise, switch to an empty profile and disconnect Tailscale + // until a GUI or CLI client connects. + return ipn.LoginProfileView{} +} + +// RegisterControlClientCallback implements [ipnext.Host]. +func (h *ExtensionHost) RegisterControlClientCallback(cb ipnext.NewControlClientCallback) (unregister func()) { + if h == nil { + return func() {} + } + if cb == nil { + panic("nil control client callback") + } + h.mu.Lock() + defer h.mu.Unlock() + handle := h.newControlClientCbs.Add(cb) + return func() { + h.mu.Lock() + defer h.mu.Unlock() + delete(h.newControlClientCbs, handle) + } +} + +// NotifyNewControlClient invokes all registered control client callbacks. +// It returns callbacks to be executed when the control client shuts down. +func (h *ExtensionHost) NotifyNewControlClient(cc controlclient.Client, profile ipn.LoginProfileView, prefs ipn.PrefsView) (ccShutdownCbs []func()) { + if h == nil { + return nil + } + h.mu.Lock() + cbs := collectValues(h.newControlClientCbs) + h.mu.Unlock() + if len(cbs) > 0 { + ccShutdownCbs = make([]func(), 0, len(cbs)) + for _, cb := range cbs { + if shutdown := cb(cc, profile, prefs); shutdown != nil { + ccShutdownCbs = append(ccShutdownCbs, shutdown) + } + } + } + return ccShutdownCbs +} + +// RegisterAuditLogProvider implements [ipnext.Host]. +func (h *ExtensionHost) RegisterAuditLogProvider(provider ipnext.AuditLogProvider) (unregister func()) { + if h == nil { + return func() {} + } + if provider == nil { + panic("nil audit log provider") + } + h.mu.Lock() + defer h.mu.Unlock() + handle := h.auditLoggers.Add(provider) + return func() { + h.mu.Lock() + defer h.mu.Unlock() + delete(h.auditLoggers, handle) + } +} + +// AuditLogger returns a function that reports an auditable action +// to all registered audit loggers. It fails if any of them returns an error, +// indicating that the action cannot be logged and must not be performed. +// +// It implements [ipnext.Host], but is also used by the [LocalBackend]. +// +// The returned function closes over the current state of the host and extensions, +// which typically includes the current profile and the audit loggers registered by extensions. +// It must not be persisted outside of the auditable action context. +func (h *ExtensionHost) AuditLogger() ipnauth.AuditLogFunc { + if h == nil { + return func(tailcfg.ClientAuditAction, string) error { return nil } + } + + h.mu.Lock() + providers := collectValues(h.auditLoggers) + h.mu.Unlock() + + var loggers []ipnauth.AuditLogFunc + if len(providers) > 0 { + loggers = make([]ipnauth.AuditLogFunc, len(providers)) + for i, provider := range providers { + loggers[i] = provider() + } + } + return func(action tailcfg.ClientAuditAction, details string) error { + // Log auditable actions to the host's log regardless of whether + // the audit loggers are available or not. + h.logf("auditlog: %v: %v", action, details) + + // Invoke all registered audit loggers and collect errors. + // If any of them returns an error, the action is denied. + var errs []error + for _, logger := range loggers { + if err := logger(action, details); err != nil { + errs = append(errs, err) + } + } + return errors.Join(errs...) + } +} + +// Shutdown shuts down the extension host and all initialized extensions. +func (h *ExtensionHost) Shutdown() { + if h == nil { + return + } + // Ensure that the init function has completed before shutting down, + // or prevent any further init calls from happening. + h.initOnce.Do(func() {}) + h.shutdownOnce.Do(h.shutdown) +} + +func (h *ExtensionHost) shutdown() { + // Prevent any queued but not yet started operations from running, + // block new operations from being enqueued, and wait for the + // currently executing operation (if any) to finish. + h.shutdownWorkQueue() + // Invoke shutdown callbacks registered by extensions. + h.shutdownExtensions() +} + +func (h *ExtensionHost) shutdownWorkQueue() { + h.workQueue.Shutdown() + var ctx context.Context + if testenv.InTest() { + // In tests, we'd like to wait indefinitely for the current operation to finish, + // mostly to help avoid flaky tests. Test runners can be pretty slow. + ctx = context.Background() + } else { + // In prod, however, we want to avoid blocking indefinitely. + // The 5s timeout is somewhat arbitrary; LocalBackend operations + // should not take that long. + var cancel context.CancelFunc + ctx, cancel = context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + } + // Since callbacks are invoked synchronously, this will also wait + // for in-flight callbacks associated with those operations to finish. + if err := h.workQueue.Wait(ctx); err != nil { + h.logf("work queue shutdown failed: %v", err) + } +} + +func (h *ExtensionHost) shutdownExtensions() { + h.mu.Lock() + extensions := h.activeExtensions + h.mu.Unlock() + + // h.mu must not be held while shutting down extensions. + // Extensions might call back into the host and that would cause + // a deadlock if the h.mu is already held. + // + // Shutdown is called in the reverse order of Init. + for _, ext := range slices.Backward(extensions) { + if err := ext.Shutdown(); err != nil { + // Extension shutdown errors are never fatal, but we log them for debugging purposes. + h.logf("%q: shutdown callback failed: %v", ext.Name(), err) + } + } +} + +// enqueueBackendOperation enqueues a function to perform an operation on the [Backend]. +// If the host has not yet been initialized (e.g., when called from an extension's Init method), +// the operation is deferred until after the host and all extensions have completed initialization. +// It panics if the f is nil. +func (h *ExtensionHost) enqueueBackendOperation(f func(Backend)) { + if h == nil { + return + } + if f == nil { + panic("nil backend operation") + } + h.mu.Lock() // protects h.initialized and h.postInitWorkQueue + defer h.mu.Unlock() + if h.initialized.Load() { + h.doEnqueueBackendOperation(f) + } else { + h.postInitWorkQueue = append(h.postInitWorkQueue, f) + } +} + +// execQueue is an ordered asynchronous queue for executing functions. +// It is implemented by [execqueue.ExecQueue]. The interface is used +// to allow testing with a mock implementation. +type execQueue interface { + Add(func()) + Shutdown() + Wait(context.Context) error +} + +// collectValues is like [slices.Collect] of [maps.Values], +// but pre-allocates the slice to avoid reallocations. +// It returns nil if the map is empty. +func collectValues[K comparable, V any](m map[K]V) []V { + if len(m) == 0 { + return nil + } + s := make([]V, 0, len(m)) + for _, v := range m { + s = append(s, v) + } + return s +} diff --git a/ipn/ipnlocal/extension_host_test.go b/ipn/ipnlocal/extension_host_test.go new file mode 100644 index 000000000..1e03abaa1 --- /dev/null +++ b/ipn/ipnlocal/extension_host_test.go @@ -0,0 +1,1139 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package ipnlocal + +import ( + "cmp" + "context" + "errors" + "net/netip" + "reflect" + "slices" + "strconv" + "strings" + "sync" + "sync/atomic" + "testing" + + deepcmp "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" + + "tailscale.com/health" + "tailscale.com/ipn" + "tailscale.com/ipn/ipnauth" + "tailscale.com/ipn/ipnext" + "tailscale.com/ipn/store/mem" + "tailscale.com/tailcfg" + "tailscale.com/tsd" + "tailscale.com/tstest" + "tailscale.com/types/key" + "tailscale.com/types/persist" + "tailscale.com/util/must" +) + +// TestExtensionInitShutdown tests that [ExtensionHost] correctly initializes +// and shuts down extensions. +func TestExtensionInitShutdown(t *testing.T) { + t.Parallel() + + // As of 2025-04-08, [ipn.Host.Init] and [ipn.Host.Shutdown] do not return errors + // as extension initialization and shutdown errors are not fatal. + // If these methods are updated to return errors, this test should also be updated. + // The conversions below will fail to compile if their signatures change, reminding us to update the test. + _ = (func(*ExtensionHost))((*ExtensionHost).Init) + _ = (func(*ExtensionHost))((*ExtensionHost).Shutdown) + + tests := []struct { + name string + nilHost bool + exts []*testExtension + wantInit []string + wantShutdown []string + skipInit bool + }{ + { + name: "nil-host", + nilHost: true, + exts: []*testExtension{}, + wantInit: []string{}, + wantShutdown: []string{}, + }, + { + name: "empty-extensions", + exts: []*testExtension{}, + wantInit: []string{}, + wantShutdown: []string{}, + }, + { + name: "single-extension", + exts: []*testExtension{{name: "A"}}, + wantInit: []string{"A"}, + wantShutdown: []string{"A"}, + }, + { + name: "multiple-extensions/all-ok", + exts: []*testExtension{{name: "A"}, {name: "B"}, {name: "C"}}, + wantInit: []string{"A", "B", "C"}, + wantShutdown: []string{"C", "B", "A"}, + }, + { + name: "multiple-extensions/no-init-no-shutdown", + exts: []*testExtension{{name: "A"}, {name: "B"}, {name: "C"}}, + wantInit: []string{}, + wantShutdown: []string{}, + skipInit: true, + }, + { + name: "multiple-extensions/init-failed/first", + exts: []*testExtension{{ + name: "A", + InitHook: func(*testExtension) error { return errors.New("init failed") }, + }, { + name: "B", + InitHook: func(*testExtension) error { return nil }, + }, { + name: "C", + InitHook: func(*testExtension) error { return nil }, + }}, + wantInit: []string{"A", "B", "C"}, + wantShutdown: []string{"C", "B"}, + }, + { + name: "multiple-extensions/init-failed/second", + exts: []*testExtension{{ + name: "A", + InitHook: func(*testExtension) error { return nil }, + }, { + name: "B", + InitHook: func(*testExtension) error { return errors.New("init failed") }, + }, { + name: "C", + InitHook: func(*testExtension) error { return nil }, + }}, + wantInit: []string{"A", "B", "C"}, + wantShutdown: []string{"C", "A"}, + }, + { + name: "multiple-extensions/init-failed/third", + exts: []*testExtension{{ + name: "A", + InitHook: func(*testExtension) error { return nil }, + }, { + name: "B", + InitHook: func(*testExtension) error { return nil }, + }, { + name: "C", + InitHook: func(*testExtension) error { return errors.New("init failed") }, + }}, + wantInit: []string{"A", "B", "C"}, + wantShutdown: []string{"B", "A"}, + }, + { + name: "multiple-extensions/init-failed/all", + exts: []*testExtension{{ + name: "A", + InitHook: func(*testExtension) error { return errors.New("init failed") }, + }, { + name: "B", + InitHook: func(*testExtension) error { return errors.New("init failed") }, + }, { + name: "C", + InitHook: func(*testExtension) error { return errors.New("init failed") }, + }}, + wantInit: []string{"A", "B", "C"}, + wantShutdown: []string{}, + }, + { + name: "multiple-extensions/init-skipped", + exts: []*testExtension{{ + name: "A", + InitHook: func(*testExtension) error { return nil }, + }, { + name: "B", + InitHook: func(*testExtension) error { return ipnext.SkipExtension }, + }, { + name: "C", + InitHook: func(*testExtension) error { return nil }, + }}, + wantInit: []string{"A", "B", "C"}, + wantShutdown: []string{"C", "A"}, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + // Configure all extensions to append their names + // to the gotInit and gotShutdown slices + // during initialization and shutdown, + // so we can check that they are called in the right order + // and that shutdown is not unless init succeeded. + var gotInit, gotShutdown []string + for _, ext := range tt.exts { + oldInitHook := ext.InitHook + ext.InitHook = func(e *testExtension) error { + gotInit = append(gotInit, e.name) + if oldInitHook == nil { + return nil + } + return oldInitHook(e) + } + ext.ShutdownHook = func(e *testExtension) error { + gotShutdown = append(gotShutdown, e.name) + return nil + } + } + + var h *ExtensionHost + if !tt.nilHost { + h = newExtensionHostForTest(t, &testBackend{}, false, tt.exts...) + } + + if !tt.skipInit { + h.Init() + } + + // Check that the extensions were initialized in the right order. + if !slices.Equal(gotInit, tt.wantInit) { + t.Errorf("Init extensions: got %v; want %v", gotInit, tt.wantInit) + } + + // Calling Init again on the host should be a no-op. + // The [testExtension.Init] method fails the test if called more than once, + // regardless of which test is running, so we don't need to check it here. + // Similarly, calling Shutdown again on the host should be a no-op as well. + // It is verified by the [testExtension.Shutdown] method itself. + if !tt.skipInit { + h.Init() + } + + // Extensions should not be shut down before the host is shut down, + // even if they are not initialized successfully. + for _, ext := range tt.exts { + if gotShutdown := ext.ShutdownCalled(); gotShutdown { + t.Errorf("%q: Extension shutdown called before host shutdown", ext.name) + } + } + + h.Shutdown() + // Check that the extensions were shut down in the right order, + // and that they were not shut down if they were not initialized successfully. + if !slices.Equal(gotShutdown, tt.wantShutdown) { + t.Errorf("Shutdown extensions: got %v; want %v", gotShutdown, tt.wantShutdown) + } + + }) + } +} + +// TestNewExtensionHost tests that [NewExtensionHost] correctly creates +// an [ExtensionHost], instantiates the extensions and handles errors +// if an extension cannot be created. +func TestNewExtensionHost(t *testing.T) { + t.Parallel() + tests := []struct { + name string + defs []*ipnext.Definition + wantErr bool + wantExts []string + }{ + { + name: "no-exts", + defs: []*ipnext.Definition{}, + wantErr: false, + wantExts: []string{}, + }, + { + name: "exts-ok", + defs: []*ipnext.Definition{ + ipnext.DefinitionForTest(&testExtension{name: "A"}), + ipnext.DefinitionForTest(&testExtension{name: "B"}), + ipnext.DefinitionForTest(&testExtension{name: "C"}), + }, + wantErr: false, + wantExts: []string{"A", "B", "C"}, + }, + { + name: "exts-skipped", + defs: []*ipnext.Definition{ + ipnext.DefinitionForTest(&testExtension{name: "A"}), + ipnext.DefinitionWithErrForTest("B", ipnext.SkipExtension), + ipnext.DefinitionForTest(&testExtension{name: "C"}), + }, + wantErr: false, // extension B is skipped, that's ok + wantExts: []string{"A", "C"}, + }, + { + name: "exts-fail", + defs: []*ipnext.Definition{ + ipnext.DefinitionForTest(&testExtension{name: "A"}), + ipnext.DefinitionWithErrForTest("B", errors.New("failed creating Ext-2")), + ipnext.DefinitionForTest(&testExtension{name: "C"}), + }, + wantErr: true, // extension B failed to create, that's not ok + wantExts: []string{}, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + logf := tstest.WhileTestRunningLogger(t) + h, err := NewExtensionHost(logf, &tsd.System{}, &testBackend{}, tt.defs...) + if gotErr := err != nil; gotErr != tt.wantErr { + t.Errorf("NewExtensionHost: gotErr %v(%v); wantErr %v", gotErr, err, tt.wantErr) + } + if err != nil { + return + } + + var gotExts []string + for _, ext := range h.allExtensions { + gotExts = append(gotExts, ext.Name()) + } + + if !slices.Equal(gotExts, tt.wantExts) { + t.Errorf("Shutdown extensions: got %v; want %v", gotExts, tt.wantExts) + } + }) + } +} + +// TestExtensionHostEnqueueBackendOperation verifies that [ExtensionHost] enqueues +// backend operations and executes them asynchronously in the order they were received. +// It also checks that operations requested before the host and all extensions are initialized +// are not executed immediately but rather after the host and extensions are initialized. +func TestExtensionHostEnqueueBackendOperation(t *testing.T) { + t.Parallel() + tests := []struct { + name string + preInitCalls []string // before host init + extInitCalls []string // from [Extension.Init]; "" means no call + wantInitCalls []string // what we expect to be called after host init + postInitCalls []string // after host init + }{ + { + name: "no-calls", + preInitCalls: []string{}, + extInitCalls: []string{}, + wantInitCalls: []string{}, + postInitCalls: []string{}, + }, + { + name: "pre-init-calls", + preInitCalls: []string{"pre-init-1", "pre-init-2"}, + extInitCalls: []string{}, + wantInitCalls: []string{"pre-init-1", "pre-init-2"}, + postInitCalls: []string{}, + }, + { + name: "init-calls", + preInitCalls: []string{}, + extInitCalls: []string{"init-1", "init-2"}, + wantInitCalls: []string{"init-1", "init-2"}, + postInitCalls: []string{}, + }, + { + name: "post-init-calls", + preInitCalls: []string{}, + extInitCalls: []string{}, + wantInitCalls: []string{}, + postInitCalls: []string{"post-init-1", "post-init-2"}, + }, + { + name: "mixed-calls", + preInitCalls: []string{"pre-init-1", "pre-init-2"}, + extInitCalls: []string{"init-1", "", "init-2"}, + wantInitCalls: []string{"pre-init-1", "pre-init-2", "init-1", "init-2"}, + postInitCalls: []string{"post-init-1", "post-init-2"}, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + var gotCalls []string + var h *ExtensionHost + b := &testBackend{ + switchToBestProfileHook: func(reason string) { + gotCalls = append(gotCalls, reason) + }, + } + + exts := make([]*testExtension, len(tt.extInitCalls)) + for i, reason := range tt.extInitCalls { + exts[i] = &testExtension{} + if reason != "" { + exts[i].InitHook = func(e *testExtension) error { + e.host.Profiles().SwitchToBestProfileAsync(reason) + return nil + } + } + } + + h = newExtensionHostForTest(t, b, false, exts...) + wq := h.SetWorkQueueForTest(t) // use a test queue instead of [execqueue.ExecQueue]. + + // Issue some pre-init calls. They should be deferred and not + // added to the queue until the host is initialized. + for _, call := range tt.preInitCalls { + h.Profiles().SwitchToBestProfileAsync(call) + } + + // The queue should be empty before the host is initialized. + wq.Drain() + if len(gotCalls) != 0 { + t.Errorf("Pre-init calls: got %v; want (none)", gotCalls) + } + gotCalls = nil + + // Initialize the host and all extensions. + // The extensions will make their calls during initialization. + h.Init() + + // Calls made before or during initialization should now be enqueued and running. + wq.Drain() + if diff := deepcmp.Diff(tt.wantInitCalls, gotCalls, cmpopts.EquateEmpty()); diff != "" { + t.Errorf("Init calls: (+got -want): %v", diff) + } + gotCalls = nil + + // Let's make some more calls, as if extensions were making them in a response + // to external events. + for _, call := range tt.postInitCalls { + h.Profiles().SwitchToBestProfileAsync(call) + } + + // Any calls made after initialization should be enqueued and running. + wq.Drain() + if diff := deepcmp.Diff(tt.postInitCalls, gotCalls, cmpopts.EquateEmpty()); diff != "" { + t.Errorf("Init calls: (+got -want): %v", diff) + } + gotCalls = nil + }) + } +} + +// TestExtensionHostProfileChangeCallback verifies that [ExtensionHost] correctly handles the registration, +// invocation, and unregistration of profile change callbacks. It also checks that the callbacks are called +// with the correct arguments and that any private keys are stripped from [ipn.Prefs] before being passed to the callback. +func TestExtensionHostProfileChangeCallback(t *testing.T) { + t.Parallel() + + type profileChange struct { + Profile *ipn.LoginProfile + Prefs *ipn.Prefs + SameNode bool + } + // newProfileChange creates a new profile change with deep copies of the profile and prefs. + newProfileChange := func(profile ipn.LoginProfileView, prefs ipn.PrefsView, sameNode bool) profileChange { + return profileChange{ + Profile: profile.AsStruct(), + Prefs: prefs.AsStruct(), + SameNode: sameNode, + } + } + // makeProfileChangeAppender returns a callback that appends profile changes to the extension's state. + makeProfileChangeAppender := func(e *testExtension) ipnext.ProfileChangeCallback { + return func(profile ipn.LoginProfileView, prefs ipn.PrefsView, sameNode bool) { + UpdateExtState(e, "changes", func(changes []profileChange) []profileChange { + return append(changes, newProfileChange(profile, prefs, sameNode)) + }) + } + } + // getProfileChanges returns the profile changes stored in the extension's state. + getProfileChanges := func(e *testExtension) []profileChange { + changes, _ := GetExtStateOk[[]profileChange](e, "changes") + return changes + } + + tests := []struct { + name string + ext *testExtension + calls []profileChange + wantCalls []profileChange + }{ + { + // Register the callback for the lifetime of the extension. + name: "Register/Lifetime", + ext: &testExtension{}, + calls: []profileChange{ + {Profile: &ipn.LoginProfile{ID: "profile-1"}}, + {Profile: &ipn.LoginProfile{ID: "profile-2"}}, + {Profile: &ipn.LoginProfile{ID: "profile-3"}}, + {Profile: &ipn.LoginProfile{ID: "profile-3"}, SameNode: true}, + }, + wantCalls: []profileChange{ // all calls are received by the callback + {Profile: &ipn.LoginProfile{ID: "profile-1"}}, + {Profile: &ipn.LoginProfile{ID: "profile-2"}}, + {Profile: &ipn.LoginProfile{ID: "profile-3"}}, + {Profile: &ipn.LoginProfile{ID: "profile-3"}, SameNode: true}, + }, + }, + { + // Override the default InitHook used in the test to unregister the callback + // after the first call. + name: "Register/Once", + ext: &testExtension{ + InitHook: func(e *testExtension) error { + var unregister func() + handler := func(profile ipn.LoginProfileView, prefs ipn.PrefsView, sameNode bool) { + makeProfileChangeAppender(e)(profile, prefs, sameNode) + unregister() + } + unregister = e.host.Profiles().RegisterProfileChangeCallback(handler) + return nil + }, + }, + calls: []profileChange{ + {Profile: &ipn.LoginProfile{ID: "profile-1"}}, + {Profile: &ipn.LoginProfile{ID: "profile-2"}}, + {Profile: &ipn.LoginProfile{ID: "profile-3"}}, + }, + wantCalls: []profileChange{ // only the first call is received by the callback + {Profile: &ipn.LoginProfile{ID: "profile-1"}}, + }, + }, + { + // Ensure that ipn.Prefs are passed to the callback. + name: "CheckPrefs", + ext: &testExtension{}, + calls: []profileChange{{ + Profile: &ipn.LoginProfile{ID: "profile-1"}, + Prefs: &ipn.Prefs{ + WantRunning: true, + LoggedOut: false, + AdvertiseRoutes: []netip.Prefix{ + netip.MustParsePrefix("192.168.1.0/24"), + netip.MustParsePrefix("192.168.2.0/24"), + }, + }, + }}, + wantCalls: []profileChange{{ + Profile: &ipn.LoginProfile{ID: "profile-1"}, + Prefs: &ipn.Prefs{ + WantRunning: true, + LoggedOut: false, + AdvertiseRoutes: []netip.Prefix{ + netip.MustParsePrefix("192.168.1.0/24"), + netip.MustParsePrefix("192.168.2.0/24"), + }, + }, + }}, + }, + { + // Ensure that private keys are stripped from persist.Persist shared with extensions. + name: "StripPrivateKeys", + ext: &testExtension{}, + calls: []profileChange{{ + Profile: &ipn.LoginProfile{ID: "profile-1"}, + Prefs: &ipn.Prefs{ + Persist: &persist.Persist{ + NodeID: "12345", + PrivateNodeKey: key.NewNode(), + OldPrivateNodeKey: key.NewNode(), + NetworkLockKey: key.NewNLPrivate(), + UserProfile: tailcfg.UserProfile{ + ID: 12345, + LoginName: "test@example.com", + DisplayName: "Test User", + ProfilePicURL: "https://example.com/profile.png", + }, + }, + }, + }}, + wantCalls: []profileChange{{ + Profile: &ipn.LoginProfile{ID: "profile-1"}, + Prefs: &ipn.Prefs{ + Persist: &persist.Persist{ + NodeID: "12345", + PrivateNodeKey: key.NodePrivate{}, // stripped + OldPrivateNodeKey: key.NodePrivate{}, // stripped + NetworkLockKey: key.NLPrivate{}, // stripped + UserProfile: tailcfg.UserProfile{ + ID: 12345, + LoginName: "test@example.com", + DisplayName: "Test User", + ProfilePicURL: "https://example.com/profile.png", + }, + }, + }, + }}, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + // Use the default InitHook if not provided by the test. + if tt.ext.InitHook == nil { + tt.ext.InitHook = func(e *testExtension) error { + // Create and register the callback on init. + handler := makeProfileChangeAppender(e) + e.Cleanup(e.host.Profiles().RegisterProfileChangeCallback(handler)) + return nil + } + } + + h := newExtensionHostForTest(t, &testBackend{}, true, tt.ext) + for _, call := range tt.calls { + h.NotifyProfileChange(call.Profile.View(), call.Prefs.View(), call.SameNode) + } + opts := []deepcmp.Option{ + cmpopts.EquateComparable(key.NodePublic{}, netip.Addr{}, netip.Prefix{}), + } + if diff := deepcmp.Diff(tt.wantCalls, getProfileChanges(tt.ext), opts...); diff != "" { + t.Errorf("ProfileChange callbacks: (-want +got): %v", diff) + } + }) + } +} + +// TestBackgroundProfileResolver tests that the background profile resolvers +// are correctly registered, unregistered and invoked by the [ExtensionHost]. +func TestBackgroundProfileResolver(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + profiles []ipn.LoginProfile // the first one is the current profile + resolvers []ipnext.ProfileResolver + wantProfile *ipn.LoginProfile + }{ + { + name: "No-Profiles/No-Resolvers", + profiles: nil, + resolvers: nil, + wantProfile: nil, + }, + { + // TODO(nickkhyl): update this test as we change "background profile resolvers" + // to just "profile resolvers". The wantProfile should be the current profile by default. + name: "Has-Profiles/No-Resolvers", + profiles: []ipn.LoginProfile{{ID: "profile-1"}}, + resolvers: nil, + wantProfile: nil, + }, + { + name: "Has-Profiles/Single-Resolver", + profiles: []ipn.LoginProfile{{ID: "profile-1"}}, + resolvers: []ipnext.ProfileResolver{ + func(ps ipnext.ProfileStore) ipn.LoginProfileView { + return ps.CurrentProfile() + }, + }, + wantProfile: &ipn.LoginProfile{ID: "profile-1"}, + }, + // TODO(nickkhyl): add more tests for multiple resolvers and different profiles + // once we change "background profile resolvers" to just "profile resolvers" + // and add proper conflict resolution logic. + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + // Create a new profile manager and add the profiles to it. + // We expose the profile manager to the extensions via the read-only [ipnext.ProfileStore] interface. + pm := must.Get(newProfileManager(new(mem.Store), t.Logf, new(health.Tracker))) + for i, p := range tt.profiles { + // Generate a unique ID and key for each profile, + // unless the profile already has them set + // or is an empty, unnamed profile. + if p.Name != "" { + if p.ID == "" { + p.ID = ipn.ProfileID("profile-" + strconv.Itoa(i)) + } + if p.Key == "" { + p.Key = "key-" + ipn.StateKey(p.ID) + } + } + pv := p.View() + pm.knownProfiles[p.ID] = pv + if i == 0 { + // Set the first profile as the current one. + // A profileManager starts with an empty profile, + // so it's okay if the list of profiles is empty. + pm.SwitchToProfile(pv) + } + } + + h := newExtensionHostForTest[ipnext.Extension](t, &testBackend{}, true) + + // Register the resolvers with the host. + // This is typically done by the extensions themselves, + // but we do it here for testing purposes. + for _, r := range tt.resolvers { + t.Cleanup(h.Profiles().RegisterBackgroundProfileResolver(r)) + } + + // Call the resolver to get the profile. + gotProfile := h.DetermineBackgroundProfile(pm) + if !gotProfile.Equals(tt.wantProfile.View()) { + t.Errorf("Resolved profile: got %v; want %v", gotProfile, tt.wantProfile) + } + }) + } +} + +// TestAuditLogProviders tests that the [ExtensionHost] correctly handles +// the registration and invocation of audit log providers. It verifies that +// the audit loggers are called with the correct actions and details, +// and that any errors returned by the providers are properly propagated. +func TestAuditLogProviders(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + auditLoggers []ipnauth.AuditLogFunc // each represents an extension + actions []tailcfg.ClientAuditAction + wantErr bool + }{ + { + name: "No-Providers", + auditLoggers: nil, + actions: []tailcfg.ClientAuditAction{"TestAction-1", "TestAction-2"}, + wantErr: false, + }, + { + name: "Single-Provider/Ok", + auditLoggers: []ipnauth.AuditLogFunc{ + func(tailcfg.ClientAuditAction, string) error { return nil }, + }, + actions: []tailcfg.ClientAuditAction{"TestAction-1", "TestAction-2"}, + wantErr: false, + }, + { + name: "Single-Provider/Err", + auditLoggers: []ipnauth.AuditLogFunc{ + func(tailcfg.ClientAuditAction, string) error { + return errors.New("failed to log") + }, + }, + actions: []tailcfg.ClientAuditAction{"TestAction-1", "TestAction-2"}, + wantErr: true, + }, + { + name: "Many-Providers/Ok", + auditLoggers: []ipnauth.AuditLogFunc{ + func(tailcfg.ClientAuditAction, string) error { return nil }, + func(tailcfg.ClientAuditAction, string) error { return nil }, + }, + actions: []tailcfg.ClientAuditAction{"TestAction-1", "TestAction-2"}, + wantErr: false, + }, + { + name: "Many-Providers/Err", + auditLoggers: []ipnauth.AuditLogFunc{ + func(tailcfg.ClientAuditAction, string) error { + return errors.New("failed to log") + }, + func(tailcfg.ClientAuditAction, string) error { + return nil // all good + }, + func(tailcfg.ClientAuditAction, string) error { + return errors.New("also failed to log") + }, + }, + actions: []tailcfg.ClientAuditAction{"TestAction-1", "TestAction-2"}, + wantErr: true, // some providers failed to log, so that's an error + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Create extensions that register the audit log providers. + // Each extension/provider will append auditable actions to its state, + // then call the test's auditLogger function. + var exts []*testExtension + for _, auditLogger := range tt.auditLoggers { + ext := &testExtension{} + provider := func() ipnauth.AuditLogFunc { + return func(action tailcfg.ClientAuditAction, details string) error { + UpdateExtState(ext, "actions", func(actions []tailcfg.ClientAuditAction) []tailcfg.ClientAuditAction { + return append(actions, action) + }) + return auditLogger(action, details) + } + } + ext.InitHook = func(e *testExtension) error { + e.Cleanup(e.host.RegisterAuditLogProvider(provider)) + return nil + } + exts = append(exts, ext) + } + + // Initialize the host and the extensions. + h := newExtensionHostForTest(t, &testBackend{}, true, exts...) + + // Use [ExtensionHost.AuditLogger] to log actions. + for _, action := range tt.actions { + err := h.AuditLogger()(action, "Test details") + if gotErr := err != nil; gotErr != tt.wantErr { + t.Errorf("AuditLogger: gotErr %v (%v); wantErr %v", gotErr, err, tt.wantErr) + } + } + + // Check that the actions were logged correctly by each provider. + for _, ext := range exts { + gotActions := GetExtState[[]tailcfg.ClientAuditAction](ext, "actions") + if !slices.Equal(gotActions, tt.actions) { + t.Errorf("Actions: got %v; want %v", gotActions, tt.actions) + } + } + }) + } +} + +// TestNilExtensionHostMethodCall tests that calling exported methods +// on a nil [ExtensionHost] does not panic. We should treat it as a valid +// value since it's used in various tests that instantiate [LocalBackend] +// manually without calling [NewLocalBackend]. It also verifies that if +// a method returns a single func value (e.g., a cleanup function), +// it should not be nil. This is a basic sanity check to ensure that +// typical method calls on a nil receiver work as expected. +// It does not replace the need for more thorough testing of specific methods. +func TestNilExtensionHostMethodCall(t *testing.T) { + t.Parallel() + + var h *ExtensionHost + typ := reflect.TypeOf(h) + for i := range typ.NumMethod() { + m := typ.Method(i) + if strings.HasSuffix(m.Name, "ForTest") { + // Skip methods that are only for testing. + continue + } + + t.Run(m.Name, func(t *testing.T) { + t.Parallel() + // Calling the method on the nil receiver should not panic. + ret := checkMethodCallWithZeroArgs(t, m, h) + if len(ret) == 1 && ret[0].Kind() == reflect.Func { + // If the method returns a single func, such as a cleanup function, + // it should not be nil. + fn := ret[0] + if fn.IsNil() { + t.Fatalf("(%T).%s returned a nil func", h, m.Name) + } + // We expect it to be a no-op and calling it should not panic. + args := makeZeroArgsFor(fn) + func() { + defer func() { + if e := recover(); e != nil { + t.Fatalf("panic calling the func returned by (%T).%s: %v", e, m.Name, e) + } + }() + fn.Call(args) + }() + } + }) + } +} + +// checkMethodCallWithZeroArgs calls the method m on the receiver r +// with zero values for all its arguments, except the receiver itself. +// It returns the result of the method call, or fails the test if the call panics. +func checkMethodCallWithZeroArgs[T any](t *testing.T, m reflect.Method, r T) []reflect.Value { + t.Helper() + args := makeZeroArgsFor(m.Func) + // The first arg is the receiver. + args[0] = reflect.ValueOf(r) + // Calling the method should not panic. + defer func() { + if e := recover(); e != nil { + t.Fatalf("panic calling (%T).%s: %v", r, m.Name, e) + } + }() + return m.Func.Call(args) +} + +func makeZeroArgsFor(fn reflect.Value) []reflect.Value { + args := make([]reflect.Value, fn.Type().NumIn()) + for i := range args { + args[i] = reflect.Zero(fn.Type().In(i)) + } + return args +} + +// newExtensionHostForTest creates an [ExtensionHost] with the given backend and extensions. +// It associates each extension that either is or embeds a [testExtension] with the test +// and assigns a name if one isn’t already set. +// +// If the host cannot be created, it fails the test. +// +// The host is initialized if the initialize parameter is true. +// It is shut down automatically when the test ends. +func newExtensionHostForTest[T ipnext.Extension](t *testing.T, b Backend, initialize bool, exts ...T) *ExtensionHost { + t.Helper() + + // testExtensionIface is a subset of the methods implemented by [testExtension] that are used here. + // We use testExtensionIface in type assertions instead of using the [testExtension] type directly, + // which supports scenarios where an extension type embeds a [testExtension]. + type testExtensionIface interface { + Name() string + setName(string) + setT(*testing.T) + checkShutdown() + } + + logf := tstest.WhileTestRunningLogger(t) + defs := make([]*ipnext.Definition, len(exts)) + for i, ext := range exts { + if ext, ok := any(ext).(testExtensionIface); ok { + ext.setName(cmp.Or(ext.Name(), "Ext-"+strconv.Itoa(i))) + ext.setT(t) + } + defs[i] = ipnext.DefinitionForTest(ext) + } + h, err := NewExtensionHost(logf, &tsd.System{}, b, defs...) + if err != nil { + t.Fatalf("NewExtensionHost: %v", err) + } + // Replace doEnqueueBackendOperation with the one that's marked as a helper, + // so that we'll have better output if [testExecQueue.Add] fails a test. + h.doEnqueueBackendOperation = func(f func(Backend)) { + t.Helper() + h.workQueue.Add(func() { f(b) }) + } + for _, ext := range exts { + if ext, ok := any(ext).(testExtensionIface); ok { + t.Cleanup(ext.checkShutdown) + } + } + t.Cleanup(h.Shutdown) + if initialize { + h.Init() + } + return h +} + +// testExtension is an [ipnext.Extension] that: +// - Calls the provided init and shutdown callbacks +// when [Init] and [Shutdown] are called. +// - Ensures that [Init] and [Shutdown] are called at most once, +// that [Shutdown] is called after [Init], but is not called if [Init] fails +// and is called before the test ends if [Init] succeeds. +// +// Typically, [testExtension]s are created and passed to [newExtensionHostForTest] +// when creating an [ExtensionHost] for testing. +type testExtension struct { + t *testing.T // test that created the extension + name string // name of the extension, used for logging + + host ipnext.Host // or nil if not initialized + + // InitHook and ShutdownHook are optional hooks that can be set by tests. + InitHook, ShutdownHook func(*testExtension) error + + // initCnt, initOkCnt and shutdownCnt are used to verify that Init and Shutdown + // are called at most once and in the correct order. + initCnt, initOkCnt, shutdownCnt atomic.Int32 + + // mu protects the following fields. + mu sync.Mutex + // state is the optional state used by tests. + // It can be accessed by tests using [setTestExtensionState], + // [getTestExtensionStateOk] and [getTestExtensionState]. + state map[string]any + // cleanup are functions to be called on shutdown. + cleanup []func() +} + +var _ ipnext.Extension = (*testExtension)(nil) + +func (e *testExtension) setT(t *testing.T) { + e.t = t +} + +func (e *testExtension) setName(name string) { + e.name = name +} + +// Name implements [ipnext.Extension]. +func (e *testExtension) Name() string { + return e.name +} + +// Init implements [ipnext.Extension]. +func (e *testExtension) Init(host ipnext.Host) (err error) { + e.t.Helper() + e.host = host + if e.initCnt.Add(1) == 1 { + e.mu.Lock() + e.state = make(map[string]any) + e.mu.Unlock() + } else { + e.t.Errorf("%q: Init called more than once", e.name) + } + if e.InitHook != nil { + err = e.InitHook(e) + } + if err == nil { + e.initOkCnt.Add(1) + } + return err // may be nil or non-nil +} + +// InitCalled reports whether the Init method was called on the receiver. +func (e *testExtension) InitCalled() bool { + return e.initCnt.Load() != 0 +} + +func (e *testExtension) Cleanup(f func()) { + e.mu.Lock() + e.cleanup = append(e.cleanup, f) + e.mu.Unlock() +} + +// Shutdown implements [ipnext.Extension]. +func (e *testExtension) Shutdown() (err error) { + e.t.Helper() + e.mu.Lock() + cleanup := e.cleanup + e.cleanup = nil + e.mu.Unlock() + for _, f := range cleanup { + f() + } + if e.ShutdownHook != nil { + err = e.ShutdownHook(e) + } + if e.shutdownCnt.Add(1) != 1 { + e.t.Errorf("%q: Shutdown called more than once", e.name) + } + if e.initCnt.Load() == 0 { + e.t.Errorf("%q: Shutdown called without Init", e.name) + } else if e.initOkCnt.Load() == 0 { + e.t.Errorf("%q: Shutdown called despite failed Init", e.name) + } + e.host = nil + return err // may be nil or non-nil +} + +func (e *testExtension) checkShutdown() { + e.t.Helper() + if e.initOkCnt.Load() != 0 && e.shutdownCnt.Load() == 0 { + e.t.Errorf("%q: Shutdown has not been called before test end", e.name) + } +} + +// ShutdownCalled reports whether the Shutdown method was called on the receiver. +func (e *testExtension) ShutdownCalled() bool { + return e.shutdownCnt.Load() != 0 +} + +// SetExtState sets a keyed state on [testExtension] to the given value. +// Tests use it to propagate test-specific state throughout the extension lifecycle +// (e.g., between [testExtension.Init], [testExtension.Shutdown], and registered callbacks) +func SetExtState[T any](e *testExtension, key string, value T) { + e.mu.Lock() + defer e.mu.Unlock() + e.state[key] = value +} + +// UpdateExtState updates a keyed state of the extension using the provided update function. +func UpdateExtState[T any](e *testExtension, key string, update func(T) T) { + e.mu.Lock() + defer e.mu.Unlock() + old, _ := e.state[key].(T) + new := update(old) + e.state[key] = new +} + +// GetExtState returns the value of the keyed state of the extension. +// It returns a zero value of T if the state is not set or is of a different type. +func GetExtState[T any](e *testExtension, key string) T { + v, _ := GetExtStateOk[T](e, key) + return v +} + +// GetExtStateOk is like [getExtState], but also reports whether the state +// with the given key exists and is of the expected type. +func GetExtStateOk[T any](e *testExtension, key string) (_ T, ok bool) { + e.mu.Lock() + defer e.mu.Unlock() + v, ok := e.state[key].(T) + return v, ok +} + +// testExecQueue is a test implementation of [execQueue] +// that defers execution of the enqueued funcs until +// [testExecQueue.Drain] is called, and fails the test if +// if [execQueue.Add] is called before the host is initialized. +// +// It is typically used by calling [ExtensionHost.SetWorkQueueForTest]. +type testExecQueue struct { + t *testing.T // test that created the queue + h *ExtensionHost // host to own the queue + + mu sync.Mutex + queue []func() +} + +var _ execQueue = (*testExecQueue)(nil) + +// SetWorkQueueForTest is a helper function that creates a new [testExecQueue] +// and sets it as the work queue for the specified [ExtensionHost], +// returning the new queue. +// +// It fails the test if the host is already initialized. +func (h *ExtensionHost) SetWorkQueueForTest(t *testing.T) *testExecQueue { + t.Helper() + if h.initialized.Load() { + t.Fatalf("UseTestWorkQueue: host is already initialized") + return nil + } + q := &testExecQueue{t: t, h: h} + h.workQueue = q + return q +} + +// Add implements [execQueue]. +func (q *testExecQueue) Add(f func()) { + q.t.Helper() + + if !q.h.initialized.Load() { + q.t.Fatal("ExecQueue.Add must not be called until the host is initialized") + return + } + + q.mu.Lock() + q.queue = append(q.queue, f) + q.mu.Unlock() +} + +// Drain executes all queued functions in the order they were added. +func (q *testExecQueue) Drain() { + q.mu.Lock() + queue := q.queue + q.queue = nil + q.mu.Unlock() + + for _, f := range queue { + f() + } +} + +// Shutdown implements [execQueue]. +func (q *testExecQueue) Shutdown() {} + +// Wait implements [execQueue]. +func (q *testExecQueue) Wait(context.Context) error { return nil } + +// testBackend implements [ipnext.Backend] for testing purposes +// by calling the provided hooks when its methods are called. +type testBackend struct { + switchToBestProfileHook func(reason string) + + // mu protects the backend state. + // It is acquired on entry to the exported methods of the backend + // and released on exit, mimicking the behavior of the [LocalBackend]. + mu sync.Mutex +} + +func (b *testBackend) SwitchToBestProfile(reason string) { + b.mu.Lock() + defer b.mu.Unlock() + if b.switchToBestProfileHook != nil { + b.switchToBestProfileHook(reason) + } +} diff --git a/ipn/ipnlocal/local.go b/ipn/ipnlocal/local.go index a99d67cda..0f3ea1fbb 100644 --- a/ipn/ipnlocal/local.go +++ b/ipn/ipnlocal/local.go @@ -169,78 +169,6 @@ type watchSession struct { cancel context.CancelFunc // to shut down the session } -// Extension extends [LocalBackend] with additional functionality. -type Extension interface { - // Init is called to initialize the extension when the [LocalBackend] is created - // and before it starts running. If the extension cannot be initialized, - // it must return an error, and the Shutdown method will not be called. - // Any returned errors are not fatal; they are used for logging. - // TODO(nickkhyl): should we allow returning a fatal error? - Init(*LocalBackend) error - - // Shutdown is called when the [LocalBackend] is shutting down, - // if the extension was initialized. Any returned errors are not fatal; - // they are used for logging. - Shutdown() error -} - -// NewExtensionFn is a function that instantiates an [Extension]. -type NewExtensionFn func(logger.Logf, *tsd.System) (Extension, error) - -// registeredExtensions is a map of registered local backend extensions, -// where the key is the name of the extension and the value is the function -// that instantiates the extension. -var registeredExtensions map[string]NewExtensionFn - -// RegisterExtension registers a function that creates a [localBackendExtension]. -// It panics if newExt is nil or if an extension with the same name has already been registered. -func RegisterExtension(name string, newExt NewExtensionFn) { - if newExt == nil { - panic(fmt.Sprintf("lb: newExt is nil: %q", name)) - } - if _, ok := registeredExtensions[name]; ok { - panic(fmt.Sprintf("lb: duplicate extensions: %q", name)) - } - mak.Set(®isteredExtensions, name, newExt) -} - -// profileResolver is any function that returns a read-only view of a login profile. -// An invalid view indicates no profile. A valid profile view with an empty [ipn.ProfileID] -// indicates that the profile is new and has not been persisted yet. -// -// It is called with [LocalBackend.mu] held. -type profileResolver func() ipn.LoginProfileView - -// NewControlClientCallback is a function to be called when a new [controlclient.Client] -// is created and before it is first used. The login profile and prefs represent -// the profile for which the cc is created and are always valid; however, the -// profile's [ipn.LoginProfileView.ID] returns a zero [ipn.ProfileID] if the profile -// is new and has not been persisted yet. -// -// The callback is called with [LocalBackend.mu] held and must not call -// any [LocalBackend] methods. -// -// It returns a function to be called when the cc is being shut down, -// or nil if no cleanup is needed. -type NewControlClientCallback func(controlclient.Client, ipn.LoginProfileView, ipn.PrefsView) (cleanup func()) - -// ProfileChangeCallback is a function to be called when the current login profile changes. -// The sameNode parameter indicates whether the profile represents the same node as before, -// such as when only the profile metadata is updated but the node ID remains the same, -// or when a new profile is persisted and assigned an [ipn.ProfileID] for the first time. -// The subscribers can use this information to decide whether to reset their state. -// -// The profile and prefs are always valid, but the profile's [ipn.LoginProfileView.ID] -// returns a zero [ipn.ProfileID] if the profile is new and has not been persisted yet. -// -// The callback is called with [LocalBackend.mu] held and must not call -// any [LocalBackend] methods. -type ProfileChangeCallback func(_ ipn.LoginProfileView, _ ipn.PrefsView, sameNode bool) - -// AuditLogProvider is a function that returns an [ipnauth.AuditLogFunc] for -// logging auditable actions. -type AuditLogProvider func() ipnauth.AuditLogFunc - // LocalBackend is the glue between the major pieces of the Tailscale // network software: the cloud control plane (via controlclient), the // network data plane (via wgengine), and the user-facing UIs and CLIs @@ -311,6 +239,13 @@ type LocalBackend struct { // for testing and graceful shutdown purposes. goTracker goroutines.Tracker + // extHost is the bridge between [LocalBackend] and the registered [ipnext.Extension]s. + // It may be nil in tests that use direct composite literal initialization of [LocalBackend] + // instead of calling [NewLocalBackend]. A nil pointer is a valid, no-op host. + // It can be used with or without b.mu held, but is typically used with it held + // to prevent state changes while invoking callbacks. + extHost *ExtensionHost + // The mutex protects the following elements. mu sync.Mutex conf *conffile.Config // latest parsed config, or nil if not in declarative mode @@ -378,9 +313,6 @@ type LocalBackend struct { c2nUpdateStatus updateStatus currentUser ipnauth.Actor - // backgroundProfileResolvers are optional background profile resolvers. - backgroundProfileResolvers set.HandleSet[profileResolver] - selfUpdateProgress []ipnstate.UpdateProgress lastSelfUpdateState ipnstate.SelfUpdateStatus // capForcedNetfilter is the netfilter that control instructs Linux clients @@ -481,25 +413,6 @@ type LocalBackend struct { // reconnectTimer is used to schedule a reconnect by setting [ipn.Prefs.WantRunning] // to true after a delay, or nil if no reconnect is scheduled. reconnectTimer tstime.TimerController - - // shutdownCbs are the callbacks to be called when the backend is shutting down. - // Each callback is called exactly once in unspecified order and without b.mu held. - // Returned errors are logged but otherwise ignored and do not affect the shutdown process. - shutdownCbs set.HandleSet[func() error] - - // newControlClientCbs are the functions to be called when a new control client is created. - newControlClientCbs set.HandleSet[NewControlClientCallback] - - // profileChangeCbs are the callbacks to be called when the current login profile changes, - // either because of a profile switch, or because the profile information was updated - // by [LocalBackend.SetControlClientStatus], including when the profile is first populated - // and persisted. - profileChangeCbs set.HandleSet[ProfileChangeCallback] - - // auditLoggers is a collection of registered audit log providers. - // Each [AuditLogProvider] is called to get an [ipnauth.AuditLogFunc] when an auditable action - // is about to be performed. If an audit logger returns an error, the action is denied. - auditLoggers set.HandleSet[AuditLogProvider] } // HealthTracker returns the health tracker for the backend. @@ -614,6 +527,10 @@ func NewLocalBackend(logf logger.Logf, logID logid.PublicID, sys *tsd.System, lo } } + if b.extHost, err = NewExtensionHost(logf, sys, b); err != nil { + return nil, fmt.Errorf("failed to create extension host: %w", err) + } + if b.unregisterSysPolicyWatch, err = b.registerSysPolicyWatch(); err != nil { return nil, err } @@ -668,19 +585,7 @@ func NewLocalBackend(logf logger.Logf, logID logid.PublicID, sys *tsd.System, lo } } - for name, newFn := range registeredExtensions { - ext, err := newFn(logf, sys) - if err != nil { - b.logf("lb: failed to create %q extension: %v", name, err) - continue - } - if err := ext.Init(b); err != nil { - b.logf("lb: failed to initialize %q extension: %v", name, err) - continue - } - b.shutdownCbs.Add(ext.Shutdown) - } - + b.extHost.Init() return b, nil } @@ -1143,17 +1048,11 @@ func (b *LocalBackend) Shutdown() { if b.notifyCancel != nil { b.notifyCancel() } - shutdownCbs := slices.Collect(maps.Values(b.shutdownCbs)) - b.shutdownCbs = nil + extHost := b.extHost + b.extHost = nil b.mu.Unlock() b.webClientShutdown() - for _, cb := range shutdownCbs { - if err := cb(); err != nil { - b.logf("shutdown callback failed: %v", err) - } - } - if b.sockstatLogger != nil { ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() @@ -1170,6 +1069,7 @@ func (b *LocalBackend) Shutdown() { if cc != nil { cc.Shutdown() } + extHost.Shutdown() b.ctxCancel() b.e.Close() <-b.e.Done() @@ -1743,7 +1643,7 @@ func (b *LocalBackend) SetControlClientStatus(c controlclient.Client, st control // If the profile ID was empty before SetPrefs, it's a new profile // and the user has just completed a login for the first time. sameNode := profile.ID() == "" || profile.ID() == cp.ID() - b.notifyProfileChangeLocked(profile, prefs.View(), sameNode) + b.extHost.NotifyProfileChange(profile, prefs.View(), sameNode) } } @@ -2492,11 +2392,7 @@ func (b *LocalBackend) Start(opts ipn.Options) error { if err != nil { return err } - for _, cb := range b.newControlClientCbs { - if cleanup := cb(cc, b.pm.CurrentProfile(), prefs); cleanup != nil { - ccShutdownCbs = append(ccShutdownCbs, cleanup) - } - } + ccShutdownCbs = b.extHost.NotifyNewControlClient(cc, b.pm.CurrentProfile(), prefs) b.setControlClientLocked(cc) endpoints := b.endpoints @@ -4060,6 +3956,10 @@ func (b *LocalBackend) switchToBestProfileLockedOnEntry(reason string, unlock un // // b.mu must be held. func (b *LocalBackend) resolveBestProfileLocked() (_ ipn.LoginProfileView, isBackground bool) { + // TODO(nickkhyl): delegate all of this to the extensions and remove the distinction + // between "foreground" and "background" profiles as we migrate away from the concept + // of a single "current user" on Windows. See tailscale/corp#18342. + // // If a GUI/CLI client is connected, use the connected user's profile, which means // either the current profile if owned by the user, or their default profile. if b.currentUser != nil { @@ -4079,7 +3979,12 @@ func (b *LocalBackend) resolveBestProfileLocked() (_ ipn.LoginProfileView, isBac // If the returned background profileID is "", Tailscale will disconnect // and remain idle until a GUI or CLI client connects. if goos := envknob.GOOS(); goos == "windows" { - profile := b.getBackgroundProfileLocked() + // If Unattended Mode is enabled for the current profile, keep using it. + if b.pm.CurrentPrefs().ForceDaemon() { + return b.pm.CurrentProfile(), true + } + // Otherwise, use the profile returned by the extension. + profile := b.extHost.DetermineBackgroundProfile(b.pm) return profile, true } @@ -4092,47 +3997,6 @@ func (b *LocalBackend) resolveBestProfileLocked() (_ ipn.LoginProfileView, isBac return b.pm.CurrentProfile(), false } -// RegisterBackgroundProfileResolver registers a function to be used when -// resolving the background profile, until the returned unregister function is called. -func (b *LocalBackend) RegisterBackgroundProfileResolver(resolver profileResolver) (unregister func()) { - // TODO(nickkhyl): should we allow specifying some kind of priority/altitude for the resolver? - b.mu.Lock() - defer b.mu.Unlock() - handle := b.backgroundProfileResolvers.Add(resolver) - return func() { - b.mu.Lock() - defer b.mu.Unlock() - delete(b.backgroundProfileResolvers, handle) - } -} - -// getBackgroundProfileLocked returns a read-only view of the profile to use -// when no GUI/CLI client is connected. If Tailscale should not run in the background -// and should disconnect until a GUI/CLI client connects, the returned view is not valid. -// As of 2025-02-07, it is only used on Windows. -func (b *LocalBackend) getBackgroundProfileLocked() ipn.LoginProfileView { - // TODO(nickkhyl): check if the returned profile is allowed on the device, - // such as when [syspolicy.Tailnet] policy setting requires a specific Tailnet. - // See tailscale/corp#26249. - - // If Unattended Mode is enabled for the current profile, keep using it. - if b.pm.CurrentPrefs().ForceDaemon() { - return b.pm.CurrentProfile() - } - - // Otherwise, attempt to resolve the background profile using the background - // profile resolvers available on the current platform. - for _, resolver := range b.backgroundProfileResolvers { - if profile := resolver(); profile.Valid() { - return profile - } - } - - // Otherwise, switch to an empty profile and disconnect Tailscale - // until a GUI or CLI client connects. - return ipn.LoginProfileView{} -} - // CurrentUserForTest returns the current user and the associated WindowsUserID. // It is used for testing only, and will be removed along with the rest of the // "current user" functionality as we progress on the multi-user improvements (tailscale/corp#18342). @@ -4351,47 +4215,6 @@ func (b *LocalBackend) MaybeClearAppConnector(mp *ipn.MaskedPrefs) error { return err } -// RegisterAuditLogProvider registers an audit log provider, which returns a function -// to be called when an auditable action is about to be performed. -// The returned function unregisters the provider. -// It panics if the provider is nil. -func (b *LocalBackend) RegisterAuditLogProvider(provider AuditLogProvider) (unregister func()) { - if provider == nil { - panic("nil audit log provider") - } - b.mu.Lock() - defer b.mu.Unlock() - handle := b.auditLoggers.Add(provider) - return func() { - b.mu.Lock() - defer b.mu.Unlock() - delete(b.auditLoggers, handle) - } -} - -// getAuditLoggerLocked returns a function that calls all currently registered -// audit loggers, failing as soon as any of them returns an error. -// -// b.mu must be held. -func (b *LocalBackend) getAuditLoggerLocked() ipnauth.AuditLogFunc { - var loggers []ipnauth.AuditLogFunc - if len(b.auditLoggers) != 0 { - loggers = make([]ipnauth.AuditLogFunc, 0, len(b.auditLoggers)) - for _, getLogger := range b.auditLoggers { - loggers = append(loggers, getLogger()) - } - } - return func(action tailcfg.ClientAuditAction, details string) error { - b.logf("auditlog: %v: %v", action, details) - for _, logger := range loggers { - if err := logger(action, details); err != nil { - return err - } - } - return nil - } -} - // EditPrefs applies the changes in mp to the current prefs, // acting as the tailscaled itself rather than a specific user. func (b *LocalBackend) EditPrefs(mp *ipn.MaskedPrefs) (ipn.PrefsView, error) { @@ -4417,7 +4240,7 @@ func (b *LocalBackend) EditPrefsAs(mp *ipn.MaskedPrefs, actor ipnauth.Actor) (ip unlock := b.lockAndGetUnlock() defer unlock() if mp.WantRunningSet && !mp.WantRunning && b.pm.CurrentPrefs().WantRunning() { - if err := actor.CheckProfileAccess(b.pm.CurrentProfile(), ipnauth.Disconnect, b.getAuditLoggerLocked()); err != nil { + if err := actor.CheckProfileAccess(b.pm.CurrentProfile(), ipnauth.Disconnect, b.extHost.AuditLogger()); err != nil { b.logf("check profile access failed: %v", err) return ipn.PrefsView{}, err } @@ -6031,23 +5854,6 @@ func (b *LocalBackend) requestEngineStatusAndWait() { b.logf("requestEngineStatusAndWait: got status update.") } -// RegisterControlClientCallback registers a function to be called every time a new -// control client is created, until the returned unregister function is called. -// It panics if the cb is nil. -func (b *LocalBackend) RegisterControlClientCallback(cb NewControlClientCallback) (unregister func()) { - if cb == nil { - panic("nil control client callback") - } - b.mu.Lock() - defer b.mu.Unlock() - handle := b.newControlClientCbs.Add(cb) - return func() { - b.mu.Lock() - defer b.mu.Unlock() - delete(b.newControlClientCbs, handle) - } -} - // setControlClientLocked sets the control client to cc, // which may be nil. // @@ -7633,37 +7439,6 @@ func (b *LocalBackend) resetDialPlan() { } } -// RegisterProfileChangeCallback registers a function to be called when the current [ipn.LoginProfile] changes. -// If includeCurrent is true, the callback is called immediately with the current profile. -// The returned function unregisters the callback. -// It panics if the cb is nil. -func (b *LocalBackend) RegisterProfileChangeCallback(cb ProfileChangeCallback, includeCurrent bool) (unregister func()) { - if cb == nil { - panic("nil profile change callback") - } - b.mu.Lock() - defer b.mu.Unlock() - handle := b.profileChangeCbs.Add(cb) - if includeCurrent { - cb(b.pm.CurrentProfile(), stripKeysFromPrefs(b.pm.CurrentPrefs()), false) - } - return func() { - b.mu.Lock() - defer b.mu.Unlock() - delete(b.profileChangeCbs, handle) - } -} - -// notifyProfileChangeLocked invokes all registered profile change callbacks. -// -// b.mu must be held. -func (b *LocalBackend) notifyProfileChangeLocked(profile ipn.LoginProfileView, prefs ipn.PrefsView, sameNode bool) { - prefs = stripKeysFromPrefs(prefs) - for _, cb := range b.profileChangeCbs { - cb(profile, prefs, sameNode) - } -} - // getHardwareAddrs returns the hardware addresses for the machine. If the list // of hardware addresses is empty, it will return the previously known hardware // addresses. Both the current, and previously known hardware addresses might be @@ -7711,7 +7486,7 @@ func (b *LocalBackend) resetForProfileChangeLockedOnEntry(unlock unlockOnce) err b.lastSuggestedExitNode = "" b.keyExpired = false b.resetAlwaysOnOverrideLocked() - b.notifyProfileChangeLocked(b.pm.CurrentProfile(), b.pm.CurrentPrefs(), false) + b.extHost.NotifyProfileChange(b.pm.CurrentProfile(), b.pm.CurrentPrefs(), false) b.setAtomicValuesFromPrefsLocked(b.pm.CurrentPrefs()) b.enterStateLockedOnEntry(ipn.NoState, unlock) // Reset state; releases b.mu b.health.SetLocalLogConfigHealth(nil) diff --git a/ipn/ipnlocal/profiles.go b/ipn/ipnlocal/profiles.go index 901a4a899..057fe2aae 100644 --- a/ipn/ipnlocal/profiles.go +++ b/ipn/ipnlocal/profiles.go @@ -17,6 +17,7 @@ import ( "tailscale.com/envknob" "tailscale.com/health" "tailscale.com/ipn" + "tailscale.com/ipn/ipnext" "tailscale.com/tailcfg" "tailscale.com/types/logger" "tailscale.com/util/clientmetric" @@ -24,6 +25,9 @@ import ( var debug = envknob.RegisterBool("TS_DEBUG_PROFILES") +// [profileManager] implements [ipnext.ProfileStore]. +var _ ipnext.ProfileStore = (*profileManager)(nil) + // profileManager is a wrapper around an [ipn.StateStore] that manages // multiple profiles and the current profile. // diff --git a/tsd/tsd.go b/tsd/tsd.go index 1d1f35017..acd09560c 100644 --- a/tsd/tsd.go +++ b/tsd/tsd.go @@ -26,7 +26,6 @@ import ( "tailscale.com/health" "tailscale.com/ipn" "tailscale.com/ipn/conffile" - "tailscale.com/ipn/desktop" "tailscale.com/net/dns" "tailscale.com/net/netmon" "tailscale.com/net/tsdial" @@ -53,7 +52,6 @@ type System struct { Netstack SubSystem[NetstackImpl] // actually a *netstack.Impl DriveForLocal SubSystem[drive.FileSystemForLocal] DriveForRemote SubSystem[drive.FileSystemForRemote] - SessionManager SubSystem[desktop.SessionManager] // InitialConfig is initial server config, if any. // It is nil if the node is not in declarative mode. @@ -112,8 +110,6 @@ func (s *System) Set(v any) { s.DriveForLocal.Set(v) case drive.FileSystemForRemote: s.DriveForRemote.Set(v) - case desktop.SessionManager: - s.SessionManager.Set(v) default: panic(fmt.Sprintf("unknown type %T", v)) }