diff --git a/cmd/k8s-operator/depaware.txt b/cmd/k8s-operator/depaware.txt index 37a1be6e3..4cc4a8d46 100644 --- a/cmd/k8s-operator/depaware.txt +++ b/cmd/k8s-operator/depaware.txt @@ -921,6 +921,7 @@ tailscale.com/cmd/k8s-operator dependencies: (generated by github.com/tailscale/ tailscale.com/types/lazy from tailscale.com/ipn/ipnlocal+ tailscale.com/types/logger from tailscale.com/appc+ tailscale.com/types/logid from tailscale.com/ipn/ipnlocal+ + tailscale.com/types/mapx from tailscale.com/ipn/ipnext tailscale.com/types/netlogtype from tailscale.com/net/connstats+ tailscale.com/types/netmap from tailscale.com/control/controlclient+ tailscale.com/types/nettype from tailscale.com/ipn/localapi+ diff --git a/cmd/tailscaled/depaware.txt b/cmd/tailscaled/depaware.txt index 31881822f..329c00e93 100644 --- a/cmd/tailscaled/depaware.txt +++ b/cmd/tailscaled/depaware.txt @@ -373,6 +373,7 @@ tailscale.com/cmd/tailscaled dependencies: (generated by github.com/tailscale/de tailscale.com/types/lazy from tailscale.com/ipn/ipnlocal+ tailscale.com/types/logger from tailscale.com/appc+ tailscale.com/types/logid from tailscale.com/cmd/tailscaled+ + tailscale.com/types/mapx from tailscale.com/ipn/ipnext tailscale.com/types/netlogtype from tailscale.com/net/connstats+ tailscale.com/types/netmap from tailscale.com/control/controlclient+ tailscale.com/types/nettype from tailscale.com/ipn/localapi+ diff --git a/ipn/ipnext/ipnext.go b/ipn/ipnext/ipnext.go index a671874d1..bd8d3d79c 100644 --- a/ipn/ipnext/ipnext.go +++ b/ipn/ipnext/ipnext.go @@ -8,6 +8,7 @@ package ipnext import ( "errors" "fmt" + "iter" "tailscale.com/control/controlclient" "tailscale.com/feature" @@ -16,8 +17,7 @@ import ( "tailscale.com/tsd" "tailscale.com/tstime" "tailscale.com/types/logger" - "tailscale.com/types/views" - "tailscale.com/util/mak" + "tailscale.com/types/mapx" ) // Extension augments LocalBackend with additional functionality. @@ -91,13 +91,9 @@ func (d *Definition) MakeExtension(logf logger.Logf, sb SafeBackend) (Extension, return ext, nil } -// extensionsByName is a map of registered extensions, +// extensions is a map of registered extensions, // where the key is the name of the extension. -var extensionsByName map[string]*Definition - -// extensionsByOrder is a slice of registered extensions, -// in the order they were registered. -var extensionsByOrder []*Definition +var extensions mapx.OrderedMap[string, *Definition] // RegisterExtension registers a function that instantiates an [Extension]. // The name must be the same as returned by the extension's [Extension.Name]. @@ -111,19 +107,16 @@ func RegisterExtension(name string, newExt NewExtensionFn) { if newExt == nil { panic(fmt.Sprintf("ipnext: newExt is nil: %q", name)) } - if _, ok := extensionsByName[name]; ok { + if extensions.Contains(name) { panic(fmt.Sprintf("ipnext: duplicate extensions: %q", name)) } - ext := &Definition{name, newExt} - mak.Set(&extensionsByName, name, ext) - extensionsByOrder = append(extensionsByOrder, ext) + extensions.Set(name, &Definition{name, newExt}) } -// Extensions returns a read-only view of the extensions -// registered via [RegisterExtension]. It preserves the order -// in which the extensions were registered. -func Extensions() views.Slice[*Definition] { - return views.SliceOf(extensionsByOrder) +// Extensions iterates over the extensions in the order they were registered +// via [RegisterExtension]. +func Extensions() iter.Seq[*Definition] { + return extensions.Values() } // DefinitionForTest returns a [Definition] for the specified [Extension]. diff --git a/ipn/ipnlocal/extension_host.go b/ipn/ipnlocal/extension_host.go index 6aa42ba12..bf0e6091c 100644 --- a/ipn/ipnlocal/extension_host.go +++ b/ipn/ipnlocal/extension_host.go @@ -162,17 +162,14 @@ func newExtensionHost(logf logger.Logf, b Backend, overrideExts ...*ipnext.Defin } // Use registered extensions. - exts := ipnext.Extensions().All() - numExts := ipnext.Extensions().Len() + extDef := ipnext.Extensions() if overrideExts != nil { // Use the provided, potentially empty, overrideExts // instead of the registered ones. - exts = slices.All(overrideExts) - numExts = len(overrideExts) + extDef = slices.Values(overrideExts) } - host.allExtensions = make([]ipnext.Extension, 0, numExts) - for _, d := range exts { + for d := range extDef { ext, err := d.MakeExtension(logf, b) if errors.Is(err, ipnext.SkipExtension) { // The extension wants to be skipped. diff --git a/types/mapx/ordered.go b/types/mapx/ordered.go new file mode 100644 index 000000000..1991f039d --- /dev/null +++ b/types/mapx/ordered.go @@ -0,0 +1,111 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package mapx contains extra map types and functions. +package mapx + +import ( + "iter" + "slices" +) + +// OrderedMap is a map that maintains the order of its keys. +// +// It is meant for maps that only grow or that are small; +// is it not optimized for deleting keys. +// +// The zero value is ready to use. +// +// Locking-wise, it has the same rules as a regular Go map: +// concurrent reads are safe, but not writes. +type OrderedMap[K comparable, V any] struct { + // m is the underlying map. + m map[K]V + + // keys is the order of keys in the map. + keys []K +} + +func (m *OrderedMap[K, V]) init() { + if m.m == nil { + m.m = make(map[K]V) + } +} + +// Set sets the value for the given key in the map. +// +// If the key already exists, it updates the value and keeps the order. +func (m *OrderedMap[K, V]) Set(key K, value V) { + m.init() + len0 := len(m.keys) + m.m[key] = value + if len(m.m) > len0 { + // New key (not an update) + m.keys = append(m.keys, key) + } +} + +// Get returns the value for the given key in the map. +// If the key does not exist, it returns the zero value for V. +func (m *OrderedMap[K, V]) Get(key K) V { + return m.m[key] +} + +// GetOk returns the value for the given key in the map +// and whether it was present in the map. +func (m *OrderedMap[K, V]) GetOk(key K) (_ V, ok bool) { + v, ok := m.m[key] + return v, ok +} + +// Contains reports whether the map contains the given key. +func (m *OrderedMap[K, V]) Contains(key K) bool { + _, ok := m.m[key] + return ok +} + +// Delete removes the key from the map. +// +// The cost is O(n) in the number of keys in the map. +func (m *OrderedMap[K, V]) Delete(key K) { + len0 := len(m.m) + delete(m.m, key) + if len(m.m) == len0 { + // Wasn't present; no need to adjust keys. + return + } + was := m.keys + m.keys = m.keys[:0] + for _, k := range was { + if k != key { + m.keys = append(m.keys, k) + } + } +} + +// All yields all the keys and values, in the order they were inserted. +func (m *OrderedMap[K, V]) All() iter.Seq2[K, V] { + return func(yield func(K, V) bool) { + for _, k := range m.keys { + if !yield(k, m.m[k]) { + return + } + } + } +} + +// Keys yields the map keys, in the order they were inserted. +func (m *OrderedMap[K, V]) Keys() iter.Seq[K] { + return slices.Values(m.keys) +} + +// Values yields the map values, in the order they were inserted. +func (m *OrderedMap[K, V]) Values() iter.Seq[V] { + return func(yield func(V) bool) { + for _, k := range m.keys { + if !yield(m.m[k]) { + return + } + } + } +} diff --git a/types/mapx/ordered_test.go b/types/mapx/ordered_test.go new file mode 100644 index 000000000..7dcb7e405 --- /dev/null +++ b/types/mapx/ordered_test.go @@ -0,0 +1,56 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package mapx + +import ( + "fmt" + "slices" + "testing" +) + +func TestOrderedMap(t *testing.T) { + // Test the OrderedMap type and its methods. + var m OrderedMap[string, int] + m.Set("d", 4) + m.Set("a", 1) + m.Set("b", 1) + m.Set("b", 2) + m.Set("c", 3) + m.Delete("d") + m.Delete("e") + + want := map[string]int{ + "a": 1, + "b": 2, + "c": 3, + "d": 0, + } + for k, v := range want { + if m.Get(k) != v { + t.Errorf("Get(%q) = %d, want %d", k, m.Get(k), v) + continue + } + got, ok := m.GetOk(k) + if got != v { + t.Errorf("GetOk(%q) = %d, want %d", k, got, v) + } + if ok != m.Contains(k) { + t.Errorf("GetOk and Contains don't agree for %q", k) + } + } + + if got, want := slices.Collect(m.Keys()), []string{"a", "b", "c"}; !slices.Equal(got, want) { + t.Errorf("Keys() = %q, want %q", got, want) + } + if got, want := slices.Collect(m.Values()), []int{1, 2, 3}; !slices.Equal(got, want) { + t.Errorf("Values() = %v, want %v", got, want) + } + var allGot []string + for k, v := range m.All() { + allGot = append(allGot, fmt.Sprintf("%s:%d", k, v)) + } + if got, want := allGot, []string{"a:1", "b:2", "c:3"}; !slices.Equal(got, want) { + t.Errorf("All() = %q, want %q", got, want) + } +}