diff --git a/cmd/viewer/viewer.go b/cmd/viewer/viewer.go index 2d30cc2eb..c0092fa85 100644 --- a/cmd/viewer/viewer.go +++ b/cmd/viewer/viewer.go @@ -42,7 +42,7 @@ func (v {{.ViewName}}{{.TypeParamNames}}) Valid() bool { return v.ж != nil } // AsStruct returns a clone of the underlying value which aliases no memory with // the original. -func (v {{.ViewName}}{{.TypeParamNames}}) AsStruct() *{{.StructName}}{{.TypeParamNames}}{ +func (v {{.ViewName}}{{.TypeParamNames}}) AsStruct() *{{.StructName}}{{.TypeParamNames}}{ if v.ж == nil { return nil } @@ -51,6 +51,8 @@ func (v {{.ViewName}}{{.TypeParamNames}}) AsStruct() *{{.StructName}}{{.TypePara func (v {{.ViewName}}{{.TypeParamNames}}) MarshalJSON() ([]byte, error) { return json.Marshal(v.ж) } +{{if .JSONV2}}func (v {{.ViewName}}{{.TypeParamNames}}) MarshalJSONV2(e *jsontext.Encoder, opt jsonexpv2.Options) error { return jsonexpv2.MarshalEncode(e, v.ж, opt) }{{end}} + func (v *{{.ViewName}}{{.TypeParamNames}}) UnmarshalJSON(b []byte) error { if v.ж != nil { return errors.New("already initialized") @@ -120,13 +122,17 @@ func requiresCloning(t types.Type) (shallow, deep bool, base types.Type) { return p, p, t } -func genView(buf *bytes.Buffer, it *codegen.ImportTracker, typ *types.Named, _ *types.Package) { +func genView(buf *bytes.Buffer, it *codegen.ImportTracker, typ *types.Named, _ *types.Package, jsonv2 bool) { t, ok := typ.Underlying().(*types.Struct) if !ok || codegen.IsViewType(t) { return } it.Import("encoding/json") it.Import("errors") + if jsonv2 { + it.ImportAs("github.com/go-json-experiment/json", "jsonexpv2") + it.Import("github.com/go-json-experiment/json/jsontext") + } args := struct { StructName string @@ -145,9 +151,14 @@ func genView(buf *bytes.Buffer, it *codegen.ImportTracker, typ *types.Named, _ * // MakeViewFnName is the name of the function that accepts a value and returns a read-only view of it. MakeViewFnName string + + // JSONV2 enables the addition of MarshalJSONV2 methods which depend on + // github.com/go-json-experiment/json. + JSONV2 bool }{ StructName: typ.Obj().Name(), ViewName: typ.Origin().Obj().Name() + "View", + JSONV2: jsonv2, } typeParams := typ.Origin().TypeParams() @@ -574,6 +585,7 @@ var ( flagTypes = flag.String("type", "", "comma-separated list of types; required") flagBuildTags = flag.String("tags", "", "compiler build tags to apply") flagCloneFunc = flag.Bool("clonefunc", false, "add a top-level Clone func") + flagJSONV2 = flag.Bool("jsonv2", false, "add jsonv2 Marshal methods") flagCloneOnlyTypes = flag.String("clone-only-type", "", "comma-separated list of types (a subset of --type) that should only generate a go:generate clone line and not actual views") @@ -630,7 +642,7 @@ func main() { if !hasClone { runCloner = true } - genView(buf, it, typ, pkg.Types) + genView(buf, it, typ, pkg.Types, *flagJSONV2) } out := pkg.Name + "_view" if *flagBuildTags == "test" { diff --git a/cmd/viewer/viewer_test.go b/cmd/viewer/viewer_test.go index cd5f3d95f..9e3a9b2c3 100644 --- a/cmd/viewer/viewer_test.go +++ b/cmd/viewer/viewer_test.go @@ -17,10 +17,12 @@ import ( func TestViewerImports(t *testing.T) { tests := []struct { - name string - content string - typeNames []string - wantImports []string + name string + content string + jsonv2 bool + typeNames []string + wantImports []string + wantNoImports []string }{ { name: "Map", @@ -34,6 +36,20 @@ func TestViewerImports(t *testing.T) { typeNames: []string{"Test"}, wantImports: []string{"tailscale.com/types/views"}, }, + { + name: "withJSONV2", + content: `type Test struct { }`, + jsonv2: true, + typeNames: []string{"Test"}, + wantImports: []string{"github.com/go-json-experiment/json"}, + }, + { + name: "withoutJSONV2", + content: `type Test struct { }`, + jsonv2: false, + typeNames: []string{"Test"}, + wantNoImports: []string{"github.com/go-json-experiment/json"}, + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { @@ -65,7 +81,7 @@ func TestViewerImports(t *testing.T) { if !ok { t.Fatalf("%q is not a named type", tt.typeNames[i]) } - genView(&output, tracker, namedType, pkg) + genView(&output, tracker, namedType, pkg, tt.jsonv2) } for _, pkgName := range tt.wantImports { @@ -73,6 +89,11 @@ func TestViewerImports(t *testing.T) { t.Errorf("missing import %q", pkgName) } } + for _, pkgName := range tt.wantNoImports { + if tracker.Has(pkgName) { + t.Errorf("unwanted import %q", pkgName) + } + } }) } } diff --git a/util/codegen/codegen.go b/util/codegen/codegen.go index 1b3af10e0..963e3a04c 100644 --- a/util/codegen/codegen.go +++ b/util/codegen/codegen.go @@ -6,6 +6,7 @@ package codegen import ( "bytes" + "cmp" "flag" "fmt" "go/ast" @@ -88,27 +89,34 @@ func NewImportTracker(thisPkg *types.Package) *ImportTracker { // ImportTracker provides a mechanism to track and build import paths. type ImportTracker struct { thisPkg *types.Package - packages map[string]bool + packages map[string]string // "github.com/go-json-experiment/json" => "jsonv2", or "encoding/json" => "" (default to basename "json"). } func (it *ImportTracker) Import(pkg string) { - if pkg != "" && !it.packages[pkg] { - mak.Set(&it.packages, pkg, true) + if pkg != "" && !it.Has(pkg) { + mak.Set(&it.packages, pkg, "") + } +} + +func (it *ImportTracker) ImportAs(pkg, as string) { + if pkg != "" && !it.Has(pkg) { + mak.Set(&it.packages, pkg, as) } } // Has reports whether the specified package has been imported. func (it *ImportTracker) Has(pkg string) bool { - return it.packages[pkg] + _, ok := it.packages[pkg] + return ok } func (it *ImportTracker) qualifier(pkg *types.Package) string { if it.thisPkg == pkg { return "" } - it.Import(pkg.Path()) - // TODO(maisem): handle conflicts? - return pkg.Name() + path := pkg.Path() + it.Import(path) + return cmp.Or(it.packages[path], pkg.Name()) } // QualifiedName returns the string representation of t in the package. @@ -127,8 +135,12 @@ func (it *ImportTracker) PackagePrefix(pkg *types.Package) string { // Write prints all the tracked imports in a single import block to w. func (it *ImportTracker) Write(w io.Writer) { fmt.Fprintf(w, "import (\n") - for s := range it.packages { - fmt.Fprintf(w, "\t%q\n", s) + for s, q := range it.packages { + if q != "" { + fmt.Fprintf(w, "\t%s %q\n", q, s) + } else { + fmt.Fprintf(w, "\t%q\n", s) + } } fmt.Fprintf(w, ")\n\n") }