types/views: make SliceOf/MapOf panic if they see a pointer

Signed-off-by: Maisem Ali <maisem@tailscale.com>
This commit is contained in:
Maisem Ali
2022-05-09 19:31:45 -07:00
committed by Maisem Ali
parent d04afc697c
commit 395cb588b6
2 changed files with 104 additions and 9 deletions

View File

@@ -9,6 +9,8 @@ package views
import (
"encoding/json"
"errors"
"fmt"
"reflect"
"inet.af/netaddr"
"tailscale.com/net/tsaddr"
@@ -97,8 +99,14 @@ type Slice[T any] struct {
ж []T
}
// SliceOf returns a Slice for the provided slice.
func SliceOf[T any](x []T) Slice[T] { return Slice[T]{x} }
// SliceOf returns a Slice for the provided slice for immutable values.
// It panics if the value type contains pointers.
func SliceOf[T any](x []T) Slice[T] {
if ev := reflect.TypeOf(x).Elem(); containsMutable(ev) {
panic(fmt.Sprintf("slice value type %q has pointers", ev.Name()))
}
return Slice[T]{x}
}
// MarshalJSON implements json.Marshaler.
func (v Slice[T]) MarshalJSON() ([]byte, error) {
@@ -186,8 +194,52 @@ func (v *IPPrefixSlice) UnmarshalJSON(b []byte) error {
return v.ж.UnmarshalJSON(b)
}
// MapOf returns a read-only view over m.
// containsMutable reports whether the provided type has anything mutable.
func containsMutable(t reflect.Type) bool {
switch x := fmt.Sprintf("%v.%v", t.PkgPath(), t.Name()); x {
case "time.Time",
"inet.af/netaddr.IP":
return false
}
k := t.Kind()
switch k {
case reflect.Bool,
reflect.Int,
reflect.Int8,
reflect.Int16,
reflect.Int32,
reflect.Int64,
reflect.Uint,
reflect.Uint8,
reflect.Uint16,
reflect.Uint32,
reflect.Uint64,
reflect.Float32,
reflect.Float64,
reflect.Complex64,
reflect.Complex128,
reflect.String:
return false
case reflect.Array: // Not a slice.
return containsMutable(t.Elem()) && t.Len() > 0
case reflect.Struct:
for i := 0; i < t.NumField(); i++ {
f := t.Field(i)
if containsMutable(f.Type) {
return true
}
}
return false
}
return true
}
// MapOf returns a read-only view over m for immutable values.
// It panics if the value type contains pointers.
func MapOf[K comparable, V comparable](m map[K]V) Map[K, V] {
if ev := reflect.TypeOf(m).Elem(); containsMutable(ev) {
panic(fmt.Sprintf("map value type %q has pointers", ev.Name()))
}
return Map[K, V]{m}
}
@@ -226,10 +278,17 @@ func (m Map[K, V]) GetOk(k K) (V, bool) {
return v, ok
}
// ForEach calls f for every k,v pair in the underlying map.
func (m Map[K, V]) ForEach(f func(k K, v V)) {
// MapRangeFn is the func called from a Map.Range call.
// Implementations should return false to stop range.
type MapRangeFn[K comparable, V any] func(k K, v V) (cont bool)
// Range calls f for every k,v pair in the underlying map.
// It stops iteration immediately if f returns false.
func (m Map[K, V]) Range(f MapRangeFn[K, V]) {
for k, v := range m.ж {
f(k, v)
if !f(k, v) {
return
}
}
}
@@ -278,9 +337,12 @@ func (m MapFn[K, T, V]) GetOk(k K) (V, bool) {
return m.wrapv(v), ok
}
// ForEach calls f for every k,v pair in the underlying map.
func (m MapFn[K, T, V]) ForEach(f func(k K, v V)) {
// Range calls f for every k,v pair in the underlying map.
// It stops iteration immediately if f returns false.
func (m MapFn[K, T, V]) Range(f MapRangeFn[K, V]) {
for k, v := range m.ж {
f(k, m.wrapv(v))
if !f(k, m.wrapv(v)) {
return
}
}
}