control/controlclient: refactor some internals

Signed-off-by: Brad Fitzpatrick <bradfitz@tailscale.com>
This commit is contained in:
Brad Fitzpatrick 2021-04-07 21:06:31 -07:00
parent 71432c6449
commit 597c19ff4e

View File

@ -274,7 +274,7 @@ func (c *Direct) TryLogout(ctx context.Context) error {
func (c *Direct) TryLogin(ctx context.Context, t *tailcfg.Oauth2Token, flags LoginFlags) (url string, err error) { func (c *Direct) TryLogin(ctx context.Context, t *tailcfg.Oauth2Token, flags LoginFlags) (url string, err error) {
c.logf("direct.TryLogin(token=%v, flags=%v)", t != nil, flags) c.logf("direct.TryLogin(token=%v, flags=%v)", t != nil, flags)
return c.doLoginOrRegen(ctx, t, flags, false, "") return c.doLoginOrRegen(ctx, loginOpt{Token: t, Flags: flags})
} }
// WaitLoginURL sits in a long poll waiting for the user to authenticate at url. // WaitLoginURL sits in a long poll waiting for the user to authenticate at url.
@ -282,22 +282,29 @@ func (c *Direct) TryLogin(ctx context.Context, t *tailcfg.Oauth2Token, flags Log
// On success, newURL and err will both be nil. // On success, newURL and err will both be nil.
func (c *Direct) WaitLoginURL(ctx context.Context, url string) (newURL string, err error) { func (c *Direct) WaitLoginURL(ctx context.Context, url string) (newURL string, err error) {
c.logf("direct.WaitLoginURL") c.logf("direct.WaitLoginURL")
return c.doLoginOrRegen(ctx, nil, LoginDefault, false, url) return c.doLoginOrRegen(ctx, loginOpt{URL: url})
} }
func (c *Direct) doLoginOrRegen(ctx context.Context, t *tailcfg.Oauth2Token, flags LoginFlags, regen bool, url string) (newURL string, err error) { func (c *Direct) doLoginOrRegen(ctx context.Context, opt loginOpt) (newURL string, err error) {
mustregen, url, err := c.doLogin(ctx, t, flags, regen, url) mustRegen, url, err := c.doLogin(ctx, opt)
if err != nil { if err != nil {
return url, err return url, err
} }
if mustregen { if mustRegen {
_, url, err = c.doLogin(ctx, t, flags, true, url) opt.Regen = true
_, url, err = c.doLogin(ctx, opt)
} }
return url, err return url, err
} }
func (c *Direct) doLogin(ctx context.Context, t *tailcfg.Oauth2Token, flags LoginFlags, regen bool, url string) (mustregen bool, newurl string, err error) { type loginOpt struct {
Token *tailcfg.Oauth2Token
Flags LoginFlags
Regen bool
URL string
}
func (c *Direct) doLogin(ctx context.Context, opt loginOpt) (mustRegen bool, newURL string, err error) {
c.mu.Lock() c.mu.Lock()
persist := c.persist persist := c.persist
tryingNewKey := c.tryingNewKey tryingNewKey := c.tryingNewKey
@ -316,22 +323,23 @@ func (c *Direct) doLogin(ctx context.Context, t *tailcfg.Oauth2Token, flags Logi
return false, "", errors.New("getMachinePrivKey returned zero key") return false, "", errors.New("getMachinePrivKey returned zero key")
} }
regen := opt.Regen
if expired { if expired {
c.logf("Old key expired -> regen=true") c.logf("Old key expired -> regen=true")
systemd.Status("key expired; run 'tailscale up' to authenticate") systemd.Status("key expired; run 'tailscale up' to authenticate")
regen = true regen = true
} }
if (flags & LoginInteractive) != 0 { if (opt.Flags & LoginInteractive) != 0 {
c.logf("LoginInteractive -> regen=true") c.logf("LoginInteractive -> regen=true")
regen = true regen = true
} }
c.logf("doLogin(regen=%v, hasUrl=%v)", regen, url != "") c.logf("doLogin(regen=%v, hasUrl=%v)", regen, opt.URL != "")
if serverKey.IsZero() { if serverKey.IsZero() {
var err error var err error
serverKey, err = loadServerKey(ctx, c.httpc, c.serverURL) serverKey, err = loadServerKey(ctx, c.httpc, c.serverURL)
if err != nil { if err != nil {
return regen, url, err return regen, opt.URL, err
} }
c.mu.Lock() c.mu.Lock()
@ -340,14 +348,14 @@ func (c *Direct) doLogin(ctx context.Context, t *tailcfg.Oauth2Token, flags Logi
} }
var oldNodeKey wgkey.Key var oldNodeKey wgkey.Key
if url != "" { if opt.URL != "" {
} else if regen || persist.PrivateNodeKey.IsZero() { } else if regen || persist.PrivateNodeKey.IsZero() {
c.logf("Generating a new nodekey.") c.logf("Generating a new nodekey.")
persist.OldPrivateNodeKey = persist.PrivateNodeKey persist.OldPrivateNodeKey = persist.PrivateNodeKey
key, err := wgkey.NewPrivate() key, err := wgkey.NewPrivate()
if err != nil { if err != nil {
c.logf("login keygen: %v", err) c.logf("login keygen: %v", err)
return regen, url, err return regen, opt.URL, err
} }
tryingNewKey = key tryingNewKey = key
} else { } else {
@ -363,7 +371,7 @@ func (c *Direct) doLogin(ctx context.Context, t *tailcfg.Oauth2Token, flags Logi
} }
if backendLogID == "" { if backendLogID == "" {
err = errors.New("hostinfo: BackendLogID missing") err = errors.New("hostinfo: BackendLogID missing")
return regen, url, err return regen, opt.URL, err
} }
now := time.Now().Round(time.Second) now := time.Now().Round(time.Second)
request := tailcfg.RegisterRequest{ request := tailcfg.RegisterRequest{
@ -371,13 +379,13 @@ func (c *Direct) doLogin(ctx context.Context, t *tailcfg.Oauth2Token, flags Logi
OldNodeKey: tailcfg.NodeKey(oldNodeKey), OldNodeKey: tailcfg.NodeKey(oldNodeKey),
NodeKey: tailcfg.NodeKey(tryingNewKey.Public()), NodeKey: tailcfg.NodeKey(tryingNewKey.Public()),
Hostinfo: hostinfo, Hostinfo: hostinfo,
Followup: url, Followup: opt.URL,
Timestamp: &now, Timestamp: &now,
} }
c.logf("RegisterReq: onode=%v node=%v fup=%v", c.logf("RegisterReq: onode=%v node=%v fup=%v",
request.OldNodeKey.ShortString(), request.OldNodeKey.ShortString(),
request.NodeKey.ShortString(), url != "") request.NodeKey.ShortString(), opt.URL != "")
request.Auth.Oauth2Token = t request.Auth.Oauth2Token = opt.Token
request.Auth.Provider = persist.Provider request.Auth.Provider = persist.Provider
request.Auth.LoginName = persist.LoginName request.Auth.LoginName = persist.LoginName
request.Auth.AuthKey = authKey request.Auth.AuthKey = authKey
@ -397,31 +405,31 @@ func (c *Direct) doLogin(ctx context.Context, t *tailcfg.Oauth2Token, flags Logi
} }
bodyData, err := encode(request, &serverKey, &machinePrivKey) bodyData, err := encode(request, &serverKey, &machinePrivKey)
if err != nil { if err != nil {
return regen, url, err return regen, opt.URL, err
} }
body := bytes.NewReader(bodyData) body := bytes.NewReader(bodyData)
u := fmt.Sprintf("%s/machine/%s", c.serverURL, machinePrivKey.Public().HexString()) u := fmt.Sprintf("%s/machine/%s", c.serverURL, machinePrivKey.Public().HexString())
req, err := http.NewRequest("POST", u, body) req, err := http.NewRequest("POST", u, body)
if err != nil { if err != nil {
return regen, url, err return regen, opt.URL, err
} }
req = req.WithContext(ctx) req = req.WithContext(ctx)
res, err := c.httpc.Do(req) res, err := c.httpc.Do(req)
if err != nil { if err != nil {
return regen, url, fmt.Errorf("register request: %v", err) return regen, opt.URL, fmt.Errorf("register request: %v", err)
} }
if res.StatusCode != 200 { if res.StatusCode != 200 {
msg, _ := ioutil.ReadAll(res.Body) msg, _ := ioutil.ReadAll(res.Body)
res.Body.Close() res.Body.Close()
return regen, url, fmt.Errorf("register request: http %d: %.200s", return regen, opt.URL, fmt.Errorf("register request: http %d: %.200s",
res.StatusCode, strings.TrimSpace(string(msg))) res.StatusCode, strings.TrimSpace(string(msg)))
} }
resp := tailcfg.RegisterResponse{} resp := tailcfg.RegisterResponse{}
if err := decode(res, &resp, &serverKey, &machinePrivKey); err != nil { if err := decode(res, &resp, &serverKey, &machinePrivKey); err != nil {
c.logf("error decoding RegisterResponse with server key %s and machine key %s: %v", serverKey, machinePrivKey.Public(), err) c.logf("error decoding RegisterResponse with server key %s and machine key %s: %v", serverKey, machinePrivKey.Public(), err)
return regen, url, fmt.Errorf("register request: %v", err) return regen, opt.URL, fmt.Errorf("register request: %v", err)
} }
// Log without PII: // Log without PII:
c.logf("RegisterReq: got response; nodeKeyExpired=%v, machineAuthorized=%v; authURL=%v", c.logf("RegisterReq: got response; nodeKeyExpired=%v, machineAuthorized=%v; authURL=%v",