cmd/cloner, cmd/viewer, util/codegen: add support for generic types and interfaces

This adds support for generic types and interfaces to our cloner and viewer codegens.
It updates these packages to determine whether to make shallow or deep copies based
on the type parameter constraints. Additionally, if a template parameter or an interface
type has View() and Clone() methods, we'll use them for getters and the cloner of the
owning structure.

Updates #12736

Signed-off-by: Nick Khyl <nickk@tailscale.com>
This commit is contained in:
Nick Khyl
2024-07-08 10:11:00 -05:00
committed by Nick Khyl
parent b7c3cfe049
commit fc28c8e7f3
9 changed files with 1039 additions and 114 deletions

View File

@@ -20,43 +20,43 @@ import (
const viewTemplateStr = `{{define "common"}}
// View returns a readonly view of {{.StructName}}.
func (p *{{.StructName}}) View() {{.ViewName}} {
return {{.ViewName}}{ж: p}
func (p *{{.StructName}}{{.TypeParamNames}}) View() {{.ViewName}}{{.TypeParamNames}} {
return {{.ViewName}}{{.TypeParamNames}}{ж: p}
}
// {{.ViewName}} provides a read-only view over {{.StructName}}.
// {{.ViewName}}{{.TypeParamNames}} provides a read-only view over {{.StructName}}{{.TypeParamNames}}.
//
// Its methods should only be called if ` + "`Valid()`" + ` returns true.
type {{.ViewName}} struct {
type {{.ViewName}}{{.TypeParams}} struct {
// ж is the underlying mutable value, named with a hard-to-type
// character that looks pointy like a pointer.
// It is named distinctively to make you think of how dangerous it is to escape
// to callers. You must not let callers be able to mutate it.
ж *{{.StructName}}
ж *{{.StructName}}{{.TypeParamNames}}
}
// Valid reports whether underlying value is non-nil.
func (v {{.ViewName}}) Valid() bool { return v.ж != nil }
func (v {{.ViewName}}{{.TypeParamNames}}) Valid() bool { return v.ж != nil }
// AsStruct returns a clone of the underlying value which aliases no memory with
// the original.
func (v {{.ViewName}}) AsStruct() *{{.StructName}}{
func (v {{.ViewName}}{{.TypeParamNames}}) AsStruct() *{{.StructName}}{{.TypeParamNames}}{
if v.ж == nil {
return nil
}
return v.ж.Clone()
}
func (v {{.ViewName}}) MarshalJSON() ([]byte, error) { return json.Marshal(v.ж) }
func (v {{.ViewName}}{{.TypeParamNames}}) MarshalJSON() ([]byte, error) { return json.Marshal(v.ж) }
func (v *{{.ViewName}}) UnmarshalJSON(b []byte) error {
func (v *{{.ViewName}}{{.TypeParamNames}}) UnmarshalJSON(b []byte) error {
if v.ж != nil {
return errors.New("already initialized")
}
if len(b) == 0 {
return nil
}
var x {{.StructName}}
var x {{.StructName}}{{.TypeParamNames}}
if err := json.Unmarshal(b, &x); err != nil {
return err
}
@@ -65,17 +65,17 @@ func (v *{{.ViewName}}) UnmarshalJSON(b []byte) error {
}
{{end}}
{{define "valueField"}}func (v {{.ViewName}}) {{.FieldName}}() {{.FieldType}} { return v.ж.{{.FieldName}} }
{{define "valueField"}}func (v {{.ViewName}}{{.TypeParamNames}}) {{.FieldName}}() {{.FieldType}} { return v.ж.{{.FieldName}} }
{{end}}
{{define "byteSliceField"}}func (v {{.ViewName}}) {{.FieldName}}() views.ByteSlice[{{.FieldType}}] { return views.ByteSliceOf(v.ж.{{.FieldName}}) }
{{define "byteSliceField"}}func (v {{.ViewName}}{{.TypeParamNames}}) {{.FieldName}}() views.ByteSlice[{{.FieldType}}] { return views.ByteSliceOf(v.ж.{{.FieldName}}) }
{{end}}
{{define "sliceField"}}func (v {{.ViewName}}) {{.FieldName}}() views.Slice[{{.FieldType}}] { return views.SliceOf(v.ж.{{.FieldName}}) }
{{define "sliceField"}}func (v {{.ViewName}}{{.TypeParamNames}}) {{.FieldName}}() views.Slice[{{.FieldType}}] { return views.SliceOf(v.ж.{{.FieldName}}) }
{{end}}
{{define "viewSliceField"}}func (v {{.ViewName}}) {{.FieldName}}() views.SliceView[{{.FieldType}},{{.FieldViewName}}] { return views.SliceOfViews[{{.FieldType}},{{.FieldViewName}}](v.ж.{{.FieldName}}) }
{{define "viewSliceField"}}func (v {{.ViewName}}{{.TypeParamNames}}) {{.FieldName}}() views.SliceView[{{.FieldType}},{{.FieldViewName}}] { return views.SliceOfViews[{{.FieldType}},{{.FieldViewName}}](v.ж.{{.FieldName}}) }
{{end}}
{{define "viewField"}}func (v {{.ViewName}}) {{.FieldName}}() {{.FieldType}}View { return v.ж.{{.FieldName}}.View() }
{{define "viewField"}}func (v {{.ViewName}}{{.TypeParamNames}}) {{.FieldName}}() {{.FieldViewName}} { return v.ж.{{.FieldName}}.View() }
{{end}}
{{define "valuePointerField"}}func (v {{.ViewName}}) {{.FieldName}}() {{.FieldType}} {
{{define "valuePointerField"}}func (v {{.ViewName}}{{.TypeParamNames}}) {{.FieldName}}() {{.FieldType}} {
if v.ж.{{.FieldName}} == nil {
return nil
}
@@ -85,21 +85,21 @@ func (v *{{.ViewName}}) UnmarshalJSON(b []byte) error {
{{end}}
{{define "mapField"}}
func(v {{.ViewName}}) {{.FieldName}}() views.Map[{{.MapKeyType}},{{.MapValueType}}] { return views.MapOf(v.ж.{{.FieldName}})}
func(v {{.ViewName}}{{.TypeParamNames}}) {{.FieldName}}() views.Map[{{.MapKeyType}},{{.MapValueType}}] { return views.MapOf(v.ж.{{.FieldName}})}
{{end}}
{{define "mapFnField"}}
func(v {{.ViewName}}) {{.FieldName}}() views.MapFn[{{.MapKeyType}},{{.MapValueType}},{{.MapValueView}}] { return views.MapFnOf(v.ж.{{.FieldName}}, func (t {{.MapValueType}}) {{.MapValueView}} {
func(v {{.ViewName}}{{.TypeParamNames}}) {{.FieldName}}() views.MapFn[{{.MapKeyType}},{{.MapValueType}},{{.MapValueView}}] { return views.MapFnOf(v.ж.{{.FieldName}}, func (t {{.MapValueType}}) {{.MapValueView}} {
return {{.MapFn}}
})}
{{end}}
{{define "mapSliceField"}}
func(v {{.ViewName}}) {{.FieldName}}() views.MapSlice[{{.MapKeyType}},{{.MapValueType}}] { return views.MapSliceOf(v.ж.{{.FieldName}}) }
func(v {{.ViewName}}{{.TypeParamNames}}) {{.FieldName}}() views.MapSlice[{{.MapKeyType}},{{.MapValueType}}] { return views.MapSliceOf(v.ж.{{.FieldName}}) }
{{end}}
{{define "unsupportedField"}}func(v {{.ViewName}}) {{.FieldName}}() {{.FieldType}} {panic("unsupported")}
{{define "unsupportedField"}}func(v {{.ViewName}}{{.TypeParamNames}}) {{.FieldName}}() {{.FieldType}} {panic("unsupported")}
{{end}}
{{define "stringFunc"}}func(v {{.ViewName}}) String() string { return v.ж.String() }
{{define "stringFunc"}}func(v {{.ViewName}}{{.TypeParamNames}}) String() string { return v.ж.String() }
{{end}}
{{define "equalFunc"}}func(v {{.ViewName}}) Equal(v2 {{.ViewName}}) bool { return v.ж.Equal(v2.ж) }
{{define "equalFunc"}}func(v {{.ViewName}}{{.TypeParamNames}}) Equal(v2 {{.ViewName}}{{.TypeParamNames}}) bool { return v.ж.Equal(v2.ж) }
{{end}}
`
@@ -131,8 +131,11 @@ func genView(buf *bytes.Buffer, it *codegen.ImportTracker, typ *types.Named, thi
it.Import("errors")
args := struct {
StructName string
ViewName string
StructName string
ViewName string
TypeParams string // e.g. [T constraints.Integer]
TypeParamNames string // e.g. [T]
FieldName string
FieldType string
FieldViewName string
@@ -143,9 +146,12 @@ func genView(buf *bytes.Buffer, it *codegen.ImportTracker, typ *types.Named, thi
MapFn string
}{
StructName: typ.Obj().Name(),
ViewName: typ.Obj().Name() + "View",
ViewName: typ.Origin().Obj().Name() + "View",
}
typeParams := typ.Origin().TypeParams()
args.TypeParams, args.TypeParamNames = codegen.FormatTypeParams(typeParams, it)
writeTemplate := func(name string) {
if err := viewTemplate.ExecuteTemplate(buf, name, args); err != nil {
log.Fatal(err)
@@ -182,19 +188,35 @@ func genView(buf *bytes.Buffer, it *codegen.ImportTracker, typ *types.Named, thi
it.Import("tailscale.com/types/views")
shallow, deep, base := requiresCloning(elem)
if deep {
if _, isPtr := elem.(*types.Pointer); isPtr {
args.FieldViewName = it.QualifiedName(base) + "View"
writeTemplate("viewSliceField")
} else {
writeTemplate("unsupportedField")
switch elem.Underlying().(type) {
case *types.Pointer:
if _, isIface := base.Underlying().(*types.Interface); !isIface {
args.FieldViewName = appendNameSuffix(it.QualifiedName(base), "View")
writeTemplate("viewSliceField")
} else {
writeTemplate("unsupportedField")
}
continue
case *types.Interface:
if viewType := viewTypeForValueType(elem); viewType != nil {
args.FieldViewName = it.QualifiedName(viewType)
writeTemplate("viewSliceField")
continue
}
}
writeTemplate("unsupportedField")
continue
} else if shallow {
if _, isBasic := base.(*types.Basic); isBasic {
switch base.Underlying().(type) {
case *types.Basic, *types.Interface:
writeTemplate("unsupportedField")
} else {
args.FieldViewName = it.QualifiedName(base) + "View"
writeTemplate("viewSliceField")
default:
if _, isIface := base.Underlying().(*types.Interface); !isIface {
args.FieldViewName = appendNameSuffix(it.QualifiedName(base), "View")
writeTemplate("viewSliceField")
} else {
writeTemplate("unsupportedField")
}
}
continue
}
@@ -205,6 +227,7 @@ func genView(buf *bytes.Buffer, it *codegen.ImportTracker, typ *types.Named, thi
strucT := underlying
args.FieldType = it.QualifiedName(fieldType)
if codegen.ContainsPointers(strucT) {
args.FieldViewName = appendNameSuffix(args.FieldType, "View")
writeTemplate("viewField")
continue
}
@@ -229,7 +252,7 @@ func genView(buf *bytes.Buffer, it *codegen.ImportTracker, typ *types.Named, thi
args.MapFn = "t.View()"
template = "mapFnField"
args.MapValueType = it.QualifiedName(mElem)
args.MapValueView = args.MapValueType + "View"
args.MapValueView = appendNameSuffix(args.MapValueType, "View")
} else {
template = "mapField"
args.MapValueType = it.QualifiedName(mElem)
@@ -249,15 +272,20 @@ func genView(buf *bytes.Buffer, it *codegen.ImportTracker, typ *types.Named, thi
case *types.Pointer:
ptr := x
pElem := ptr.Elem()
switch pElem.(type) {
case *types.Struct, *types.Named:
ptrType := it.QualifiedName(ptr)
viewType := it.QualifiedName(pElem) + "View"
args.MapFn = fmt.Sprintf("views.SliceOfViews[%v,%v](t)", ptrType, viewType)
args.MapValueView = fmt.Sprintf("views.SliceView[%v,%v]", ptrType, viewType)
args.MapValueType = "[]" + ptrType
template = "mapFnField"
default:
template = "unsupportedField"
if _, isIface := pElem.Underlying().(*types.Interface); !isIface {
switch pElem.(type) {
case *types.Struct, *types.Named:
ptrType := it.QualifiedName(ptr)
viewType := appendNameSuffix(it.QualifiedName(pElem), "View")
args.MapFn = fmt.Sprintf("views.SliceOfViews[%v,%v](t)", ptrType, viewType)
args.MapValueView = fmt.Sprintf("views.SliceView[%v,%v]", ptrType, viewType)
args.MapValueType = "[]" + ptrType
template = "mapFnField"
default:
template = "unsupportedField"
}
} else {
template = "unsupportedField"
}
default:
@@ -266,13 +294,29 @@ func genView(buf *bytes.Buffer, it *codegen.ImportTracker, typ *types.Named, thi
case *types.Pointer:
ptr := u
pElem := ptr.Elem()
switch pElem.(type) {
case *types.Struct, *types.Named:
args.MapValueType = it.QualifiedName(ptr)
args.MapValueView = it.QualifiedName(pElem) + "View"
if _, isIface := pElem.Underlying().(*types.Interface); !isIface {
switch pElem.(type) {
case *types.Struct, *types.Named:
args.MapValueType = it.QualifiedName(ptr)
args.MapValueView = appendNameSuffix(it.QualifiedName(pElem), "View")
args.MapFn = "t.View()"
template = "mapFnField"
default:
template = "unsupportedField"
}
} else {
template = "unsupportedField"
}
case *types.Interface, *types.TypeParam:
if viewType := viewTypeForValueType(u); viewType != nil {
args.MapValueType = it.QualifiedName(u)
args.MapValueView = it.QualifiedName(viewType)
args.MapFn = "t.View()"
template = "mapFnField"
default:
} else if !codegen.ContainsPointers(u) {
args.MapValueType = it.QualifiedName(mElem)
template = "mapField"
} else {
template = "unsupportedField"
}
default:
@@ -283,14 +327,28 @@ func genView(buf *bytes.Buffer, it *codegen.ImportTracker, typ *types.Named, thi
case *types.Pointer:
ptr := underlying
_, deep, base := requiresCloning(ptr)
if deep {
args.FieldType = it.QualifiedName(base)
writeTemplate("viewField")
if _, isIface := base.Underlying().(*types.Interface); !isIface {
args.FieldType = it.QualifiedName(base)
args.FieldViewName = appendNameSuffix(args.FieldType, "View")
writeTemplate("viewField")
} else {
writeTemplate("unsupportedField")
}
} else {
args.FieldType = it.QualifiedName(ptr)
writeTemplate("valuePointerField")
}
continue
case *types.Interface:
// If fieldType is an interface with a "View() {ViewType}" method, it can be used to clone the field.
// This includes scenarios where fieldType is a constrained type parameter.
if viewType := viewTypeForValueType(underlying); viewType != nil {
args.FieldViewName = it.QualifiedName(viewType)
writeTemplate("viewField")
continue
}
}
writeTemplate("unsupportedField")
}
@@ -318,7 +376,27 @@ func genView(buf *bytes.Buffer, it *codegen.ImportTracker, typ *types.Named, thi
}
}
fmt.Fprintf(buf, "\n")
buf.Write(codegen.AssertStructUnchanged(t, args.StructName, "View", it))
buf.Write(codegen.AssertStructUnchanged(t, args.StructName, typeParams, "View", it))
}
func appendNameSuffix(name, suffix string) string {
if idx := strings.IndexRune(name, '['); idx != -1 {
// Insert suffix after the type name, but before type parameters.
return name[:idx] + suffix + name[idx:]
}
return name + suffix
}
func viewTypeForValueType(typ types.Type) types.Type {
viewMethod := codegen.LookupMethod(typ, "View")
if viewMethod == nil {
return nil
}
sig, ok := viewMethod.Type().(*types.Signature)
if !ok || sig.Results().Len() != 1 {
return nil
}
return sig.Results().At(0).Type()
}
var (