types/key: add MachinePrivate and MachinePublic.

Plumb throughout the codebase as a replacement for the mixed use of
tailcfg.MachineKey and wgkey.Private/Public.

Signed-off-by: David Anderson <danderson@tailscale.com>
This commit is contained in:
David Anderson
2021-09-01 01:52:27 -07:00
committed by Dave Anderson
parent 4ce091cbd8
commit 4fdb88efe1
24 changed files with 605 additions and 234 deletions

View File

@@ -7,7 +7,6 @@ package controlclient
import (
"bytes"
"context"
"crypto/rand"
"encoding/binary"
"encoding/json"
"errors"
@@ -28,7 +27,7 @@ import (
"sync/atomic"
"time"
"golang.org/x/crypto/nacl/box"
"go4.org/mem"
"inet.af/netaddr"
"tailscale.com/control/controlknobs"
"tailscale.com/health"
@@ -42,6 +41,7 @@ import (
"tailscale.com/net/tlsdial"
"tailscale.com/net/tshttpproxy"
"tailscale.com/tailcfg"
"tailscale.com/types/key"
"tailscale.com/types/logger"
"tailscale.com/types/netmap"
"tailscale.com/types/opt"
@@ -62,14 +62,14 @@ type Direct struct {
logf logger.Logf
linkMon *monitor.Mon // or nil
discoPubKey tailcfg.DiscoKey
getMachinePrivKey func() (wgkey.Private, error)
getMachinePrivKey func() (key.MachinePrivate, error)
debugFlags []string
keepSharerAndUserSplit bool
skipIPForwardingCheck bool
pinger Pinger
mu sync.Mutex // mutex guards the following fields
serverKey wgkey.Key
serverKey key.MachinePublic
persist persist.Persist
authKey string
tryingNewKey wgkey.Private
@@ -83,12 +83,12 @@ type Direct struct {
}
type Options struct {
Persist persist.Persist // initial persistent data
GetMachinePrivateKey func() (wgkey.Private, error) // returns the machine key to use
ServerURL string // URL of the tailcontrol server
AuthKey string // optional node auth key for auto registration
TimeNow func() time.Time // time.Now implementation used by Client
Hostinfo *tailcfg.Hostinfo // non-nil passes ownership, nil means to use default using os.Hostname, etc
Persist persist.Persist // initial persistent data
GetMachinePrivateKey func() (key.MachinePrivate, error) // returns the machine key to use
ServerURL string // URL of the tailcontrol server
AuthKey string // optional node auth key for auto registration
TimeNow func() time.Time // time.Now implementation used by Client
Hostinfo *tailcfg.Hostinfo // non-nil passes ownership, nil means to use default using os.Hostname, etc
DiscoPublicKey tailcfg.DiscoKey
NewDecompressor func() (Decompressor, error)
KeepAlive bool
@@ -320,7 +320,7 @@ func (c *Direct) doLogin(ctx context.Context, opt loginOpt) (mustRegen bool, new
if err != nil {
return regen, opt.URL, err
}
c.logf("control server key %s from %s", serverKey.HexString(), c.serverURL)
c.logf("control server key %s from %s", serverKey.ShortString(), c.serverURL)
c.mu.Lock()
c.serverKey = serverKey
@@ -398,13 +398,13 @@ func (c *Direct) doLogin(ctx context.Context, opt loginOpt) (mustRegen bool, new
c.logf("RegisterRequest: %s", j)
}
bodyData, err := encode(request, &serverKey, &machinePrivKey)
bodyData, err := encode(request, serverKey, machinePrivKey)
if err != nil {
return regen, opt.URL, err
}
body := bytes.NewReader(bodyData)
u := fmt.Sprintf("%s/machine/%s", c.serverURL, machinePrivKey.Public().HexString())
u := fmt.Sprintf("%s/machine/%s", c.serverURL, machinePrivKey.Public().UntypedHexString())
req, err := http.NewRequest("POST", u, body)
if err != nil {
return regen, opt.URL, err
@@ -422,7 +422,7 @@ func (c *Direct) doLogin(ctx context.Context, opt loginOpt) (mustRegen bool, new
res.StatusCode, strings.TrimSpace(string(msg)))
}
resp := tailcfg.RegisterResponse{}
if err := decode(res, &resp, &serverKey, &machinePrivKey); err != nil {
if err := decode(res, &resp, serverKey, machinePrivKey); err != nil {
c.logf("error decoding RegisterResponse with server key %s and machine key %s: %v", serverKey, machinePrivKey.Public(), err)
return regen, opt.URL, fmt.Errorf("register request: %v", err)
}
@@ -636,7 +636,7 @@ func (c *Direct) sendMapRequest(ctx context.Context, maxPolls int, cb func(*netm
request.ReadOnly = true
}
bodyData, err := encode(request, &serverKey, &machinePrivKey)
bodyData, err := encode(request, serverKey, machinePrivKey)
if err != nil {
vlogf("netmap: encode: %v", err)
return err
@@ -645,9 +645,9 @@ func (c *Direct) sendMapRequest(ctx context.Context, maxPolls int, cb func(*netm
ctx, cancel := context.WithCancel(ctx)
defer cancel()
machinePubKey := tailcfg.MachineKey(machinePrivKey.Public())
machinePubKey := machinePrivKey.Public()
t0 := time.Now()
u := fmt.Sprintf("%s/machine/%s/map", serverURL, machinePubKey.HexString())
u := fmt.Sprintf("%s/machine/%s/map", serverURL, machinePubKey.UntypedHexString())
req, err := http.NewRequestWithContext(ctx, "POST", u, bytes.NewReader(bodyData))
if err != nil {
@@ -734,7 +734,7 @@ func (c *Direct) sendMapRequest(ctx context.Context, maxPolls int, cb func(*netm
vlogf("netmap: read body after %v", time.Since(t0).Round(time.Millisecond))
var resp tailcfg.MapResponse
if err := c.decodeMsg(msg, &resp, &machinePrivKey); err != nil {
if err := c.decodeMsg(msg, &resp, machinePrivKey); err != nil {
vlogf("netmap: decode error: %v")
return err
}
@@ -830,7 +830,7 @@ func (c *Direct) sendMapRequest(ctx context.Context, maxPolls int, cb func(*netm
return nil
}
func decode(res *http.Response, v interface{}, serverKey *wgkey.Key, mkey *wgkey.Private) error {
func decode(res *http.Response, v interface{}, serverKey key.MachinePublic, mkey key.MachinePrivate) error {
defer res.Body.Close()
msg, err := ioutil.ReadAll(io.LimitReader(res.Body, 1<<20))
if err != nil {
@@ -849,14 +849,14 @@ var (
var jsonEscapedZero = []byte(`\u0000`)
func (c *Direct) decodeMsg(msg []byte, v interface{}, machinePrivKey *wgkey.Private) error {
func (c *Direct) decodeMsg(msg []byte, v interface{}, machinePrivKey key.MachinePrivate) error {
c.mu.Lock()
serverKey := c.serverKey
c.mu.Unlock()
decrypted, err := decryptMsg(msg, &serverKey, machinePrivKey)
if err != nil {
return err
decrypted, ok := machinePrivKey.OpenFrom(serverKey, msg)
if !ok {
return errors.New("cannot decrypt response")
}
var b []byte
if c.newDecompressor == nil {
@@ -888,10 +888,10 @@ func (c *Direct) decodeMsg(msg []byte, v interface{}, machinePrivKey *wgkey.Priv
}
func decodeMsg(msg []byte, v interface{}, serverKey *wgkey.Key, machinePrivKey *wgkey.Private) error {
decrypted, err := decryptMsg(msg, serverKey, machinePrivKey)
if err != nil {
return err
func decodeMsg(msg []byte, v interface{}, serverKey key.MachinePublic, machinePrivKey key.MachinePrivate) error {
decrypted, ok := machinePrivKey.OpenFrom(serverKey, msg)
if !ok {
return errors.New("cannot decrypt response")
}
if bytes.Contains(decrypted, jsonEscapedZero) {
log.Printf("[unexpected] zero byte in controlclient decodeMsg into %T: %q", v, decrypted)
@@ -902,23 +902,7 @@ func decodeMsg(msg []byte, v interface{}, serverKey *wgkey.Key, machinePrivKey *
return nil
}
func decryptMsg(msg []byte, serverKey *wgkey.Key, mkey *wgkey.Private) ([]byte, error) {
var nonce [24]byte
if len(msg) < len(nonce)+1 {
return nil, fmt.Errorf("response missing nonce, len=%d", len(msg))
}
copy(nonce[:], msg)
msg = msg[len(nonce):]
pub, pri := (*[32]byte)(serverKey), (*[32]byte)(mkey)
decrypted, ok := box.Open(nil, msg, &nonce, pub, pri)
if !ok {
return nil, fmt.Errorf("cannot decrypt response (len %d + nonce %d = %d)", len(msg), len(nonce), len(msg)+len(nonce))
}
return decrypted, nil
}
func encode(v interface{}, serverKey *wgkey.Key, mkey *wgkey.Private) ([]byte, error) {
func encode(v interface{}, serverKey key.MachinePublic, mkey key.MachinePrivate) ([]byte, error) {
b, err := json.Marshal(v)
if err != nil {
return nil, err
@@ -928,38 +912,32 @@ func encode(v interface{}, serverKey *wgkey.Key, mkey *wgkey.Private) ([]byte, e
log.Printf("MapRequest: %s", b)
}
}
var nonce [24]byte
if _, err := io.ReadFull(rand.Reader, nonce[:]); err != nil {
panic(err)
}
pub, pri := (*[32]byte)(serverKey), (*[32]byte)(mkey)
msg := box.Seal(nonce[:], b, &nonce, pub, pri)
return msg, nil
return mkey.SealTo(serverKey, b), nil
}
func loadServerKey(ctx context.Context, httpc *http.Client, serverURL string) (wgkey.Key, error) {
func loadServerKey(ctx context.Context, httpc *http.Client, serverURL string) (key.MachinePublic, error) {
req, err := http.NewRequest("GET", serverURL+"/key", nil)
if err != nil {
return wgkey.Key{}, fmt.Errorf("create control key request: %v", err)
return key.MachinePublic{}, fmt.Errorf("create control key request: %v", err)
}
req = req.WithContext(ctx)
res, err := httpc.Do(req)
if err != nil {
return wgkey.Key{}, fmt.Errorf("fetch control key: %v", err)
return key.MachinePublic{}, fmt.Errorf("fetch control key: %v", err)
}
defer res.Body.Close()
b, err := ioutil.ReadAll(io.LimitReader(res.Body, 1<<16))
if err != nil {
return wgkey.Key{}, fmt.Errorf("fetch control key response: %v", err)
return key.MachinePublic{}, fmt.Errorf("fetch control key response: %v", err)
}
if res.StatusCode != 200 {
return wgkey.Key{}, fmt.Errorf("fetch control key: %d: %s", res.StatusCode, string(b))
return key.MachinePublic{}, fmt.Errorf("fetch control key: %d: %s", res.StatusCode, string(b))
}
key, err := wgkey.ParseHex(string(b))
k, err := key.ParseMachinePublicUntyped(mem.B(b))
if err != nil {
return wgkey.Key{}, fmt.Errorf("fetch control key: %v", err)
return key.MachinePublic{}, fmt.Errorf("fetch control key: %v", err)
}
return key, nil
return k, nil
}
// Debug contains temporary internal-only debug knobs.
@@ -1207,13 +1185,13 @@ func (c *Direct) SetDNS(ctx context.Context, req *tailcfg.SetDNSRequest) error {
return errors.New("getMachinePrivKey returned zero key")
}
bodyData, err := encode(req, &serverKey, &machinePrivKey)
bodyData, err := encode(req, serverKey, machinePrivKey)
if err != nil {
return err
}
body := bytes.NewReader(bodyData)
u := fmt.Sprintf("%s/machine/%s/set-dns", c.serverURL, machinePrivKey.Public().HexString())
u := fmt.Sprintf("%s/machine/%s/set-dns", c.serverURL, machinePrivKey.Public().UntypedHexString())
hreq, err := http.NewRequestWithContext(ctx, "POST", u, body)
if err != nil {
return err
@@ -1228,7 +1206,7 @@ func (c *Direct) SetDNS(ctx context.Context, req *tailcfg.SetDNSRequest) error {
return fmt.Errorf("set-dns response: %v, %.200s", res.Status, strings.TrimSpace(string(msg)))
}
var setDNSRes struct{} // no fields yet
if err := decode(res, &setDNSRes, &serverKey, &machinePrivKey); err != nil {
if err := decode(res, &setDNSRes, serverKey, machinePrivKey); err != nil {
c.logf("error decoding SetDNSResponse with server key %s and machine key %s: %v", serverKey, machinePrivKey.Public(), err)
return fmt.Errorf("set-dns-response: %v", err)
}