cmd/cloner, cmd/viewer, util/codegen: add support for generic types and interfaces

This adds support for generic types and interfaces to our cloner and viewer codegens.
It updates these packages to determine whether to make shallow or deep copies based
on the type parameter constraints. Additionally, if a template parameter or an interface
type has View() and Clone() methods, we'll use them for getters and the cloner of the
owning structure.

Updates #12736

Signed-off-by: Nick Khyl <nickk@tailscale.com>
This commit is contained in:
Nick Khyl
2024-07-08 10:11:00 -05:00
committed by Nick Khyl
parent b7c3cfe049
commit fc28c8e7f3
9 changed files with 1039 additions and 114 deletions

View File

@@ -27,9 +27,9 @@ var flagCopyright = flag.Bool("copyright", true, "add Tailscale copyright to gen
func LoadTypes(buildTags string, pkgName string) (*packages.Package, map[string]*types.Named, error) {
cfg := &packages.Config{
Mode: packages.NeedTypes | packages.NeedTypesInfo | packages.NeedSyntax | packages.NeedName,
Tests: false,
Tests: buildTags == "test",
}
if buildTags != "" {
if buildTags != "" && !cfg.Tests {
cfg.BuildFlags = []string{"-tags=" + buildTags}
}
@@ -37,6 +37,9 @@ func LoadTypes(buildTags string, pkgName string) (*packages.Package, map[string]
if err != nil {
return nil, nil, err
}
if cfg.Tests {
pkgs = testPackages(pkgs)
}
if len(pkgs) != 1 {
return nil, nil, fmt.Errorf("wrong number of packages: %d", len(pkgs))
}
@@ -44,6 +47,17 @@ func LoadTypes(buildTags string, pkgName string) (*packages.Package, map[string]
return pkg, namedTypes(pkg), nil
}
func testPackages(pkgs []*packages.Package) []*packages.Package {
var testPackages []*packages.Package
for _, pkg := range pkgs {
testPackageID := fmt.Sprintf("%[1]s [%[1]s.test]", pkg.PkgPath)
if pkg.ID == testPackageID {
testPackages = append(testPackages, pkg)
}
}
return testPackages
}
// HasNoClone reports whether the provided tag has `codegen:noclone`.
func HasNoClone(structTag string) bool {
val := reflect.StructTag(structTag).Get("codegen")
@@ -193,13 +207,21 @@ func namedTypes(pkg *packages.Package) map[string]*types.Named {
// ctx is a single-word context for this assertion, such as "Clone".
// If non-nil, AssertStructUnchanged will add elements to imports
// for each package path that the caller must import for the returned code to compile.
func AssertStructUnchanged(t *types.Struct, tname, ctx string, it *ImportTracker) []byte {
func AssertStructUnchanged(t *types.Struct, tname string, params *types.TypeParamList, ctx string, it *ImportTracker) []byte {
buf := new(bytes.Buffer)
w := func(format string, args ...any) {
fmt.Fprintf(buf, format+"\n", args...)
}
w("// A compilation failure here means this code must be regenerated, with the command at the top of this file.")
w("var _%s%sNeedsRegeneration = %s(struct {", tname, ctx, tname)
hasTypeParams := params != nil && params.Len() > 0
if hasTypeParams {
constraints, identifiers := FormatTypeParams(params, it)
w("func _%s%sNeedsRegeneration%s (%s%s) {", tname, ctx, constraints, tname, identifiers)
w("_%s%sNeedsRegeneration(struct {", tname, ctx)
} else {
w("var _%s%sNeedsRegeneration = %s(struct {", tname, ctx, tname)
}
for i := range t.NumFields() {
st := t.Field(i)
@@ -209,14 +231,25 @@ func AssertStructUnchanged(t *types.Struct, tname, ctx string, it *ImportTracker
continue
}
qname := it.QualifiedName(ft)
var tag string
if hasTypeParams {
tag = t.Tag(i)
if tag != "" {
tag = "`" + tag + "`"
}
}
if st.Anonymous() {
w("\t%s ", fname)
w("\t%s %s", fname, tag)
} else {
w("\t%s %s", fname, qname)
w("\t%s %s %s", fname, qname, tag)
}
}
w("}{})\n")
if hasTypeParams {
w("}{})\n}")
} else {
w("}{})")
}
return buf.Bytes()
}
@@ -242,10 +275,21 @@ func ContainsPointers(typ types.Type) bool {
switch ft := typ.Underlying().(type) {
case *types.Array:
return ContainsPointers(ft.Elem())
case *types.Basic:
if ft.Kind() == types.UnsafePointer {
return true
}
case *types.Chan:
return true
case *types.Interface:
return true // a little too broad
if ft.Empty() || ft.IsMethodSet() {
return true
}
for i := 0; i < ft.NumEmbeddeds(); i++ {
if ContainsPointers(ft.EmbeddedType(i)) {
return true
}
}
case *types.Map:
return true
case *types.Pointer:
@@ -258,6 +302,12 @@ func ContainsPointers(typ types.Type) bool {
return true
}
}
case *types.Union:
for i := range ft.Len() {
if ContainsPointers(ft.Term(i).Type()) {
return true
}
}
}
return false
}
@@ -273,3 +323,44 @@ func IsViewType(typ types.Type) bool {
}
return t.Field(0).Name() == "ж"
}
// FormatTypeParams formats the specified params and returns two strings:
// - constraints are comma-separated type parameters and their constraints in square brackets (e.g. [T any, V constraints.Integer])
// - names are comma-separated type parameter names in square brackets (e.g. [T, V])
//
// If params is nil or empty, both return values are empty strings.
func FormatTypeParams(params *types.TypeParamList, it *ImportTracker) (constraints, names string) {
if params == nil || params.Len() == 0 {
return "", ""
}
var constraintList, nameList []string
for i := range params.Len() {
param := params.At(i)
name := param.Obj().Name()
constraint := it.QualifiedName(param.Constraint())
nameList = append(nameList, name)
constraintList = append(constraintList, name+" "+constraint)
}
constraints = "[" + strings.Join(constraintList, ", ") + "]"
names = "[" + strings.Join(nameList, ", ") + "]"
return constraints, names
}
// LookupMethod returns the method with the specified name in t, or nil if the method does not exist.
func LookupMethod(t types.Type, name string) *types.Func {
if t, ok := t.(*types.Named); ok {
for i := 0; i < t.NumMethods(); i++ {
if method := t.Method(i); method.Name() == name {
return method
}
}
}
if t, ok := t.Underlying().(*types.Interface); ok {
for i := 0; i < t.NumMethods(); i++ {
if method := t.Method(i); method.Name() == name {
return method
}
}
}
return nil
}

View File

@@ -0,0 +1,176 @@
// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause
package codegen
import (
"log"
"net/netip"
"testing"
"unsafe"
"golang.org/x/exp/constraints"
)
type AnyParam[T any] struct {
V T
}
type AnyParamPhantom[T any] struct {
}
type IntegerParam[T constraints.Integer] struct {
V T
}
type FloatParam[T constraints.Float] struct {
V T
}
type StringLikeParam[T ~string] struct {
V T
}
type BasicType interface {
~bool | constraints.Integer | constraints.Float | constraints.Complex | ~string
}
type BasicTypeParam[T BasicType] struct {
V T
}
type IntPtr *int
type IntPtrParam[T IntPtr] struct {
V T
}
type IntegerPtr interface {
*int | *int32 | *int64
}
type IntegerPtrParam[T IntegerPtr] struct {
V T
}
type IntegerParamPtr[T constraints.Integer] struct {
V *T
}
type IntegerSliceParam[T constraints.Integer] struct {
V []T
}
type IntegerMapParam[T constraints.Integer] struct {
V []T
}
type UnsafePointerParam[T unsafe.Pointer] struct {
V T
}
type ValueUnionParam[T netip.Prefix | BasicType] struct {
V T
}
type ValueUnionParamPtr[T netip.Prefix | BasicType] struct {
V *T
}
type PointerUnionParam[T netip.Prefix | BasicType | IntPtr] struct {
V T
}
type Interface interface {
Method()
}
type InterfaceParam[T Interface] struct {
V T
}
func TestGenericContainsPointers(t *testing.T) {
tests := []struct {
typ string
wantPointer bool
}{
{
typ: "AnyParam",
wantPointer: true,
},
{
typ: "AnyParamPhantom",
wantPointer: false, // has a pointer type parameter, but no pointer fields
},
{
typ: "IntegerParam",
wantPointer: false,
},
{
typ: "FloatParam",
wantPointer: false,
},
{
typ: "StringLikeParam",
wantPointer: false,
},
{
typ: "BasicTypeParam",
wantPointer: false,
},
{
typ: "IntPtrParam",
wantPointer: true,
},
{
typ: "IntegerPtrParam",
wantPointer: true,
},
{
typ: "IntegerParamPtr",
wantPointer: true,
},
{
typ: "IntegerSliceParam",
wantPointer: true,
},
{
typ: "IntegerMapParam",
wantPointer: true,
},
{
typ: "UnsafePointerParam",
wantPointer: true,
},
{
typ: "InterfaceParam",
wantPointer: true,
},
{
typ: "ValueUnionParam",
wantPointer: false,
},
{
typ: "ValueUnionParamPtr",
wantPointer: true,
},
{
typ: "PointerUnionParam",
wantPointer: true,
},
}
_, namedTypes, err := LoadTypes("test", ".")
if err != nil {
log.Fatal(err)
}
for _, tt := range tests {
t.Run(tt.typ, func(t *testing.T) {
typ := namedTypes[tt.typ]
if isPointer := ContainsPointers(typ); isPointer != tt.wantPointer {
t.Fatalf("ContainsPointers: got %v, want: %v", isPointer, tt.wantPointer)
}
})
}
}