ipn/localapi: refactor some cert code in prep for a move

I want to move the guts (after the HTTP layer) of the certificate
fetching into the ipnlocal package, out of localapi.

As prep, refactor a bit:

* add a method to do the fetch-from-cert-or-as-needed-with-refresh,
  rather than doing it in the HTTP hander

* convert two methods to funcs, taking the one extra field (LocalBackend)
  then needed from their method receiver. One of the methods needed
  nothing from its receiver.

This will make a future change easier to reason about.

Change-Id: I2a7811e5d7246139927bb86e7db8009bf09b3be3
Signed-off-by: Brad Fitzpatrick <bradfitz@tailscale.com>
This commit is contained in:
Brad Fitzpatrick 2022-11-07 20:49:46 -08:00 committed by Brad Fitzpatrick
parent 847a8cf917
commit 9be8d15979

View File

@ -34,6 +34,7 @@
"golang.org/x/crypto/acme" "golang.org/x/crypto/acme"
"tailscale.com/envknob" "tailscale.com/envknob"
"tailscale.com/ipn/ipnlocal"
"tailscale.com/ipn/ipnstate" "tailscale.com/ipn/ipnstate"
"tailscale.com/types/logger" "tailscale.com/types/logger"
"tailscale.com/util/strs" "tailscale.com/util/strs"
@ -79,13 +80,6 @@ func (h *Handler) serveCert(w http.ResponseWriter, r *http.Request) {
http.Error(w, "cert access denied", http.StatusForbidden) http.Error(w, "cert access denied", http.StatusForbidden)
return return
} }
dir, err := h.certDir()
if err != nil {
h.logf("certDir: %v", err)
http.Error(w, "failed to get cert dir", 500)
return
}
domain, ok := strs.CutPrefix(r.URL.Path, "/localapi/v0/cert/") domain, ok := strs.CutPrefix(r.URL.Path, "/localapi/v0/cert/")
if !ok { if !ok {
http.Error(w, "internal handler config wired wrong", 500) http.Error(w, "internal handler config wired wrong", 500)
@ -95,8 +89,24 @@ func (h *Handler) serveCert(w http.ResponseWriter, r *http.Request) {
http.Error(w, "invalid domain", 400) http.Error(w, "invalid domain", 400)
return return
} }
now := time.Now() pair, err := h.getCertPEM(r.Context(), domain)
if err != nil {
http.Error(w, fmt.Sprint(err), 500)
return
}
serveKeyPair(w, r, pair)
}
// getCertPEM gets the KeyPair for domain, either from cache, via the ACME
// process, or from cache and kicking off an async ACME renewal.
func (h *Handler) getCertPEM(ctx context.Context, domain string) (*keyPair, error) {
logf := logger.WithPrefix(h.logf, fmt.Sprintf("cert(%q): ", domain)) logf := logger.WithPrefix(h.logf, fmt.Sprintf("cert(%q): ", domain))
dir, err := h.certDir()
if err != nil {
logf("failed to get certDir: %v", err)
return nil, err
}
now := time.Now()
traceACME := func(v any) { traceACME := func(v any) {
if !acmeDebug() { if !acmeDebug() {
return return
@ -105,24 +115,22 @@ func (h *Handler) serveCert(w http.ResponseWriter, r *http.Request) {
log.Printf("acme %T: %s", v, j) log.Printf("acme %T: %s", v, j)
} }
if pair, ok := h.getCertPEMCached(dir, domain, now); ok { if pair, ok := getCertPEMCached(dir, domain, now); ok {
future := now.AddDate(0, 0, 14) future := now.AddDate(0, 0, 14)
if h.shouldStartDomainRenewal(dir, domain, future) { if h.shouldStartDomainRenewal(dir, domain, future) {
logf("starting async renewal") logf("starting async renewal")
// Start renewal in the background. // Start renewal in the background.
go h.getCertPEM(context.Background(), logf, traceACME, dir, domain, future) go getCertPEM(context.Background(), h.b, logf, traceACME, dir, domain, future)
} }
serveKeyPair(w, r, pair) return pair, nil
return
} }
pair, err := h.getCertPEM(r.Context(), logf, traceACME, dir, domain, now) pair, err := getCertPEM(ctx, h.b, logf, traceACME, dir, domain, now)
if err != nil { if err != nil {
logf("getCertPEM: %v", err) logf("getCertPEM: %v", err)
http.Error(w, fmt.Sprint(err), 500) return nil, err
return
} }
serveKeyPair(w, r, pair) return pair, nil
} }
func (h *Handler) shouldStartDomainRenewal(dir, domain string, future time.Time) bool { func (h *Handler) shouldStartDomainRenewal(dir, domain string, future time.Time) bool {
@ -135,7 +143,7 @@ func (h *Handler) shouldStartDomainRenewal(dir, domain string, future time.Time)
return false return false
} }
lastRenewCheck[domain] = now lastRenewCheck[domain] = now
_, ok := h.getCertPEMCached(dir, domain, future) _, ok := getCertPEMCached(dir, domain, future)
return !ok return !ok
} }
@ -154,10 +162,12 @@ func serveKeyPair(w http.ResponseWriter, r *http.Request, p *keyPair) {
} }
} }
// keyPair is a TLS public and private key, and whether they were obtained
// from cache or freshly obtained.
type keyPair struct { type keyPair struct {
certPEM []byte certPEM []byte // public key, in PEM form
keyPEM []byte keyPEM []byte // private key, in PEM form
cached bool cached bool // whether result came from cache
} }
func keyFile(dir, domain string) string { return filepath.Join(dir, domain+".key") } func keyFile(dir, domain string) string { return filepath.Join(dir, domain+".key") }
@ -166,7 +176,7 @@ func certFile(dir, domain string) string { return filepath.Join(dir, domain+".cr
// getCertPEMCached returns a non-nil keyPair and true if a cached // getCertPEMCached returns a non-nil keyPair and true if a cached
// keypair for domain exists on disk in dir that is valid at the // keypair for domain exists on disk in dir that is valid at the
// provided now time. // provided now time.
func (h *Handler) getCertPEMCached(dir, domain string, now time.Time) (p *keyPair, ok bool) { func getCertPEMCached(dir, domain string, now time.Time) (p *keyPair, ok bool) {
if !validLookingCertDomain(domain) { if !validLookingCertDomain(domain) {
// Before we read files from disk using it, validate it's halfway // Before we read files from disk using it, validate it's halfway
// reasonable looking. // reasonable looking.
@ -181,11 +191,11 @@ func (h *Handler) getCertPEMCached(dir, domain string, now time.Time) (p *keyPai
return nil, false return nil, false
} }
func (h *Handler) getCertPEM(ctx context.Context, logf logger.Logf, traceACME func(any), dir, domain string, now time.Time) (*keyPair, error) { func getCertPEM(ctx context.Context, lb *ipnlocal.LocalBackend, logf logger.Logf, traceACME func(any), dir, domain string, now time.Time) (*keyPair, error) {
acmeMu.Lock() acmeMu.Lock()
defer acmeMu.Unlock() defer acmeMu.Unlock()
if p, ok := h.getCertPEMCached(dir, domain, now); ok { if p, ok := getCertPEMCached(dir, domain, now); ok {
return p, nil return p, nil
} }
@ -223,7 +233,7 @@ func (h *Handler) getCertPEM(ctx context.Context, logf logger.Logf, traceACME fu
} }
// Before hitting LetsEncrypt, see if this is a domain that Tailscale will do DNS challenges for. // Before hitting LetsEncrypt, see if this is a domain that Tailscale will do DNS challenges for.
st := h.b.StatusWithoutPeers() st := lb.StatusWithoutPeers()
if err := checkCertDomain(st, domain); err != nil { if err := checkCertDomain(st, domain); err != nil {
return nil, err return nil, err
} }
@ -260,7 +270,7 @@ func (h *Handler) getCertPEM(ctx context.Context, logf logger.Logf, traceACME fu
} }
if !ok { if !ok {
logf("starting SetDNS call...") logf("starting SetDNS call...")
err = h.b.SetDNS(ctx, key, rec) err = lb.SetDNS(ctx, key, rec)
if err != nil { if err != nil {
return nil, fmt.Errorf("SetDNS %q => %q: %w", key, rec, err) return nil, fmt.Errorf("SetDNS %q => %q: %w", key, rec, err)
} }