mirror of
https://github.com/tailscale/tailscale.git
synced 2025-04-27 11:12:24 +00:00
1140 lines
35 KiB
Go
1140 lines
35 KiB
Go
![]() |
// Copyright (c) Tailscale Inc & AUTHORS
|
|||
|
// SPDX-License-Identifier: BSD-3-Clause
|
|||
|
|
|||
|
package ipnlocal
|
|||
|
|
|||
|
import (
|
|||
|
"cmp"
|
|||
|
"context"
|
|||
|
"errors"
|
|||
|
"net/netip"
|
|||
|
"reflect"
|
|||
|
"slices"
|
|||
|
"strconv"
|
|||
|
"strings"
|
|||
|
"sync"
|
|||
|
"sync/atomic"
|
|||
|
"testing"
|
|||
|
|
|||
|
deepcmp "github.com/google/go-cmp/cmp"
|
|||
|
"github.com/google/go-cmp/cmp/cmpopts"
|
|||
|
|
|||
|
"tailscale.com/health"
|
|||
|
"tailscale.com/ipn"
|
|||
|
"tailscale.com/ipn/ipnauth"
|
|||
|
"tailscale.com/ipn/ipnext"
|
|||
|
"tailscale.com/ipn/store/mem"
|
|||
|
"tailscale.com/tailcfg"
|
|||
|
"tailscale.com/tsd"
|
|||
|
"tailscale.com/tstest"
|
|||
|
"tailscale.com/types/key"
|
|||
|
"tailscale.com/types/persist"
|
|||
|
"tailscale.com/util/must"
|
|||
|
)
|
|||
|
|
|||
|
// TestExtensionInitShutdown tests that [ExtensionHost] correctly initializes
|
|||
|
// and shuts down extensions.
|
|||
|
func TestExtensionInitShutdown(t *testing.T) {
|
|||
|
t.Parallel()
|
|||
|
|
|||
|
// As of 2025-04-08, [ipn.Host.Init] and [ipn.Host.Shutdown] do not return errors
|
|||
|
// as extension initialization and shutdown errors are not fatal.
|
|||
|
// If these methods are updated to return errors, this test should also be updated.
|
|||
|
// The conversions below will fail to compile if their signatures change, reminding us to update the test.
|
|||
|
_ = (func(*ExtensionHost))((*ExtensionHost).Init)
|
|||
|
_ = (func(*ExtensionHost))((*ExtensionHost).Shutdown)
|
|||
|
|
|||
|
tests := []struct {
|
|||
|
name string
|
|||
|
nilHost bool
|
|||
|
exts []*testExtension
|
|||
|
wantInit []string
|
|||
|
wantShutdown []string
|
|||
|
skipInit bool
|
|||
|
}{
|
|||
|
{
|
|||
|
name: "nil-host",
|
|||
|
nilHost: true,
|
|||
|
exts: []*testExtension{},
|
|||
|
wantInit: []string{},
|
|||
|
wantShutdown: []string{},
|
|||
|
},
|
|||
|
{
|
|||
|
name: "empty-extensions",
|
|||
|
exts: []*testExtension{},
|
|||
|
wantInit: []string{},
|
|||
|
wantShutdown: []string{},
|
|||
|
},
|
|||
|
{
|
|||
|
name: "single-extension",
|
|||
|
exts: []*testExtension{{name: "A"}},
|
|||
|
wantInit: []string{"A"},
|
|||
|
wantShutdown: []string{"A"},
|
|||
|
},
|
|||
|
{
|
|||
|
name: "multiple-extensions/all-ok",
|
|||
|
exts: []*testExtension{{name: "A"}, {name: "B"}, {name: "C"}},
|
|||
|
wantInit: []string{"A", "B", "C"},
|
|||
|
wantShutdown: []string{"C", "B", "A"},
|
|||
|
},
|
|||
|
{
|
|||
|
name: "multiple-extensions/no-init-no-shutdown",
|
|||
|
exts: []*testExtension{{name: "A"}, {name: "B"}, {name: "C"}},
|
|||
|
wantInit: []string{},
|
|||
|
wantShutdown: []string{},
|
|||
|
skipInit: true,
|
|||
|
},
|
|||
|
{
|
|||
|
name: "multiple-extensions/init-failed/first",
|
|||
|
exts: []*testExtension{{
|
|||
|
name: "A",
|
|||
|
InitHook: func(*testExtension) error { return errors.New("init failed") },
|
|||
|
}, {
|
|||
|
name: "B",
|
|||
|
InitHook: func(*testExtension) error { return nil },
|
|||
|
}, {
|
|||
|
name: "C",
|
|||
|
InitHook: func(*testExtension) error { return nil },
|
|||
|
}},
|
|||
|
wantInit: []string{"A", "B", "C"},
|
|||
|
wantShutdown: []string{"C", "B"},
|
|||
|
},
|
|||
|
{
|
|||
|
name: "multiple-extensions/init-failed/second",
|
|||
|
exts: []*testExtension{{
|
|||
|
name: "A",
|
|||
|
InitHook: func(*testExtension) error { return nil },
|
|||
|
}, {
|
|||
|
name: "B",
|
|||
|
InitHook: func(*testExtension) error { return errors.New("init failed") },
|
|||
|
}, {
|
|||
|
name: "C",
|
|||
|
InitHook: func(*testExtension) error { return nil },
|
|||
|
}},
|
|||
|
wantInit: []string{"A", "B", "C"},
|
|||
|
wantShutdown: []string{"C", "A"},
|
|||
|
},
|
|||
|
{
|
|||
|
name: "multiple-extensions/init-failed/third",
|
|||
|
exts: []*testExtension{{
|
|||
|
name: "A",
|
|||
|
InitHook: func(*testExtension) error { return nil },
|
|||
|
}, {
|
|||
|
name: "B",
|
|||
|
InitHook: func(*testExtension) error { return nil },
|
|||
|
}, {
|
|||
|
name: "C",
|
|||
|
InitHook: func(*testExtension) error { return errors.New("init failed") },
|
|||
|
}},
|
|||
|
wantInit: []string{"A", "B", "C"},
|
|||
|
wantShutdown: []string{"B", "A"},
|
|||
|
},
|
|||
|
{
|
|||
|
name: "multiple-extensions/init-failed/all",
|
|||
|
exts: []*testExtension{{
|
|||
|
name: "A",
|
|||
|
InitHook: func(*testExtension) error { return errors.New("init failed") },
|
|||
|
}, {
|
|||
|
name: "B",
|
|||
|
InitHook: func(*testExtension) error { return errors.New("init failed") },
|
|||
|
}, {
|
|||
|
name: "C",
|
|||
|
InitHook: func(*testExtension) error { return errors.New("init failed") },
|
|||
|
}},
|
|||
|
wantInit: []string{"A", "B", "C"},
|
|||
|
wantShutdown: []string{},
|
|||
|
},
|
|||
|
{
|
|||
|
name: "multiple-extensions/init-skipped",
|
|||
|
exts: []*testExtension{{
|
|||
|
name: "A",
|
|||
|
InitHook: func(*testExtension) error { return nil },
|
|||
|
}, {
|
|||
|
name: "B",
|
|||
|
InitHook: func(*testExtension) error { return ipnext.SkipExtension },
|
|||
|
}, {
|
|||
|
name: "C",
|
|||
|
InitHook: func(*testExtension) error { return nil },
|
|||
|
}},
|
|||
|
wantInit: []string{"A", "B", "C"},
|
|||
|
wantShutdown: []string{"C", "A"},
|
|||
|
},
|
|||
|
}
|
|||
|
for _, tt := range tests {
|
|||
|
t.Run(tt.name, func(t *testing.T) {
|
|||
|
t.Parallel()
|
|||
|
|
|||
|
// Configure all extensions to append their names
|
|||
|
// to the gotInit and gotShutdown slices
|
|||
|
// during initialization and shutdown,
|
|||
|
// so we can check that they are called in the right order
|
|||
|
// and that shutdown is not unless init succeeded.
|
|||
|
var gotInit, gotShutdown []string
|
|||
|
for _, ext := range tt.exts {
|
|||
|
oldInitHook := ext.InitHook
|
|||
|
ext.InitHook = func(e *testExtension) error {
|
|||
|
gotInit = append(gotInit, e.name)
|
|||
|
if oldInitHook == nil {
|
|||
|
return nil
|
|||
|
}
|
|||
|
return oldInitHook(e)
|
|||
|
}
|
|||
|
ext.ShutdownHook = func(e *testExtension) error {
|
|||
|
gotShutdown = append(gotShutdown, e.name)
|
|||
|
return nil
|
|||
|
}
|
|||
|
}
|
|||
|
|
|||
|
var h *ExtensionHost
|
|||
|
if !tt.nilHost {
|
|||
|
h = newExtensionHostForTest(t, &testBackend{}, false, tt.exts...)
|
|||
|
}
|
|||
|
|
|||
|
if !tt.skipInit {
|
|||
|
h.Init()
|
|||
|
}
|
|||
|
|
|||
|
// Check that the extensions were initialized in the right order.
|
|||
|
if !slices.Equal(gotInit, tt.wantInit) {
|
|||
|
t.Errorf("Init extensions: got %v; want %v", gotInit, tt.wantInit)
|
|||
|
}
|
|||
|
|
|||
|
// Calling Init again on the host should be a no-op.
|
|||
|
// The [testExtension.Init] method fails the test if called more than once,
|
|||
|
// regardless of which test is running, so we don't need to check it here.
|
|||
|
// Similarly, calling Shutdown again on the host should be a no-op as well.
|
|||
|
// It is verified by the [testExtension.Shutdown] method itself.
|
|||
|
if !tt.skipInit {
|
|||
|
h.Init()
|
|||
|
}
|
|||
|
|
|||
|
// Extensions should not be shut down before the host is shut down,
|
|||
|
// even if they are not initialized successfully.
|
|||
|
for _, ext := range tt.exts {
|
|||
|
if gotShutdown := ext.ShutdownCalled(); gotShutdown {
|
|||
|
t.Errorf("%q: Extension shutdown called before host shutdown", ext.name)
|
|||
|
}
|
|||
|
}
|
|||
|
|
|||
|
h.Shutdown()
|
|||
|
// Check that the extensions were shut down in the right order,
|
|||
|
// and that they were not shut down if they were not initialized successfully.
|
|||
|
if !slices.Equal(gotShutdown, tt.wantShutdown) {
|
|||
|
t.Errorf("Shutdown extensions: got %v; want %v", gotShutdown, tt.wantShutdown)
|
|||
|
}
|
|||
|
|
|||
|
})
|
|||
|
}
|
|||
|
}
|
|||
|
|
|||
|
// TestNewExtensionHost tests that [NewExtensionHost] correctly creates
|
|||
|
// an [ExtensionHost], instantiates the extensions and handles errors
|
|||
|
// if an extension cannot be created.
|
|||
|
func TestNewExtensionHost(t *testing.T) {
|
|||
|
t.Parallel()
|
|||
|
tests := []struct {
|
|||
|
name string
|
|||
|
defs []*ipnext.Definition
|
|||
|
wantErr bool
|
|||
|
wantExts []string
|
|||
|
}{
|
|||
|
{
|
|||
|
name: "no-exts",
|
|||
|
defs: []*ipnext.Definition{},
|
|||
|
wantErr: false,
|
|||
|
wantExts: []string{},
|
|||
|
},
|
|||
|
{
|
|||
|
name: "exts-ok",
|
|||
|
defs: []*ipnext.Definition{
|
|||
|
ipnext.DefinitionForTest(&testExtension{name: "A"}),
|
|||
|
ipnext.DefinitionForTest(&testExtension{name: "B"}),
|
|||
|
ipnext.DefinitionForTest(&testExtension{name: "C"}),
|
|||
|
},
|
|||
|
wantErr: false,
|
|||
|
wantExts: []string{"A", "B", "C"},
|
|||
|
},
|
|||
|
{
|
|||
|
name: "exts-skipped",
|
|||
|
defs: []*ipnext.Definition{
|
|||
|
ipnext.DefinitionForTest(&testExtension{name: "A"}),
|
|||
|
ipnext.DefinitionWithErrForTest("B", ipnext.SkipExtension),
|
|||
|
ipnext.DefinitionForTest(&testExtension{name: "C"}),
|
|||
|
},
|
|||
|
wantErr: false, // extension B is skipped, that's ok
|
|||
|
wantExts: []string{"A", "C"},
|
|||
|
},
|
|||
|
{
|
|||
|
name: "exts-fail",
|
|||
|
defs: []*ipnext.Definition{
|
|||
|
ipnext.DefinitionForTest(&testExtension{name: "A"}),
|
|||
|
ipnext.DefinitionWithErrForTest("B", errors.New("failed creating Ext-2")),
|
|||
|
ipnext.DefinitionForTest(&testExtension{name: "C"}),
|
|||
|
},
|
|||
|
wantErr: true, // extension B failed to create, that's not ok
|
|||
|
wantExts: []string{},
|
|||
|
},
|
|||
|
}
|
|||
|
for _, tt := range tests {
|
|||
|
t.Run(tt.name, func(t *testing.T) {
|
|||
|
t.Parallel()
|
|||
|
logf := tstest.WhileTestRunningLogger(t)
|
|||
|
h, err := NewExtensionHost(logf, &tsd.System{}, &testBackend{}, tt.defs...)
|
|||
|
if gotErr := err != nil; gotErr != tt.wantErr {
|
|||
|
t.Errorf("NewExtensionHost: gotErr %v(%v); wantErr %v", gotErr, err, tt.wantErr)
|
|||
|
}
|
|||
|
if err != nil {
|
|||
|
return
|
|||
|
}
|
|||
|
|
|||
|
var gotExts []string
|
|||
|
for _, ext := range h.allExtensions {
|
|||
|
gotExts = append(gotExts, ext.Name())
|
|||
|
}
|
|||
|
|
|||
|
if !slices.Equal(gotExts, tt.wantExts) {
|
|||
|
t.Errorf("Shutdown extensions: got %v; want %v", gotExts, tt.wantExts)
|
|||
|
}
|
|||
|
})
|
|||
|
}
|
|||
|
}
|
|||
|
|
|||
|
// TestExtensionHostEnqueueBackendOperation verifies that [ExtensionHost] enqueues
|
|||
|
// backend operations and executes them asynchronously in the order they were received.
|
|||
|
// It also checks that operations requested before the host and all extensions are initialized
|
|||
|
// are not executed immediately but rather after the host and extensions are initialized.
|
|||
|
func TestExtensionHostEnqueueBackendOperation(t *testing.T) {
|
|||
|
t.Parallel()
|
|||
|
tests := []struct {
|
|||
|
name string
|
|||
|
preInitCalls []string // before host init
|
|||
|
extInitCalls []string // from [Extension.Init]; "" means no call
|
|||
|
wantInitCalls []string // what we expect to be called after host init
|
|||
|
postInitCalls []string // after host init
|
|||
|
}{
|
|||
|
{
|
|||
|
name: "no-calls",
|
|||
|
preInitCalls: []string{},
|
|||
|
extInitCalls: []string{},
|
|||
|
wantInitCalls: []string{},
|
|||
|
postInitCalls: []string{},
|
|||
|
},
|
|||
|
{
|
|||
|
name: "pre-init-calls",
|
|||
|
preInitCalls: []string{"pre-init-1", "pre-init-2"},
|
|||
|
extInitCalls: []string{},
|
|||
|
wantInitCalls: []string{"pre-init-1", "pre-init-2"},
|
|||
|
postInitCalls: []string{},
|
|||
|
},
|
|||
|
{
|
|||
|
name: "init-calls",
|
|||
|
preInitCalls: []string{},
|
|||
|
extInitCalls: []string{"init-1", "init-2"},
|
|||
|
wantInitCalls: []string{"init-1", "init-2"},
|
|||
|
postInitCalls: []string{},
|
|||
|
},
|
|||
|
{
|
|||
|
name: "post-init-calls",
|
|||
|
preInitCalls: []string{},
|
|||
|
extInitCalls: []string{},
|
|||
|
wantInitCalls: []string{},
|
|||
|
postInitCalls: []string{"post-init-1", "post-init-2"},
|
|||
|
},
|
|||
|
{
|
|||
|
name: "mixed-calls",
|
|||
|
preInitCalls: []string{"pre-init-1", "pre-init-2"},
|
|||
|
extInitCalls: []string{"init-1", "", "init-2"},
|
|||
|
wantInitCalls: []string{"pre-init-1", "pre-init-2", "init-1", "init-2"},
|
|||
|
postInitCalls: []string{"post-init-1", "post-init-2"},
|
|||
|
},
|
|||
|
}
|
|||
|
for _, tt := range tests {
|
|||
|
t.Run(tt.name, func(t *testing.T) {
|
|||
|
t.Parallel()
|
|||
|
|
|||
|
var gotCalls []string
|
|||
|
var h *ExtensionHost
|
|||
|
b := &testBackend{
|
|||
|
switchToBestProfileHook: func(reason string) {
|
|||
|
gotCalls = append(gotCalls, reason)
|
|||
|
},
|
|||
|
}
|
|||
|
|
|||
|
exts := make([]*testExtension, len(tt.extInitCalls))
|
|||
|
for i, reason := range tt.extInitCalls {
|
|||
|
exts[i] = &testExtension{}
|
|||
|
if reason != "" {
|
|||
|
exts[i].InitHook = func(e *testExtension) error {
|
|||
|
e.host.Profiles().SwitchToBestProfileAsync(reason)
|
|||
|
return nil
|
|||
|
}
|
|||
|
}
|
|||
|
}
|
|||
|
|
|||
|
h = newExtensionHostForTest(t, b, false, exts...)
|
|||
|
wq := h.SetWorkQueueForTest(t) // use a test queue instead of [execqueue.ExecQueue].
|
|||
|
|
|||
|
// Issue some pre-init calls. They should be deferred and not
|
|||
|
// added to the queue until the host is initialized.
|
|||
|
for _, call := range tt.preInitCalls {
|
|||
|
h.Profiles().SwitchToBestProfileAsync(call)
|
|||
|
}
|
|||
|
|
|||
|
// The queue should be empty before the host is initialized.
|
|||
|
wq.Drain()
|
|||
|
if len(gotCalls) != 0 {
|
|||
|
t.Errorf("Pre-init calls: got %v; want (none)", gotCalls)
|
|||
|
}
|
|||
|
gotCalls = nil
|
|||
|
|
|||
|
// Initialize the host and all extensions.
|
|||
|
// The extensions will make their calls during initialization.
|
|||
|
h.Init()
|
|||
|
|
|||
|
// Calls made before or during initialization should now be enqueued and running.
|
|||
|
wq.Drain()
|
|||
|
if diff := deepcmp.Diff(tt.wantInitCalls, gotCalls, cmpopts.EquateEmpty()); diff != "" {
|
|||
|
t.Errorf("Init calls: (+got -want): %v", diff)
|
|||
|
}
|
|||
|
gotCalls = nil
|
|||
|
|
|||
|
// Let's make some more calls, as if extensions were making them in a response
|
|||
|
// to external events.
|
|||
|
for _, call := range tt.postInitCalls {
|
|||
|
h.Profiles().SwitchToBestProfileAsync(call)
|
|||
|
}
|
|||
|
|
|||
|
// Any calls made after initialization should be enqueued and running.
|
|||
|
wq.Drain()
|
|||
|
if diff := deepcmp.Diff(tt.postInitCalls, gotCalls, cmpopts.EquateEmpty()); diff != "" {
|
|||
|
t.Errorf("Init calls: (+got -want): %v", diff)
|
|||
|
}
|
|||
|
gotCalls = nil
|
|||
|
})
|
|||
|
}
|
|||
|
}
|
|||
|
|
|||
|
// TestExtensionHostProfileChangeCallback verifies that [ExtensionHost] correctly handles the registration,
|
|||
|
// invocation, and unregistration of profile change callbacks. It also checks that the callbacks are called
|
|||
|
// with the correct arguments and that any private keys are stripped from [ipn.Prefs] before being passed to the callback.
|
|||
|
func TestExtensionHostProfileChangeCallback(t *testing.T) {
|
|||
|
t.Parallel()
|
|||
|
|
|||
|
type profileChange struct {
|
|||
|
Profile *ipn.LoginProfile
|
|||
|
Prefs *ipn.Prefs
|
|||
|
SameNode bool
|
|||
|
}
|
|||
|
// newProfileChange creates a new profile change with deep copies of the profile and prefs.
|
|||
|
newProfileChange := func(profile ipn.LoginProfileView, prefs ipn.PrefsView, sameNode bool) profileChange {
|
|||
|
return profileChange{
|
|||
|
Profile: profile.AsStruct(),
|
|||
|
Prefs: prefs.AsStruct(),
|
|||
|
SameNode: sameNode,
|
|||
|
}
|
|||
|
}
|
|||
|
// makeProfileChangeAppender returns a callback that appends profile changes to the extension's state.
|
|||
|
makeProfileChangeAppender := func(e *testExtension) ipnext.ProfileChangeCallback {
|
|||
|
return func(profile ipn.LoginProfileView, prefs ipn.PrefsView, sameNode bool) {
|
|||
|
UpdateExtState(e, "changes", func(changes []profileChange) []profileChange {
|
|||
|
return append(changes, newProfileChange(profile, prefs, sameNode))
|
|||
|
})
|
|||
|
}
|
|||
|
}
|
|||
|
// getProfileChanges returns the profile changes stored in the extension's state.
|
|||
|
getProfileChanges := func(e *testExtension) []profileChange {
|
|||
|
changes, _ := GetExtStateOk[[]profileChange](e, "changes")
|
|||
|
return changes
|
|||
|
}
|
|||
|
|
|||
|
tests := []struct {
|
|||
|
name string
|
|||
|
ext *testExtension
|
|||
|
calls []profileChange
|
|||
|
wantCalls []profileChange
|
|||
|
}{
|
|||
|
{
|
|||
|
// Register the callback for the lifetime of the extension.
|
|||
|
name: "Register/Lifetime",
|
|||
|
ext: &testExtension{},
|
|||
|
calls: []profileChange{
|
|||
|
{Profile: &ipn.LoginProfile{ID: "profile-1"}},
|
|||
|
{Profile: &ipn.LoginProfile{ID: "profile-2"}},
|
|||
|
{Profile: &ipn.LoginProfile{ID: "profile-3"}},
|
|||
|
{Profile: &ipn.LoginProfile{ID: "profile-3"}, SameNode: true},
|
|||
|
},
|
|||
|
wantCalls: []profileChange{ // all calls are received by the callback
|
|||
|
{Profile: &ipn.LoginProfile{ID: "profile-1"}},
|
|||
|
{Profile: &ipn.LoginProfile{ID: "profile-2"}},
|
|||
|
{Profile: &ipn.LoginProfile{ID: "profile-3"}},
|
|||
|
{Profile: &ipn.LoginProfile{ID: "profile-3"}, SameNode: true},
|
|||
|
},
|
|||
|
},
|
|||
|
{
|
|||
|
// Override the default InitHook used in the test to unregister the callback
|
|||
|
// after the first call.
|
|||
|
name: "Register/Once",
|
|||
|
ext: &testExtension{
|
|||
|
InitHook: func(e *testExtension) error {
|
|||
|
var unregister func()
|
|||
|
handler := func(profile ipn.LoginProfileView, prefs ipn.PrefsView, sameNode bool) {
|
|||
|
makeProfileChangeAppender(e)(profile, prefs, sameNode)
|
|||
|
unregister()
|
|||
|
}
|
|||
|
unregister = e.host.Profiles().RegisterProfileChangeCallback(handler)
|
|||
|
return nil
|
|||
|
},
|
|||
|
},
|
|||
|
calls: []profileChange{
|
|||
|
{Profile: &ipn.LoginProfile{ID: "profile-1"}},
|
|||
|
{Profile: &ipn.LoginProfile{ID: "profile-2"}},
|
|||
|
{Profile: &ipn.LoginProfile{ID: "profile-3"}},
|
|||
|
},
|
|||
|
wantCalls: []profileChange{ // only the first call is received by the callback
|
|||
|
{Profile: &ipn.LoginProfile{ID: "profile-1"}},
|
|||
|
},
|
|||
|
},
|
|||
|
{
|
|||
|
// Ensure that ipn.Prefs are passed to the callback.
|
|||
|
name: "CheckPrefs",
|
|||
|
ext: &testExtension{},
|
|||
|
calls: []profileChange{{
|
|||
|
Profile: &ipn.LoginProfile{ID: "profile-1"},
|
|||
|
Prefs: &ipn.Prefs{
|
|||
|
WantRunning: true,
|
|||
|
LoggedOut: false,
|
|||
|
AdvertiseRoutes: []netip.Prefix{
|
|||
|
netip.MustParsePrefix("192.168.1.0/24"),
|
|||
|
netip.MustParsePrefix("192.168.2.0/24"),
|
|||
|
},
|
|||
|
},
|
|||
|
}},
|
|||
|
wantCalls: []profileChange{{
|
|||
|
Profile: &ipn.LoginProfile{ID: "profile-1"},
|
|||
|
Prefs: &ipn.Prefs{
|
|||
|
WantRunning: true,
|
|||
|
LoggedOut: false,
|
|||
|
AdvertiseRoutes: []netip.Prefix{
|
|||
|
netip.MustParsePrefix("192.168.1.0/24"),
|
|||
|
netip.MustParsePrefix("192.168.2.0/24"),
|
|||
|
},
|
|||
|
},
|
|||
|
}},
|
|||
|
},
|
|||
|
{
|
|||
|
// Ensure that private keys are stripped from persist.Persist shared with extensions.
|
|||
|
name: "StripPrivateKeys",
|
|||
|
ext: &testExtension{},
|
|||
|
calls: []profileChange{{
|
|||
|
Profile: &ipn.LoginProfile{ID: "profile-1"},
|
|||
|
Prefs: &ipn.Prefs{
|
|||
|
Persist: &persist.Persist{
|
|||
|
NodeID: "12345",
|
|||
|
PrivateNodeKey: key.NewNode(),
|
|||
|
OldPrivateNodeKey: key.NewNode(),
|
|||
|
NetworkLockKey: key.NewNLPrivate(),
|
|||
|
UserProfile: tailcfg.UserProfile{
|
|||
|
ID: 12345,
|
|||
|
LoginName: "test@example.com",
|
|||
|
DisplayName: "Test User",
|
|||
|
ProfilePicURL: "https://example.com/profile.png",
|
|||
|
},
|
|||
|
},
|
|||
|
},
|
|||
|
}},
|
|||
|
wantCalls: []profileChange{{
|
|||
|
Profile: &ipn.LoginProfile{ID: "profile-1"},
|
|||
|
Prefs: &ipn.Prefs{
|
|||
|
Persist: &persist.Persist{
|
|||
|
NodeID: "12345",
|
|||
|
PrivateNodeKey: key.NodePrivate{}, // stripped
|
|||
|
OldPrivateNodeKey: key.NodePrivate{}, // stripped
|
|||
|
NetworkLockKey: key.NLPrivate{}, // stripped
|
|||
|
UserProfile: tailcfg.UserProfile{
|
|||
|
ID: 12345,
|
|||
|
LoginName: "test@example.com",
|
|||
|
DisplayName: "Test User",
|
|||
|
ProfilePicURL: "https://example.com/profile.png",
|
|||
|
},
|
|||
|
},
|
|||
|
},
|
|||
|
}},
|
|||
|
},
|
|||
|
}
|
|||
|
for _, tt := range tests {
|
|||
|
t.Run(tt.name, func(t *testing.T) {
|
|||
|
t.Parallel()
|
|||
|
|
|||
|
// Use the default InitHook if not provided by the test.
|
|||
|
if tt.ext.InitHook == nil {
|
|||
|
tt.ext.InitHook = func(e *testExtension) error {
|
|||
|
// Create and register the callback on init.
|
|||
|
handler := makeProfileChangeAppender(e)
|
|||
|
e.Cleanup(e.host.Profiles().RegisterProfileChangeCallback(handler))
|
|||
|
return nil
|
|||
|
}
|
|||
|
}
|
|||
|
|
|||
|
h := newExtensionHostForTest(t, &testBackend{}, true, tt.ext)
|
|||
|
for _, call := range tt.calls {
|
|||
|
h.NotifyProfileChange(call.Profile.View(), call.Prefs.View(), call.SameNode)
|
|||
|
}
|
|||
|
opts := []deepcmp.Option{
|
|||
|
cmpopts.EquateComparable(key.NodePublic{}, netip.Addr{}, netip.Prefix{}),
|
|||
|
}
|
|||
|
if diff := deepcmp.Diff(tt.wantCalls, getProfileChanges(tt.ext), opts...); diff != "" {
|
|||
|
t.Errorf("ProfileChange callbacks: (-want +got): %v", diff)
|
|||
|
}
|
|||
|
})
|
|||
|
}
|
|||
|
}
|
|||
|
|
|||
|
// TestBackgroundProfileResolver tests that the background profile resolvers
|
|||
|
// are correctly registered, unregistered and invoked by the [ExtensionHost].
|
|||
|
func TestBackgroundProfileResolver(t *testing.T) {
|
|||
|
t.Parallel()
|
|||
|
|
|||
|
tests := []struct {
|
|||
|
name string
|
|||
|
profiles []ipn.LoginProfile // the first one is the current profile
|
|||
|
resolvers []ipnext.ProfileResolver
|
|||
|
wantProfile *ipn.LoginProfile
|
|||
|
}{
|
|||
|
{
|
|||
|
name: "No-Profiles/No-Resolvers",
|
|||
|
profiles: nil,
|
|||
|
resolvers: nil,
|
|||
|
wantProfile: nil,
|
|||
|
},
|
|||
|
{
|
|||
|
// TODO(nickkhyl): update this test as we change "background profile resolvers"
|
|||
|
// to just "profile resolvers". The wantProfile should be the current profile by default.
|
|||
|
name: "Has-Profiles/No-Resolvers",
|
|||
|
profiles: []ipn.LoginProfile{{ID: "profile-1"}},
|
|||
|
resolvers: nil,
|
|||
|
wantProfile: nil,
|
|||
|
},
|
|||
|
{
|
|||
|
name: "Has-Profiles/Single-Resolver",
|
|||
|
profiles: []ipn.LoginProfile{{ID: "profile-1"}},
|
|||
|
resolvers: []ipnext.ProfileResolver{
|
|||
|
func(ps ipnext.ProfileStore) ipn.LoginProfileView {
|
|||
|
return ps.CurrentProfile()
|
|||
|
},
|
|||
|
},
|
|||
|
wantProfile: &ipn.LoginProfile{ID: "profile-1"},
|
|||
|
},
|
|||
|
// TODO(nickkhyl): add more tests for multiple resolvers and different profiles
|
|||
|
// once we change "background profile resolvers" to just "profile resolvers"
|
|||
|
// and add proper conflict resolution logic.
|
|||
|
}
|
|||
|
for _, tt := range tests {
|
|||
|
t.Run(tt.name, func(t *testing.T) {
|
|||
|
t.Parallel()
|
|||
|
|
|||
|
// Create a new profile manager and add the profiles to it.
|
|||
|
// We expose the profile manager to the extensions via the read-only [ipnext.ProfileStore] interface.
|
|||
|
pm := must.Get(newProfileManager(new(mem.Store), t.Logf, new(health.Tracker)))
|
|||
|
for i, p := range tt.profiles {
|
|||
|
// Generate a unique ID and key for each profile,
|
|||
|
// unless the profile already has them set
|
|||
|
// or is an empty, unnamed profile.
|
|||
|
if p.Name != "" {
|
|||
|
if p.ID == "" {
|
|||
|
p.ID = ipn.ProfileID("profile-" + strconv.Itoa(i))
|
|||
|
}
|
|||
|
if p.Key == "" {
|
|||
|
p.Key = "key-" + ipn.StateKey(p.ID)
|
|||
|
}
|
|||
|
}
|
|||
|
pv := p.View()
|
|||
|
pm.knownProfiles[p.ID] = pv
|
|||
|
if i == 0 {
|
|||
|
// Set the first profile as the current one.
|
|||
|
// A profileManager starts with an empty profile,
|
|||
|
// so it's okay if the list of profiles is empty.
|
|||
|
pm.SwitchToProfile(pv)
|
|||
|
}
|
|||
|
}
|
|||
|
|
|||
|
h := newExtensionHostForTest[ipnext.Extension](t, &testBackend{}, true)
|
|||
|
|
|||
|
// Register the resolvers with the host.
|
|||
|
// This is typically done by the extensions themselves,
|
|||
|
// but we do it here for testing purposes.
|
|||
|
for _, r := range tt.resolvers {
|
|||
|
t.Cleanup(h.Profiles().RegisterBackgroundProfileResolver(r))
|
|||
|
}
|
|||
|
|
|||
|
// Call the resolver to get the profile.
|
|||
|
gotProfile := h.DetermineBackgroundProfile(pm)
|
|||
|
if !gotProfile.Equals(tt.wantProfile.View()) {
|
|||
|
t.Errorf("Resolved profile: got %v; want %v", gotProfile, tt.wantProfile)
|
|||
|
}
|
|||
|
})
|
|||
|
}
|
|||
|
}
|
|||
|
|
|||
|
// TestAuditLogProviders tests that the [ExtensionHost] correctly handles
|
|||
|
// the registration and invocation of audit log providers. It verifies that
|
|||
|
// the audit loggers are called with the correct actions and details,
|
|||
|
// and that any errors returned by the providers are properly propagated.
|
|||
|
func TestAuditLogProviders(t *testing.T) {
|
|||
|
t.Parallel()
|
|||
|
|
|||
|
tests := []struct {
|
|||
|
name string
|
|||
|
auditLoggers []ipnauth.AuditLogFunc // each represents an extension
|
|||
|
actions []tailcfg.ClientAuditAction
|
|||
|
wantErr bool
|
|||
|
}{
|
|||
|
{
|
|||
|
name: "No-Providers",
|
|||
|
auditLoggers: nil,
|
|||
|
actions: []tailcfg.ClientAuditAction{"TestAction-1", "TestAction-2"},
|
|||
|
wantErr: false,
|
|||
|
},
|
|||
|
{
|
|||
|
name: "Single-Provider/Ok",
|
|||
|
auditLoggers: []ipnauth.AuditLogFunc{
|
|||
|
func(tailcfg.ClientAuditAction, string) error { return nil },
|
|||
|
},
|
|||
|
actions: []tailcfg.ClientAuditAction{"TestAction-1", "TestAction-2"},
|
|||
|
wantErr: false,
|
|||
|
},
|
|||
|
{
|
|||
|
name: "Single-Provider/Err",
|
|||
|
auditLoggers: []ipnauth.AuditLogFunc{
|
|||
|
func(tailcfg.ClientAuditAction, string) error {
|
|||
|
return errors.New("failed to log")
|
|||
|
},
|
|||
|
},
|
|||
|
actions: []tailcfg.ClientAuditAction{"TestAction-1", "TestAction-2"},
|
|||
|
wantErr: true,
|
|||
|
},
|
|||
|
{
|
|||
|
name: "Many-Providers/Ok",
|
|||
|
auditLoggers: []ipnauth.AuditLogFunc{
|
|||
|
func(tailcfg.ClientAuditAction, string) error { return nil },
|
|||
|
func(tailcfg.ClientAuditAction, string) error { return nil },
|
|||
|
},
|
|||
|
actions: []tailcfg.ClientAuditAction{"TestAction-1", "TestAction-2"},
|
|||
|
wantErr: false,
|
|||
|
},
|
|||
|
{
|
|||
|
name: "Many-Providers/Err",
|
|||
|
auditLoggers: []ipnauth.AuditLogFunc{
|
|||
|
func(tailcfg.ClientAuditAction, string) error {
|
|||
|
return errors.New("failed to log")
|
|||
|
},
|
|||
|
func(tailcfg.ClientAuditAction, string) error {
|
|||
|
return nil // all good
|
|||
|
},
|
|||
|
func(tailcfg.ClientAuditAction, string) error {
|
|||
|
return errors.New("also failed to log")
|
|||
|
},
|
|||
|
},
|
|||
|
actions: []tailcfg.ClientAuditAction{"TestAction-1", "TestAction-2"},
|
|||
|
wantErr: true, // some providers failed to log, so that's an error
|
|||
|
},
|
|||
|
}
|
|||
|
for _, tt := range tests {
|
|||
|
t.Run(tt.name, func(t *testing.T) {
|
|||
|
// Create extensions that register the audit log providers.
|
|||
|
// Each extension/provider will append auditable actions to its state,
|
|||
|
// then call the test's auditLogger function.
|
|||
|
var exts []*testExtension
|
|||
|
for _, auditLogger := range tt.auditLoggers {
|
|||
|
ext := &testExtension{}
|
|||
|
provider := func() ipnauth.AuditLogFunc {
|
|||
|
return func(action tailcfg.ClientAuditAction, details string) error {
|
|||
|
UpdateExtState(ext, "actions", func(actions []tailcfg.ClientAuditAction) []tailcfg.ClientAuditAction {
|
|||
|
return append(actions, action)
|
|||
|
})
|
|||
|
return auditLogger(action, details)
|
|||
|
}
|
|||
|
}
|
|||
|
ext.InitHook = func(e *testExtension) error {
|
|||
|
e.Cleanup(e.host.RegisterAuditLogProvider(provider))
|
|||
|
return nil
|
|||
|
}
|
|||
|
exts = append(exts, ext)
|
|||
|
}
|
|||
|
|
|||
|
// Initialize the host and the extensions.
|
|||
|
h := newExtensionHostForTest(t, &testBackend{}, true, exts...)
|
|||
|
|
|||
|
// Use [ExtensionHost.AuditLogger] to log actions.
|
|||
|
for _, action := range tt.actions {
|
|||
|
err := h.AuditLogger()(action, "Test details")
|
|||
|
if gotErr := err != nil; gotErr != tt.wantErr {
|
|||
|
t.Errorf("AuditLogger: gotErr %v (%v); wantErr %v", gotErr, err, tt.wantErr)
|
|||
|
}
|
|||
|
}
|
|||
|
|
|||
|
// Check that the actions were logged correctly by each provider.
|
|||
|
for _, ext := range exts {
|
|||
|
gotActions := GetExtState[[]tailcfg.ClientAuditAction](ext, "actions")
|
|||
|
if !slices.Equal(gotActions, tt.actions) {
|
|||
|
t.Errorf("Actions: got %v; want %v", gotActions, tt.actions)
|
|||
|
}
|
|||
|
}
|
|||
|
})
|
|||
|
}
|
|||
|
}
|
|||
|
|
|||
|
// TestNilExtensionHostMethodCall tests that calling exported methods
|
|||
|
// on a nil [ExtensionHost] does not panic. We should treat it as a valid
|
|||
|
// value since it's used in various tests that instantiate [LocalBackend]
|
|||
|
// manually without calling [NewLocalBackend]. It also verifies that if
|
|||
|
// a method returns a single func value (e.g., a cleanup function),
|
|||
|
// it should not be nil. This is a basic sanity check to ensure that
|
|||
|
// typical method calls on a nil receiver work as expected.
|
|||
|
// It does not replace the need for more thorough testing of specific methods.
|
|||
|
func TestNilExtensionHostMethodCall(t *testing.T) {
|
|||
|
t.Parallel()
|
|||
|
|
|||
|
var h *ExtensionHost
|
|||
|
typ := reflect.TypeOf(h)
|
|||
|
for i := range typ.NumMethod() {
|
|||
|
m := typ.Method(i)
|
|||
|
if strings.HasSuffix(m.Name, "ForTest") {
|
|||
|
// Skip methods that are only for testing.
|
|||
|
continue
|
|||
|
}
|
|||
|
|
|||
|
t.Run(m.Name, func(t *testing.T) {
|
|||
|
t.Parallel()
|
|||
|
// Calling the method on the nil receiver should not panic.
|
|||
|
ret := checkMethodCallWithZeroArgs(t, m, h)
|
|||
|
if len(ret) == 1 && ret[0].Kind() == reflect.Func {
|
|||
|
// If the method returns a single func, such as a cleanup function,
|
|||
|
// it should not be nil.
|
|||
|
fn := ret[0]
|
|||
|
if fn.IsNil() {
|
|||
|
t.Fatalf("(%T).%s returned a nil func", h, m.Name)
|
|||
|
}
|
|||
|
// We expect it to be a no-op and calling it should not panic.
|
|||
|
args := makeZeroArgsFor(fn)
|
|||
|
func() {
|
|||
|
defer func() {
|
|||
|
if e := recover(); e != nil {
|
|||
|
t.Fatalf("panic calling the func returned by (%T).%s: %v", e, m.Name, e)
|
|||
|
}
|
|||
|
}()
|
|||
|
fn.Call(args)
|
|||
|
}()
|
|||
|
}
|
|||
|
})
|
|||
|
}
|
|||
|
}
|
|||
|
|
|||
|
// checkMethodCallWithZeroArgs calls the method m on the receiver r
|
|||
|
// with zero values for all its arguments, except the receiver itself.
|
|||
|
// It returns the result of the method call, or fails the test if the call panics.
|
|||
|
func checkMethodCallWithZeroArgs[T any](t *testing.T, m reflect.Method, r T) []reflect.Value {
|
|||
|
t.Helper()
|
|||
|
args := makeZeroArgsFor(m.Func)
|
|||
|
// The first arg is the receiver.
|
|||
|
args[0] = reflect.ValueOf(r)
|
|||
|
// Calling the method should not panic.
|
|||
|
defer func() {
|
|||
|
if e := recover(); e != nil {
|
|||
|
t.Fatalf("panic calling (%T).%s: %v", r, m.Name, e)
|
|||
|
}
|
|||
|
}()
|
|||
|
return m.Func.Call(args)
|
|||
|
}
|
|||
|
|
|||
|
func makeZeroArgsFor(fn reflect.Value) []reflect.Value {
|
|||
|
args := make([]reflect.Value, fn.Type().NumIn())
|
|||
|
for i := range args {
|
|||
|
args[i] = reflect.Zero(fn.Type().In(i))
|
|||
|
}
|
|||
|
return args
|
|||
|
}
|
|||
|
|
|||
|
// newExtensionHostForTest creates an [ExtensionHost] with the given backend and extensions.
|
|||
|
// It associates each extension that either is or embeds a [testExtension] with the test
|
|||
|
// and assigns a name if one isn’t already set.
|
|||
|
//
|
|||
|
// If the host cannot be created, it fails the test.
|
|||
|
//
|
|||
|
// The host is initialized if the initialize parameter is true.
|
|||
|
// It is shut down automatically when the test ends.
|
|||
|
func newExtensionHostForTest[T ipnext.Extension](t *testing.T, b Backend, initialize bool, exts ...T) *ExtensionHost {
|
|||
|
t.Helper()
|
|||
|
|
|||
|
// testExtensionIface is a subset of the methods implemented by [testExtension] that are used here.
|
|||
|
// We use testExtensionIface in type assertions instead of using the [testExtension] type directly,
|
|||
|
// which supports scenarios where an extension type embeds a [testExtension].
|
|||
|
type testExtensionIface interface {
|
|||
|
Name() string
|
|||
|
setName(string)
|
|||
|
setT(*testing.T)
|
|||
|
checkShutdown()
|
|||
|
}
|
|||
|
|
|||
|
logf := tstest.WhileTestRunningLogger(t)
|
|||
|
defs := make([]*ipnext.Definition, len(exts))
|
|||
|
for i, ext := range exts {
|
|||
|
if ext, ok := any(ext).(testExtensionIface); ok {
|
|||
|
ext.setName(cmp.Or(ext.Name(), "Ext-"+strconv.Itoa(i)))
|
|||
|
ext.setT(t)
|
|||
|
}
|
|||
|
defs[i] = ipnext.DefinitionForTest(ext)
|
|||
|
}
|
|||
|
h, err := NewExtensionHost(logf, &tsd.System{}, b, defs...)
|
|||
|
if err != nil {
|
|||
|
t.Fatalf("NewExtensionHost: %v", err)
|
|||
|
}
|
|||
|
// Replace doEnqueueBackendOperation with the one that's marked as a helper,
|
|||
|
// so that we'll have better output if [testExecQueue.Add] fails a test.
|
|||
|
h.doEnqueueBackendOperation = func(f func(Backend)) {
|
|||
|
t.Helper()
|
|||
|
h.workQueue.Add(func() { f(b) })
|
|||
|
}
|
|||
|
for _, ext := range exts {
|
|||
|
if ext, ok := any(ext).(testExtensionIface); ok {
|
|||
|
t.Cleanup(ext.checkShutdown)
|
|||
|
}
|
|||
|
}
|
|||
|
t.Cleanup(h.Shutdown)
|
|||
|
if initialize {
|
|||
|
h.Init()
|
|||
|
}
|
|||
|
return h
|
|||
|
}
|
|||
|
|
|||
|
// testExtension is an [ipnext.Extension] that:
|
|||
|
// - Calls the provided init and shutdown callbacks
|
|||
|
// when [Init] and [Shutdown] are called.
|
|||
|
// - Ensures that [Init] and [Shutdown] are called at most once,
|
|||
|
// that [Shutdown] is called after [Init], but is not called if [Init] fails
|
|||
|
// and is called before the test ends if [Init] succeeds.
|
|||
|
//
|
|||
|
// Typically, [testExtension]s are created and passed to [newExtensionHostForTest]
|
|||
|
// when creating an [ExtensionHost] for testing.
|
|||
|
type testExtension struct {
|
|||
|
t *testing.T // test that created the extension
|
|||
|
name string // name of the extension, used for logging
|
|||
|
|
|||
|
host ipnext.Host // or nil if not initialized
|
|||
|
|
|||
|
// InitHook and ShutdownHook are optional hooks that can be set by tests.
|
|||
|
InitHook, ShutdownHook func(*testExtension) error
|
|||
|
|
|||
|
// initCnt, initOkCnt and shutdownCnt are used to verify that Init and Shutdown
|
|||
|
// are called at most once and in the correct order.
|
|||
|
initCnt, initOkCnt, shutdownCnt atomic.Int32
|
|||
|
|
|||
|
// mu protects the following fields.
|
|||
|
mu sync.Mutex
|
|||
|
// state is the optional state used by tests.
|
|||
|
// It can be accessed by tests using [setTestExtensionState],
|
|||
|
// [getTestExtensionStateOk] and [getTestExtensionState].
|
|||
|
state map[string]any
|
|||
|
// cleanup are functions to be called on shutdown.
|
|||
|
cleanup []func()
|
|||
|
}
|
|||
|
|
|||
|
var _ ipnext.Extension = (*testExtension)(nil)
|
|||
|
|
|||
|
func (e *testExtension) setT(t *testing.T) {
|
|||
|
e.t = t
|
|||
|
}
|
|||
|
|
|||
|
func (e *testExtension) setName(name string) {
|
|||
|
e.name = name
|
|||
|
}
|
|||
|
|
|||
|
// Name implements [ipnext.Extension].
|
|||
|
func (e *testExtension) Name() string {
|
|||
|
return e.name
|
|||
|
}
|
|||
|
|
|||
|
// Init implements [ipnext.Extension].
|
|||
|
func (e *testExtension) Init(host ipnext.Host) (err error) {
|
|||
|
e.t.Helper()
|
|||
|
e.host = host
|
|||
|
if e.initCnt.Add(1) == 1 {
|
|||
|
e.mu.Lock()
|
|||
|
e.state = make(map[string]any)
|
|||
|
e.mu.Unlock()
|
|||
|
} else {
|
|||
|
e.t.Errorf("%q: Init called more than once", e.name)
|
|||
|
}
|
|||
|
if e.InitHook != nil {
|
|||
|
err = e.InitHook(e)
|
|||
|
}
|
|||
|
if err == nil {
|
|||
|
e.initOkCnt.Add(1)
|
|||
|
}
|
|||
|
return err // may be nil or non-nil
|
|||
|
}
|
|||
|
|
|||
|
// InitCalled reports whether the Init method was called on the receiver.
|
|||
|
func (e *testExtension) InitCalled() bool {
|
|||
|
return e.initCnt.Load() != 0
|
|||
|
}
|
|||
|
|
|||
|
func (e *testExtension) Cleanup(f func()) {
|
|||
|
e.mu.Lock()
|
|||
|
e.cleanup = append(e.cleanup, f)
|
|||
|
e.mu.Unlock()
|
|||
|
}
|
|||
|
|
|||
|
// Shutdown implements [ipnext.Extension].
|
|||
|
func (e *testExtension) Shutdown() (err error) {
|
|||
|
e.t.Helper()
|
|||
|
e.mu.Lock()
|
|||
|
cleanup := e.cleanup
|
|||
|
e.cleanup = nil
|
|||
|
e.mu.Unlock()
|
|||
|
for _, f := range cleanup {
|
|||
|
f()
|
|||
|
}
|
|||
|
if e.ShutdownHook != nil {
|
|||
|
err = e.ShutdownHook(e)
|
|||
|
}
|
|||
|
if e.shutdownCnt.Add(1) != 1 {
|
|||
|
e.t.Errorf("%q: Shutdown called more than once", e.name)
|
|||
|
}
|
|||
|
if e.initCnt.Load() == 0 {
|
|||
|
e.t.Errorf("%q: Shutdown called without Init", e.name)
|
|||
|
} else if e.initOkCnt.Load() == 0 {
|
|||
|
e.t.Errorf("%q: Shutdown called despite failed Init", e.name)
|
|||
|
}
|
|||
|
e.host = nil
|
|||
|
return err // may be nil or non-nil
|
|||
|
}
|
|||
|
|
|||
|
func (e *testExtension) checkShutdown() {
|
|||
|
e.t.Helper()
|
|||
|
if e.initOkCnt.Load() != 0 && e.shutdownCnt.Load() == 0 {
|
|||
|
e.t.Errorf("%q: Shutdown has not been called before test end", e.name)
|
|||
|
}
|
|||
|
}
|
|||
|
|
|||
|
// ShutdownCalled reports whether the Shutdown method was called on the receiver.
|
|||
|
func (e *testExtension) ShutdownCalled() bool {
|
|||
|
return e.shutdownCnt.Load() != 0
|
|||
|
}
|
|||
|
|
|||
|
// SetExtState sets a keyed state on [testExtension] to the given value.
|
|||
|
// Tests use it to propagate test-specific state throughout the extension lifecycle
|
|||
|
// (e.g., between [testExtension.Init], [testExtension.Shutdown], and registered callbacks)
|
|||
|
func SetExtState[T any](e *testExtension, key string, value T) {
|
|||
|
e.mu.Lock()
|
|||
|
defer e.mu.Unlock()
|
|||
|
e.state[key] = value
|
|||
|
}
|
|||
|
|
|||
|
// UpdateExtState updates a keyed state of the extension using the provided update function.
|
|||
|
func UpdateExtState[T any](e *testExtension, key string, update func(T) T) {
|
|||
|
e.mu.Lock()
|
|||
|
defer e.mu.Unlock()
|
|||
|
old, _ := e.state[key].(T)
|
|||
|
new := update(old)
|
|||
|
e.state[key] = new
|
|||
|
}
|
|||
|
|
|||
|
// GetExtState returns the value of the keyed state of the extension.
|
|||
|
// It returns a zero value of T if the state is not set or is of a different type.
|
|||
|
func GetExtState[T any](e *testExtension, key string) T {
|
|||
|
v, _ := GetExtStateOk[T](e, key)
|
|||
|
return v
|
|||
|
}
|
|||
|
|
|||
|
// GetExtStateOk is like [getExtState], but also reports whether the state
|
|||
|
// with the given key exists and is of the expected type.
|
|||
|
func GetExtStateOk[T any](e *testExtension, key string) (_ T, ok bool) {
|
|||
|
e.mu.Lock()
|
|||
|
defer e.mu.Unlock()
|
|||
|
v, ok := e.state[key].(T)
|
|||
|
return v, ok
|
|||
|
}
|
|||
|
|
|||
|
// testExecQueue is a test implementation of [execQueue]
|
|||
|
// that defers execution of the enqueued funcs until
|
|||
|
// [testExecQueue.Drain] is called, and fails the test if
|
|||
|
// if [execQueue.Add] is called before the host is initialized.
|
|||
|
//
|
|||
|
// It is typically used by calling [ExtensionHost.SetWorkQueueForTest].
|
|||
|
type testExecQueue struct {
|
|||
|
t *testing.T // test that created the queue
|
|||
|
h *ExtensionHost // host to own the queue
|
|||
|
|
|||
|
mu sync.Mutex
|
|||
|
queue []func()
|
|||
|
}
|
|||
|
|
|||
|
var _ execQueue = (*testExecQueue)(nil)
|
|||
|
|
|||
|
// SetWorkQueueForTest is a helper function that creates a new [testExecQueue]
|
|||
|
// and sets it as the work queue for the specified [ExtensionHost],
|
|||
|
// returning the new queue.
|
|||
|
//
|
|||
|
// It fails the test if the host is already initialized.
|
|||
|
func (h *ExtensionHost) SetWorkQueueForTest(t *testing.T) *testExecQueue {
|
|||
|
t.Helper()
|
|||
|
if h.initialized.Load() {
|
|||
|
t.Fatalf("UseTestWorkQueue: host is already initialized")
|
|||
|
return nil
|
|||
|
}
|
|||
|
q := &testExecQueue{t: t, h: h}
|
|||
|
h.workQueue = q
|
|||
|
return q
|
|||
|
}
|
|||
|
|
|||
|
// Add implements [execQueue].
|
|||
|
func (q *testExecQueue) Add(f func()) {
|
|||
|
q.t.Helper()
|
|||
|
|
|||
|
if !q.h.initialized.Load() {
|
|||
|
q.t.Fatal("ExecQueue.Add must not be called until the host is initialized")
|
|||
|
return
|
|||
|
}
|
|||
|
|
|||
|
q.mu.Lock()
|
|||
|
q.queue = append(q.queue, f)
|
|||
|
q.mu.Unlock()
|
|||
|
}
|
|||
|
|
|||
|
// Drain executes all queued functions in the order they were added.
|
|||
|
func (q *testExecQueue) Drain() {
|
|||
|
q.mu.Lock()
|
|||
|
queue := q.queue
|
|||
|
q.queue = nil
|
|||
|
q.mu.Unlock()
|
|||
|
|
|||
|
for _, f := range queue {
|
|||
|
f()
|
|||
|
}
|
|||
|
}
|
|||
|
|
|||
|
// Shutdown implements [execQueue].
|
|||
|
func (q *testExecQueue) Shutdown() {}
|
|||
|
|
|||
|
// Wait implements [execQueue].
|
|||
|
func (q *testExecQueue) Wait(context.Context) error { return nil }
|
|||
|
|
|||
|
// testBackend implements [ipnext.Backend] for testing purposes
|
|||
|
// by calling the provided hooks when its methods are called.
|
|||
|
type testBackend struct {
|
|||
|
switchToBestProfileHook func(reason string)
|
|||
|
|
|||
|
// mu protects the backend state.
|
|||
|
// It is acquired on entry to the exported methods of the backend
|
|||
|
// and released on exit, mimicking the behavior of the [LocalBackend].
|
|||
|
mu sync.Mutex
|
|||
|
}
|
|||
|
|
|||
|
func (b *testBackend) SwitchToBestProfile(reason string) {
|
|||
|
b.mu.Lock()
|
|||
|
defer b.mu.Unlock()
|
|||
|
if b.switchToBestProfileHook != nil {
|
|||
|
b.switchToBestProfileHook(reason)
|
|||
|
}
|
|||
|
}
|