diff --git a/control/controlclient/direct.go b/control/controlclient/direct.go index 8becba10b..d3643f719 100644 --- a/control/controlclient/direct.go +++ b/control/controlclient/direct.go @@ -47,6 +47,7 @@ "tailscale.com/types/opt" "tailscale.com/types/persist" "tailscale.com/util/clientmetric" + "tailscale.com/util/multierr" "tailscale.com/util/systemd" "tailscale.com/wgengine/monitor" ) @@ -68,8 +69,10 @@ type Direct struct { skipIPForwardingCheck bool pinger Pinger - mu sync.Mutex // mutex guards the following fields - serverKey key.MachinePublic + mu sync.Mutex // mutex guards the following fields + serverKey key.MachinePublic // original ("legacy") nacl crypto_box-based public key + serverNoiseKey key.MachinePublic + persist persist.Persist authKey string tryingNewKey key.NodePrivate @@ -326,16 +329,17 @@ func (c *Direct) doLogin(ctx context.Context, opt loginOpt) (mustRegen bool, new c.logf("doLogin(regen=%v, hasUrl=%v)", regen, opt.URL != "") if serverKey.IsZero() { - var err error - serverKey, err = loadServerKey(ctx, c.httpc, c.serverURL) + keys, err := loadServerPubKeys(ctx, c.httpc, c.serverURL) if err != nil { return regen, opt.URL, err } c.logf("control server key %s from %s", serverKey.ShortString(), c.serverURL) c.mu.Lock() - c.serverKey = serverKey + c.serverKey = keys.LegacyPublicKey + c.serverNoiseKey = keys.PublicKey c.mu.Unlock() + serverKey = keys.LegacyPublicKey } var oldNodeKey key.NodePublic @@ -950,29 +954,39 @@ func encode(v interface{}, serverKey key.MachinePublic, mkey key.MachinePrivate) return mkey.SealTo(serverKey, b), nil } -func loadServerKey(ctx context.Context, httpc *http.Client, serverURL string) (key.MachinePublic, error) { - req, err := http.NewRequest("GET", serverURL+"/key", nil) +func loadServerPubKeys(ctx context.Context, httpc *http.Client, serverURL string) (*tailcfg.OverTLSPublicKeyResponse, error) { + keyURL := fmt.Sprintf("%v/key?v=%d", serverURL, tailcfg.CurrentCapabilityVersion) + req, err := http.NewRequestWithContext(ctx, "GET", keyURL, nil) if err != nil { - return key.MachinePublic{}, fmt.Errorf("create control key request: %v", err) + return nil, fmt.Errorf("create control key request: %v", err) } - req = req.WithContext(ctx) res, err := httpc.Do(req) if err != nil { - return key.MachinePublic{}, fmt.Errorf("fetch control key: %v", err) + return nil, fmt.Errorf("fetch control key: %v", err) } defer res.Body.Close() - b, err := ioutil.ReadAll(io.LimitReader(res.Body, 1<<16)) + b, err := ioutil.ReadAll(io.LimitReader(res.Body, 64<<10)) if err != nil { - return key.MachinePublic{}, fmt.Errorf("fetch control key response: %v", err) + return nil, fmt.Errorf("fetch control key response: %v", err) } if res.StatusCode != 200 { - return key.MachinePublic{}, fmt.Errorf("fetch control key: %d: %s", res.StatusCode, string(b)) + return nil, fmt.Errorf("fetch control key: %d", res.StatusCode) } + var out tailcfg.OverTLSPublicKeyResponse + jsonErr := json.Unmarshal(b, &out) + if jsonErr == nil { + return &out, nil + } + + // Some old control servers might not be updated to send the new format. + // Accept the old pre-JSON format too. + out = tailcfg.OverTLSPublicKeyResponse{} k, err := key.ParseMachinePublicUntyped(mem.B(b)) if err != nil { - return key.MachinePublic{}, fmt.Errorf("fetch control key: %v", err) + return nil, multierr.New(jsonErr, err) } - return k, nil + out.LegacyPublicKey = k + return &out, nil } // Debug contains temporary internal-only debug knobs. diff --git a/tstest/integration/testcontrol/testcontrol.go b/tstest/integration/testcontrol/testcontrol.go index 4f481c5d0..5a7dd14cb 100644 --- a/tstest/integration/testcontrol/testcontrol.go +++ b/tstest/integration/testcontrol/testcontrol.go @@ -53,11 +53,15 @@ type Server struct { initMuxOnce sync.Once mux *http.ServeMux - mu sync.Mutex - inServeMap int - cond *sync.Cond // lazily initialized by condLocked - pubKey key.MachinePublic - privKey key.ControlPrivate // not strictly needed vs. MachinePrivate, but handy to test type interactions. + mu sync.Mutex + inServeMap int + cond *sync.Cond // lazily initialized by condLocked + pubKey key.MachinePublic + privKey key.ControlPrivate // not strictly needed vs. MachinePrivate, but handy to test type interactions. + + noisePubKey key.MachinePublic + noisePrivKey key.ControlPrivate // not strictly needed vs. MachinePrivate, but handy to test type interactions. + nodes map[key.NodePublic]*tailcfg.Node users map[key.NodePublic]*tailcfg.User logins map[key.NodePublic]*tailcfg.Login @@ -211,30 +215,42 @@ func (s *Server) serveUnhandled(w http.ResponseWriter, r *http.Request) { go panic(fmt.Sprintf("testcontrol.Server received unhandled request: %s", got.Bytes())) } -func (s *Server) publicKey() key.MachinePublic { - pub, _ := s.keyPair() - return pub +func (s *Server) publicKeys() (noiseKey, pubKey key.MachinePublic) { + s.mu.Lock() + defer s.mu.Unlock() + s.ensureKeyPairLocked() + return s.noisePubKey, s.pubKey } func (s *Server) privateKey() key.ControlPrivate { - _, priv := s.keyPair() - return priv -} - -func (s *Server) keyPair() (pub key.MachinePublic, priv key.ControlPrivate) { s.mu.Lock() defer s.mu.Unlock() - if s.pubKey.IsZero() { - s.privKey = key.NewControl() - s.pubKey = s.privKey.Public() + s.ensureKeyPairLocked() + return s.privKey +} + +func (s *Server) ensureKeyPairLocked() { + if !s.pubKey.IsZero() { + return } - return s.pubKey, s.privKey + s.noisePrivKey = key.NewControl() + s.noisePubKey = s.noisePrivKey.Public() + s.privKey = key.NewControl() + s.pubKey = s.privKey.Public() } func (s *Server) serveKey(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "text/plain") - w.WriteHeader(200) - io.WriteString(w, s.publicKey().UntypedHexString()) + noiseKey, legacyKey := s.publicKeys() + if r.FormValue("v") == "" { + w.Header().Set("Content-Type", "text/plain") + io.WriteString(w, legacyKey.UntypedHexString()) + return + } + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(&tailcfg.OverTLSPublicKeyResponse{ + LegacyPublicKey: legacyKey, + PublicKey: noiseKey, + }) } func (s *Server) serveMachine(w http.ResponseWriter, r *http.Request) { @@ -245,6 +261,7 @@ func (s *Server) serveMachine(w http.ResponseWriter, r *http.Request) { mkeyStr = mkeyStr[:i] } + // TODO(maisem/bradfitz): support noise protocol here. mkey, err := key.ParseMachinePublicUntyped(mem.S(mkeyStr)) if err != nil { http.Error(w, "bad machine key hex", 400)