From c1ecae13ab708cef90905085f87729974f6c339d Mon Sep 17 00:00:00 2001 From: Andrew Lytvynov Date: Thu, 27 Jul 2023 12:29:40 -0700 Subject: [PATCH] ipn/{ipnlocal,localapi}: actually renew certs before expiry (#8731) While our `shouldStartDomainRenewal` check is correct, `getCertPEM` would always bail if the existing cert is not expired. Add the same `shouldStartDomainRenewal` check to `getCertPEM` to make it proceed with renewal when existing certs are still valid but should be renewed. The extra check is expensive (ARI request towards LetsEncrypt), so cache the last check result for 1hr to not degrade `tailscale serve` performance. Also, asynchronous renewal is great for `tailscale serve` but confusing for `tailscale cert`. Add an explicit flag to `GetCertPEM` to force a synchronous renewal for `tailscale cert`. Fixes #8725 Signed-off-by: Andrew Lytvynov --- ipn/ipnlocal/cert.go | 86 ++++++++++++++++++++++++--------------- ipn/ipnlocal/cert_js.go | 2 +- ipn/ipnlocal/cert_test.go | 9 ++-- ipn/ipnlocal/serve.go | 4 +- ipn/localapi/cert.go | 2 +- 5 files changed, 63 insertions(+), 40 deletions(-) diff --git a/ipn/ipnlocal/cert.go b/ipn/ipnlocal/cert.go index f5384276c..627cc7872 100644 --- a/ipn/ipnlocal/cert.go +++ b/ipn/ipnlocal/cert.go @@ -53,8 +53,8 @@ // populate the on-disk cache and the rest should use that. acmeMu sync.Mutex - renewMu sync.Mutex // lock order: don't hold acmeMu and renewMu at the same time - lastRenewCheck = map[string]time.Time{} + renewMu sync.Mutex // lock order: acmeMu before renewMu + renewCertAt = map[string]time.Time{} ) // certDir returns (creating if needed) the directory in which cached @@ -80,9 +80,15 @@ func (b *LocalBackend) certDir() (string, error) { var acmeDebug = envknob.RegisterBool("TS_DEBUG_ACME") -// getCertPEM gets the KeyPair for domain, either from cache, via the ACME -// process, or from cache and kicking off an async ACME renewal. -func (b *LocalBackend) GetCertPEM(ctx context.Context, domain string) (*TLSCertKeyPair, error) { +// GetCertPEM gets the TLSCertKeyPair for domain, either from cache or via the +// ACME process. ACME process is used for new domain certs, existing expired +// certs or existing certs that should get renewed due to upcoming expiry. +// +// syncRenewal changes renewal behavior for existing certs that are still valid +// but need renewal. When syncRenewal is set, the method blocks until a new +// cert is issued. When syncRenewal is not set, existing cert is returned right +// away and renewal is kicked off in a background goroutine. +func (b *LocalBackend) GetCertPEM(ctx context.Context, domain string, syncRenewal bool) (*TLSCertKeyPair, error) { if !validLookingCertDomain(domain) { return nil, errors.New("invalid domain") } @@ -105,12 +111,15 @@ func (b *LocalBackend) GetCertPEM(ctx context.Context, domain string) (*TLSCertK shouldRenew, err := b.shouldStartDomainRenewal(cs, domain, now, pair) if err != nil { logf("error checking for certificate renewal: %v", err) - } else if shouldRenew { + } else if !shouldRenew { + return pair, nil + } + if !syncRenewal { logf("starting async renewal") // Start renewal in the background. go b.getCertPEM(context.Background(), cs, logf, traceACME, domain, now) } - return pair, nil + // Synchronous renewal happens below. } pair, err := b.getCertPEM(ctx, cs, logf, traceACME, domain, now) @@ -124,37 +133,43 @@ func (b *LocalBackend) GetCertPEM(ctx context.Context, domain string) (*TLSCertK func (b *LocalBackend) shouldStartDomainRenewal(cs certStore, domain string, now time.Time, pair *TLSCertKeyPair) (bool, error) { renewMu.Lock() defer renewMu.Unlock() - if last, ok := lastRenewCheck[domain]; ok && now.Sub(last) < time.Minute { - // We checked very recently. Don't bother reparsing & - // validating the x509 cert. - return false, nil + if renewAt, ok := renewCertAt[domain]; ok { + return now.After(renewAt), nil } - lastRenewCheck[domain] = now - renew, err := b.shouldStartDomainRenewalByARI(cs, now, pair) + renewTime, err := b.domainRenewalTimeByARI(cs, pair) if err != nil { // Log any ARI failure and fall back to checking for renewal by expiry. b.logf("acme: ARI check failed: %v; falling back to expiry-based check", err) - } else { - return renew, nil + renewTime, err = b.domainRenewalTimeByExpiry(pair) + if err != nil { + return false, err + } } - return b.shouldStartDomainRenewalByExpiry(now, pair) + renewCertAt[domain] = renewTime + return now.After(renewTime), nil } -func (b *LocalBackend) shouldStartDomainRenewalByExpiry(now time.Time, pair *TLSCertKeyPair) (bool, error) { +func (b *LocalBackend) domainRenewed(domain string) { + renewMu.Lock() + defer renewMu.Unlock() + delete(renewCertAt, domain) +} + +func (b *LocalBackend) domainRenewalTimeByExpiry(pair *TLSCertKeyPair) (time.Time, error) { block, _ := pem.Decode(pair.CertPEM) if block == nil { - return false, fmt.Errorf("parsing certificate PEM") + return time.Time{}, fmt.Errorf("parsing certificate PEM") } cert, err := x509.ParseCertificate(block.Bytes) if err != nil { - return false, fmt.Errorf("parsing certificate: %w", err) + return time.Time{}, fmt.Errorf("parsing certificate: %w", err) } certLifetime := cert.NotAfter.Sub(cert.NotBefore) if certLifetime < 0 { - return false, fmt.Errorf("negative certificate lifetime %v", certLifetime) + return time.Time{}, fmt.Errorf("negative certificate lifetime %v", certLifetime) } // Per https://github.com/tailscale/tailscale/issues/8204, check @@ -163,36 +178,32 @@ func (b *LocalBackend) shouldStartDomainRenewalByExpiry(now time.Time, pair *TLS // Encrypt. renewalDuration := certLifetime * 2 / 3 renewAt := cert.NotBefore.Add(renewalDuration) - - if now.After(renewAt) { - return true, nil - } - return false, nil + return renewAt, nil } -func (b *LocalBackend) shouldStartDomainRenewalByARI(cs certStore, now time.Time, pair *TLSCertKeyPair) (bool, error) { +func (b *LocalBackend) domainRenewalTimeByARI(cs certStore, pair *TLSCertKeyPair) (time.Time, error) { var blocks []*pem.Block rest := pair.CertPEM for len(rest) > 0 { var block *pem.Block block, rest = pem.Decode(rest) if block == nil { - return false, fmt.Errorf("parsing certificate PEM") + return time.Time{}, fmt.Errorf("parsing certificate PEM") } blocks = append(blocks, block) } if len(blocks) < 2 { - return false, fmt.Errorf("could not parse certificate chain from certStore, got %d PEM block(s)", len(blocks)) + return time.Time{}, fmt.Errorf("could not parse certificate chain from certStore, got %d PEM block(s)", len(blocks)) } ac, err := acmeClient(cs) if err != nil { - return false, err + return time.Time{}, err } ctx, cancel := context.WithTimeout(b.ctx, 5*time.Second) defer cancel() ri, err := ac.FetchRenewalInfo(ctx, blocks[0].Bytes, blocks[1].Bytes) if err != nil { - return false, fmt.Errorf("failed to fetch renewal info from ACME server: %w", err) + return time.Time{}, fmt.Errorf("failed to fetch renewal info from ACME server: %w", err) } if acmeDebug() { b.logf("acme: ARI response: %+v", ri) @@ -203,7 +214,7 @@ func (b *LocalBackend) shouldStartDomainRenewalByARI(cs certStore, now time.Time // https://datatracker.ietf.org/doc/draft-ietf-acme-ari/ start, end := ri.SuggestedWindow.Start, ri.SuggestedWindow.End renewTime := start.Add(time.Duration(insecurerand.Int63n(int64(end.Sub(start))))) - return now.After(renewTime), nil + return renewTime, nil } // certStore provides a way to perist and retrieve TLS certificates. @@ -371,8 +382,18 @@ func (b *LocalBackend) getCertPEM(ctx context.Context, cs certStore, logf logger acmeMu.Lock() defer acmeMu.Unlock() + // In case this method was triggered multiple times in parallel (when + // serving incoming requests), check whether one of the other goroutines + // already renewed the cert before us. if p, err := getCertPEMCached(cs, domain, now); err == nil { - return p, nil + // shouldStartDomainRenewal caches its result so it's OK to call this + // frequently. + shouldRenew, err := b.shouldStartDomainRenewal(cs, domain, now, p) + if err != nil { + logf("error checking for certificate renewal: %v", err) + } else if !shouldRenew { + return p, nil + } } else if !errors.Is(err, ipn.ErrStateNotExist) && !errors.Is(err, errCertExpired) { return nil, err } @@ -509,6 +530,7 @@ func (b *LocalBackend) getCertPEM(ctx context.Context, cs certStore, logf logger if err := cs.WriteCert(domain, certPEM.Bytes()); err != nil { return nil, err } + b.domainRenewed(domain) return &TLSCertKeyPair{CertPEM: certPEM.Bytes(), KeyPEM: privPEM.Bytes()}, nil } diff --git a/ipn/ipnlocal/cert_js.go b/ipn/ipnlocal/cert_js.go index a5fdfc4ba..24defb47b 100644 --- a/ipn/ipnlocal/cert_js.go +++ b/ipn/ipnlocal/cert_js.go @@ -12,6 +12,6 @@ type TLSCertKeyPair struct { CertPEM, KeyPEM []byte } -func (b *LocalBackend) GetCertPEM(ctx context.Context, domain string) (*TLSCertKeyPair, error) { +func (b *LocalBackend) GetCertPEM(ctx context.Context, domain string, syncRenewal bool) (*TLSCertKeyPair, error) { return nil, errors.New("not implemented for js/wasm") } diff --git a/ipn/ipnlocal/cert_test.go b/ipn/ipnlocal/cert_test.go index 52ba13453..66d942032 100644 --- a/ipn/ipnlocal/cert_test.go +++ b/ipn/ipnlocal/cert_test.go @@ -112,7 +112,7 @@ func TestShouldStartDomainRenewal(t *testing.T) { reset := func() { renewMu.Lock() defer renewMu.Unlock() - maps.Clear(lastRenewCheck) + maps.Clear(renewCertAt) } mustMakePair := func(template *x509.Certificate) *TLSCertKeyPair { @@ -178,7 +178,7 @@ func TestShouldStartDomainRenewal(t *testing.T) { t.Run(tt.name, func(t *testing.T) { reset() - ret, err := b.shouldStartDomainRenewalByExpiry(now, mustMakePair(&x509.Certificate{ + ret, err := b.domainRenewalTimeByExpiry(mustMakePair(&x509.Certificate{ SerialNumber: big.NewInt(2019), Subject: subject, NotBefore: tt.notBefore, @@ -192,8 +192,9 @@ func TestShouldStartDomainRenewal(t *testing.T) { t.Errorf("got err=%q, want %q", err.Error(), tt.wantErr) } } else { - if ret != tt.want { - t.Errorf("got ret=%v, want %v", ret, tt.want) + renew := now.After(ret) + if renew != tt.want { + t.Errorf("got renew=%v (ret=%v), want renew %v", renew, ret, tt.want) } } }) diff --git a/ipn/ipnlocal/serve.go b/ipn/ipnlocal/serve.go index aa2c1a605..99330309b 100644 --- a/ipn/ipnlocal/serve.go +++ b/ipn/ipnlocal/serve.go @@ -372,7 +372,7 @@ func (b *LocalBackend) tcpHandlerForServe(dport uint16, srcAddr netip.AddrPort) GetCertificate: func(hi *tls.ClientHelloInfo) (*tls.Certificate, error) { ctx, cancel := context.WithTimeout(context.Background(), time.Minute) defer cancel() - pair, err := b.GetCertPEM(ctx, sni) + pair, err := b.GetCertPEM(ctx, sni, false) if err != nil { return nil, err } @@ -675,7 +675,7 @@ func (b *LocalBackend) getTLSServeCertForPort(port uint16) func(hi *tls.ClientHe ctx, cancel := context.WithTimeout(context.Background(), time.Minute) defer cancel() - pair, err := b.GetCertPEM(ctx, hi.ServerName) + pair, err := b.GetCertPEM(ctx, hi.ServerName, false) if err != nil { return nil, err } diff --git a/ipn/localapi/cert.go b/ipn/localapi/cert.go index 447c3bc3c..e1704cb49 100644 --- a/ipn/localapi/cert.go +++ b/ipn/localapi/cert.go @@ -23,7 +23,7 @@ func (h *Handler) serveCert(w http.ResponseWriter, r *http.Request) { http.Error(w, "internal handler config wired wrong", 500) return } - pair, err := h.b.GetCertPEM(r.Context(), domain) + pair, err := h.b.GetCertPEM(r.Context(), domain, true) if err != nil { // TODO(bradfitz): 500 is a little lazy here. The errors returned from // GetCertPEM (and everywhere) should carry info info to get whether