From 09fbae01a93486812313b35af78f6aeb69537f23 Mon Sep 17 00:00:00 2001 From: wardn Date: Sat, 15 Feb 2020 22:23:58 -0800 Subject: [PATCH] tailcfg: don't panic on node equal check Signed-off-by: wardn --- tailcfg/tailcfg.go | 28 ++++--- tailcfg/tailcfg_test.go | 166 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 184 insertions(+), 10 deletions(-) create mode 100644 tailcfg/tailcfg_test.go diff --git a/tailcfg/tailcfg.go b/tailcfg/tailcfg.go index 6ce555ad0..ad127d155 100644 --- a/tailcfg/tailcfg.go +++ b/tailcfg/tailcfg.go @@ -6,9 +6,9 @@ import ( "bytes" - "encoding/json" "errors" "fmt" + "reflect" "strings" "time" @@ -345,15 +345,23 @@ func (id GroupID) String() string { return fmt.Sprintf("groupid:%x", int64( func (id RoleID) String() string { return fmt.Sprintf("roleid:%x", int64(id)) } func (id CapabilityID) String() string { return fmt.Sprintf("capid:%x", int64(id)) } +// Equal reports whether n and n2 are equal. func (n *Node) Equal(n2 *Node) bool { - // TODO(crawshaw): this is crude, but is an easy way to avoid bugs. - b, err := json.Marshal(n) - if err != nil { - panic(err) + if n == nil && n2 == nil { + return true } - b2, err := json.Marshal(n2) - if err != nil { - panic(err) - } - return bytes.Equal(b, b2) + return n != nil && n2 != nil && + n.ID == n2.ID && + n.Name == n2.Name && + n.User == n2.User && + n.Key == n2.Key && + n.KeyExpiry.Equal(n2.KeyExpiry) && + n.Machine == n2.Machine && + reflect.DeepEqual(n.Addresses, n2.Addresses) && + reflect.DeepEqual(n.AllowedIPs, n2.AllowedIPs) && + reflect.DeepEqual(n.Endpoints, n2.Endpoints) && + reflect.DeepEqual(n.Hostinfo, n2.Hostinfo) && + n.Created.Equal(n2.Created) && + reflect.DeepEqual(n.LastSeen, n2.LastSeen) && + n.MachineAuthorized == n2.MachineAuthorized } diff --git a/tailcfg/tailcfg_test.go b/tailcfg/tailcfg_test.go new file mode 100644 index 000000000..08adf4f6c --- /dev/null +++ b/tailcfg/tailcfg_test.go @@ -0,0 +1,166 @@ +// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package tailcfg + +import ( + "reflect" + "testing" + "time" + + "github.com/tailscale/wireguard-go/wgcfg" +) + +func fieldsOf(t reflect.Type) (fields []string) { + for i := 0; i < t.NumField(); i++ { + fields = append(fields, t.Field(i).Name) + } + return +} + +func TestNodeEqual(t *testing.T) { + nodeHandles := []string{"ID", "Name", "User", "Key", "KeyExpiry", "Machine", "Addresses", "AllowedIPs", "Endpoints", "Hostinfo", "Created", "LastSeen", "MachineAuthorized"} + if have := fieldsOf(reflect.TypeOf(Node{})); !reflect.DeepEqual(have, nodeHandles) { + t.Errorf("Node.Equal check might be out of sync\nfields: %q\nhandled: %q\n", + have, nodeHandles) + } + + newPublicKey := func(t *testing.T) wgcfg.Key { + t.Helper() + k, err := wgcfg.NewPrivateKey() + if err != nil { + t.Fatal(err) + } + return k.Public() + } + n1 := newPublicKey(t) + now := time.Now() + + tests := []struct { + a, b *Node + want bool + }{ + { + &Node{}, + nil, + false, + }, + { + nil, + &Node{}, + false, + }, + { + &Node{}, + &Node{}, + true, + }, + { + &Node{User: 0}, + &Node{User: 1}, + false, + }, + { + &Node{User: 1}, + &Node{User: 1}, + true, + }, + { + &Node{Key: NodeKey(n1)}, + &Node{Key: NodeKey(newPublicKey(t))}, + false, + }, + { + &Node{Key: NodeKey(n1)}, + &Node{Key: NodeKey(n1)}, + true, + }, + { + &Node{KeyExpiry: now}, + &Node{KeyExpiry: now.Add(60 * time.Second)}, + false, + }, + { + &Node{KeyExpiry: now}, + &Node{KeyExpiry: now}, + true, + }, + { + &Node{Machine: MachineKey(n1)}, + &Node{Machine: MachineKey(newPublicKey(t))}, + false, + }, + { + &Node{Machine: MachineKey(n1)}, + &Node{Machine: MachineKey(n1)}, + true, + }, + { + &Node{Addresses: []wgcfg.CIDR{}}, + &Node{Addresses: nil}, + false, + }, + { + &Node{Addresses: []wgcfg.CIDR{}}, + &Node{Addresses: []wgcfg.CIDR{}}, + true, + }, + { + &Node{AllowedIPs: []wgcfg.CIDR{}}, + &Node{AllowedIPs: nil}, + false, + }, + { + &Node{Addresses: []wgcfg.CIDR{}}, + &Node{Addresses: []wgcfg.CIDR{}}, + true, + }, + { + &Node{Endpoints: []string{}}, + &Node{Endpoints: nil}, + false, + }, + { + &Node{Endpoints: []string{}}, + &Node{Endpoints: []string{}}, + true, + }, + { + &Node{Hostinfo: Hostinfo{Hostname: "alice"}}, + &Node{Hostinfo: Hostinfo{Hostname: "bob"}}, + false, + }, + { + &Node{Hostinfo: Hostinfo{}}, + &Node{Hostinfo: Hostinfo{}}, + true, + }, + { + &Node{Created: now}, + &Node{Created: now.Add(60 * time.Second)}, + false, + }, + { + &Node{Created: now}, + &Node{Created: now}, + true, + }, + { + &Node{LastSeen: &now}, + &Node{LastSeen: nil}, + false, + }, + { + &Node{LastSeen: &now}, + &Node{LastSeen: &now}, + true, + }, + } + for i, tt := range tests { + got := tt.a.Equal(tt.b) + if got != tt.want { + t.Errorf("%d. Equal = %v; want %v", i, got, tt.want) + } + } +}