diff --git a/cmd/cloner/cloner.go b/cmd/cloner/cloner.go index db354af3b..23f3e219c 100644 --- a/cmd/cloner/cloner.go +++ b/cmd/cloner/cloner.go @@ -47,7 +47,7 @@ func main() { it := codegen.NewImportTracker(pkg.Types) buf := new(bytes.Buffer) for _, typeName := range typeNames { - typ, ok := namedTypes[typeName] + typ, ok := namedTypes[typeName].(*types.Named) if !ok { log.Fatalf("could not find type %s", typeName) } diff --git a/cmd/viewer/tests/tests.go b/cmd/viewer/tests/tests.go index 1f1ec0557..8f5dc23af 100644 --- a/cmd/viewer/tests/tests.go +++ b/cmd/viewer/tests/tests.go @@ -13,7 +13,7 @@ import ( "tailscale.com/types/views" ) -//go:generate go run tailscale.com/cmd/viewer --type=StructWithPtrs,StructWithoutPtrs,Map,StructWithSlices,OnlyGetClone,StructWithEmbedded,GenericIntStruct,GenericNoPtrsStruct,GenericCloneableStruct,StructWithContainers --clone-only-type=OnlyGetClone +//go:generate go run tailscale.com/cmd/viewer --type=StructWithPtrs,StructWithoutPtrs,Map,StructWithSlices,OnlyGetClone,StructWithEmbedded,GenericIntStruct,GenericNoPtrsStruct,GenericCloneableStruct,StructWithContainers,StructWithTypeAliasFields,GenericTypeAliasStruct --clone-only-type=OnlyGetClone type StructWithoutPtrs struct { Int int @@ -202,3 +202,20 @@ type StructWithContainers struct { CloneableMap MapContainer[int, *StructWithPtrs] CloneableGenericMap MapContainer[int, *GenericNoPtrsStruct[int]] } + +type ( + StructWithPtrsAlias = StructWithPtrs + StructWithoutPtrsAlias = StructWithoutPtrs +) + +type StructWithTypeAliasFields struct { + WithPtr StructWithPtrsAlias + WithoutPtr StructWithoutPtrsAlias +} + +type integer = constraints.Integer + +type GenericTypeAliasStruct[T integer, T2 views.ViewCloner[T2, V2], V2 views.StructView[T2]] struct { + NonCloneable T + Cloneable T2 +} diff --git a/cmd/viewer/tests/tests_clone.go b/cmd/viewer/tests/tests_clone.go index 53e6bacfb..542512787 100644 --- a/cmd/viewer/tests/tests_clone.go +++ b/cmd/viewer/tests/tests_clone.go @@ -441,3 +441,41 @@ var _StructWithContainersCloneNeedsRegeneration = StructWithContainers(struct { CloneableMap MapContainer[int, *StructWithPtrs] CloneableGenericMap MapContainer[int, *GenericNoPtrsStruct[int]] }{}) + +// Clone makes a deep copy of StructWithTypeAliasFields. +// The result aliases no memory with the original. +func (src *StructWithTypeAliasFields) Clone() *StructWithTypeAliasFields { + if src == nil { + return nil + } + dst := new(StructWithTypeAliasFields) + *dst = *src + panic("TODO: WithPtr (*types.Struct)") + return dst +} + +// A compilation failure here means this code must be regenerated, with the command at the top of this file. +var _StructWithTypeAliasFieldsCloneNeedsRegeneration = StructWithTypeAliasFields(struct { + WithPtr StructWithPtrsAlias + WithoutPtr StructWithoutPtrsAlias +}{}) + +// Clone makes a deep copy of GenericTypeAliasStruct. +// The result aliases no memory with the original. +func (src *GenericTypeAliasStruct[T, T2, V2]) Clone() *GenericTypeAliasStruct[T, T2, V2] { + if src == nil { + return nil + } + dst := new(GenericTypeAliasStruct[T, T2, V2]) + *dst = *src + dst.Cloneable = src.Cloneable.Clone() + return dst +} + +// A compilation failure here means this code must be regenerated, with the command at the top of this file. +func _GenericTypeAliasStructCloneNeedsRegeneration[T integer, T2 views.ViewCloner[T2, V2], V2 views.StructView[T2]](GenericTypeAliasStruct[T, T2, V2]) { + _GenericTypeAliasStructCloneNeedsRegeneration(struct { + NonCloneable T + Cloneable T2 + }{}) +} diff --git a/cmd/viewer/tests/tests_view.go b/cmd/viewer/tests/tests_view.go index cf07dc663..f0cbf1564 100644 --- a/cmd/viewer/tests/tests_view.go +++ b/cmd/viewer/tests/tests_view.go @@ -14,7 +14,7 @@ import ( "tailscale.com/types/views" ) -//go:generate go run tailscale.com/cmd/cloner -clonefunc=false -type=StructWithPtrs,StructWithoutPtrs,Map,StructWithSlices,OnlyGetClone,StructWithEmbedded,GenericIntStruct,GenericNoPtrsStruct,GenericCloneableStruct,StructWithContainers +//go:generate go run tailscale.com/cmd/cloner -clonefunc=false -type=StructWithPtrs,StructWithoutPtrs,Map,StructWithSlices,OnlyGetClone,StructWithEmbedded,GenericIntStruct,GenericNoPtrsStruct,GenericCloneableStruct,StructWithContainers,StructWithTypeAliasFields,GenericTypeAliasStruct // View returns a readonly view of StructWithPtrs. func (p *StructWithPtrs) View() StructWithPtrsView { @@ -676,3 +676,115 @@ var _StructWithContainersViewNeedsRegeneration = StructWithContainers(struct { CloneableMap MapContainer[int, *StructWithPtrs] CloneableGenericMap MapContainer[int, *GenericNoPtrsStruct[int]] }{}) + +// View returns a readonly view of StructWithTypeAliasFields. +func (p *StructWithTypeAliasFields) View() StructWithTypeAliasFieldsView { + return StructWithTypeAliasFieldsView{ж: p} +} + +// StructWithTypeAliasFieldsView provides a read-only view over StructWithTypeAliasFields. +// +// Its methods should only be called if `Valid()` returns true. +type StructWithTypeAliasFieldsView struct { + // ж is the underlying mutable value, named with a hard-to-type + // character that looks pointy like a pointer. + // It is named distinctively to make you think of how dangerous it is to escape + // to callers. You must not let callers be able to mutate it. + ж *StructWithTypeAliasFields +} + +// Valid reports whether underlying value is non-nil. +func (v StructWithTypeAliasFieldsView) Valid() bool { return v.ж != nil } + +// AsStruct returns a clone of the underlying value which aliases no memory with +// the original. +func (v StructWithTypeAliasFieldsView) AsStruct() *StructWithTypeAliasFields { + if v.ж == nil { + return nil + } + return v.ж.Clone() +} + +func (v StructWithTypeAliasFieldsView) MarshalJSON() ([]byte, error) { return json.Marshal(v.ж) } + +func (v *StructWithTypeAliasFieldsView) UnmarshalJSON(b []byte) error { + if v.ж != nil { + return errors.New("already initialized") + } + if len(b) == 0 { + return nil + } + var x StructWithTypeAliasFields + if err := json.Unmarshal(b, &x); err != nil { + return err + } + v.ж = &x + return nil +} + +func (v StructWithTypeAliasFieldsView) WithPtr() StructWithPtrsView { return v.ж.WithPtr.View() } +func (v StructWithTypeAliasFieldsView) WithoutPtr() StructWithoutPtrsAlias { return v.ж.WithoutPtr } + +// A compilation failure here means this code must be regenerated, with the command at the top of this file. +var _StructWithTypeAliasFieldsViewNeedsRegeneration = StructWithTypeAliasFields(struct { + WithPtr StructWithPtrsAlias + WithoutPtr StructWithoutPtrsAlias +}{}) + +// View returns a readonly view of GenericTypeAliasStruct. +func (p *GenericTypeAliasStruct[T, T2, V2]) View() GenericTypeAliasStructView[T, T2, V2] { + return GenericTypeAliasStructView[T, T2, V2]{ж: p} +} + +// GenericTypeAliasStructView[T, T2, V2] provides a read-only view over GenericTypeAliasStruct[T, T2, V2]. +// +// Its methods should only be called if `Valid()` returns true. +type GenericTypeAliasStructView[T integer, T2 views.ViewCloner[T2, V2], V2 views.StructView[T2]] struct { + // ж is the underlying mutable value, named with a hard-to-type + // character that looks pointy like a pointer. + // It is named distinctively to make you think of how dangerous it is to escape + // to callers. You must not let callers be able to mutate it. + ж *GenericTypeAliasStruct[T, T2, V2] +} + +// Valid reports whether underlying value is non-nil. +func (v GenericTypeAliasStructView[T, T2, V2]) Valid() bool { return v.ж != nil } + +// AsStruct returns a clone of the underlying value which aliases no memory with +// the original. +func (v GenericTypeAliasStructView[T, T2, V2]) AsStruct() *GenericTypeAliasStruct[T, T2, V2] { + if v.ж == nil { + return nil + } + return v.ж.Clone() +} + +func (v GenericTypeAliasStructView[T, T2, V2]) MarshalJSON() ([]byte, error) { + return json.Marshal(v.ж) +} + +func (v *GenericTypeAliasStructView[T, T2, V2]) UnmarshalJSON(b []byte) error { + if v.ж != nil { + return errors.New("already initialized") + } + if len(b) == 0 { + return nil + } + var x GenericTypeAliasStruct[T, T2, V2] + if err := json.Unmarshal(b, &x); err != nil { + return err + } + v.ж = &x + return nil +} + +func (v GenericTypeAliasStructView[T, T2, V2]) NonCloneable() T { return v.ж.NonCloneable } +func (v GenericTypeAliasStructView[T, T2, V2]) Cloneable() V2 { return v.ж.Cloneable.View() } + +// A compilation failure here means this code must be regenerated, with the command at the top of this file. +func _GenericTypeAliasStructViewNeedsRegeneration[T integer, T2 views.ViewCloner[T2, V2], V2 views.StructView[T2]](GenericTypeAliasStruct[T, T2, V2]) { + _GenericTypeAliasStructViewNeedsRegeneration(struct { + NonCloneable T + Cloneable T2 + }{}) +} diff --git a/cmd/viewer/viewer.go b/cmd/viewer/viewer.go index 2e122a128..e4e56163d 100644 --- a/cmd/viewer/viewer.go +++ b/cmd/viewer/viewer.go @@ -567,7 +567,7 @@ func main() { if cloneOnlyType[typeName] { continue } - typ, ok := namedTypes[typeName] + typ, ok := namedTypes[typeName].(*types.Named) if !ok { log.Fatalf("could not find type %s", typeName) } diff --git a/util/codegen/codegen.go b/util/codegen/codegen.go index 3ef4b9cc1..4e2c86909 100644 --- a/util/codegen/codegen.go +++ b/util/codegen/codegen.go @@ -24,7 +24,7 @@ import ( var flagCopyright = flag.Bool("copyright", true, "add Tailscale copyright to generated file headers") // LoadTypes returns all named types in pkgName, keyed by their type name. -func LoadTypes(buildTags string, pkgName string) (*packages.Package, map[string]*types.Named, error) { +func LoadTypes(buildTags string, pkgName string) (*packages.Package, map[string]types.Type, error) { cfg := &packages.Config{ Mode: packages.NeedTypes | packages.NeedTypesInfo | packages.NeedSyntax | packages.NeedName, Tests: buildTags == "test", @@ -181,8 +181,8 @@ func writeFormatted(code []byte, path string) error { } // namedTypes returns all named types in pkg, keyed by their type name. -func namedTypes(pkg *packages.Package) map[string]*types.Named { - nt := make(map[string]*types.Named) +func namedTypes(pkg *packages.Package) map[string]types.Type { + nt := make(map[string]types.Type) for _, file := range pkg.Syntax { for _, d := range file.Decls { decl, ok := d.(*ast.GenDecl) @@ -198,11 +198,10 @@ func namedTypes(pkg *packages.Package) map[string]*types.Named { if !ok { continue } - typ, ok := typeNameObj.Type().(*types.Named) - if !ok { - continue + switch typ := typeNameObj.Type(); typ.(type) { + case *types.Alias, *types.Named: + nt[spec.Name.Name] = typ } - nt[spec.Name.Name] = typ } } } @@ -356,14 +355,25 @@ func FormatTypeParams(params *types.TypeParamList, it *ImportTracker) (constrain // LookupMethod returns the method with the specified name in t, or nil if the method does not exist. func LookupMethod(t types.Type, name string) *types.Func { - if t, ok := t.(*types.Named); ok { - for i := 0; i < t.NumMethods(); i++ { - if method := t.Method(i); method.Name() == name { - return method + switch t := t.(type) { + case *types.Alias: + return LookupMethod(t.Rhs(), name) + case *types.TypeParam: + return LookupMethod(t.Constraint(), name) + case *types.Pointer: + return LookupMethod(t.Elem(), name) + case *types.Named: + switch u := t.Underlying().(type) { + case *types.Interface: + return LookupMethod(u, name) + default: + for i := 0; i < t.NumMethods(); i++ { + if method := t.Method(i); method.Name() == name { + return method + } } } - } - if t, ok := t.Underlying().(*types.Interface); ok { + case *types.Interface: for i := 0; i < t.NumMethods(); i++ { if method := t.Method(i); method.Name() == name { return method diff --git a/util/codegen/codegen_test.go b/util/codegen/codegen_test.go index 9c61da51d..28ddaed2b 100644 --- a/util/codegen/codegen_test.go +++ b/util/codegen/codegen_test.go @@ -4,10 +4,11 @@ package codegen import ( + "cmp" "go/types" - "log" "net/netip" "strings" + "sync" "testing" "unsafe" @@ -162,14 +163,9 @@ func TestGenericContainsPointers(t *testing.T) { }, } - _, namedTypes, err := LoadTypes("test", ".") - if err != nil { - log.Fatal(err) - } - for _, tt := range tests { t.Run(tt.typ, func(t *testing.T) { - typ := namedTypes[tt.typ] + typ := lookupTestType(t, tt.typ) if isPointer := ContainsPointers(typ); isPointer != tt.wantPointer { t.Fatalf("ContainsPointers: got %v, want: %v", isPointer, tt.wantPointer) } @@ -252,3 +248,199 @@ func TestAssertStructUnchanged(t *testing.T) { }) } } + +type NamedType struct{} + +func (NamedType) Method() {} + +type NamedTypeAlias = NamedType + +type NamedInterface interface { + Method() +} + +type NamedInterfaceAlias = NamedInterface + +type GenericType[T NamedInterface] struct { + TypeParamField T + TypeParamPtrField *T +} + +type GenericTypeWithAliasConstraint[T NamedInterfaceAlias] struct { + TypeParamField T + TypeParamPtrField *T +} + +func TestLookupMethod(t *testing.T) { + tests := []struct { + name string + typ types.Type + methodName string + wantHasMethod bool + wantReceiver types.Type + }{ + { + name: "NamedType/HasMethod", + typ: lookupTestType(t, "NamedType"), + methodName: "Method", + wantHasMethod: true, + }, + { + name: "NamedType/NoMethod", + typ: lookupTestType(t, "NamedType"), + methodName: "NoMethod", + wantHasMethod: false, + }, + { + name: "NamedTypeAlias/HasMethod", + typ: lookupTestType(t, "NamedTypeAlias"), + methodName: "Method", + wantHasMethod: true, + wantReceiver: lookupTestType(t, "NamedType"), + }, + { + name: "NamedTypeAlias/NoMethod", + typ: lookupTestType(t, "NamedTypeAlias"), + methodName: "NoMethod", + wantHasMethod: false, + }, + { + name: "PtrToNamedType/HasMethod", + typ: types.NewPointer(lookupTestType(t, "NamedType")), + methodName: "Method", + wantHasMethod: true, + wantReceiver: lookupTestType(t, "NamedType"), + }, + { + name: "PtrToNamedType/NoMethod", + typ: types.NewPointer(lookupTestType(t, "NamedType")), + methodName: "NoMethod", + wantHasMethod: false, + }, + { + name: "PtrToNamedTypeAlias/HasMethod", + typ: types.NewPointer(lookupTestType(t, "NamedTypeAlias")), + methodName: "Method", + wantHasMethod: true, + wantReceiver: lookupTestType(t, "NamedType"), + }, + { + name: "PtrToNamedTypeAlias/NoMethod", + typ: types.NewPointer(lookupTestType(t, "NamedTypeAlias")), + methodName: "NoMethod", + wantHasMethod: false, + }, + { + name: "NamedInterface/HasMethod", + typ: lookupTestType(t, "NamedInterface"), + methodName: "Method", + wantHasMethod: true, + }, + { + name: "NamedInterface/NoMethod", + typ: lookupTestType(t, "NamedInterface"), + methodName: "NoMethod", + wantHasMethod: false, + }, + { + name: "Interface/HasMethod", + typ: types.NewInterfaceType([]*types.Func{types.NewFunc(0, nil, "Method", types.NewSignatureType(nil, nil, nil, nil, nil, false))}, nil), + methodName: "Method", + wantHasMethod: true, + }, + { + name: "Interface/NoMethod", + typ: types.NewInterfaceType(nil, nil), + methodName: "NoMethod", + wantHasMethod: false, + }, + { + name: "TypeParam/HasMethod", + typ: lookupTestType(t, "GenericType").Underlying().(*types.Struct).Field(0).Type(), + methodName: "Method", + wantHasMethod: true, + wantReceiver: lookupTestType(t, "NamedInterface"), + }, + { + name: "TypeParam/NoMethod", + typ: lookupTestType(t, "GenericType").Underlying().(*types.Struct).Field(0).Type(), + methodName: "NoMethod", + wantHasMethod: false, + }, + { + name: "TypeParamPtr/HasMethod", + typ: lookupTestType(t, "GenericType").Underlying().(*types.Struct).Field(1).Type(), + methodName: "Method", + wantHasMethod: true, + wantReceiver: lookupTestType(t, "NamedInterface"), + }, + { + name: "TypeParamPtr/NoMethod", + typ: lookupTestType(t, "GenericType").Underlying().(*types.Struct).Field(1).Type(), + methodName: "NoMethod", + wantHasMethod: false, + }, + { + name: "TypeParamWithAlias/HasMethod", + typ: lookupTestType(t, "GenericTypeWithAliasConstraint").Underlying().(*types.Struct).Field(0).Type(), + methodName: "Method", + wantHasMethod: true, + wantReceiver: lookupTestType(t, "NamedInterface"), + }, + { + name: "TypeParamWithAlias/NoMethod", + typ: lookupTestType(t, "GenericTypeWithAliasConstraint").Underlying().(*types.Struct).Field(0).Type(), + methodName: "NoMethod", + wantHasMethod: false, + }, + { + name: "TypeParamWithAliasPtr/HasMethod", + typ: lookupTestType(t, "GenericTypeWithAliasConstraint").Underlying().(*types.Struct).Field(1).Type(), + methodName: "Method", + wantHasMethod: true, + wantReceiver: lookupTestType(t, "NamedInterface"), + }, + { + name: "TypeParamWithAliasPtr/NoMethod", + typ: lookupTestType(t, "GenericTypeWithAliasConstraint").Underlying().(*types.Struct).Field(1).Type(), + methodName: "NoMethod", + wantHasMethod: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + gotMethod := LookupMethod(tt.typ, tt.methodName) + if gotHasMethod := gotMethod != nil; gotHasMethod != tt.wantHasMethod { + t.Fatalf("HasMethod: got %v; want %v", gotMethod, tt.wantHasMethod) + } + if gotMethod == nil { + return + } + if gotMethod.Name() != tt.methodName { + t.Errorf("Name: got %v; want %v", gotMethod.Name(), tt.methodName) + } + if gotRecv, wantRecv := gotMethod.Signature().Recv().Type(), cmp.Or(tt.wantReceiver, tt.typ); !types.Identical(gotRecv, wantRecv) { + t.Errorf("Recv: got %v; want %v", gotRecv, wantRecv) + } + }) + } +} + +var namedTestTypes = sync.OnceValues(func() (map[string]types.Type, error) { + _, namedTypes, err := LoadTypes("test", ".") + return namedTypes, err +}) + +func lookupTestType(t *testing.T, name string) types.Type { + t.Helper() + types, err := namedTestTypes() + if err != nil { + t.Fatal(err) + } + typ, ok := types[name] + if !ok { + t.Fatalf("type %q is not declared in the current package", name) + } + return typ +}