mirror of
https://github.com/tailscale/tailscale.git
synced 2024-11-29 13:05:46 +00:00
util/codegen: add AssertStructUnchanged
Refactored out from cmd/cloner. Signed-off-by: Josh Bleecher Snyder <josh@tailscale.com>
This commit is contained in:
parent
fb66ff7c78
commit
618376dbc0
@ -164,18 +164,6 @@ func gen(buf *bytes.Buffer, imports map[string]struct{}, name string, typ *types
|
|||||||
|
|
||||||
switch t := typ.Underlying().(type) {
|
switch t := typ.Underlying().(type) {
|
||||||
case *types.Struct:
|
case *types.Struct:
|
||||||
// We generate two bits of code simultaneously while we walk the struct.
|
|
||||||
// One is the Clone method itself, which we write directly to buf.
|
|
||||||
// The other is a variable assignment that will fail if the struct
|
|
||||||
// changes without the Clone method getting regenerated.
|
|
||||||
// We write that to regenBuf, and then append it to buf at the end.
|
|
||||||
regenBuf := new(bytes.Buffer)
|
|
||||||
writeRegen := func(format string, args ...interface{}) {
|
|
||||||
fmt.Fprintf(regenBuf, format+"\n", args...)
|
|
||||||
}
|
|
||||||
writeRegen("// A compilation failure here means this code must be regenerated, with the command at the top of this file.")
|
|
||||||
writeRegen("var _%sNeedsRegeneration = %s(struct {", name, name)
|
|
||||||
|
|
||||||
name := typ.Obj().Name()
|
name := typ.Obj().Name()
|
||||||
fmt.Fprintf(buf, "// Clone makes a deep copy of %s.\n", name)
|
fmt.Fprintf(buf, "// Clone makes a deep copy of %s.\n", name)
|
||||||
fmt.Fprintf(buf, "// The result aliases no memory with the original.\n")
|
fmt.Fprintf(buf, "// The result aliases no memory with the original.\n")
|
||||||
@ -191,9 +179,6 @@ func gen(buf *bytes.Buffer, imports map[string]struct{}, name string, typ *types
|
|||||||
for i := 0; i < t.NumFields(); i++ {
|
for i := 0; i < t.NumFields(); i++ {
|
||||||
fname := t.Field(i).Name()
|
fname := t.Field(i).Name()
|
||||||
ft := t.Field(i).Type()
|
ft := t.Field(i).Type()
|
||||||
|
|
||||||
writeRegen("\t%s %s", fname, importedName(ft))
|
|
||||||
|
|
||||||
if !containsPointers(ft) {
|
if !containsPointers(ft) {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
@ -258,9 +243,7 @@ func gen(buf *bytes.Buffer, imports map[string]struct{}, name string, typ *types
|
|||||||
writef("return dst")
|
writef("return dst")
|
||||||
fmt.Fprintf(buf, "}\n\n")
|
fmt.Fprintf(buf, "}\n\n")
|
||||||
|
|
||||||
writeRegen("}{})\n")
|
buf.Write(codegen.AssertStructUnchanged(t, name, "", thisPkg, imports))
|
||||||
|
|
||||||
buf.Write(regenBuf.Bytes())
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -6,8 +6,10 @@
|
|||||||
package codegen
|
package codegen
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bytes"
|
||||||
"fmt"
|
"fmt"
|
||||||
"go/format"
|
"go/format"
|
||||||
|
"go/types"
|
||||||
"os"
|
"os"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -38,3 +40,42 @@ func WriteFormatted(code []byte, path string) error {
|
|||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// AssertStructUnchanged generates code that asserts at compile time that type t is unchanged.
|
||||||
|
// tname is the named type corresponding to t.
|
||||||
|
// ctx is a single-word context for this assertion, such as "Clone".
|
||||||
|
// thisPkg is the package containing t.
|
||||||
|
// If non-nil, AssertStructUnchanged will add elements to imports
|
||||||
|
// for each package path that the caller must import for the returned code to compile.
|
||||||
|
func AssertStructUnchanged(t *types.Struct, tname, ctx string, thisPkg *types.Package, imports map[string]struct{}) []byte {
|
||||||
|
buf := new(bytes.Buffer)
|
||||||
|
w := func(format string, args ...interface{}) {
|
||||||
|
fmt.Fprintf(buf, format+"\n", args...)
|
||||||
|
}
|
||||||
|
w("// A compilation failure here means this code must be regenerated, with the command at the top of this file.")
|
||||||
|
w("var _%s%sNeedsRegeneration = %s(struct {", tname, ctx, tname)
|
||||||
|
|
||||||
|
for i := 0; i < t.NumFields(); i++ {
|
||||||
|
fname := t.Field(i).Name()
|
||||||
|
ft := t.Field(i).Type()
|
||||||
|
qname, imppath := importedName(ft, thisPkg)
|
||||||
|
if imppath != "" && imports != nil {
|
||||||
|
imports[imppath] = struct{}{}
|
||||||
|
}
|
||||||
|
w("\t%s %s", fname, qname)
|
||||||
|
}
|
||||||
|
|
||||||
|
w("}{})\n")
|
||||||
|
return buf.Bytes()
|
||||||
|
}
|
||||||
|
|
||||||
|
func importedName(t types.Type, thisPkg *types.Package) (qualifiedName, importPkg string) {
|
||||||
|
qual := func(pkg *types.Package) string {
|
||||||
|
if thisPkg == pkg {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
importPkg = pkg.Path()
|
||||||
|
return pkg.Name()
|
||||||
|
}
|
||||||
|
return types.TypeString(t, qual), importPkg
|
||||||
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user