From 0f4c9c0ecb133f2e7e3df2626e2a6a114d6dc251 Mon Sep 17 00:00:00 2001 From: Nick Khyl Date: Mon, 21 Oct 2024 12:28:41 -0500 Subject: [PATCH] cmd/viewer: import types/views when generating a getter for a map field Fixes #13873 Signed-off-by: Nick Khyl --- cmd/viewer/viewer.go | 1 + cmd/viewer/viewer_test.go | 78 +++++++++++++++++++++++++++++++++++++++ util/codegen/codegen.go | 5 +++ 3 files changed, 84 insertions(+) create mode 100644 cmd/viewer/viewer_test.go diff --git a/cmd/viewer/viewer.go b/cmd/viewer/viewer.go index 96223297b..0c5868f3a 100644 --- a/cmd/viewer/viewer.go +++ b/cmd/viewer/viewer.go @@ -258,6 +258,7 @@ func genView(buf *bytes.Buffer, it *codegen.ImportTracker, typ *types.Named, thi writeTemplate("unsupportedField") continue } + it.Import("tailscale.com/types/views") args.MapKeyType = it.QualifiedName(key) mElem := m.Elem() var template string diff --git a/cmd/viewer/viewer_test.go b/cmd/viewer/viewer_test.go new file mode 100644 index 000000000..cd5f3d95f --- /dev/null +++ b/cmd/viewer/viewer_test.go @@ -0,0 +1,78 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package main + +import ( + "bytes" + "fmt" + "go/ast" + "go/parser" + "go/token" + "go/types" + "testing" + + "tailscale.com/util/codegen" +) + +func TestViewerImports(t *testing.T) { + tests := []struct { + name string + content string + typeNames []string + wantImports []string + }{ + { + name: "Map", + content: `type Test struct { Map map[string]int }`, + typeNames: []string{"Test"}, + wantImports: []string{"tailscale.com/types/views"}, + }, + { + name: "Slice", + content: `type Test struct { Slice []int }`, + typeNames: []string{"Test"}, + wantImports: []string{"tailscale.com/types/views"}, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + fset := token.NewFileSet() + f, err := parser.ParseFile(fset, "test.go", "package test\n\n"+tt.content, 0) + if err != nil { + fmt.Println("Error parsing:", err) + return + } + + info := &types.Info{ + Types: make(map[ast.Expr]types.TypeAndValue), + } + + conf := types.Config{} + pkg, err := conf.Check("", fset, []*ast.File{f}, info) + if err != nil { + t.Fatal(err) + } + + var output bytes.Buffer + tracker := codegen.NewImportTracker(pkg) + for i := range tt.typeNames { + typeName, ok := pkg.Scope().Lookup(tt.typeNames[i]).(*types.TypeName) + if !ok { + t.Fatalf("type %q does not exist", tt.typeNames[i]) + } + namedType, ok := typeName.Type().(*types.Named) + if !ok { + t.Fatalf("%q is not a named type", tt.typeNames[i]) + } + genView(&output, tracker, namedType, pkg) + } + + for _, pkgName := range tt.wantImports { + if !tracker.Has(pkgName) { + t.Errorf("missing import %q", pkgName) + } + } + }) + } +} diff --git a/util/codegen/codegen.go b/util/codegen/codegen.go index d998d925d..2f7781b68 100644 --- a/util/codegen/codegen.go +++ b/util/codegen/codegen.go @@ -97,6 +97,11 @@ func (it *ImportTracker) Import(pkg string) { } } +// Has reports whether the specified package has been imported. +func (it *ImportTracker) Has(pkg string) bool { + return it.packages[pkg] +} + func (it *ImportTracker) qualifier(pkg *types.Package) string { if it.thisPkg == pkg { return ""