mirror of
https://github.com/tailscale/tailscale.git
synced 2025-01-07 16:17:41 +00:00
util/deephash: add IncludeFields, ExcludeFields HasherForType Options
Updates tailscale/corp#6198 Change-Id: Iafc18c5b947522cf07a42a56f35c0319cc7b1c94 Signed-off-by: Brad Fitzpatrick <bradfitz@tailscale.com>
This commit is contained in:
parent
e7d1538a2d
commit
4af22f3785
@ -23,11 +23,13 @@
|
||||
"crypto/sha256"
|
||||
"encoding/binary"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"tailscale.com/util/hashx"
|
||||
"tailscale.com/util/set"
|
||||
)
|
||||
|
||||
// There is much overlap between the theory of serialization and hashing.
|
||||
@ -152,12 +154,90 @@ func Hash[T any](v *T) Sum {
|
||||
return h.sum()
|
||||
}
|
||||
|
||||
// Option is an optional argument to HasherForType.
|
||||
type Option interface {
|
||||
isOption()
|
||||
}
|
||||
|
||||
type fieldFilterOpt struct {
|
||||
t reflect.Type
|
||||
fields set.Set[string]
|
||||
includeOnMatch bool // true to include fields, false to exclude them
|
||||
}
|
||||
|
||||
func (fieldFilterOpt) isOption() {}
|
||||
|
||||
func (f fieldFilterOpt) filterStructField(sf reflect.StructField) (include bool) {
|
||||
if f.fields.Contains(sf.Name) {
|
||||
return f.includeOnMatch
|
||||
}
|
||||
return !f.includeOnMatch
|
||||
}
|
||||
|
||||
// IncludeFields returns an option that modifies the hashing for T to only
|
||||
// include the named struct fields.
|
||||
//
|
||||
// T must be a struct type, and must match the type of the value passed to
|
||||
// HasherForType.
|
||||
func IncludeFields[T any](fields ...string) Option {
|
||||
return newFieldFilter[T](true, fields)
|
||||
}
|
||||
|
||||
// ExcludeFields returns an option that modifies the hashing for T to include
|
||||
// all struct fields of T except those provided in fields.
|
||||
//
|
||||
// T must be a struct type, and must match the type of the value passed to
|
||||
// HasherForType.
|
||||
func ExcludeFields[T any](fields ...string) Option {
|
||||
return newFieldFilter[T](false, fields)
|
||||
}
|
||||
|
||||
func newFieldFilter[T any](include bool, fields []string) Option {
|
||||
var zero T
|
||||
t := reflect.TypeOf(&zero).Elem()
|
||||
fieldSet := set.Set[string]{}
|
||||
for _, f := range fields {
|
||||
if _, ok := t.FieldByName(f); !ok {
|
||||
panic(fmt.Sprintf("unknown field %q for type %v", f, t))
|
||||
}
|
||||
fieldSet.Add(f)
|
||||
}
|
||||
return fieldFilterOpt{t, fieldSet, include}
|
||||
}
|
||||
|
||||
// HasherForType returns a hash that is specialized for the provided type.
|
||||
func HasherForType[T any]() func(*T) Sum {
|
||||
//
|
||||
// HasherForType panics if the opts are invalid for the provided type.
|
||||
//
|
||||
// Currently, at most one option can be provided (IncludeFields or
|
||||
// ExcludeFields) and its type must match the type of T. Those restrictions may
|
||||
// be removed in the future, along with documentation about their precedence
|
||||
// when combined.
|
||||
func HasherForType[T any](opts ...Option) func(*T) Sum {
|
||||
var v *T
|
||||
seedOnce.Do(initSeed)
|
||||
if len(opts) > 1 {
|
||||
panic("HasherForType only accepts one optional argument") // for now
|
||||
}
|
||||
t := reflect.TypeOf(v).Elem()
|
||||
hash := lookupTypeHasher(t)
|
||||
var hash typeHasherFunc
|
||||
for _, o := range opts {
|
||||
switch o := o.(type) {
|
||||
default:
|
||||
panic(fmt.Sprintf("unknown HasherOpt %T", o))
|
||||
case fieldFilterOpt:
|
||||
if t.Kind() != reflect.Struct {
|
||||
panic("HasherForStructTypeWithFieldFilter requires T of kind struct")
|
||||
}
|
||||
if t != o.t {
|
||||
panic(fmt.Sprintf("field filter for type %v does not match HasherForType type %v", o.t, t))
|
||||
}
|
||||
hash = makeStructHasher(t, o.filterStructField)
|
||||
}
|
||||
}
|
||||
if hash == nil {
|
||||
hash = lookupTypeHasher(t)
|
||||
}
|
||||
return func(v *T) (s Sum) {
|
||||
// This logic is identical to Hash, but pull out a few statements.
|
||||
h := hasherPool.Get().(*hasher)
|
||||
@ -225,7 +305,7 @@ func makeTypeHasher(t reflect.Type) typeHasherFunc {
|
||||
case reflect.Slice:
|
||||
return makeSliceHasher(t)
|
||||
case reflect.Struct:
|
||||
return makeStructHasher(t)
|
||||
return makeStructHasher(t, keepAllStructFields)
|
||||
case reflect.Map:
|
||||
return makeMapHasher(t)
|
||||
case reflect.Pointer:
|
||||
@ -353,9 +433,12 @@ func makeSliceHasher(t reflect.Type) typeHasherFunc {
|
||||
}
|
||||
}
|
||||
|
||||
func makeStructHasher(t reflect.Type) typeHasherFunc {
|
||||
func keepAllStructFields(keepField reflect.StructField) bool { return true }
|
||||
|
||||
func makeStructHasher(t reflect.Type, keepField func(reflect.StructField) bool) typeHasherFunc {
|
||||
type fieldHasher struct {
|
||||
idx int // index of field for reflect.Type.Field(n); negative if memory is directly hashable
|
||||
keep bool
|
||||
hash typeHasherFunc // only valid if idx is not negative
|
||||
offset uintptr
|
||||
size uintptr
|
||||
@ -365,8 +448,8 @@ type fieldHasher struct {
|
||||
init := func() {
|
||||
for i, numField := 0, t.NumField(); i < numField; i++ {
|
||||
sf := t.Field(i)
|
||||
f := fieldHasher{i, nil, sf.Offset, sf.Type.Size()}
|
||||
if typeIsMemHashable(sf.Type) {
|
||||
f := fieldHasher{i, keepField(sf), nil, sf.Offset, sf.Type.Size()}
|
||||
if f.keep && typeIsMemHashable(sf.Type) {
|
||||
f.idx = -1
|
||||
}
|
||||
|
||||
@ -390,6 +473,9 @@ type fieldHasher struct {
|
||||
return func(h *hasher, p pointer) {
|
||||
once.Do(init)
|
||||
for _, field := range fields {
|
||||
if !field.keep {
|
||||
continue
|
||||
}
|
||||
pf := p.structField(field.idx, field.offset, field.size)
|
||||
if field.idx < 0 {
|
||||
h.HashBytes(pf.asMemory(field.size))
|
||||
|
@ -1066,6 +1066,51 @@ func TestAppendTo(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestFilterFields(t *testing.T) {
|
||||
type T struct {
|
||||
A int
|
||||
B int
|
||||
C int
|
||||
}
|
||||
|
||||
hashers := map[string]func(*T) Sum{
|
||||
"all": HasherForType[T](),
|
||||
"ac": HasherForType[T](IncludeFields[T]("A", "C")),
|
||||
"b": HasherForType[T](ExcludeFields[T]("A", "C")),
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
hasher string
|
||||
a, b T
|
||||
wantEq bool
|
||||
}{
|
||||
{"all", T{1, 2, 3}, T{1, 2, 3}, true},
|
||||
{"all", T{1, 2, 3}, T{0, 2, 3}, false},
|
||||
{"all", T{1, 2, 3}, T{1, 0, 3}, false},
|
||||
{"all", T{1, 2, 3}, T{1, 2, 0}, false},
|
||||
|
||||
{"ac", T{0, 0, 0}, T{0, 0, 0}, true},
|
||||
{"ac", T{1, 0, 1}, T{1, 1, 1}, true},
|
||||
{"ac", T{1, 1, 1}, T{1, 1, 0}, false},
|
||||
|
||||
{"b", T{0, 0, 0}, T{0, 0, 0}, true},
|
||||
{"b", T{1, 0, 1}, T{1, 1, 1}, false},
|
||||
{"b", T{1, 1, 1}, T{0, 1, 0}, true},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
f, ok := hashers[tt.hasher]
|
||||
if !ok {
|
||||
t.Fatalf("bad test: unknown hasher %q", tt.hasher)
|
||||
}
|
||||
sum1 := f(&tt.a)
|
||||
sum2 := f(&tt.b)
|
||||
got := sum1 == sum2
|
||||
if got != tt.wantEq {
|
||||
t.Errorf("hasher %q, for %+v and %v, got equal = %v; want %v", tt.hasher, tt.a, tt.b, got, tt.wantEq)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkAppendTo(b *testing.B) {
|
||||
b.ReportAllocs()
|
||||
v := getVal()
|
||||
|
Loading…
x
Reference in New Issue
Block a user