mirror of
				https://github.com/tailscale/tailscale.git
				synced 2025-11-04 09:25:11 +00:00 
			
		
		
		
	types/views: make SliceOf/MapOf panic if they see a pointer
Signed-off-by: Maisem Ali <maisem@tailscale.com>
This commit is contained in:
		@@ -9,6 +9,8 @@ package views
 | 
			
		||||
import (
 | 
			
		||||
	"encoding/json"
 | 
			
		||||
	"errors"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"reflect"
 | 
			
		||||
 | 
			
		||||
	"inet.af/netaddr"
 | 
			
		||||
	"tailscale.com/net/tsaddr"
 | 
			
		||||
@@ -97,8 +99,14 @@ type Slice[T any] struct {
 | 
			
		||||
	ж []T
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// SliceOf returns a Slice for the provided slice.
 | 
			
		||||
func SliceOf[T any](x []T) Slice[T] { return Slice[T]{x} }
 | 
			
		||||
// SliceOf returns a Slice for the provided slice for immutable values.
 | 
			
		||||
// It panics if the value type contains pointers.
 | 
			
		||||
func SliceOf[T any](x []T) Slice[T] {
 | 
			
		||||
	if ev := reflect.TypeOf(x).Elem(); containsMutable(ev) {
 | 
			
		||||
		panic(fmt.Sprintf("slice value type %q has pointers", ev.Name()))
 | 
			
		||||
	}
 | 
			
		||||
	return Slice[T]{x}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// MarshalJSON implements json.Marshaler.
 | 
			
		||||
func (v Slice[T]) MarshalJSON() ([]byte, error) {
 | 
			
		||||
@@ -186,8 +194,52 @@ func (v *IPPrefixSlice) UnmarshalJSON(b []byte) error {
 | 
			
		||||
	return v.ж.UnmarshalJSON(b)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// MapOf returns a read-only view over m.
 | 
			
		||||
// containsMutable reports whether the provided type has anything mutable.
 | 
			
		||||
func containsMutable(t reflect.Type) bool {
 | 
			
		||||
	switch x := fmt.Sprintf("%v.%v", t.PkgPath(), t.Name()); x {
 | 
			
		||||
	case "time.Time",
 | 
			
		||||
		"inet.af/netaddr.IP":
 | 
			
		||||
		return false
 | 
			
		||||
	}
 | 
			
		||||
	k := t.Kind()
 | 
			
		||||
	switch k {
 | 
			
		||||
	case reflect.Bool,
 | 
			
		||||
		reflect.Int,
 | 
			
		||||
		reflect.Int8,
 | 
			
		||||
		reflect.Int16,
 | 
			
		||||
		reflect.Int32,
 | 
			
		||||
		reflect.Int64,
 | 
			
		||||
		reflect.Uint,
 | 
			
		||||
		reflect.Uint8,
 | 
			
		||||
		reflect.Uint16,
 | 
			
		||||
		reflect.Uint32,
 | 
			
		||||
		reflect.Uint64,
 | 
			
		||||
		reflect.Float32,
 | 
			
		||||
		reflect.Float64,
 | 
			
		||||
		reflect.Complex64,
 | 
			
		||||
		reflect.Complex128,
 | 
			
		||||
		reflect.String:
 | 
			
		||||
		return false
 | 
			
		||||
	case reflect.Array: // Not a slice.
 | 
			
		||||
		return containsMutable(t.Elem()) && t.Len() > 0
 | 
			
		||||
	case reflect.Struct:
 | 
			
		||||
		for i := 0; i < t.NumField(); i++ {
 | 
			
		||||
			f := t.Field(i)
 | 
			
		||||
			if containsMutable(f.Type) {
 | 
			
		||||
				return true
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
		return false
 | 
			
		||||
	}
 | 
			
		||||
	return true
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// MapOf returns a read-only view over m for immutable values.
 | 
			
		||||
// It panics if the value type contains pointers.
 | 
			
		||||
func MapOf[K comparable, V comparable](m map[K]V) Map[K, V] {
 | 
			
		||||
	if ev := reflect.TypeOf(m).Elem(); containsMutable(ev) {
 | 
			
		||||
		panic(fmt.Sprintf("map value type %q has pointers", ev.Name()))
 | 
			
		||||
	}
 | 
			
		||||
	return Map[K, V]{m}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@@ -226,10 +278,17 @@ func (m Map[K, V]) GetOk(k K) (V, bool) {
 | 
			
		||||
	return v, ok
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// ForEach calls f for every k,v pair in the underlying map.
 | 
			
		||||
func (m Map[K, V]) ForEach(f func(k K, v V)) {
 | 
			
		||||
// MapRangeFn is the func called from a Map.Range call.
 | 
			
		||||
// Implementations should return false to stop range.
 | 
			
		||||
type MapRangeFn[K comparable, V any] func(k K, v V) (cont bool)
 | 
			
		||||
 | 
			
		||||
// Range calls f for every k,v pair in the underlying map.
 | 
			
		||||
// It stops iteration immediately if f returns false.
 | 
			
		||||
func (m Map[K, V]) Range(f MapRangeFn[K, V]) {
 | 
			
		||||
	for k, v := range m.ж {
 | 
			
		||||
		f(k, v)
 | 
			
		||||
		if !f(k, v) {
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@@ -278,9 +337,12 @@ func (m MapFn[K, T, V]) GetOk(k K) (V, bool) {
 | 
			
		||||
	return m.wrapv(v), ok
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// ForEach calls f for every k,v pair in the underlying map.
 | 
			
		||||
func (m MapFn[K, T, V]) ForEach(f func(k K, v V)) {
 | 
			
		||||
// Range calls f for every k,v pair in the underlying map.
 | 
			
		||||
// It stops iteration immediately if f returns false.
 | 
			
		||||
func (m MapFn[K, T, V]) Range(f MapRangeFn[K, V]) {
 | 
			
		||||
	for k, v := range m.ж {
 | 
			
		||||
		f(k, m.wrapv(v))
 | 
			
		||||
		if !f(k, m.wrapv(v)) {
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -10,10 +10,43 @@ import (
 | 
			
		||||
	"reflect"
 | 
			
		||||
	"strings"
 | 
			
		||||
	"testing"
 | 
			
		||||
	"time"
 | 
			
		||||
 | 
			
		||||
	"go4.org/mem"
 | 
			
		||||
	"inet.af/netaddr"
 | 
			
		||||
	"tailscale.com/types/structs"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func TestContainsPointers(t *testing.T) {
 | 
			
		||||
	tests := []struct {
 | 
			
		||||
		name string
 | 
			
		||||
		in   any
 | 
			
		||||
		want bool
 | 
			
		||||
	}{
 | 
			
		||||
		{name: "string", in: "foo", want: false},
 | 
			
		||||
		{name: "int", in: 42, want: false},
 | 
			
		||||
		{name: "struct", in: struct{ string }{"foo"}, want: false},
 | 
			
		||||
		{name: "mem.RO", in: mem.B([]byte{1}), want: false},
 | 
			
		||||
		{name: "time.Time", in: time.Now(), want: false},
 | 
			
		||||
		{name: "netaddr.IP", in: netaddr.MustParseIP("1.1.1.1"), want: false},
 | 
			
		||||
		{name: "netaddr.IPPrefix", in: netaddr.MustParseIP("1.1.1.1"), want: false},
 | 
			
		||||
		{name: "structs.Incomparable", in: structs.Incomparable{}, want: false},
 | 
			
		||||
 | 
			
		||||
		{name: "*int", in: (*int)(nil), want: true},
 | 
			
		||||
		{name: "*string", in: (*string)(nil), want: true},
 | 
			
		||||
		{name: "struct-with-pointer", in: struct{ X *string }{}, want: true},
 | 
			
		||||
		{name: "slice-with-pointer", in: []struct{ X *string }{}, want: true},
 | 
			
		||||
		{name: "slice-of-struct", in: []struct{ string }{}, want: true},
 | 
			
		||||
	}
 | 
			
		||||
	for _, tc := range tests {
 | 
			
		||||
		t.Run(tc.name, func(t *testing.T) {
 | 
			
		||||
			if containsMutable(reflect.TypeOf(tc.in)) != tc.want {
 | 
			
		||||
				t.Errorf("containsPointers %T; want %v", tc.in, tc.want)
 | 
			
		||||
			}
 | 
			
		||||
		})
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestViewsJSON(t *testing.T) {
 | 
			
		||||
	mustCIDR := func(cidrs ...string) (out []netaddr.IPPrefix) {
 | 
			
		||||
		for _, cidr := range cidrs {
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user