From fc28c8e7f39d83e75dfd6009c789c0a9739ba9bd Mon Sep 17 00:00:00 2001 From: Nick Khyl Date: Mon, 8 Jul 2024 10:11:00 -0500 Subject: [PATCH] cmd/cloner, cmd/viewer, util/codegen: add support for generic types and interfaces This adds support for generic types and interfaces to our cloner and viewer codegens. It updates these packages to determine whether to make shallow or deep copies based on the type parameter constraints. Additionally, if a template parameter or an interface type has View() and Clone() methods, we'll use them for getters and the cloner of the owning structure. Updates #12736 Signed-off-by: Nick Khyl --- cmd/cloner/cloner.go | 87 +++++++++-- cmd/viewer/tests/tests.go | 56 +++++++- cmd/viewer/tests/tests_clone.go | 232 +++++++++++++++++++++++++++--- cmd/viewer/tests/tests_view.go | 246 ++++++++++++++++++++++++++++++-- cmd/viewer/viewer.go | 182 ++++++++++++++++------- ipn/ipn_clone.go | 31 +++- tailcfg/tailcfg_clone.go | 36 ++++- util/codegen/codegen.go | 107 ++++++++++++-- util/codegen/codegen_test.go | 176 +++++++++++++++++++++++ 9 files changed, 1039 insertions(+), 114 deletions(-) create mode 100644 util/codegen/codegen_test.go diff --git a/cmd/cloner/cloner.go b/cmd/cloner/cloner.go index 937fd9059..b4e940b2d 100644 --- a/cmd/cloner/cloner.go +++ b/cmd/cloner/cloner.go @@ -91,16 +91,19 @@ func gen(buf *bytes.Buffer, it *codegen.ImportTracker, typ *types.Named) { } name := typ.Obj().Name() + typeParams := typ.Origin().TypeParams() + _, typeParamNames := codegen.FormatTypeParams(typeParams, it) + nameWithParams := name + typeParamNames fmt.Fprintf(buf, "// Clone makes a deep copy of %s.\n", name) fmt.Fprintf(buf, "// The result aliases no memory with the original.\n") - fmt.Fprintf(buf, "func (src *%s) Clone() *%s {\n", name, name) + fmt.Fprintf(buf, "func (src *%s) Clone() *%s {\n", nameWithParams, nameWithParams) writef := func(format string, args ...any) { fmt.Fprintf(buf, "\t"+format+"\n", args...) } writef("if src == nil {") writef("\treturn nil") writef("}") - writef("dst := new(%s)", name) + writef("dst := new(%s)", nameWithParams) writef("*dst = *src") for i := range t.NumFields() { fname := t.Field(i).Name() @@ -126,16 +129,23 @@ func gen(buf *bytes.Buffer, it *codegen.ImportTracker, typ *types.Named) { writef("dst.%s = make([]%s, len(src.%s))", fname, n, fname) writef("for i := range dst.%s {", fname) if ptr, isPtr := ft.Elem().(*types.Pointer); isPtr { - if _, isBasic := ptr.Elem().Underlying().(*types.Basic); isBasic { - it.Import("tailscale.com/types/ptr") - writef("if src.%s[i] == nil { dst.%s[i] = nil } else {", fname, fname) - writef("\tdst.%s[i] = ptr.To(*src.%s[i])", fname, fname) - writef("}") + writef("if src.%s[i] == nil { dst.%s[i] = nil } else {", fname, fname) + if codegen.ContainsPointers(ptr.Elem()) { + if _, isIface := ptr.Elem().Underlying().(*types.Interface); isIface { + it.Import("tailscale.com/types/ptr") + writef("\tdst.%s[i] = ptr.To((*src.%s[i]).Clone())", fname, fname) + } else { + writef("\tdst.%s[i] = src.%s[i].Clone()", fname, fname) + } } else { - writef("\tdst.%s[i] = src.%s[i].Clone()", fname, fname) + it.Import("tailscale.com/types/ptr") + writef("\tdst.%s[i] = ptr.To(*src.%s[i])", fname, fname) } + writef("}") } else if ft.Elem().String() == "encoding/json.RawMessage" { writef("\tdst.%s[i] = append(src.%s[i][:0:0], src.%s[i]...)", fname, fname, fname) + } else if _, isIface := ft.Elem().Underlying().(*types.Interface); isIface { + writef("\tdst.%s[i] = src.%s[i].Clone()", fname, fname) } else { writef("\tdst.%s[i] = *src.%s[i].Clone()", fname, fname) } @@ -145,14 +155,19 @@ func gen(buf *bytes.Buffer, it *codegen.ImportTracker, typ *types.Named) { writef("dst.%s = append(src.%s[:0:0], src.%s...)", fname, fname, fname) } case *types.Pointer: - if named, _ := ft.Elem().(*types.Named); named != nil && codegen.ContainsPointers(ft.Elem()) { + base := ft.Elem() + hasPtrs := codegen.ContainsPointers(base) + if named, _ := base.(*types.Named); named != nil && hasPtrs { writef("dst.%s = src.%s.Clone()", fname, fname) continue } it.Import("tailscale.com/types/ptr") writef("if dst.%s != nil {", fname) - writef("\tdst.%s = ptr.To(*src.%s)", fname, fname) - if codegen.ContainsPointers(ft.Elem()) { + if _, isIface := base.Underlying().(*types.Interface); isIface && hasPtrs { + writef("\tdst.%s = ptr.To((*src.%s).Clone())", fname, fname) + } else if !hasPtrs { + writef("\tdst.%s = ptr.To(*src.%s)", fname, fname) + } else { writef("\t" + `panic("TODO pointers in pointers")`) } writef("}") @@ -172,18 +187,50 @@ func gen(buf *bytes.Buffer, it *codegen.ImportTracker, typ *types.Named) { writef("if dst.%s != nil {", fname) writef("\tdst.%s = map[%s]%s{}", fname, it.QualifiedName(ft.Key()), it.QualifiedName(elem)) writef("\tfor k, v := range src.%s {", fname) - switch elem.(type) { + + switch elem := elem.Underlying().(type) { case *types.Pointer: - writef("\t\tdst.%s[k] = v.Clone()", fname) + writef("\t\tif v == nil { dst.%s[k] = nil } else {", fname) + if base := elem.Elem().Underlying(); codegen.ContainsPointers(base) { + if _, isIface := base.(*types.Interface); isIface { + it.Import("tailscale.com/types/ptr") + writef("\t\t\tdst.%s[k] = ptr.To((*v).Clone())", fname) + } else { + writef("\t\t\tdst.%s[k] = v.Clone()", fname) + } + } else { + it.Import("tailscale.com/types/ptr") + writef("\t\t\tdst.%s[k] = ptr.To(*v)", fname) + } + writef("}") + case *types.Interface: + if cloneResultType := methodResultType(elem, "Clone"); cloneResultType != nil { + if _, isPtr := cloneResultType.(*types.Pointer); isPtr { + writef("\t\tdst.%s[k] = *(v.Clone())", fname) + } else { + writef("\t\tdst.%s[k] = v.Clone()", fname) + } + } else { + writef(`panic("%s (%v) does not have a Clone method")`, fname, elem) + } default: writef("\t\tdst.%s[k] = *(v.Clone())", fname) } + writef("\t}") writef("}") } else { it.Import("maps") writef("\tdst.%s = maps.Clone(src.%s)", fname, fname) } + case *types.Interface: + // If ft is an interface with a "Clone() ft" method, it can be used to clone the field. + // This includes scenarios where ft is a constrained type parameter. + if cloneResultType := methodResultType(ft, "Clone"); cloneResultType.Underlying() == ft { + writef("dst.%s = src.%s.Clone()", fname, fname) + continue + } + writef(`panic("%s (%v) does not have a compatible Clone method")`, fname, ft) default: writef(`panic("TODO: %s (%T)")`, fname, ft) } @@ -191,7 +238,7 @@ func gen(buf *bytes.Buffer, it *codegen.ImportTracker, typ *types.Named) { writef("return dst") fmt.Fprintf(buf, "}\n\n") - buf.Write(codegen.AssertStructUnchanged(t, name, "Clone", it)) + buf.Write(codegen.AssertStructUnchanged(t, name, typeParams, "Clone", it)) } // hasBasicUnderlying reports true when typ.Underlying() is a slice or a map. @@ -203,3 +250,15 @@ func hasBasicUnderlying(typ types.Type) bool { return false } } + +func methodResultType(typ types.Type, method string) types.Type { + viewMethod := codegen.LookupMethod(typ, method) + if viewMethod == nil { + return nil + } + sig, ok := viewMethod.Type().(*types.Signature) + if !ok || sig.Results().Len() != 1 { + return nil + } + return sig.Results().At(0).Type() +} diff --git a/cmd/viewer/tests/tests.go b/cmd/viewer/tests/tests.go index 55413403b..ed4d6914a 100644 --- a/cmd/viewer/tests/tests.go +++ b/cmd/viewer/tests/tests.go @@ -7,9 +7,12 @@ import ( "fmt" "net/netip" + + "golang.org/x/exp/constraints" + "tailscale.com/types/views" ) -//go:generate go run tailscale.com/cmd/viewer --type=StructWithPtrs,StructWithoutPtrs,Map,StructWithSlices,OnlyGetClone,StructWithEmbedded --clone-only-type=OnlyGetClone +//go:generate go run tailscale.com/cmd/viewer --type=StructWithPtrs,StructWithoutPtrs,Map,StructWithSlices,OnlyGetClone,StructWithEmbedded,GenericIntStruct,GenericNoPtrsStruct,GenericCloneableStruct --clone-only-type=OnlyGetClone type StructWithoutPtrs struct { Int int @@ -25,12 +28,12 @@ type Map struct { SlicesWithPtrs map[string][]*StructWithPtrs SlicesWithoutPtrs map[string][]*StructWithoutPtrs StructWithoutPtrKey map[StructWithoutPtrs]int `json:"-"` + StructWithPtr map[string]StructWithPtrs // Unsupported views. SliceIntPtr map[string][]*int PointerKey map[*string]int `json:"-"` StructWithPtrKey map[StructWithPtrs]int `json:"-"` - StructWithPtr map[string]StructWithPtrs } type StructWithPtrs struct { @@ -50,12 +53,14 @@ type StructWithSlices struct { Values []StructWithoutPtrs ValuePointers []*StructWithoutPtrs StructPointers []*StructWithPtrs - Structs []StructWithPtrs - Ints []*int Slice []string Prefixes []netip.Prefix Data []byte + + // Unsupported views. + Structs []StructWithPtrs + Ints []*int } type OnlyGetClone struct { @@ -66,3 +71,46 @@ type StructWithEmbedded struct { A *StructWithPtrs StructWithSlices } + +type GenericIntStruct[T constraints.Integer] struct { + Value T + Pointer *T + Slice []T + Map map[string]T + + // Unsupported views. + PtrSlice []*T + PtrKeyMap map[*T]string `json:"-"` + PtrValueMap map[string]*T + SliceMap map[string][]T +} + +type BasicType interface { + ~bool | constraints.Integer | constraints.Float | constraints.Complex | ~string +} + +type GenericNoPtrsStruct[T StructWithoutPtrs | netip.Prefix | BasicType] struct { + Value T + Pointer *T + Slice []T + Map map[string]T + + // Unsupported views. + PtrSlice []*T + PtrKeyMap map[*T]string `json:"-"` + PtrValueMap map[string]*T + SliceMap map[string][]T +} + +type GenericCloneableStruct[T views.ViewCloner[T, V], V views.StructView[T]] struct { + Value T + Slice []T + Map map[string]T + + // Unsupported views. + Pointer *T + PtrSlice []*T + PtrKeyMap map[*T]string `json:"-"` + PtrValueMap map[string]*T + SliceMap map[string][]T +} diff --git a/cmd/viewer/tests/tests_clone.go b/cmd/viewer/tests/tests_clone.go index 2d8c1ba31..ec5631da9 100644 --- a/cmd/viewer/tests/tests_clone.go +++ b/cmd/viewer/tests/tests_clone.go @@ -9,7 +9,9 @@ "maps" "net/netip" + "golang.org/x/exp/constraints" "tailscale.com/types/ptr" + "tailscale.com/types/views" ) // Clone makes a deep copy of StructWithPtrs. @@ -71,13 +73,21 @@ func (src *Map) Clone() *Map { if dst.StructPtrWithPtr != nil { dst.StructPtrWithPtr = map[string]*StructWithPtrs{} for k, v := range src.StructPtrWithPtr { - dst.StructPtrWithPtr[k] = v.Clone() + if v == nil { + dst.StructPtrWithPtr[k] = nil + } else { + dst.StructPtrWithPtr[k] = v.Clone() + } } } if dst.StructPtrWithoutPtr != nil { dst.StructPtrWithoutPtr = map[string]*StructWithoutPtrs{} for k, v := range src.StructPtrWithoutPtr { - dst.StructPtrWithoutPtr[k] = v.Clone() + if v == nil { + dst.StructPtrWithoutPtr[k] = nil + } else { + dst.StructPtrWithoutPtr[k] = ptr.To(*v) + } } } dst.StructWithoutPtr = maps.Clone(src.StructWithoutPtr) @@ -94,6 +104,12 @@ func (src *Map) Clone() *Map { } } dst.StructWithoutPtrKey = maps.Clone(src.StructWithoutPtrKey) + if dst.StructWithPtr != nil { + dst.StructWithPtr = map[string]StructWithPtrs{} + for k, v := range src.StructWithPtr { + dst.StructWithPtr[k] = *(v.Clone()) + } + } if dst.SliceIntPtr != nil { dst.SliceIntPtr = map[string][]*int{} for k := range src.SliceIntPtr { @@ -102,12 +118,6 @@ func (src *Map) Clone() *Map { } dst.PointerKey = maps.Clone(src.PointerKey) dst.StructWithPtrKey = maps.Clone(src.StructWithPtrKey) - if dst.StructWithPtr != nil { - dst.StructWithPtr = map[string]StructWithPtrs{} - for k, v := range src.StructWithPtr { - dst.StructWithPtr[k] = *(v.Clone()) - } - } return dst } @@ -121,10 +131,10 @@ func (src *Map) Clone() *Map { SlicesWithPtrs map[string][]*StructWithPtrs SlicesWithoutPtrs map[string][]*StructWithoutPtrs StructWithoutPtrKey map[StructWithoutPtrs]int + StructWithPtr map[string]StructWithPtrs SliceIntPtr map[string][]*int PointerKey map[*string]int StructWithPtrKey map[StructWithPtrs]int - StructWithPtr map[string]StructWithPtrs }{}) // Clone makes a deep copy of StructWithSlices. @@ -139,15 +149,26 @@ func (src *StructWithSlices) Clone() *StructWithSlices { if src.ValuePointers != nil { dst.ValuePointers = make([]*StructWithoutPtrs, len(src.ValuePointers)) for i := range dst.ValuePointers { - dst.ValuePointers[i] = src.ValuePointers[i].Clone() + if src.ValuePointers[i] == nil { + dst.ValuePointers[i] = nil + } else { + dst.ValuePointers[i] = ptr.To(*src.ValuePointers[i]) + } } } if src.StructPointers != nil { dst.StructPointers = make([]*StructWithPtrs, len(src.StructPointers)) for i := range dst.StructPointers { - dst.StructPointers[i] = src.StructPointers[i].Clone() + if src.StructPointers[i] == nil { + dst.StructPointers[i] = nil + } else { + dst.StructPointers[i] = src.StructPointers[i].Clone() + } } } + dst.Slice = append(src.Slice[:0:0], src.Slice...) + dst.Prefixes = append(src.Prefixes[:0:0], src.Prefixes...) + dst.Data = append(src.Data[:0:0], src.Data...) if src.Structs != nil { dst.Structs = make([]StructWithPtrs, len(src.Structs)) for i := range dst.Structs { @@ -164,9 +185,6 @@ func (src *StructWithSlices) Clone() *StructWithSlices { } } } - dst.Slice = append(src.Slice[:0:0], src.Slice...) - dst.Prefixes = append(src.Prefixes[:0:0], src.Prefixes...) - dst.Data = append(src.Data[:0:0], src.Data...) return dst } @@ -175,11 +193,11 @@ func (src *StructWithSlices) Clone() *StructWithSlices { Values []StructWithoutPtrs ValuePointers []*StructWithoutPtrs StructPointers []*StructWithPtrs - Structs []StructWithPtrs - Ints []*int Slice []string Prefixes []netip.Prefix Data []byte + Structs []StructWithPtrs + Ints []*int }{}) // Clone makes a deep copy of OnlyGetClone. @@ -216,3 +234,185 @@ func (src *StructWithEmbedded) Clone() *StructWithEmbedded { A *StructWithPtrs StructWithSlices }{}) + +// Clone makes a deep copy of GenericIntStruct. +// The result aliases no memory with the original. +func (src *GenericIntStruct[T]) Clone() *GenericIntStruct[T] { + if src == nil { + return nil + } + dst := new(GenericIntStruct[T]) + *dst = *src + if dst.Pointer != nil { + dst.Pointer = ptr.To(*src.Pointer) + } + dst.Slice = append(src.Slice[:0:0], src.Slice...) + dst.Map = maps.Clone(src.Map) + if src.PtrSlice != nil { + dst.PtrSlice = make([]*T, len(src.PtrSlice)) + for i := range dst.PtrSlice { + if src.PtrSlice[i] == nil { + dst.PtrSlice[i] = nil + } else { + dst.PtrSlice[i] = ptr.To(*src.PtrSlice[i]) + } + } + } + dst.PtrKeyMap = maps.Clone(src.PtrKeyMap) + if dst.PtrValueMap != nil { + dst.PtrValueMap = map[string]*T{} + for k, v := range src.PtrValueMap { + if v == nil { + dst.PtrValueMap[k] = nil + } else { + dst.PtrValueMap[k] = ptr.To(*v) + } + } + } + if dst.SliceMap != nil { + dst.SliceMap = map[string][]T{} + for k := range src.SliceMap { + dst.SliceMap[k] = append([]T{}, src.SliceMap[k]...) + } + } + return dst +} + +// A compilation failure here means this code must be regenerated, with the command at the top of this file. +func _GenericIntStructCloneNeedsRegeneration[T constraints.Integer](GenericIntStruct[T]) { + _GenericIntStructCloneNeedsRegeneration(struct { + Value T + Pointer *T + Slice []T + Map map[string]T + PtrSlice []*T + PtrKeyMap map[*T]string `json:"-"` + PtrValueMap map[string]*T + SliceMap map[string][]T + }{}) +} + +// Clone makes a deep copy of GenericNoPtrsStruct. +// The result aliases no memory with the original. +func (src *GenericNoPtrsStruct[T]) Clone() *GenericNoPtrsStruct[T] { + if src == nil { + return nil + } + dst := new(GenericNoPtrsStruct[T]) + *dst = *src + if dst.Pointer != nil { + dst.Pointer = ptr.To(*src.Pointer) + } + dst.Slice = append(src.Slice[:0:0], src.Slice...) + dst.Map = maps.Clone(src.Map) + if src.PtrSlice != nil { + dst.PtrSlice = make([]*T, len(src.PtrSlice)) + for i := range dst.PtrSlice { + if src.PtrSlice[i] == nil { + dst.PtrSlice[i] = nil + } else { + dst.PtrSlice[i] = ptr.To(*src.PtrSlice[i]) + } + } + } + dst.PtrKeyMap = maps.Clone(src.PtrKeyMap) + if dst.PtrValueMap != nil { + dst.PtrValueMap = map[string]*T{} + for k, v := range src.PtrValueMap { + if v == nil { + dst.PtrValueMap[k] = nil + } else { + dst.PtrValueMap[k] = ptr.To(*v) + } + } + } + if dst.SliceMap != nil { + dst.SliceMap = map[string][]T{} + for k := range src.SliceMap { + dst.SliceMap[k] = append([]T{}, src.SliceMap[k]...) + } + } + return dst +} + +// A compilation failure here means this code must be regenerated, with the command at the top of this file. +func _GenericNoPtrsStructCloneNeedsRegeneration[T StructWithoutPtrs | netip.Prefix | BasicType](GenericNoPtrsStruct[T]) { + _GenericNoPtrsStructCloneNeedsRegeneration(struct { + Value T + Pointer *T + Slice []T + Map map[string]T + PtrSlice []*T + PtrKeyMap map[*T]string `json:"-"` + PtrValueMap map[string]*T + SliceMap map[string][]T + }{}) +} + +// Clone makes a deep copy of GenericCloneableStruct. +// The result aliases no memory with the original. +func (src *GenericCloneableStruct[T, V]) Clone() *GenericCloneableStruct[T, V] { + if src == nil { + return nil + } + dst := new(GenericCloneableStruct[T, V]) + *dst = *src + dst.Value = src.Value.Clone() + if src.Slice != nil { + dst.Slice = make([]T, len(src.Slice)) + for i := range dst.Slice { + dst.Slice[i] = src.Slice[i].Clone() + } + } + if dst.Map != nil { + dst.Map = map[string]T{} + for k, v := range src.Map { + dst.Map[k] = v.Clone() + } + } + if dst.Pointer != nil { + dst.Pointer = ptr.To((*src.Pointer).Clone()) + } + if src.PtrSlice != nil { + dst.PtrSlice = make([]*T, len(src.PtrSlice)) + for i := range dst.PtrSlice { + if src.PtrSlice[i] == nil { + dst.PtrSlice[i] = nil + } else { + dst.PtrSlice[i] = ptr.To((*src.PtrSlice[i]).Clone()) + } + } + } + dst.PtrKeyMap = maps.Clone(src.PtrKeyMap) + if dst.PtrValueMap != nil { + dst.PtrValueMap = map[string]*T{} + for k, v := range src.PtrValueMap { + if v == nil { + dst.PtrValueMap[k] = nil + } else { + dst.PtrValueMap[k] = ptr.To((*v).Clone()) + } + } + } + if dst.SliceMap != nil { + dst.SliceMap = map[string][]T{} + for k := range src.SliceMap { + dst.SliceMap[k] = append([]T{}, src.SliceMap[k]...) + } + } + return dst +} + +// A compilation failure here means this code must be regenerated, with the command at the top of this file. +func _GenericCloneableStructCloneNeedsRegeneration[T views.ViewCloner[T, V], V views.StructView[T]](GenericCloneableStruct[T, V]) { + _GenericCloneableStructCloneNeedsRegeneration(struct { + Value T + Slice []T + Map map[string]T + Pointer *T + PtrSlice []*T + PtrKeyMap map[*T]string `json:"-"` + PtrValueMap map[string]*T + SliceMap map[string][]T + }{}) +} diff --git a/cmd/viewer/tests/tests_view.go b/cmd/viewer/tests/tests_view.go index f8966d20e..9a337f5aa 100644 --- a/cmd/viewer/tests/tests_view.go +++ b/cmd/viewer/tests/tests_view.go @@ -10,10 +10,11 @@ "errors" "net/netip" + "golang.org/x/exp/constraints" "tailscale.com/types/views" ) -//go:generate go run tailscale.com/cmd/cloner -clonefunc=false -type=StructWithPtrs,StructWithoutPtrs,Map,StructWithSlices,OnlyGetClone,StructWithEmbedded +//go:generate go run tailscale.com/cmd/cloner -clonefunc=false -type=StructWithPtrs,StructWithoutPtrs,Map,StructWithSlices,OnlyGetClone,StructWithEmbedded,GenericIntStruct,GenericNoPtrsStruct,GenericCloneableStruct // View returns a readonly view of StructWithPtrs. func (p *StructWithPtrs) View() StructWithPtrsView { @@ -221,15 +222,15 @@ func (v MapView) SlicesWithoutPtrs() views.MapFn[string, []*StructWithoutPtrs, v func (v MapView) StructWithoutPtrKey() views.Map[StructWithoutPtrs, int] { return views.MapOf(v.ж.StructWithoutPtrKey) } -func (v MapView) SliceIntPtr() map[string][]*int { panic("unsupported") } -func (v MapView) PointerKey() map[*string]int { panic("unsupported") } -func (v MapView) StructWithPtrKey() map[StructWithPtrs]int { panic("unsupported") } func (v MapView) StructWithPtr() views.MapFn[string, StructWithPtrs, StructWithPtrsView] { return views.MapFnOf(v.ж.StructWithPtr, func(t StructWithPtrs) StructWithPtrsView { return t.View() }) } +func (v MapView) SliceIntPtr() map[string][]*int { panic("unsupported") } +func (v MapView) PointerKey() map[*string]int { panic("unsupported") } +func (v MapView) StructWithPtrKey() map[StructWithPtrs]int { panic("unsupported") } // A compilation failure here means this code must be regenerated, with the command at the top of this file. var _MapViewNeedsRegeneration = Map(struct { @@ -241,10 +242,10 @@ func (v MapView) StructWithPtr() views.MapFn[string, StructWithPtrs, StructWithP SlicesWithPtrs map[string][]*StructWithPtrs SlicesWithoutPtrs map[string][]*StructWithoutPtrs StructWithoutPtrKey map[StructWithoutPtrs]int + StructWithPtr map[string]StructWithPtrs SliceIntPtr map[string][]*int PointerKey map[*string]int StructWithPtrKey map[StructWithPtrs]int - StructWithPtr map[string]StructWithPtrs }{}) // View returns a readonly view of StructWithSlices. @@ -301,24 +302,24 @@ func (v StructWithSlicesView) ValuePointers() views.SliceView[*StructWithoutPtrs func (v StructWithSlicesView) StructPointers() views.SliceView[*StructWithPtrs, StructWithPtrsView] { return views.SliceOfViews[*StructWithPtrs, StructWithPtrsView](v.ж.StructPointers) } -func (v StructWithSlicesView) Structs() StructWithPtrs { panic("unsupported") } -func (v StructWithSlicesView) Ints() *int { panic("unsupported") } func (v StructWithSlicesView) Slice() views.Slice[string] { return views.SliceOf(v.ж.Slice) } func (v StructWithSlicesView) Prefixes() views.Slice[netip.Prefix] { return views.SliceOf(v.ж.Prefixes) } func (v StructWithSlicesView) Data() views.ByteSlice[[]byte] { return views.ByteSliceOf(v.ж.Data) } +func (v StructWithSlicesView) Structs() StructWithPtrs { panic("unsupported") } +func (v StructWithSlicesView) Ints() *int { panic("unsupported") } // A compilation failure here means this code must be regenerated, with the command at the top of this file. var _StructWithSlicesViewNeedsRegeneration = StructWithSlices(struct { Values []StructWithoutPtrs ValuePointers []*StructWithoutPtrs StructPointers []*StructWithPtrs - Structs []StructWithPtrs - Ints []*int Slice []string Prefixes []netip.Prefix Data []byte + Structs []StructWithPtrs + Ints []*int }{}) // View returns a readonly view of StructWithEmbedded. @@ -376,3 +377,230 @@ func (v StructWithEmbeddedView) StructWithSlices() StructWithSlicesView { A *StructWithPtrs StructWithSlices }{}) + +// View returns a readonly view of GenericIntStruct. +func (p *GenericIntStruct[T]) View() GenericIntStructView[T] { + return GenericIntStructView[T]{ж: p} +} + +// GenericIntStructView[T] provides a read-only view over GenericIntStruct[T]. +// +// Its methods should only be called if `Valid()` returns true. +type GenericIntStructView[T constraints.Integer] struct { + // ж is the underlying mutable value, named with a hard-to-type + // character that looks pointy like a pointer. + // It is named distinctively to make you think of how dangerous it is to escape + // to callers. You must not let callers be able to mutate it. + ж *GenericIntStruct[T] +} + +// Valid reports whether underlying value is non-nil. +func (v GenericIntStructView[T]) Valid() bool { return v.ж != nil } + +// AsStruct returns a clone of the underlying value which aliases no memory with +// the original. +func (v GenericIntStructView[T]) AsStruct() *GenericIntStruct[T] { + if v.ж == nil { + return nil + } + return v.ж.Clone() +} + +func (v GenericIntStructView[T]) MarshalJSON() ([]byte, error) { return json.Marshal(v.ж) } + +func (v *GenericIntStructView[T]) UnmarshalJSON(b []byte) error { + if v.ж != nil { + return errors.New("already initialized") + } + if len(b) == 0 { + return nil + } + var x GenericIntStruct[T] + if err := json.Unmarshal(b, &x); err != nil { + return err + } + v.ж = &x + return nil +} + +func (v GenericIntStructView[T]) Value() T { return v.ж.Value } +func (v GenericIntStructView[T]) Pointer() *T { + if v.ж.Pointer == nil { + return nil + } + x := *v.ж.Pointer + return &x +} + +func (v GenericIntStructView[T]) Slice() views.Slice[T] { return views.SliceOf(v.ж.Slice) } + +func (v GenericIntStructView[T]) Map() views.Map[string, T] { return views.MapOf(v.ж.Map) } +func (v GenericIntStructView[T]) PtrSlice() *T { panic("unsupported") } +func (v GenericIntStructView[T]) PtrKeyMap() map[*T]string { panic("unsupported") } +func (v GenericIntStructView[T]) PtrValueMap() map[string]*T { panic("unsupported") } +func (v GenericIntStructView[T]) SliceMap() map[string][]T { panic("unsupported") } + +// A compilation failure here means this code must be regenerated, with the command at the top of this file. +func _GenericIntStructViewNeedsRegeneration[T constraints.Integer](GenericIntStruct[T]) { + _GenericIntStructViewNeedsRegeneration(struct { + Value T + Pointer *T + Slice []T + Map map[string]T + PtrSlice []*T + PtrKeyMap map[*T]string `json:"-"` + PtrValueMap map[string]*T + SliceMap map[string][]T + }{}) +} + +// View returns a readonly view of GenericNoPtrsStruct. +func (p *GenericNoPtrsStruct[T]) View() GenericNoPtrsStructView[T] { + return GenericNoPtrsStructView[T]{ж: p} +} + +// GenericNoPtrsStructView[T] provides a read-only view over GenericNoPtrsStruct[T]. +// +// Its methods should only be called if `Valid()` returns true. +type GenericNoPtrsStructView[T StructWithoutPtrs | netip.Prefix | BasicType] struct { + // ж is the underlying mutable value, named with a hard-to-type + // character that looks pointy like a pointer. + // It is named distinctively to make you think of how dangerous it is to escape + // to callers. You must not let callers be able to mutate it. + ж *GenericNoPtrsStruct[T] +} + +// Valid reports whether underlying value is non-nil. +func (v GenericNoPtrsStructView[T]) Valid() bool { return v.ж != nil } + +// AsStruct returns a clone of the underlying value which aliases no memory with +// the original. +func (v GenericNoPtrsStructView[T]) AsStruct() *GenericNoPtrsStruct[T] { + if v.ж == nil { + return nil + } + return v.ж.Clone() +} + +func (v GenericNoPtrsStructView[T]) MarshalJSON() ([]byte, error) { return json.Marshal(v.ж) } + +func (v *GenericNoPtrsStructView[T]) UnmarshalJSON(b []byte) error { + if v.ж != nil { + return errors.New("already initialized") + } + if len(b) == 0 { + return nil + } + var x GenericNoPtrsStruct[T] + if err := json.Unmarshal(b, &x); err != nil { + return err + } + v.ж = &x + return nil +} + +func (v GenericNoPtrsStructView[T]) Value() T { return v.ж.Value } +func (v GenericNoPtrsStructView[T]) Pointer() *T { + if v.ж.Pointer == nil { + return nil + } + x := *v.ж.Pointer + return &x +} + +func (v GenericNoPtrsStructView[T]) Slice() views.Slice[T] { return views.SliceOf(v.ж.Slice) } + +func (v GenericNoPtrsStructView[T]) Map() views.Map[string, T] { return views.MapOf(v.ж.Map) } +func (v GenericNoPtrsStructView[T]) PtrSlice() *T { panic("unsupported") } +func (v GenericNoPtrsStructView[T]) PtrKeyMap() map[*T]string { panic("unsupported") } +func (v GenericNoPtrsStructView[T]) PtrValueMap() map[string]*T { panic("unsupported") } +func (v GenericNoPtrsStructView[T]) SliceMap() map[string][]T { panic("unsupported") } + +// A compilation failure here means this code must be regenerated, with the command at the top of this file. +func _GenericNoPtrsStructViewNeedsRegeneration[T StructWithoutPtrs | netip.Prefix | BasicType](GenericNoPtrsStruct[T]) { + _GenericNoPtrsStructViewNeedsRegeneration(struct { + Value T + Pointer *T + Slice []T + Map map[string]T + PtrSlice []*T + PtrKeyMap map[*T]string `json:"-"` + PtrValueMap map[string]*T + SliceMap map[string][]T + }{}) +} + +// View returns a readonly view of GenericCloneableStruct. +func (p *GenericCloneableStruct[T, V]) View() GenericCloneableStructView[T, V] { + return GenericCloneableStructView[T, V]{ж: p} +} + +// GenericCloneableStructView[T, V] provides a read-only view over GenericCloneableStruct[T, V]. +// +// Its methods should only be called if `Valid()` returns true. +type GenericCloneableStructView[T views.ViewCloner[T, V], V views.StructView[T]] struct { + // ж is the underlying mutable value, named with a hard-to-type + // character that looks pointy like a pointer. + // It is named distinctively to make you think of how dangerous it is to escape + // to callers. You must not let callers be able to mutate it. + ж *GenericCloneableStruct[T, V] +} + +// Valid reports whether underlying value is non-nil. +func (v GenericCloneableStructView[T, V]) Valid() bool { return v.ж != nil } + +// AsStruct returns a clone of the underlying value which aliases no memory with +// the original. +func (v GenericCloneableStructView[T, V]) AsStruct() *GenericCloneableStruct[T, V] { + if v.ж == nil { + return nil + } + return v.ж.Clone() +} + +func (v GenericCloneableStructView[T, V]) MarshalJSON() ([]byte, error) { return json.Marshal(v.ж) } + +func (v *GenericCloneableStructView[T, V]) UnmarshalJSON(b []byte) error { + if v.ж != nil { + return errors.New("already initialized") + } + if len(b) == 0 { + return nil + } + var x GenericCloneableStruct[T, V] + if err := json.Unmarshal(b, &x); err != nil { + return err + } + v.ж = &x + return nil +} + +func (v GenericCloneableStructView[T, V]) Value() V { return v.ж.Value.View() } +func (v GenericCloneableStructView[T, V]) Slice() views.SliceView[T, V] { + return views.SliceOfViews[T, V](v.ж.Slice) +} + +func (v GenericCloneableStructView[T, V]) Map() views.MapFn[string, T, V] { + return views.MapFnOf(v.ж.Map, func(t T) V { + return t.View() + }) +} +func (v GenericCloneableStructView[T, V]) Pointer() map[string]T { panic("unsupported") } +func (v GenericCloneableStructView[T, V]) PtrSlice() *T { panic("unsupported") } +func (v GenericCloneableStructView[T, V]) PtrKeyMap() map[*T]string { panic("unsupported") } +func (v GenericCloneableStructView[T, V]) PtrValueMap() map[string]*T { panic("unsupported") } +func (v GenericCloneableStructView[T, V]) SliceMap() map[string][]T { panic("unsupported") } + +// A compilation failure here means this code must be regenerated, with the command at the top of this file. +func _GenericCloneableStructViewNeedsRegeneration[T views.ViewCloner[T, V], V views.StructView[T]](GenericCloneableStruct[T, V]) { + _GenericCloneableStructViewNeedsRegeneration(struct { + Value T + Slice []T + Map map[string]T + Pointer *T + PtrSlice []*T + PtrKeyMap map[*T]string `json:"-"` + PtrValueMap map[string]*T + SliceMap map[string][]T + }{}) +} diff --git a/cmd/viewer/viewer.go b/cmd/viewer/viewer.go index f89210268..557b6e459 100644 --- a/cmd/viewer/viewer.go +++ b/cmd/viewer/viewer.go @@ -20,43 +20,43 @@ const viewTemplateStr = `{{define "common"}} // View returns a readonly view of {{.StructName}}. -func (p *{{.StructName}}) View() {{.ViewName}} { - return {{.ViewName}}{ж: p} +func (p *{{.StructName}}{{.TypeParamNames}}) View() {{.ViewName}}{{.TypeParamNames}} { + return {{.ViewName}}{{.TypeParamNames}}{ж: p} } -// {{.ViewName}} provides a read-only view over {{.StructName}}. +// {{.ViewName}}{{.TypeParamNames}} provides a read-only view over {{.StructName}}{{.TypeParamNames}}. // // Its methods should only be called if ` + "`Valid()`" + ` returns true. -type {{.ViewName}} struct { +type {{.ViewName}}{{.TypeParams}} struct { // ж is the underlying mutable value, named with a hard-to-type // character that looks pointy like a pointer. // It is named distinctively to make you think of how dangerous it is to escape // to callers. You must not let callers be able to mutate it. - ж *{{.StructName}} + ж *{{.StructName}}{{.TypeParamNames}} } // Valid reports whether underlying value is non-nil. -func (v {{.ViewName}}) Valid() bool { return v.ж != nil } +func (v {{.ViewName}}{{.TypeParamNames}}) Valid() bool { return v.ж != nil } // AsStruct returns a clone of the underlying value which aliases no memory with // the original. -func (v {{.ViewName}}) AsStruct() *{{.StructName}}{ +func (v {{.ViewName}}{{.TypeParamNames}}) AsStruct() *{{.StructName}}{{.TypeParamNames}}{ if v.ж == nil { return nil } return v.ж.Clone() } -func (v {{.ViewName}}) MarshalJSON() ([]byte, error) { return json.Marshal(v.ж) } +func (v {{.ViewName}}{{.TypeParamNames}}) MarshalJSON() ([]byte, error) { return json.Marshal(v.ж) } -func (v *{{.ViewName}}) UnmarshalJSON(b []byte) error { +func (v *{{.ViewName}}{{.TypeParamNames}}) UnmarshalJSON(b []byte) error { if v.ж != nil { return errors.New("already initialized") } if len(b) == 0 { return nil } - var x {{.StructName}} + var x {{.StructName}}{{.TypeParamNames}} if err := json.Unmarshal(b, &x); err != nil { return err } @@ -65,17 +65,17 @@ func (v *{{.ViewName}}) UnmarshalJSON(b []byte) error { } {{end}} -{{define "valueField"}}func (v {{.ViewName}}) {{.FieldName}}() {{.FieldType}} { return v.ж.{{.FieldName}} } +{{define "valueField"}}func (v {{.ViewName}}{{.TypeParamNames}}) {{.FieldName}}() {{.FieldType}} { return v.ж.{{.FieldName}} } {{end}} -{{define "byteSliceField"}}func (v {{.ViewName}}) {{.FieldName}}() views.ByteSlice[{{.FieldType}}] { return views.ByteSliceOf(v.ж.{{.FieldName}}) } +{{define "byteSliceField"}}func (v {{.ViewName}}{{.TypeParamNames}}) {{.FieldName}}() views.ByteSlice[{{.FieldType}}] { return views.ByteSliceOf(v.ж.{{.FieldName}}) } {{end}} -{{define "sliceField"}}func (v {{.ViewName}}) {{.FieldName}}() views.Slice[{{.FieldType}}] { return views.SliceOf(v.ж.{{.FieldName}}) } +{{define "sliceField"}}func (v {{.ViewName}}{{.TypeParamNames}}) {{.FieldName}}() views.Slice[{{.FieldType}}] { return views.SliceOf(v.ж.{{.FieldName}}) } {{end}} -{{define "viewSliceField"}}func (v {{.ViewName}}) {{.FieldName}}() views.SliceView[{{.FieldType}},{{.FieldViewName}}] { return views.SliceOfViews[{{.FieldType}},{{.FieldViewName}}](v.ж.{{.FieldName}}) } +{{define "viewSliceField"}}func (v {{.ViewName}}{{.TypeParamNames}}) {{.FieldName}}() views.SliceView[{{.FieldType}},{{.FieldViewName}}] { return views.SliceOfViews[{{.FieldType}},{{.FieldViewName}}](v.ж.{{.FieldName}}) } {{end}} -{{define "viewField"}}func (v {{.ViewName}}) {{.FieldName}}() {{.FieldType}}View { return v.ж.{{.FieldName}}.View() } +{{define "viewField"}}func (v {{.ViewName}}{{.TypeParamNames}}) {{.FieldName}}() {{.FieldViewName}} { return v.ж.{{.FieldName}}.View() } {{end}} -{{define "valuePointerField"}}func (v {{.ViewName}}) {{.FieldName}}() {{.FieldType}} { +{{define "valuePointerField"}}func (v {{.ViewName}}{{.TypeParamNames}}) {{.FieldName}}() {{.FieldType}} { if v.ж.{{.FieldName}} == nil { return nil } @@ -85,21 +85,21 @@ func (v *{{.ViewName}}) UnmarshalJSON(b []byte) error { {{end}} {{define "mapField"}} -func(v {{.ViewName}}) {{.FieldName}}() views.Map[{{.MapKeyType}},{{.MapValueType}}] { return views.MapOf(v.ж.{{.FieldName}})} +func(v {{.ViewName}}{{.TypeParamNames}}) {{.FieldName}}() views.Map[{{.MapKeyType}},{{.MapValueType}}] { return views.MapOf(v.ж.{{.FieldName}})} {{end}} {{define "mapFnField"}} -func(v {{.ViewName}}) {{.FieldName}}() views.MapFn[{{.MapKeyType}},{{.MapValueType}},{{.MapValueView}}] { return views.MapFnOf(v.ж.{{.FieldName}}, func (t {{.MapValueType}}) {{.MapValueView}} { +func(v {{.ViewName}}{{.TypeParamNames}}) {{.FieldName}}() views.MapFn[{{.MapKeyType}},{{.MapValueType}},{{.MapValueView}}] { return views.MapFnOf(v.ж.{{.FieldName}}, func (t {{.MapValueType}}) {{.MapValueView}} { return {{.MapFn}} })} {{end}} {{define "mapSliceField"}} -func(v {{.ViewName}}) {{.FieldName}}() views.MapSlice[{{.MapKeyType}},{{.MapValueType}}] { return views.MapSliceOf(v.ж.{{.FieldName}}) } +func(v {{.ViewName}}{{.TypeParamNames}}) {{.FieldName}}() views.MapSlice[{{.MapKeyType}},{{.MapValueType}}] { return views.MapSliceOf(v.ж.{{.FieldName}}) } {{end}} -{{define "unsupportedField"}}func(v {{.ViewName}}) {{.FieldName}}() {{.FieldType}} {panic("unsupported")} +{{define "unsupportedField"}}func(v {{.ViewName}}{{.TypeParamNames}}) {{.FieldName}}() {{.FieldType}} {panic("unsupported")} {{end}} -{{define "stringFunc"}}func(v {{.ViewName}}) String() string { return v.ж.String() } +{{define "stringFunc"}}func(v {{.ViewName}}{{.TypeParamNames}}) String() string { return v.ж.String() } {{end}} -{{define "equalFunc"}}func(v {{.ViewName}}) Equal(v2 {{.ViewName}}) bool { return v.ж.Equal(v2.ж) } +{{define "equalFunc"}}func(v {{.ViewName}}{{.TypeParamNames}}) Equal(v2 {{.ViewName}}{{.TypeParamNames}}) bool { return v.ж.Equal(v2.ж) } {{end}} ` @@ -131,8 +131,11 @@ func genView(buf *bytes.Buffer, it *codegen.ImportTracker, typ *types.Named, thi it.Import("errors") args := struct { - StructName string - ViewName string + StructName string + ViewName string + TypeParams string // e.g. [T constraints.Integer] + TypeParamNames string // e.g. [T] + FieldName string FieldType string FieldViewName string @@ -143,9 +146,12 @@ func genView(buf *bytes.Buffer, it *codegen.ImportTracker, typ *types.Named, thi MapFn string }{ StructName: typ.Obj().Name(), - ViewName: typ.Obj().Name() + "View", + ViewName: typ.Origin().Obj().Name() + "View", } + typeParams := typ.Origin().TypeParams() + args.TypeParams, args.TypeParamNames = codegen.FormatTypeParams(typeParams, it) + writeTemplate := func(name string) { if err := viewTemplate.ExecuteTemplate(buf, name, args); err != nil { log.Fatal(err) @@ -182,19 +188,35 @@ func genView(buf *bytes.Buffer, it *codegen.ImportTracker, typ *types.Named, thi it.Import("tailscale.com/types/views") shallow, deep, base := requiresCloning(elem) if deep { - if _, isPtr := elem.(*types.Pointer); isPtr { - args.FieldViewName = it.QualifiedName(base) + "View" - writeTemplate("viewSliceField") - } else { - writeTemplate("unsupportedField") + switch elem.Underlying().(type) { + case *types.Pointer: + if _, isIface := base.Underlying().(*types.Interface); !isIface { + args.FieldViewName = appendNameSuffix(it.QualifiedName(base), "View") + writeTemplate("viewSliceField") + } else { + writeTemplate("unsupportedField") + } + continue + case *types.Interface: + if viewType := viewTypeForValueType(elem); viewType != nil { + args.FieldViewName = it.QualifiedName(viewType) + writeTemplate("viewSliceField") + continue + } } + writeTemplate("unsupportedField") continue } else if shallow { - if _, isBasic := base.(*types.Basic); isBasic { + switch base.Underlying().(type) { + case *types.Basic, *types.Interface: writeTemplate("unsupportedField") - } else { - args.FieldViewName = it.QualifiedName(base) + "View" - writeTemplate("viewSliceField") + default: + if _, isIface := base.Underlying().(*types.Interface); !isIface { + args.FieldViewName = appendNameSuffix(it.QualifiedName(base), "View") + writeTemplate("viewSliceField") + } else { + writeTemplate("unsupportedField") + } } continue } @@ -205,6 +227,7 @@ func genView(buf *bytes.Buffer, it *codegen.ImportTracker, typ *types.Named, thi strucT := underlying args.FieldType = it.QualifiedName(fieldType) if codegen.ContainsPointers(strucT) { + args.FieldViewName = appendNameSuffix(args.FieldType, "View") writeTemplate("viewField") continue } @@ -229,7 +252,7 @@ func genView(buf *bytes.Buffer, it *codegen.ImportTracker, typ *types.Named, thi args.MapFn = "t.View()" template = "mapFnField" args.MapValueType = it.QualifiedName(mElem) - args.MapValueView = args.MapValueType + "View" + args.MapValueView = appendNameSuffix(args.MapValueType, "View") } else { template = "mapField" args.MapValueType = it.QualifiedName(mElem) @@ -249,15 +272,20 @@ func genView(buf *bytes.Buffer, it *codegen.ImportTracker, typ *types.Named, thi case *types.Pointer: ptr := x pElem := ptr.Elem() - switch pElem.(type) { - case *types.Struct, *types.Named: - ptrType := it.QualifiedName(ptr) - viewType := it.QualifiedName(pElem) + "View" - args.MapFn = fmt.Sprintf("views.SliceOfViews[%v,%v](t)", ptrType, viewType) - args.MapValueView = fmt.Sprintf("views.SliceView[%v,%v]", ptrType, viewType) - args.MapValueType = "[]" + ptrType - template = "mapFnField" - default: + template = "unsupportedField" + if _, isIface := pElem.Underlying().(*types.Interface); !isIface { + switch pElem.(type) { + case *types.Struct, *types.Named: + ptrType := it.QualifiedName(ptr) + viewType := appendNameSuffix(it.QualifiedName(pElem), "View") + args.MapFn = fmt.Sprintf("views.SliceOfViews[%v,%v](t)", ptrType, viewType) + args.MapValueView = fmt.Sprintf("views.SliceView[%v,%v]", ptrType, viewType) + args.MapValueType = "[]" + ptrType + template = "mapFnField" + default: + template = "unsupportedField" + } + } else { template = "unsupportedField" } default: @@ -266,13 +294,29 @@ func genView(buf *bytes.Buffer, it *codegen.ImportTracker, typ *types.Named, thi case *types.Pointer: ptr := u pElem := ptr.Elem() - switch pElem.(type) { - case *types.Struct, *types.Named: - args.MapValueType = it.QualifiedName(ptr) - args.MapValueView = it.QualifiedName(pElem) + "View" + if _, isIface := pElem.Underlying().(*types.Interface); !isIface { + switch pElem.(type) { + case *types.Struct, *types.Named: + args.MapValueType = it.QualifiedName(ptr) + args.MapValueView = appendNameSuffix(it.QualifiedName(pElem), "View") + args.MapFn = "t.View()" + template = "mapFnField" + default: + template = "unsupportedField" + } + } else { + template = "unsupportedField" + } + case *types.Interface, *types.TypeParam: + if viewType := viewTypeForValueType(u); viewType != nil { + args.MapValueType = it.QualifiedName(u) + args.MapValueView = it.QualifiedName(viewType) args.MapFn = "t.View()" template = "mapFnField" - default: + } else if !codegen.ContainsPointers(u) { + args.MapValueType = it.QualifiedName(mElem) + template = "mapField" + } else { template = "unsupportedField" } default: @@ -283,14 +327,28 @@ func genView(buf *bytes.Buffer, it *codegen.ImportTracker, typ *types.Named, thi case *types.Pointer: ptr := underlying _, deep, base := requiresCloning(ptr) + if deep { - args.FieldType = it.QualifiedName(base) - writeTemplate("viewField") + if _, isIface := base.Underlying().(*types.Interface); !isIface { + args.FieldType = it.QualifiedName(base) + args.FieldViewName = appendNameSuffix(args.FieldType, "View") + writeTemplate("viewField") + } else { + writeTemplate("unsupportedField") + } } else { args.FieldType = it.QualifiedName(ptr) writeTemplate("valuePointerField") } continue + case *types.Interface: + // If fieldType is an interface with a "View() {ViewType}" method, it can be used to clone the field. + // This includes scenarios where fieldType is a constrained type parameter. + if viewType := viewTypeForValueType(underlying); viewType != nil { + args.FieldViewName = it.QualifiedName(viewType) + writeTemplate("viewField") + continue + } } writeTemplate("unsupportedField") } @@ -318,7 +376,27 @@ func genView(buf *bytes.Buffer, it *codegen.ImportTracker, typ *types.Named, thi } } fmt.Fprintf(buf, "\n") - buf.Write(codegen.AssertStructUnchanged(t, args.StructName, "View", it)) + buf.Write(codegen.AssertStructUnchanged(t, args.StructName, typeParams, "View", it)) +} + +func appendNameSuffix(name, suffix string) string { + if idx := strings.IndexRune(name, '['); idx != -1 { + // Insert suffix after the type name, but before type parameters. + return name[:idx] + suffix + name[idx:] + } + return name + suffix +} + +func viewTypeForValueType(typ types.Type) types.Type { + viewMethod := codegen.LookupMethod(typ, "View") + if viewMethod == nil { + return nil + } + sig, ok := viewMethod.Type().(*types.Signature) + if !ok || sig.Results().Len() != 1 { + return nil + } + return sig.Results().At(0).Type() } var ( diff --git a/ipn/ipn_clone.go b/ipn/ipn_clone.go index 9457c50f0..de35b60a7 100644 --- a/ipn/ipn_clone.go +++ b/ipn/ipn_clone.go @@ -14,6 +14,7 @@ "tailscale.com/types/opt" "tailscale.com/types/persist" "tailscale.com/types/preftype" + "tailscale.com/types/ptr" ) // Clone makes a deep copy of Prefs. @@ -29,7 +30,11 @@ func (src *Prefs) Clone() *Prefs { if src.DriveShares != nil { dst.DriveShares = make([]*drive.Share, len(src.DriveShares)) for i := range dst.DriveShares { - dst.DriveShares[i] = src.DriveShares[i].Clone() + if src.DriveShares[i] == nil { + dst.DriveShares[i] = nil + } else { + dst.DriveShares[i] = src.DriveShares[i].Clone() + } } } dst.Persist = src.Persist.Clone() @@ -81,20 +86,32 @@ func (src *ServeConfig) Clone() *ServeConfig { if dst.TCP != nil { dst.TCP = map[uint16]*TCPPortHandler{} for k, v := range src.TCP { - dst.TCP[k] = v.Clone() + if v == nil { + dst.TCP[k] = nil + } else { + dst.TCP[k] = ptr.To(*v) + } } } if dst.Web != nil { dst.Web = map[HostPort]*WebServerConfig{} for k, v := range src.Web { - dst.Web[k] = v.Clone() + if v == nil { + dst.Web[k] = nil + } else { + dst.Web[k] = v.Clone() + } } } dst.AllowFunnel = maps.Clone(src.AllowFunnel) if dst.Foreground != nil { dst.Foreground = map[string]*ServeConfig{} for k, v := range src.Foreground { - dst.Foreground[k] = v.Clone() + if v == nil { + dst.Foreground[k] = nil + } else { + dst.Foreground[k] = v.Clone() + } } } return dst @@ -157,7 +174,11 @@ func (src *WebServerConfig) Clone() *WebServerConfig { if dst.Handlers != nil { dst.Handlers = map[string]*HTTPHandler{} for k, v := range src.Handlers { - dst.Handlers[k] = v.Clone() + if v == nil { + dst.Handlers[k] = nil + } else { + dst.Handlers[k] = ptr.To(*v) + } } } return dst diff --git a/tailcfg/tailcfg_clone.go b/tailcfg/tailcfg_clone.go index 823fe6810..a98efe4d1 100644 --- a/tailcfg/tailcfg_clone.go +++ b/tailcfg/tailcfg_clone.go @@ -77,7 +77,11 @@ func (src *Node) Clone() *Node { if src.ExitNodeDNSResolvers != nil { dst.ExitNodeDNSResolvers = make([]*dnstype.Resolver, len(src.ExitNodeDNSResolvers)) for i := range dst.ExitNodeDNSResolvers { - dst.ExitNodeDNSResolvers[i] = src.ExitNodeDNSResolvers[i].Clone() + if src.ExitNodeDNSResolvers[i] == nil { + dst.ExitNodeDNSResolvers[i] = nil + } else { + dst.ExitNodeDNSResolvers[i] = src.ExitNodeDNSResolvers[i].Clone() + } } } return dst @@ -244,7 +248,11 @@ func (src *DNSConfig) Clone() *DNSConfig { if src.Resolvers != nil { dst.Resolvers = make([]*dnstype.Resolver, len(src.Resolvers)) for i := range dst.Resolvers { - dst.Resolvers[i] = src.Resolvers[i].Clone() + if src.Resolvers[i] == nil { + dst.Resolvers[i] = nil + } else { + dst.Resolvers[i] = src.Resolvers[i].Clone() + } } } if dst.Routes != nil { @@ -256,7 +264,11 @@ func (src *DNSConfig) Clone() *DNSConfig { if src.FallbackResolvers != nil { dst.FallbackResolvers = make([]*dnstype.Resolver, len(src.FallbackResolvers)) for i := range dst.FallbackResolvers { - dst.FallbackResolvers[i] = src.FallbackResolvers[i].Clone() + if src.FallbackResolvers[i] == nil { + dst.FallbackResolvers[i] = nil + } else { + dst.FallbackResolvers[i] = src.FallbackResolvers[i].Clone() + } } } dst.Domains = append(src.Domains[:0:0], src.Domains...) @@ -393,7 +405,11 @@ func (src *DERPRegion) Clone() *DERPRegion { if src.Nodes != nil { dst.Nodes = make([]*DERPNode, len(src.Nodes)) for i := range dst.Nodes { - dst.Nodes[i] = src.Nodes[i].Clone() + if src.Nodes[i] == nil { + dst.Nodes[i] = nil + } else { + dst.Nodes[i] = ptr.To(*src.Nodes[i]) + } } } return dst @@ -422,7 +438,11 @@ func (src *DERPMap) Clone() *DERPMap { if dst.Regions != nil { dst.Regions = map[int]*DERPRegion{} for k, v := range src.Regions { - dst.Regions[k] = v.Clone() + if v == nil { + dst.Regions[k] = nil + } else { + dst.Regions[k] = v.Clone() + } } } return dst @@ -476,7 +496,11 @@ func (src *SSHRule) Clone() *SSHRule { if src.Principals != nil { dst.Principals = make([]*SSHPrincipal, len(src.Principals)) for i := range dst.Principals { - dst.Principals[i] = src.Principals[i].Clone() + if src.Principals[i] == nil { + dst.Principals[i] = nil + } else { + dst.Principals[i] = src.Principals[i].Clone() + } } } dst.SSHUsers = maps.Clone(src.SSHUsers) diff --git a/util/codegen/codegen.go b/util/codegen/codegen.go index 6c2b3c71e..13dbc94a4 100644 --- a/util/codegen/codegen.go +++ b/util/codegen/codegen.go @@ -27,9 +27,9 @@ func LoadTypes(buildTags string, pkgName string) (*packages.Package, map[string]*types.Named, error) { cfg := &packages.Config{ Mode: packages.NeedTypes | packages.NeedTypesInfo | packages.NeedSyntax | packages.NeedName, - Tests: false, + Tests: buildTags == "test", } - if buildTags != "" { + if buildTags != "" && !cfg.Tests { cfg.BuildFlags = []string{"-tags=" + buildTags} } @@ -37,6 +37,9 @@ func LoadTypes(buildTags string, pkgName string) (*packages.Package, map[string] if err != nil { return nil, nil, err } + if cfg.Tests { + pkgs = testPackages(pkgs) + } if len(pkgs) != 1 { return nil, nil, fmt.Errorf("wrong number of packages: %d", len(pkgs)) } @@ -44,6 +47,17 @@ func LoadTypes(buildTags string, pkgName string) (*packages.Package, map[string] return pkg, namedTypes(pkg), nil } +func testPackages(pkgs []*packages.Package) []*packages.Package { + var testPackages []*packages.Package + for _, pkg := range pkgs { + testPackageID := fmt.Sprintf("%[1]s [%[1]s.test]", pkg.PkgPath) + if pkg.ID == testPackageID { + testPackages = append(testPackages, pkg) + } + } + return testPackages +} + // HasNoClone reports whether the provided tag has `codegen:noclone`. func HasNoClone(structTag string) bool { val := reflect.StructTag(structTag).Get("codegen") @@ -193,13 +207,21 @@ func namedTypes(pkg *packages.Package) map[string]*types.Named { // ctx is a single-word context for this assertion, such as "Clone". // If non-nil, AssertStructUnchanged will add elements to imports // for each package path that the caller must import for the returned code to compile. -func AssertStructUnchanged(t *types.Struct, tname, ctx string, it *ImportTracker) []byte { +func AssertStructUnchanged(t *types.Struct, tname string, params *types.TypeParamList, ctx string, it *ImportTracker) []byte { buf := new(bytes.Buffer) w := func(format string, args ...any) { fmt.Fprintf(buf, format+"\n", args...) } w("// A compilation failure here means this code must be regenerated, with the command at the top of this file.") - w("var _%s%sNeedsRegeneration = %s(struct {", tname, ctx, tname) + + hasTypeParams := params != nil && params.Len() > 0 + if hasTypeParams { + constraints, identifiers := FormatTypeParams(params, it) + w("func _%s%sNeedsRegeneration%s (%s%s) {", tname, ctx, constraints, tname, identifiers) + w("_%s%sNeedsRegeneration(struct {", tname, ctx) + } else { + w("var _%s%sNeedsRegeneration = %s(struct {", tname, ctx, tname) + } for i := range t.NumFields() { st := t.Field(i) @@ -209,14 +231,25 @@ func AssertStructUnchanged(t *types.Struct, tname, ctx string, it *ImportTracker continue } qname := it.QualifiedName(ft) + var tag string + if hasTypeParams { + tag = t.Tag(i) + if tag != "" { + tag = "`" + tag + "`" + } + } if st.Anonymous() { - w("\t%s ", fname) + w("\t%s %s", fname, tag) } else { - w("\t%s %s", fname, qname) + w("\t%s %s %s", fname, qname, tag) } } - w("}{})\n") + if hasTypeParams { + w("}{})\n}") + } else { + w("}{})") + } return buf.Bytes() } @@ -242,10 +275,21 @@ func ContainsPointers(typ types.Type) bool { switch ft := typ.Underlying().(type) { case *types.Array: return ContainsPointers(ft.Elem()) + case *types.Basic: + if ft.Kind() == types.UnsafePointer { + return true + } case *types.Chan: return true case *types.Interface: - return true // a little too broad + if ft.Empty() || ft.IsMethodSet() { + return true + } + for i := 0; i < ft.NumEmbeddeds(); i++ { + if ContainsPointers(ft.EmbeddedType(i)) { + return true + } + } case *types.Map: return true case *types.Pointer: @@ -258,6 +302,12 @@ func ContainsPointers(typ types.Type) bool { return true } } + case *types.Union: + for i := range ft.Len() { + if ContainsPointers(ft.Term(i).Type()) { + return true + } + } } return false } @@ -273,3 +323,44 @@ func IsViewType(typ types.Type) bool { } return t.Field(0).Name() == "ж" } + +// FormatTypeParams formats the specified params and returns two strings: +// - constraints are comma-separated type parameters and their constraints in square brackets (e.g. [T any, V constraints.Integer]) +// - names are comma-separated type parameter names in square brackets (e.g. [T, V]) +// +// If params is nil or empty, both return values are empty strings. +func FormatTypeParams(params *types.TypeParamList, it *ImportTracker) (constraints, names string) { + if params == nil || params.Len() == 0 { + return "", "" + } + var constraintList, nameList []string + for i := range params.Len() { + param := params.At(i) + name := param.Obj().Name() + constraint := it.QualifiedName(param.Constraint()) + nameList = append(nameList, name) + constraintList = append(constraintList, name+" "+constraint) + } + constraints = "[" + strings.Join(constraintList, ", ") + "]" + names = "[" + strings.Join(nameList, ", ") + "]" + return constraints, names +} + +// LookupMethod returns the method with the specified name in t, or nil if the method does not exist. +func LookupMethod(t types.Type, name string) *types.Func { + if t, ok := t.(*types.Named); ok { + for i := 0; i < t.NumMethods(); i++ { + if method := t.Method(i); method.Name() == name { + return method + } + } + } + if t, ok := t.Underlying().(*types.Interface); ok { + for i := 0; i < t.NumMethods(); i++ { + if method := t.Method(i); method.Name() == name { + return method + } + } + } + return nil +} diff --git a/util/codegen/codegen_test.go b/util/codegen/codegen_test.go new file mode 100644 index 000000000..5f4a13979 --- /dev/null +++ b/util/codegen/codegen_test.go @@ -0,0 +1,176 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package codegen + +import ( + "log" + "net/netip" + "testing" + "unsafe" + + "golang.org/x/exp/constraints" +) + +type AnyParam[T any] struct { + V T +} + +type AnyParamPhantom[T any] struct { +} + +type IntegerParam[T constraints.Integer] struct { + V T +} + +type FloatParam[T constraints.Float] struct { + V T +} + +type StringLikeParam[T ~string] struct { + V T +} + +type BasicType interface { + ~bool | constraints.Integer | constraints.Float | constraints.Complex | ~string +} + +type BasicTypeParam[T BasicType] struct { + V T +} + +type IntPtr *int + +type IntPtrParam[T IntPtr] struct { + V T +} + +type IntegerPtr interface { + *int | *int32 | *int64 +} + +type IntegerPtrParam[T IntegerPtr] struct { + V T +} + +type IntegerParamPtr[T constraints.Integer] struct { + V *T +} + +type IntegerSliceParam[T constraints.Integer] struct { + V []T +} + +type IntegerMapParam[T constraints.Integer] struct { + V []T +} + +type UnsafePointerParam[T unsafe.Pointer] struct { + V T +} + +type ValueUnionParam[T netip.Prefix | BasicType] struct { + V T +} + +type ValueUnionParamPtr[T netip.Prefix | BasicType] struct { + V *T +} + +type PointerUnionParam[T netip.Prefix | BasicType | IntPtr] struct { + V T +} + +type Interface interface { + Method() +} + +type InterfaceParam[T Interface] struct { + V T +} + +func TestGenericContainsPointers(t *testing.T) { + tests := []struct { + typ string + wantPointer bool + }{ + { + typ: "AnyParam", + wantPointer: true, + }, + { + typ: "AnyParamPhantom", + wantPointer: false, // has a pointer type parameter, but no pointer fields + }, + { + typ: "IntegerParam", + wantPointer: false, + }, + { + typ: "FloatParam", + wantPointer: false, + }, + { + typ: "StringLikeParam", + wantPointer: false, + }, + { + typ: "BasicTypeParam", + wantPointer: false, + }, + { + typ: "IntPtrParam", + wantPointer: true, + }, + { + typ: "IntegerPtrParam", + wantPointer: true, + }, + { + typ: "IntegerParamPtr", + wantPointer: true, + }, + { + typ: "IntegerSliceParam", + wantPointer: true, + }, + { + typ: "IntegerMapParam", + wantPointer: true, + }, + { + typ: "UnsafePointerParam", + wantPointer: true, + }, + { + typ: "InterfaceParam", + wantPointer: true, + }, + { + typ: "ValueUnionParam", + wantPointer: false, + }, + { + typ: "ValueUnionParamPtr", + wantPointer: true, + }, + { + typ: "PointerUnionParam", + wantPointer: true, + }, + } + + _, namedTypes, err := LoadTypes("test", ".") + if err != nil { + log.Fatal(err) + } + + for _, tt := range tests { + t.Run(tt.typ, func(t *testing.T) { + typ := namedTypes[tt.typ] + if isPointer := ContainsPointers(typ); isPointer != tt.wantPointer { + t.Fatalf("ContainsPointers: got %v, want: %v", isPointer, tt.wantPointer) + } + }) + } +}