tailscale/ipn/ipnlocal/extension_host_test.go
Nick Khyl f28c8d0ec0 ipn/ipn{ext,local}: allow extension lookup by name or type
In this PR, we add two methods to facilitate extension lookup by both extensions,
and non-extensions (e.g., PeerAPI or LocalAPI handlers):
 - FindExtensionByName returns an extension with the specified name.
   It can then be type asserted to a given type.
 - FindMatchingExtension is like errors.As, but for extensions.
   It returns the first extension that matches the target type (either a specific extension
   or an interface).

Updates tailscale/corp#27645
Updates tailscale/corp#27502

Signed-off-by: Nick Khyl <nickk@tailscale.com>
2025-04-11 18:34:46 -05:00

1234 lines
38 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause
package ipnlocal
import (
"cmp"
"context"
"errors"
"net/netip"
"reflect"
"slices"
"strconv"
"strings"
"sync"
"sync/atomic"
"testing"
deepcmp "github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
"tailscale.com/health"
"tailscale.com/ipn"
"tailscale.com/ipn/ipnauth"
"tailscale.com/ipn/ipnext"
"tailscale.com/ipn/store/mem"
"tailscale.com/tailcfg"
"tailscale.com/tsd"
"tailscale.com/tstest"
"tailscale.com/types/key"
"tailscale.com/types/persist"
"tailscale.com/util/must"
)
// TestExtensionInitShutdown tests that [ExtensionHost] correctly initializes
// and shuts down extensions.
func TestExtensionInitShutdown(t *testing.T) {
t.Parallel()
// As of 2025-04-08, [ipn.Host.Init] and [ipn.Host.Shutdown] do not return errors
// as extension initialization and shutdown errors are not fatal.
// If these methods are updated to return errors, this test should also be updated.
// The conversions below will fail to compile if their signatures change, reminding us to update the test.
_ = (func(*ExtensionHost))((*ExtensionHost).Init)
_ = (func(*ExtensionHost))((*ExtensionHost).Shutdown)
tests := []struct {
name string
nilHost bool
exts []*testExtension
wantInit []string
wantShutdown []string
skipInit bool
}{
{
name: "nil-host",
nilHost: true,
exts: []*testExtension{},
wantInit: []string{},
wantShutdown: []string{},
},
{
name: "empty-extensions",
exts: []*testExtension{},
wantInit: []string{},
wantShutdown: []string{},
},
{
name: "single-extension",
exts: []*testExtension{{name: "A"}},
wantInit: []string{"A"},
wantShutdown: []string{"A"},
},
{
name: "multiple-extensions/all-ok",
exts: []*testExtension{{name: "A"}, {name: "B"}, {name: "C"}},
wantInit: []string{"A", "B", "C"},
wantShutdown: []string{"C", "B", "A"},
},
{
name: "multiple-extensions/no-init-no-shutdown",
exts: []*testExtension{{name: "A"}, {name: "B"}, {name: "C"}},
wantInit: []string{},
wantShutdown: []string{},
skipInit: true,
},
{
name: "multiple-extensions/init-failed/first",
exts: []*testExtension{{
name: "A",
InitHook: func(*testExtension) error { return errors.New("init failed") },
}, {
name: "B",
InitHook: func(*testExtension) error { return nil },
}, {
name: "C",
InitHook: func(*testExtension) error { return nil },
}},
wantInit: []string{"A", "B", "C"},
wantShutdown: []string{"C", "B"},
},
{
name: "multiple-extensions/init-failed/second",
exts: []*testExtension{{
name: "A",
InitHook: func(*testExtension) error { return nil },
}, {
name: "B",
InitHook: func(*testExtension) error { return errors.New("init failed") },
}, {
name: "C",
InitHook: func(*testExtension) error { return nil },
}},
wantInit: []string{"A", "B", "C"},
wantShutdown: []string{"C", "A"},
},
{
name: "multiple-extensions/init-failed/third",
exts: []*testExtension{{
name: "A",
InitHook: func(*testExtension) error { return nil },
}, {
name: "B",
InitHook: func(*testExtension) error { return nil },
}, {
name: "C",
InitHook: func(*testExtension) error { return errors.New("init failed") },
}},
wantInit: []string{"A", "B", "C"},
wantShutdown: []string{"B", "A"},
},
{
name: "multiple-extensions/init-failed/all",
exts: []*testExtension{{
name: "A",
InitHook: func(*testExtension) error { return errors.New("init failed") },
}, {
name: "B",
InitHook: func(*testExtension) error { return errors.New("init failed") },
}, {
name: "C",
InitHook: func(*testExtension) error { return errors.New("init failed") },
}},
wantInit: []string{"A", "B", "C"},
wantShutdown: []string{},
},
{
name: "multiple-extensions/init-skipped",
exts: []*testExtension{{
name: "A",
InitHook: func(*testExtension) error { return nil },
}, {
name: "B",
InitHook: func(*testExtension) error { return ipnext.SkipExtension },
}, {
name: "C",
InitHook: func(*testExtension) error { return nil },
}},
wantInit: []string{"A", "B", "C"},
wantShutdown: []string{"C", "A"},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
// Configure all extensions to append their names
// to the gotInit and gotShutdown slices
// during initialization and shutdown,
// so we can check that they are called in the right order
// and that shutdown is not unless init succeeded.
var gotInit, gotShutdown []string
for _, ext := range tt.exts {
oldInitHook := ext.InitHook
ext.InitHook = func(e *testExtension) error {
gotInit = append(gotInit, e.name)
if oldInitHook == nil {
return nil
}
return oldInitHook(e)
}
ext.ShutdownHook = func(e *testExtension) error {
gotShutdown = append(gotShutdown, e.name)
return nil
}
}
var h *ExtensionHost
if !tt.nilHost {
h = newExtensionHostForTest(t, &testBackend{}, false, tt.exts...)
}
if !tt.skipInit {
h.Init()
}
// Check that the extensions were initialized in the right order.
if !slices.Equal(gotInit, tt.wantInit) {
t.Errorf("Init extensions: got %v; want %v", gotInit, tt.wantInit)
}
// Calling Init again on the host should be a no-op.
// The [testExtension.Init] method fails the test if called more than once,
// regardless of which test is running, so we don't need to check it here.
// Similarly, calling Shutdown again on the host should be a no-op as well.
// It is verified by the [testExtension.Shutdown] method itself.
if !tt.skipInit {
h.Init()
}
// Extensions should not be shut down before the host is shut down,
// even if they are not initialized successfully.
for _, ext := range tt.exts {
if gotShutdown := ext.ShutdownCalled(); gotShutdown {
t.Errorf("%q: Extension shutdown called before host shutdown", ext.name)
}
}
h.Shutdown()
// Check that the extensions were shut down in the right order,
// and that they were not shut down if they were not initialized successfully.
if !slices.Equal(gotShutdown, tt.wantShutdown) {
t.Errorf("Shutdown extensions: got %v; want %v", gotShutdown, tt.wantShutdown)
}
})
}
}
// TestNewExtensionHost tests that [NewExtensionHost] correctly creates
// an [ExtensionHost], instantiates the extensions and handles errors
// if an extension cannot be created.
func TestNewExtensionHost(t *testing.T) {
t.Parallel()
tests := []struct {
name string
defs []*ipnext.Definition
wantErr bool
wantExts []string
}{
{
name: "no-exts",
defs: []*ipnext.Definition{},
wantErr: false,
wantExts: []string{},
},
{
name: "exts-ok",
defs: []*ipnext.Definition{
ipnext.DefinitionForTest(&testExtension{name: "A"}),
ipnext.DefinitionForTest(&testExtension{name: "B"}),
ipnext.DefinitionForTest(&testExtension{name: "C"}),
},
wantErr: false,
wantExts: []string{"A", "B", "C"},
},
{
name: "exts-skipped",
defs: []*ipnext.Definition{
ipnext.DefinitionForTest(&testExtension{name: "A"}),
ipnext.DefinitionWithErrForTest("B", ipnext.SkipExtension),
ipnext.DefinitionForTest(&testExtension{name: "C"}),
},
wantErr: false, // extension B is skipped, that's ok
wantExts: []string{"A", "C"},
},
{
name: "exts-fail",
defs: []*ipnext.Definition{
ipnext.DefinitionForTest(&testExtension{name: "A"}),
ipnext.DefinitionWithErrForTest("B", errors.New("failed creating Ext-2")),
ipnext.DefinitionForTest(&testExtension{name: "C"}),
},
wantErr: true, // extension B failed to create, that's not ok
wantExts: []string{},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
logf := tstest.WhileTestRunningLogger(t)
h, err := NewExtensionHost(logf, &tsd.System{}, &testBackend{}, tt.defs...)
if gotErr := err != nil; gotErr != tt.wantErr {
t.Errorf("NewExtensionHost: gotErr %v(%v); wantErr %v", gotErr, err, tt.wantErr)
}
if err != nil {
return
}
var gotExts []string
for _, ext := range h.allExtensions {
gotExts = append(gotExts, ext.Name())
}
if !slices.Equal(gotExts, tt.wantExts) {
t.Errorf("Shutdown extensions: got %v; want %v", gotExts, tt.wantExts)
}
})
}
}
// TestFindMatchingExtension tests that [ExtensionHost.FindMatchingExtension] correctly
// finds extensions by their type or interface.
func TestFindMatchingExtension(t *testing.T) {
t.Parallel()
// Define test extension types and a couple of interfaces
type (
extensionA struct {
testExtension
}
extensionB struct {
testExtension
}
extensionC struct {
testExtension
}
supportedIface interface {
Name() string
}
unsupportedIface interface {
Unsupported()
}
)
// Register extensions A and B, but not C.
extA := &extensionA{testExtension: testExtension{name: "A"}}
extB := &extensionB{testExtension: testExtension{name: "B"}}
h := newExtensionHostForTest[ipnext.Extension](t, &testBackend{}, true, extA, extB)
var gotA *extensionA
if !h.FindMatchingExtension(&gotA) {
t.Errorf("LookupExtension(%T): not found", gotA)
} else if gotA != extA {
t.Errorf("LookupExtension(%T): got %v; want %v", gotA, gotA, extA)
}
var gotB *extensionB
if !h.FindMatchingExtension(&gotB) {
t.Errorf("LookupExtension(%T): extension B not found", gotB)
} else if gotB != extB {
t.Errorf("LookupExtension(%T): got %v; want %v", gotB, gotB, extB)
}
var gotC *extensionC
if h.FindMatchingExtension(&gotC) {
t.Errorf("LookupExtension(%T): found, but it should not exist", gotC)
}
// All extensions implement the supportedIface interface,
// but LookupExtension should only return the first one found,
// which is extA.
var gotSupportedIface supportedIface
if !h.FindMatchingExtension(&gotSupportedIface) {
t.Errorf("LookupExtension(%T): not found", gotSupportedIface)
} else if gotName, wantName := gotSupportedIface.Name(), extA.Name(); gotName != wantName {
t.Errorf("LookupExtension(%T): name: got %v; want %v", gotSupportedIface, gotName, wantName)
} else if gotSupportedIface != extA {
t.Errorf("LookupExtension(%T): got %v; want %v", gotSupportedIface, gotSupportedIface, extA)
}
var gotUnsupportedIface unsupportedIface
if h.FindMatchingExtension(&gotUnsupportedIface) {
t.Errorf("LookupExtension(%T): found, but it should not exist", gotUnsupportedIface)
}
}
// TestFindExtensionByName tests that [ExtensionHost.FindExtensionByName] correctly
// finds extensions by their name.
func TestFindExtensionByName(t *testing.T) {
// Register extensions A and B, but not C.
extA := &testExtension{name: "A"}
extB := &testExtension{name: "B"}
h := newExtensionHostForTest(t, &testBackend{}, true, extA, extB)
gotA, ok := h.FindExtensionByName(extA.Name()).(*testExtension)
if !ok {
t.Errorf("FindExtensionByName(%q): not found", extA.Name())
} else if gotA != extA {
t.Errorf(`FindExtensionByName(%q): got %v; want %v`, extA.Name(), gotA, extA)
}
gotB, ok := h.FindExtensionByName(extB.Name()).(*testExtension)
if !ok {
t.Errorf("FindExtensionByName(%q): not found", extB.Name())
} else if gotB != extB {
t.Errorf(`FindExtensionByName(%q): got %v; want %v`, extB.Name(), gotB, extB)
}
gotC, ok := h.FindExtensionByName("C").(*testExtension)
if ok {
t.Errorf(`FindExtensionByName("C"): found, but it should not exist: %v`, gotC)
}
}
// TestExtensionHostEnqueueBackendOperation verifies that [ExtensionHost] enqueues
// backend operations and executes them asynchronously in the order they were received.
// It also checks that operations requested before the host and all extensions are initialized
// are not executed immediately but rather after the host and extensions are initialized.
func TestExtensionHostEnqueueBackendOperation(t *testing.T) {
t.Parallel()
tests := []struct {
name string
preInitCalls []string // before host init
extInitCalls []string // from [Extension.Init]; "" means no call
wantInitCalls []string // what we expect to be called after host init
postInitCalls []string // after host init
}{
{
name: "no-calls",
preInitCalls: []string{},
extInitCalls: []string{},
wantInitCalls: []string{},
postInitCalls: []string{},
},
{
name: "pre-init-calls",
preInitCalls: []string{"pre-init-1", "pre-init-2"},
extInitCalls: []string{},
wantInitCalls: []string{"pre-init-1", "pre-init-2"},
postInitCalls: []string{},
},
{
name: "init-calls",
preInitCalls: []string{},
extInitCalls: []string{"init-1", "init-2"},
wantInitCalls: []string{"init-1", "init-2"},
postInitCalls: []string{},
},
{
name: "post-init-calls",
preInitCalls: []string{},
extInitCalls: []string{},
wantInitCalls: []string{},
postInitCalls: []string{"post-init-1", "post-init-2"},
},
{
name: "mixed-calls",
preInitCalls: []string{"pre-init-1", "pre-init-2"},
extInitCalls: []string{"init-1", "", "init-2"},
wantInitCalls: []string{"pre-init-1", "pre-init-2", "init-1", "init-2"},
postInitCalls: []string{"post-init-1", "post-init-2"},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
var gotCalls []string
var h *ExtensionHost
b := &testBackend{
switchToBestProfileHook: func(reason string) {
gotCalls = append(gotCalls, reason)
},
}
exts := make([]*testExtension, len(tt.extInitCalls))
for i, reason := range tt.extInitCalls {
exts[i] = &testExtension{}
if reason != "" {
exts[i].InitHook = func(e *testExtension) error {
e.host.Profiles().SwitchToBestProfileAsync(reason)
return nil
}
}
}
h = newExtensionHostForTest(t, b, false, exts...)
wq := h.SetWorkQueueForTest(t) // use a test queue instead of [execqueue.ExecQueue].
// Issue some pre-init calls. They should be deferred and not
// added to the queue until the host is initialized.
for _, call := range tt.preInitCalls {
h.Profiles().SwitchToBestProfileAsync(call)
}
// The queue should be empty before the host is initialized.
wq.Drain()
if len(gotCalls) != 0 {
t.Errorf("Pre-init calls: got %v; want (none)", gotCalls)
}
gotCalls = nil
// Initialize the host and all extensions.
// The extensions will make their calls during initialization.
h.Init()
// Calls made before or during initialization should now be enqueued and running.
wq.Drain()
if diff := deepcmp.Diff(tt.wantInitCalls, gotCalls, cmpopts.EquateEmpty()); diff != "" {
t.Errorf("Init calls: (+got -want): %v", diff)
}
gotCalls = nil
// Let's make some more calls, as if extensions were making them in a response
// to external events.
for _, call := range tt.postInitCalls {
h.Profiles().SwitchToBestProfileAsync(call)
}
// Any calls made after initialization should be enqueued and running.
wq.Drain()
if diff := deepcmp.Diff(tt.postInitCalls, gotCalls, cmpopts.EquateEmpty()); diff != "" {
t.Errorf("Init calls: (+got -want): %v", diff)
}
gotCalls = nil
})
}
}
// TestExtensionHostProfileChangeCallback verifies that [ExtensionHost] correctly handles the registration,
// invocation, and unregistration of profile change callbacks. It also checks that the callbacks are called
// with the correct arguments and that any private keys are stripped from [ipn.Prefs] before being passed to the callback.
func TestExtensionHostProfileChangeCallback(t *testing.T) {
t.Parallel()
type profileChange struct {
Profile *ipn.LoginProfile
Prefs *ipn.Prefs
SameNode bool
}
// newProfileChange creates a new profile change with deep copies of the profile and prefs.
newProfileChange := func(profile ipn.LoginProfileView, prefs ipn.PrefsView, sameNode bool) profileChange {
return profileChange{
Profile: profile.AsStruct(),
Prefs: prefs.AsStruct(),
SameNode: sameNode,
}
}
// makeProfileChangeAppender returns a callback that appends profile changes to the extension's state.
makeProfileChangeAppender := func(e *testExtension) ipnext.ProfileChangeCallback {
return func(profile ipn.LoginProfileView, prefs ipn.PrefsView, sameNode bool) {
UpdateExtState(e, "changes", func(changes []profileChange) []profileChange {
return append(changes, newProfileChange(profile, prefs, sameNode))
})
}
}
// getProfileChanges returns the profile changes stored in the extension's state.
getProfileChanges := func(e *testExtension) []profileChange {
changes, _ := GetExtStateOk[[]profileChange](e, "changes")
return changes
}
tests := []struct {
name string
ext *testExtension
calls []profileChange
wantCalls []profileChange
}{
{
// Register the callback for the lifetime of the extension.
name: "Register/Lifetime",
ext: &testExtension{},
calls: []profileChange{
{Profile: &ipn.LoginProfile{ID: "profile-1"}},
{Profile: &ipn.LoginProfile{ID: "profile-2"}},
{Profile: &ipn.LoginProfile{ID: "profile-3"}},
{Profile: &ipn.LoginProfile{ID: "profile-3"}, SameNode: true},
},
wantCalls: []profileChange{ // all calls are received by the callback
{Profile: &ipn.LoginProfile{ID: "profile-1"}},
{Profile: &ipn.LoginProfile{ID: "profile-2"}},
{Profile: &ipn.LoginProfile{ID: "profile-3"}},
{Profile: &ipn.LoginProfile{ID: "profile-3"}, SameNode: true},
},
},
{
// Override the default InitHook used in the test to unregister the callback
// after the first call.
name: "Register/Once",
ext: &testExtension{
InitHook: func(e *testExtension) error {
var unregister func()
handler := func(profile ipn.LoginProfileView, prefs ipn.PrefsView, sameNode bool) {
makeProfileChangeAppender(e)(profile, prefs, sameNode)
unregister()
}
unregister = e.host.Profiles().RegisterProfileChangeCallback(handler)
return nil
},
},
calls: []profileChange{
{Profile: &ipn.LoginProfile{ID: "profile-1"}},
{Profile: &ipn.LoginProfile{ID: "profile-2"}},
{Profile: &ipn.LoginProfile{ID: "profile-3"}},
},
wantCalls: []profileChange{ // only the first call is received by the callback
{Profile: &ipn.LoginProfile{ID: "profile-1"}},
},
},
{
// Ensure that ipn.Prefs are passed to the callback.
name: "CheckPrefs",
ext: &testExtension{},
calls: []profileChange{{
Profile: &ipn.LoginProfile{ID: "profile-1"},
Prefs: &ipn.Prefs{
WantRunning: true,
LoggedOut: false,
AdvertiseRoutes: []netip.Prefix{
netip.MustParsePrefix("192.168.1.0/24"),
netip.MustParsePrefix("192.168.2.0/24"),
},
},
}},
wantCalls: []profileChange{{
Profile: &ipn.LoginProfile{ID: "profile-1"},
Prefs: &ipn.Prefs{
WantRunning: true,
LoggedOut: false,
AdvertiseRoutes: []netip.Prefix{
netip.MustParsePrefix("192.168.1.0/24"),
netip.MustParsePrefix("192.168.2.0/24"),
},
},
}},
},
{
// Ensure that private keys are stripped from persist.Persist shared with extensions.
name: "StripPrivateKeys",
ext: &testExtension{},
calls: []profileChange{{
Profile: &ipn.LoginProfile{ID: "profile-1"},
Prefs: &ipn.Prefs{
Persist: &persist.Persist{
NodeID: "12345",
PrivateNodeKey: key.NewNode(),
OldPrivateNodeKey: key.NewNode(),
NetworkLockKey: key.NewNLPrivate(),
UserProfile: tailcfg.UserProfile{
ID: 12345,
LoginName: "test@example.com",
DisplayName: "Test User",
ProfilePicURL: "https://example.com/profile.png",
},
},
},
}},
wantCalls: []profileChange{{
Profile: &ipn.LoginProfile{ID: "profile-1"},
Prefs: &ipn.Prefs{
Persist: &persist.Persist{
NodeID: "12345",
PrivateNodeKey: key.NodePrivate{}, // stripped
OldPrivateNodeKey: key.NodePrivate{}, // stripped
NetworkLockKey: key.NLPrivate{}, // stripped
UserProfile: tailcfg.UserProfile{
ID: 12345,
LoginName: "test@example.com",
DisplayName: "Test User",
ProfilePicURL: "https://example.com/profile.png",
},
},
},
}},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
// Use the default InitHook if not provided by the test.
if tt.ext.InitHook == nil {
tt.ext.InitHook = func(e *testExtension) error {
// Create and register the callback on init.
handler := makeProfileChangeAppender(e)
e.Cleanup(e.host.Profiles().RegisterProfileChangeCallback(handler))
return nil
}
}
h := newExtensionHostForTest(t, &testBackend{}, true, tt.ext)
for _, call := range tt.calls {
h.NotifyProfileChange(call.Profile.View(), call.Prefs.View(), call.SameNode)
}
opts := []deepcmp.Option{
cmpopts.EquateComparable(key.NodePublic{}, netip.Addr{}, netip.Prefix{}),
}
if diff := deepcmp.Diff(tt.wantCalls, getProfileChanges(tt.ext), opts...); diff != "" {
t.Errorf("ProfileChange callbacks: (-want +got): %v", diff)
}
})
}
}
// TestBackgroundProfileResolver tests that the background profile resolvers
// are correctly registered, unregistered and invoked by the [ExtensionHost].
func TestBackgroundProfileResolver(t *testing.T) {
t.Parallel()
tests := []struct {
name string
profiles []ipn.LoginProfile // the first one is the current profile
resolvers []ipnext.ProfileResolver
wantProfile *ipn.LoginProfile
}{
{
name: "No-Profiles/No-Resolvers",
profiles: nil,
resolvers: nil,
wantProfile: nil,
},
{
// TODO(nickkhyl): update this test as we change "background profile resolvers"
// to just "profile resolvers". The wantProfile should be the current profile by default.
name: "Has-Profiles/No-Resolvers",
profiles: []ipn.LoginProfile{{ID: "profile-1"}},
resolvers: nil,
wantProfile: nil,
},
{
name: "Has-Profiles/Single-Resolver",
profiles: []ipn.LoginProfile{{ID: "profile-1"}},
resolvers: []ipnext.ProfileResolver{
func(ps ipnext.ProfileStore) ipn.LoginProfileView {
return ps.CurrentProfile()
},
},
wantProfile: &ipn.LoginProfile{ID: "profile-1"},
},
// TODO(nickkhyl): add more tests for multiple resolvers and different profiles
// once we change "background profile resolvers" to just "profile resolvers"
// and add proper conflict resolution logic.
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
// Create a new profile manager and add the profiles to it.
// We expose the profile manager to the extensions via the read-only [ipnext.ProfileStore] interface.
pm := must.Get(newProfileManager(new(mem.Store), t.Logf, new(health.Tracker)))
for i, p := range tt.profiles {
// Generate a unique ID and key for each profile,
// unless the profile already has them set
// or is an empty, unnamed profile.
if p.Name != "" {
if p.ID == "" {
p.ID = ipn.ProfileID("profile-" + strconv.Itoa(i))
}
if p.Key == "" {
p.Key = "key-" + ipn.StateKey(p.ID)
}
}
pv := p.View()
pm.knownProfiles[p.ID] = pv
if i == 0 {
// Set the first profile as the current one.
// A profileManager starts with an empty profile,
// so it's okay if the list of profiles is empty.
pm.SwitchToProfile(pv)
}
}
h := newExtensionHostForTest[ipnext.Extension](t, &testBackend{}, true)
// Register the resolvers with the host.
// This is typically done by the extensions themselves,
// but we do it here for testing purposes.
for _, r := range tt.resolvers {
t.Cleanup(h.Profiles().RegisterBackgroundProfileResolver(r))
}
// Call the resolver to get the profile.
gotProfile := h.DetermineBackgroundProfile(pm)
if !gotProfile.Equals(tt.wantProfile.View()) {
t.Errorf("Resolved profile: got %v; want %v", gotProfile, tt.wantProfile)
}
})
}
}
// TestAuditLogProviders tests that the [ExtensionHost] correctly handles
// the registration and invocation of audit log providers. It verifies that
// the audit loggers are called with the correct actions and details,
// and that any errors returned by the providers are properly propagated.
func TestAuditLogProviders(t *testing.T) {
t.Parallel()
tests := []struct {
name string
auditLoggers []ipnauth.AuditLogFunc // each represents an extension
actions []tailcfg.ClientAuditAction
wantErr bool
}{
{
name: "No-Providers",
auditLoggers: nil,
actions: []tailcfg.ClientAuditAction{"TestAction-1", "TestAction-2"},
wantErr: false,
},
{
name: "Single-Provider/Ok",
auditLoggers: []ipnauth.AuditLogFunc{
func(tailcfg.ClientAuditAction, string) error { return nil },
},
actions: []tailcfg.ClientAuditAction{"TestAction-1", "TestAction-2"},
wantErr: false,
},
{
name: "Single-Provider/Err",
auditLoggers: []ipnauth.AuditLogFunc{
func(tailcfg.ClientAuditAction, string) error {
return errors.New("failed to log")
},
},
actions: []tailcfg.ClientAuditAction{"TestAction-1", "TestAction-2"},
wantErr: true,
},
{
name: "Many-Providers/Ok",
auditLoggers: []ipnauth.AuditLogFunc{
func(tailcfg.ClientAuditAction, string) error { return nil },
func(tailcfg.ClientAuditAction, string) error { return nil },
},
actions: []tailcfg.ClientAuditAction{"TestAction-1", "TestAction-2"},
wantErr: false,
},
{
name: "Many-Providers/Err",
auditLoggers: []ipnauth.AuditLogFunc{
func(tailcfg.ClientAuditAction, string) error {
return errors.New("failed to log")
},
func(tailcfg.ClientAuditAction, string) error {
return nil // all good
},
func(tailcfg.ClientAuditAction, string) error {
return errors.New("also failed to log")
},
},
actions: []tailcfg.ClientAuditAction{"TestAction-1", "TestAction-2"},
wantErr: true, // some providers failed to log, so that's an error
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Create extensions that register the audit log providers.
// Each extension/provider will append auditable actions to its state,
// then call the test's auditLogger function.
var exts []*testExtension
for _, auditLogger := range tt.auditLoggers {
ext := &testExtension{}
provider := func() ipnauth.AuditLogFunc {
return func(action tailcfg.ClientAuditAction, details string) error {
UpdateExtState(ext, "actions", func(actions []tailcfg.ClientAuditAction) []tailcfg.ClientAuditAction {
return append(actions, action)
})
return auditLogger(action, details)
}
}
ext.InitHook = func(e *testExtension) error {
e.Cleanup(e.host.RegisterAuditLogProvider(provider))
return nil
}
exts = append(exts, ext)
}
// Initialize the host and the extensions.
h := newExtensionHostForTest(t, &testBackend{}, true, exts...)
// Use [ExtensionHost.AuditLogger] to log actions.
for _, action := range tt.actions {
err := h.AuditLogger()(action, "Test details")
if gotErr := err != nil; gotErr != tt.wantErr {
t.Errorf("AuditLogger: gotErr %v (%v); wantErr %v", gotErr, err, tt.wantErr)
}
}
// Check that the actions were logged correctly by each provider.
for _, ext := range exts {
gotActions := GetExtState[[]tailcfg.ClientAuditAction](ext, "actions")
if !slices.Equal(gotActions, tt.actions) {
t.Errorf("Actions: got %v; want %v", gotActions, tt.actions)
}
}
})
}
}
// TestNilExtensionHostMethodCall tests that calling exported methods
// on a nil [ExtensionHost] does not panic. We should treat it as a valid
// value since it's used in various tests that instantiate [LocalBackend]
// manually without calling [NewLocalBackend]. It also verifies that if
// a method returns a single func value (e.g., a cleanup function),
// it should not be nil. This is a basic sanity check to ensure that
// typical method calls on a nil receiver work as expected.
// It does not replace the need for more thorough testing of specific methods.
func TestNilExtensionHostMethodCall(t *testing.T) {
t.Parallel()
var h *ExtensionHost
typ := reflect.TypeOf(h)
for i := range typ.NumMethod() {
m := typ.Method(i)
if strings.HasSuffix(m.Name, "ForTest") {
// Skip methods that are only for testing.
continue
}
t.Run(m.Name, func(t *testing.T) {
t.Parallel()
// Calling the method on the nil receiver should not panic.
ret := checkMethodCallWithZeroArgs(t, m, h)
if len(ret) == 1 && ret[0].Kind() == reflect.Func {
// If the method returns a single func, such as a cleanup function,
// it should not be nil.
fn := ret[0]
if fn.IsNil() {
t.Fatalf("(%T).%s returned a nil func", h, m.Name)
}
// We expect it to be a no-op and calling it should not panic.
args := makeZeroArgsFor(fn)
func() {
defer func() {
if e := recover(); e != nil {
t.Fatalf("panic calling the func returned by (%T).%s: %v", e, m.Name, e)
}
}()
fn.Call(args)
}()
}
})
}
}
// checkMethodCallWithZeroArgs calls the method m on the receiver r
// with zero values for all its arguments, except the receiver itself.
// It returns the result of the method call, or fails the test if the call panics.
func checkMethodCallWithZeroArgs[T any](t *testing.T, m reflect.Method, r T) []reflect.Value {
t.Helper()
args := makeZeroArgsFor(m.Func)
// The first arg is the receiver.
args[0] = reflect.ValueOf(r)
// Calling the method should not panic.
defer func() {
if e := recover(); e != nil {
t.Fatalf("panic calling (%T).%s: %v", r, m.Name, e)
}
}()
return m.Func.Call(args)
}
func makeZeroArgsFor(fn reflect.Value) []reflect.Value {
args := make([]reflect.Value, fn.Type().NumIn())
for i := range args {
args[i] = reflect.Zero(fn.Type().In(i))
}
return args
}
// newExtensionHostForTest creates an [ExtensionHost] with the given backend and extensions.
// It associates each extension that either is or embeds a [testExtension] with the test
// and assigns a name if one isnt already set.
//
// If the host cannot be created, it fails the test.
//
// The host is initialized if the initialize parameter is true.
// It is shut down automatically when the test ends.
func newExtensionHostForTest[T ipnext.Extension](t *testing.T, b Backend, initialize bool, exts ...T) *ExtensionHost {
t.Helper()
// testExtensionIface is a subset of the methods implemented by [testExtension] that are used here.
// We use testExtensionIface in type assertions instead of using the [testExtension] type directly,
// which supports scenarios where an extension type embeds a [testExtension].
type testExtensionIface interface {
Name() string
setName(string)
setT(*testing.T)
checkShutdown()
}
logf := tstest.WhileTestRunningLogger(t)
defs := make([]*ipnext.Definition, len(exts))
for i, ext := range exts {
if ext, ok := any(ext).(testExtensionIface); ok {
ext.setName(cmp.Or(ext.Name(), "Ext-"+strconv.Itoa(i)))
ext.setT(t)
}
defs[i] = ipnext.DefinitionForTest(ext)
}
h, err := NewExtensionHost(logf, &tsd.System{}, b, defs...)
if err != nil {
t.Fatalf("NewExtensionHost: %v", err)
}
// Replace doEnqueueBackendOperation with the one that's marked as a helper,
// so that we'll have better output if [testExecQueue.Add] fails a test.
h.doEnqueueBackendOperation = func(f func(Backend)) {
t.Helper()
h.workQueue.Add(func() { f(b) })
}
for _, ext := range exts {
if ext, ok := any(ext).(testExtensionIface); ok {
t.Cleanup(ext.checkShutdown)
}
}
t.Cleanup(h.Shutdown)
if initialize {
h.Init()
}
return h
}
// testExtension is an [ipnext.Extension] that:
// - Calls the provided init and shutdown callbacks
// when [Init] and [Shutdown] are called.
// - Ensures that [Init] and [Shutdown] are called at most once,
// that [Shutdown] is called after [Init], but is not called if [Init] fails
// and is called before the test ends if [Init] succeeds.
//
// Typically, [testExtension]s are created and passed to [newExtensionHostForTest]
// when creating an [ExtensionHost] for testing.
type testExtension struct {
t *testing.T // test that created the extension
name string // name of the extension, used for logging
host ipnext.Host // or nil if not initialized
// InitHook and ShutdownHook are optional hooks that can be set by tests.
InitHook, ShutdownHook func(*testExtension) error
// initCnt, initOkCnt and shutdownCnt are used to verify that Init and Shutdown
// are called at most once and in the correct order.
initCnt, initOkCnt, shutdownCnt atomic.Int32
// mu protects the following fields.
mu sync.Mutex
// state is the optional state used by tests.
// It can be accessed by tests using [setTestExtensionState],
// [getTestExtensionStateOk] and [getTestExtensionState].
state map[string]any
// cleanup are functions to be called on shutdown.
cleanup []func()
}
var _ ipnext.Extension = (*testExtension)(nil)
func (e *testExtension) setT(t *testing.T) {
e.t = t
}
func (e *testExtension) setName(name string) {
e.name = name
}
// Name implements [ipnext.Extension].
func (e *testExtension) Name() string {
return e.name
}
// Init implements [ipnext.Extension].
func (e *testExtension) Init(host ipnext.Host) (err error) {
e.t.Helper()
e.host = host
if e.initCnt.Add(1) == 1 {
e.mu.Lock()
e.state = make(map[string]any)
e.mu.Unlock()
} else {
e.t.Errorf("%q: Init called more than once", e.name)
}
if e.InitHook != nil {
err = e.InitHook(e)
}
if err == nil {
e.initOkCnt.Add(1)
}
return err // may be nil or non-nil
}
// InitCalled reports whether the Init method was called on the receiver.
func (e *testExtension) InitCalled() bool {
return e.initCnt.Load() != 0
}
func (e *testExtension) Cleanup(f func()) {
e.mu.Lock()
e.cleanup = append(e.cleanup, f)
e.mu.Unlock()
}
// Shutdown implements [ipnext.Extension].
func (e *testExtension) Shutdown() (err error) {
e.t.Helper()
e.mu.Lock()
cleanup := e.cleanup
e.cleanup = nil
e.mu.Unlock()
for _, f := range cleanup {
f()
}
if e.ShutdownHook != nil {
err = e.ShutdownHook(e)
}
if e.shutdownCnt.Add(1) != 1 {
e.t.Errorf("%q: Shutdown called more than once", e.name)
}
if e.initCnt.Load() == 0 {
e.t.Errorf("%q: Shutdown called without Init", e.name)
} else if e.initOkCnt.Load() == 0 {
e.t.Errorf("%q: Shutdown called despite failed Init", e.name)
}
e.host = nil
return err // may be nil or non-nil
}
func (e *testExtension) checkShutdown() {
e.t.Helper()
if e.initOkCnt.Load() != 0 && e.shutdownCnt.Load() == 0 {
e.t.Errorf("%q: Shutdown has not been called before test end", e.name)
}
}
// ShutdownCalled reports whether the Shutdown method was called on the receiver.
func (e *testExtension) ShutdownCalled() bool {
return e.shutdownCnt.Load() != 0
}
// SetExtState sets a keyed state on [testExtension] to the given value.
// Tests use it to propagate test-specific state throughout the extension lifecycle
// (e.g., between [testExtension.Init], [testExtension.Shutdown], and registered callbacks)
func SetExtState[T any](e *testExtension, key string, value T) {
e.mu.Lock()
defer e.mu.Unlock()
e.state[key] = value
}
// UpdateExtState updates a keyed state of the extension using the provided update function.
func UpdateExtState[T any](e *testExtension, key string, update func(T) T) {
e.mu.Lock()
defer e.mu.Unlock()
old, _ := e.state[key].(T)
new := update(old)
e.state[key] = new
}
// GetExtState returns the value of the keyed state of the extension.
// It returns a zero value of T if the state is not set or is of a different type.
func GetExtState[T any](e *testExtension, key string) T {
v, _ := GetExtStateOk[T](e, key)
return v
}
// GetExtStateOk is like [getExtState], but also reports whether the state
// with the given key exists and is of the expected type.
func GetExtStateOk[T any](e *testExtension, key string) (_ T, ok bool) {
e.mu.Lock()
defer e.mu.Unlock()
v, ok := e.state[key].(T)
return v, ok
}
// testExecQueue is a test implementation of [execQueue]
// that defers execution of the enqueued funcs until
// [testExecQueue.Drain] is called, and fails the test if
// if [execQueue.Add] is called before the host is initialized.
//
// It is typically used by calling [ExtensionHost.SetWorkQueueForTest].
type testExecQueue struct {
t *testing.T // test that created the queue
h *ExtensionHost // host to own the queue
mu sync.Mutex
queue []func()
}
var _ execQueue = (*testExecQueue)(nil)
// SetWorkQueueForTest is a helper function that creates a new [testExecQueue]
// and sets it as the work queue for the specified [ExtensionHost],
// returning the new queue.
//
// It fails the test if the host is already initialized.
func (h *ExtensionHost) SetWorkQueueForTest(t *testing.T) *testExecQueue {
t.Helper()
if h.initialized.Load() {
t.Fatalf("UseTestWorkQueue: host is already initialized")
return nil
}
q := &testExecQueue{t: t, h: h}
h.workQueue = q
return q
}
// Add implements [execQueue].
func (q *testExecQueue) Add(f func()) {
q.t.Helper()
if !q.h.initialized.Load() {
q.t.Fatal("ExecQueue.Add must not be called until the host is initialized")
return
}
q.mu.Lock()
q.queue = append(q.queue, f)
q.mu.Unlock()
}
// Drain executes all queued functions in the order they were added.
func (q *testExecQueue) Drain() {
q.mu.Lock()
queue := q.queue
q.queue = nil
q.mu.Unlock()
for _, f := range queue {
f()
}
}
// Shutdown implements [execQueue].
func (q *testExecQueue) Shutdown() {}
// Wait implements [execQueue].
func (q *testExecQueue) Wait(context.Context) error { return nil }
// testBackend implements [ipnext.Backend] for testing purposes
// by calling the provided hooks when its methods are called.
type testBackend struct {
switchToBestProfileHook func(reason string)
// mu protects the backend state.
// It is acquired on entry to the exported methods of the backend
// and released on exit, mimicking the behavior of the [LocalBackend].
mu sync.Mutex
}
func (b *testBackend) SwitchToBestProfile(reason string) {
b.mu.Lock()
defer b.mu.Unlock()
if b.switchToBestProfileHook != nil {
b.switchToBestProfileHook(reason)
}
}