diff --git a/types/views/views.go b/types/views/views.go index eae8c0b16..4addc6448 100644 --- a/types/views/views.go +++ b/types/views/views.go @@ -513,6 +513,47 @@ func (m Map[K, V]) AsMap() map[K]V { return maps.Clone(m.ж) } +// NOTE: the type constraints for MapViewsEqual and MapViewsEqualFunc are based +// on those for maps.Equal and maps.EqualFunc. + +// MapViewsEqual returns whether the two given [Map]s are equal. Both K and V +// must be comparable; if V is non-comparable, use [MapViewsEqualFunc] instead. +func MapViewsEqual[K, V comparable](a, b Map[K, V]) bool { + if a.Len() != b.Len() || a.IsNil() != b.IsNil() { + return false + } + if a.IsNil() { + return true // both nil; can exit early + } + + for k, v := range a.All() { + bv, ok := b.GetOk(k) + if !ok || v != bv { + return false + } + } + return true +} + +// MapViewsEqualFunc returns whether the two given [Map]s are equal, using the +// given function to compare two values. +func MapViewsEqualFunc[K comparable, V1, V2 any](a Map[K, V1], b Map[K, V2], eq func(V1, V2) bool) bool { + if a.Len() != b.Len() || a.IsNil() != b.IsNil() { + return false + } + if a.IsNil() { + return true // both nil; can exit early + } + + for k, v := range a.All() { + bv, ok := b.GetOk(k) + if !ok || !eq(v, bv) { + return false + } + } + return true +} + // MapRangeFn is the func called from a Map.Range call. // Implementations should return false to stop range. type MapRangeFn[K comparable, V any] func(k K, v V) (cont bool) diff --git a/types/views/views_test.go b/types/views/views_test.go index 8a1ff3fdd..51b086a4e 100644 --- a/types/views/views_test.go +++ b/types/views/views_test.go @@ -15,6 +15,7 @@ import ( "unsafe" qt "github.com/frankban/quicktest" + "tailscale.com/types/structs" ) type viewStruct struct { @@ -501,3 +502,87 @@ func TestMapFnIter(t *testing.T) { t.Errorf("got %q; want %q", got, want) } } + +func TestMapViewsEqual(t *testing.T) { + testCases := []struct { + name string + a, b map[string]string + want bool + }{ + { + name: "both_nil", + a: nil, + b: nil, + want: true, + }, + { + name: "both_empty", + a: map[string]string{}, + b: map[string]string{}, + want: true, + }, + { + name: "one_nil", + a: nil, + b: map[string]string{"a": "1"}, + want: false, + }, + { + name: "different_length", + a: map[string]string{"a": "1"}, + b: map[string]string{"a": "1", "b": "2"}, + want: false, + }, + { + name: "different_values", + a: map[string]string{"a": "1"}, + b: map[string]string{"a": "2"}, + want: false, + }, + { + name: "different_keys", + a: map[string]string{"a": "1"}, + b: map[string]string{"b": "1"}, + want: false, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + got := MapViewsEqual(MapOf(tc.a), MapOf(tc.b)) + if got != tc.want { + t.Errorf("MapViewsEqual: got=%v, want %v", got, tc.want) + } + + got = MapViewsEqualFunc(MapOf(tc.a), MapOf(tc.b), func(a, b string) bool { + return a == b + }) + if got != tc.want { + t.Errorf("MapViewsEqualFunc: got=%v, want %v", got, tc.want) + } + }) + } +} + +func TestMapViewsEqualFunc(t *testing.T) { + // Test that we can compare maps with two different non-comparable + // values using a custom comparison function. + type customStruct1 struct { + _ structs.Incomparable + Field1 string + } + type customStruct2 struct { + _ structs.Incomparable + Field2 string + } + + a := map[string]customStruct1{"a": {Field1: "1"}} + b := map[string]customStruct2{"a": {Field2: "1"}} + + got := MapViewsEqualFunc(MapOf(a), MapOf(b), func(a customStruct1, b customStruct2) bool { + return a.Field1 == b.Field2 + }) + if !got { + t.Errorf("MapViewsEqualFunc: got=%v, want true", got) + } +}