cmd/equaler: add command to generate Equal() methods

The generator is still crude and do not cover most types, but it
covers all the ones needed by the tailcfg package. It's a start.

Fixes #8077.

Signed-off-by: salman <salman@tailscale.com>
This commit is contained in:
salman 2023-05-05 16:47:28 +01:00
parent 5783adcc6f
commit 8e6f564f7e
7 changed files with 436 additions and 126 deletions

177
cmd/equaler/equaler.go Normal file
View File

@ -0,0 +1,177 @@
// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause
// Equaler is a tool to automate the creation of an Equals method.
//
// This tool assumes that if a type you give it contains another named struct
// type, that type will also have an Equal method, and that all fields are
// comparable unless explicitly excluded.
package main
import (
"bytes"
"flag"
"fmt"
"go/token"
"go/types"
"log"
"os"
"strings"
"golang.org/x/exp/slices"
"tailscale.com/util/codegen"
)
var (
flagTypes = flag.String("type", "", "comma-separated list of types; required")
flagBuildTags = flag.String("tags", "", "compiler build tags to apply")
)
func main() {
log.SetFlags(0)
log.SetPrefix("equaler: ")
flag.Parse()
if len(*flagTypes) == 0 {
flag.Usage()
os.Exit(2)
}
typeNames := strings.Split(*flagTypes, ",")
pkg, namedTypes, err := codegen.LoadTypes(*flagBuildTags, ".")
if err != nil {
log.Fatal(err)
}
it := codegen.NewImportTracker(pkg.Types)
buf := new(bytes.Buffer)
for _, typeName := range typeNames {
typ, ok := namedTypes[typeName]
if !ok {
log.Fatalf("could not find type %s", typeName)
}
gen(buf, it, typ, typeNames)
}
cloneOutput := pkg.Name + "_equal.go"
if err := codegen.WritePackageFile("tailscale.com/cmd/equaler", pkg, cloneOutput, it, buf); err != nil {
log.Fatal(err)
}
}
func gen(buf *bytes.Buffer, it *codegen.ImportTracker, typ *types.Named, typeNames []string) {
t, ok := typ.Underlying().(*types.Struct)
if !ok {
return
}
name := typ.Obj().Name()
fmt.Fprintf(buf, "// Equal reports whether a and b are equal.\n")
fmt.Fprintf(buf, "func (a *%s) Equal(b *%s) bool {\n", name, name)
writef := func(format string, args ...any) {
fmt.Fprintf(buf, "\t"+format+"\n", args...)
}
writef("if a == b {")
writef("\treturn true")
writef("}")
writef("return a != nil && b != nil &&")
for i := 0; i < t.NumFields(); i++ {
fname := t.Field(i).Name()
ft := t.Field(i).Type()
// Fields which are explicitly ignored are skipped.
if codegen.HasNoEqual(t.Tag(i)) {
writef("\t// Skipping %s because of codegen:noequal", fname)
continue
}
// Fields which are named types that have an Equal() method, get that method used
if named, _ := ft.(*types.Named); named != nil {
if implementsEqual(ft) || slices.Contains(typeNames, named.Obj().Name()) {
writef("\ta.%s.Equal(b.%s) &&", fname, fname)
continue
}
}
// Fields which are just values are directly compared, unless they have an Equal() method.
if !codegen.ContainsPointers(ft) {
writef("\ta.%s == b.%s &&", fname, fname)
continue
}
switch ft := ft.Underlying().(type) {
case *types.Pointer:
if named, _ := ft.Elem().(*types.Named); named != nil {
if slices.Contains(typeNames, named.Obj().Name()) || implementsEqual(ft) {
writef("\t((a.%s == nil) == (b.%s == nil)) && (a.%s == nil || a.%s.Equal(b.%s)) &&", fname, fname, fname, fname, fname)
continue
}
if implementsEqual(ft.Elem()) {
writef("\t((a.%s == nil) == (b.%s == nil)) && (a.%s == nil || a.%s.Equal(*b.%s)) &&", fname, fname, fname, fname, fname)
continue
}
}
if !codegen.ContainsPointers(ft.Elem()) {
writef("\t((a.%s == nil) == (b.%s == nil)) && (a.%s == nil || *a.%s == *b.%s) &&", fname, fname, fname, fname, fname)
continue
}
log.Fatalf("unimplemented: %s (%T)", fname, ft)
case *types.Slice:
// Empty slices and nil slices are different.
writef("\t((a.%s == nil) == (b.%s == nil)) &&", fname, fname)
if named, _ := ft.Elem().(*types.Named); named != nil {
if implementsEqual(ft.Elem()) {
it.Import("golang.org/x/exp/slices")
writef("\tslices.EqualFunc(a.%s, b.%s, func(aa %s, bb %s) bool {return aa.Equal(bb)}) &&", fname, fname, named.Obj().Name(), named.Obj().Name())
continue
}
if slices.Contains(typeNames, named.Obj().Name()) || implementsEqual(types.NewPointer(ft.Elem())) {
it.Import("golang.org/x/exp/slices")
writef("\tslices.EqualFunc(a.%s, b.%s, func(aa %s, bb %s) bool {return aa.Equal(&bb)}) &&", fname, fname, named.Obj().Name(), named.Obj().Name())
continue
}
}
if !codegen.ContainsPointers(ft.Elem()) {
it.Import("golang.org/x/exp/slices")
writef("\tslices.Equal(a.%s, b.%s) &&", fname, fname)
continue
}
log.Fatalf("unimplemented: %s (%T)", fname, ft)
case *types.Map:
if !codegen.ContainsPointers(ft.Elem()) {
it.Import("golang.org/x/exp/maps")
writef("\tmaps.Equal(a.%s, b.%s) &&", fname, fname)
continue
}
log.Fatalf("unimplemented: %s (%T)", fname, ft)
default:
log.Fatalf("unimplemented: %s (%T)", fname, ft)
}
}
writef("\ttrue")
fmt.Fprintf(buf, "}\n\n")
buf.Write(codegen.AssertStructUnchanged(t, name, "Equal", it))
}
// hasBasicUnderlying reports true when typ.Underlying() is a slice or a map.
func hasBasicUnderlying(typ types.Type) bool {
switch typ.Underlying().(type) {
case *types.Slice, *types.Map:
return true
default:
return false
}
}
// implementsEqual reports whether typ has an Equal(typ) bool method.
func implementsEqual(typ types.Type) bool {
return types.Implements(typ, types.NewInterfaceType(
[]*types.Func{types.NewFunc(
token.NoPos, nil, "Equal", types.NewSignatureType(
types.NewVar(token.NoPos, nil, "a", typ),
nil, nil,
types.NewTuple(types.NewVar(token.NoPos, nil, "b", typ)),
types.NewTuple(types.NewVar(token.NoPos, nil, "", types.Typ[types.Bool])), false))},
[]types.Type{},
))
}

