diff --git a/cmd/viewer/tests/tests_view.go b/cmd/viewer/tests/tests_view.go index 0c6c9e287..f8966d20e 100644 --- a/cmd/viewer/tests/tests_view.go +++ b/cmd/viewer/tests/tests_view.go @@ -188,11 +188,7 @@ func (v *MapView) UnmarshalJSON(b []byte) error { func (v MapView) Int() views.Map[string, int] { return views.MapOf(v.ж.Int) } -func (v MapView) SliceInt() views.MapFn[string, []int, views.Slice[int]] { - return views.MapFnOf(v.ж.SliceInt, func(t []int) views.Slice[int] { - return views.SliceOf(t) - }) -} +func (v MapView) SliceInt() views.MapSlice[string, int] { return views.MapSliceOf(v.ж.SliceInt) } func (v MapView) StructPtrWithPtr() views.MapFn[string, *StructWithPtrs, StructWithPtrsView] { return views.MapFnOf(v.ж.StructPtrWithPtr, func(t *StructWithPtrs) StructWithPtrsView { diff --git a/cmd/viewer/viewer.go b/cmd/viewer/viewer.go index a83499a69..f89210268 100644 --- a/cmd/viewer/viewer.go +++ b/cmd/viewer/viewer.go @@ -92,6 +92,9 @@ func(v {{.ViewName}}) {{.FieldName}}() views.MapFn[{{.MapKeyType}},{{.MapValueTy return {{.MapFn}} })} {{end}} +{{define "mapSliceField"}} +func(v {{.ViewName}}) {{.FieldName}}() views.MapSlice[{{.MapKeyType}},{{.MapValueType}}] { return views.MapSliceOf(v.ж.{{.FieldName}}) } +{{end}} {{define "unsupportedField"}}func(v {{.ViewName}}) {{.FieldName}}() {{.FieldType}} {panic("unsupported")} {{end}} {{define "stringFunc"}}func(v {{.ViewName}}) String() string { return v.ж.String() } @@ -241,9 +244,8 @@ func genView(buf *bytes.Buffer, it *codegen.ImportTracker, typ *types.Named, thi case *types.Basic, *types.Named: sElem := it.QualifiedName(sElem) args.MapValueView = fmt.Sprintf("views.Slice[%v]", sElem) - args.MapValueType = "[]" + sElem - args.MapFn = "views.SliceOf(t)" - template = "mapFnField" + args.MapValueType = sElem + template = "mapSliceField" case *types.Pointer: ptr := x pElem := ptr.Elem() diff --git a/ipn/ipnlocal/local.go b/ipn/ipnlocal/local.go index c219a391e..a77adf33f 100644 --- a/ipn/ipnlocal/local.go +++ b/ipn/ipnlocal/local.go @@ -6565,7 +6565,7 @@ func suggestExitNode(report *netcheck.Report, netMap *netmap.NetworkMap, prevSug if allowList != nil && !allowList.Contains(peer.StableID()) { continue } - if peer.CapMap().Has(tailcfg.NodeAttrSuggestExitNode) && tsaddr.ContainsExitRoutes(peer.AllowedIPs()) { + if peer.CapMap().Contains(tailcfg.NodeAttrSuggestExitNode) && tsaddr.ContainsExitRoutes(peer.AllowedIPs()) { candidates = append(candidates, peer) } } diff --git a/ipn/ipnlocal/serve.go b/ipn/ipnlocal/serve.go index 4e4af5f99..d7ba1f24c 100644 --- a/ipn/ipnlocal/serve.go +++ b/ipn/ipnlocal/serve.go @@ -307,7 +307,7 @@ func (b *LocalBackend) setServeConfigLocked(config *ipn.ServeConfig, etag string if prevConfig.Valid() { has := func(string) bool { return false } if b.serveConfig.Valid() { - has = b.serveConfig.Foreground().Has + has = b.serveConfig.Foreground().Contains } prevConfig.Foreground().Range(func(k string, v ipn.ServeConfigView) (cont bool) { if !has(k) { @@ -338,7 +338,7 @@ func (b *LocalBackend) ServeConfig() ipn.ServeConfigView { func (b *LocalBackend) DeleteForegroundSession(sessionID string) error { b.mu.Lock() defer b.mu.Unlock() - if !b.serveConfig.Valid() || !b.serveConfig.Foreground().Has(sessionID) { + if !b.serveConfig.Valid() || !b.serveConfig.Foreground().Contains(sessionID) { return nil } sc := b.serveConfig.AsStruct() diff --git a/tailcfg/tailcfg_view.go b/tailcfg/tailcfg_view.go index b5e1c9e80..3bc57ec29 100644 --- a/tailcfg/tailcfg_view.go +++ b/tailcfg/tailcfg_view.go @@ -168,10 +168,8 @@ func (v NodeView) Online() *bool { func (v NodeView) MachineAuthorized() bool { return v.ж.MachineAuthorized } func (v NodeView) Capabilities() views.Slice[NodeCapability] { return views.SliceOf(v.ж.Capabilities) } -func (v NodeView) CapMap() views.MapFn[NodeCapability, []RawMessage, views.Slice[RawMessage]] { - return views.MapFnOf(v.ж.CapMap, func(t []RawMessage) views.Slice[RawMessage] { - return views.SliceOf(t) - }) +func (v NodeView) CapMap() views.MapSlice[NodeCapability, RawMessage] { + return views.MapSliceOf(v.ж.CapMap) } func (v NodeView) UnsignedPeerAPIOnly() bool { return v.ж.UnsignedPeerAPIOnly } func (v NodeView) ComputedName() string { return v.ж.ComputedName } diff --git a/types/views/views.go b/types/views/views.go index 387ff0fd9..42758966f 100644 --- a/types/views/views.go +++ b/types/views/views.go @@ -329,13 +329,88 @@ func SliceEqualAnyOrder[T comparable](a, b Slice[T]) bool { return true } -// MapOf returns a view over m. It is the caller's responsibility to make sure K -// and V is immutable, if this is being used to provide a read-only view over m. -func MapOf[K comparable, V comparable](m map[K]V) Map[K, V] { - return Map[K, V]{m} +// MapSlice is a view over a map whose values are slices. +type MapSlice[K comparable, V any] struct { + // ж is the underlying mutable value, named with a hard-to-type + // character that looks pointy like a pointer. + // It is named distinctively to make you think of how dangerous it is to escape + // to callers. You must not let callers be able to mutate it. + ж map[K][]V } -// Map is a view over a map whose values are immutable. +// MapSliceOf returns a MapSlice for the provided map. It is the caller's +// responsibility to make sure V is immutable. +func MapSliceOf[K comparable, V any](m map[K][]V) MapSlice[K, V] { + return MapSlice[K, V]{m} +} + +// Contains reports whether k has an entry in the map. +func (m MapSlice[K, V]) Contains(k K) bool { + _, ok := m.ж[k] + return ok +} + +// IsNil reports whether the underlying map is nil. +func (m MapSlice[K, V]) IsNil() bool { + return m.ж == nil +} + +// Len returns the number of elements in the map. +func (m MapSlice[K, V]) Len() int { return len(m.ж) } + +// Get returns the element with key k. +func (m MapSlice[K, V]) Get(k K) Slice[V] { + return SliceOf(m.ж[k]) +} + +// GetOk returns the element with key k and a bool representing whether the key +// is in map. +func (m MapSlice[K, V]) GetOk(k K) (Slice[V], bool) { + v, ok := m.ж[k] + return SliceOf(v), ok +} + +// MarshalJSON implements json.Marshaler. +func (m MapSlice[K, V]) MarshalJSON() ([]byte, error) { + return json.Marshal(m.ж) +} + +// UnmarshalJSON implements json.Unmarshaler. +// It should only be called on an uninitialized Map. +func (m *MapSlice[K, V]) UnmarshalJSON(b []byte) error { + if m.ж != nil { + return errors.New("already initialized") + } + return json.Unmarshal(b, &m.ж) +} + +// Range calls f for every k,v pair in the underlying map. +// It stops iteration immediately if f returns false. +func (m MapSlice[K, V]) Range(f MapRangeFn[K, Slice[V]]) { + for k, v := range m.ж { + if !f(k, SliceOf(v)) { + return + } + } +} + +// AsMap returns a shallow-clone of the underlying map. +// +// If V is a pointer type, it is the caller's responsibility to make sure the +// values are immutable. The map and slices are cloned, but the values are not. +func (m MapSlice[K, V]) AsMap() map[K][]V { + if m.ж == nil { + return nil + } + out := maps.Clone(m.ж) + for k, v := range out { + out[k] = slices.Clone(v) + } + return out +} + +// Map provides a read-only view of a map. It is the caller's responsibility to +// make sure V is immutable. type Map[K comparable, V any] struct { // ж is the underlying mutable value, named with a hard-to-type // character that looks pointy like a pointer. @@ -344,8 +419,20 @@ type Map[K comparable, V any] struct { ж map[K]V } +// MapOf returns a view over m. It is the caller's responsibility to make sure V +// is immutable. +func MapOf[K comparable, V any](m map[K]V) Map[K, V] { + return Map[K, V]{m} +} + // Has reports whether k has an entry in the map. +// Deprecated: use Contains instead. func (m Map[K, V]) Has(k K) bool { + return m.Contains(k) +} + +// Contains reports whether k has an entry in the map. +func (m Map[K, V]) Contains(k K) bool { _, ok := m.ж[k] return ok } @@ -428,7 +515,13 @@ type MapFn[K comparable, T any, V any] struct { } // Has reports whether k has an entry in the map. +// Deprecated: use Contains instead. func (m MapFn[K, T, V]) Has(k K) bool { + return m.Contains(k) +} + +// Contains reports whether k has an entry in the map. +func (m MapFn[K, T, V]) Contains(k K) bool { _, ok := m.ж[k] return ok }