diff --git a/cmd/cloner/cloner.go b/cmd/cloner/cloner.go index 937fd9059..b4e940b2d 100644 --- a/cmd/cloner/cloner.go +++ b/cmd/cloner/cloner.go @@ -91,16 +91,19 @@ func gen(buf *bytes.Buffer, it *codegen.ImportTracker, typ *types.Named) { } name := typ.Obj().Name() + typeParams := typ.Origin().TypeParams() + _, typeParamNames := codegen.FormatTypeParams(typeParams, it) + nameWithParams := name + typeParamNames fmt.Fprintf(buf, "// Clone makes a deep copy of %s.\n", name) fmt.Fprintf(buf, "// The result aliases no memory with the original.\n") - fmt.Fprintf(buf, "func (src *%s) Clone() *%s {\n", name, name) + fmt.Fprintf(buf, "func (src *%s) Clone() *%s {\n", nameWithParams, nameWithParams) writef := func(format string, args ...any) { fmt.Fprintf(buf, "\t"+format+"\n", args...) } writef("if src == nil {") writef("\treturn nil") writef("}") - writef("dst := new(%s)", name) + writef("dst := new(%s)", nameWithParams) writef("*dst = *src") for i := range t.NumFields() { fname := t.Field(i).Name() @@ -126,16 +129,23 @@ func gen(buf *bytes.Buffer, it *codegen.ImportTracker, typ *types.Named) { writef("dst.%s = make([]%s, len(src.%s))", fname, n, fname) writef("for i := range dst.%s {", fname) if ptr, isPtr := ft.Elem().(*types.Pointer); isPtr { - if _, isBasic := ptr.Elem().Underlying().(*types.Basic); isBasic { - it.Import("tailscale.com/types/ptr") - writef("if src.%s[i] == nil { dst.%s[i] = nil } else {", fname, fname) - writef("\tdst.%s[i] = ptr.To(*src.%s[i])", fname, fname) - writef("}") + writef("if src.%s[i] == nil { dst.%s[i] = nil } else {", fname, fname) + if codegen.ContainsPointers(ptr.Elem()) { + if _, isIface := ptr.Elem().Underlying().(*types.Interface); isIface { + it.Import("tailscale.com/types/ptr") + writef("\tdst.%s[i] = ptr.To((*src.%s[i]).Clone())", fname, fname) + } else { + writef("\tdst.%s[i] = src.%s[i].Clone()", fname, fname) + } } else { - writef("\tdst.%s[i] = src.%s[i].Clone()", fname, fname) + it.Import("tailscale.com/types/ptr") + writef("\tdst.%s[i] = ptr.To(*src.%s[i])", fname, fname) } + writef("}") } else if ft.Elem().String() == "encoding/json.RawMessage" { writef("\tdst.%s[i] = append(src.%s[i][:0:0], src.%s[i]...)", fname, fname, fname) + } else if _, isIface := ft.Elem().Underlying().(*types.Interface); isIface { + writef("\tdst.%s[i] = src.%s[i].Clone()", fname, fname) } else { writef("\tdst.%s[i] = *src.%s[i].Clone()", fname, fname) } @@ -145,14 +155,19 @@ func gen(buf *bytes.Buffer, it *codegen.ImportTracker, typ *types.Named) { writef("dst.%s = append(src.%s[:0:0], src.%s...)", fname, fname, fname) } case *types.Pointer: - if named, _ := ft.Elem().(*types.Named); named != nil && codegen.ContainsPointers(ft.Elem()) { + base := ft.Elem() + hasPtrs := codegen.ContainsPointers(base) + if named, _ := base.(*types.Named); named != nil && hasPtrs { writef("dst.%s = src.%s.Clone()", fname, fname) continue } it.Import("tailscale.com/types/ptr") writef("if dst.%s != nil {", fname) - writef("\tdst.%s = ptr.To(*src.%s)", fname, fname) - if codegen.ContainsPointers(ft.Elem()) { + if _, isIface := base.Underlying().(*types.Interface); isIface && hasPtrs { + writef("\tdst.%s = ptr.To((*src.%s).Clone())", fname, fname) + } else if !hasPtrs { + writef("\tdst.%s = ptr.To(*src.%s)", fname, fname) + } else { writef("\t" + `panic("TODO pointers in pointers")`) } writef("}") @@ -172,18 +187,50 @@ func gen(buf *bytes.Buffer, it *codegen.ImportTracker, typ *types.Named) { writef("if dst.%s != nil {", fname) writef("\tdst.%s = map[%s]%s{}", fname, it.QualifiedName(ft.Key()), it.QualifiedName(elem)) writef("\tfor k, v := range src.%s {", fname) - switch elem.(type) { + + switch elem := elem.Underlying().(type) { case *types.Pointer: - writef("\t\tdst.%s[k] = v.Clone()", fname) + writef("\t\tif v == nil { dst.%s[k] = nil } else {", fname) + if base := elem.Elem().Underlying(); codegen.ContainsPointers(base) { + if _, isIface := base.(*types.Interface); isIface { + it.Import("tailscale.com/types/ptr") + writef("\t\t\tdst.%s[k] = ptr.To((*v).Clone())", fname) + } else { + writef("\t\t\tdst.%s[k] = v.Clone()", fname) + } + } else { + it.Import("tailscale.com/types/ptr") + writef("\t\t\tdst.%s[k] = ptr.To(*v)", fname) + } + writef("}") + case *types.Interface: + if cloneResultType := methodResultType(elem, "Clone"); cloneResultType != nil { + if _, isPtr := cloneResultType.(*types.Pointer); isPtr { + writef("\t\tdst.%s[k] = *(v.Clone())", fname) + } else { + writef("\t\tdst.%s[k] = v.Clone()", fname) + } + } else { + writef(`panic("%s (%v) does not have a Clone method")`, fname, elem) + } default: writef("\t\tdst.%s[k] = *(v.Clone())", fname) } + writef("\t}") writef("}") } else { it.Import("maps") writef("\tdst.%s = maps.Clone(src.%s)", fname, fname) } + case *types.Interface: + // If ft is an interface with a "Clone() ft" method, it can be used to clone the field. + // This includes scenarios where ft is a constrained type parameter. + if cloneResultType := methodResultType(ft, "Clone"); cloneResultType.Underlying() == ft { + writef("dst.%s = src.%s.Clone()", fname, fname) + continue + } + writef(`panic("%s (%v) does not have a compatible Clone method")`, fname, ft) default: writef(`panic("TODO: %s (%T)")`, fname, ft) } @@ -191,7 +238,7 @@ func gen(buf *bytes.Buffer, it *codegen.ImportTracker, typ *types.Named) { writef("return dst") fmt.Fprintf(buf, "}\n\n") - buf.Write(codegen.AssertStructUnchanged(t, name, "Clone", it)) + buf.Write(codegen.AssertStructUnchanged(t, name, typeParams, "Clone", it)) } // hasBasicUnderlying reports true when typ.Underlying() is a slice or a map. @@ -203,3 +250,15 @@ func hasBasicUnderlying(typ types.Type) bool { return false } } + +func methodResultType(typ types.Type, method string) types.Type { + viewMethod := codegen.LookupMethod(typ, method) + if viewMethod == nil { + return nil + } + sig, ok := viewMethod.Type().(*types.Signature) + if !ok || sig.Results().Len() != 1 { + return nil + } + return sig.Results().At(0).Type() +} diff --git a/cmd/viewer/tests/tests.go b/cmd/viewer/tests/tests.go index 55413403b..ed4d6914a 100644 --- a/cmd/viewer/tests/tests.go +++ b/cmd/viewer/tests/tests.go @@ -7,9 +7,12 @@ import ( "fmt" "net/netip" + + "golang.org/x/exp/constraints" + "tailscale.com/types/views" ) -//go:generate go run tailscale.com/cmd/viewer --type=StructWithPtrs,StructWithoutPtrs,Map,StructWithSlices,OnlyGetClone,StructWithEmbedded --clone-only-type=OnlyGetClone +//go:generate go run tailscale.com/cmd/viewer --type=StructWithPtrs,StructWithoutPtrs,Map,StructWithSlices,OnlyGetClone,StructWithEmbedded,GenericIntStruct,GenericNoPtrsStruct,GenericCloneableStruct --clone-only-type=OnlyGetClone type StructWithoutPtrs struct { Int int @@ -25,12 +28,12 @@ type Map struct { SlicesWithPtrs map[string][]*StructWithPtrs SlicesWithoutPtrs map[string][]*StructWithoutPtrs StructWithoutPtrKey map[StructWithoutPtrs]int `json:"-"` + StructWithPtr map[string]StructWithPtrs // Unsupported views. SliceIntPtr map[string][]*int PointerKey map[*string]int `json:"-"` StructWithPtrKey map[StructWithPtrs]int `json:"-"` - StructWithPtr map[string]StructWithPtrs } type StructWithPtrs struct { @@ -50,12 +53,14 @@ type StructWithSlices struct { Values []StructWithoutPtrs ValuePointers []*StructWithoutPtrs StructPointers []*StructWithPtrs - Structs []StructWithPtrs - Ints []*int Slice []string Prefixes []netip.Prefix Data []byte + + // Unsupported views. + Structs []StructWithPtrs + Ints []*int } type OnlyGetClone struct { @@ -66,3 +71,46 @@ type StructWithEmbedded struct { A *StructWithPtrs StructWithSlices } + +type GenericIntStruct[T constraints.Integer] struct { + Value T + Pointer *T + Slice []T + Map map[string]T + + // Unsupported views. + PtrSlice []*T + PtrKeyMap map[*T]string `json:"-"` + PtrValueMap map[string]*T + SliceMap map[string][]T +} + +type BasicType interface { + ~bool | constraints.Integer | constraints.Float | constraints.Complex | ~string +} + +type GenericNoPtrsStruct[T StructWithoutPtrs | netip.Prefix | BasicType] struct { + Value T + Pointer *T + Slice []T + Map map[string]T + + // Unsupported views. + PtrSlice []*T + PtrKeyMap map[*T]string `json:"-"` + PtrValueMap map[string]*T + SliceMap map[string][]T +} + +type GenericCloneableStruct[T views.ViewCloner[T, V], V views.StructView[T]] struct { + Value T + Slice []T + Map map[string]T + + // Unsupported views. + Pointer *T + PtrSlice []*T + PtrKeyMap map[*T]string `json:"-"` + PtrValueMap map[string]*T + SliceMap map[string][]T +} diff --git a/cmd/viewer/tests/tests_clone.go b/cmd/viewer/tests/tests_clone.go index 2d8c1ba31..ec5631da9 100644 --- a/cmd/viewer/tests/tests_clone.go +++ b/cmd/viewer/tests/tests_clone.go @@ -9,7 +9,9 @@ "maps" "net/netip" + "golang.org/x/exp/constraints" "tailscale.com/types/ptr" + "tailscale.com/types/views" ) // Clone makes a deep copy of StructWithPtrs. @@ -71,13 +73,21 @@ func (src *Map) Clone() *Map { if dst.StructPtrWithPtr != nil { dst.StructPtrWithPtr = map[string]*StructWithPtrs{} for k, v := range src.StructPtrWithPtr { - dst.StructPtrWithPtr[k] = v.Clone() + if v == nil { + dst.StructPtrWithPtr[k] = nil + } else { + dst.StructPtrWithPtr[k] = v.Clone() + } } } if dst.StructPtrWithoutPtr != nil { dst.StructPtrWithoutPtr = map[string]*StructWithoutPtrs{} for k, v := range src.StructPtrWithoutPtr { - dst.StructPtrWithoutPtr[k] = v.Clone() + if v == nil { + dst.StructPtrWithoutPtr[k] = nil + } else { + dst.StructPtrWithoutPtr[k] = ptr.To(*v) + } } } dst.StructWithoutPtr = maps.Clone(src.StructWithoutPtr) @@ -94,6 +104,12 @@ func (src *Map) Clone() *Map { } } dst.StructWithoutPtrKey = maps.Clone(src.StructWithoutPtrKey) + if dst.StructWithPtr != nil { + dst.StructWithPtr = map[string]StructWithPtrs{} + for k, v := range src.StructWithPtr { + dst.StructWithPtr[k] = *(v.Clone()) + } + } if dst.SliceIntPtr != nil { dst.SliceIntPtr = map[string][]*int{} for k := range src.SliceIntPtr { @@ -102,12 +118,6 @@ func (src *Map) Clone() *Map { } dst.PointerKey = maps.Clone(src.PointerKey) dst.StructWithPtrKey = maps.Clone(src.StructWithPtrKey) - if dst.StructWithPtr != nil { - dst.StructWithPtr = map[string]StructWithPtrs{} - for k, v := range src.StructWithPtr { - dst.StructWithPtr[k] = *(v.Clone()) - } - } return dst } @@ -121,10 +131,10 @@ func (src *Map) Clone() *Map { SlicesWithPtrs map[string][]*StructWithPtrs SlicesWithoutPtrs map[string][]*StructWithoutPtrs StructWithoutPtrKey map[StructWithoutPtrs]int + StructWithPtr map[string]StructWithPtrs SliceIntPtr map[string][]*int PointerKey map[*string]int StructWithPtrKey map[StructWithPtrs]int - StructWithPtr map[string]StructWithPtrs }{}) // Clone makes a deep copy of StructWithSlices. @@ -139,15 +149,26 @@ func (src *StructWithSlices) Clone() *StructWithSlices { if src.ValuePointers != nil { dst.ValuePointers = make([]*StructWithoutPtrs, len(src.ValuePointers)) for i := range dst.ValuePointers { - dst.ValuePointers[i] = src.ValuePointers[i].Clone() + if src.ValuePointers[i] == nil { + dst.ValuePointers[i] = nil + } else { + dst.ValuePointers[i] = ptr.To(*src.ValuePointers[i]) + } } } if src.StructPointers != nil { dst.StructPointers = make([]*StructWithPtrs, len(src.StructPointers)) for i := range dst.StructPointers { - dst.StructPointers[i] = src.StructPointers[i].Clone() + if src.StructPointers[i] == nil { + dst.StructPointers[i] = nil + } else { + dst.StructPointers[i] = src.StructPointers[i].Clone() + } } } + dst.Slice = append(src.Slice[:0:0], src.Slice...) + dst.Prefixes = append(src.Prefixes[:0:0], src.Prefixes...) + dst.Data = append(src.Data[:0:0], src.Data...) if src.Structs != nil { dst.Structs = make([]StructWithPtrs, len(src.Structs)) for i := range dst.Structs { @@ -164,9 +185,6 @@ func (src *StructWithSlices) Clone() *StructWithSlices { } } } - dst.Slice = append(src.Slice[:0:0], src.Slice...) - dst.Prefixes = append(src.Prefixes[:0:0], src.Prefixes...) - dst.Data = append(src.Data[:0:0], src.Data...) return dst } @@ -175,11 +193,11 @@ func (src *StructWithSlices) Clone() *StructWithSlices { Values []StructWithoutPtrs ValuePointers []*StructWithoutPtrs StructPointers []*StructWithPtrs - Structs []StructWithPtrs - Ints []*int Slice []string Prefixes []netip.Prefix Data []byte + Structs []StructWithPtrs + Ints []*int }{}) // Clone makes a deep copy of OnlyGetClone. @@ -216,3 +234,185 @@ func (src *StructWithEmbedded) Clone() *StructWithEmbedded { A *StructWithPtrs StructWithSlices }{}) + +// Clone makes a deep copy of GenericIntStruct. +// The result aliases no memory with the original. +func (src *GenericIntStruct[T]) Clone() *GenericIntStruct[T] { + if src == nil { + return nil + } + dst := new(GenericIntStruct[T]) + *dst = *src + if dst.Pointer != nil { + dst.Pointer = ptr.To(*src.Pointer) + } + dst.Slice = append(src.Slice[:0:0], src.Slice...) + dst.Map = maps.Clone(src.Map) + if src.PtrSlice != nil { + dst.PtrSlice = make([]*T, len(src.PtrSlice)) + for i := range dst.PtrSlice { + if src.PtrSlice[i] == nil { + dst.PtrSlice[i] = nil + } else { + dst.PtrSlice[i] = ptr.To(*src.PtrSlice[i]) + } + } + } + dst.PtrKeyMap = maps.Clone(src.PtrKeyMap) + if dst.PtrValueMap != nil { + dst.PtrValueMap = map[string]*T{} + for k, v := range src.PtrValueMap { + if v == nil { + dst.PtrValueMap[k] = nil + } else { + dst.PtrValueMap[k] = ptr.To(*v) + } + } + } + if dst.SliceMap != nil { + dst.SliceMap = map[string][]T{} + for k := range src.SliceMap { + dst.SliceMap[k] = append([]T{}, src.SliceMap[k]...) + } + } + return dst +} + +// A compilation failure here means this code must be regenerated, with the command at the top of this file. +func _GenericIntStructCloneNeedsRegeneration[T constraints.Integer](GenericIntStruct[T]) { + _GenericIntStructCloneNeedsRegeneration(struct { + Value T + Pointer *T + Slice []T + Map map[string]T + PtrSlice []*T + PtrKeyMap map[*T]string `json:"-"` + PtrValueMap map[string]*T + SliceMap map[string][]T + }{}) +} + +// Clone makes a deep copy of GenericNoPtrsStruct. +// The result aliases no memory with the original. +func (src *GenericNoPtrsStruct[T]) Clone() *GenericNoPtrsStruct[T] { + if src == nil { + return nil + } + dst := new(GenericNoPtrsStruct[T]) + *dst = *src + if dst.Pointer != nil { + dst.Pointer = ptr.To(*src.Pointer) + } + dst.Slice = append(src.Slice[:0:0], src.Slice...) + dst.Map = maps.Clone(src.Map) + if src.PtrSlice != nil { + dst.PtrSlice = make([]*T, len(src.PtrSlice)) + for i := range dst.PtrSlice { + if src.PtrSlice[i] == nil { + dst.PtrSlice[i] = nil + } else { + dst.PtrSlice[i] = ptr.To(*src.PtrSlice[i]) + } + } + } + dst.PtrKeyMap = maps.Clone(src.PtrKeyMap) + if dst.PtrValueMap != nil { + dst.PtrValueMap = map[string]*T{} + for k, v := range src.PtrValueMap { + if v == nil { + dst.PtrValueMap[k] = nil + } else { + dst.PtrValueMap[k] = ptr.To(*v) + } + } + } + if dst.SliceMap != nil { + dst.SliceMap = map[string][]T{} + for k := range src.SliceMap { + dst.SliceMap[k] = append([]T{}, src.SliceMap[k]...) + } + } + return dst +} + +// A compilation failure here means this code must be regenerated, with the command at the top of this file. +func _GenericNoPtrsStructCloneNeedsRegeneration[T StructWithoutPtrs | netip.Prefix | BasicType](GenericNoPtrsStruct[T]) { + _GenericNoPtrsStructCloneNeedsRegeneration(struct { + Value T + Pointer *T + Slice []T + Map map[string]T + PtrSlice []*T + PtrKeyMap map[*T]string `json:"-"` + PtrValueMap map[string]*T + SliceMap map[string][]T + }{}) +} + +// Clone makes a deep copy of GenericCloneableStruct. +// The result aliases no memory with the original. +func (src *GenericCloneableStruct[T, V]) Clone() *GenericCloneableStruct[T, V] { + if src == nil { + return nil + } + dst := new(GenericCloneableStruct[T, V]) + *dst = *src + dst.Value = src.Value.Clone() + if src.Slice != nil { + dst.Slice = make([]T, len(src.Slice)) + for i := range dst.Slice { + dst.Slice[i] = src.Slice[i].Clone() + } + } + if dst.Map != nil { + dst.Map = map[string]T{} + for k, v := range src.Map { + dst.Map[k] = v.Clone() + } + } + if dst.Pointer != nil { + dst.Pointer = ptr.To((*src.Pointer).Clone()) + } + if src.PtrSlice != nil { + dst.PtrSlice = make([]*T, len(src.PtrSlice)) + for i := range dst.PtrSlice { + if src.PtrSlice[i] == nil { + dst.PtrSlice[i] = nil + } else { + dst.PtrSlice[i] = ptr.To((*src.PtrSlice[i]).Clone()) + } + } + } + dst.PtrKeyMap = maps.Clone(src.PtrKeyMap) + if dst.PtrValueMap != nil { + dst.PtrValueMap = map[string]*T{} + for k, v := range src.PtrValueMap { + if v == nil { + dst.PtrValueMap[k] = nil + } else { + dst.PtrValueMap[k] = ptr.To((*v).Clone()) + } + } + } + if dst.SliceMap != nil { + dst.SliceMap = map[string][]T{} + for k := range src.SliceMap { + dst.SliceMap[k] = append([]T{}, src.SliceMap[k]...) + } + } + return dst +} + +// A compilation failure here means this code must be regenerated, with the command at the top of this file. +func _GenericCloneableStructCloneNeedsRegeneration[T views.ViewCloner[T, V], V views.StructView[T]](GenericCloneableStruct[T, V]) { + _GenericCloneableStructCloneNeedsRegeneration(struct { + Value T + Slice []T + Map map[string]T + Pointer *T + PtrSlice []*T + PtrKeyMap map[*T]string `json:"-"` + PtrValueMap map[string]*T + SliceMap map[string][]T + }{}) +} diff --git a/cmd/viewer/tests/tests_view.go b/cmd/viewer/tests/tests_view.go index f8966d20e..9a337f5aa 100644 --- a/cmd/viewer/tests/tests_view.go +++ b/cmd/viewer/tests/tests_view.go @@ -10,10 +10,11 @@ "errors" "net/netip" + "golang.org/x/exp/constraints" "tailscale.com/types/views" ) -//go:generate go run tailscale.com/cmd/cloner -clonefunc=false -type=StructWithPtrs,StructWithoutPtrs,Map,StructWithSlices,OnlyGetClone,StructWithEmbedded +//go:generate go run tailscale.com/cmd/cloner -clonefunc=false -type=StructWithPtrs,StructWithoutPtrs,Map,StructWithSlices,OnlyGetClone,StructWithEmbedded,GenericIntStruct,GenericNoPtrsStruct,GenericCloneableStruct // View returns a readonly view of StructWithPtrs. func (p *StructWithPtrs) View() StructWithPtrsView { @@ -221,15 +222,15 @@ func (v MapView) SlicesWithoutPtrs() views.MapFn[string, []*StructWithoutPtrs, v func (v MapView) StructWithoutPtrKey() views.Map[StructWithoutPtrs, int] { return views.MapOf(v.ж.StructWithoutPtrKey) } -func (v MapView) SliceIntPtr() map[string][]*int { panic("unsupported") } -func (v MapView) PointerKey() map[*string]int { panic("unsupported") } -func (v MapView) StructWithPtrKey() map[StructWithPtrs]int { panic("unsupported") } func (v MapView) StructWithPtr() views.MapFn[string, StructWithPtrs, StructWithPtrsView] { return views.MapFnOf(v.ж.StructWithPtr, func(t StructWithPtrs) StructWithPtrsView { return t.View() }) } +func (v MapView) SliceIntPtr() map[string][]*int { panic("unsupported") } +func (v MapView) PointerKey() map[*string]int { panic("unsupported") } +func (v MapView) StructWithPtrKey() map[StructWithPtrs]int { panic("unsupported") } // A compilation failure here means this code must be regenerated, with the command at the top of this file. var _MapViewNeedsRegeneration = Map(struct { @@ -241,10 +242,10 @@ func (v MapView) StructWithPtr() views.MapFn[string, StructWithPtrs, StructWithP SlicesWithPtrs map[string][]*StructWithPtrs SlicesWithoutPtrs map[string][]*StructWithoutPtrs StructWithoutPtrKey map[StructWithoutPtrs]int + StructWithPtr map[string]StructWithPtrs SliceIntPtr map[string][]*int PointerKey map[*string]int StructWithPtrKey map[StructWithPtrs]int - StructWithPtr map[string]StructWithPtrs }{}) // View returns a readonly view of StructWithSlices. @@ -301,24 +302,24 @@ func (v StructWithSlicesView) ValuePointers() views.SliceView[*StructWithoutPtrs func (v StructWithSlicesView) StructPointers() views.SliceView[*StructWithPtrs, StructWithPtrsView] { return views.SliceOfViews[*StructWithPtrs, StructWithPtrsView](v.ж.StructPointers) } -func (v StructWithSlicesView) Structs() StructWithPtrs { panic("unsupported") } -func (v StructWithSlicesView) Ints() *int { panic("unsupported") } func (v StructWithSlicesView) Slice() views.Slice[string] { return views.SliceOf(v.ж.Slice) } func (v StructWithSlicesView) Prefixes() views.Slice[netip.Prefix] { return views.SliceOf(v.ж.Prefixes) } func (v StructWithSlicesView) Data() views.ByteSlice[[]byte] { return views.ByteSliceOf(v.ж.Data) } +func (v StructWithSlicesView) Structs() StructWithPtrs { panic("unsupported") } +func (v StructWithSlicesView) Ints() *int { panic("unsupported") } // A compilation failure here means this code must be regenerated, with the command at the top of this file. var _StructWithSlicesViewNeedsRegeneration = StructWithSlices(struct { Values []StructWithoutPtrs ValuePointers []*StructWithoutPtrs StructPointers []*StructWithPtrs - Structs []StructWithPtrs - Ints []*int Slice []string Prefixes []netip.Prefix Data []byte + Structs []StructWithPtrs + Ints []*int }{}) // View returns a readonly view of StructWithEmbedded. @@ -376,3 +377,230 @@ func (v StructWithEmbeddedView) StructWithSlices() StructWithSlicesView { A *StructWithPtrs StructWithSlices }{}) + +// View returns a readonly view of GenericIntStruct. +func (p *GenericIntStruct[T]) View() GenericIntStructView[T] { + return GenericIntStructView[T]{ж: p} +} + +// GenericIntStructView[T] provides a read-only view over GenericIntStruct[T]. +// +// Its methods should only be called if `Valid()` returns true. +type GenericIntStructView[T constraints.Integer] struct { + // ж is the underlying mutable value, named with a hard-to-type + // character that looks pointy like a pointer. + // It is named distinctively to make you think of how dangerous it is to escape + // to callers. You must not let callers be able to mutate it. + ж *GenericIntStruct[T] +} + +// Valid reports whether underlying value is non-nil. +func (v GenericIntStructView[T]) Valid() bool { return v.ж != nil } + +// AsStruct returns a clone of the underlying value which aliases no memory with +// the original. +func (v GenericIntStructView[T]) AsStruct() *GenericIntStruct[T] { + if v.ж == nil { + return nil + } + return v.ж.Clone() +} + +func (v GenericIntStructView[T]) MarshalJSON() ([]byte, error) { return json.Marshal(v.ж) } + +func (v *GenericIntStructView[T]) UnmarshalJSON(b []byte) error { + if v.ж != nil { + return errors.New("already initialized") + } + if len(b) == 0 { + return nil + } + var x GenericIntStruct[T] + if err := json.Unmarshal(b, &x); err != nil { + return err + } + v.ж = &x + return nil +} + +func (v GenericIntStructView[T]) Value() T { return v.ж.Value } +func (v GenericIntStructView[T]) Pointer() *T { + if v.ж.Pointer == nil { + return nil + } + x := *v.ж.Pointer + return &x +} + +func (v GenericIntStructView[T]) Slice() views.Slice[T] { return views.SliceOf(v.ж.Slice) } + +func (v GenericIntStructView[T]) Map() views.Map[string, T] { return views.MapOf(v.ж.Map) } +func (v GenericIntStructView[T]) PtrSlice() *T { panic("unsupported") } +func (v GenericIntStructView[T]) PtrKeyMap() map[*T]string { panic("unsupported") } +func (v GenericIntStructView[T]) PtrValueMap() map[string]*T { panic("unsupported") } +func (v GenericIntStructView[T]) SliceMap() map[string][]T { panic("unsupported") } + +// A compilation failure here means this code must be regenerated, with the command at the top of this file. +func _GenericIntStructViewNeedsRegeneration[T constraints.Integer](GenericIntStruct[T]) { + _GenericIntStructViewNeedsRegeneration(struct { + Value T + Pointer *T + Slice []T + Map map[string]T + PtrSlice []*T + PtrKeyMap map[*T]string `json:"-"` + PtrValueMap map[string]*T + SliceMap map[string][]T + }{}) +} + +// View returns a readonly view of GenericNoPtrsStruct. +func (p *GenericNoPtrsStruct[T]) View() GenericNoPtrsStructView[T] { + return GenericNoPtrsStructView[T]{ж: p} +} + +// GenericNoPtrsStructView[T] provides a read-only view over GenericNoPtrsStruct[T]. +// +// Its methods should only be called if `Valid()` returns true. +type GenericNoPtrsStructView[T StructWithoutPtrs | netip.Prefix | BasicType] struct { + // ж is the underlying mutable value, named with a hard-to-type + // character that looks pointy like a pointer. + // It is named distinctively to make you think of how dangerous it is to escape + // to callers. You must not let callers be able to mutate it. + ж *GenericNoPtrsStruct[T] +} + +// Valid reports whether underlying value is non-nil. +func (v GenericNoPtrsStructView[T]) Valid() bool { return v.ж != nil } + +// AsStruct returns a clone of the underlying value which aliases no memory with +// the original. +func (v GenericNoPtrsStructView[T]) AsStruct() *GenericNoPtrsStruct[T] { + if v.ж == nil { + return nil + } + return v.ж.Clone() +} + +func (v GenericNoPtrsStructView[T]) MarshalJSON() ([]byte, error) { return json.Marshal(v.ж) } + +func (v *GenericNoPtrsStructView[T]) UnmarshalJSON(b []byte) error { + if v.ж != nil { + return errors.New("already initialized") + } + if len(b) == 0 { + return nil + } + var x GenericNoPtrsStruct[T] + if err := json.Unmarshal(b, &x); err != nil { + return err + } + v.ж = &x + return nil +} + +func (v GenericNoPtrsStructView[T]) Value() T { return v.ж.Value } +func (v GenericNoPtrsStructView[T]) Pointer() *T { + if v.ж.Pointer == nil { + return nil + } + x := *v.ж.Pointer + return &x +} + +func (v GenericNoPtrsStructView[T]) Slice() views.Slice[T] { return views.SliceOf(v.ж.Slice) } + +func (v GenericNoPtrsStructView[T]) Map() views.Map[string, T] { return views.MapOf(v.ж.Map) } +func (v GenericNoPtrsStructView[T]) PtrSlice() *T { panic("unsupported") } +func (v GenericNoPtrsStructView[T]) PtrKeyMap() map[*T]string { panic("unsupported") } +func (v GenericNoPtrsStructView[T]) PtrValueMap() map[string]*T { panic("unsupported") } +func (v GenericNoPtrsStructView[T]) SliceMap() map[string][]T { panic("unsupported") } + +// A compilation failure here means this code must be regenerated, with the command at the top of this file. +func _GenericNoPtrsStructViewNeedsRegeneration[T StructWithoutPtrs | netip.Prefix | BasicType](GenericNoPtrsStruct[T]) { + _GenericNoPtrsStructViewNeedsRegeneration(struct { + Value T + Pointer *T + Slice []T + Map map[string]T + PtrSlice []*T + PtrKeyMap map[*T]string `json:"-"` + PtrValueMap map[string]*T + SliceMap map[string][]T + }{}) +} + +// View returns a readonly view of GenericCloneableStruct. +func (p *GenericCloneableStruct[T, V]) View() GenericCloneableStructView[T, V] { + return GenericCloneableStructView[T, V]{ж: p} +} + +// GenericCloneableStructView[T, V] provides a read-only view over GenericCloneableStruct[T, V]. +// +// Its methods should only be called if `Valid()` returns true. +type GenericCloneableStructView[T views.ViewCloner[T, V], V views.StructView[T]] struct { + // ж is the underlying mutable value, named with a hard-to-type + // character that looks pointy like a pointer. + // It is named distinctively to make you think of how dangerous it is to escape + // to callers. You must not let callers be able to mutate it. + ж *GenericCloneableStruct[T, V] +} + +// Valid reports whether underlying value is non-nil. +func (v GenericCloneableStructView[T, V]) Valid() bool { return v.ж != nil } + +// AsStruct returns a clone of the underlying value which aliases no memory with +// the original. +func (v GenericCloneableStructView[T, V]) AsStruct() *GenericCloneableStruct[T, V] { + if v.ж == nil { + return nil + } + return v.ж.Clone() +} + +func (v GenericCloneableStructView[T, V]) MarshalJSON() ([]byte, error) { return json.Marshal(v.ж) } + +func (v *GenericCloneableStructView[T, V]) UnmarshalJSON(b []byte) error { + if v.ж != nil { + return errors.New("already initialized") + } + if len(b) == 0 { + return nil + } + var x GenericCloneableStruct[T, V] + if err := json.Unmarshal(b, &x); err != nil { + return err + } + v.ж = &x + return nil +} + +func (v GenericCloneableStructView[T, V]) Value() V { return v.ж.Value.View() } +func (v GenericCloneableStructView[T, V]) Slice() views.SliceView[T, V] { + return views.SliceOfViews[T, V](v.ж.Slice) +} + +func (v GenericCloneableStructView[T, V]) Map() views.MapFn[string, T, V] { + return views.MapFnOf(v.ж.Map, func(t T) V { + return t.View() + }) +} +func (v GenericCloneableStructView[T, V]) Pointer() map[string]T { panic("unsupported") } +func (v GenericCloneableStructView[T, V]) PtrSlice() *T { panic("unsupported") } +func (v GenericCloneableStructView[T, V]) PtrKeyMap() map[*T]string { panic("unsupported") } +func (v GenericCloneableStructView[T, V]) PtrValueMap() map[string]*T { panic("unsupported") } +func (v GenericCloneableStructView[T, V]) SliceMap() map[string][]T { panic("unsupported") } + +// A compilation failure here means this code must be regenerated, with the command at the top of this file. +func _GenericCloneableStructViewNeedsRegeneration[T views.ViewCloner[T, V], V views.StructView[T]](GenericCloneableStruct[T, V]) { + _GenericCloneableStructViewNeedsRegeneration(struct { + Value T + Slice []T + Map map[string]T + Pointer *T + PtrSlice []*T + PtrKeyMap map[*T]string `json:"-"` + PtrValueMap map[string]*T + SliceMap map[string][]T + }{}) +} diff --git a/cmd/viewer/viewer.go b/cmd/viewer/viewer.go index f89210268..557b6e459 100644 --- a/cmd/viewer/viewer.go +++ b/cmd/viewer/viewer.go @@ -20,43 +20,43 @@ const viewTemplateStr = `{{define "common"}} // View returns a readonly view of {{.StructName}}. -func (p *{{.StructName}}) View() {{.ViewName}} { - return {{.ViewName}}{ж: p} +func (p *{{.StructName}}{{.TypeParamNames}}) View() {{.ViewName}}{{.TypeParamNames}} { + return {{.ViewName}}{{.TypeParamNames}}{ж: p} } -// {{.ViewName}} provides a read-only view over {{.StructName}}. +// {{.ViewName}}{{.TypeParamNames}} provides a read-only view over {{.StructName}}{{.TypeParamNames}}. // // Its methods should only be called if ` + "`Valid()`" + ` returns true. -type {{.ViewName}} struct { +type {{.ViewName}}{{.TypeParams}} struct { // ж is the underlying mutable value, named with a hard-to-type // character that looks pointy like a pointer. // It is named distinctively to make you think of how dangerous it is to escape // to callers. You must not let callers be able to mutate it. - ж *{{.StructName}} + ж *{{.StructName}}{{.TypeParamNames}} } // Valid reports whether underlying value is non-nil. -func (v {{.ViewName}}) Valid() bool { return v.ж != nil } +func (v {{.ViewName}}{{.TypeParamNames}}) Valid() bool { return v.ж != nil } // AsStruct returns a clone of the underlying value which aliases no memory with // the original. -func (v {{.ViewName}}) AsStruct() *{{.StructName}}{ +func (v {{.ViewName}}{{.TypeParamNames}}) AsStruct() *{{.StructName}}{{.TypeParamNames}}{ if v.ж == nil { return nil } return v.ж.Clone() } -func (v {{.ViewName}}) MarshalJSON() ([]byte, error) { return json.Marshal(v.ж) } +func (v {{.ViewName}}{{.TypeParamNames}}) MarshalJSON() ([]byte, error) { return json.Marshal(v.ж) } -func (v *{{.ViewName}}) UnmarshalJSON(b []byte) error { +func (v *{{.ViewName}}{{.TypeParamNames}}) UnmarshalJSON(b []byte) error { if v.ж != nil { return errors.New("already initialized") } if len(b) == 0 { return nil } - var x {{.StructName}} + var x {{.StructName}}{{.TypeParamNames}} if err := json.Unmarshal(b, &x); err != nil { return err } @@ -65,17 +65,17 @@ func (v *{{.ViewName}}) UnmarshalJSON(b []byte) error { } {{end}} -{{define "valueField"}}func (v {{.ViewName}}) {{.FieldName}}() {{.FieldType}} { return v.ж.{{.FieldName}} } +{{define "valueField"}}func (v {{.ViewName}}{{.TypeParamNames}}) {{.FieldName}}() {{.FieldType}} { return v.ж.{{.FieldName}} } {{end}} -{{define "byteSliceField"}}func (v {{.ViewName}}) {{.FieldName}}() views.ByteSlice[{{.FieldType}}] { return views.ByteSliceOf(v.ж.{{.FieldName}}) } +{{define "byteSliceField"}}func (v {{.ViewName}}{{.TypeParamNames}}) {{.FieldName}}() views.ByteSlice[{{.FieldType}}] { return views.ByteSliceOf(v.ж.{{.FieldName}}) } {{end}} -{{define "sliceField"}}func (v {{.ViewName}}) {{.FieldName}}() views.Slice[{{.FieldType}}] { return views.SliceOf(v.ж.{{.FieldName}}) } +{{define "sliceField"}}func (v {{.ViewName}}{{.TypeParamNames}}) {{.FieldName}}() views.Slice[{{.FieldType}}] { return views.SliceOf(v.ж.{{.FieldName}}) } {{end}} -{{define "viewSliceField"}}func (v {{.ViewName}}) {{.FieldName}}() views.SliceView[{{.FieldType}},{{.FieldViewName}}] { return views.SliceOfViews[{{.FieldType}},{{.FieldViewName}}](v.ж.{{.FieldName}}) } +{{define "viewSliceField"}}func (v {{.ViewName}}{{.TypeParamNames}}) {{.FieldName}}() views.SliceView[{{.FieldType}},{{.FieldViewName}}] { return views.SliceOfViews[{{.FieldType}},{{.FieldViewName}}](v.ж.{{.FieldName}}) } {{end}} -{{define "viewField"}}func (v {{.ViewName}}) {{.FieldName}}() {{.FieldType}}View { return v.ж.{{.FieldName}}.View() } +{{define "viewField"}}func (v {{.ViewName}}{{.TypeParamNames}}) {{.FieldName}}() {{.FieldViewName}} { return v.ж.{{.FieldName}}.View() } {{end}} -{{define "valuePointerField"}}func (v {{.ViewName}}) {{.FieldName}}() {{.FieldType}} { +{{define "valuePointerField"}}func (v {{.ViewName}}{{.TypeParamNames}}) {{.FieldName}}() {{.FieldType}} { if v.ж.{{.FieldName}} == nil { return nil } @@ -85,21 +85,21 @@ func (v *{{.ViewName}}) UnmarshalJSON(b []byte) error { {{end}} {{define "mapField"}} -func(v {{.ViewName}}) {{.FieldName}}() views.Map[{{.MapKeyType}},{{.MapValueType}}] { return views.MapOf(v.ж.{{.FieldName}})} +func(v {{.ViewName}}{{.TypeParamNames}}) {{.FieldName}}() views.Map[{{.MapKeyType}},{{.MapValueType}}] { return views.MapOf(v.ж.{{.FieldName}})} {{end}} {{define "mapFnField"}} -func(v {{.ViewName}}) {{.FieldName}}() views.MapFn[{{.MapKeyType}},{{.MapValueType}},{{.MapValueView}}] { return views.MapFnOf(v.ж.{{.FieldName}}, func (t {{.MapValueType}}) {{.MapValueView}} { +func(v {{.ViewName}}{{.TypeParamNames}}) {{.FieldName}}() views.MapFn[{{.MapKeyType}},{{.MapValueType}},{{.MapValueView}}] { return views.MapFnOf(v.ж.{{.FieldName}}, func (t {{.MapValueType}}) {{.MapValueView}} { return {{.MapFn}} })} {{end}} {{define "mapSliceField"}} -func(v {{.ViewName}}) {{.FieldName}}() views.MapSlice[{{.MapKeyType}},{{.MapValueType}}] { return views.MapSliceOf(v.ж.{{.FieldName}}) } +func(v {{.ViewName}}{{.TypeParamNames}}) {{.FieldName}}() views.MapSlice[{{.MapKeyType}},{{.MapValueType}}] { return views.MapSliceOf(v.ж.{{.FieldName}}) } {{end}} -{{define "unsupportedField"}}func(v {{.ViewName}}) {{.FieldName}}() {{.FieldType}} {panic("unsupported")} +{{define "unsupportedField"}}func(v {{.ViewName}}{{.TypeParamNames}}) {{.FieldName}}() {{.FieldType}} {panic("unsupported")} {{end}} -{{define "stringFunc"}}func(v {{.ViewName}}) String() string { return v.ж.String() } +{{define "stringFunc"}}func(v {{.ViewName}}{{.TypeParamNames}}) String() string { return v.ж.String() } {{end}} -{{define "equalFunc"}}func(v {{.ViewName}}) Equal(v2 {{.ViewName}}) bool { return v.ж.Equal(v2.ж) } +{{define "equalFunc"}}func(v {{.ViewName}}{{.TypeParamNames}}) Equal(v2 {{.ViewName}}{{.TypeParamNames}}) bool { return v.ж.Equal(v2.ж) } {{end}} ` @@ -131,8 +131,11 @@ func genView(buf *bytes.Buffer, it *codegen.ImportTracker, typ *types.Named, thi it.Import("errors") args := struct { - StructName string - ViewName string + StructName string + ViewName string + TypeParams string // e.g. [T constraints.Integer] + TypeParamNames string // e.g. [T] + FieldName string FieldType string FieldViewName string @@ -143,9 +146,12 @@ func genView(buf *bytes.Buffer, it *codegen.ImportTracker, typ *types.Named, thi MapFn string }{ StructName: typ.Obj().Name(), - ViewName: typ.Obj().Name() + "View", + ViewName: typ.Origin().Obj().Name() + "View", } + typeParams := typ.Origin().TypeParams() + args.TypeParams, args.TypeParamNames = codegen.FormatTypeParams(typeParams, it) + writeTemplate := func(name string) { if err := viewTemplate.ExecuteTemplate(buf, name, args); err != nil { log.Fatal(err) @@ -182,19 +188,35 @@ func genView(buf *bytes.Buffer, it *codegen.ImportTracker, typ *types.Named, thi it.Import("tailscale.com/types/views") shallow, deep, base := requiresCloning(elem) if deep { - if _, isPtr := elem.(*types.Pointer); isPtr { - args.FieldViewName = it.QualifiedName(base) + "View" - writeTemplate("viewSliceField") - } else { - writeTemplate("unsupportedField") + switch elem.Underlying().(type) { + case *types.Pointer: + if _, isIface := base.Underlying().(*types.Interface); !isIface { + args.FieldViewName = appendNameSuffix(it.QualifiedName(base), "View") + writeTemplate("viewSliceField") + } else { + writeTemplate("unsupportedField") + } + continue + case *types.Interface: + if viewType := viewTypeForValueType(elem); viewType != nil { + args.FieldViewName = it.QualifiedName(viewType) + writeTemplate("viewSliceField") + continue + } } + writeTemplate("unsupportedField") continue } else if shallow { - if _, isBasic := base.(*types.Basic); isBasic { + switch base.Underlying().(type) { + case *types.Basic, *types.Interface: writeTemplate("unsupportedField") - } else { - args.FieldViewName = it.QualifiedName(base) + "View" - writeTemplate("viewSliceField") + default: + if _, isIface := base.Underlying().(*types.Interface); !isIface { + args.FieldViewName = appendNameSuffix(it.QualifiedName(base), "View") + writeTemplate("viewSliceField") + } else { + writeTemplate("unsupportedField") + } } continue } @@ -205,6 +227,7 @@ func genView(buf *bytes.Buffer, it *codegen.ImportTracker, typ *types.Named, thi strucT := underlying args.FieldType = it.QualifiedName(fieldType) if codegen.ContainsPointers(strucT) { + args.FieldViewName = appendNameSuffix(args.FieldType, "View") writeTemplate("viewField") continue } @@ -229,7 +252,7 @@ func genView(buf *bytes.Buffer, it *codegen.ImportTracker, typ *types.Named, thi args.MapFn = "t.View()" template = "mapFnField" args.MapValueType = it.QualifiedName(mElem) - args.MapValueView = args.MapValueType + "View" + args.MapValueView = appendNameSuffix(args.MapValueType, "View") } else { template = "mapField" args.MapValueType = it.QualifiedName(mElem) @@ -249,15 +272,20 @@ func genView(buf *bytes.Buffer, it *codegen.ImportTracker, typ *types.Named, thi case *types.Pointer: ptr := x pElem := ptr.Elem() - switch pElem.(type) { - case *types.Struct, *types.Named: - ptrType := it.QualifiedName(ptr) - viewType := it.QualifiedName(pElem) + "View" - args.MapFn = fmt.Sprintf("views.SliceOfViews[%v,%v](t)", ptrType, viewType) - args.MapValueView = fmt.Sprintf("views.SliceView[%v,%v]", ptrType, viewType) - args.MapValueType = "[]" + ptrType - template = "mapFnField" - default: + template = "unsupportedField" + if _, isIface := pElem.Underlying().(*types.Interface); !isIface { + switch pElem.(type) { + case *types.Struct, *types.Named: + ptrType := it.QualifiedName(ptr) + viewType := appendNameSuffix(it.QualifiedName(pElem), "View") + args.MapFn = fmt.Sprintf("views.SliceOfViews[%v,%v](t)", ptrType, viewType) + args.MapValueView = fmt.Sprintf("views.SliceView[%v,%v]", ptrType, viewType) + args.MapValueType = "[]" + ptrType + template = "mapFnField" + default: + template = "unsupportedField" + } + } else { template = "unsupportedField" } default: @@ -266,13 +294,29 @@ func genView(buf *bytes.Buffer, it *codegen.ImportTracker, typ *types.Named, thi case *types.Pointer: ptr := u pElem := ptr.Elem() - switch pElem.(type) { - case *types.Struct, *types.Named: - args.MapValueType = it.QualifiedName(ptr) - args.MapValueView = it.QualifiedName(pElem) + "View" + if _, isIface := pElem.Underlying().(*types.Interface); !isIface { + switch pElem.(type) { + case *types.Struct, *types.Named: + args.MapValueType = it.QualifiedName(ptr) + args.MapValueView = appendNameSuffix(it.QualifiedName(pElem), "View") + args.MapFn = "t.View()" + template = "mapFnField" + default: + template = "unsupportedField" + } + } else { + template = "unsupportedField" + } + case *types.Interface, *types.TypeParam: + if viewType := viewTypeForValueType(u); viewType != nil { + args.MapValueType = it.QualifiedName(u) + args.MapValueView = it.QualifiedName(viewType) args.MapFn = "t.View()" template = "mapFnField" - default: + } else if !codegen.ContainsPointers(u) { + args.MapValueType = it.QualifiedName(mElem) + template = "mapField" + } else { template = "unsupportedField" } default: @@ -283,14 +327,28 @@ func genView(buf *bytes.Buffer, it *codegen.ImportTracker, typ *types.Named, thi case *types.Pointer: ptr := underlying _, deep, base := requiresCloning(ptr) + if deep { - args.FieldType = it.QualifiedName(base) - writeTemplate("viewField") + if _, isIface := base.Underlying().(*types.Interface); !isIface { + args.FieldType = it.QualifiedName(base) + args.FieldViewName = appendNameSuffix(args.FieldType, "View") + writeTemplate("viewField") + } else { + writeTemplate("unsupportedField") + } } else { args.FieldType = it.QualifiedName(ptr) writeTemplate("valuePointerField") } continue + case *types.Interface: + // If fieldType is an interface with a "View() {ViewType}" method, it can be used to clone the field. + // This includes scenarios where fieldType is a constrained type parameter. + if viewType := viewTypeForValueType(underlying); viewType != nil { + args.FieldViewName = it.QualifiedName(viewType) + writeTemplate("viewField") + continue + } } writeTemplate("unsupportedField") } @@ -318,7 +376,27 @@ func genView(buf *bytes.Buffer, it *codegen.ImportTracker, typ *types.Named, thi } } fmt.Fprintf(buf, "\n") - buf.Write(codegen.AssertStructUnchanged(t, args.StructName, "View", it)) + buf.Write(codegen.AssertStructUnchanged(t, args.StructName, typeParams, "View", it)) +} + +func appendNameSuffix(name, suffix string) string { + if idx := strings.IndexRune(name, '['); idx != -1 { + // Insert suffix after the type name, but before type parameters. + return name[:idx] + suffix + name[idx:] + } + return name + suffix +} + +func viewTypeForValueType(typ types.Type) types.Type { + viewMethod := codegen.LookupMethod(typ, "View") + if viewMethod == nil { + return nil + } + sig, ok := viewMethod.Type().(*types.Signature) + if !ok || sig.Results().Len() != 1 { + return nil + } + return sig.Results().At(0).Type() } var ( diff --git a/ipn/ipn_clone.go b/ipn/ipn_clone.go index 9457c50f0..de35b60a7 100644 --- a/ipn/ipn_clone.go +++ b/ipn/ipn_clone.go @@ -14,6 +14,7 @@ "tailscale.com/types/opt" "tailscale.com/types/persist" "tailscale.com/types/preftype" + "tailscale.com/types/ptr" ) // Clone makes a deep copy of Prefs. @@ -29,7 +30,11 @@ func (src *Prefs) Clone() *Prefs { if src.DriveShares != nil { dst.DriveShares = make([]*drive.Share, len(src.DriveShares)) for i := range dst.DriveShares { - dst.DriveShares[i] = src.DriveShares[i].Clone() + if src.DriveShares[i] == nil { + dst.DriveShares[i] = nil + } else { + dst.DriveShares[i] = src.DriveShares[i].Clone() + } } } dst.Persist = src.Persist.Clone() @@ -81,20 +86,32 @@ func (src *ServeConfig) Clone() *ServeConfig { if dst.TCP != nil { dst.TCP = map[uint16]*TCPPortHandler{} for k, v := range src.TCP { - dst.TCP[k] = v.Clone() + if v == nil { + dst.TCP[k] = nil + } else { + dst.TCP[k] = ptr.To(*v) + } } } if dst.Web != nil { dst.Web = map[HostPort]*WebServerConfig{} for k, v := range src.Web { - dst.Web[k] = v.Clone() + if v == nil { + dst.Web[k] = nil + } else { + dst.Web[k] = v.Clone() + } } } dst.AllowFunnel = maps.Clone(src.AllowFunnel) if dst.Foreground != nil { dst.Foreground = map[string]*ServeConfig{} for k, v := range src.Foreground { - dst.Foreground[k] = v.Clone() + if v == nil { + dst.Foreground[k] = nil + } else { + dst.Foreground[k] = v.Clone() + } } } return dst @@ -157,7 +174,11 @@ func (src *WebServerConfig) Clone() *WebServerConfig { if dst.Handlers != nil { dst.Handlers = map[string]*HTTPHandler{} for k, v := range src.Handlers { - dst.Handlers[k] = v.Clone() + if v == nil { + dst.Handlers[k] = nil + } else { + dst.Handlers[k] = ptr.To(*v) + } } } return dst diff --git a/tailcfg/tailcfg_clone.go b/tailcfg/tailcfg_clone.go index 823fe6810..a98efe4d1 100644 --- a/tailcfg/tailcfg_clone.go +++ b/tailcfg/tailcfg_clone.go @@ -77,7 +77,11 @@ func (src *Node) Clone() *Node { if src.ExitNodeDNSResolvers != nil { dst.ExitNodeDNSResolvers = make([]*dnstype.Resolver, len(src.ExitNodeDNSResolvers)) for i := range dst.ExitNodeDNSResolvers { - dst.ExitNodeDNSResolvers[i] = src.ExitNodeDNSResolvers[i].Clone() + if src.ExitNodeDNSResolvers[i] == nil { + dst.ExitNodeDNSResolvers[i] = nil + } else { + dst.ExitNodeDNSResolvers[i] = src.ExitNodeDNSResolvers[i].Clone() + } } } return dst @@ -244,7 +248,11 @@ func (src *DNSConfig) Clone() *DNSConfig { if src.Resolvers != nil { dst.Resolvers = make([]*dnstype.Resolver, len(src.Resolvers)) for i := range dst.Resolvers { - dst.Resolvers[i] = src.Resolvers[i].Clone() + if src.Resolvers[i] == nil { + dst.Resolvers[i] = nil + } else { + dst.Resolvers[i] = src.Resolvers[i].Clone() + } } } if dst.Routes != nil { @@ -256,7 +264,11 @@ func (src *DNSConfig) Clone() *DNSConfig { if src.FallbackResolvers != nil { dst.FallbackResolvers = make([]*dnstype.Resolver, len(src.FallbackResolvers)) for i := range dst.FallbackResolvers { - dst.FallbackResolvers[i] = src.FallbackResolvers[i].Clone() + if src.FallbackResolvers[i] == nil { + dst.FallbackResolvers[i] = nil + } else { + dst.FallbackResolvers[i] = src.FallbackResolvers[i].Clone() + } } } dst.Domains = append(src.Domains[:0:0], src.Domains...) @@ -393,7 +405,11 @@ func (src *DERPRegion) Clone() *DERPRegion { if src.Nodes != nil { dst.Nodes = make([]*DERPNode, len(src.Nodes)) for i := range dst.Nodes { - dst.Nodes[i] = src.Nodes[i].Clone() + if src.Nodes[i] == nil { + dst.Nodes[i] = nil + } else { + dst.Nodes[i] = ptr.To(*src.Nodes[i]) + } } } return dst @@ -422,7 +438,11 @@ func (src *DERPMap) Clone() *DERPMap { if dst.Regions != nil { dst.Regions = map[int]*DERPRegion{} for k, v := range src.Regions { - dst.Regions[k] = v.Clone() + if v == nil { + dst.Regions[k] = nil + } else { + dst.Regions[k] = v.Clone() + } } } return dst @@ -476,7 +496,11 @@ func (src *SSHRule) Clone() *SSHRule { if src.Principals != nil { dst.Principals = make([]*SSHPrincipal, len(src.Principals)) for i := range dst.Principals { - dst.Principals[i] = src.Principals[i].Clone() + if src.Principals[i] == nil { + dst.Principals[i] = nil + } else { + dst.Principals[i] = src.Principals[i].Clone() + } } } dst.SSHUsers = maps.Clone(src.SSHUsers) diff --git a/util/codegen/codegen.go b/util/codegen/codegen.go index 6c2b3c71e..13dbc94a4 100644 --- a/util/codegen/codegen.go +++ b/util/codegen/codegen.go @@ -27,9 +27,9 @@ func LoadTypes(buildTags string, pkgName string) (*packages.Package, map[string]*types.Named, error) { cfg := &packages.Config{ Mode: packages.NeedTypes | packages.NeedTypesInfo | packages.NeedSyntax | packages.NeedName, - Tests: false, + Tests: buildTags == "test", } - if buildTags != "" { + if buildTags != "" && !cfg.Tests { cfg.BuildFlags = []string{"-tags=" + buildTags} } @@ -37,6 +37,9 @@ func LoadTypes(buildTags string, pkgName string) (*packages.Package, map[string] if err != nil { return nil, nil, err } + if cfg.Tests { + pkgs = testPackages(pkgs) + } if len(pkgs) != 1 { return nil, nil, fmt.Errorf("wrong number of packages: %d", len(pkgs)) } @@ -44,6 +47,17 @@ func LoadTypes(buildTags string, pkgName string) (*packages.Package, map[string] return pkg, namedTypes(pkg), nil } +func testPackages(pkgs []*packages.Package) []*packages.Package { + var testPackages []*packages.Package + for _, pkg := range pkgs { + testPackageID := fmt.Sprintf("%[1]s [%[1]s.test]", pkg.PkgPath) + if pkg.ID == testPackageID { + testPackages = append(testPackages, pkg) + } + } + return testPackages +} + // HasNoClone reports whether the provided tag has `codegen:noclone`. func HasNoClone(structTag string) bool { val := reflect.StructTag(structTag).Get("codegen") @@ -193,13 +207,21 @@ func namedTypes(pkg *packages.Package) map[string]*types.Named { // ctx is a single-word context for this assertion, such as "Clone". // If non-nil, AssertStructUnchanged will add elements to imports // for each package path that the caller must import for the returned code to compile. -func AssertStructUnchanged(t *types.Struct, tname, ctx string, it *ImportTracker) []byte { +func AssertStructUnchanged(t *types.Struct, tname string, params *types.TypeParamList, ctx string, it *ImportTracker) []byte { buf := new(bytes.Buffer) w := func(format string, args ...any) { fmt.Fprintf(buf, format+"\n", args...) } w("// A compilation failure here means this code must be regenerated, with the command at the top of this file.") - w("var _%s%sNeedsRegeneration = %s(struct {", tname, ctx, tname) + + hasTypeParams := params != nil && params.Len() > 0 + if hasTypeParams { + constraints, identifiers := FormatTypeParams(params, it) + w("func _%s%sNeedsRegeneration%s (%s%s) {", tname, ctx, constraints, tname, identifiers) + w("_%s%sNeedsRegeneration(struct {", tname, ctx) + } else { + w("var _%s%sNeedsRegeneration = %s(struct {", tname, ctx, tname) + } for i := range t.NumFields() { st := t.Field(i) @@ -209,14 +231,25 @@ func AssertStructUnchanged(t *types.Struct, tname, ctx string, it *ImportTracker continue } qname := it.QualifiedName(ft) + var tag string + if hasTypeParams { + tag = t.Tag(i) + if tag != "" { + tag = "`" + tag + "`" + } + } if st.Anonymous() { - w("\t%s ", fname) + w("\t%s %s", fname, tag) } else { - w("\t%s %s", fname, qname) + w("\t%s %s %s", fname, qname, tag) } } - w("}{})\n") + if hasTypeParams { + w("}{})\n}") + } else { + w("}{})") + } return buf.Bytes() } @@ -242,10 +275,21 @@ func ContainsPointers(typ types.Type) bool { switch ft := typ.Underlying().(type) { case *types.Array: return ContainsPointers(ft.Elem()) + case *types.Basic: + if ft.Kind() == types.UnsafePointer { + return true + } case *types.Chan: return true case *types.Interface: - return true // a little too broad + if ft.Empty() || ft.IsMethodSet() { + return true + } + for i := 0; i < ft.NumEmbeddeds(); i++ { + if ContainsPointers(ft.EmbeddedType(i)) { + return true + } + } case *types.Map: return true case *types.Pointer: @@ -258,6 +302,12 @@ func ContainsPointers(typ types.Type) bool { return true } } + case *types.Union: + for i := range ft.Len() { + if ContainsPointers(ft.Term(i).Type()) { + return true + } + } } return false } @@ -273,3 +323,44 @@ func IsViewType(typ types.Type) bool { } return t.Field(0).Name() == "ж" } + +// FormatTypeParams formats the specified params and returns two strings: +// - constraints are comma-separated type parameters and their constraints in square brackets (e.g. [T any, V constraints.Integer]) +// - names are comma-separated type parameter names in square brackets (e.g. [T, V]) +// +// If params is nil or empty, both return values are empty strings. +func FormatTypeParams(params *types.TypeParamList, it *ImportTracker) (constraints, names string) { + if params == nil || params.Len() == 0 { + return "", "" + } + var constraintList, nameList []string + for i := range params.Len() { + param := params.At(i) + name := param.Obj().Name() + constraint := it.QualifiedName(param.Constraint()) + nameList = append(nameList, name) + constraintList = append(constraintList, name+" "+constraint) + } + constraints = "[" + strings.Join(constraintList, ", ") + "]" + names = "[" + strings.Join(nameList, ", ") + "]" + return constraints, names +} + +// LookupMethod returns the method with the specified name in t, or nil if the method does not exist. +func LookupMethod(t types.Type, name string) *types.Func { + if t, ok := t.(*types.Named); ok { + for i := 0; i < t.NumMethods(); i++ { + if method := t.Method(i); method.Name() == name { + return method + } + } + } + if t, ok := t.Underlying().(*types.Interface); ok { + for i := 0; i < t.NumMethods(); i++ { + if method := t.Method(i); method.Name() == name { + return method + } + } + } + return nil +} diff --git a/util/codegen/codegen_test.go b/util/codegen/codegen_test.go new file mode 100644 index 000000000..5f4a13979 --- /dev/null +++ b/util/codegen/codegen_test.go @@ -0,0 +1,176 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package codegen + +import ( + "log" + "net/netip" + "testing" + "unsafe" + + "golang.org/x/exp/constraints" +) + +type AnyParam[T any] struct { + V T +} + +type AnyParamPhantom[T any] struct { +} + +type IntegerParam[T constraints.Integer] struct { + V T +} + +type FloatParam[T constraints.Float] struct { + V T +} + +type StringLikeParam[T ~string] struct { + V T +} + +type BasicType interface { + ~bool | constraints.Integer | constraints.Float | constraints.Complex | ~string +} + +type BasicTypeParam[T BasicType] struct { + V T +} + +type IntPtr *int + +type IntPtrParam[T IntPtr] struct { + V T +} + +type IntegerPtr interface { + *int | *int32 | *int64 +} + +type IntegerPtrParam[T IntegerPtr] struct { + V T +} + +type IntegerParamPtr[T constraints.Integer] struct { + V *T +} + +type IntegerSliceParam[T constraints.Integer] struct { + V []T +} + +type IntegerMapParam[T constraints.Integer] struct { + V []T +} + +type UnsafePointerParam[T unsafe.Pointer] struct { + V T +} + +type ValueUnionParam[T netip.Prefix | BasicType] struct { + V T +} + +type ValueUnionParamPtr[T netip.Prefix | BasicType] struct { + V *T +} + +type PointerUnionParam[T netip.Prefix | BasicType | IntPtr] struct { + V T +} + +type Interface interface { + Method() +} + +type InterfaceParam[T Interface] struct { + V T +} + +func TestGenericContainsPointers(t *testing.T) { + tests := []struct { + typ string + wantPointer bool + }{ + { + typ: "AnyParam", + wantPointer: true, + }, + { + typ: "AnyParamPhantom", + wantPointer: false, // has a pointer type parameter, but no pointer fields + }, + { + typ: "IntegerParam", + wantPointer: false, + }, + { + typ: "FloatParam", + wantPointer: false, + }, + { + typ: "StringLikeParam", + wantPointer: false, + }, + { + typ: "BasicTypeParam", + wantPointer: false, + }, + { + typ: "IntPtrParam", + wantPointer: true, + }, + { + typ: "IntegerPtrParam", + wantPointer: true, + }, + { + typ: "IntegerParamPtr", + wantPointer: true, + }, + { + typ: "IntegerSliceParam", + wantPointer: true, + }, + { + typ: "IntegerMapParam", + wantPointer: true, + }, + { + typ: "UnsafePointerParam", + wantPointer: true, + }, + { + typ: "InterfaceParam", + wantPointer: true, + }, + { + typ: "ValueUnionParam", + wantPointer: false, + }, + { + typ: "ValueUnionParamPtr", + wantPointer: true, + }, + { + typ: "PointerUnionParam", + wantPointer: true, + }, + } + + _, namedTypes, err := LoadTypes("test", ".") + if err != nil { + log.Fatal(err) + } + + for _, tt := range tests { + t.Run(tt.typ, func(t *testing.T) { + typ := namedTypes[tt.typ] + if isPointer := ContainsPointers(typ); isPointer != tt.wantPointer { + t.Fatalf("ContainsPointers: got %v, want: %v", isPointer, tt.wantPointer) + } + }) + } +}