mirror of
https://github.com/tailscale/tailscale.git
synced 2025-07-30 07:43:42 +00:00
cmd/viewer,util/codegen: add MarshalJSONV2 methods to views with --jsonv2
Updates tailscale/corp#791 Updates tailscale/corp#26353 Signed-off-by: Paul Scott <paul@tailscale.com>
This commit is contained in:
parent
820bdb870a
commit
09285ead78
@ -42,7 +42,7 @@ func (v {{.ViewName}}{{.TypeParamNames}}) Valid() bool { return v.ж != nil }
|
||||
|
||||
// AsStruct returns a clone of the underlying value which aliases no memory with
|
||||
// the original.
|
||||
func (v {{.ViewName}}{{.TypeParamNames}}) AsStruct() *{{.StructName}}{{.TypeParamNames}}{
|
||||
func (v {{.ViewName}}{{.TypeParamNames}}) AsStruct() *{{.StructName}}{{.TypeParamNames}}{
|
||||
if v.ж == nil {
|
||||
return nil
|
||||
}
|
||||
@ -51,6 +51,8 @@ func (v {{.ViewName}}{{.TypeParamNames}}) AsStruct() *{{.StructName}}{{.TypePara
|
||||
|
||||
func (v {{.ViewName}}{{.TypeParamNames}}) MarshalJSON() ([]byte, error) { return json.Marshal(v.ж) }
|
||||
|
||||
{{if .JSONV2}}func (v {{.ViewName}}{{.TypeParamNames}}) MarshalJSONV2(e *jsontext.Encoder, opt jsonexpv2.Options) error { return jsonexpv2.MarshalEncode(e, v.ж, opt) }{{end}}
|
||||
|
||||
func (v *{{.ViewName}}{{.TypeParamNames}}) UnmarshalJSON(b []byte) error {
|
||||
if v.ж != nil {
|
||||
return errors.New("already initialized")
|
||||
@ -120,13 +122,17 @@ func requiresCloning(t types.Type) (shallow, deep bool, base types.Type) {
|
||||
return p, p, t
|
||||
}
|
||||
|
||||
func genView(buf *bytes.Buffer, it *codegen.ImportTracker, typ *types.Named, _ *types.Package) {
|
||||
func genView(buf *bytes.Buffer, it *codegen.ImportTracker, typ *types.Named, _ *types.Package, jsonv2 bool) {
|
||||
t, ok := typ.Underlying().(*types.Struct)
|
||||
if !ok || codegen.IsViewType(t) {
|
||||
return
|
||||
}
|
||||
it.Import("encoding/json")
|
||||
it.Import("errors")
|
||||
if jsonv2 {
|
||||
it.ImportAs("github.com/go-json-experiment/json", "jsonexpv2")
|
||||
it.Import("github.com/go-json-experiment/json/jsontext")
|
||||
}
|
||||
|
||||
args := struct {
|
||||
StructName string
|
||||
@ -145,9 +151,14 @@ func genView(buf *bytes.Buffer, it *codegen.ImportTracker, typ *types.Named, _ *
|
||||
|
||||
// MakeViewFnName is the name of the function that accepts a value and returns a read-only view of it.
|
||||
MakeViewFnName string
|
||||
|
||||
// JSONV2 enables the addition of MarshalJSONV2 methods which depend on
|
||||
// github.com/go-json-experiment/json.
|
||||
JSONV2 bool
|
||||
}{
|
||||
StructName: typ.Obj().Name(),
|
||||
ViewName: typ.Origin().Obj().Name() + "View",
|
||||
JSONV2: jsonv2,
|
||||
}
|
||||
|
||||
typeParams := typ.Origin().TypeParams()
|
||||
@ -574,6 +585,7 @@ var (
|
||||
flagTypes = flag.String("type", "", "comma-separated list of types; required")
|
||||
flagBuildTags = flag.String("tags", "", "compiler build tags to apply")
|
||||
flagCloneFunc = flag.Bool("clonefunc", false, "add a top-level Clone func")
|
||||
flagJSONV2 = flag.Bool("jsonv2", false, "add jsonv2 Marshal methods")
|
||||
|
||||
flagCloneOnlyTypes = flag.String("clone-only-type", "", "comma-separated list of types (a subset of --type) that should only generate a go:generate clone line and not actual views")
|
||||
|
||||
@ -630,7 +642,7 @@ func main() {
|
||||
if !hasClone {
|
||||
runCloner = true
|
||||
}
|
||||
genView(buf, it, typ, pkg.Types)
|
||||
genView(buf, it, typ, pkg.Types, *flagJSONV2)
|
||||
}
|
||||
out := pkg.Name + "_view"
|
||||
if *flagBuildTags == "test" {
|
||||
|
@ -17,10 +17,12 @@ import (
|
||||
|
||||
func TestViewerImports(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
content string
|
||||
typeNames []string
|
||||
wantImports []string
|
||||
name string
|
||||
content string
|
||||
jsonv2 bool
|
||||
typeNames []string
|
||||
wantImports []string
|
||||
wantNoImports []string
|
||||
}{
|
||||
{
|
||||
name: "Map",
|
||||
@ -34,6 +36,20 @@ func TestViewerImports(t *testing.T) {
|
||||
typeNames: []string{"Test"},
|
||||
wantImports: []string{"tailscale.com/types/views"},
|
||||
},
|
||||
{
|
||||
name: "withJSONV2",
|
||||
content: `type Test struct { }`,
|
||||
jsonv2: true,
|
||||
typeNames: []string{"Test"},
|
||||
wantImports: []string{"github.com/go-json-experiment/json"},
|
||||
},
|
||||
{
|
||||
name: "withoutJSONV2",
|
||||
content: `type Test struct { }`,
|
||||
jsonv2: false,
|
||||
typeNames: []string{"Test"},
|
||||
wantNoImports: []string{"github.com/go-json-experiment/json"},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
@ -65,7 +81,7 @@ func TestViewerImports(t *testing.T) {
|
||||
if !ok {
|
||||
t.Fatalf("%q is not a named type", tt.typeNames[i])
|
||||
}
|
||||
genView(&output, tracker, namedType, pkg)
|
||||
genView(&output, tracker, namedType, pkg, tt.jsonv2)
|
||||
}
|
||||
|
||||
for _, pkgName := range tt.wantImports {
|
||||
@ -73,6 +89,11 @@ func TestViewerImports(t *testing.T) {
|
||||
t.Errorf("missing import %q", pkgName)
|
||||
}
|
||||
}
|
||||
for _, pkgName := range tt.wantNoImports {
|
||||
if tracker.Has(pkgName) {
|
||||
t.Errorf("unwanted import %q", pkgName)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
@ -6,6 +6,7 @@ package codegen
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"cmp"
|
||||
"flag"
|
||||
"fmt"
|
||||
"go/ast"
|
||||
@ -88,27 +89,34 @@ func NewImportTracker(thisPkg *types.Package) *ImportTracker {
|
||||
// ImportTracker provides a mechanism to track and build import paths.
|
||||
type ImportTracker struct {
|
||||
thisPkg *types.Package
|
||||
packages map[string]bool
|
||||
packages map[string]string // "github.com/go-json-experiment/json" => "jsonv2", or "encoding/json" => "" (default to basename "json").
|
||||
}
|
||||
|
||||
func (it *ImportTracker) Import(pkg string) {
|
||||
if pkg != "" && !it.packages[pkg] {
|
||||
mak.Set(&it.packages, pkg, true)
|
||||
if pkg != "" && !it.Has(pkg) {
|
||||
mak.Set(&it.packages, pkg, "")
|
||||
}
|
||||
}
|
||||
|
||||
func (it *ImportTracker) ImportAs(pkg, as string) {
|
||||
if pkg != "" && !it.Has(pkg) {
|
||||
mak.Set(&it.packages, pkg, as)
|
||||
}
|
||||
}
|
||||
|
||||
// Has reports whether the specified package has been imported.
|
||||
func (it *ImportTracker) Has(pkg string) bool {
|
||||
return it.packages[pkg]
|
||||
_, ok := it.packages[pkg]
|
||||
return ok
|
||||
}
|
||||
|
||||
func (it *ImportTracker) qualifier(pkg *types.Package) string {
|
||||
if it.thisPkg == pkg {
|
||||
return ""
|
||||
}
|
||||
it.Import(pkg.Path())
|
||||
// TODO(maisem): handle conflicts?
|
||||
return pkg.Name()
|
||||
path := pkg.Path()
|
||||
it.Import(path)
|
||||
return cmp.Or(it.packages[path], pkg.Name())
|
||||
}
|
||||
|
||||
// QualifiedName returns the string representation of t in the package.
|
||||
@ -127,8 +135,12 @@ func (it *ImportTracker) PackagePrefix(pkg *types.Package) string {
|
||||
// Write prints all the tracked imports in a single import block to w.
|
||||
func (it *ImportTracker) Write(w io.Writer) {
|
||||
fmt.Fprintf(w, "import (\n")
|
||||
for s := range it.packages {
|
||||
fmt.Fprintf(w, "\t%q\n", s)
|
||||
for s, q := range it.packages {
|
||||
if q != "" {
|
||||
fmt.Fprintf(w, "\t%s %q\n", q, s)
|
||||
} else {
|
||||
fmt.Fprintf(w, "\t%q\n", s)
|
||||
}
|
||||
}
|
||||
fmt.Fprintf(w, ")\n\n")
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user