mirror of
https://github.com/tailscale/tailscale.git
synced 2025-04-21 06:01:42 +00:00
control/controlclient: support lazy machine key generation
It's not done in the caller yet, but the controlclient does it now. Updates #1573 Signed-off-by: Brad Fitzpatrick <bradfitz@tailscale.com>
This commit is contained in:
parent
8d57bce5ef
commit
a998fe7c3d
@ -63,7 +63,7 @@ type Direct struct {
|
|||||||
logf logger.Logf
|
logf logger.Logf
|
||||||
linkMon *monitor.Mon // or nil
|
linkMon *monitor.Mon // or nil
|
||||||
discoPubKey tailcfg.DiscoKey
|
discoPubKey tailcfg.DiscoKey
|
||||||
machinePrivKey wgkey.Private
|
getMachinePrivKey func() (wgkey.Private, error)
|
||||||
debugFlags []string
|
debugFlags []string
|
||||||
keepSharerAndUserSplit bool
|
keepSharerAndUserSplit bool
|
||||||
|
|
||||||
@ -82,7 +82,7 @@ type Direct struct {
|
|||||||
|
|
||||||
type Options struct {
|
type Options struct {
|
||||||
Persist persist.Persist // initial persistent data
|
Persist persist.Persist // initial persistent data
|
||||||
MachinePrivateKey wgkey.Private // the machine key to use
|
GetMachinePrivateKey func() (wgkey.Private, error) // returns the machine key to use
|
||||||
ServerURL string // URL of the tailcontrol server
|
ServerURL string // URL of the tailcontrol server
|
||||||
AuthKey string // optional node auth key for auto registration
|
AuthKey string // optional node auth key for auto registration
|
||||||
TimeNow func() time.Time // time.Now implementation used by Client
|
TimeNow func() time.Time // time.Now implementation used by Client
|
||||||
@ -110,8 +110,8 @@ func NewDirect(opts Options) (*Direct, error) {
|
|||||||
if opts.ServerURL == "" {
|
if opts.ServerURL == "" {
|
||||||
return nil, errors.New("controlclient.New: no server URL specified")
|
return nil, errors.New("controlclient.New: no server URL specified")
|
||||||
}
|
}
|
||||||
if opts.MachinePrivateKey.IsZero() {
|
if opts.GetMachinePrivateKey == nil {
|
||||||
return nil, errors.New("controlclient.New: no MachinePrivateKey specified")
|
return nil, errors.New("controlclient.New: no GetMachinePrivateKey specified")
|
||||||
}
|
}
|
||||||
opts.ServerURL = strings.TrimRight(opts.ServerURL, "/")
|
opts.ServerURL = strings.TrimRight(opts.ServerURL, "/")
|
||||||
serverURL, err := url.Parse(opts.ServerURL)
|
serverURL, err := url.Parse(opts.ServerURL)
|
||||||
@ -147,7 +147,7 @@ func NewDirect(opts Options) (*Direct, error) {
|
|||||||
|
|
||||||
c := &Direct{
|
c := &Direct{
|
||||||
httpc: httpc,
|
httpc: httpc,
|
||||||
machinePrivKey: opts.MachinePrivateKey,
|
getMachinePrivKey: opts.GetMachinePrivateKey,
|
||||||
serverURL: opts.ServerURL,
|
serverURL: opts.ServerURL,
|
||||||
timeNow: opts.TimeNow,
|
timeNow: opts.TimeNow,
|
||||||
logf: opts.Logf,
|
logf: opts.Logf,
|
||||||
@ -301,8 +301,12 @@ func (c *Direct) doLogin(ctx context.Context, t *tailcfg.Oauth2Token, flags Logi
|
|||||||
expired := c.expiry != nil && !c.expiry.IsZero() && c.expiry.Before(c.timeNow())
|
expired := c.expiry != nil && !c.expiry.IsZero() && c.expiry.Before(c.timeNow())
|
||||||
c.mu.Unlock()
|
c.mu.Unlock()
|
||||||
|
|
||||||
if c.machinePrivKey.IsZero() {
|
machinePrivKey, err := c.getMachinePrivKey()
|
||||||
return false, "", errors.New("controlclient.Direct requires a machine private key")
|
if err != nil {
|
||||||
|
return false, "", fmt.Errorf("getMachinePrivKey: %w", err)
|
||||||
|
}
|
||||||
|
if machinePrivKey.IsZero() {
|
||||||
|
return false, "", errors.New("getMachinePrivKey returned zero key")
|
||||||
}
|
}
|
||||||
|
|
||||||
if expired {
|
if expired {
|
||||||
@ -370,7 +374,7 @@ func (c *Direct) doLogin(ctx context.Context, t *tailcfg.Oauth2Token, flags Logi
|
|||||||
request.Auth.Provider = persist.Provider
|
request.Auth.Provider = persist.Provider
|
||||||
request.Auth.LoginName = persist.LoginName
|
request.Auth.LoginName = persist.LoginName
|
||||||
request.Auth.AuthKey = authKey
|
request.Auth.AuthKey = authKey
|
||||||
err = signRegisterRequest(&request, c.serverURL, c.serverKey, c.machinePrivKey.Public())
|
err = signRegisterRequest(&request, c.serverURL, c.serverKey, machinePrivKey.Public())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// If signing failed, clear all related fields
|
// If signing failed, clear all related fields
|
||||||
request.SignatureType = tailcfg.SignatureNone
|
request.SignatureType = tailcfg.SignatureNone
|
||||||
@ -384,13 +388,13 @@ func (c *Direct) doLogin(ctx context.Context, t *tailcfg.Oauth2Token, flags Logi
|
|||||||
c.logf("RegisterReq sign error: %v", err)
|
c.logf("RegisterReq sign error: %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
bodyData, err := encode(request, &serverKey, &c.machinePrivKey)
|
bodyData, err := encode(request, &serverKey, &machinePrivKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return regen, url, err
|
return regen, url, err
|
||||||
}
|
}
|
||||||
body := bytes.NewReader(bodyData)
|
body := bytes.NewReader(bodyData)
|
||||||
|
|
||||||
u := fmt.Sprintf("%s/machine/%s", c.serverURL, c.machinePrivKey.Public().HexString())
|
u := fmt.Sprintf("%s/machine/%s", c.serverURL, machinePrivKey.Public().HexString())
|
||||||
req, err := http.NewRequest("POST", u, body)
|
req, err := http.NewRequest("POST", u, body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return regen, url, err
|
return regen, url, err
|
||||||
@ -408,8 +412,8 @@ func (c *Direct) doLogin(ctx context.Context, t *tailcfg.Oauth2Token, flags Logi
|
|||||||
res.StatusCode, strings.TrimSpace(string(msg)))
|
res.StatusCode, strings.TrimSpace(string(msg)))
|
||||||
}
|
}
|
||||||
resp := tailcfg.RegisterResponse{}
|
resp := tailcfg.RegisterResponse{}
|
||||||
if err := decode(res, &resp, &serverKey, &c.machinePrivKey); err != nil {
|
if err := decode(res, &resp, &serverKey, &machinePrivKey); err != nil {
|
||||||
c.logf("error decoding RegisterResponse with server key %s and machine key %s: %v", serverKey, c.machinePrivKey.Public(), err)
|
c.logf("error decoding RegisterResponse with server key %s and machine key %s: %v", serverKey, machinePrivKey.Public(), err)
|
||||||
return regen, url, fmt.Errorf("register request: %v", err)
|
return regen, url, fmt.Errorf("register request: %v", err)
|
||||||
}
|
}
|
||||||
// Log without PII:
|
// Log without PII:
|
||||||
@ -536,6 +540,14 @@ func (c *Direct) sendMapRequest(ctx context.Context, maxPolls int, cb func(*netm
|
|||||||
everEndpoints := c.everEndpoints
|
everEndpoints := c.everEndpoints
|
||||||
c.mu.Unlock()
|
c.mu.Unlock()
|
||||||
|
|
||||||
|
machinePrivKey, err := c.getMachinePrivKey()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("getMachinePrivKey: %w", err)
|
||||||
|
}
|
||||||
|
if machinePrivKey.IsZero() {
|
||||||
|
return errors.New("getMachinePrivKey returned zero key")
|
||||||
|
}
|
||||||
|
|
||||||
if persist.PrivateNodeKey.IsZero() {
|
if persist.PrivateNodeKey.IsZero() {
|
||||||
return errors.New("privateNodeKey is zero")
|
return errors.New("privateNodeKey is zero")
|
||||||
}
|
}
|
||||||
@ -593,7 +605,7 @@ func (c *Direct) sendMapRequest(ctx context.Context, maxPolls int, cb func(*netm
|
|||||||
request.ReadOnly = true
|
request.ReadOnly = true
|
||||||
}
|
}
|
||||||
|
|
||||||
bodyData, err := encode(request, &serverKey, &c.machinePrivKey)
|
bodyData, err := encode(request, &serverKey, &machinePrivKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
vlogf("netmap: encode: %v", err)
|
vlogf("netmap: encode: %v", err)
|
||||||
return err
|
return err
|
||||||
@ -602,7 +614,7 @@ func (c *Direct) sendMapRequest(ctx context.Context, maxPolls int, cb func(*netm
|
|||||||
ctx, cancel := context.WithCancel(ctx)
|
ctx, cancel := context.WithCancel(ctx)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
machinePubKey := tailcfg.MachineKey(c.machinePrivKey.Public())
|
machinePubKey := tailcfg.MachineKey(machinePrivKey.Public())
|
||||||
t0 := time.Now()
|
t0 := time.Now()
|
||||||
u := fmt.Sprintf("%s/machine/%s/map", serverURL, machinePubKey.HexString())
|
u := fmt.Sprintf("%s/machine/%s/map", serverURL, machinePubKey.HexString())
|
||||||
|
|
||||||
@ -695,7 +707,7 @@ func (c *Direct) sendMapRequest(ctx context.Context, maxPolls int, cb func(*netm
|
|||||||
vlogf("netmap: read body after %v", time.Since(t0).Round(time.Millisecond))
|
vlogf("netmap: read body after %v", time.Since(t0).Round(time.Millisecond))
|
||||||
|
|
||||||
var resp tailcfg.MapResponse
|
var resp tailcfg.MapResponse
|
||||||
if err := c.decodeMsg(msg, &resp); err != nil {
|
if err := c.decodeMsg(msg, &resp, &machinePrivKey); err != nil {
|
||||||
vlogf("netmap: decode error: %v")
|
vlogf("netmap: decode error: %v")
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@ -878,12 +890,12 @@ var debugMap, _ = strconv.ParseBool(os.Getenv("TS_DEBUG_MAP"))
|
|||||||
|
|
||||||
var jsonEscapedZero = []byte(`\u0000`)
|
var jsonEscapedZero = []byte(`\u0000`)
|
||||||
|
|
||||||
func (c *Direct) decodeMsg(msg []byte, v interface{}) error {
|
func (c *Direct) decodeMsg(msg []byte, v interface{}, machinePrivKey *wgkey.Private) error {
|
||||||
c.mu.Lock()
|
c.mu.Lock()
|
||||||
serverKey := c.serverKey
|
serverKey := c.serverKey
|
||||||
c.mu.Unlock()
|
c.mu.Unlock()
|
||||||
|
|
||||||
decrypted, err := decryptMsg(msg, &serverKey, &c.machinePrivKey)
|
decrypted, err := decryptMsg(msg, &serverKey, machinePrivKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@ -917,8 +929,8 @@ func (c *Direct) decodeMsg(msg []byte, v interface{}) error {
|
|||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func decodeMsg(msg []byte, v interface{}, serverKey *wgkey.Key, mkey *wgkey.Private) error {
|
func decodeMsg(msg []byte, v interface{}, serverKey *wgkey.Key, machinePrivKey *wgkey.Private) error {
|
||||||
decrypted, err := decryptMsg(msg, serverKey, mkey)
|
decrypted, err := decryptMsg(msg, serverKey, machinePrivKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -103,7 +103,13 @@ func TestNewDirect(t *testing.T) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
t.Error(err)
|
t.Error(err)
|
||||||
}
|
}
|
||||||
opts := Options{ServerURL: "https://example.com", MachinePrivateKey: key, Hostinfo: hi}
|
opts := Options{
|
||||||
|
ServerURL: "https://example.com",
|
||||||
|
Hostinfo: hi,
|
||||||
|
GetMachinePrivateKey: func() (wgkey.Private, error) {
|
||||||
|
return key, nil
|
||||||
|
},
|
||||||
|
}
|
||||||
c, err := NewDirect(opts)
|
c, err := NewDirect(opts)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
|
@ -623,7 +623,12 @@ func (b *LocalBackend) Start(opts ipn.Options) error {
|
|||||||
persistv = &persist.Persist{}
|
persistv = &persist.Persist{}
|
||||||
}
|
}
|
||||||
cli, err := controlclient.New(controlclient.Options{
|
cli, err := controlclient.New(controlclient.Options{
|
||||||
MachinePrivateKey: machinePrivKey,
|
GetMachinePrivateKey: func() (wgkey.Private, error) {
|
||||||
|
// TODO(bradfitz): finish pushing this laziness further; see
|
||||||
|
// https://github.com/tailscale/tailscale/issues/1573
|
||||||
|
// For now this is only lazy-ified in controlclient.
|
||||||
|
return machinePrivKey, nil
|
||||||
|
},
|
||||||
Logf: logger.WithPrefix(b.logf, "control: "),
|
Logf: logger.WithPrefix(b.logf, "control: "),
|
||||||
Persist: *persistv,
|
Persist: *persistv,
|
||||||
ServerURL: b.serverURL,
|
ServerURL: b.serverURL,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user