mirror of
https://github.com/tailscale/tailscale.git
synced 2025-04-25 10:14:36 +00:00

In this PR, we add two methods to facilitate extension lookup by both extensions, and non-extensions (e.g., PeerAPI or LocalAPI handlers): - FindExtensionByName returns an extension with the specified name. It can then be type asserted to a given type. - FindMatchingExtension is like errors.As, but for extensions. It returns the first extension that matches the target type (either a specific extension or an interface). Updates tailscale/corp#27645 Updates tailscale/corp#27502 Signed-off-by: Nick Khyl <nickk@tailscale.com>
1234 lines
38 KiB
Go
1234 lines
38 KiB
Go
// Copyright (c) Tailscale Inc & AUTHORS
|
||
// SPDX-License-Identifier: BSD-3-Clause
|
||
|
||
package ipnlocal
|
||
|
||
import (
|
||
"cmp"
|
||
"context"
|
||
"errors"
|
||
"net/netip"
|
||
"reflect"
|
||
"slices"
|
||
"strconv"
|
||
"strings"
|
||
"sync"
|
||
"sync/atomic"
|
||
"testing"
|
||
|
||
deepcmp "github.com/google/go-cmp/cmp"
|
||
"github.com/google/go-cmp/cmp/cmpopts"
|
||
|
||
"tailscale.com/health"
|
||
"tailscale.com/ipn"
|
||
"tailscale.com/ipn/ipnauth"
|
||
"tailscale.com/ipn/ipnext"
|
||
"tailscale.com/ipn/store/mem"
|
||
"tailscale.com/tailcfg"
|
||
"tailscale.com/tsd"
|
||
"tailscale.com/tstest"
|
||
"tailscale.com/types/key"
|
||
"tailscale.com/types/persist"
|
||
"tailscale.com/util/must"
|
||
)
|
||
|
||
// TestExtensionInitShutdown tests that [ExtensionHost] correctly initializes
|
||
// and shuts down extensions.
|
||
func TestExtensionInitShutdown(t *testing.T) {
|
||
t.Parallel()
|
||
|
||
// As of 2025-04-08, [ipn.Host.Init] and [ipn.Host.Shutdown] do not return errors
|
||
// as extension initialization and shutdown errors are not fatal.
|
||
// If these methods are updated to return errors, this test should also be updated.
|
||
// The conversions below will fail to compile if their signatures change, reminding us to update the test.
|
||
_ = (func(*ExtensionHost))((*ExtensionHost).Init)
|
||
_ = (func(*ExtensionHost))((*ExtensionHost).Shutdown)
|
||
|
||
tests := []struct {
|
||
name string
|
||
nilHost bool
|
||
exts []*testExtension
|
||
wantInit []string
|
||
wantShutdown []string
|
||
skipInit bool
|
||
}{
|
||
{
|
||
name: "nil-host",
|
||
nilHost: true,
|
||
exts: []*testExtension{},
|
||
wantInit: []string{},
|
||
wantShutdown: []string{},
|
||
},
|
||
{
|
||
name: "empty-extensions",
|
||
exts: []*testExtension{},
|
||
wantInit: []string{},
|
||
wantShutdown: []string{},
|
||
},
|
||
{
|
||
name: "single-extension",
|
||
exts: []*testExtension{{name: "A"}},
|
||
wantInit: []string{"A"},
|
||
wantShutdown: []string{"A"},
|
||
},
|
||
{
|
||
name: "multiple-extensions/all-ok",
|
||
exts: []*testExtension{{name: "A"}, {name: "B"}, {name: "C"}},
|
||
wantInit: []string{"A", "B", "C"},
|
||
wantShutdown: []string{"C", "B", "A"},
|
||
},
|
||
{
|
||
name: "multiple-extensions/no-init-no-shutdown",
|
||
exts: []*testExtension{{name: "A"}, {name: "B"}, {name: "C"}},
|
||
wantInit: []string{},
|
||
wantShutdown: []string{},
|
||
skipInit: true,
|
||
},
|
||
{
|
||
name: "multiple-extensions/init-failed/first",
|
||
exts: []*testExtension{{
|
||
name: "A",
|
||
InitHook: func(*testExtension) error { return errors.New("init failed") },
|
||
}, {
|
||
name: "B",
|
||
InitHook: func(*testExtension) error { return nil },
|
||
}, {
|
||
name: "C",
|
||
InitHook: func(*testExtension) error { return nil },
|
||
}},
|
||
wantInit: []string{"A", "B", "C"},
|
||
wantShutdown: []string{"C", "B"},
|
||
},
|
||
{
|
||
name: "multiple-extensions/init-failed/second",
|
||
exts: []*testExtension{{
|
||
name: "A",
|
||
InitHook: func(*testExtension) error { return nil },
|
||
}, {
|
||
name: "B",
|
||
InitHook: func(*testExtension) error { return errors.New("init failed") },
|
||
}, {
|
||
name: "C",
|
||
InitHook: func(*testExtension) error { return nil },
|
||
}},
|
||
wantInit: []string{"A", "B", "C"},
|
||
wantShutdown: []string{"C", "A"},
|
||
},
|
||
{
|
||
name: "multiple-extensions/init-failed/third",
|
||
exts: []*testExtension{{
|
||
name: "A",
|
||
InitHook: func(*testExtension) error { return nil },
|
||
}, {
|
||
name: "B",
|
||
InitHook: func(*testExtension) error { return nil },
|
||
}, {
|
||
name: "C",
|
||
InitHook: func(*testExtension) error { return errors.New("init failed") },
|
||
}},
|
||
wantInit: []string{"A", "B", "C"},
|
||
wantShutdown: []string{"B", "A"},
|
||
},
|
||
{
|
||
name: "multiple-extensions/init-failed/all",
|
||
exts: []*testExtension{{
|
||
name: "A",
|
||
InitHook: func(*testExtension) error { return errors.New("init failed") },
|
||
}, {
|
||
name: "B",
|
||
InitHook: func(*testExtension) error { return errors.New("init failed") },
|
||
}, {
|
||
name: "C",
|
||
InitHook: func(*testExtension) error { return errors.New("init failed") },
|
||
}},
|
||
wantInit: []string{"A", "B", "C"},
|
||
wantShutdown: []string{},
|
||
},
|
||
{
|
||
name: "multiple-extensions/init-skipped",
|
||
exts: []*testExtension{{
|
||
name: "A",
|
||
InitHook: func(*testExtension) error { return nil },
|
||
}, {
|
||
name: "B",
|
||
InitHook: func(*testExtension) error { return ipnext.SkipExtension },
|
||
}, {
|
||
name: "C",
|
||
InitHook: func(*testExtension) error { return nil },
|
||
}},
|
||
wantInit: []string{"A", "B", "C"},
|
||
wantShutdown: []string{"C", "A"},
|
||
},
|
||
}
|
||
for _, tt := range tests {
|
||
t.Run(tt.name, func(t *testing.T) {
|
||
t.Parallel()
|
||
|
||
// Configure all extensions to append their names
|
||
// to the gotInit and gotShutdown slices
|
||
// during initialization and shutdown,
|
||
// so we can check that they are called in the right order
|
||
// and that shutdown is not unless init succeeded.
|
||
var gotInit, gotShutdown []string
|
||
for _, ext := range tt.exts {
|
||
oldInitHook := ext.InitHook
|
||
ext.InitHook = func(e *testExtension) error {
|
||
gotInit = append(gotInit, e.name)
|
||
if oldInitHook == nil {
|
||
return nil
|
||
}
|
||
return oldInitHook(e)
|
||
}
|
||
ext.ShutdownHook = func(e *testExtension) error {
|
||
gotShutdown = append(gotShutdown, e.name)
|
||
return nil
|
||
}
|
||
}
|
||
|
||
var h *ExtensionHost
|
||
if !tt.nilHost {
|
||
h = newExtensionHostForTest(t, &testBackend{}, false, tt.exts...)
|
||
}
|
||
|
||
if !tt.skipInit {
|
||
h.Init()
|
||
}
|
||
|
||
// Check that the extensions were initialized in the right order.
|
||
if !slices.Equal(gotInit, tt.wantInit) {
|
||
t.Errorf("Init extensions: got %v; want %v", gotInit, tt.wantInit)
|
||
}
|
||
|
||
// Calling Init again on the host should be a no-op.
|
||
// The [testExtension.Init] method fails the test if called more than once,
|
||
// regardless of which test is running, so we don't need to check it here.
|
||
// Similarly, calling Shutdown again on the host should be a no-op as well.
|
||
// It is verified by the [testExtension.Shutdown] method itself.
|
||
if !tt.skipInit {
|
||
h.Init()
|
||
}
|
||
|
||
// Extensions should not be shut down before the host is shut down,
|
||
// even if they are not initialized successfully.
|
||
for _, ext := range tt.exts {
|
||
if gotShutdown := ext.ShutdownCalled(); gotShutdown {
|
||
t.Errorf("%q: Extension shutdown called before host shutdown", ext.name)
|
||
}
|
||
}
|
||
|
||
h.Shutdown()
|
||
// Check that the extensions were shut down in the right order,
|
||
// and that they were not shut down if they were not initialized successfully.
|
||
if !slices.Equal(gotShutdown, tt.wantShutdown) {
|
||
t.Errorf("Shutdown extensions: got %v; want %v", gotShutdown, tt.wantShutdown)
|
||
}
|
||
|
||
})
|
||
}
|
||
}
|
||
|
||
// TestNewExtensionHost tests that [NewExtensionHost] correctly creates
|
||
// an [ExtensionHost], instantiates the extensions and handles errors
|
||
// if an extension cannot be created.
|
||
func TestNewExtensionHost(t *testing.T) {
|
||
t.Parallel()
|
||
tests := []struct {
|
||
name string
|
||
defs []*ipnext.Definition
|
||
wantErr bool
|
||
wantExts []string
|
||
}{
|
||
{
|
||
name: "no-exts",
|
||
defs: []*ipnext.Definition{},
|
||
wantErr: false,
|
||
wantExts: []string{},
|
||
},
|
||
{
|
||
name: "exts-ok",
|
||
defs: []*ipnext.Definition{
|
||
ipnext.DefinitionForTest(&testExtension{name: "A"}),
|
||
ipnext.DefinitionForTest(&testExtension{name: "B"}),
|
||
ipnext.DefinitionForTest(&testExtension{name: "C"}),
|
||
},
|
||
wantErr: false,
|
||
wantExts: []string{"A", "B", "C"},
|
||
},
|
||
{
|
||
name: "exts-skipped",
|
||
defs: []*ipnext.Definition{
|
||
ipnext.DefinitionForTest(&testExtension{name: "A"}),
|
||
ipnext.DefinitionWithErrForTest("B", ipnext.SkipExtension),
|
||
ipnext.DefinitionForTest(&testExtension{name: "C"}),
|
||
},
|
||
wantErr: false, // extension B is skipped, that's ok
|
||
wantExts: []string{"A", "C"},
|
||
},
|
||
{
|
||
name: "exts-fail",
|
||
defs: []*ipnext.Definition{
|
||
ipnext.DefinitionForTest(&testExtension{name: "A"}),
|
||
ipnext.DefinitionWithErrForTest("B", errors.New("failed creating Ext-2")),
|
||
ipnext.DefinitionForTest(&testExtension{name: "C"}),
|
||
},
|
||
wantErr: true, // extension B failed to create, that's not ok
|
||
wantExts: []string{},
|
||
},
|
||
}
|
||
for _, tt := range tests {
|
||
t.Run(tt.name, func(t *testing.T) {
|
||
t.Parallel()
|
||
logf := tstest.WhileTestRunningLogger(t)
|
||
h, err := NewExtensionHost(logf, &tsd.System{}, &testBackend{}, tt.defs...)
|
||
if gotErr := err != nil; gotErr != tt.wantErr {
|
||
t.Errorf("NewExtensionHost: gotErr %v(%v); wantErr %v", gotErr, err, tt.wantErr)
|
||
}
|
||
if err != nil {
|
||
return
|
||
}
|
||
|
||
var gotExts []string
|
||
for _, ext := range h.allExtensions {
|
||
gotExts = append(gotExts, ext.Name())
|
||
}
|
||
|
||
if !slices.Equal(gotExts, tt.wantExts) {
|
||
t.Errorf("Shutdown extensions: got %v; want %v", gotExts, tt.wantExts)
|
||
}
|
||
})
|
||
}
|
||
}
|
||
|
||
// TestFindMatchingExtension tests that [ExtensionHost.FindMatchingExtension] correctly
|
||
// finds extensions by their type or interface.
|
||
func TestFindMatchingExtension(t *testing.T) {
|
||
t.Parallel()
|
||
|
||
// Define test extension types and a couple of interfaces
|
||
type (
|
||
extensionA struct {
|
||
testExtension
|
||
}
|
||
extensionB struct {
|
||
testExtension
|
||
}
|
||
extensionC struct {
|
||
testExtension
|
||
}
|
||
supportedIface interface {
|
||
Name() string
|
||
}
|
||
unsupportedIface interface {
|
||
Unsupported()
|
||
}
|
||
)
|
||
|
||
// Register extensions A and B, but not C.
|
||
extA := &extensionA{testExtension: testExtension{name: "A"}}
|
||
extB := &extensionB{testExtension: testExtension{name: "B"}}
|
||
h := newExtensionHostForTest[ipnext.Extension](t, &testBackend{}, true, extA, extB)
|
||
|
||
var gotA *extensionA
|
||
if !h.FindMatchingExtension(&gotA) {
|
||
t.Errorf("LookupExtension(%T): not found", gotA)
|
||
} else if gotA != extA {
|
||
t.Errorf("LookupExtension(%T): got %v; want %v", gotA, gotA, extA)
|
||
}
|
||
|
||
var gotB *extensionB
|
||
if !h.FindMatchingExtension(&gotB) {
|
||
t.Errorf("LookupExtension(%T): extension B not found", gotB)
|
||
} else if gotB != extB {
|
||
t.Errorf("LookupExtension(%T): got %v; want %v", gotB, gotB, extB)
|
||
}
|
||
|
||
var gotC *extensionC
|
||
if h.FindMatchingExtension(&gotC) {
|
||
t.Errorf("LookupExtension(%T): found, but it should not exist", gotC)
|
||
}
|
||
|
||
// All extensions implement the supportedIface interface,
|
||
// but LookupExtension should only return the first one found,
|
||
// which is extA.
|
||
var gotSupportedIface supportedIface
|
||
if !h.FindMatchingExtension(&gotSupportedIface) {
|
||
t.Errorf("LookupExtension(%T): not found", gotSupportedIface)
|
||
} else if gotName, wantName := gotSupportedIface.Name(), extA.Name(); gotName != wantName {
|
||
t.Errorf("LookupExtension(%T): name: got %v; want %v", gotSupportedIface, gotName, wantName)
|
||
} else if gotSupportedIface != extA {
|
||
t.Errorf("LookupExtension(%T): got %v; want %v", gotSupportedIface, gotSupportedIface, extA)
|
||
}
|
||
|
||
var gotUnsupportedIface unsupportedIface
|
||
if h.FindMatchingExtension(&gotUnsupportedIface) {
|
||
t.Errorf("LookupExtension(%T): found, but it should not exist", gotUnsupportedIface)
|
||
}
|
||
}
|
||
|
||
// TestFindExtensionByName tests that [ExtensionHost.FindExtensionByName] correctly
|
||
// finds extensions by their name.
|
||
func TestFindExtensionByName(t *testing.T) {
|
||
// Register extensions A and B, but not C.
|
||
extA := &testExtension{name: "A"}
|
||
extB := &testExtension{name: "B"}
|
||
h := newExtensionHostForTest(t, &testBackend{}, true, extA, extB)
|
||
|
||
gotA, ok := h.FindExtensionByName(extA.Name()).(*testExtension)
|
||
if !ok {
|
||
t.Errorf("FindExtensionByName(%q): not found", extA.Name())
|
||
} else if gotA != extA {
|
||
t.Errorf(`FindExtensionByName(%q): got %v; want %v`, extA.Name(), gotA, extA)
|
||
}
|
||
|
||
gotB, ok := h.FindExtensionByName(extB.Name()).(*testExtension)
|
||
if !ok {
|
||
t.Errorf("FindExtensionByName(%q): not found", extB.Name())
|
||
} else if gotB != extB {
|
||
t.Errorf(`FindExtensionByName(%q): got %v; want %v`, extB.Name(), gotB, extB)
|
||
}
|
||
|
||
gotC, ok := h.FindExtensionByName("C").(*testExtension)
|
||
if ok {
|
||
t.Errorf(`FindExtensionByName("C"): found, but it should not exist: %v`, gotC)
|
||
}
|
||
}
|
||
|
||
// TestExtensionHostEnqueueBackendOperation verifies that [ExtensionHost] enqueues
|
||
// backend operations and executes them asynchronously in the order they were received.
|
||
// It also checks that operations requested before the host and all extensions are initialized
|
||
// are not executed immediately but rather after the host and extensions are initialized.
|
||
func TestExtensionHostEnqueueBackendOperation(t *testing.T) {
|
||
t.Parallel()
|
||
tests := []struct {
|
||
name string
|
||
preInitCalls []string // before host init
|
||
extInitCalls []string // from [Extension.Init]; "" means no call
|
||
wantInitCalls []string // what we expect to be called after host init
|
||
postInitCalls []string // after host init
|
||
}{
|
||
{
|
||
name: "no-calls",
|
||
preInitCalls: []string{},
|
||
extInitCalls: []string{},
|
||
wantInitCalls: []string{},
|
||
postInitCalls: []string{},
|
||
},
|
||
{
|
||
name: "pre-init-calls",
|
||
preInitCalls: []string{"pre-init-1", "pre-init-2"},
|
||
extInitCalls: []string{},
|
||
wantInitCalls: []string{"pre-init-1", "pre-init-2"},
|
||
postInitCalls: []string{},
|
||
},
|
||
{
|
||
name: "init-calls",
|
||
preInitCalls: []string{},
|
||
extInitCalls: []string{"init-1", "init-2"},
|
||
wantInitCalls: []string{"init-1", "init-2"},
|
||
postInitCalls: []string{},
|
||
},
|
||
{
|
||
name: "post-init-calls",
|
||
preInitCalls: []string{},
|
||
extInitCalls: []string{},
|
||
wantInitCalls: []string{},
|
||
postInitCalls: []string{"post-init-1", "post-init-2"},
|
||
},
|
||
{
|
||
name: "mixed-calls",
|
||
preInitCalls: []string{"pre-init-1", "pre-init-2"},
|
||
extInitCalls: []string{"init-1", "", "init-2"},
|
||
wantInitCalls: []string{"pre-init-1", "pre-init-2", "init-1", "init-2"},
|
||
postInitCalls: []string{"post-init-1", "post-init-2"},
|
||
},
|
||
}
|
||
for _, tt := range tests {
|
||
t.Run(tt.name, func(t *testing.T) {
|
||
t.Parallel()
|
||
|
||
var gotCalls []string
|
||
var h *ExtensionHost
|
||
b := &testBackend{
|
||
switchToBestProfileHook: func(reason string) {
|
||
gotCalls = append(gotCalls, reason)
|
||
},
|
||
}
|
||
|
||
exts := make([]*testExtension, len(tt.extInitCalls))
|
||
for i, reason := range tt.extInitCalls {
|
||
exts[i] = &testExtension{}
|
||
if reason != "" {
|
||
exts[i].InitHook = func(e *testExtension) error {
|
||
e.host.Profiles().SwitchToBestProfileAsync(reason)
|
||
return nil
|
||
}
|
||
}
|
||
}
|
||
|
||
h = newExtensionHostForTest(t, b, false, exts...)
|
||
wq := h.SetWorkQueueForTest(t) // use a test queue instead of [execqueue.ExecQueue].
|
||
|
||
// Issue some pre-init calls. They should be deferred and not
|
||
// added to the queue until the host is initialized.
|
||
for _, call := range tt.preInitCalls {
|
||
h.Profiles().SwitchToBestProfileAsync(call)
|
||
}
|
||
|
||
// The queue should be empty before the host is initialized.
|
||
wq.Drain()
|
||
if len(gotCalls) != 0 {
|
||
t.Errorf("Pre-init calls: got %v; want (none)", gotCalls)
|
||
}
|
||
gotCalls = nil
|
||
|
||
// Initialize the host and all extensions.
|
||
// The extensions will make their calls during initialization.
|
||
h.Init()
|
||
|
||
// Calls made before or during initialization should now be enqueued and running.
|
||
wq.Drain()
|
||
if diff := deepcmp.Diff(tt.wantInitCalls, gotCalls, cmpopts.EquateEmpty()); diff != "" {
|
||
t.Errorf("Init calls: (+got -want): %v", diff)
|
||
}
|
||
gotCalls = nil
|
||
|
||
// Let's make some more calls, as if extensions were making them in a response
|
||
// to external events.
|
||
for _, call := range tt.postInitCalls {
|
||
h.Profiles().SwitchToBestProfileAsync(call)
|
||
}
|
||
|
||
// Any calls made after initialization should be enqueued and running.
|
||
wq.Drain()
|
||
if diff := deepcmp.Diff(tt.postInitCalls, gotCalls, cmpopts.EquateEmpty()); diff != "" {
|
||
t.Errorf("Init calls: (+got -want): %v", diff)
|
||
}
|
||
gotCalls = nil
|
||
})
|
||
}
|
||
}
|
||
|
||
// TestExtensionHostProfileChangeCallback verifies that [ExtensionHost] correctly handles the registration,
|
||
// invocation, and unregistration of profile change callbacks. It also checks that the callbacks are called
|
||
// with the correct arguments and that any private keys are stripped from [ipn.Prefs] before being passed to the callback.
|
||
func TestExtensionHostProfileChangeCallback(t *testing.T) {
|
||
t.Parallel()
|
||
|
||
type profileChange struct {
|
||
Profile *ipn.LoginProfile
|
||
Prefs *ipn.Prefs
|
||
SameNode bool
|
||
}
|
||
// newProfileChange creates a new profile change with deep copies of the profile and prefs.
|
||
newProfileChange := func(profile ipn.LoginProfileView, prefs ipn.PrefsView, sameNode bool) profileChange {
|
||
return profileChange{
|
||
Profile: profile.AsStruct(),
|
||
Prefs: prefs.AsStruct(),
|
||
SameNode: sameNode,
|
||
}
|
||
}
|
||
// makeProfileChangeAppender returns a callback that appends profile changes to the extension's state.
|
||
makeProfileChangeAppender := func(e *testExtension) ipnext.ProfileChangeCallback {
|
||
return func(profile ipn.LoginProfileView, prefs ipn.PrefsView, sameNode bool) {
|
||
UpdateExtState(e, "changes", func(changes []profileChange) []profileChange {
|
||
return append(changes, newProfileChange(profile, prefs, sameNode))
|
||
})
|
||
}
|
||
}
|
||
// getProfileChanges returns the profile changes stored in the extension's state.
|
||
getProfileChanges := func(e *testExtension) []profileChange {
|
||
changes, _ := GetExtStateOk[[]profileChange](e, "changes")
|
||
return changes
|
||
}
|
||
|
||
tests := []struct {
|
||
name string
|
||
ext *testExtension
|
||
calls []profileChange
|
||
wantCalls []profileChange
|
||
}{
|
||
{
|
||
// Register the callback for the lifetime of the extension.
|
||
name: "Register/Lifetime",
|
||
ext: &testExtension{},
|
||
calls: []profileChange{
|
||
{Profile: &ipn.LoginProfile{ID: "profile-1"}},
|
||
{Profile: &ipn.LoginProfile{ID: "profile-2"}},
|
||
{Profile: &ipn.LoginProfile{ID: "profile-3"}},
|
||
{Profile: &ipn.LoginProfile{ID: "profile-3"}, SameNode: true},
|
||
},
|
||
wantCalls: []profileChange{ // all calls are received by the callback
|
||
{Profile: &ipn.LoginProfile{ID: "profile-1"}},
|
||
{Profile: &ipn.LoginProfile{ID: "profile-2"}},
|
||
{Profile: &ipn.LoginProfile{ID: "profile-3"}},
|
||
{Profile: &ipn.LoginProfile{ID: "profile-3"}, SameNode: true},
|
||
},
|
||
},
|
||
{
|
||
// Override the default InitHook used in the test to unregister the callback
|
||
// after the first call.
|
||
name: "Register/Once",
|
||
ext: &testExtension{
|
||
InitHook: func(e *testExtension) error {
|
||
var unregister func()
|
||
handler := func(profile ipn.LoginProfileView, prefs ipn.PrefsView, sameNode bool) {
|
||
makeProfileChangeAppender(e)(profile, prefs, sameNode)
|
||
unregister()
|
||
}
|
||
unregister = e.host.Profiles().RegisterProfileChangeCallback(handler)
|
||
return nil
|
||
},
|
||
},
|
||
calls: []profileChange{
|
||
{Profile: &ipn.LoginProfile{ID: "profile-1"}},
|
||
{Profile: &ipn.LoginProfile{ID: "profile-2"}},
|
||
{Profile: &ipn.LoginProfile{ID: "profile-3"}},
|
||
},
|
||
wantCalls: []profileChange{ // only the first call is received by the callback
|
||
{Profile: &ipn.LoginProfile{ID: "profile-1"}},
|
||
},
|
||
},
|
||
{
|
||
// Ensure that ipn.Prefs are passed to the callback.
|
||
name: "CheckPrefs",
|
||
ext: &testExtension{},
|
||
calls: []profileChange{{
|
||
Profile: &ipn.LoginProfile{ID: "profile-1"},
|
||
Prefs: &ipn.Prefs{
|
||
WantRunning: true,
|
||
LoggedOut: false,
|
||
AdvertiseRoutes: []netip.Prefix{
|
||
netip.MustParsePrefix("192.168.1.0/24"),
|
||
netip.MustParsePrefix("192.168.2.0/24"),
|
||
},
|
||
},
|
||
}},
|
||
wantCalls: []profileChange{{
|
||
Profile: &ipn.LoginProfile{ID: "profile-1"},
|
||
Prefs: &ipn.Prefs{
|
||
WantRunning: true,
|
||
LoggedOut: false,
|
||
AdvertiseRoutes: []netip.Prefix{
|
||
netip.MustParsePrefix("192.168.1.0/24"),
|
||
netip.MustParsePrefix("192.168.2.0/24"),
|
||
},
|
||
},
|
||
}},
|
||
},
|
||
{
|
||
// Ensure that private keys are stripped from persist.Persist shared with extensions.
|
||
name: "StripPrivateKeys",
|
||
ext: &testExtension{},
|
||
calls: []profileChange{{
|
||
Profile: &ipn.LoginProfile{ID: "profile-1"},
|
||
Prefs: &ipn.Prefs{
|
||
Persist: &persist.Persist{
|
||
NodeID: "12345",
|
||
PrivateNodeKey: key.NewNode(),
|
||
OldPrivateNodeKey: key.NewNode(),
|
||
NetworkLockKey: key.NewNLPrivate(),
|
||
UserProfile: tailcfg.UserProfile{
|
||
ID: 12345,
|
||
LoginName: "test@example.com",
|
||
DisplayName: "Test User",
|
||
ProfilePicURL: "https://example.com/profile.png",
|
||
},
|
||
},
|
||
},
|
||
}},
|
||
wantCalls: []profileChange{{
|
||
Profile: &ipn.LoginProfile{ID: "profile-1"},
|
||
Prefs: &ipn.Prefs{
|
||
Persist: &persist.Persist{
|
||
NodeID: "12345",
|
||
PrivateNodeKey: key.NodePrivate{}, // stripped
|
||
OldPrivateNodeKey: key.NodePrivate{}, // stripped
|
||
NetworkLockKey: key.NLPrivate{}, // stripped
|
||
UserProfile: tailcfg.UserProfile{
|
||
ID: 12345,
|
||
LoginName: "test@example.com",
|
||
DisplayName: "Test User",
|
||
ProfilePicURL: "https://example.com/profile.png",
|
||
},
|
||
},
|
||
},
|
||
}},
|
||
},
|
||
}
|
||
for _, tt := range tests {
|
||
t.Run(tt.name, func(t *testing.T) {
|
||
t.Parallel()
|
||
|
||
// Use the default InitHook if not provided by the test.
|
||
if tt.ext.InitHook == nil {
|
||
tt.ext.InitHook = func(e *testExtension) error {
|
||
// Create and register the callback on init.
|
||
handler := makeProfileChangeAppender(e)
|
||
e.Cleanup(e.host.Profiles().RegisterProfileChangeCallback(handler))
|
||
return nil
|
||
}
|
||
}
|
||
|
||
h := newExtensionHostForTest(t, &testBackend{}, true, tt.ext)
|
||
for _, call := range tt.calls {
|
||
h.NotifyProfileChange(call.Profile.View(), call.Prefs.View(), call.SameNode)
|
||
}
|
||
opts := []deepcmp.Option{
|
||
cmpopts.EquateComparable(key.NodePublic{}, netip.Addr{}, netip.Prefix{}),
|
||
}
|
||
if diff := deepcmp.Diff(tt.wantCalls, getProfileChanges(tt.ext), opts...); diff != "" {
|
||
t.Errorf("ProfileChange callbacks: (-want +got): %v", diff)
|
||
}
|
||
})
|
||
}
|
||
}
|
||
|
||
// TestBackgroundProfileResolver tests that the background profile resolvers
|
||
// are correctly registered, unregistered and invoked by the [ExtensionHost].
|
||
func TestBackgroundProfileResolver(t *testing.T) {
|
||
t.Parallel()
|
||
|
||
tests := []struct {
|
||
name string
|
||
profiles []ipn.LoginProfile // the first one is the current profile
|
||
resolvers []ipnext.ProfileResolver
|
||
wantProfile *ipn.LoginProfile
|
||
}{
|
||
{
|
||
name: "No-Profiles/No-Resolvers",
|
||
profiles: nil,
|
||
resolvers: nil,
|
||
wantProfile: nil,
|
||
},
|
||
{
|
||
// TODO(nickkhyl): update this test as we change "background profile resolvers"
|
||
// to just "profile resolvers". The wantProfile should be the current profile by default.
|
||
name: "Has-Profiles/No-Resolvers",
|
||
profiles: []ipn.LoginProfile{{ID: "profile-1"}},
|
||
resolvers: nil,
|
||
wantProfile: nil,
|
||
},
|
||
{
|
||
name: "Has-Profiles/Single-Resolver",
|
||
profiles: []ipn.LoginProfile{{ID: "profile-1"}},
|
||
resolvers: []ipnext.ProfileResolver{
|
||
func(ps ipnext.ProfileStore) ipn.LoginProfileView {
|
||
return ps.CurrentProfile()
|
||
},
|
||
},
|
||
wantProfile: &ipn.LoginProfile{ID: "profile-1"},
|
||
},
|
||
// TODO(nickkhyl): add more tests for multiple resolvers and different profiles
|
||
// once we change "background profile resolvers" to just "profile resolvers"
|
||
// and add proper conflict resolution logic.
|
||
}
|
||
for _, tt := range tests {
|
||
t.Run(tt.name, func(t *testing.T) {
|
||
t.Parallel()
|
||
|
||
// Create a new profile manager and add the profiles to it.
|
||
// We expose the profile manager to the extensions via the read-only [ipnext.ProfileStore] interface.
|
||
pm := must.Get(newProfileManager(new(mem.Store), t.Logf, new(health.Tracker)))
|
||
for i, p := range tt.profiles {
|
||
// Generate a unique ID and key for each profile,
|
||
// unless the profile already has them set
|
||
// or is an empty, unnamed profile.
|
||
if p.Name != "" {
|
||
if p.ID == "" {
|
||
p.ID = ipn.ProfileID("profile-" + strconv.Itoa(i))
|
||
}
|
||
if p.Key == "" {
|
||
p.Key = "key-" + ipn.StateKey(p.ID)
|
||
}
|
||
}
|
||
pv := p.View()
|
||
pm.knownProfiles[p.ID] = pv
|
||
if i == 0 {
|
||
// Set the first profile as the current one.
|
||
// A profileManager starts with an empty profile,
|
||
// so it's okay if the list of profiles is empty.
|
||
pm.SwitchToProfile(pv)
|
||
}
|
||
}
|
||
|
||
h := newExtensionHostForTest[ipnext.Extension](t, &testBackend{}, true)
|
||
|
||
// Register the resolvers with the host.
|
||
// This is typically done by the extensions themselves,
|
||
// but we do it here for testing purposes.
|
||
for _, r := range tt.resolvers {
|
||
t.Cleanup(h.Profiles().RegisterBackgroundProfileResolver(r))
|
||
}
|
||
|
||
// Call the resolver to get the profile.
|
||
gotProfile := h.DetermineBackgroundProfile(pm)
|
||
if !gotProfile.Equals(tt.wantProfile.View()) {
|
||
t.Errorf("Resolved profile: got %v; want %v", gotProfile, tt.wantProfile)
|
||
}
|
||
})
|
||
}
|
||
}
|
||
|
||
// TestAuditLogProviders tests that the [ExtensionHost] correctly handles
|
||
// the registration and invocation of audit log providers. It verifies that
|
||
// the audit loggers are called with the correct actions and details,
|
||
// and that any errors returned by the providers are properly propagated.
|
||
func TestAuditLogProviders(t *testing.T) {
|
||
t.Parallel()
|
||
|
||
tests := []struct {
|
||
name string
|
||
auditLoggers []ipnauth.AuditLogFunc // each represents an extension
|
||
actions []tailcfg.ClientAuditAction
|
||
wantErr bool
|
||
}{
|
||
{
|
||
name: "No-Providers",
|
||
auditLoggers: nil,
|
||
actions: []tailcfg.ClientAuditAction{"TestAction-1", "TestAction-2"},
|
||
wantErr: false,
|
||
},
|
||
{
|
||
name: "Single-Provider/Ok",
|
||
auditLoggers: []ipnauth.AuditLogFunc{
|
||
func(tailcfg.ClientAuditAction, string) error { return nil },
|
||
},
|
||
actions: []tailcfg.ClientAuditAction{"TestAction-1", "TestAction-2"},
|
||
wantErr: false,
|
||
},
|
||
{
|
||
name: "Single-Provider/Err",
|
||
auditLoggers: []ipnauth.AuditLogFunc{
|
||
func(tailcfg.ClientAuditAction, string) error {
|
||
return errors.New("failed to log")
|
||
},
|
||
},
|
||
actions: []tailcfg.ClientAuditAction{"TestAction-1", "TestAction-2"},
|
||
wantErr: true,
|
||
},
|
||
{
|
||
name: "Many-Providers/Ok",
|
||
auditLoggers: []ipnauth.AuditLogFunc{
|
||
func(tailcfg.ClientAuditAction, string) error { return nil },
|
||
func(tailcfg.ClientAuditAction, string) error { return nil },
|
||
},
|
||
actions: []tailcfg.ClientAuditAction{"TestAction-1", "TestAction-2"},
|
||
wantErr: false,
|
||
},
|
||
{
|
||
name: "Many-Providers/Err",
|
||
auditLoggers: []ipnauth.AuditLogFunc{
|
||
func(tailcfg.ClientAuditAction, string) error {
|
||
return errors.New("failed to log")
|
||
},
|
||
func(tailcfg.ClientAuditAction, string) error {
|
||
return nil // all good
|
||
},
|
||
func(tailcfg.ClientAuditAction, string) error {
|
||
return errors.New("also failed to log")
|
||
},
|
||
},
|
||
actions: []tailcfg.ClientAuditAction{"TestAction-1", "TestAction-2"},
|
||
wantErr: true, // some providers failed to log, so that's an error
|
||
},
|
||
}
|
||
for _, tt := range tests {
|
||
t.Run(tt.name, func(t *testing.T) {
|
||
// Create extensions that register the audit log providers.
|
||
// Each extension/provider will append auditable actions to its state,
|
||
// then call the test's auditLogger function.
|
||
var exts []*testExtension
|
||
for _, auditLogger := range tt.auditLoggers {
|
||
ext := &testExtension{}
|
||
provider := func() ipnauth.AuditLogFunc {
|
||
return func(action tailcfg.ClientAuditAction, details string) error {
|
||
UpdateExtState(ext, "actions", func(actions []tailcfg.ClientAuditAction) []tailcfg.ClientAuditAction {
|
||
return append(actions, action)
|
||
})
|
||
return auditLogger(action, details)
|
||
}
|
||
}
|
||
ext.InitHook = func(e *testExtension) error {
|
||
e.Cleanup(e.host.RegisterAuditLogProvider(provider))
|
||
return nil
|
||
}
|
||
exts = append(exts, ext)
|
||
}
|
||
|
||
// Initialize the host and the extensions.
|
||
h := newExtensionHostForTest(t, &testBackend{}, true, exts...)
|
||
|
||
// Use [ExtensionHost.AuditLogger] to log actions.
|
||
for _, action := range tt.actions {
|
||
err := h.AuditLogger()(action, "Test details")
|
||
if gotErr := err != nil; gotErr != tt.wantErr {
|
||
t.Errorf("AuditLogger: gotErr %v (%v); wantErr %v", gotErr, err, tt.wantErr)
|
||
}
|
||
}
|
||
|
||
// Check that the actions were logged correctly by each provider.
|
||
for _, ext := range exts {
|
||
gotActions := GetExtState[[]tailcfg.ClientAuditAction](ext, "actions")
|
||
if !slices.Equal(gotActions, tt.actions) {
|
||
t.Errorf("Actions: got %v; want %v", gotActions, tt.actions)
|
||
}
|
||
}
|
||
})
|
||
}
|
||
}
|
||
|
||
// TestNilExtensionHostMethodCall tests that calling exported methods
|
||
// on a nil [ExtensionHost] does not panic. We should treat it as a valid
|
||
// value since it's used in various tests that instantiate [LocalBackend]
|
||
// manually without calling [NewLocalBackend]. It also verifies that if
|
||
// a method returns a single func value (e.g., a cleanup function),
|
||
// it should not be nil. This is a basic sanity check to ensure that
|
||
// typical method calls on a nil receiver work as expected.
|
||
// It does not replace the need for more thorough testing of specific methods.
|
||
func TestNilExtensionHostMethodCall(t *testing.T) {
|
||
t.Parallel()
|
||
|
||
var h *ExtensionHost
|
||
typ := reflect.TypeOf(h)
|
||
for i := range typ.NumMethod() {
|
||
m := typ.Method(i)
|
||
if strings.HasSuffix(m.Name, "ForTest") {
|
||
// Skip methods that are only for testing.
|
||
continue
|
||
}
|
||
|
||
t.Run(m.Name, func(t *testing.T) {
|
||
t.Parallel()
|
||
// Calling the method on the nil receiver should not panic.
|
||
ret := checkMethodCallWithZeroArgs(t, m, h)
|
||
if len(ret) == 1 && ret[0].Kind() == reflect.Func {
|
||
// If the method returns a single func, such as a cleanup function,
|
||
// it should not be nil.
|
||
fn := ret[0]
|
||
if fn.IsNil() {
|
||
t.Fatalf("(%T).%s returned a nil func", h, m.Name)
|
||
}
|
||
// We expect it to be a no-op and calling it should not panic.
|
||
args := makeZeroArgsFor(fn)
|
||
func() {
|
||
defer func() {
|
||
if e := recover(); e != nil {
|
||
t.Fatalf("panic calling the func returned by (%T).%s: %v", e, m.Name, e)
|
||
}
|
||
}()
|
||
fn.Call(args)
|
||
}()
|
||
}
|
||
})
|
||
}
|
||
}
|
||
|
||
// checkMethodCallWithZeroArgs calls the method m on the receiver r
|
||
// with zero values for all its arguments, except the receiver itself.
|
||
// It returns the result of the method call, or fails the test if the call panics.
|
||
func checkMethodCallWithZeroArgs[T any](t *testing.T, m reflect.Method, r T) []reflect.Value {
|
||
t.Helper()
|
||
args := makeZeroArgsFor(m.Func)
|
||
// The first arg is the receiver.
|
||
args[0] = reflect.ValueOf(r)
|
||
// Calling the method should not panic.
|
||
defer func() {
|
||
if e := recover(); e != nil {
|
||
t.Fatalf("panic calling (%T).%s: %v", r, m.Name, e)
|
||
}
|
||
}()
|
||
return m.Func.Call(args)
|
||
}
|
||
|
||
func makeZeroArgsFor(fn reflect.Value) []reflect.Value {
|
||
args := make([]reflect.Value, fn.Type().NumIn())
|
||
for i := range args {
|
||
args[i] = reflect.Zero(fn.Type().In(i))
|
||
}
|
||
return args
|
||
}
|
||
|
||
// newExtensionHostForTest creates an [ExtensionHost] with the given backend and extensions.
|
||
// It associates each extension that either is or embeds a [testExtension] with the test
|
||
// and assigns a name if one isn’t already set.
|
||
//
|
||
// If the host cannot be created, it fails the test.
|
||
//
|
||
// The host is initialized if the initialize parameter is true.
|
||
// It is shut down automatically when the test ends.
|
||
func newExtensionHostForTest[T ipnext.Extension](t *testing.T, b Backend, initialize bool, exts ...T) *ExtensionHost {
|
||
t.Helper()
|
||
|
||
// testExtensionIface is a subset of the methods implemented by [testExtension] that are used here.
|
||
// We use testExtensionIface in type assertions instead of using the [testExtension] type directly,
|
||
// which supports scenarios where an extension type embeds a [testExtension].
|
||
type testExtensionIface interface {
|
||
Name() string
|
||
setName(string)
|
||
setT(*testing.T)
|
||
checkShutdown()
|
||
}
|
||
|
||
logf := tstest.WhileTestRunningLogger(t)
|
||
defs := make([]*ipnext.Definition, len(exts))
|
||
for i, ext := range exts {
|
||
if ext, ok := any(ext).(testExtensionIface); ok {
|
||
ext.setName(cmp.Or(ext.Name(), "Ext-"+strconv.Itoa(i)))
|
||
ext.setT(t)
|
||
}
|
||
defs[i] = ipnext.DefinitionForTest(ext)
|
||
}
|
||
h, err := NewExtensionHost(logf, &tsd.System{}, b, defs...)
|
||
if err != nil {
|
||
t.Fatalf("NewExtensionHost: %v", err)
|
||
}
|
||
// Replace doEnqueueBackendOperation with the one that's marked as a helper,
|
||
// so that we'll have better output if [testExecQueue.Add] fails a test.
|
||
h.doEnqueueBackendOperation = func(f func(Backend)) {
|
||
t.Helper()
|
||
h.workQueue.Add(func() { f(b) })
|
||
}
|
||
for _, ext := range exts {
|
||
if ext, ok := any(ext).(testExtensionIface); ok {
|
||
t.Cleanup(ext.checkShutdown)
|
||
}
|
||
}
|
||
t.Cleanup(h.Shutdown)
|
||
if initialize {
|
||
h.Init()
|
||
}
|
||
return h
|
||
}
|
||
|
||
// testExtension is an [ipnext.Extension] that:
|
||
// - Calls the provided init and shutdown callbacks
|
||
// when [Init] and [Shutdown] are called.
|
||
// - Ensures that [Init] and [Shutdown] are called at most once,
|
||
// that [Shutdown] is called after [Init], but is not called if [Init] fails
|
||
// and is called before the test ends if [Init] succeeds.
|
||
//
|
||
// Typically, [testExtension]s are created and passed to [newExtensionHostForTest]
|
||
// when creating an [ExtensionHost] for testing.
|
||
type testExtension struct {
|
||
t *testing.T // test that created the extension
|
||
name string // name of the extension, used for logging
|
||
|
||
host ipnext.Host // or nil if not initialized
|
||
|
||
// InitHook and ShutdownHook are optional hooks that can be set by tests.
|
||
InitHook, ShutdownHook func(*testExtension) error
|
||
|
||
// initCnt, initOkCnt and shutdownCnt are used to verify that Init and Shutdown
|
||
// are called at most once and in the correct order.
|
||
initCnt, initOkCnt, shutdownCnt atomic.Int32
|
||
|
||
// mu protects the following fields.
|
||
mu sync.Mutex
|
||
// state is the optional state used by tests.
|
||
// It can be accessed by tests using [setTestExtensionState],
|
||
// [getTestExtensionStateOk] and [getTestExtensionState].
|
||
state map[string]any
|
||
// cleanup are functions to be called on shutdown.
|
||
cleanup []func()
|
||
}
|
||
|
||
var _ ipnext.Extension = (*testExtension)(nil)
|
||
|
||
func (e *testExtension) setT(t *testing.T) {
|
||
e.t = t
|
||
}
|
||
|
||
func (e *testExtension) setName(name string) {
|
||
e.name = name
|
||
}
|
||
|
||
// Name implements [ipnext.Extension].
|
||
func (e *testExtension) Name() string {
|
||
return e.name
|
||
}
|
||
|
||
// Init implements [ipnext.Extension].
|
||
func (e *testExtension) Init(host ipnext.Host) (err error) {
|
||
e.t.Helper()
|
||
e.host = host
|
||
if e.initCnt.Add(1) == 1 {
|
||
e.mu.Lock()
|
||
e.state = make(map[string]any)
|
||
e.mu.Unlock()
|
||
} else {
|
||
e.t.Errorf("%q: Init called more than once", e.name)
|
||
}
|
||
if e.InitHook != nil {
|
||
err = e.InitHook(e)
|
||
}
|
||
if err == nil {
|
||
e.initOkCnt.Add(1)
|
||
}
|
||
return err // may be nil or non-nil
|
||
}
|
||
|
||
// InitCalled reports whether the Init method was called on the receiver.
|
||
func (e *testExtension) InitCalled() bool {
|
||
return e.initCnt.Load() != 0
|
||
}
|
||
|
||
func (e *testExtension) Cleanup(f func()) {
|
||
e.mu.Lock()
|
||
e.cleanup = append(e.cleanup, f)
|
||
e.mu.Unlock()
|
||
}
|
||
|
||
// Shutdown implements [ipnext.Extension].
|
||
func (e *testExtension) Shutdown() (err error) {
|
||
e.t.Helper()
|
||
e.mu.Lock()
|
||
cleanup := e.cleanup
|
||
e.cleanup = nil
|
||
e.mu.Unlock()
|
||
for _, f := range cleanup {
|
||
f()
|
||
}
|
||
if e.ShutdownHook != nil {
|
||
err = e.ShutdownHook(e)
|
||
}
|
||
if e.shutdownCnt.Add(1) != 1 {
|
||
e.t.Errorf("%q: Shutdown called more than once", e.name)
|
||
}
|
||
if e.initCnt.Load() == 0 {
|
||
e.t.Errorf("%q: Shutdown called without Init", e.name)
|
||
} else if e.initOkCnt.Load() == 0 {
|
||
e.t.Errorf("%q: Shutdown called despite failed Init", e.name)
|
||
}
|
||
e.host = nil
|
||
return err // may be nil or non-nil
|
||
}
|
||
|
||
func (e *testExtension) checkShutdown() {
|
||
e.t.Helper()
|
||
if e.initOkCnt.Load() != 0 && e.shutdownCnt.Load() == 0 {
|
||
e.t.Errorf("%q: Shutdown has not been called before test end", e.name)
|
||
}
|
||
}
|
||
|
||
// ShutdownCalled reports whether the Shutdown method was called on the receiver.
|
||
func (e *testExtension) ShutdownCalled() bool {
|
||
return e.shutdownCnt.Load() != 0
|
||
}
|
||
|
||
// SetExtState sets a keyed state on [testExtension] to the given value.
|
||
// Tests use it to propagate test-specific state throughout the extension lifecycle
|
||
// (e.g., between [testExtension.Init], [testExtension.Shutdown], and registered callbacks)
|
||
func SetExtState[T any](e *testExtension, key string, value T) {
|
||
e.mu.Lock()
|
||
defer e.mu.Unlock()
|
||
e.state[key] = value
|
||
}
|
||
|
||
// UpdateExtState updates a keyed state of the extension using the provided update function.
|
||
func UpdateExtState[T any](e *testExtension, key string, update func(T) T) {
|
||
e.mu.Lock()
|
||
defer e.mu.Unlock()
|
||
old, _ := e.state[key].(T)
|
||
new := update(old)
|
||
e.state[key] = new
|
||
}
|
||
|
||
// GetExtState returns the value of the keyed state of the extension.
|
||
// It returns a zero value of T if the state is not set or is of a different type.
|
||
func GetExtState[T any](e *testExtension, key string) T {
|
||
v, _ := GetExtStateOk[T](e, key)
|
||
return v
|
||
}
|
||
|
||
// GetExtStateOk is like [getExtState], but also reports whether the state
|
||
// with the given key exists and is of the expected type.
|
||
func GetExtStateOk[T any](e *testExtension, key string) (_ T, ok bool) {
|
||
e.mu.Lock()
|
||
defer e.mu.Unlock()
|
||
v, ok := e.state[key].(T)
|
||
return v, ok
|
||
}
|
||
|
||
// testExecQueue is a test implementation of [execQueue]
|
||
// that defers execution of the enqueued funcs until
|
||
// [testExecQueue.Drain] is called, and fails the test if
|
||
// if [execQueue.Add] is called before the host is initialized.
|
||
//
|
||
// It is typically used by calling [ExtensionHost.SetWorkQueueForTest].
|
||
type testExecQueue struct {
|
||
t *testing.T // test that created the queue
|
||
h *ExtensionHost // host to own the queue
|
||
|
||
mu sync.Mutex
|
||
queue []func()
|
||
}
|
||
|
||
var _ execQueue = (*testExecQueue)(nil)
|
||
|
||
// SetWorkQueueForTest is a helper function that creates a new [testExecQueue]
|
||
// and sets it as the work queue for the specified [ExtensionHost],
|
||
// returning the new queue.
|
||
//
|
||
// It fails the test if the host is already initialized.
|
||
func (h *ExtensionHost) SetWorkQueueForTest(t *testing.T) *testExecQueue {
|
||
t.Helper()
|
||
if h.initialized.Load() {
|
||
t.Fatalf("UseTestWorkQueue: host is already initialized")
|
||
return nil
|
||
}
|
||
q := &testExecQueue{t: t, h: h}
|
||
h.workQueue = q
|
||
return q
|
||
}
|
||
|
||
// Add implements [execQueue].
|
||
func (q *testExecQueue) Add(f func()) {
|
||
q.t.Helper()
|
||
|
||
if !q.h.initialized.Load() {
|
||
q.t.Fatal("ExecQueue.Add must not be called until the host is initialized")
|
||
return
|
||
}
|
||
|
||
q.mu.Lock()
|
||
q.queue = append(q.queue, f)
|
||
q.mu.Unlock()
|
||
}
|
||
|
||
// Drain executes all queued functions in the order they were added.
|
||
func (q *testExecQueue) Drain() {
|
||
q.mu.Lock()
|
||
queue := q.queue
|
||
q.queue = nil
|
||
q.mu.Unlock()
|
||
|
||
for _, f := range queue {
|
||
f()
|
||
}
|
||
}
|
||
|
||
// Shutdown implements [execQueue].
|
||
func (q *testExecQueue) Shutdown() {}
|
||
|
||
// Wait implements [execQueue].
|
||
func (q *testExecQueue) Wait(context.Context) error { return nil }
|
||
|
||
// testBackend implements [ipnext.Backend] for testing purposes
|
||
// by calling the provided hooks when its methods are called.
|
||
type testBackend struct {
|
||
switchToBestProfileHook func(reason string)
|
||
|
||
// mu protects the backend state.
|
||
// It is acquired on entry to the exported methods of the backend
|
||
// and released on exit, mimicking the behavior of the [LocalBackend].
|
||
mu sync.Mutex
|
||
}
|
||
|
||
func (b *testBackend) SwitchToBestProfile(reason string) {
|
||
b.mu.Lock()
|
||
defer b.mu.Unlock()
|
||
if b.switchToBestProfileHook != nil {
|
||
b.switchToBestProfileHook(reason)
|
||
}
|
||
}
|