diff --git a/tailcfg/tailcfg.go b/tailcfg/tailcfg.go index 3bce4d8e1..e6b363e1b 100644 --- a/tailcfg/tailcfg.go +++ b/tailcfg/tailcfg.go @@ -13,7 +13,9 @@ "time" "github.com/tailscale/wireguard-go/wgcfg" + "go4.org/mem" "golang.org/x/oauth2" + "tailscale.com/types/key" "tailscale.com/types/opt" "tailscale.com/types/structs" ) @@ -38,6 +40,10 @@ // NodeKey is the curve25519 public key for a node. type NodeKey [32]byte +// DiscoKey is the curve25519 public key for path discovery key. +// It's never written to disk or reused between network start-ups. +type DiscoKey [32]byte + type Group struct { ID GroupID Name string @@ -127,6 +133,7 @@ type Node struct { Key NodeKey KeyExpiry time.Time Machine MachineKey + DiscoKey DiscoKey Addresses []wgcfg.CIDR // IP addresses of this Node directly AllowedIPs []wgcfg.CIDR // range of IP addresses to route to this node Endpoints []string `json:",omitempty"` // IP+port (public via STUN, and local LANs) @@ -519,59 +526,43 @@ type Debug struct { LogHeapURL string `json:",omitempty"` } -func (k MachineKey) String() string { return fmt.Sprintf("mkey:%x", k[:]) } +func (k MachineKey) String() string { return fmt.Sprintf("mkey:%x", k[:]) } +func (k MachineKey) MarshalText() ([]byte, error) { return keyMarshalText("mkey:", k), nil } +func (k *MachineKey) UnmarshalText(text []byte) error { return keyUnmarshalText(k[:], "mkey:", text) } -func (k MachineKey) MarshalText() ([]byte, error) { - buf := new(bytes.Buffer) - fmt.Fprintf(buf, "mkey:%x", k[:]) - return buf.Bytes(), nil +func keyMarshalText(prefix string, k [32]byte) []byte { + buf := bytes.NewBuffer(make([]byte, 0, len(prefix)+64)) + fmt.Fprintf(buf, "%s%x", prefix, k[:]) + return buf.Bytes() } -func (k *MachineKey) UnmarshalText(text []byte) error { - s := string(text) - if !strings.HasPrefix(s, "mkey:") { - return errors.New(`MachineKey.UnmarshalText: missing prefix`) +func keyUnmarshalText(dst []byte, prefix string, text []byte) error { + if len(text) < len(prefix) || string(text[:len(prefix)]) != prefix { + return fmt.Errorf("UnmarshalText: missing %q prefix", prefix) } - s = strings.TrimPrefix(s, `mkey:`) - key, err := wgcfg.ParseHexKey(s) + pub, err := key.NewPublicFromHexMem(mem.B(text[len(prefix):])) if err != nil { - return fmt.Errorf("MachineKey.UnmarhsalText: %v", err) + return fmt.Errorf("UnmarshalText: after %q: %v", prefix, err) } - copy(k[:], key[:]) + copy(dst[:], pub[:]) return nil } -func (k NodeKey) String() string { return fmt.Sprintf("nodekey:%x", k[:]) } +func (k NodeKey) ShortString() string { return (key.Public(k)).ShortString() } -func (k NodeKey) ShortString() string { - pk := wgcfg.Key(k) - return pk.ShortString() -} +func (k NodeKey) String() string { return fmt.Sprintf("nodekey:%x", k[:]) } +func (k NodeKey) MarshalText() ([]byte, error) { return keyMarshalText("nodekey:", k), nil } +func (k *NodeKey) UnmarshalText(text []byte) error { return keyUnmarshalText(k[:], "nodekey:", text) } -func (k NodeKey) MarshalText() ([]byte, error) { - buf := new(bytes.Buffer) - fmt.Fprintf(buf, "nodekey:%x", k[:]) - return buf.Bytes(), nil -} +// IsZero reports whether k is the zero value. +func (k NodeKey) IsZero() bool { return k == NodeKey{} } -func (k *NodeKey) UnmarshalText(text []byte) error { - s := string(text) - if !strings.HasPrefix(s, "nodekey:") { - return errors.New(`Nodekey.UnmarshalText: missing prefix`) - } - s = strings.TrimPrefix(s, "nodekey:") - key, err := wgcfg.ParseHexKey(s) - if err != nil { - return fmt.Errorf("tailcfg.Ukey.UnmarhsalText: %v", err) - } - copy(k[:], key[:]) - return nil -} +func (k DiscoKey) String() string { return fmt.Sprintf("discokey:%x", k[:]) } +func (k DiscoKey) MarshalText() ([]byte, error) { return keyMarshalText("discokey:", k), nil } +func (k *DiscoKey) UnmarshalText(text []byte) error { return keyUnmarshalText(k[:], "discokey:", text) } -// IsZero reports whether k is the NodeKey zero value. -func (k NodeKey) IsZero() bool { - return k == NodeKey{} -} +// IsZero reports whether k is the zero value. +func (k DiscoKey) IsZero() bool { return k == DiscoKey{} } func (id ID) String() string { return fmt.Sprintf("id:%x", int64(id)) } func (id UserID) String() string { return fmt.Sprintf("userid:%x", int64(id)) } @@ -593,6 +584,7 @@ func (n *Node) Equal(n2 *Node) bool { n.Key == n2.Key && n.KeyExpiry.Equal(n2.KeyExpiry) && n.Machine == n2.Machine && + n.DiscoKey == n2.DiscoKey && reflect.DeepEqual(n.Addresses, n2.Addresses) && reflect.DeepEqual(n.AllowedIPs, n2.AllowedIPs) && reflect.DeepEqual(n.Endpoints, n2.Endpoints) && diff --git a/tailcfg/tailcfg_test.go b/tailcfg/tailcfg_test.go index 2801c94f1..63af39505 100644 --- a/tailcfg/tailcfg_test.go +++ b/tailcfg/tailcfg_test.go @@ -5,7 +5,9 @@ package tailcfg import ( + "encoding" "reflect" + "strings" "testing" "time" @@ -176,7 +178,7 @@ func TestHostinfoEqual(t *testing.T) { } func TestNodeEqual(t *testing.T) { - nodeHandles := []string{"ID", "Name", "User", "Key", "KeyExpiry", "Machine", "Addresses", "AllowedIPs", "Endpoints", "DERP", "Hostinfo", "Created", "LastSeen", "KeepAlive", "MachineAuthorized"} + nodeHandles := []string{"ID", "Name", "User", "Key", "KeyExpiry", "Machine", "DiscoKey", "Addresses", "AllowedIPs", "Endpoints", "DERP", "Hostinfo", "Created", "LastSeen", "KeepAlive", "MachineAuthorized"} if have := fieldsOf(reflect.TypeOf(Node{})); !reflect.DeepEqual(have, nodeHandles) { t.Errorf("Node.Equal check might be out of sync\nfields: %q\nhandled: %q\n", have, nodeHandles) @@ -336,3 +338,51 @@ func TestNetInfoFields(t *testing.T) { have, handled) } } + +func TestMachineKeyMarshal(t *testing.T) { + var k1, k2 MachineKey + for i := range k1 { + k1[i] = byte(i) + } + testKey(t, "mkey:", k1, &k2) +} + +func TestNodeKeyMarshal(t *testing.T) { + var k1, k2 NodeKey + for i := range k1 { + k1[i] = byte(i) + } + testKey(t, "nodekey:", k1, &k2) +} + +func TestDiscoKeyMarshal(t *testing.T) { + var k1, k2 DiscoKey + for i := range k1 { + k1[i] = byte(i) + } + testKey(t, "discokey:", k1, &k2) +} + +type keyIn interface { + String() string + MarshalText() ([]byte, error) +} + +func testKey(t *testing.T, prefix string, in keyIn, out encoding.TextUnmarshaler) { + got, err := in.MarshalText() + if err != nil { + t.Fatal(err) + } + if err := out.UnmarshalText(got); err != nil { + t.Fatal(err) + } + if s := in.String(); string(got) != s { + t.Errorf("MarshalText = %q != String %q", got, s) + } + if !strings.HasPrefix(string(got), prefix) { + t.Errorf("%q didn't start with prefix %q", got, prefix) + } + if reflect.ValueOf(out).Elem().Interface() != in { + t.Errorf("mismatch after unmarshal") + } +}