mirror of
https://github.com/tailscale/tailscale.git
synced 2025-08-11 13:18:53 +00:00
cmd/cloner, cmd/viewer, util/codegen: add support for generic types and interfaces
This adds support for generic types and interfaces to our cloner and viewer codegens. It updates these packages to determine whether to make shallow or deep copies based on the type parameter constraints. Additionally, if a template parameter or an interface type has View() and Clone() methods, we'll use them for getters and the cloner of the owning structure. Updates #12736 Signed-off-by: Nick Khyl <nickk@tailscale.com>
This commit is contained in:
@@ -27,9 +27,9 @@ var flagCopyright = flag.Bool("copyright", true, "add Tailscale copyright to gen
|
||||
func LoadTypes(buildTags string, pkgName string) (*packages.Package, map[string]*types.Named, error) {
|
||||
cfg := &packages.Config{
|
||||
Mode: packages.NeedTypes | packages.NeedTypesInfo | packages.NeedSyntax | packages.NeedName,
|
||||
Tests: false,
|
||||
Tests: buildTags == "test",
|
||||
}
|
||||
if buildTags != "" {
|
||||
if buildTags != "" && !cfg.Tests {
|
||||
cfg.BuildFlags = []string{"-tags=" + buildTags}
|
||||
}
|
||||
|
||||
@@ -37,6 +37,9 @@ func LoadTypes(buildTags string, pkgName string) (*packages.Package, map[string]
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
if cfg.Tests {
|
||||
pkgs = testPackages(pkgs)
|
||||
}
|
||||
if len(pkgs) != 1 {
|
||||
return nil, nil, fmt.Errorf("wrong number of packages: %d", len(pkgs))
|
||||
}
|
||||
@@ -44,6 +47,17 @@ func LoadTypes(buildTags string, pkgName string) (*packages.Package, map[string]
|
||||
return pkg, namedTypes(pkg), nil
|
||||
}
|
||||
|
||||
func testPackages(pkgs []*packages.Package) []*packages.Package {
|
||||
var testPackages []*packages.Package
|
||||
for _, pkg := range pkgs {
|
||||
testPackageID := fmt.Sprintf("%[1]s [%[1]s.test]", pkg.PkgPath)
|
||||
if pkg.ID == testPackageID {
|
||||
testPackages = append(testPackages, pkg)
|
||||
}
|
||||
}
|
||||
return testPackages
|
||||
}
|
||||
|
||||
// HasNoClone reports whether the provided tag has `codegen:noclone`.
|
||||
func HasNoClone(structTag string) bool {
|
||||
val := reflect.StructTag(structTag).Get("codegen")
|
||||
@@ -193,13 +207,21 @@ func namedTypes(pkg *packages.Package) map[string]*types.Named {
|
||||
// ctx is a single-word context for this assertion, such as "Clone".
|
||||
// If non-nil, AssertStructUnchanged will add elements to imports
|
||||
// for each package path that the caller must import for the returned code to compile.
|
||||
func AssertStructUnchanged(t *types.Struct, tname, ctx string, it *ImportTracker) []byte {
|
||||
func AssertStructUnchanged(t *types.Struct, tname string, params *types.TypeParamList, ctx string, it *ImportTracker) []byte {
|
||||
buf := new(bytes.Buffer)
|
||||
w := func(format string, args ...any) {
|
||||
fmt.Fprintf(buf, format+"\n", args...)
|
||||
}
|
||||
w("// A compilation failure here means this code must be regenerated, with the command at the top of this file.")
|
||||
w("var _%s%sNeedsRegeneration = %s(struct {", tname, ctx, tname)
|
||||
|
||||
hasTypeParams := params != nil && params.Len() > 0
|
||||
if hasTypeParams {
|
||||
constraints, identifiers := FormatTypeParams(params, it)
|
||||
w("func _%s%sNeedsRegeneration%s (%s%s) {", tname, ctx, constraints, tname, identifiers)
|
||||
w("_%s%sNeedsRegeneration(struct {", tname, ctx)
|
||||
} else {
|
||||
w("var _%s%sNeedsRegeneration = %s(struct {", tname, ctx, tname)
|
||||
}
|
||||
|
||||
for i := range t.NumFields() {
|
||||
st := t.Field(i)
|
||||
@@ -209,14 +231,25 @@ func AssertStructUnchanged(t *types.Struct, tname, ctx string, it *ImportTracker
|
||||
continue
|
||||
}
|
||||
qname := it.QualifiedName(ft)
|
||||
var tag string
|
||||
if hasTypeParams {
|
||||
tag = t.Tag(i)
|
||||
if tag != "" {
|
||||
tag = "`" + tag + "`"
|
||||
}
|
||||
}
|
||||
if st.Anonymous() {
|
||||
w("\t%s ", fname)
|
||||
w("\t%s %s", fname, tag)
|
||||
} else {
|
||||
w("\t%s %s", fname, qname)
|
||||
w("\t%s %s %s", fname, qname, tag)
|
||||
}
|
||||
}
|
||||
|
||||
w("}{})\n")
|
||||
if hasTypeParams {
|
||||
w("}{})\n}")
|
||||
} else {
|
||||
w("}{})")
|
||||
}
|
||||
return buf.Bytes()
|
||||
}
|
||||
|
||||
@@ -242,10 +275,21 @@ func ContainsPointers(typ types.Type) bool {
|
||||
switch ft := typ.Underlying().(type) {
|
||||
case *types.Array:
|
||||
return ContainsPointers(ft.Elem())
|
||||
case *types.Basic:
|
||||
if ft.Kind() == types.UnsafePointer {
|
||||
return true
|
||||
}
|
||||
case *types.Chan:
|
||||
return true
|
||||
case *types.Interface:
|
||||
return true // a little too broad
|
||||
if ft.Empty() || ft.IsMethodSet() {
|
||||
return true
|
||||
}
|
||||
for i := 0; i < ft.NumEmbeddeds(); i++ {
|
||||
if ContainsPointers(ft.EmbeddedType(i)) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
case *types.Map:
|
||||
return true
|
||||
case *types.Pointer:
|
||||
@@ -258,6 +302,12 @@ func ContainsPointers(typ types.Type) bool {
|
||||
return true
|
||||
}
|
||||
}
|
||||
case *types.Union:
|
||||
for i := range ft.Len() {
|
||||
if ContainsPointers(ft.Term(i).Type()) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
@@ -273,3 +323,44 @@ func IsViewType(typ types.Type) bool {
|
||||
}
|
||||
return t.Field(0).Name() == "ж"
|
||||
}
|
||||
|
||||
// FormatTypeParams formats the specified params and returns two strings:
|
||||
// - constraints are comma-separated type parameters and their constraints in square brackets (e.g. [T any, V constraints.Integer])
|
||||
// - names are comma-separated type parameter names in square brackets (e.g. [T, V])
|
||||
//
|
||||
// If params is nil or empty, both return values are empty strings.
|
||||
func FormatTypeParams(params *types.TypeParamList, it *ImportTracker) (constraints, names string) {
|
||||
if params == nil || params.Len() == 0 {
|
||||
return "", ""
|
||||
}
|
||||
var constraintList, nameList []string
|
||||
for i := range params.Len() {
|
||||
param := params.At(i)
|
||||
name := param.Obj().Name()
|
||||
constraint := it.QualifiedName(param.Constraint())
|
||||
nameList = append(nameList, name)
|
||||
constraintList = append(constraintList, name+" "+constraint)
|
||||
}
|
||||
constraints = "[" + strings.Join(constraintList, ", ") + "]"
|
||||
names = "[" + strings.Join(nameList, ", ") + "]"
|
||||
return constraints, names
|
||||
}
|
||||
|
||||
// LookupMethod returns the method with the specified name in t, or nil if the method does not exist.
|
||||
func LookupMethod(t types.Type, name string) *types.Func {
|
||||
if t, ok := t.(*types.Named); ok {
|
||||
for i := 0; i < t.NumMethods(); i++ {
|
||||
if method := t.Method(i); method.Name() == name {
|
||||
return method
|
||||
}
|
||||
}
|
||||
}
|
||||
if t, ok := t.Underlying().(*types.Interface); ok {
|
||||
for i := 0; i < t.NumMethods(); i++ {
|
||||
if method := t.Method(i); method.Name() == name {
|
||||
return method
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
176
util/codegen/codegen_test.go
Normal file
176
util/codegen/codegen_test.go
Normal file
@@ -0,0 +1,176 @@
|
||||
// Copyright (c) Tailscale Inc & AUTHORS
|
||||
// SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
package codegen
|
||||
|
||||
import (
|
||||
"log"
|
||||
"net/netip"
|
||||
"testing"
|
||||
"unsafe"
|
||||
|
||||
"golang.org/x/exp/constraints"
|
||||
)
|
||||
|
||||
type AnyParam[T any] struct {
|
||||
V T
|
||||
}
|
||||
|
||||
type AnyParamPhantom[T any] struct {
|
||||
}
|
||||
|
||||
type IntegerParam[T constraints.Integer] struct {
|
||||
V T
|
||||
}
|
||||
|
||||
type FloatParam[T constraints.Float] struct {
|
||||
V T
|
||||
}
|
||||
|
||||
type StringLikeParam[T ~string] struct {
|
||||
V T
|
||||
}
|
||||
|
||||
type BasicType interface {
|
||||
~bool | constraints.Integer | constraints.Float | constraints.Complex | ~string
|
||||
}
|
||||
|
||||
type BasicTypeParam[T BasicType] struct {
|
||||
V T
|
||||
}
|
||||
|
||||
type IntPtr *int
|
||||
|
||||
type IntPtrParam[T IntPtr] struct {
|
||||
V T
|
||||
}
|
||||
|
||||
type IntegerPtr interface {
|
||||
*int | *int32 | *int64
|
||||
}
|
||||
|
||||
type IntegerPtrParam[T IntegerPtr] struct {
|
||||
V T
|
||||
}
|
||||
|
||||
type IntegerParamPtr[T constraints.Integer] struct {
|
||||
V *T
|
||||
}
|
||||
|
||||
type IntegerSliceParam[T constraints.Integer] struct {
|
||||
V []T
|
||||
}
|
||||
|
||||
type IntegerMapParam[T constraints.Integer] struct {
|
||||
V []T
|
||||
}
|
||||
|
||||
type UnsafePointerParam[T unsafe.Pointer] struct {
|
||||
V T
|
||||
}
|
||||
|
||||
type ValueUnionParam[T netip.Prefix | BasicType] struct {
|
||||
V T
|
||||
}
|
||||
|
||||
type ValueUnionParamPtr[T netip.Prefix | BasicType] struct {
|
||||
V *T
|
||||
}
|
||||
|
||||
type PointerUnionParam[T netip.Prefix | BasicType | IntPtr] struct {
|
||||
V T
|
||||
}
|
||||
|
||||
type Interface interface {
|
||||
Method()
|
||||
}
|
||||
|
||||
type InterfaceParam[T Interface] struct {
|
||||
V T
|
||||
}
|
||||
|
||||
func TestGenericContainsPointers(t *testing.T) {
|
||||
tests := []struct {
|
||||
typ string
|
||||
wantPointer bool
|
||||
}{
|
||||
{
|
||||
typ: "AnyParam",
|
||||
wantPointer: true,
|
||||
},
|
||||
{
|
||||
typ: "AnyParamPhantom",
|
||||
wantPointer: false, // has a pointer type parameter, but no pointer fields
|
||||
},
|
||||
{
|
||||
typ: "IntegerParam",
|
||||
wantPointer: false,
|
||||
},
|
||||
{
|
||||
typ: "FloatParam",
|
||||
wantPointer: false,
|
||||
},
|
||||
{
|
||||
typ: "StringLikeParam",
|
||||
wantPointer: false,
|
||||
},
|
||||
{
|
||||
typ: "BasicTypeParam",
|
||||
wantPointer: false,
|
||||
},
|
||||
{
|
||||
typ: "IntPtrParam",
|
||||
wantPointer: true,
|
||||
},
|
||||
{
|
||||
typ: "IntegerPtrParam",
|
||||
wantPointer: true,
|
||||
},
|
||||
{
|
||||
typ: "IntegerParamPtr",
|
||||
wantPointer: true,
|
||||
},
|
||||
{
|
||||
typ: "IntegerSliceParam",
|
||||
wantPointer: true,
|
||||
},
|
||||
{
|
||||
typ: "IntegerMapParam",
|
||||
wantPointer: true,
|
||||
},
|
||||
{
|
||||
typ: "UnsafePointerParam",
|
||||
wantPointer: true,
|
||||
},
|
||||
{
|
||||
typ: "InterfaceParam",
|
||||
wantPointer: true,
|
||||
},
|
||||
{
|
||||
typ: "ValueUnionParam",
|
||||
wantPointer: false,
|
||||
},
|
||||
{
|
||||
typ: "ValueUnionParamPtr",
|
||||
wantPointer: true,
|
||||
},
|
||||
{
|
||||
typ: "PointerUnionParam",
|
||||
wantPointer: true,
|
||||
},
|
||||
}
|
||||
|
||||
_, namedTypes, err := LoadTypes("test", ".")
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.typ, func(t *testing.T) {
|
||||
typ := namedTypes[tt.typ]
|
||||
if isPointer := ContainsPointers(typ); isPointer != tt.wantPointer {
|
||||
t.Fatalf("ContainsPointers: got %v, want: %v", isPointer, tt.wantPointer)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
Reference in New Issue
Block a user