// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved. // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. // Cloner is a tool to automate the creation of a Clone method. // // The result of the Clone method aliases no memory that can be edited // with the original. // // This tool makes lots of implicit assumptions about the types you feed it. // In particular, it can only write relatively "shallow" Clone methods. // That is, if a type contains another named struct type, cloner assumes that // named type will also have a Clone method. package main import ( "bytes" "flag" "fmt" "go/types" "log" "os" "strings" "golang.org/x/tools/go/packages" "tailscale.com/util/codegen" ) var ( flagTypes = flag.String("type", "", "comma-separated list of types; required") flagOutput = flag.String("output", "", "output file; required") flagBuildTags = flag.String("tags", "", "compiler build tags to apply") flagCloneFunc = flag.Bool("clonefunc", false, "add a top-level Clone func") ) func main() { log.SetFlags(0) log.SetPrefix("cloner: ") flag.Parse() if len(*flagTypes) == 0 { flag.Usage() os.Exit(2) } typeNames := strings.Split(*flagTypes, ",") cfg := &packages.Config{ Mode: packages.NeedTypes | packages.NeedTypesInfo | packages.NeedSyntax | packages.NeedName, Tests: false, } if *flagBuildTags != "" { cfg.BuildFlags = []string{"-tags=" + *flagBuildTags} } pkgs, err := packages.Load(cfg, ".") if err != nil { log.Fatal(err) } if len(pkgs) != 1 { log.Fatalf("wrong number of packages: %d", len(pkgs)) } pkg := pkgs[0] buf := new(bytes.Buffer) imports := make(map[string]struct{}) namedTypes := codegen.NamedTypes(pkg) for _, typeName := range typeNames { typ, ok := namedTypes[typeName] if !ok { log.Fatalf("could not find type %s", typeName) } gen(buf, imports, typ, pkg.Types) } w := func(format string, args ...interface{}) { fmt.Fprintf(buf, format+"\n", args...) } if *flagCloneFunc { w("// Clone duplicates src into dst and reports whether it succeeded.") w("// To succeed, must be of types <*T, *T> or <*T, **T>,") w("// where T is one of %s.", *flagTypes) w("func Clone(dst, src interface{}) bool {") w(" switch src := src.(type) {") for _, typeName := range typeNames { w(" case *%s:", typeName) w(" switch dst := dst.(type) {") w(" case *%s:", typeName) w(" *dst = *src.Clone()") w(" return true") w(" case **%s:", typeName) w(" *dst = src.Clone()") w(" return true") w(" }") } w(" }") w(" return false") w("}") } contents := new(bytes.Buffer) var flagArgs []string if *flagTypes != "" { flagArgs = append(flagArgs, "-type="+*flagTypes) } if *flagOutput != "" { flagArgs = append(flagArgs, "-output="+*flagOutput) } if *flagBuildTags != "" { flagArgs = append(flagArgs, "-tags="+*flagBuildTags) } if *flagCloneFunc { flagArgs = append(flagArgs, "-clonefunc") } fmt.Fprintf(contents, header, strings.Join(flagArgs, " "), pkg.Name) fmt.Fprintf(contents, "import (\n") for s := range imports { fmt.Fprintf(contents, "\t%q\n", s) } fmt.Fprintf(contents, ")\n\n") contents.Write(buf.Bytes()) output := *flagOutput if output == "" { flag.Usage() os.Exit(2) } if err := codegen.WriteFormatted(contents.Bytes(), output); err != nil { log.Fatal(err) } } const header = `// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved. // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. // Code generated by tailscale.com/cmd/cloner; DO NOT EDIT. //go:generate go run tailscale.com/cmd/cloner %s package %s ` func gen(buf *bytes.Buffer, imports map[string]struct{}, typ *types.Named, thisPkg *types.Package) { pkgQual := func(pkg *types.Package) string { if thisPkg == pkg { return "" } imports[pkg.Path()] = struct{}{} return pkg.Name() } importedName := func(t types.Type) string { return types.TypeString(t, pkgQual) } t, ok := typ.Underlying().(*types.Struct) if !ok { return } name := typ.Obj().Name() fmt.Fprintf(buf, "// Clone makes a deep copy of %s.\n", name) fmt.Fprintf(buf, "// The result aliases no memory with the original.\n") fmt.Fprintf(buf, "func (src *%s) Clone() *%s {\n", name, name) writef := func(format string, args ...interface{}) { fmt.Fprintf(buf, "\t"+format+"\n", args...) } writef("if src == nil {") writef("\treturn nil") writef("}") writef("dst := new(%s)", name) writef("*dst = *src") for i := 0; i < t.NumFields(); i++ { fname := t.Field(i).Name() ft := t.Field(i).Type() if !codegen.ContainsPointers(ft) { continue } if named, _ := ft.(*types.Named); named != nil && !hasBasicUnderlying(ft) { writef("dst.%s = *src.%s.Clone()", fname, fname) continue } switch ft := ft.Underlying().(type) { case *types.Slice: if codegen.ContainsPointers(ft.Elem()) { n := importedName(ft.Elem()) writef("dst.%s = make([]%s, len(src.%s))", fname, n, fname) writef("for i := range dst.%s {", fname) if _, isPtr := ft.Elem().(*types.Pointer); isPtr { writef("\tdst.%s[i] = src.%s[i].Clone()", fname, fname) } else { writef("\tdst.%s[i] = *src.%s[i].Clone()", fname, fname) } writef("}") } else { writef("dst.%s = append(src.%s[:0:0], src.%s...)", fname, fname, fname) } case *types.Pointer: if named, _ := ft.Elem().(*types.Named); named != nil && codegen.ContainsPointers(ft.Elem()) { writef("dst.%s = src.%s.Clone()", fname, fname) continue } n := importedName(ft.Elem()) writef("if dst.%s != nil {", fname) writef("\tdst.%s = new(%s)", fname, n) writef("\t*dst.%s = *src.%s", fname, fname) if codegen.ContainsPointers(ft.Elem()) { writef("\t" + `panic("TODO pointers in pointers")`) } writef("}") case *types.Map: writef("if dst.%s != nil {", fname) writef("\tdst.%s = map[%s]%s{}", fname, importedName(ft.Key()), importedName(ft.Elem())) if sliceType, isSlice := ft.Elem().(*types.Slice); isSlice { n := importedName(sliceType.Elem()) writef("\tfor k := range src.%s {", fname) // use zero-length slice instead of nil to ensure // the key is always copied. writef("\t\tdst.%s[k] = append([]%s{}, src.%s[k]...)", fname, n, fname) writef("\t}") } else if codegen.ContainsPointers(ft.Elem()) { writef("\tfor k, v := range src.%s {", fname) writef("\t\tdst.%s[k] = v.Clone()", fname) writef("\t}") } else { writef("\tfor k, v := range src.%s {", fname) writef("\t\tdst.%s[k] = v", fname) writef("\t}") } writef("}") default: writef(`panic("TODO: %s (%T)")`, fname, ft) } } writef("return dst") fmt.Fprintf(buf, "}\n\n") buf.Write(codegen.AssertStructUnchanged(t, thisPkg, name, "Clone", imports)) } func hasBasicUnderlying(typ types.Type) bool { switch typ.Underlying().(type) { case *types.Slice, *types.Map: return true default: return false } }