// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

// Cloner is a tool to automate the creation of a Clone method.
//
// The result of the Clone method aliases no memory that can be edited
// with the original.
//
// This tool makes lots of implicit assumptions about the types you feed it.
// In particular, it can only write relatively "shallow" Clone methods.
// That is, if a type contains another named struct type, cloner assumes that
// named type will also have a Clone method.
package main

import (
	"bytes"
	"flag"
	"fmt"
	"go/types"
	"log"
	"os"
	"strings"

	"golang.org/x/tools/go/packages"
	"tailscale.com/util/codegen"
)

var (
	flagTypes     = flag.String("type", "", "comma-separated list of types; required")
	flagOutput    = flag.String("output", "", "output file; required")
	flagBuildTags = flag.String("tags", "", "compiler build tags to apply")
	flagCloneFunc = flag.Bool("clonefunc", false, "add a top-level Clone func")
)

func main() {
	log.SetFlags(0)
	log.SetPrefix("cloner: ")
	flag.Parse()
	if len(*flagTypes) == 0 {
		flag.Usage()
		os.Exit(2)
	}
	typeNames := strings.Split(*flagTypes, ",")

	cfg := &packages.Config{
		Mode:  packages.NeedTypes | packages.NeedTypesInfo | packages.NeedSyntax | packages.NeedName,
		Tests: false,
	}
	if *flagBuildTags != "" {
		cfg.BuildFlags = []string{"-tags=" + *flagBuildTags}
	}
	pkgs, err := packages.Load(cfg, ".")
	if err != nil {
		log.Fatal(err)
	}
	if len(pkgs) != 1 {
		log.Fatalf("wrong number of packages: %d", len(pkgs))
	}
	pkg := pkgs[0]
	buf := new(bytes.Buffer)
	imports := make(map[string]struct{})
	namedTypes := codegen.NamedTypes(pkg)
	for _, typeName := range typeNames {
		typ, ok := namedTypes[typeName]
		if !ok {
			log.Fatalf("could not find type %s", typeName)
		}
		gen(buf, imports, typ, pkg.Types)
	}

	w := func(format string, args ...interface{}) {
		fmt.Fprintf(buf, format+"\n", args...)
	}
	if *flagCloneFunc {
		w("// Clone duplicates src into dst and reports whether it succeeded.")
		w("// To succeed, <src, dst> must be of types <*T, *T> or <*T, **T>,")
		w("// where T is one of %s.", *flagTypes)
		w("func Clone(dst, src interface{}) bool {")
		w("	switch src := src.(type) {")
		for _, typeName := range typeNames {
			w("	case *%s:", typeName)
			w("		switch dst := dst.(type) {")
			w("		case *%s:", typeName)
			w("			*dst = *src.Clone()")
			w("			return true")
			w("		case **%s:", typeName)
			w("			*dst = src.Clone()")
			w("			return true")
			w("		}")
		}
		w("	}")
		w("	return false")
		w("}")
	}

	contents := new(bytes.Buffer)
	var flagArgs []string
	if *flagTypes != "" {
		flagArgs = append(flagArgs, "-type="+*flagTypes)
	}
	if *flagOutput != "" {
		flagArgs = append(flagArgs, "-output="+*flagOutput)
	}
	if *flagBuildTags != "" {
		flagArgs = append(flagArgs, "-tags="+*flagBuildTags)
	}
	if *flagCloneFunc {
		flagArgs = append(flagArgs, "-clonefunc")
	}
	fmt.Fprintf(contents, header, strings.Join(flagArgs, " "), pkg.Name)
	fmt.Fprintf(contents, "import (\n")
	for s := range imports {
		fmt.Fprintf(contents, "\t%q\n", s)
	}
	fmt.Fprintf(contents, ")\n\n")
	contents.Write(buf.Bytes())

	output := *flagOutput
	if output == "" {
		flag.Usage()
		os.Exit(2)
	}
	if err := codegen.WriteFormatted(contents.Bytes(), output); err != nil {
		log.Fatal(err)
	}
}

const header = `// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

// Code generated by tailscale.com/cmd/cloner; DO NOT EDIT.
//` + `go:generate` + ` go run tailscale.com/cmd/cloner %s

package %s

`

