From f28c8d0ec0b4dbdccd87ee43aa13ce13485dc2b1 Mon Sep 17 00:00:00 2001 From: Nick Khyl Date: Fri, 11 Apr 2025 10:09:03 -0500 Subject: [PATCH] ipn/ipn{ext,local}: allow extension lookup by name or type In this PR, we add two methods to facilitate extension lookup by both extensions, and non-extensions (e.g., PeerAPI or LocalAPI handlers): - FindExtensionByName returns an extension with the specified name. It can then be type asserted to a given type. - FindMatchingExtension is like errors.As, but for extensions. It returns the first extension that matches the target type (either a specific extension or an interface). Updates tailscale/corp#27645 Updates tailscale/corp#27502 Signed-off-by: Nick Khyl --- ipn/ipnext/ipnext.go | 19 ++++++ ipn/ipnlocal/extension_host.go | 55 +++++++++++++++++ ipn/ipnlocal/extension_host_test.go | 94 +++++++++++++++++++++++++++++ ipn/ipnlocal/local.go | 16 +++++ 4 files changed, 184 insertions(+) diff --git a/ipn/ipnext/ipnext.go b/ipn/ipnext/ipnext.go index af870b53a..f8fd500ce 100644 --- a/ipn/ipnext/ipnext.go +++ b/ipn/ipnext/ipnext.go @@ -174,6 +174,9 @@ func DefinitionWithErrForTest(name string, err error) *Definition { // // A host must be safe for concurrent use. type Host interface { + // Extensions returns the host's [ExtensionServices]. + Extensions() ExtensionServices + // Profiles returns the host's [ProfileServices]. Profiles() ProfileServices @@ -197,6 +200,22 @@ type Host interface { RegisterControlClientCallback(NewControlClientCallback) (unregister func()) } +// ExtensionServices provides access to the [Host]'s extension management services, +// such as fetching active extensions. +type ExtensionServices interface { + // FindExtensionByName returns an active extension with the given name, + // or nil if no such extension exists. + FindExtensionByName(name string) any + + // FindMatchingExtension finds the first active extension that matches target, + // and if one is found, sets target to that extension and returns true. + // Otherwise, it returns false. + // + // It panics if target is not a non-nil pointer to either a type + // that implements [ipnext.Extension], or to any interface type. + FindMatchingExtension(target any) bool +} + // ProfileServices provides access to the [Host]'s profile management services, // such as switching profiles and registering profile change callbacks. type ProfileServices interface { diff --git a/ipn/ipnlocal/extension_host.go b/ipn/ipnlocal/extension_host.go index 4a617ed72..9c6b6d44c 100644 --- a/ipn/ipnlocal/extension_host.go +++ b/ipn/ipnlocal/extension_host.go @@ -9,6 +9,7 @@ import ( "fmt" "iter" "maps" + "reflect" "slices" "strings" "sync" @@ -233,6 +234,60 @@ func (h *ExtensionHost) init() { } +// Extensions implements [ipnext.Host]. +func (h *ExtensionHost) Extensions() ipnext.ExtensionServices { + // Currently, [ExtensionHost] implements [ExtensionServices] directly. + // We might want to extract it to a separate type in the future. + return h +} + +// FindExtensionByName implements [ipnext.ExtensionServices] +// and is also used by the [LocalBackend]. +// It returns nil if the extension is not found. +func (h *ExtensionHost) FindExtensionByName(name string) any { + if h == nil { + return nil + } + h.mu.Lock() + defer h.mu.Unlock() + return h.extensionsByName[name] +} + +// extensionIfaceType is the runtime type of the [ipnext.Extension] interface. +var extensionIfaceType = reflect.TypeFor[ipnext.Extension]() + +// FindMatchingExtension implements [ipnext.ExtensionServices] +// and is also used by the [LocalBackend]. +func (h *ExtensionHost) FindMatchingExtension(target any) bool { + if h == nil { + return false + } + + if target == nil { + panic("ipnext: target cannot be nil") + } + + val := reflect.ValueOf(target) + typ := val.Type() + if typ.Kind() != reflect.Ptr || val.IsNil() { + panic("ipnext: target must be a non-nil pointer") + } + targetType := typ.Elem() + if targetType.Kind() != reflect.Interface && !targetType.Implements(extensionIfaceType) { + panic("ipnext: *target must be interface or implement ipnext.Extension") + } + + h.mu.Lock() + defer h.mu.Unlock() + for _, ext := range h.activeExtensions { + if reflect.TypeOf(ext).AssignableTo(targetType) { + val.Elem().Set(reflect.ValueOf(ext)) + return true + } + } + return false +} + // Profiles implements [ipnext.Host]. func (h *ExtensionHost) Profiles() ipnext.ProfileServices { // Currently, [ExtensionHost] implements [ipnext.ProfileServices] directly. diff --git a/ipn/ipnlocal/extension_host_test.go b/ipn/ipnlocal/extension_host_test.go index 1e03abaa1..cefe9339d 100644 --- a/ipn/ipnlocal/extension_host_test.go +++ b/ipn/ipnlocal/extension_host_test.go @@ -299,6 +299,100 @@ func TestNewExtensionHost(t *testing.T) { } } +// TestFindMatchingExtension tests that [ExtensionHost.FindMatchingExtension] correctly +// finds extensions by their type or interface. +func TestFindMatchingExtension(t *testing.T) { + t.Parallel() + + // Define test extension types and a couple of interfaces + type ( + extensionA struct { + testExtension + } + extensionB struct { + testExtension + } + extensionC struct { + testExtension + } + supportedIface interface { + Name() string + } + unsupportedIface interface { + Unsupported() + } + ) + + // Register extensions A and B, but not C. + extA := &extensionA{testExtension: testExtension{name: "A"}} + extB := &extensionB{testExtension: testExtension{name: "B"}} + h := newExtensionHostForTest[ipnext.Extension](t, &testBackend{}, true, extA, extB) + + var gotA *extensionA + if !h.FindMatchingExtension(&gotA) { + t.Errorf("LookupExtension(%T): not found", gotA) + } else if gotA != extA { + t.Errorf("LookupExtension(%T): got %v; want %v", gotA, gotA, extA) + } + + var gotB *extensionB + if !h.FindMatchingExtension(&gotB) { + t.Errorf("LookupExtension(%T): extension B not found", gotB) + } else if gotB != extB { + t.Errorf("LookupExtension(%T): got %v; want %v", gotB, gotB, extB) + } + + var gotC *extensionC + if h.FindMatchingExtension(&gotC) { + t.Errorf("LookupExtension(%T): found, but it should not exist", gotC) + } + + // All extensions implement the supportedIface interface, + // but LookupExtension should only return the first one found, + // which is extA. + var gotSupportedIface supportedIface + if !h.FindMatchingExtension(&gotSupportedIface) { + t.Errorf("LookupExtension(%T): not found", gotSupportedIface) + } else if gotName, wantName := gotSupportedIface.Name(), extA.Name(); gotName != wantName { + t.Errorf("LookupExtension(%T): name: got %v; want %v", gotSupportedIface, gotName, wantName) + } else if gotSupportedIface != extA { + t.Errorf("LookupExtension(%T): got %v; want %v", gotSupportedIface, gotSupportedIface, extA) + } + + var gotUnsupportedIface unsupportedIface + if h.FindMatchingExtension(&gotUnsupportedIface) { + t.Errorf("LookupExtension(%T): found, but it should not exist", gotUnsupportedIface) + } +} + +// TestFindExtensionByName tests that [ExtensionHost.FindExtensionByName] correctly +// finds extensions by their name. +func TestFindExtensionByName(t *testing.T) { + // Register extensions A and B, but not C. + extA := &testExtension{name: "A"} + extB := &testExtension{name: "B"} + h := newExtensionHostForTest(t, &testBackend{}, true, extA, extB) + + gotA, ok := h.FindExtensionByName(extA.Name()).(*testExtension) + if !ok { + t.Errorf("FindExtensionByName(%q): not found", extA.Name()) + } else if gotA != extA { + t.Errorf(`FindExtensionByName(%q): got %v; want %v`, extA.Name(), gotA, extA) + } + + gotB, ok := h.FindExtensionByName(extB.Name()).(*testExtension) + if !ok { + t.Errorf("FindExtensionByName(%q): not found", extB.Name()) + } else if gotB != extB { + t.Errorf(`FindExtensionByName(%q): got %v; want %v`, extB.Name(), gotB, extB) + } + + gotC, ok := h.FindExtensionByName("C").(*testExtension) + if ok { + t.Errorf(`FindExtensionByName("C"): found, but it should not exist: %v`, gotC) + } +} + // TestExtensionHostEnqueueBackendOperation verifies that [ExtensionHost] enqueues // backend operations and executes them asynchronously in the order they were received. // It also checks that operations requested before the host and all extensions are initialized diff --git a/ipn/ipnlocal/local.go b/ipn/ipnlocal/local.go index 0f3ea1fbb..9ec4b4767 100644 --- a/ipn/ipnlocal/local.go +++ b/ipn/ipnlocal/local.go @@ -589,6 +589,22 @@ func NewLocalBackend(logf logger.Logf, logID logid.PublicID, sys *tsd.System, lo return b, nil } +// FindExtensionByName returns an active extension with the given name, +// or nil if no such extension exists. +func (b *LocalBackend) FindExtensionByName(name string) any { + return b.extHost.Extensions().FindExtensionByName(name) +} + +// FindMatchingExtension finds the first active extension that matches target, +// and if one is found, sets target to that extension and returns true. +// Otherwise, it returns false. +// +// It panics if target is not a non-nil pointer to either a type +// that implements [ipnext.Extension], or to any interface type. +func (b *LocalBackend) FindMatchingExtension(target any) bool { + return b.extHost.Extensions().FindMatchingExtension(target) +} + type componentLogState struct { until time.Time timer tstime.TimerController // if non-nil, the AfterFunc to disable it