View File

@ -5,8 +5,9 @@
//go:generate go run tailscale.com/cmd/viewer --type=User,Node,Hostinfo,NetInfo,Login,DNSConfig,RegisterResponse,DERPRegion,DERPMap,DERPNode,SSHRule,SSHAction,SSHPrincipal,ControlDialPlan --clonefunc
//go:generage go run tailscale.com/cmd/equaler --type Node,Hostinfo,NetInfo,Service
import (
"bytes"
"encoding/hex"
"errors"
"fmt"
@ -497,7 +498,7 @@ func (h *Hostinfo) CheckRequestTags() error {
// Service represents a service running on a node.
type Service struct {
_ structs.Incomparable
_ structs.Incomparable `codegen:"noequal"`
// Proto is the type of service. It's usually the constant TCP
// or UDP ("tcp" or "udp"), but it can also be one of the
@ -582,9 +583,6 @@ type Hostinfo struct {
Cloud string `json:",omitempty"`
Userspace opt.Bool `json:",omitempty"` // if the client is running in userspace (netstack) mode
UserspaceRouter opt.Bool `json:",omitempty"` // if the client's subnet router is running in userspace (netstack) mode
// NOTE: any new fields containing pointers in this type
// require changes to Hostinfo.Equal.
}
// TailscaleSSHEnabled reports whether or not this node is acting as a
@ -664,9 +662,7 @@ type NetInfo struct {
// This should only be updated rarely, or when there's a
// material change, as any change here also gets uploaded to
// the control plane.
DERPLatency map[string]float64 `json:",omitempty"`
// Update BasicallyEqual when adding fields.
DERPLatency map[string]float64 `json:",omitempty" codegen:"noequal"`
}
func (ni *NetInfo) String() string {
@ -704,40 +700,6 @@ func conciseOptBool(b opt.Bool, trueVal string) string {
return ""
}
// BasicallyEqual reports whether ni and ni2 are basically equal, ignoring
// changes in DERP ServerLatency & RegionLatency.
func (ni *NetInfo) BasicallyEqual(ni2 *NetInfo) bool {
if (ni == nil) != (ni2 == nil) {
return false
}
if ni == nil {
return true
}
return ni.MappingVariesByDestIP == ni2.MappingVariesByDestIP &&
ni.HairPinning == ni2.HairPinning &&
ni.WorkingIPv6 == ni2.WorkingIPv6 &&
ni.OSHasIPv6 == ni2.OSHasIPv6 &&
ni.WorkingUDP == ni2.WorkingUDP &&
ni.WorkingICMPv4 == ni2.WorkingICMPv4 &&
ni.HavePortMap == ni2.HavePortMap &&
ni.UPnP == ni2.UPnP &&
ni.PMP == ni2.PMP &&
ni.PCP == ni2.PCP &&
ni.PreferredDERP == ni2.PreferredDERP &&
ni.LinkType == ni2.LinkType
}
// Equal reports whether h and h2 are equal.
func (h *Hostinfo) Equal(h2 *Hostinfo) bool {
if h == nil && h2 == nil {
return true
}
if (h == nil) != (h2 == nil) {
return false
}
return reflect.DeepEqual(h, h2)
}
// HowUnequal returns a list of paths through Hostinfo where h and h2 differ.
// If they differ in nil-ness, the path is "nil", otherwise the path is like
// "ShieldsUp" or "NetInfo.nil" or "NetInfo.PCP".
@ -1689,82 +1651,6 @@ func (id UserID) String() string { return fmt.Sprintf("userid:%x", int64(id)) }
func (id LoginID) String() string { return fmt.Sprintf("loginid:%x", int64(id)) }
func (id NodeID) String() string { return fmt.Sprintf("nodeid:%x", int64(id)) }
// Equal reports whether n and n2 are equal.
func (n *Node) Equal(n2 *Node) bool {
if n == nil && n2 == nil {
return true
}
return n != nil && n2 != nil &&
n.ID == n2.ID &&
n.StableID == n2.StableID &&
n.Name == n2.Name &&
n.User == n2.User &&
n.Sharer == n2.Sharer &&
n.UnsignedPeerAPIOnly == n2.UnsignedPeerAPIOnly &&
n.Key == n2.Key &&
n.KeyExpiry.Equal(n2.KeyExpiry) &&
bytes.Equal(n.KeySignature, n2.KeySignature) &&
n.Machine == n2.Machine &&
n.DiscoKey == n2.DiscoKey &&
eqPtr(n.Online, n2.Online) &&
eqCIDRs(n.Addresses, n2.Addresses) &&
eqCIDRs(n.AllowedIPs, n2.AllowedIPs) &&
eqCIDRs(n.PrimaryRoutes, n2.PrimaryRoutes) &&
eqStrings(n.Endpoints, n2.Endpoints) &&
n.DERP == n2.DERP &&
n.Cap == n2.Cap &&
n.Hostinfo.Equal(n2.Hostinfo) &&
n.Created.Equal(n2.Created) &&
eqTimePtr(n.LastSeen, n2.LastSeen) &&
n.MachineAuthorized == n2.MachineAuthorized &&
eqStrings(n.Capabilities, n2.Capabilities) &&
n.ComputedName == n2.ComputedName &&
n.computedHostIfDifferent == n2.computedHostIfDifferent &&
n.ComputedNameWithHost == n2.ComputedNameWithHost &&
eqStrings(n.Tags, n2.Tags) &&
n.Expired == n2.Expired &&
eqPtr(n.SelfNodeV4MasqAddrForThisPeer, n2.SelfNodeV4MasqAddrForThisPeer) &&
n.IsWireGuardOnly == n2.IsWireGuardOnly
}
func eqPtr[T comparable](a, b *T) bool {
if a == b { // covers nil
return true
}
if a == nil || b == nil {
return false
}
return *a == *b
}
func eqStrings(a, b []string) bool {
if len(a) != len(b) || ((a == nil) != (b == nil)) {
return false
}
for i, v := range a {
if v != b[i] {
return false
}
}
return true
}
func eqCIDRs(a, b []netip.Prefix) bool {
if len(a) != len(b) || ((a == nil) != (b == nil)) {
return false
}
for i, v := range a {
if v != b[i] {
return false
}
}
return true
}
func eqTimePtr(a, b *time.Time) bool {
return ((a == nil) == (b == nil)) && (a == nil || a.Equal(*b))
}
// Oauth2Token is a copy of golang.org/x/oauth2.Token, to avoid the
// go.mod dependency on App Engine and grpc, which was causing problems.
// All we actually needed was this struct on the client side.

244
tailcfg/tailcfg_equal.go Normal file
View File

@ -0,0 +1,244 @@
// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause
// Code generated by tailscale.com/cmd/equaler; DO NOT EDIT.
package tailcfg
import (
"net/netip"
"time"
"golang.org/x/exp/slices"
"tailscale.com/types/key"
"tailscale.com/types/opt"
"tailscale.com/types/structs"
"tailscale.com/types/tkatype"
)
// Equal reports whether a and b are equal.
func (a *Node) Equal(b *Node) bool {
if a == b {
return true
}
return a != nil && b != nil &&
a.ID == b.ID &&
a.StableID == b.StableID &&
a.Name == b.Name &&
a.User == b.User &&
a.Sharer == b.Sharer &&
a.Key == b.Key &&
a.KeyExpiry.Equal(b.KeyExpiry) &&
((a.KeySignature == nil) == (b.KeySignature == nil)) &&
slices.Equal(a.KeySignature, b.KeySignature) &&
a.Machine == b.Machine &&
a.DiscoKey == b.DiscoKey &&
((a.Addresses == nil) == (b.Addresses == nil)) &&
slices.Equal(a.Addresses, b.Addresses) &&
((a.AllowedIPs == nil) == (b.AllowedIPs == nil)) &&
slices.Equal(a.AllowedIPs, b.AllowedIPs) &&
((a.Endpoints == nil) == (b.Endpoints == nil)) &&
slices.Equal(a.Endpoints, b.Endpoints) &&
a.DERP == b.DERP &&
a.Hostinfo.Equal(b.Hostinfo) &&
a.Created.Equal(b.Created) &&
a.Cap == b.Cap &&
((a.Tags == nil) == (b.Tags == nil)) &&
slices.Equal(a.Tags, b.Tags) &&
((a.PrimaryRoutes == nil) == (b.PrimaryRoutes == nil)) &&
slices.Equal(a.PrimaryRoutes, b.PrimaryRoutes) &&
((a.LastSeen == nil) == (b.LastSeen == nil)) && (a.LastSeen == nil || a.LastSeen.Equal(*b.LastSeen)) &&
((a.Online == nil) == (b.Online == nil)) && (a.Online == nil || *a.Online == *b.Online) &&
a.KeepAlive == b.KeepAlive &&
a.MachineAuthorized == b.MachineAuthorized &&
((a.Capabilities == nil) == (b.Capabilities == nil)) &&
slices.Equal(a.Capabilities, b.Capabilities) &&
a.UnsignedPeerAPIOnly == b.UnsignedPeerAPIOnly &&
a.ComputedName == b.ComputedName &&
a.computedHostIfDifferent == b.computedHostIfDifferent &&
a.ComputedNameWithHost == b.ComputedNameWithHost &&
a.DataPlaneAuditLogID == b.DataPlaneAuditLogID &&
a.Expired == b.Expired &&
((a.SelfNodeV4MasqAddrForThisPeer == nil) == (b.SelfNodeV4MasqAddrForThisPeer == nil)) && (a.SelfNodeV4MasqAddrForThisPeer == nil || *a.SelfNodeV4MasqAddrForThisPeer == *b.SelfNodeV4MasqAddrForThisPeer) &&
a.IsWireGuardOnly == b.IsWireGuardOnly &&
true
}
// A compilation failure here means this code must be regenerated, with the command at the top of this file.
var _NodeEqualNeedsRegeneration = Node(struct {
ID NodeID
StableID StableNodeID
Name string
User UserID
Sharer UserID
Key key.NodePublic
KeyExpiry time.Time
KeySignature tkatype.MarshaledSignature
Machine key.MachinePublic
DiscoKey key.DiscoPublic
Addresses []netip.Prefix
AllowedIPs []netip.Prefix
Endpoints []string
DERP string
Hostinfo HostinfoView
Created time.Time
Cap CapabilityVersion
Tags []string
PrimaryRoutes []netip.Prefix
LastSeen *time.Time
Online *bool
KeepAlive bool
MachineAuthorized bool
Capabilities []string
UnsignedPeerAPIOnly bool
ComputedName string
computedHostIfDifferent string
ComputedNameWithHost string
DataPlaneAuditLogID string
Expired bool
SelfNodeV4MasqAddrForThisPeer *netip.Addr
IsWireGuardOnly bool
}{})
// Equal reports whether a and b are equal.
func (a *Hostinfo) Equal(b *Hostinfo) bool {
if a == b {
return true
}
return a != nil && b != nil &&
a.IPNVersion == b.IPNVersion &&
a.FrontendLogID == b.FrontendLogID &&
a.BackendLogID == b.BackendLogID &&
a.OS == b.OS &&
a.OSVersion == b.OSVersion &&
a.Container == b.Container &&
a.Env == b.Env &&
a.Distro == b.Distro &&
a.DistroVersion == b.DistroVersion &&
a.DistroCodeName == b.DistroCodeName &&
a.App == b.App &&
a.Desktop == b.Desktop &&
a.Package == b.Package &&
a.DeviceModel == b.DeviceModel &&
a.PushDeviceToken == b.PushDeviceToken &&
a.Hostname == b.Hostname &&
a.ShieldsUp == b.ShieldsUp &&
a.ShareeNode == b.ShareeNode &&
a.NoLogsNoSupport == b.NoLogsNoSupport &&
a.WireIngress == b.WireIngress &&
a.AllowsUpdate == b.AllowsUpdate &&
a.Machine == b.Machine &&
a.GoArch == b.GoArch &&
a.GoArchVar == b.GoArchVar &&
a.GoVersion == b.GoVersion &&
((a.RoutableIPs == nil) == (b.RoutableIPs == nil)) &&
slices.Equal(a.RoutableIPs, b.RoutableIPs) &&
((a.RequestTags == nil) == (b.RequestTags == nil)) &&
slices.Equal(a.RequestTags, b.RequestTags) &&
((a.Services == nil) == (b.Services == nil)) &&
slices.EqualFunc(a.Services, b.Services, func(aa Service, bb Service) bool { return aa.Equal(&bb) }) &&
((a.NetInfo == nil) == (b.NetInfo == nil)) && (a.NetInfo == nil || a.NetInfo.Equal(b.NetInfo)) &&
((a.SSH_HostKeys == nil) == (b.SSH_HostKeys == nil)) &&
slices.Equal(a.SSH_HostKeys, b.SSH_HostKeys) &&
a.Cloud == b.Cloud &&
a.Userspace == b.Userspace &&
a.UserspaceRouter == b.UserspaceRouter &&
true
}
// A compilation failure here means this code must be regenerated, with the command at the top of this file.
var _HostinfoEqualNeedsRegeneration = Hostinfo(struct {
IPNVersion string
FrontendLogID string
BackendLogID string
OS string
OSVersion string
Container opt.Bool
Env string
Distro string
DistroVersion string
DistroCodeName string
App string
Desktop opt.Bool
Package string
DeviceModel string
PushDeviceToken string
Hostname string
ShieldsUp bool
ShareeNode bool
NoLogsNoSupport bool
WireIngress bool
AllowsUpdate bool
Machine string
GoArch string
GoArchVar string
GoVersion string
RoutableIPs []netip.Prefix
RequestTags []string
Services []Service
NetInfo *NetInfo
SSH_HostKeys []string
Cloud string
Userspace opt.Bool
UserspaceRouter opt.Bool
}{})
// Equal reports whether a and b are equal.
func (a *NetInfo) Equal(b *NetInfo) bool {
if a == b {
return true
}
return a != nil && b != nil &&
a.MappingVariesByDestIP == b.MappingVariesByDestIP &&
a.HairPinning == b.HairPinning &&
a.WorkingIPv6 == b.WorkingIPv6 &&
a.OSHasIPv6 == b.OSHasIPv6 &&
a.WorkingUDP == b.WorkingUDP &&
a.WorkingICMPv4 == b.WorkingICMPv4 &&
a.HavePortMap == b.HavePortMap &&
a.UPnP == b.UPnP &&
a.PMP == b.PMP &&
a.PCP == b.PCP &&
a.PreferredDERP == b.PreferredDERP &&
a.LinkType == b.LinkType &&
// Skipping DERPLatency because of codegen:noequal
true
}
// A compilation failure here means this code must be regenerated, with the command at the top of this file.
var _NetInfoEqualNeedsRegeneration = NetInfo(struct {
MappingVariesByDestIP opt.Bool
HairPinning opt.Bool
WorkingIPv6 opt.Bool
OSHasIPv6 opt.Bool
WorkingUDP opt.Bool
WorkingICMPv4 opt.Bool
HavePortMap bool
UPnP opt.Bool
PMP opt.Bool
PCP opt.Bool
PreferredDERP int
LinkType string
DERPLatency map[string]float64
}{})
// Equal reports whether a and b are equal.
func (a *Service) Equal(b *Service) bool {
if a == b {
return true
}
return a != nil && b != nil &&
// Skipping _ because of codegen:noequal
a.Proto == b.Proto &&
a.Port == b.Port &&
a.Description == b.Description &&
true
}
// A compilation failure here means this code must be regenerated, with the command at the top of this file.
var _ServiceEqualNeedsRegeneration = Service(struct {
_ structs.Incomparable
Proto ServiceProto
Port uint16
Description string
}{})

View File

@ -572,7 +572,7 @@ func TestNetInfoFields(t *testing.T) {
"DERPLatency",
}
if have := fieldsOf(reflect.TypeOf(NetInfo{})); !reflect.DeepEqual(have, handled) {
t.Errorf("NetInfo.Clone/BasicallyEqually check might be out of sync\nfields: %q\nhandled: %q\n",
t.Errorf("NetInfo.Clone/Equal check might be out of sync\nfields: %q\nhandled: %q\n",
have, handled)
}
}

View File

@ -402,6 +402,7 @@ func (v NetInfoView) LinkType() string { return v.ж.LinkType }
func (v NetInfoView) DERPLatency() views.Map[string, float64] { return views.MapOf(v.ж.DERPLatency) }
func (v NetInfoView) String() string { return v.ж.String() }
func (v NetInfoView) Equal(v2 NetInfoView) bool { return v.ж.Equal(v2.ж) }
// A compilation failure here means this code must be regenerated, with the command at the top of this file.
var _NetInfoViewNeedsRegeneration = NetInfo(struct {

View File

@ -16,6 +16,7 @@
"reflect"
"strings"
"golang.org/x/exp/slices"
"golang.org/x/tools/go/packages"
"golang.org/x/tools/imports"
"tailscale.com/util/mak"
@ -47,12 +48,13 @@ func LoadTypes(buildTags string, pkgName string) (*packages.Package, map[string]
// HasNoClone reports whether the provided tag has `codegen:noclone`.
func HasNoClone(structTag string) bool {
val := reflect.StructTag(structTag).Get("codegen")
for _, v := range strings.Split(val, ",") {
if v == "noclone" {
return true
}
}
return false
return slices.Contains(strings.Split(val, ","), "noclone")
}
// HasNoEqual reports whether the provided tag has `codegen:noequal`.
func HasNoEqual(structTag string) bool {
val := reflect.StructTag(structTag).Get("codegen")
return slices.Contains(strings.Split(val, ","), "noequal")
}
const copyrightHeader = `// Copyright (c) Tailscale Inc & AUTHORS

View File

@ -951,7 +951,7 @@ func (c *Conn) pickDERPFallback() int {
func (c *Conn) callNetInfoCallback(ni *tailcfg.NetInfo) {
c.mu.Lock()
defer c.mu.Unlock()
if ni.BasicallyEqual(c.netInfoLast) {
if ni.Equal(c.netInfoLast) {
return
}
c.callNetInfoCallbackLocked(ni)