mirror of
https://github.com/tailscale/tailscale.git
synced 2025-04-23 17:31:43 +00:00
ipn/ipn{ext,local}: allow extension lookup by name or type
In this PR, we add two methods to facilitate extension lookup by both extensions, and non-extensions (e.g., PeerAPI or LocalAPI handlers): - FindExtensionByName returns an extension with the specified name. It can then be type asserted to a given type. - FindMatchingExtension is like errors.As, but for extensions. It returns the first extension that matches the target type (either a specific extension or an interface). Updates tailscale/corp#27645 Updates tailscale/corp#27502 Signed-off-by: Nick Khyl <nickk@tailscale.com>
This commit is contained in:
parent
1e290867bd
commit
f28c8d0ec0
@ -174,6 +174,9 @@ func DefinitionWithErrForTest(name string, err error) *Definition {
|
||||
//
|
||||
// A host must be safe for concurrent use.
|
||||
type Host interface {
|
||||
// Extensions returns the host's [ExtensionServices].
|
||||
Extensions() ExtensionServices
|
||||
|
||||
// Profiles returns the host's [ProfileServices].
|
||||
Profiles() ProfileServices
|
||||
|
||||
@ -197,6 +200,22 @@ type Host interface {
|
||||
RegisterControlClientCallback(NewControlClientCallback) (unregister func())
|
||||
}
|
||||
|
||||
// ExtensionServices provides access to the [Host]'s extension management services,
|
||||
// such as fetching active extensions.
|
||||
type ExtensionServices interface {
|
||||
// FindExtensionByName returns an active extension with the given name,
|
||||
// or nil if no such extension exists.
|
||||
FindExtensionByName(name string) any
|
||||
|
||||
// FindMatchingExtension finds the first active extension that matches target,
|
||||
// and if one is found, sets target to that extension and returns true.
|
||||
// Otherwise, it returns false.
|
||||
//
|
||||
// It panics if target is not a non-nil pointer to either a type
|
||||
// that implements [ipnext.Extension], or to any interface type.
|
||||
FindMatchingExtension(target any) bool
|
||||
}
|
||||
|
||||
// ProfileServices provides access to the [Host]'s profile management services,
|
||||
// such as switching profiles and registering profile change callbacks.
|
||||
type ProfileServices interface {
|
||||
|
@ -9,6 +9,7 @@ import (
|
||||
"fmt"
|
||||
"iter"
|
||||
"maps"
|
||||
"reflect"
|
||||
"slices"
|
||||
"strings"
|
||||
"sync"
|
||||
@ -233,6 +234,60 @@ func (h *ExtensionHost) init() {
|
||||
|
||||
}
|
||||
|
||||
// Extensions implements [ipnext.Host].
|
||||
func (h *ExtensionHost) Extensions() ipnext.ExtensionServices {
|
||||
// Currently, [ExtensionHost] implements [ExtensionServices] directly.
|
||||
// We might want to extract it to a separate type in the future.
|
||||
return h
|
||||
}
|
||||
|
||||
// FindExtensionByName implements [ipnext.ExtensionServices]
|
||||
// and is also used by the [LocalBackend].
|
||||
// It returns nil if the extension is not found.
|
||||
func (h *ExtensionHost) FindExtensionByName(name string) any {
|
||||
if h == nil {
|
||||
return nil
|
||||
}
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
return h.extensionsByName[name]
|
||||
}
|
||||
|
||||
// extensionIfaceType is the runtime type of the [ipnext.Extension] interface.
|
||||
var extensionIfaceType = reflect.TypeFor[ipnext.Extension]()
|
||||
|
||||
// FindMatchingExtension implements [ipnext.ExtensionServices]
|
||||
// and is also used by the [LocalBackend].
|
||||
func (h *ExtensionHost) FindMatchingExtension(target any) bool {
|
||||
if h == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
if target == nil {
|
||||
panic("ipnext: target cannot be nil")
|
||||
}
|
||||
|
||||
val := reflect.ValueOf(target)
|
||||
typ := val.Type()
|
||||
if typ.Kind() != reflect.Ptr || val.IsNil() {
|
||||
panic("ipnext: target must be a non-nil pointer")
|
||||
}
|
||||
targetType := typ.Elem()
|
||||
if targetType.Kind() != reflect.Interface && !targetType.Implements(extensionIfaceType) {
|
||||
panic("ipnext: *target must be interface or implement ipnext.Extension")
|
||||
}
|
||||
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
for _, ext := range h.activeExtensions {
|
||||
if reflect.TypeOf(ext).AssignableTo(targetType) {
|
||||
val.Elem().Set(reflect.ValueOf(ext))
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// Profiles implements [ipnext.Host].
|
||||
func (h *ExtensionHost) Profiles() ipnext.ProfileServices {
|
||||
// Currently, [ExtensionHost] implements [ipnext.ProfileServices] directly.
|
||||
|
@ -299,6 +299,100 @@ func TestNewExtensionHost(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// TestFindMatchingExtension tests that [ExtensionHost.FindMatchingExtension] correctly
|
||||
// finds extensions by their type or interface.
|
||||
func TestFindMatchingExtension(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Define test extension types and a couple of interfaces
|
||||
type (
|
||||
extensionA struct {
|
||||
testExtension
|
||||
}
|
||||
extensionB struct {
|
||||
testExtension
|
||||
}
|
||||
extensionC struct {
|
||||
testExtension
|
||||
}
|
||||
supportedIface interface {
|
||||
Name() string
|
||||
}
|
||||
unsupportedIface interface {
|
||||
Unsupported()
|
||||
}
|
||||
)
|
||||
|
||||
// Register extensions A and B, but not C.
|
||||
extA := &extensionA{testExtension: testExtension{name: "A"}}
|
||||
extB := &extensionB{testExtension: testExtension{name: "B"}}
|
||||
h := newExtensionHostForTest[ipnext.Extension](t, &testBackend{}, true, extA, extB)
|
||||
|
||||
var gotA *extensionA
|
||||
if !h.FindMatchingExtension(&gotA) {
|
||||
t.Errorf("LookupExtension(%T): not found", gotA)
|
||||
} else if gotA != extA {
|
||||
t.Errorf("LookupExtension(%T): got %v; want %v", gotA, gotA, extA)
|
||||
}
|
||||
|
||||
var gotB *extensionB
|
||||
if !h.FindMatchingExtension(&gotB) {
|
||||
t.Errorf("LookupExtension(%T): extension B not found", gotB)
|
||||
} else if gotB != extB {
|
||||
t.Errorf("LookupExtension(%T): got %v; want %v", gotB, gotB, extB)
|
||||
}
|
||||
|
||||
var gotC *extensionC
|
||||
if h.FindMatchingExtension(&gotC) {
|
||||
t.Errorf("LookupExtension(%T): found, but it should not exist", gotC)
|
||||
}
|
||||
|
||||
// All extensions implement the supportedIface interface,
|
||||
// but LookupExtension should only return the first one found,
|
||||
// which is extA.
|
||||
var gotSupportedIface supportedIface
|
||||
if !h.FindMatchingExtension(&gotSupportedIface) {
|
||||
t.Errorf("LookupExtension(%T): not found", gotSupportedIface)
|
||||
} else if gotName, wantName := gotSupportedIface.Name(), extA.Name(); gotName != wantName {
|
||||
t.Errorf("LookupExtension(%T): name: got %v; want %v", gotSupportedIface, gotName, wantName)
|
||||
} else if gotSupportedIface != extA {
|
||||
t.Errorf("LookupExtension(%T): got %v; want %v", gotSupportedIface, gotSupportedIface, extA)
|
||||
}
|
||||
|
||||
var gotUnsupportedIface unsupportedIface
|
||||
if h.FindMatchingExtension(&gotUnsupportedIface) {
|
||||
t.Errorf("LookupExtension(%T): found, but it should not exist", gotUnsupportedIface)
|
||||
}
|
||||
}
|
||||
|
||||
// TestFindExtensionByName tests that [ExtensionHost.FindExtensionByName] correctly
|
||||
// finds extensions by their name.
|
||||
func TestFindExtensionByName(t *testing.T) {
|
||||
// Register extensions A and B, but not C.
|
||||
extA := &testExtension{name: "A"}
|
||||
extB := &testExtension{name: "B"}
|
||||
h := newExtensionHostForTest(t, &testBackend{}, true, extA, extB)
|
||||
|
||||
gotA, ok := h.FindExtensionByName(extA.Name()).(*testExtension)
|
||||
if !ok {
|
||||
t.Errorf("FindExtensionByName(%q): not found", extA.Name())
|
||||
} else if gotA != extA {
|
||||
t.Errorf(`FindExtensionByName(%q): got %v; want %v`, extA.Name(), gotA, extA)
|
||||
}
|
||||
|
||||
gotB, ok := h.FindExtensionByName(extB.Name()).(*testExtension)
|
||||
if !ok {
|
||||
t.Errorf("FindExtensionByName(%q): not found", extB.Name())
|
||||
} else if gotB != extB {
|
||||
t.Errorf(`FindExtensionByName(%q): got %v; want %v`, extB.Name(), gotB, extB)
|
||||
}
|
||||
|
||||
gotC, ok := h.FindExtensionByName("C").(*testExtension)
|
||||
if ok {
|
||||
t.Errorf(`FindExtensionByName("C"): found, but it should not exist: %v`, gotC)
|
||||
}
|
||||
}
|
||||
|
||||
// TestExtensionHostEnqueueBackendOperation verifies that [ExtensionHost] enqueues
|
||||
// backend operations and executes them asynchronously in the order they were received.
|
||||
// It also checks that operations requested before the host and all extensions are initialized
|
||||
|
@ -589,6 +589,22 @@ func NewLocalBackend(logf logger.Logf, logID logid.PublicID, sys *tsd.System, lo
|
||||
return b, nil
|
||||
}
|
||||
|
||||
// FindExtensionByName returns an active extension with the given name,
|
||||
// or nil if no such extension exists.
|
||||
func (b *LocalBackend) FindExtensionByName(name string) any {
|
||||
return b.extHost.Extensions().FindExtensionByName(name)
|
||||
}
|
||||
|
||||
// FindMatchingExtension finds the first active extension that matches target,
|
||||
// and if one is found, sets target to that extension and returns true.
|
||||
// Otherwise, it returns false.
|
||||
//
|
||||
// It panics if target is not a non-nil pointer to either a type
|
||||
// that implements [ipnext.Extension], or to any interface type.
|
||||
func (b *LocalBackend) FindMatchingExtension(target any) bool {
|
||||
return b.extHost.Extensions().FindMatchingExtension(target)
|
||||
}
|
||||
|
||||
type componentLogState struct {
|
||||
until time.Time
|
||||
timer tstime.TimerController // if non-nil, the AfterFunc to disable it
|
||||
|
Loading…
x
Reference in New Issue
Block a user