control/controlclient: support lazy machine key generation

It's not done in the caller yet, but the controlclient does it now.

Updates #1573

Signed-off-by: Brad Fitzpatrick <bradfitz@tailscale.com>
This commit is contained in:
Brad Fitzpatrick 2021-03-31 08:51:22 -07:00
parent 8d57bce5ef
commit a998fe7c3d
3 changed files with 67 additions and 44 deletions

View File

@ -63,7 +63,7 @@ type Direct struct {
logf logger.Logf logf logger.Logf
linkMon *monitor.Mon // or nil linkMon *monitor.Mon // or nil
discoPubKey tailcfg.DiscoKey discoPubKey tailcfg.DiscoKey
machinePrivKey wgkey.Private getMachinePrivKey func() (wgkey.Private, error)
debugFlags []string debugFlags []string
keepSharerAndUserSplit bool keepSharerAndUserSplit bool
@ -82,7 +82,7 @@ type Direct struct {
type Options struct { type Options struct {
Persist persist.Persist // initial persistent data Persist persist.Persist // initial persistent data
MachinePrivateKey wgkey.Private // the machine key to use GetMachinePrivateKey func() (wgkey.Private, error) // returns the machine key to use
ServerURL string // URL of the tailcontrol server ServerURL string // URL of the tailcontrol server
AuthKey string // optional node auth key for auto registration AuthKey string // optional node auth key for auto registration
TimeNow func() time.Time // time.Now implementation used by Client TimeNow func() time.Time // time.Now implementation used by Client
@ -110,8 +110,8 @@ func NewDirect(opts Options) (*Direct, error) {
if opts.ServerURL == "" { if opts.ServerURL == "" {
return nil, errors.New("controlclient.New: no server URL specified") return nil, errors.New("controlclient.New: no server URL specified")
} }
if opts.MachinePrivateKey.IsZero() { if opts.GetMachinePrivateKey == nil {
return nil, errors.New("controlclient.New: no MachinePrivateKey specified") return nil, errors.New("controlclient.New: no GetMachinePrivateKey specified")
} }
opts.ServerURL = strings.TrimRight(opts.ServerURL, "/") opts.ServerURL = strings.TrimRight(opts.ServerURL, "/")
serverURL, err := url.Parse(opts.ServerURL) serverURL, err := url.Parse(opts.ServerURL)
@ -147,7 +147,7 @@ func NewDirect(opts Options) (*Direct, error) {
c := &Direct{ c := &Direct{
httpc: httpc, httpc: httpc,
machinePrivKey: opts.MachinePrivateKey, getMachinePrivKey: opts.GetMachinePrivateKey,
serverURL: opts.ServerURL, serverURL: opts.ServerURL,
timeNow: opts.TimeNow, timeNow: opts.TimeNow,
logf: opts.Logf, logf: opts.Logf,
@ -301,8 +301,12 @@ func (c *Direct) doLogin(ctx context.Context, t *tailcfg.Oauth2Token, flags Logi
expired := c.expiry != nil && !c.expiry.IsZero() && c.expiry.Before(c.timeNow()) expired := c.expiry != nil && !c.expiry.IsZero() && c.expiry.Before(c.timeNow())
c.mu.Unlock() c.mu.Unlock()
if c.machinePrivKey.IsZero() { machinePrivKey, err := c.getMachinePrivKey()
return false, "", errors.New("controlclient.Direct requires a machine private key") if err != nil {
return false, "", fmt.Errorf("getMachinePrivKey: %w", err)
}
if machinePrivKey.IsZero() {
return false, "", errors.New("getMachinePrivKey returned zero key")
} }
if expired { if expired {
@ -370,7 +374,7 @@ func (c *Direct) doLogin(ctx context.Context, t *tailcfg.Oauth2Token, flags Logi
request.Auth.Provider = persist.Provider request.Auth.Provider = persist.Provider
request.Auth.LoginName = persist.LoginName request.Auth.LoginName = persist.LoginName
request.Auth.AuthKey = authKey request.Auth.AuthKey = authKey
err = signRegisterRequest(&request, c.serverURL, c.serverKey, c.machinePrivKey.Public()) err = signRegisterRequest(&request, c.serverURL, c.serverKey, machinePrivKey.Public())
if err != nil { if err != nil {
// If signing failed, clear all related fields // If signing failed, clear all related fields
request.SignatureType = tailcfg.SignatureNone request.SignatureType = tailcfg.SignatureNone
@ -384,13 +388,13 @@ func (c *Direct) doLogin(ctx context.Context, t *tailcfg.Oauth2Token, flags Logi
c.logf("RegisterReq sign error: %v", err) c.logf("RegisterReq sign error: %v", err)
} }
} }
bodyData, err := encode(request, &serverKey, &c.machinePrivKey) bodyData, err := encode(request, &serverKey, &machinePrivKey)
if err != nil { if err != nil {
return regen, url, err return regen, url, err
} }
body := bytes.NewReader(bodyData) body := bytes.NewReader(bodyData)
u := fmt.Sprintf("%s/machine/%s", c.serverURL, c.machinePrivKey.Public().HexString()) u := fmt.Sprintf("%s/machine/%s", c.serverURL, machinePrivKey.Public().HexString())
req, err := http.NewRequest("POST", u, body) req, err := http.NewRequest("POST", u, body)
if err != nil { if err != nil {
return regen, url, err return regen, url, err
@ -408,8 +412,8 @@ func (c *Direct) doLogin(ctx context.Context, t *tailcfg.Oauth2Token, flags Logi
res.StatusCode, strings.TrimSpace(string(msg))) res.StatusCode, strings.TrimSpace(string(msg)))
} }
resp := tailcfg.RegisterResponse{} resp := tailcfg.RegisterResponse{}
if err := decode(res, &resp, &serverKey, &c.machinePrivKey); err != nil { if err := decode(res, &resp, &serverKey, &machinePrivKey); err != nil {
c.logf("error decoding RegisterResponse with server key %s and machine key %s: %v", serverKey, c.machinePrivKey.Public(), err) c.logf("error decoding RegisterResponse with server key %s and machine key %s: %v", serverKey, machinePrivKey.Public(), err)
return regen, url, fmt.Errorf("register request: %v", err) return regen, url, fmt.Errorf("register request: %v", err)
} }
// Log without PII: // Log without PII:
@ -536,6 +540,14 @@ func (c *Direct) sendMapRequest(ctx context.Context, maxPolls int, cb func(*netm
everEndpoints := c.everEndpoints everEndpoints := c.everEndpoints
c.mu.Unlock() c.mu.Unlock()
machinePrivKey, err := c.getMachinePrivKey()
if err != nil {
return fmt.Errorf("getMachinePrivKey: %w", err)
}
if machinePrivKey.IsZero() {
return errors.New("getMachinePrivKey returned zero key")
}
if persist.PrivateNodeKey.IsZero() { if persist.PrivateNodeKey.IsZero() {
return errors.New("privateNodeKey is zero") return errors.New("privateNodeKey is zero")
} }
@ -593,7 +605,7 @@ func (c *Direct) sendMapRequest(ctx context.Context, maxPolls int, cb func(*netm
request.ReadOnly = true request.ReadOnly = true
} }
bodyData, err := encode(request, &serverKey, &c.machinePrivKey) bodyData, err := encode(request, &serverKey, &machinePrivKey)
if err != nil { if err != nil {
vlogf("netmap: encode: %v", err) vlogf("netmap: encode: %v", err)
return err return err
@ -602,7 +614,7 @@ func (c *Direct) sendMapRequest(ctx context.Context, maxPolls int, cb func(*netm
ctx, cancel := context.WithCancel(ctx) ctx, cancel := context.WithCancel(ctx)
defer cancel() defer cancel()
machinePubKey := tailcfg.MachineKey(c.machinePrivKey.Public()) machinePubKey := tailcfg.MachineKey(machinePrivKey.Public())
t0 := time.Now() t0 := time.Now()
u := fmt.Sprintf("%s/machine/%s/map", serverURL, machinePubKey.HexString()) u := fmt.Sprintf("%s/machine/%s/map", serverURL, machinePubKey.HexString())
@ -695,7 +707,7 @@ func (c *Direct) sendMapRequest(ctx context.Context, maxPolls int, cb func(*netm
vlogf("netmap: read body after %v", time.Since(t0).Round(time.Millisecond)) vlogf("netmap: read body after %v", time.Since(t0).Round(time.Millisecond))
var resp tailcfg.MapResponse var resp tailcfg.MapResponse
if err := c.decodeMsg(msg, &resp); err != nil { if err := c.decodeMsg(msg, &resp, &machinePrivKey); err != nil {
vlogf("netmap: decode error: %v") vlogf("netmap: decode error: %v")
return err return err
} }
@ -878,12 +890,12 @@ var debugMap, _ = strconv.ParseBool(os.Getenv("TS_DEBUG_MAP"))
var jsonEscapedZero = []byte(`\u0000`) var jsonEscapedZero = []byte(`\u0000`)
func (c *Direct) decodeMsg(msg []byte, v interface{}) error { func (c *Direct) decodeMsg(msg []byte, v interface{}, machinePrivKey *wgkey.Private) error {
c.mu.Lock() c.mu.Lock()
serverKey := c.serverKey serverKey := c.serverKey
c.mu.Unlock() c.mu.Unlock()
decrypted, err := decryptMsg(msg, &serverKey, &c.machinePrivKey) decrypted, err := decryptMsg(msg, &serverKey, machinePrivKey)
if err != nil { if err != nil {
return err return err
} }
@ -917,8 +929,8 @@ func (c *Direct) decodeMsg(msg []byte, v interface{}) error {
} }
func decodeMsg(msg []byte, v interface{}, serverKey *wgkey.Key, mkey *wgkey.Private) error { func decodeMsg(msg []byte, v interface{}, serverKey *wgkey.Key, machinePrivKey *wgkey.Private) error {
decrypted, err := decryptMsg(msg, serverKey, mkey) decrypted, err := decryptMsg(msg, serverKey, machinePrivKey)
if err != nil { if err != nil {
return err return err
} }

View File

@ -103,7 +103,13 @@ func TestNewDirect(t *testing.T) {
if err != nil { if err != nil {
t.Error(err) t.Error(err)
} }
opts := Options{ServerURL: "https://example.com", MachinePrivateKey: key, Hostinfo: hi} opts := Options{
ServerURL: "https://example.com",
Hostinfo: hi,
GetMachinePrivateKey: func() (wgkey.Private, error) {
return key, nil
},
}
c, err := NewDirect(opts) c, err := NewDirect(opts)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)

View File

@ -623,7 +623,12 @@ func (b *LocalBackend) Start(opts ipn.Options) error {
persistv = &persist.Persist{} persistv = &persist.Persist{}
} }
cli, err := controlclient.New(controlclient.Options{ cli, err := controlclient.New(controlclient.Options{
MachinePrivateKey: machinePrivKey, GetMachinePrivateKey: func() (wgkey.Private, error) {
// TODO(bradfitz): finish pushing this laziness further; see
// https://github.com/tailscale/tailscale/issues/1573
// For now this is only lazy-ified in controlclient.
return machinePrivKey, nil
},
Logf: logger.WithPrefix(b.logf, "control: "), Logf: logger.WithPrefix(b.logf, "control: "),
Persist: *persistv, Persist: *persistv,
ServerURL: b.serverURL, ServerURL: b.serverURL,