ipn/{ipnlocal,localapi}: actually renew certs before expiry (#8731)

While our `shouldStartDomainRenewal` check is correct, `getCertPEM`
would always bail if the existing cert is not expired. Add the same
`shouldStartDomainRenewal` check to `getCertPEM` to make it proceed with
renewal when existing certs are still valid but should be renewed.

The extra check is expensive (ARI request towards LetsEncrypt), so cache
the last check result for 1hr to not degrade `tailscale serve`
performance.

Also, asynchronous renewal is great for `tailscale serve` but confusing
for `tailscale cert`. Add an explicit flag to `GetCertPEM` to force a
synchronous renewal for `tailscale cert`.

Fixes #8725

Signed-off-by: Andrew Lytvynov <awly@tailscale.com>
This commit is contained in:
Andrew Lytvynov 2023-07-27 12:29:40 -07:00 committed by GitHub
parent aa37be70cf
commit c1ecae13ab
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 63 additions and 40 deletions

View File

@ -53,8 +53,8 @@
// populate the on-disk cache and the rest should use that. // populate the on-disk cache and the rest should use that.
acmeMu sync.Mutex acmeMu sync.Mutex
renewMu sync.Mutex // lock order: don't hold acmeMu and renewMu at the same time renewMu sync.Mutex // lock order: acmeMu before renewMu
lastRenewCheck = map[string]time.Time{} renewCertAt = map[string]time.Time{}
) )
// certDir returns (creating if needed) the directory in which cached // certDir returns (creating if needed) the directory in which cached
@ -80,9 +80,15 @@ func (b *LocalBackend) certDir() (string, error) {
var acmeDebug = envknob.RegisterBool("TS_DEBUG_ACME") var acmeDebug = envknob.RegisterBool("TS_DEBUG_ACME")
// getCertPEM gets the KeyPair for domain, either from cache, via the ACME // GetCertPEM gets the TLSCertKeyPair for domain, either from cache or via the
// process, or from cache and kicking off an async ACME renewal. // ACME process. ACME process is used for new domain certs, existing expired
func (b *LocalBackend) GetCertPEM(ctx context.Context, domain string) (*TLSCertKeyPair, error) { // certs or existing certs that should get renewed due to upcoming expiry.
//
// syncRenewal changes renewal behavior for existing certs that are still valid
// but need renewal. When syncRenewal is set, the method blocks until a new
// cert is issued. When syncRenewal is not set, existing cert is returned right
// away and renewal is kicked off in a background goroutine.
func (b *LocalBackend) GetCertPEM(ctx context.Context, domain string, syncRenewal bool) (*TLSCertKeyPair, error) {
if !validLookingCertDomain(domain) { if !validLookingCertDomain(domain) {
return nil, errors.New("invalid domain") return nil, errors.New("invalid domain")
} }
@ -105,12 +111,15 @@ func (b *LocalBackend) GetCertPEM(ctx context.Context, domain string) (*TLSCertK
shouldRenew, err := b.shouldStartDomainRenewal(cs, domain, now, pair) shouldRenew, err := b.shouldStartDomainRenewal(cs, domain, now, pair)
if err != nil { if err != nil {
logf("error checking for certificate renewal: %v", err) logf("error checking for certificate renewal: %v", err)
} else if shouldRenew { } else if !shouldRenew {
return pair, nil
}
if !syncRenewal {
logf("starting async renewal") logf("starting async renewal")
// Start renewal in the background. // Start renewal in the background.
go b.getCertPEM(context.Background(), cs, logf, traceACME, domain, now) go b.getCertPEM(context.Background(), cs, logf, traceACME, domain, now)
} }
return pair, nil // Synchronous renewal happens below.
} }
pair, err := b.getCertPEM(ctx, cs, logf, traceACME, domain, now) pair, err := b.getCertPEM(ctx, cs, logf, traceACME, domain, now)
@ -124,37 +133,43 @@ func (b *LocalBackend) GetCertPEM(ctx context.Context, domain string) (*TLSCertK
func (b *LocalBackend) shouldStartDomainRenewal(cs certStore, domain string, now time.Time, pair *TLSCertKeyPair) (bool, error) { func (b *LocalBackend) shouldStartDomainRenewal(cs certStore, domain string, now time.Time, pair *TLSCertKeyPair) (bool, error) {
renewMu.Lock() renewMu.Lock()
defer renewMu.Unlock() defer renewMu.Unlock()
if last, ok := lastRenewCheck[domain]; ok && now.Sub(last) < time.Minute { if renewAt, ok := renewCertAt[domain]; ok {
// We checked very recently. Don't bother reparsing & return now.After(renewAt), nil
// validating the x509 cert.
return false, nil
} }
lastRenewCheck[domain] = now
renew, err := b.shouldStartDomainRenewalByARI(cs, now, pair) renewTime, err := b.domainRenewalTimeByARI(cs, pair)
if err != nil { if err != nil {
// Log any ARI failure and fall back to checking for renewal by expiry. // Log any ARI failure and fall back to checking for renewal by expiry.
b.logf("acme: ARI check failed: %v; falling back to expiry-based check", err) b.logf("acme: ARI check failed: %v; falling back to expiry-based check", err)
} else { renewTime, err = b.domainRenewalTimeByExpiry(pair)
return renew, nil if err != nil {
return false, err
}
} }
return b.shouldStartDomainRenewalByExpiry(now, pair) renewCertAt[domain] = renewTime
return now.After(renewTime), nil
} }
func (b *LocalBackend) shouldStartDomainRenewalByExpiry(now time.Time, pair *TLSCertKeyPair) (bool, error) { func (b *LocalBackend) domainRenewed(domain string) {
renewMu.Lock()
defer renewMu.Unlock()
delete(renewCertAt, domain)
}
func (b *LocalBackend) domainRenewalTimeByExpiry(pair *TLSCertKeyPair) (time.Time, error) {
block, _ := pem.Decode(pair.CertPEM) block, _ := pem.Decode(pair.CertPEM)
if block == nil { if block == nil {
return false, fmt.Errorf("parsing certificate PEM") return time.Time{}, fmt.Errorf("parsing certificate PEM")
} }
cert, err := x509.ParseCertificate(block.Bytes) cert, err := x509.ParseCertificate(block.Bytes)
if err != nil { if err != nil {
return false, fmt.Errorf("parsing certificate: %w", err) return time.Time{}, fmt.Errorf("parsing certificate: %w", err)
} }
certLifetime := cert.NotAfter.Sub(cert.NotBefore) certLifetime := cert.NotAfter.Sub(cert.NotBefore)
if certLifetime < 0 { if certLifetime < 0 {
return false, fmt.Errorf("negative certificate lifetime %v", certLifetime) return time.Time{}, fmt.Errorf("negative certificate lifetime %v", certLifetime)
} }
// Per https://github.com/tailscale/tailscale/issues/8204, check // Per https://github.com/tailscale/tailscale/issues/8204, check
@ -163,36 +178,32 @@ func (b *LocalBackend) shouldStartDomainRenewalByExpiry(now time.Time, pair *TLS
// Encrypt. // Encrypt.
renewalDuration := certLifetime * 2 / 3 renewalDuration := certLifetime * 2 / 3
renewAt := cert.NotBefore.Add(renewalDuration) renewAt := cert.NotBefore.Add(renewalDuration)
return renewAt, nil
if now.After(renewAt) {
return true, nil
}
return false, nil
} }
func (b *LocalBackend) shouldStartDomainRenewalByARI(cs certStore, now time.Time, pair *TLSCertKeyPair) (bool, error) { func (b *LocalBackend) domainRenewalTimeByARI(cs certStore, pair *TLSCertKeyPair) (time.Time, error) {
var blocks []*pem.Block var blocks []*pem.Block
rest := pair.CertPEM rest := pair.CertPEM
for len(rest) > 0 { for len(rest) > 0 {
var block *pem.Block var block *pem.Block
block, rest = pem.Decode(rest) block, rest = pem.Decode(rest)
if block == nil { if block == nil {
return false, fmt.Errorf("parsing certificate PEM") return time.Time{}, fmt.Errorf("parsing certificate PEM")
} }
blocks = append(blocks, block) blocks = append(blocks, block)
} }
if len(blocks) < 2 { if len(blocks) < 2 {
return false, fmt.Errorf("could not parse certificate chain from certStore, got %d PEM block(s)", len(blocks)) return time.Time{}, fmt.Errorf("could not parse certificate chain from certStore, got %d PEM block(s)", len(blocks))
} }
ac, err := acmeClient(cs) ac, err := acmeClient(cs)
if err != nil { if err != nil {
return false, err return time.Time{}, err
} }
ctx, cancel := context.WithTimeout(b.ctx, 5*time.Second) ctx, cancel := context.WithTimeout(b.ctx, 5*time.Second)
defer cancel() defer cancel()
ri, err := ac.FetchRenewalInfo(ctx, blocks[0].Bytes, blocks[1].Bytes) ri, err := ac.FetchRenewalInfo(ctx, blocks[0].Bytes, blocks[1].Bytes)
if err != nil { if err != nil {
return false, fmt.Errorf("failed to fetch renewal info from ACME server: %w", err) return time.Time{}, fmt.Errorf("failed to fetch renewal info from ACME server: %w", err)
} }
if acmeDebug() { if acmeDebug() {
b.logf("acme: ARI response: %+v", ri) b.logf("acme: ARI response: %+v", ri)
@ -203,7 +214,7 @@ func (b *LocalBackend) shouldStartDomainRenewalByARI(cs certStore, now time.Time
// https://datatracker.ietf.org/doc/draft-ietf-acme-ari/ // https://datatracker.ietf.org/doc/draft-ietf-acme-ari/
start, end := ri.SuggestedWindow.Start, ri.SuggestedWindow.End start, end := ri.SuggestedWindow.Start, ri.SuggestedWindow.End
renewTime := start.Add(time.Duration(insecurerand.Int63n(int64(end.Sub(start))))) renewTime := start.Add(time.Duration(insecurerand.Int63n(int64(end.Sub(start)))))
return now.After(renewTime), nil return renewTime, nil
} }
// certStore provides a way to perist and retrieve TLS certificates. // certStore provides a way to perist and retrieve TLS certificates.
@ -371,8 +382,18 @@ func (b *LocalBackend) getCertPEM(ctx context.Context, cs certStore, logf logger
acmeMu.Lock() acmeMu.Lock()
defer acmeMu.Unlock() defer acmeMu.Unlock()
// In case this method was triggered multiple times in parallel (when
// serving incoming requests), check whether one of the other goroutines
// already renewed the cert before us.
if p, err := getCertPEMCached(cs, domain, now); err == nil { if p, err := getCertPEMCached(cs, domain, now); err == nil {
return p, nil // shouldStartDomainRenewal caches its result so it's OK to call this
// frequently.
shouldRenew, err := b.shouldStartDomainRenewal(cs, domain, now, p)
if err != nil {
logf("error checking for certificate renewal: %v", err)
} else if !shouldRenew {
return p, nil
}
} else if !errors.Is(err, ipn.ErrStateNotExist) && !errors.Is(err, errCertExpired) { } else if !errors.Is(err, ipn.ErrStateNotExist) && !errors.Is(err, errCertExpired) {
return nil, err return nil, err
} }
@ -509,6 +530,7 @@ func (b *LocalBackend) getCertPEM(ctx context.Context, cs certStore, logf logger
if err := cs.WriteCert(domain, certPEM.Bytes()); err != nil { if err := cs.WriteCert(domain, certPEM.Bytes()); err != nil {
return nil, err return nil, err
} }
b.domainRenewed(domain)
return &TLSCertKeyPair{CertPEM: certPEM.Bytes(), KeyPEM: privPEM.Bytes()}, nil return &TLSCertKeyPair{CertPEM: certPEM.Bytes(), KeyPEM: privPEM.Bytes()}, nil
} }

View File

@ -12,6 +12,6 @@ type TLSCertKeyPair struct {
CertPEM, KeyPEM []byte CertPEM, KeyPEM []byte
} }
func (b *LocalBackend) GetCertPEM(ctx context.Context, domain string) (*TLSCertKeyPair, error) { func (b *LocalBackend) GetCertPEM(ctx context.Context, domain string, syncRenewal bool) (*TLSCertKeyPair, error) {
return nil, errors.New("not implemented for js/wasm") return nil, errors.New("not implemented for js/wasm")
} }

View File

@ -112,7 +112,7 @@ func TestShouldStartDomainRenewal(t *testing.T) {
reset := func() { reset := func() {
renewMu.Lock() renewMu.Lock()
defer renewMu.Unlock() defer renewMu.Unlock()
maps.Clear(lastRenewCheck) maps.Clear(renewCertAt)
} }
mustMakePair := func(template *x509.Certificate) *TLSCertKeyPair { mustMakePair := func(template *x509.Certificate) *TLSCertKeyPair {
@ -178,7 +178,7 @@ func TestShouldStartDomainRenewal(t *testing.T) {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
reset() reset()
ret, err := b.shouldStartDomainRenewalByExpiry(now, mustMakePair(&x509.Certificate{ ret, err := b.domainRenewalTimeByExpiry(mustMakePair(&x509.Certificate{
SerialNumber: big.NewInt(2019), SerialNumber: big.NewInt(2019),
Subject: subject, Subject: subject,
NotBefore: tt.notBefore, NotBefore: tt.notBefore,
@ -192,8 +192,9 @@ func TestShouldStartDomainRenewal(t *testing.T) {
t.Errorf("got err=%q, want %q", err.Error(), tt.wantErr) t.Errorf("got err=%q, want %q", err.Error(), tt.wantErr)
} }
} else { } else {
if ret != tt.want { renew := now.After(ret)
t.Errorf("got ret=%v, want %v", ret, tt.want) if renew != tt.want {
t.Errorf("got renew=%v (ret=%v), want renew %v", renew, ret, tt.want)
} }
} }
}) })

View File

@ -372,7 +372,7 @@ func (b *LocalBackend) tcpHandlerForServe(dport uint16, srcAddr netip.AddrPort)
GetCertificate: func(hi *tls.ClientHelloInfo) (*tls.Certificate, error) { GetCertificate: func(hi *tls.ClientHelloInfo) (*tls.Certificate, error) {
ctx, cancel := context.WithTimeout(context.Background(), time.Minute) ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
defer cancel() defer cancel()
pair, err := b.GetCertPEM(ctx, sni) pair, err := b.GetCertPEM(ctx, sni, false)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -675,7 +675,7 @@ func (b *LocalBackend) getTLSServeCertForPort(port uint16) func(hi *tls.ClientHe
ctx, cancel := context.WithTimeout(context.Background(), time.Minute) ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
defer cancel() defer cancel()
pair, err := b.GetCertPEM(ctx, hi.ServerName) pair, err := b.GetCertPEM(ctx, hi.ServerName, false)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -23,7 +23,7 @@ func (h *Handler) serveCert(w http.ResponseWriter, r *http.Request) {
http.Error(w, "internal handler config wired wrong", 500) http.Error(w, "internal handler config wired wrong", 500)
return return
} }
pair, err := h.b.GetCertPEM(r.Context(), domain) pair, err := h.b.GetCertPEM(r.Context(), domain, true)
if err != nil { if err != nil {
// TODO(bradfitz): 500 is a little lazy here. The errors returned from // TODO(bradfitz): 500 is a little lazy here. The errors returned from
// GetCertPEM (and everywhere) should carry info info to get whether // GetCertPEM (and everywhere) should carry info info to get whether