types/views: add MapViewsEqual and MapViewsEqualFunc

Extracted from some code written in the other repo.

Updates tailscale/corp#25479

Signed-off-by: Andrew Dunham <andrew@du.nham.ca>
Change-Id: I92c97a63a8f35cace6e89a730938ea587dcefd9b
This commit is contained in:
Andrew Dunham 2025-01-08 13:21:54 -05:00
parent 1d4fd2fb34
commit 9f17260e21
2 changed files with 126 additions and 0 deletions

View File

@ -513,6 +513,47 @@ func (m Map[K, V]) AsMap() map[K]V {
return maps.Clone(m.ж) return maps.Clone(m.ж)
} }
// NOTE: the type constraints for MapViewsEqual and MapViewsEqualFunc are based
// on those for maps.Equal and maps.EqualFunc.
// MapViewsEqual returns whether the two given [Map]s are equal. Both K and V
// must be comparable; if V is non-comparable, use [MapViewsEqualFunc] instead.
func MapViewsEqual[K, V comparable](a, b Map[K, V]) bool {
if a.Len() != b.Len() || a.IsNil() != b.IsNil() {
return false
}
if a.IsNil() {
return true // both nil; can exit early
}
for k, v := range a.All() {
bv, ok := b.GetOk(k)
if !ok || v != bv {
return false
}
}
return true
}
// MapViewsEqualFunc returns whether the two given [Map]s are equal, using the
// given function to compare two values.
func MapViewsEqualFunc[K comparable, V1, V2 any](a Map[K, V1], b Map[K, V2], eq func(V1, V2) bool) bool {
if a.Len() != b.Len() || a.IsNil() != b.IsNil() {
return false
}
if a.IsNil() {
return true // both nil; can exit early
}
for k, v := range a.All() {
bv, ok := b.GetOk(k)
if !ok || !eq(v, bv) {
return false
}
}
return true
}
// MapRangeFn is the func called from a Map.Range call. // MapRangeFn is the func called from a Map.Range call.
// Implementations should return false to stop range. // Implementations should return false to stop range.
type MapRangeFn[K comparable, V any] func(k K, v V) (cont bool) type MapRangeFn[K comparable, V any] func(k K, v V) (cont bool)

View File

@ -15,6 +15,7 @@ import (
"unsafe" "unsafe"
qt "github.com/frankban/quicktest" qt "github.com/frankban/quicktest"
"tailscale.com/types/structs"
) )
type viewStruct struct { type viewStruct struct {
@ -501,3 +502,87 @@ func TestMapFnIter(t *testing.T) {
t.Errorf("got %q; want %q", got, want) t.Errorf("got %q; want %q", got, want)
} }
} }
func TestMapViewsEqual(t *testing.T) {
testCases := []struct {
name string
a, b map[string]string
want bool
}{
{
name: "both_nil",
a: nil,
b: nil,
want: true,
},
{
name: "both_empty",
a: map[string]string{},
b: map[string]string{},
want: true,
},
{
name: "one_nil",
a: nil,
b: map[string]string{"a": "1"},
want: false,
},
{
name: "different_length",
a: map[string]string{"a": "1"},
b: map[string]string{"a": "1", "b": "2"},
want: false,
},
{
name: "different_values",
a: map[string]string{"a": "1"},
b: map[string]string{"a": "2"},
want: false,
},
{
name: "different_keys",
a: map[string]string{"a": "1"},
b: map[string]string{"b": "1"},
want: false,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
got := MapViewsEqual(MapOf(tc.a), MapOf(tc.b))
if got != tc.want {
t.Errorf("MapViewsEqual: got=%v, want %v", got, tc.want)
}
got = MapViewsEqualFunc(MapOf(tc.a), MapOf(tc.b), func(a, b string) bool {
return a == b
})
if got != tc.want {
t.Errorf("MapViewsEqualFunc: got=%v, want %v", got, tc.want)
}
})
}
}
func TestMapViewsEqualFunc(t *testing.T) {
// Test that we can compare maps with two different non-comparable
// values using a custom comparison function.
type customStruct1 struct {
_ structs.Incomparable
Field1 string
}
type customStruct2 struct {
_ structs.Incomparable
Field2 string
}
a := map[string]customStruct1{"a": {Field1: "1"}}
b := map[string]customStruct2{"a": {Field2: "1"}}
got := MapViewsEqualFunc(MapOf(a), MapOf(b), func(a customStruct1, b customStruct2) bool {
return a.Field1 == b.Field2
})
if !got {
t.Errorf("MapViewsEqualFunc: got=%v, want true", got)
}
}