| 
									
										
										
										
											2020-07-24 18:00:02 +10:00
										 |  |  | // Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved. | 
					
						
							|  |  |  | // Use of this source code is governed by a BSD-style | 
					
						
							|  |  |  | // license that can be found in the LICENSE file. | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | // Cloner is a tool to automate the creation of a Clone method. | 
					
						
							|  |  |  | // | 
					
						
							|  |  |  | // The result of the Clone method aliases no memory that can be edited | 
					
						
							|  |  |  | // with the original. | 
					
						
							|  |  |  | // | 
					
						
							|  |  |  | // This tool makes lots of implicit assumptions about the types you feed it. | 
					
						
							|  |  |  | // In particular, it can only write relatively "shallow" Clone methods. | 
					
						
							|  |  |  | // That is, if a type contains another named struct type, cloner assumes that | 
					
						
							|  |  |  | // named type will also have a Clone method. | 
					
						
							|  |  |  | package main | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | import ( | 
					
						
							|  |  |  | 	"bytes" | 
					
						
							|  |  |  | 	"flag" | 
					
						
							|  |  |  | 	"fmt" | 
					
						
							|  |  |  | 	"go/types" | 
					
						
							|  |  |  | 	"log" | 
					
						
							|  |  |  | 	"os" | 
					
						
							|  |  |  | 	"strings" | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	"golang.org/x/tools/go/packages" | 
					
						
							| 
									
										
										
										
											2021-09-16 15:41:57 -07:00
										 |  |  | 	"tailscale.com/util/codegen" | 
					
						
							| 
									
										
										
										
											2020-07-24 18:00:02 +10:00
										 |  |  | ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | var ( | 
					
						
							|  |  |  | 	flagTypes     = flag.String("type", "", "comma-separated list of types; required") | 
					
						
							|  |  |  | 	flagOutput    = flag.String("output", "", "output file; required") | 
					
						
							|  |  |  | 	flagBuildTags = flag.String("tags", "", "compiler build tags to apply") | 
					
						
							| 
									
										
										
										
											2020-10-19 10:46:30 -07:00
										 |  |  | 	flagCloneFunc = flag.Bool("clonefunc", false, "add a top-level Clone func") | 
					
						
							| 
									
										
										
										
											2020-07-24 18:00:02 +10:00
										 |  |  | ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | func main() { | 
					
						
							|  |  |  | 	log.SetFlags(0) | 
					
						
							|  |  |  | 	log.SetPrefix("cloner: ") | 
					
						
							|  |  |  | 	flag.Parse() | 
					
						
							|  |  |  | 	if len(*flagTypes) == 0 { | 
					
						
							|  |  |  | 		flag.Usage() | 
					
						
							|  |  |  | 		os.Exit(2) | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 	typeNames := strings.Split(*flagTypes, ",") | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	cfg := &packages.Config{ | 
					
						
							|  |  |  | 		Mode:  packages.NeedTypes | packages.NeedTypesInfo | packages.NeedSyntax | packages.NeedName, | 
					
						
							|  |  |  | 		Tests: false, | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 	if *flagBuildTags != "" { | 
					
						
							|  |  |  | 		cfg.BuildFlags = []string{"-tags=" + *flagBuildTags} | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 	pkgs, err := packages.Load(cfg, ".") | 
					
						
							|  |  |  | 	if err != nil { | 
					
						
							|  |  |  | 		log.Fatal(err) | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 	if len(pkgs) != 1 { | 
					
						
							|  |  |  | 		log.Fatalf("wrong number of packages: %d", len(pkgs)) | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 	pkg := pkgs[0] | 
					
						
							|  |  |  | 	buf := new(bytes.Buffer) | 
					
						
							|  |  |  | 	imports := make(map[string]struct{}) | 
					
						
							| 
									
										
										
										
											2021-09-16 16:24:50 -07:00
										 |  |  | 	namedTypes := codegen.NamedTypes(pkg) | 
					
						
							| 
									
										
										
										
											2020-07-24 18:00:02 +10:00
										 |  |  | 	for _, typeName := range typeNames { | 
					
						
							| 
									
										
										
										
											2021-09-16 16:24:50 -07:00
										 |  |  | 		typ, ok := namedTypes[typeName] | 
					
						
							|  |  |  | 		if !ok { | 
					
						
							| 
									
										
										
										
											2020-07-24 18:00:02 +10:00
										 |  |  | 			log.Fatalf("could not find type %s", typeName) | 
					
						
							|  |  |  | 		} | 
					
						
							| 
									
										
										
										
											2021-09-16 16:24:50 -07:00
										 |  |  | 		gen(buf, imports, typ, pkg.Types) | 
					
						
							| 
									
										
										
										
											2020-07-24 18:00:02 +10:00
										 |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-09-04 15:19:12 -07:00
										 |  |  | 	w := func(format string, args ...interface{}) { | 
					
						
							|  |  |  | 		fmt.Fprintf(buf, format+"\n", args...) | 
					
						
							|  |  |  | 	} | 
					
						
							| 
									
										
										
										
											2020-10-19 10:46:30 -07:00
										 |  |  | 	if *flagCloneFunc { | 
					
						
							|  |  |  | 		w("// Clone duplicates src into dst and reports whether it succeeded.") | 
					
						
							|  |  |  | 		w("// To succeed, <src, dst> must be of types <*T, *T> or <*T, **T>,") | 
					
						
							|  |  |  | 		w("// where T is one of %s.", *flagTypes) | 
					
						
							|  |  |  | 		w("func Clone(dst, src interface{}) bool {") | 
					
						
							|  |  |  | 		w("	switch src := src.(type) {") | 
					
						
							|  |  |  | 		for _, typeName := range typeNames { | 
					
						
							|  |  |  | 			w("	case *%s:", typeName) | 
					
						
							|  |  |  | 			w("		switch dst := dst.(type) {") | 
					
						
							|  |  |  | 			w("		case *%s:", typeName) | 
					
						
							|  |  |  | 			w("			*dst = *src.Clone()") | 
					
						
							|  |  |  | 			w("			return true") | 
					
						
							|  |  |  | 			w("		case **%s:", typeName) | 
					
						
							|  |  |  | 			w("			*dst = src.Clone()") | 
					
						
							|  |  |  | 			w("			return true") | 
					
						
							|  |  |  | 			w("		}") | 
					
						
							|  |  |  | 		} | 
					
						
							|  |  |  | 		w("	}") | 
					
						
							|  |  |  | 		w("	return false") | 
					
						
							|  |  |  | 		w("}") | 
					
						
							| 
									
										
										
										
											2020-09-04 15:19:12 -07:00
										 |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-07-24 18:00:02 +10:00
										 |  |  | 	contents := new(bytes.Buffer) | 
					
						
							| 
									
										
										
										
											2021-10-14 12:25:55 -07:00
										 |  |  | 	var flagArgs []string | 
					
						
							|  |  |  | 	if *flagTypes != "" { | 
					
						
							|  |  |  | 		flagArgs = append(flagArgs, "-type="+*flagTypes) | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 	if *flagOutput != "" { | 
					
						
							|  |  |  | 		flagArgs = append(flagArgs, "-output="+*flagOutput) | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 	if *flagBuildTags != "" { | 
					
						
							|  |  |  | 		flagArgs = append(flagArgs, "-tags="+*flagBuildTags) | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 	if *flagCloneFunc { | 
					
						
							|  |  |  | 		flagArgs = append(flagArgs, "-clonefunc") | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 	fmt.Fprintf(contents, header, strings.Join(flagArgs, " "), pkg.Name) | 
					
						
							| 
									
										
										
										
											2020-07-24 18:00:02 +10:00
										 |  |  | 	fmt.Fprintf(contents, "import (\n") | 
					
						
							|  |  |  | 	for s := range imports { | 
					
						
							|  |  |  | 		fmt.Fprintf(contents, "\t%q\n", s) | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 	fmt.Fprintf(contents, ")\n\n") | 
					
						
							|  |  |  | 	contents.Write(buf.Bytes()) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	output := *flagOutput | 
					
						
							|  |  |  | 	if output == "" { | 
					
						
							|  |  |  | 		flag.Usage() | 
					
						
							|  |  |  | 		os.Exit(2) | 
					
						
							|  |  |  | 	} | 
					
						
							| 
									
										
										
										
											2021-09-16 15:41:57 -07:00
										 |  |  | 	if err := codegen.WriteFormatted(contents.Bytes(), output); err != nil { | 
					
						
							| 
									
										
										
										
											2020-07-24 18:00:02 +10:00
										 |  |  | 		log.Fatal(err) | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | const header = `// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved. | 
					
						
							|  |  |  | // Use of this source code is governed by a BSD-style | 
					
						
							|  |  |  | // license that can be found in the LICENSE file. | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-10-14 12:25:55 -07:00
										 |  |  | // Code generated by tailscale.com/cmd/cloner; DO NOT EDIT. | 
					
						
							| 
									
										
										
										
											2021-10-14 12:51:28 -07:00
										 |  |  | //` + `go:generate` + ` go run tailscale.com/cmd/cloner %s | 
					
						
							| 
									
										
										
										
											2020-07-24 18:00:02 +10:00
										 |  |  | 
 | 
					
						
							|  |  |  | package %s | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | ` | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-09-16 16:00:34 -07:00
										 |  |  | func gen(buf *bytes.Buffer, imports map[string]struct{}, typ *types.Named, thisPkg *types.Package) { | 
					
						
							| 
									
										
										
										
											2020-07-24 18:00:02 +10:00
										 |  |  | 	pkgQual := func(pkg *types.Package) string { | 
					
						
							|  |  |  | 		if thisPkg == pkg { | 
					
						
							|  |  |  | 			return "" | 
					
						
							|  |  |  | 		} | 
					
						
							|  |  |  | 		imports[pkg.Path()] = struct{}{} | 
					
						
							|  |  |  | 		return pkg.Name() | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 	importedName := func(t types.Type) string { | 
					
						
							|  |  |  | 		return types.TypeString(t, pkgQual) | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-09-16 16:00:34 -07:00
										 |  |  | 	t, ok := typ.Underlying().(*types.Struct) | 
					
						
							|  |  |  | 	if !ok { | 
					
						
							|  |  |  | 		return | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	name := typ.Obj().Name() | 
					
						
							|  |  |  | 	fmt.Fprintf(buf, "// Clone makes a deep copy of %s.\n", name) | 
					
						
							|  |  |  | 	fmt.Fprintf(buf, "// The result aliases no memory with the original.\n") | 
					
						
							|  |  |  | 	fmt.Fprintf(buf, "func (src *%s) Clone() *%s {\n", name, name) | 
					
						
							|  |  |  | 	writef := func(format string, args ...interface{}) { | 
					
						
							|  |  |  | 		fmt.Fprintf(buf, "\t"+format+"\n", args...) | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 	writef("if src == nil {") | 
					
						
							|  |  |  | 	writef("\treturn nil") | 
					
						
							|  |  |  | 	writef("}") | 
					
						
							|  |  |  | 	writef("dst := new(%s)", name) | 
					
						
							|  |  |  | 	writef("*dst = *src") | 
					
						
							|  |  |  | 	for i := 0; i < t.NumFields(); i++ { | 
					
						
							|  |  |  | 		fname := t.Field(i).Name() | 
					
						
							|  |  |  | 		ft := t.Field(i).Type() | 
					
						
							| 
									
										
										
										
											2021-09-17 11:29:17 -07:00
										 |  |  | 		if !codegen.ContainsPointers(ft) { | 
					
						
							| 
									
										
										
										
											2021-09-16 16:00:34 -07:00
										 |  |  | 			continue | 
					
						
							| 
									
										
										
										
											2020-07-24 18:00:02 +10:00
										 |  |  | 		} | 
					
						
							| 
									
										
										
										
											2021-09-16 16:00:34 -07:00
										 |  |  | 		if named, _ := ft.(*types.Named); named != nil && !hasBasicUnderlying(ft) { | 
					
						
							|  |  |  | 			writef("dst.%s = *src.%s.Clone()", fname, fname) | 
					
						
							|  |  |  | 			continue | 
					
						
							|  |  |  | 		} | 
					
						
							|  |  |  | 		switch ft := ft.Underlying().(type) { | 
					
						
							|  |  |  | 		case *types.Slice: | 
					
						
							| 
									
										
										
										
											2021-09-17 11:29:17 -07:00
										 |  |  | 			if codegen.ContainsPointers(ft.Elem()) { | 
					
						
							| 
									
										
										
										
											2020-07-24 18:00:02 +10:00
										 |  |  | 				n := importedName(ft.Elem()) | 
					
						
							| 
									
										
										
										
											2021-09-16 16:00:34 -07:00
										 |  |  | 				writef("dst.%s = make([]%s, len(src.%s))", fname, n, fname) | 
					
						
							|  |  |  | 				writef("for i := range dst.%s {", fname) | 
					
						
							|  |  |  | 				if _, isPtr := ft.Elem().(*types.Pointer); isPtr { | 
					
						
							|  |  |  | 					writef("\tdst.%s[i] = src.%s[i].Clone()", fname, fname) | 
					
						
							| 
									
										
										
										
											2020-07-24 18:00:02 +10:00
										 |  |  | 				} else { | 
					
						
							| 
									
										
										
										
											2021-09-16 16:00:34 -07:00
										 |  |  | 					writef("\tdst.%s[i] = *src.%s[i].Clone()", fname, fname) | 
					
						
							| 
									
										
										
										
											2020-07-24 18:00:02 +10:00
										 |  |  | 				} | 
					
						
							|  |  |  | 				writef("}") | 
					
						
							| 
									
										
										
										
											2021-09-16 16:00:34 -07:00
										 |  |  | 			} else { | 
					
						
							|  |  |  | 				writef("dst.%s = append(src.%s[:0:0], src.%s...)", fname, fname, fname) | 
					
						
							| 
									
										
										
										
											2020-07-24 18:00:02 +10:00
										 |  |  | 			} | 
					
						
							| 
									
										
										
										
											2021-09-16 16:00:34 -07:00
										 |  |  | 		case *types.Pointer: | 
					
						
							| 
									
										
										
										
											2021-09-17 11:29:17 -07:00
										 |  |  | 			if named, _ := ft.Elem().(*types.Named); named != nil && codegen.ContainsPointers(ft.Elem()) { | 
					
						
							| 
									
										
										
										
											2021-09-16 16:00:34 -07:00
										 |  |  | 				writef("dst.%s = src.%s.Clone()", fname, fname) | 
					
						
							|  |  |  | 				continue | 
					
						
							|  |  |  | 			} | 
					
						
							|  |  |  | 			n := importedName(ft.Elem()) | 
					
						
							|  |  |  | 			writef("if dst.%s != nil {", fname) | 
					
						
							|  |  |  | 			writef("\tdst.%s = new(%s)", fname, n) | 
					
						
							|  |  |  | 			writef("\t*dst.%s = *src.%s", fname, fname) | 
					
						
							| 
									
										
										
										
											2021-09-17 11:29:17 -07:00
										 |  |  | 			if codegen.ContainsPointers(ft.Elem()) { | 
					
						
							| 
									
										
										
										
											2021-09-16 16:00:34 -07:00
										 |  |  | 				writef("\t" + `panic("TODO pointers in pointers")`) | 
					
						
							|  |  |  | 			} | 
					
						
							|  |  |  | 			writef("}") | 
					
						
							|  |  |  | 		case *types.Map: | 
					
						
							|  |  |  | 			writef("if dst.%s != nil {", fname) | 
					
						
							|  |  |  | 			writef("\tdst.%s = map[%s]%s{}", fname, importedName(ft.Key()), importedName(ft.Elem())) | 
					
						
							|  |  |  | 			if sliceType, isSlice := ft.Elem().(*types.Slice); isSlice { | 
					
						
							|  |  |  | 				n := importedName(sliceType.Elem()) | 
					
						
							|  |  |  | 				writef("\tfor k := range src.%s {", fname) | 
					
						
							|  |  |  | 				// use zero-length slice instead of nil to ensure | 
					
						
							|  |  |  | 				// the key is always copied. | 
					
						
							|  |  |  | 				writef("\t\tdst.%s[k] = append([]%s{}, src.%s[k]...)", fname, n, fname) | 
					
						
							|  |  |  | 				writef("\t}") | 
					
						
							| 
									
										
										
										
											2021-09-17 11:29:17 -07:00
										 |  |  | 			} else if codegen.ContainsPointers(ft.Elem()) { | 
					
						
							| 
									
										
										
										
											2021-09-16 16:00:34 -07:00
										 |  |  | 				writef("\tfor k, v := range src.%s {", fname) | 
					
						
							|  |  |  | 				writef("\t\tdst.%s[k] = v.Clone()", fname) | 
					
						
							|  |  |  | 				writef("\t}") | 
					
						
							|  |  |  | 			} else { | 
					
						
							|  |  |  | 				writef("\tfor k, v := range src.%s {", fname) | 
					
						
							|  |  |  | 				writef("\t\tdst.%s[k] = v", fname) | 
					
						
							|  |  |  | 				writef("\t}") | 
					
						
							|  |  |  | 			} | 
					
						
							|  |  |  | 			writef("}") | 
					
						
							|  |  |  | 		default: | 
					
						
							| 
									
										
										
										
											2021-09-16 16:31:35 -07:00
										 |  |  | 			writef(`panic("TODO: %s (%T)")`, fname, ft) | 
					
						
							| 
									
										
										
										
											2020-07-24 18:00:02 +10:00
										 |  |  | 		} | 
					
						
							|  |  |  | 	} | 
					
						
							| 
									
										
										
										
											2021-09-16 16:00:34 -07:00
										 |  |  | 	writef("return dst") | 
					
						
							|  |  |  | 	fmt.Fprintf(buf, "}\n\n") | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-09-17 16:43:57 -07:00
										 |  |  | 	buf.Write(codegen.AssertStructUnchanged(t, thisPkg, name, "Clone", imports)) | 
					
						
							| 
									
										
										
										
											2020-07-24 18:00:02 +10:00
										 |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | func hasBasicUnderlying(typ types.Type) bool { | 
					
						
							|  |  |  | 	switch typ.Underlying().(type) { | 
					
						
							|  |  |  | 	case *types.Slice, *types.Map: | 
					
						
							|  |  |  | 		return true | 
					
						
							|  |  |  | 	default: | 
					
						
							|  |  |  | 		return false | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | } |