From 03acab2639c03c83a5507f95ef0128d893fc405e Mon Sep 17 00:00:00 2001 From: Nick Khyl Date: Thu, 22 Aug 2024 17:59:45 -0500 Subject: [PATCH] cmd/cloner, cmd/viewer, util/codegen: add support for aliases of cloneable types We have several checked type assertions to *types.Named in both cmd/cloner and cmd/viewer. As Go 1.23 updates the go/types package to produce Alias type nodes for type aliases, these type assertions no longer work as expected unless the new behavior is disabled with gotypesalias=0. In this PR, we add codegen.NamedTypeOf(t types.Type), which functions like t.(*types.Named) but also unrolls type aliases. We then use it in place of type assertions in the cmd/cloner and cmd/viewer packages where appropriate. We also update type switches to include *types.Alias alongside *types.Named in relevant cases, remove *types.Struct cases when switching on types.Type.Underlying and update the tests with more cases where type aliases can be used. Updates #13224 Updates #12912 Signed-off-by: Nick Khyl --- cmd/cloner/cloner.go | 4 +- cmd/viewer/tests/tests.go | 18 ++++++++- cmd/viewer/tests/tests_clone.go | 70 +++++++++++++++++++++++++++++++-- cmd/viewer/tests/tests_view.go | 53 ++++++++++++++++++++++++- cmd/viewer/viewer.go | 18 ++++----- util/codegen/codegen.go | 9 +++++ 6 files changed, 154 insertions(+), 18 deletions(-) diff --git a/cmd/cloner/cloner.go b/cmd/cloner/cloner.go index 23f3e219c..a1ffc30fe 100644 --- a/cmd/cloner/cloner.go +++ b/cmd/cloner/cloner.go @@ -115,7 +115,7 @@ func gen(buf *bytes.Buffer, it *codegen.ImportTracker, typ *types.Named) { if !codegen.ContainsPointers(ft) || codegen.HasNoClone(t.Tag(i)) { continue } - if named, _ := ft.(*types.Named); named != nil { + if named, _ := codegen.NamedTypeOf(ft); named != nil { if codegen.IsViewType(ft) { writef("dst.%s = src.%s", fname, fname) continue @@ -161,7 +161,7 @@ func gen(buf *bytes.Buffer, it *codegen.ImportTracker, typ *types.Named) { case *types.Pointer: base := ft.Elem() hasPtrs := codegen.ContainsPointers(base) - if named, _ := base.(*types.Named); named != nil && hasPtrs { + if named, _ := codegen.NamedTypeOf(base); named != nil && hasPtrs { writef("dst.%s = src.%s.Clone()", fname, fname) continue } diff --git a/cmd/viewer/tests/tests.go b/cmd/viewer/tests/tests.go index 8f5dc23af..14a488861 100644 --- a/cmd/viewer/tests/tests.go +++ b/cmd/viewer/tests/tests.go @@ -204,13 +204,27 @@ type StructWithContainers struct { } type ( - StructWithPtrsAlias = StructWithPtrs - StructWithoutPtrsAlias = StructWithoutPtrs + StructWithPtrsAlias = StructWithPtrs + StructWithoutPtrsAlias = StructWithoutPtrs + StructWithPtrsAliasView = StructWithPtrsView + StructWithoutPtrsAliasView = StructWithoutPtrsView ) type StructWithTypeAliasFields struct { WithPtr StructWithPtrsAlias WithoutPtr StructWithoutPtrsAlias + + WithPtrByPtr *StructWithPtrsAlias + WithoutPtrByPtr *StructWithoutPtrsAlias + + SliceWithPtrs []*StructWithPtrsAlias + SliceWithoutPtrs []*StructWithoutPtrsAlias + + MapWithPtrs map[string]*StructWithPtrsAlias + MapWithoutPtrs map[string]*StructWithoutPtrsAlias + + MapOfSlicesWithPtrs map[string][]*StructWithPtrsAlias + MapOfSlicesWithoutPtrs map[string][]*StructWithoutPtrsAlias } type integer = constraints.Integer diff --git a/cmd/viewer/tests/tests_clone.go b/cmd/viewer/tests/tests_clone.go index 542512787..9131f5040 100644 --- a/cmd/viewer/tests/tests_clone.go +++ b/cmd/viewer/tests/tests_clone.go @@ -450,14 +450,78 @@ func (src *StructWithTypeAliasFields) Clone() *StructWithTypeAliasFields { } dst := new(StructWithTypeAliasFields) *dst = *src - panic("TODO: WithPtr (*types.Struct)") + dst.WithPtr = *src.WithPtr.Clone() + dst.WithPtrByPtr = src.WithPtrByPtr.Clone() + if dst.WithoutPtrByPtr != nil { + dst.WithoutPtrByPtr = ptr.To(*src.WithoutPtrByPtr) + } + if src.SliceWithPtrs != nil { + dst.SliceWithPtrs = make([]*StructWithPtrsAlias, len(src.SliceWithPtrs)) + for i := range dst.SliceWithPtrs { + if src.SliceWithPtrs[i] == nil { + dst.SliceWithPtrs[i] = nil + } else { + dst.SliceWithPtrs[i] = src.SliceWithPtrs[i].Clone() + } + } + } + if src.SliceWithoutPtrs != nil { + dst.SliceWithoutPtrs = make([]*StructWithoutPtrsAlias, len(src.SliceWithoutPtrs)) + for i := range dst.SliceWithoutPtrs { + if src.SliceWithoutPtrs[i] == nil { + dst.SliceWithoutPtrs[i] = nil + } else { + dst.SliceWithoutPtrs[i] = ptr.To(*src.SliceWithoutPtrs[i]) + } + } + } + if dst.MapWithPtrs != nil { + dst.MapWithPtrs = map[string]*StructWithPtrsAlias{} + for k, v := range src.MapWithPtrs { + if v == nil { + dst.MapWithPtrs[k] = nil + } else { + dst.MapWithPtrs[k] = v.Clone() + } + } + } + if dst.MapWithoutPtrs != nil { + dst.MapWithoutPtrs = map[string]*StructWithoutPtrsAlias{} + for k, v := range src.MapWithoutPtrs { + if v == nil { + dst.MapWithoutPtrs[k] = nil + } else { + dst.MapWithoutPtrs[k] = ptr.To(*v) + } + } + } + if dst.MapOfSlicesWithPtrs != nil { + dst.MapOfSlicesWithPtrs = map[string][]*StructWithPtrsAlias{} + for k := range src.MapOfSlicesWithPtrs { + dst.MapOfSlicesWithPtrs[k] = append([]*StructWithPtrsAlias{}, src.MapOfSlicesWithPtrs[k]...) + } + } + if dst.MapOfSlicesWithoutPtrs != nil { + dst.MapOfSlicesWithoutPtrs = map[string][]*StructWithoutPtrsAlias{} + for k := range src.MapOfSlicesWithoutPtrs { + dst.MapOfSlicesWithoutPtrs[k] = append([]*StructWithoutPtrsAlias{}, src.MapOfSlicesWithoutPtrs[k]...) + } + } return dst } // A compilation failure here means this code must be regenerated, with the command at the top of this file. var _StructWithTypeAliasFieldsCloneNeedsRegeneration = StructWithTypeAliasFields(struct { - WithPtr StructWithPtrsAlias - WithoutPtr StructWithoutPtrsAlias + WithPtr StructWithPtrsAlias + WithoutPtr StructWithoutPtrsAlias + WithPtrByPtr *StructWithPtrsAlias + WithoutPtrByPtr *StructWithoutPtrsAlias + SliceWithPtrs []*StructWithPtrsAlias + SliceWithoutPtrs []*StructWithoutPtrsAlias + MapWithPtrs map[string]*StructWithPtrsAlias + MapWithoutPtrs map[string]*StructWithoutPtrsAlias + MapOfSlicesWithPtrs map[string][]*StructWithPtrsAlias + MapOfSlicesWithoutPtrs map[string][]*StructWithoutPtrsAlias }{}) // Clone makes a deep copy of GenericTypeAliasStruct. diff --git a/cmd/viewer/tests/tests_view.go b/cmd/viewer/tests/tests_view.go index f0cbf1564..9c74c9426 100644 --- a/cmd/viewer/tests/tests_view.go +++ b/cmd/viewer/tests/tests_view.go @@ -724,11 +724,60 @@ func (v *StructWithTypeAliasFieldsView) UnmarshalJSON(b []byte) error { func (v StructWithTypeAliasFieldsView) WithPtr() StructWithPtrsView { return v.ж.WithPtr.View() } func (v StructWithTypeAliasFieldsView) WithoutPtr() StructWithoutPtrsAlias { return v.ж.WithoutPtr } +func (v StructWithTypeAliasFieldsView) WithPtrByPtr() StructWithPtrsAliasView { + return v.ж.WithPtrByPtr.View() +} +func (v StructWithTypeAliasFieldsView) WithoutPtrByPtr() *StructWithoutPtrsAlias { + if v.ж.WithoutPtrByPtr == nil { + return nil + } + x := *v.ж.WithoutPtrByPtr + return &x +} + +func (v StructWithTypeAliasFieldsView) SliceWithPtrs() views.SliceView[*StructWithPtrsAlias, StructWithPtrsAliasView] { + return views.SliceOfViews[*StructWithPtrsAlias, StructWithPtrsAliasView](v.ж.SliceWithPtrs) +} +func (v StructWithTypeAliasFieldsView) SliceWithoutPtrs() views.SliceView[*StructWithoutPtrsAlias, StructWithoutPtrsAliasView] { + return views.SliceOfViews[*StructWithoutPtrsAlias, StructWithoutPtrsAliasView](v.ж.SliceWithoutPtrs) +} + +func (v StructWithTypeAliasFieldsView) MapWithPtrs() views.MapFn[string, *StructWithPtrsAlias, StructWithPtrsAliasView] { + return views.MapFnOf(v.ж.MapWithPtrs, func(t *StructWithPtrsAlias) StructWithPtrsAliasView { + return t.View() + }) +} + +func (v StructWithTypeAliasFieldsView) MapWithoutPtrs() views.MapFn[string, *StructWithoutPtrsAlias, StructWithoutPtrsAliasView] { + return views.MapFnOf(v.ж.MapWithoutPtrs, func(t *StructWithoutPtrsAlias) StructWithoutPtrsAliasView { + return t.View() + }) +} + +func (v StructWithTypeAliasFieldsView) MapOfSlicesWithPtrs() views.MapFn[string, []*StructWithPtrsAlias, views.SliceView[*StructWithPtrsAlias, StructWithPtrsAliasView]] { + return views.MapFnOf(v.ж.MapOfSlicesWithPtrs, func(t []*StructWithPtrsAlias) views.SliceView[*StructWithPtrsAlias, StructWithPtrsAliasView] { + return views.SliceOfViews[*StructWithPtrsAlias, StructWithPtrsAliasView](t) + }) +} + +func (v StructWithTypeAliasFieldsView) MapOfSlicesWithoutPtrs() views.MapFn[string, []*StructWithoutPtrsAlias, views.SliceView[*StructWithoutPtrsAlias, StructWithoutPtrsAliasView]] { + return views.MapFnOf(v.ж.MapOfSlicesWithoutPtrs, func(t []*StructWithoutPtrsAlias) views.SliceView[*StructWithoutPtrsAlias, StructWithoutPtrsAliasView] { + return views.SliceOfViews[*StructWithoutPtrsAlias, StructWithoutPtrsAliasView](t) + }) +} // A compilation failure here means this code must be regenerated, with the command at the top of this file. var _StructWithTypeAliasFieldsViewNeedsRegeneration = StructWithTypeAliasFields(struct { - WithPtr StructWithPtrsAlias - WithoutPtr StructWithoutPtrsAlias + WithPtr StructWithPtrsAlias + WithoutPtr StructWithoutPtrsAlias + WithPtrByPtr *StructWithPtrsAlias + WithoutPtrByPtr *StructWithoutPtrsAlias + SliceWithPtrs []*StructWithPtrsAlias + SliceWithoutPtrs []*StructWithoutPtrsAlias + MapWithPtrs map[string]*StructWithPtrsAlias + MapWithoutPtrs map[string]*StructWithoutPtrsAlias + MapOfSlicesWithPtrs map[string][]*StructWithPtrsAlias + MapOfSlicesWithoutPtrs map[string][]*StructWithoutPtrsAlias }{}) // View returns a readonly view of GenericTypeAliasStruct. diff --git a/cmd/viewer/viewer.go b/cmd/viewer/viewer.go index e4e56163d..96223297b 100644 --- a/cmd/viewer/viewer.go +++ b/cmd/viewer/viewer.go @@ -230,7 +230,7 @@ func genView(buf *bytes.Buffer, it *codegen.ImportTracker, typ *types.Named, thi writeTemplate("sliceField") } continue - case *types.Struct, *types.Named: + case *types.Struct: strucT := underlying args.FieldType = it.QualifiedName(fieldType) if codegen.ContainsPointers(strucT) { @@ -262,7 +262,7 @@ func genView(buf *bytes.Buffer, it *codegen.ImportTracker, typ *types.Named, thi mElem := m.Elem() var template string switch u := mElem.(type) { - case *types.Struct, *types.Named: + case *types.Struct, *types.Named, *types.Alias: strucT := u args.FieldType = it.QualifiedName(fieldType) if codegen.ContainsPointers(strucT) { @@ -281,7 +281,7 @@ func genView(buf *bytes.Buffer, it *codegen.ImportTracker, typ *types.Named, thi slice := u sElem := slice.Elem() switch x := sElem.(type) { - case *types.Basic, *types.Named: + case *types.Basic, *types.Named, *types.Alias: sElem := it.QualifiedName(sElem) args.MapValueView = fmt.Sprintf("views.Slice[%v]", sElem) args.MapValueType = sElem @@ -292,7 +292,7 @@ func genView(buf *bytes.Buffer, it *codegen.ImportTracker, typ *types.Named, thi template = "unsupportedField" if _, isIface := pElem.Underlying().(*types.Interface); !isIface { switch pElem.(type) { - case *types.Struct, *types.Named: + case *types.Struct, *types.Named, *types.Alias: ptrType := it.QualifiedName(ptr) viewType := appendNameSuffix(it.QualifiedName(pElem), "View") args.MapFn = fmt.Sprintf("views.SliceOfViews[%v,%v](t)", ptrType, viewType) @@ -313,7 +313,7 @@ func genView(buf *bytes.Buffer, it *codegen.ImportTracker, typ *types.Named, thi pElem := ptr.Elem() if _, isIface := pElem.Underlying().(*types.Interface); !isIface { switch pElem.(type) { - case *types.Struct, *types.Named: + case *types.Struct, *types.Named, *types.Alias: args.MapValueType = it.QualifiedName(ptr) args.MapValueView = appendNameSuffix(it.QualifiedName(pElem), "View") args.MapFn = "t.View()" @@ -422,7 +422,7 @@ func viewTypeForValueType(typ types.Type) types.Type { func viewTypeForContainerType(typ types.Type) (*types.Named, *types.Func) { // The container type should be an instantiated generic type, // with its first type parameter specifying the element type. - containerType, ok := typ.(*types.Named) + containerType, ok := codegen.NamedTypeOf(typ) if !ok || containerType.TypeArgs().Len() == 0 { return nil, nil } @@ -435,7 +435,7 @@ func viewTypeForContainerType(typ types.Type) (*types.Named, *types.Func) { if !ok { return nil, nil } - containerViewGenericType, ok := containerViewTypeObj.Type().(*types.Named) + containerViewGenericType, ok := codegen.NamedTypeOf(containerViewTypeObj.Type()) if !ok || containerViewGenericType.TypeParams().Len() != containerType.TypeArgs().Len()+1 { return nil, nil } @@ -448,7 +448,7 @@ func viewTypeForContainerType(typ types.Type) (*types.Named, *types.Func) { } // ...and add the element view type. // For that, we need to first determine the named elem type... - elemType, ok := baseType(containerType.TypeArgs().At(containerType.TypeArgs().Len() - 1)).(*types.Named) + elemType, ok := codegen.NamedTypeOf(baseType(containerType.TypeArgs().At(containerType.TypeArgs().Len() - 1))) if !ok { return nil, nil } @@ -473,7 +473,7 @@ func viewTypeForContainerType(typ types.Type) (*types.Named, *types.Func) { } // If elemType is an instantiated generic type, instantiate the elemViewType as well. if elemTypeArgs := elemType.TypeArgs(); elemTypeArgs != nil { - elemViewType = must.Get(types.Instantiate(nil, elemViewType, collectTypes(elemTypeArgs), false)).(*types.Named) + elemViewType, _ = codegen.NamedTypeOf(must.Get(types.Instantiate(nil, elemViewType, collectTypes(elemTypeArgs), false))) } // And finally set the elemViewType as the last type argument. containerViewTypeArgs[len(containerViewTypeArgs)-1] = elemViewType diff --git a/util/codegen/codegen.go b/util/codegen/codegen.go index 4e2c86909..d998d925d 100644 --- a/util/codegen/codegen.go +++ b/util/codegen/codegen.go @@ -382,3 +382,12 @@ func LookupMethod(t types.Type, name string) *types.Func { } return nil } + +// NamedTypeOf is like t.(*types.Named), but also works with type aliases. +func NamedTypeOf(t types.Type) (named *types.Named, ok bool) { + if a, ok := t.(*types.Alias); ok { + return NamedTypeOf(types.Unalias(a)) + } + named, ok = t.(*types.Named) + return +}