func gen(buf *bytes.Buffer, imports map[string]struct{}, typ *types.Named, thisPkg *types.Package) {
	pkgQual := func(pkg *types.Package) string {
		if thisPkg == pkg {
			return ""
		}
		imports[pkg.Path()] = struct{}{}
		return pkg.Name()
	}
	importedName := func(t types.Type) string {
		return types.TypeString(t, pkgQual)
	}

	t, ok := typ.Underlying().(*types.Struct)
	if !ok {
		return
	}

	name := typ.Obj().Name()
	fmt.Fprintf(buf, "// Clone makes a deep copy of %s.\n", name)
	fmt.Fprintf(buf, "// The result aliases no memory with the original.\n")
	fmt.Fprintf(buf, "func (src *%s) Clone() *%s {\n", name, name)
	writef := func(format string, args ...interface{}) {
		fmt.Fprintf(buf, "\t"+format+"\n", args...)
	}
	writef("if src == nil {")
	writef("\treturn nil")
	writef("}")
	writef("dst := new(%s)", name)
	writef("*dst = *src")
	for i := 0; i < t.NumFields(); i++ {
		fname := t.Field(i).Name()
		ft := t.Field(i).Type()
		if !codegen.ContainsPointers(ft) {
			continue
		}
		if named, _ := ft.(*types.Named); named != nil && !hasBasicUnderlying(ft) {
			writef("dst.%s = *src.%s.Clone()", fname, fname)
			continue
		}
		switch ft := ft.Underlying().(type) {
		case *types.Slice:
			if codegen.ContainsPointers(ft.Elem()) {
				n := importedName(ft.Elem())
				writef("dst.%s = make([]%s, len(src.%s))", fname, n, fname)
				writef("for i := range dst.%s {", fname)
				if _, isPtr := ft.Elem().(*types.Pointer); isPtr {
					writef("\tdst.%s[i] = src.%s[i].Clone()", fname, fname)
				} else {
					writef("\tdst.%s[i] = *src.%s[i].Clone()", fname, fname)
				}
				writef("}")
			} else {
				writef("dst.%s = append(src.%s[:0:0], src.%s...)", fname, fname, fname)
			}
		case *types.Pointer:
			if named, _ := ft.Elem().(*types.Named); named != nil && codegen.ContainsPointers(ft.Elem()) {
				writef("dst.%s = src.%s.Clone()", fname, fname)
				continue
			}
			n := importedName(ft.Elem())
			writef("if dst.%s != nil {", fname)
			writef("\tdst.%s = new(%s)", fname, n)
			writef("\t*dst.%s = *src.%s", fname, fname)
			if codegen.ContainsPointers(ft.Elem()) {
				writef("\t" + `panic("TODO pointers in pointers")`)
			}
			writef("}")
		case *types.Map:
			writef("if dst.%s != nil {", fname)
			writef("\tdst.%s = map[%s]%s{}", fname, importedName(ft.Key()), importedName(ft.Elem()))
			if sliceType, isSlice := ft.Elem().(*types.Slice); isSlice {
				n := importedName(sliceType.Elem())
				writef("\tfor k := range src.%s {", fname)
				// use zero-length slice instead of nil to ensure
				// the key is always copied.
				writef("\t\tdst.%s[k] = append([]%s{}, src.%s[k]...)", fname, n, fname)
				writef("\t}")
			} else if codegen.ContainsPointers(ft.Elem()) {
				writef("\tfor k, v := range src.%s {", fname)
				writef("\t\tdst.%s[k] = v.Clone()", fname)
				writef("\t}")
			} else {
				writef("\tfor k, v := range src.%s {", fname)
				writef("\t\tdst.%s[k] = v", fname)
				writef("\t}")
			}
			writef("}")
		default:
			writef(`panic("TODO: %s (%T)")`, fname, ft)
		}
	}
	writef("return dst")
	fmt.Fprintf(buf, "}\n\n")

	buf.Write(codegen.AssertStructUnchanged(t, thisPkg, name, "Clone", imports))
}

func hasBasicUnderlying(typ types.Type) bool {
	switch typ.Underlying().(type) {
	case *types.Slice, *types.Map:
		return true
	default:
		return false
	}
}