util/codegen, cmd/cloner, cmd/viewer: update codegen.LookupMethod to support alias type nodes

Go 1.23 updates the go/types package to produce Alias type nodes for type aliases, unless disabled with gotypesalias=0.
This new default behavior breaks codegen.LookupMethod, which uses checked type assertions to types.Named and
types.Interface, as only named types and interfaces have methods.

In this PR, we update codegen.LookupMethod to perform method lookup on the right-hand side of the alias declaration
and clearly switch on the supported type nodes types. We also improve support for various edge cases, such as when an alias
is used as a type parameter constraint, and add tests for the LookupMethod function.

Additionally, we update cmd/viewer/tests to include types with aliases used in type fields and generic type constraints.

Updates #13224
Updates #12912

Signed-off-by: Nick Khyl <nickk@tailscale.com>
This commit is contained in:
Nick Khyl
2024-08-22 16:33:35 -05:00
committed by Nick Khyl
parent aa42ae9058
commit a9dc6e07ad
7 changed files with 393 additions and 24 deletions

View File

@@ -24,7 +24,7 @@ import (
var flagCopyright = flag.Bool("copyright", true, "add Tailscale copyright to generated file headers")
// LoadTypes returns all named types in pkgName, keyed by their type name.
func LoadTypes(buildTags string, pkgName string) (*packages.Package, map[string]*types.Named, error) {
func LoadTypes(buildTags string, pkgName string) (*packages.Package, map[string]types.Type, error) {
cfg := &packages.Config{
Mode: packages.NeedTypes | packages.NeedTypesInfo | packages.NeedSyntax | packages.NeedName,
Tests: buildTags == "test",
@@ -181,8 +181,8 @@ func writeFormatted(code []byte, path string) error {
}
// namedTypes returns all named types in pkg, keyed by their type name.
func namedTypes(pkg *packages.Package) map[string]*types.Named {
nt := make(map[string]*types.Named)
func namedTypes(pkg *packages.Package) map[string]types.Type {
nt := make(map[string]types.Type)
for _, file := range pkg.Syntax {
for _, d := range file.Decls {
decl, ok := d.(*ast.GenDecl)
@@ -198,11 +198,10 @@ func namedTypes(pkg *packages.Package) map[string]*types.Named {
if !ok {
continue
}
typ, ok := typeNameObj.Type().(*types.Named)
if !ok {
continue
switch typ := typeNameObj.Type(); typ.(type) {
case *types.Alias, *types.Named:
nt[spec.Name.Name] = typ
}
nt[spec.Name.Name] = typ
}
}
}
@@ -356,14 +355,25 @@ func FormatTypeParams(params *types.TypeParamList, it *ImportTracker) (constrain
// LookupMethod returns the method with the specified name in t, or nil if the method does not exist.
func LookupMethod(t types.Type, name string) *types.Func {
if t, ok := t.(*types.Named); ok {
for i := 0; i < t.NumMethods(); i++ {
if method := t.Method(i); method.Name() == name {
return method
switch t := t.(type) {
case *types.Alias:
return LookupMethod(t.Rhs(), name)
case *types.TypeParam:
return LookupMethod(t.Constraint(), name)
case *types.Pointer:
return LookupMethod(t.Elem(), name)
case *types.Named:
switch u := t.Underlying().(type) {
case *types.Interface:
return LookupMethod(u, name)
default:
for i := 0; i < t.NumMethods(); i++ {
if method := t.Method(i); method.Name() == name {
return method
}
}
}
}
if t, ok := t.Underlying().(*types.Interface); ok {
case *types.Interface:
for i := 0; i < t.NumMethods(); i++ {
if method := t.Method(i); method.Name() == name {
return method

View File

@@ -4,10 +4,11 @@
package codegen
import (
"cmp"
"go/types"
"log"
"net/netip"
"strings"
"sync"
"testing"
"unsafe"
@@ -162,14 +163,9 @@ func TestGenericContainsPointers(t *testing.T) {
},
}
_, namedTypes, err := LoadTypes("test", ".")
if err != nil {
log.Fatal(err)
}
for _, tt := range tests {
t.Run(tt.typ, func(t *testing.T) {
typ := namedTypes[tt.typ]
typ := lookupTestType(t, tt.typ)
if isPointer := ContainsPointers(typ); isPointer != tt.wantPointer {
t.Fatalf("ContainsPointers: got %v, want: %v", isPointer, tt.wantPointer)
}
@@ -252,3 +248,199 @@ func TestAssertStructUnchanged(t *testing.T) {
})
}
}
type NamedType struct{}
func (NamedType) Method() {}
type NamedTypeAlias = NamedType
type NamedInterface interface {
Method()
}
type NamedInterfaceAlias = NamedInterface
type GenericType[T NamedInterface] struct {
TypeParamField T
TypeParamPtrField *T
}
type GenericTypeWithAliasConstraint[T NamedInterfaceAlias] struct {
TypeParamField T
TypeParamPtrField *T
}
func TestLookupMethod(t *testing.T) {
tests := []struct {
name string
typ types.Type
methodName string
wantHasMethod bool
wantReceiver types.Type
}{
{
name: "NamedType/HasMethod",
typ: lookupTestType(t, "NamedType"),
methodName: "Method",
wantHasMethod: true,
},
{
name: "NamedType/NoMethod",
typ: lookupTestType(t, "NamedType"),
methodName: "NoMethod",
wantHasMethod: false,
},
{
name: "NamedTypeAlias/HasMethod",
typ: lookupTestType(t, "NamedTypeAlias"),
methodName: "Method",
wantHasMethod: true,
wantReceiver: lookupTestType(t, "NamedType"),
},
{
name: "NamedTypeAlias/NoMethod",
typ: lookupTestType(t, "NamedTypeAlias"),
methodName: "NoMethod",
wantHasMethod: false,
},
{
name: "PtrToNamedType/HasMethod",
typ: types.NewPointer(lookupTestType(t, "NamedType")),
methodName: "Method",
wantHasMethod: true,
wantReceiver: lookupTestType(t, "NamedType"),
},
{
name: "PtrToNamedType/NoMethod",
typ: types.NewPointer(lookupTestType(t, "NamedType")),
methodName: "NoMethod",
wantHasMethod: false,
},
{
name: "PtrToNamedTypeAlias/HasMethod",
typ: types.NewPointer(lookupTestType(t, "NamedTypeAlias")),
methodName: "Method",
wantHasMethod: true,
wantReceiver: lookupTestType(t, "NamedType"),
},
{
name: "PtrToNamedTypeAlias/NoMethod",
typ: types.NewPointer(lookupTestType(t, "NamedTypeAlias")),
methodName: "NoMethod",
wantHasMethod: false,
},
{
name: "NamedInterface/HasMethod",
typ: lookupTestType(t, "NamedInterface"),
methodName: "Method",
wantHasMethod: true,
},
{
name: "NamedInterface/NoMethod",
typ: lookupTestType(t, "NamedInterface"),
methodName: "NoMethod",
wantHasMethod: false,
},
{
name: "Interface/HasMethod",
typ: types.NewInterfaceType([]*types.Func{types.NewFunc(0, nil, "Method", types.NewSignatureType(nil, nil, nil, nil, nil, false))}, nil),
methodName: "Method",
wantHasMethod: true,
},
{
name: "Interface/NoMethod",
typ: types.NewInterfaceType(nil, nil),
methodName: "NoMethod",
wantHasMethod: false,
},
{
name: "TypeParam/HasMethod",
typ: lookupTestType(t, "GenericType").Underlying().(*types.Struct).Field(0).Type(),
methodName: "Method",
wantHasMethod: true,
wantReceiver: lookupTestType(t, "NamedInterface"),
},
{
name: "TypeParam/NoMethod",
typ: lookupTestType(t, "GenericType").Underlying().(*types.Struct).Field(0).Type(),
methodName: "NoMethod",
wantHasMethod: false,
},
{
name: "TypeParamPtr/HasMethod",
typ: lookupTestType(t, "GenericType").Underlying().(*types.Struct).Field(1).Type(),
methodName: "Method",
wantHasMethod: true,
wantReceiver: lookupTestType(t, "NamedInterface"),
},
{
name: "TypeParamPtr/NoMethod",
typ: lookupTestType(t, "GenericType").Underlying().(*types.Struct).Field(1).Type(),
methodName: "NoMethod",
wantHasMethod: false,
},
{
name: "TypeParamWithAlias/HasMethod",
typ: lookupTestType(t, "GenericTypeWithAliasConstraint").Underlying().(*types.Struct).Field(0).Type(),
methodName: "Method",
wantHasMethod: true,
wantReceiver: lookupTestType(t, "NamedInterface"),
},
{
name: "TypeParamWithAlias/NoMethod",
typ: lookupTestType(t, "GenericTypeWithAliasConstraint").Underlying().(*types.Struct).Field(0).Type(),
methodName: "NoMethod",
wantHasMethod: false,
},
{
name: "TypeParamWithAliasPtr/HasMethod",
typ: lookupTestType(t, "GenericTypeWithAliasConstraint").Underlying().(*types.Struct).Field(1).Type(),
methodName: "Method",
wantHasMethod: true,
wantReceiver: lookupTestType(t, "NamedInterface"),
},
{
name: "TypeParamWithAliasPtr/NoMethod",
typ: lookupTestType(t, "GenericTypeWithAliasConstraint").Underlying().(*types.Struct).Field(1).Type(),
methodName: "NoMethod",
wantHasMethod: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
gotMethod := LookupMethod(tt.typ, tt.methodName)
if gotHasMethod := gotMethod != nil; gotHasMethod != tt.wantHasMethod {
t.Fatalf("HasMethod: got %v; want %v", gotMethod, tt.wantHasMethod)
}
if gotMethod == nil {
return
}
if gotMethod.Name() != tt.methodName {
t.Errorf("Name: got %v; want %v", gotMethod.Name(), tt.methodName)
}
if gotRecv, wantRecv := gotMethod.Signature().Recv().Type(), cmp.Or(tt.wantReceiver, tt.typ); !types.Identical(gotRecv, wantRecv) {
t.Errorf("Recv: got %v; want %v", gotRecv, wantRecv)
}
})
}
}
var namedTestTypes = sync.OnceValues(func() (map[string]types.Type, error) {
_, namedTypes, err := LoadTypes("test", ".")
return namedTypes, err
})
func lookupTestType(t *testing.T, name string) types.Type {
t.Helper()
types, err := namedTestTypes()
if err != nil {
t.Fatal(err)
}
typ, ok := types[name]
if !ok {
t.Fatalf("type %q is not declared in the current package", name)
}
return typ
}