diff --git a/ipn/ipnlocal/cert.go b/ipn/ipnlocal/cert.go index f5384276c..627cc7872 100644 --- a/ipn/ipnlocal/cert.go +++ b/ipn/ipnlocal/cert.go @@ -53,8 +53,8 @@ // populate the on-disk cache and the rest should use that. acmeMu sync.Mutex - renewMu sync.Mutex // lock order: don't hold acmeMu and renewMu at the same time - lastRenewCheck = map[string]time.Time{} + renewMu sync.Mutex // lock order: acmeMu before renewMu + renewCertAt = map[string]time.Time{} ) // certDir returns (creating if needed) the directory in which cached @@ -80,9 +80,15 @@ func (b *LocalBackend) certDir() (string, error) { var acmeDebug = envknob.RegisterBool("TS_DEBUG_ACME") -// getCertPEM gets the KeyPair for domain, either from cache, via the ACME -// process, or from cache and kicking off an async ACME renewal. -func (b *LocalBackend) GetCertPEM(ctx context.Context, domain string) (*TLSCertKeyPair, error) { +// GetCertPEM gets the TLSCertKeyPair for domain, either from cache or via the +// ACME process. ACME process is used for new domain certs, existing expired +// certs or existing certs that should get renewed due to upcoming expiry. +// +// syncRenewal changes renewal behavior for existing certs that are still valid +// but need renewal. When syncRenewal is set, the method blocks until a new +// cert is issued. When syncRenewal is not set, existing cert is returned right +// away and renewal is kicked off in a background goroutine. +func (b *LocalBackend) GetCertPEM(ctx context.Context, domain string, syncRenewal bool) (*TLSCertKeyPair, error) { if !validLookingCertDomain(domain) { return nil, errors.New("invalid domain") } @@ -105,12 +111,15 @@ func (b *LocalBackend) GetCertPEM(ctx context.Context, domain string) (*TLSCertK shouldRenew, err := b.shouldStartDomainRenewal(cs, domain, now, pair) if err != nil { logf("error checking for certificate renewal: %v", err) - } else if shouldRenew { + } else if !shouldRenew { + return pair, nil + } + if !syncRenewal { logf("starting async renewal") // Start renewal in the background. go b.getCertPEM(context.Background(), cs, logf, traceACME, domain, now) } - return pair, nil + // Synchronous renewal happens below. } pair, err := b.getCertPEM(ctx, cs, logf, traceACME, domain, now) @@ -124,37 +133,43 @@ func (b *LocalBackend) GetCertPEM(ctx context.Context, domain string) (*TLSCertK func (b *LocalBackend) shouldStartDomainRenewal(cs certStore, domain string, now time.Time, pair *TLSCertKeyPair) (bool, error) { renewMu.Lock() defer renewMu.Unlock() - if last, ok := lastRenewCheck[domain]; ok && now.Sub(last) < time.Minute { - // We checked very recently. Don't bother reparsing & - // validating the x509 cert. - return false, nil + if renewAt, ok := renewCertAt[domain]; ok { + return now.After(renewAt), nil } - lastRenewCheck[domain] = now - renew, err := b.shouldStartDomainRenewalByARI(cs, now, pair) + renewTime, err := b.domainRenewalTimeByARI(cs, pair) if err != nil { // Log any ARI failure and fall back to checking for renewal by expiry. b.logf("acme: ARI check failed: %v; falling back to expiry-based check", err) - } else { - return renew, nil + renewTime, err = b.domainRenewalTimeByExpiry(pair) + if err != nil { + return false, err + } } - return b.shouldStartDomainRenewalByExpiry(now, pair) + renewCertAt[domain] = renewTime + return now.After(renewTime), nil } -func (b *LocalBackend) shouldStartDomainRenewalByExpiry(now time.Time, pair *TLSCertKeyPair) (bool, error) { +func (b *LocalBackend) domainRenewed(domain string) { + renewMu.Lock() + defer renewMu.Unlock() + delete(renewCertAt, domain) +} + +func (b *LocalBackend) domainRenewalTimeByExpiry(pair *TLSCertKeyPair) (time.Time, error) { block, _ := pem.Decode(pair.CertPEM) if block == nil { - return false, fmt.Errorf("parsing certificate PEM") + return time.Time{}, fmt.Errorf("parsing certificate PEM") } cert, err := x509.ParseCertificate(block.Bytes) if err != nil { - return false, fmt.Errorf("parsing certificate: %w", err) + return time.Time{}, fmt.Errorf("parsing certificate: %w", err) } certLifetime := cert.NotAfter.Sub(cert.NotBefore) if certLifetime < 0 { - return false, fmt.Errorf("negative certificate lifetime %v", certLifetime) + return time.Time{}, fmt.Errorf("negative certificate lifetime %v", certLifetime) } // Per https://github.com/tailscale/tailscale/issues/8204, check @@ -163,36 +178,32 @@ func (b *LocalBackend) shouldStartDomainRenewalByExpiry(now time.Time, pair *TLS // Encrypt. renewalDuration := certLifetime * 2 / 3 renewAt := cert.NotBefore.Add(renewalDuration) - - if now.After(renewAt) { - return true, nil - } - return false, nil + return renewAt, nil } -func (b *LocalBackend) shouldStartDomainRenewalByARI(cs certStore, now time.Time, pair *TLSCertKeyPair) (bool, error) { +func (b *LocalBackend) domainRenewalTimeByARI(cs certStore, pair *TLSCertKeyPair) (time.Time, error) { var blocks []*pem.Block rest := pair.CertPEM for len(rest) > 0 { var block *pem.Block block, rest = pem.Decode(rest) if block == nil { - return false, fmt.Errorf("parsing certificate PEM") + return time.Time{}, fmt.Errorf("parsing certificate PEM") } blocks = append(blocks, block) } if len(blocks) < 2 { - return false, fmt.Errorf("could not parse certificate chain from certStore, got %d PEM block(s)", len(blocks)) + return time.Time{}, fmt.Errorf("could not parse certificate chain from certStore, got %d PEM block(s)", len(blocks)) } ac, err := acmeClient(cs) if err != nil { - return false, err + return time.Time{}, err } ctx, cancel := context.WithTimeout(b.ctx, 5*time.Second) defer cancel() ri, err := ac.FetchRenewalInfo(ctx, blocks[0].Bytes, blocks[1].Bytes) if err != nil { - return false, fmt.Errorf("failed to fetch renewal info from ACME server: %w", err) + return time.Time{}, fmt.Errorf("failed to fetch renewal info from ACME server: %w", err) } if acmeDebug() { b.logf("acme: ARI response: %+v", ri) @@ -203,7 +214,7 @@ func (b *LocalBackend) shouldStartDomainRenewalByARI(cs certStore, now time.Time // https://datatracker.ietf.org/doc/draft-ietf-acme-ari/ start, end := ri.SuggestedWindow.Start, ri.SuggestedWindow.End renewTime := start.Add(time.Duration(insecurerand.Int63n(int64(end.Sub(start))))) - return now.After(renewTime), nil + return renewTime, nil } // certStore provides a way to perist and retrieve TLS certificates. @@ -371,8 +382,18 @@ func (b *LocalBackend) getCertPEM(ctx context.Context, cs certStore, logf logger acmeMu.Lock() defer acmeMu.Unlock() + // In case this method was triggered multiple times in parallel (when + // serving incoming requests), check whether one of the other goroutines + // already renewed the cert before us. if p, err := getCertPEMCached(cs, domain, now); err == nil { - return p, nil + // shouldStartDomainRenewal caches its result so it's OK to call this + // frequently. + shouldRenew, err := b.shouldStartDomainRenewal(cs, domain, now, p) + if err != nil { + logf("error checking for certificate renewal: %v", err) + } else if !shouldRenew { + return p, nil + } } else if !errors.Is(err, ipn.ErrStateNotExist) && !errors.Is(err, errCertExpired) { return nil, err } @@ -509,6 +530,7 @@ func (b *LocalBackend) getCertPEM(ctx context.Context, cs certStore, logf logger if err := cs.WriteCert(domain, certPEM.Bytes()); err != nil { return nil, err } + b.domainRenewed(domain) return &TLSCertKeyPair{CertPEM: certPEM.Bytes(), KeyPEM: privPEM.Bytes()}, nil } diff --git a/ipn/ipnlocal/cert_js.go b/ipn/ipnlocal/cert_js.go index a5fdfc4ba..24defb47b 100644 --- a/ipn/ipnlocal/cert_js.go +++ b/ipn/ipnlocal/cert_js.go @@ -12,6 +12,6 @@ type TLSCertKeyPair struct { CertPEM, KeyPEM []byte } -func (b *LocalBackend) GetCertPEM(ctx context.Context, domain string) (*TLSCertKeyPair, error) { +func (b *LocalBackend) GetCertPEM(ctx context.Context, domain string, syncRenewal bool) (*TLSCertKeyPair, error) { return nil, errors.New("not implemented for js/wasm") } diff --git a/ipn/ipnlocal/cert_test.go b/ipn/ipnlocal/cert_test.go index 52ba13453..66d942032 100644 --- a/ipn/ipnlocal/cert_test.go +++ b/ipn/ipnlocal/cert_test.go @@ -112,7 +112,7 @@ func TestShouldStartDomainRenewal(t *testing.T) { reset := func() { renewMu.Lock() defer renewMu.Unlock() - maps.Clear(lastRenewCheck) + maps.Clear(renewCertAt) } mustMakePair := func(template *x509.Certificate) *TLSCertKeyPair { @@ -178,7 +178,7 @@ func TestShouldStartDomainRenewal(t *testing.T) { t.Run(tt.name, func(t *testing.T) { reset() - ret, err := b.shouldStartDomainRenewalByExpiry(now, mustMakePair(&x509.Certificate{ + ret, err := b.domainRenewalTimeByExpiry(mustMakePair(&x509.Certificate{ SerialNumber: big.NewInt(2019), Subject: subject, NotBefore: tt.notBefore, @@ -192,8 +192,9 @@ func TestShouldStartDomainRenewal(t *testing.T) { t.Errorf("got err=%q, want %q", err.Error(), tt.wantErr) } } else { - if ret != tt.want { - t.Errorf("got ret=%v, want %v", ret, tt.want) + renew := now.After(ret) + if renew != tt.want { + t.Errorf("got renew=%v (ret=%v), want renew %v", renew, ret, tt.want) } } }) diff --git a/ipn/ipnlocal/serve.go b/ipn/ipnlocal/serve.go index aa2c1a605..99330309b 100644 --- a/ipn/ipnlocal/serve.go +++ b/ipn/ipnlocal/serve.go @@ -372,7 +372,7 @@ func (b *LocalBackend) tcpHandlerForServe(dport uint16, srcAddr netip.AddrPort) GetCertificate: func(hi *tls.ClientHelloInfo) (*tls.Certificate, error) { ctx, cancel := context.WithTimeout(context.Background(), time.Minute) defer cancel() - pair, err := b.GetCertPEM(ctx, sni) + pair, err := b.GetCertPEM(ctx, sni, false) if err != nil { return nil, err } @@ -675,7 +675,7 @@ func (b *LocalBackend) getTLSServeCertForPort(port uint16) func(hi *tls.ClientHe ctx, cancel := context.WithTimeout(context.Background(), time.Minute) defer cancel() - pair, err := b.GetCertPEM(ctx, hi.ServerName) + pair, err := b.GetCertPEM(ctx, hi.ServerName, false) if err != nil { return nil, err } diff --git a/ipn/localapi/cert.go b/ipn/localapi/cert.go index 447c3bc3c..e1704cb49 100644 --- a/ipn/localapi/cert.go +++ b/ipn/localapi/cert.go @@ -23,7 +23,7 @@ func (h *Handler) serveCert(w http.ResponseWriter, r *http.Request) { http.Error(w, "internal handler config wired wrong", 500) return } - pair, err := h.b.GetCertPEM(r.Context(), domain) + pair, err := h.b.GetCertPEM(r.Context(), domain, true) if err != nil { // TODO(bradfitz): 500 is a little lazy here. The errors returned from // GetCertPEM (and everywhere) should carry info info to get whether