diff --git a/util/deephash/deephash.go b/util/deephash/deephash.go index 0795c2c07..d21ef7ca0 100644 --- a/util/deephash/deephash.go +++ b/util/deephash/deephash.go @@ -23,11 +23,13 @@ "crypto/sha256" "encoding/binary" "encoding/hex" + "fmt" "reflect" "sync" "time" "tailscale.com/util/hashx" + "tailscale.com/util/set" ) // There is much overlap between the theory of serialization and hashing. @@ -152,12 +154,90 @@ func Hash[T any](v *T) Sum { return h.sum() } +// Option is an optional argument to HasherForType. +type Option interface { + isOption() +} + +type fieldFilterOpt struct { + t reflect.Type + fields set.Set[string] + includeOnMatch bool // true to include fields, false to exclude them +} + +func (fieldFilterOpt) isOption() {} + +func (f fieldFilterOpt) filterStructField(sf reflect.StructField) (include bool) { + if f.fields.Contains(sf.Name) { + return f.includeOnMatch + } + return !f.includeOnMatch +} + +// IncludeFields returns an option that modifies the hashing for T to only +// include the named struct fields. +// +// T must be a struct type, and must match the type of the value passed to +// HasherForType. +func IncludeFields[T any](fields ...string) Option { + return newFieldFilter[T](true, fields) +} + +// ExcludeFields returns an option that modifies the hashing for T to include +// all struct fields of T except those provided in fields. +// +// T must be a struct type, and must match the type of the value passed to +// HasherForType. +func ExcludeFields[T any](fields ...string) Option { + return newFieldFilter[T](false, fields) +} + +func newFieldFilter[T any](include bool, fields []string) Option { + var zero T + t := reflect.TypeOf(&zero).Elem() + fieldSet := set.Set[string]{} + for _, f := range fields { + if _, ok := t.FieldByName(f); !ok { + panic(fmt.Sprintf("unknown field %q for type %v", f, t)) + } + fieldSet.Add(f) + } + return fieldFilterOpt{t, fieldSet, include} +} + // HasherForType returns a hash that is specialized for the provided type. -func HasherForType[T any]() func(*T) Sum { +// +// HasherForType panics if the opts are invalid for the provided type. +// +// Currently, at most one option can be provided (IncludeFields or +// ExcludeFields) and its type must match the type of T. Those restrictions may +// be removed in the future, along with documentation about their precedence +// when combined. +func HasherForType[T any](opts ...Option) func(*T) Sum { var v *T seedOnce.Do(initSeed) + if len(opts) > 1 { + panic("HasherForType only accepts one optional argument") // for now + } t := reflect.TypeOf(v).Elem() - hash := lookupTypeHasher(t) + var hash typeHasherFunc + for _, o := range opts { + switch o := o.(type) { + default: + panic(fmt.Sprintf("unknown HasherOpt %T", o)) + case fieldFilterOpt: + if t.Kind() != reflect.Struct { + panic("HasherForStructTypeWithFieldFilter requires T of kind struct") + } + if t != o.t { + panic(fmt.Sprintf("field filter for type %v does not match HasherForType type %v", o.t, t)) + } + hash = makeStructHasher(t, o.filterStructField) + } + } + if hash == nil { + hash = lookupTypeHasher(t) + } return func(v *T) (s Sum) { // This logic is identical to Hash, but pull out a few statements. h := hasherPool.Get().(*hasher) @@ -225,7 +305,7 @@ func makeTypeHasher(t reflect.Type) typeHasherFunc { case reflect.Slice: return makeSliceHasher(t) case reflect.Struct: - return makeStructHasher(t) + return makeStructHasher(t, keepAllStructFields) case reflect.Map: return makeMapHasher(t) case reflect.Pointer: @@ -353,9 +433,12 @@ func makeSliceHasher(t reflect.Type) typeHasherFunc { } } -func makeStructHasher(t reflect.Type) typeHasherFunc { +func keepAllStructFields(keepField reflect.StructField) bool { return true } + +func makeStructHasher(t reflect.Type, keepField func(reflect.StructField) bool) typeHasherFunc { type fieldHasher struct { - idx int // index of field for reflect.Type.Field(n); negative if memory is directly hashable + idx int // index of field for reflect.Type.Field(n); negative if memory is directly hashable + keep bool hash typeHasherFunc // only valid if idx is not negative offset uintptr size uintptr @@ -365,8 +448,8 @@ type fieldHasher struct { init := func() { for i, numField := 0, t.NumField(); i < numField; i++ { sf := t.Field(i) - f := fieldHasher{i, nil, sf.Offset, sf.Type.Size()} - if typeIsMemHashable(sf.Type) { + f := fieldHasher{i, keepField(sf), nil, sf.Offset, sf.Type.Size()} + if f.keep && typeIsMemHashable(sf.Type) { f.idx = -1 } @@ -390,6 +473,9 @@ type fieldHasher struct { return func(h *hasher, p pointer) { once.Do(init) for _, field := range fields { + if !field.keep { + continue + } pf := p.structField(field.idx, field.offset, field.size) if field.idx < 0 { h.HashBytes(pf.asMemory(field.size)) diff --git a/util/deephash/deephash_test.go b/util/deephash/deephash_test.go index f9e26995a..1da79a998 100644 --- a/util/deephash/deephash_test.go +++ b/util/deephash/deephash_test.go @@ -1066,6 +1066,51 @@ func TestAppendTo(t *testing.T) { } } +func TestFilterFields(t *testing.T) { + type T struct { + A int + B int + C int + } + + hashers := map[string]func(*T) Sum{ + "all": HasherForType[T](), + "ac": HasherForType[T](IncludeFields[T]("A", "C")), + "b": HasherForType[T](ExcludeFields[T]("A", "C")), + } + + tests := []struct { + hasher string + a, b T + wantEq bool + }{ + {"all", T{1, 2, 3}, T{1, 2, 3}, true}, + {"all", T{1, 2, 3}, T{0, 2, 3}, false}, + {"all", T{1, 2, 3}, T{1, 0, 3}, false}, + {"all", T{1, 2, 3}, T{1, 2, 0}, false}, + + {"ac", T{0, 0, 0}, T{0, 0, 0}, true}, + {"ac", T{1, 0, 1}, T{1, 1, 1}, true}, + {"ac", T{1, 1, 1}, T{1, 1, 0}, false}, + + {"b", T{0, 0, 0}, T{0, 0, 0}, true}, + {"b", T{1, 0, 1}, T{1, 1, 1}, false}, + {"b", T{1, 1, 1}, T{0, 1, 0}, true}, + } + for _, tt := range tests { + f, ok := hashers[tt.hasher] + if !ok { + t.Fatalf("bad test: unknown hasher %q", tt.hasher) + } + sum1 := f(&tt.a) + sum2 := f(&tt.b) + got := sum1 == sum2 + if got != tt.wantEq { + t.Errorf("hasher %q, for %+v and %v, got equal = %v; want %v", tt.hasher, tt.a, tt.b, got, tt.wantEq) + } + } +} + func BenchmarkAppendTo(b *testing.B) { b.ReportAllocs() v := getVal()