ipn/ipn{ext,local}: allow extension lookup by name or type

In this PR, we add two methods to facilitate extension lookup by both extensions,
and non-extensions (e.g., PeerAPI or LocalAPI handlers):
 - FindExtensionByName returns an extension with the specified name.
   It can then be type asserted to a given type.
 - FindMatchingExtension is like errors.As, but for extensions.
   It returns the first extension that matches the target type (either a specific extension
   or an interface).

Updates tailscale/corp#27645
Updates tailscale/corp#27502

Signed-off-by: Nick Khyl <nickk@tailscale.com>
This commit is contained in:
Nick Khyl 2025-04-11 10:09:03 -05:00 committed by Nick Khyl
parent 1e290867bd
commit f28c8d0ec0
4 changed files with 184 additions and 0 deletions

View File

@ -174,6 +174,9 @@ func DefinitionWithErrForTest(name string, err error) *Definition {
//
// A host must be safe for concurrent use.
type Host interface {
// Extensions returns the host's [ExtensionServices].
Extensions() ExtensionServices
// Profiles returns the host's [ProfileServices].
Profiles() ProfileServices
@ -197,6 +200,22 @@ type Host interface {
RegisterControlClientCallback(NewControlClientCallback) (unregister func())
}
// ExtensionServices provides access to the [Host]'s extension management services,
// such as fetching active extensions.
type ExtensionServices interface {
// FindExtensionByName returns an active extension with the given name,
// or nil if no such extension exists.
FindExtensionByName(name string) any
// FindMatchingExtension finds the first active extension that matches target,
// and if one is found, sets target to that extension and returns true.
// Otherwise, it returns false.
//
// It panics if target is not a non-nil pointer to either a type
// that implements [ipnext.Extension], or to any interface type.
FindMatchingExtension(target any) bool
}
// ProfileServices provides access to the [Host]'s profile management services,
// such as switching profiles and registering profile change callbacks.
type ProfileServices interface {

View File

@ -9,6 +9,7 @@ import (
"fmt"
"iter"
"maps"
"reflect"
"slices"
"strings"
"sync"
@ -233,6 +234,60 @@ func (h *ExtensionHost) init() {
}
// Extensions implements [ipnext.Host].
func (h *ExtensionHost) Extensions() ipnext.ExtensionServices {
// Currently, [ExtensionHost] implements [ExtensionServices] directly.
// We might want to extract it to a separate type in the future.
return h
}
// FindExtensionByName implements [ipnext.ExtensionServices]
// and is also used by the [LocalBackend].
// It returns nil if the extension is not found.
func (h *ExtensionHost) FindExtensionByName(name string) any {
if h == nil {
return nil
}
h.mu.Lock()
defer h.mu.Unlock()
return h.extensionsByName[name]
}
// extensionIfaceType is the runtime type of the [ipnext.Extension] interface.
var extensionIfaceType = reflect.TypeFor[ipnext.Extension]()
// FindMatchingExtension implements [ipnext.ExtensionServices]
// and is also used by the [LocalBackend].
func (h *ExtensionHost) FindMatchingExtension(target any) bool {
if h == nil {
return false
}
if target == nil {
panic("ipnext: target cannot be nil")
}
val := reflect.ValueOf(target)
typ := val.Type()
if typ.Kind() != reflect.Ptr || val.IsNil() {
panic("ipnext: target must be a non-nil pointer")
}
targetType := typ.Elem()
if targetType.Kind() != reflect.Interface && !targetType.Implements(extensionIfaceType) {
panic("ipnext: *target must be interface or implement ipnext.Extension")
}
h.mu.Lock()
defer h.mu.Unlock()
for _, ext := range h.activeExtensions {
if reflect.TypeOf(ext).AssignableTo(targetType) {
val.Elem().Set(reflect.ValueOf(ext))
return true
}
}
return false
}
// Profiles implements [ipnext.Host].
func (h *ExtensionHost) Profiles() ipnext.ProfileServices {
// Currently, [ExtensionHost] implements [ipnext.ProfileServices] directly.

View File

@ -299,6 +299,100 @@ func TestNewExtensionHost(t *testing.T) {
}
}
// TestFindMatchingExtension tests that [ExtensionHost.FindMatchingExtension] correctly
// finds extensions by their type or interface.
func TestFindMatchingExtension(t *testing.T) {
t.Parallel()
// Define test extension types and a couple of interfaces
type (
extensionA struct {
testExtension
}
extensionB struct {
testExtension
}
extensionC struct {
testExtension
}
supportedIface interface {
Name() string
}
unsupportedIface interface {
Unsupported()
}
)
// Register extensions A and B, but not C.
extA := &extensionA{testExtension: testExtension{name: "A"}}
extB := &extensionB{testExtension: testExtension{name: "B"}}
h := newExtensionHostForTest[ipnext.Extension](t, &testBackend{}, true, extA, extB)
var gotA *extensionA
if !h.FindMatchingExtension(&gotA) {
t.Errorf("LookupExtension(%T): not found", gotA)
} else if gotA != extA {
t.Errorf("LookupExtension(%T): got %v; want %v", gotA, gotA, extA)
}
var gotB *extensionB
if !h.FindMatchingExtension(&gotB) {
t.Errorf("LookupExtension(%T): extension B not found", gotB)
} else if gotB != extB {
t.Errorf("LookupExtension(%T): got %v; want %v", gotB, gotB, extB)
}
var gotC *extensionC
if h.FindMatchingExtension(&gotC) {
t.Errorf("LookupExtension(%T): found, but it should not exist", gotC)
}
// All extensions implement the supportedIface interface,
// but LookupExtension should only return the first one found,
// which is extA.
var gotSupportedIface supportedIface
if !h.FindMatchingExtension(&gotSupportedIface) {
t.Errorf("LookupExtension(%T): not found", gotSupportedIface)
} else if gotName, wantName := gotSupportedIface.Name(), extA.Name(); gotName != wantName {
t.Errorf("LookupExtension(%T): name: got %v; want %v", gotSupportedIface, gotName, wantName)
} else if gotSupportedIface != extA {
t.Errorf("LookupExtension(%T): got %v; want %v", gotSupportedIface, gotSupportedIface, extA)
}
var gotUnsupportedIface unsupportedIface
if h.FindMatchingExtension(&gotUnsupportedIface) {
t.Errorf("LookupExtension(%T): found, but it should not exist", gotUnsupportedIface)
}
}
// TestFindExtensionByName tests that [ExtensionHost.FindExtensionByName] correctly
// finds extensions by their name.
func TestFindExtensionByName(t *testing.T) {
// Register extensions A and B, but not C.
extA := &testExtension{name: "A"}
extB := &testExtension{name: "B"}
h := newExtensionHostForTest(t, &testBackend{}, true, extA, extB)
gotA, ok := h.FindExtensionByName(extA.Name()).(*testExtension)
if !ok {
t.Errorf("FindExtensionByName(%q): not found", extA.Name())
} else if gotA != extA {
t.Errorf(`FindExtensionByName(%q): got %v; want %v`, extA.Name(), gotA, extA)
}
gotB, ok := h.FindExtensionByName(extB.Name()).(*testExtension)
if !ok {
t.Errorf("FindExtensionByName(%q): not found", extB.Name())
} else if gotB != extB {
t.Errorf(`FindExtensionByName(%q): got %v; want %v`, extB.Name(), gotB, extB)
}
gotC, ok := h.FindExtensionByName("C").(*testExtension)
if ok {
t.Errorf(`FindExtensionByName("C"): found, but it should not exist: %v`, gotC)
}
}
// TestExtensionHostEnqueueBackendOperation verifies that [ExtensionHost] enqueues
// backend operations and executes them asynchronously in the order they were received.
// It also checks that operations requested before the host and all extensions are initialized

View File

@ -589,6 +589,22 @@ func NewLocalBackend(logf logger.Logf, logID logid.PublicID, sys *tsd.System, lo
return b, nil
}
// FindExtensionByName returns an active extension with the given name,
// or nil if no such extension exists.
func (b *LocalBackend) FindExtensionByName(name string) any {
return b.extHost.Extensions().FindExtensionByName(name)
}
// FindMatchingExtension finds the first active extension that matches target,
// and if one is found, sets target to that extension and returns true.
// Otherwise, it returns false.
//
// It panics if target is not a non-nil pointer to either a type
// that implements [ipnext.Extension], or to any interface type.
func (b *LocalBackend) FindMatchingExtension(target any) bool {
return b.extHost.Extensions().FindMatchingExtension(target)
}
type componentLogState struct {
until time.Time
timer tstime.TimerController // if non-nil, the AfterFunc to disable it