diff --git a/ipn/ipnlocal/cert.go b/ipn/ipnlocal/cert.go index 1405fdab4..b311fe5b4 100644 --- a/ipn/ipnlocal/cert.go +++ b/ipn/ipnlocal/cert.go @@ -137,7 +137,7 @@ func (b *LocalBackend) GetCertPEMWithValidity(ctx context.Context, domain string if minValidity == 0 { logf("starting async renewal") // Start renewal in the background, return current valid cert. - b.goTracker.Go(func() { getCertPEM(context.Background(), b, cs, logf, traceACME, domain, now, minValidity) }) + b.goTracker.Go(func() { b.getCertPEM(context.Background(), cs, logf, traceACME, domain, now, minValidity) }) return pair, nil } // If the caller requested a specific validity duration, fall through @@ -149,7 +149,7 @@ func (b *LocalBackend) GetCertPEMWithValidity(ctx context.Context, domain string return nil, fmt.Errorf("retrieving cached TLS certificate failed with %w, and cert store is configured in read-only mode, not attempting to issue new certificate", err) } - pair, err := getCertPEM(ctx, b, cs, logf, traceACME, domain, now, minValidity) + pair, err := b.getCertPEM(ctx, cs, logf, traceACME, domain, now, minValidity) if err != nil { logf("getCertPEM: %v", err) return nil, err @@ -476,8 +476,7 @@ func getCertPEMCached(cs certStore, domain string, now time.Time) (p *TLSCertKey } // getCertPem checks if a cert needs to be renewed and if so, renews it. -// It can be overridden in tests. -var getCertPEM = func(ctx context.Context, b *LocalBackend, cs certStore, logf logger.Logf, traceACME func(any), domain string, now time.Time, minValidity time.Duration) (*TLSCertKeyPair, error) { +func (b *LocalBackend) getCertPEM(ctx context.Context, cs certStore, logf logger.Logf, traceACME func(any), domain string, now time.Time, minValidity time.Duration) (*TLSCertKeyPair, error) { acmeMu.Lock() defer acmeMu.Unlock() diff --git a/ipn/ipnlocal/cert_test.go b/ipn/ipnlocal/cert_test.go index 8eec805c9..ce730aac8 100644 --- a/ipn/ipnlocal/cert_test.go +++ b/ipn/ipnlocal/cert_test.go @@ -14,7 +14,10 @@ import ( "crypto/x509/pkix" "embed" "encoding/pem" + "fmt" "math/big" + "net/http" + "net/http/httptest" "os" "path/filepath" "testing" @@ -23,8 +26,9 @@ import ( "github.com/google/go-cmp/cmp" "tailscale.com/envknob" "tailscale.com/ipn/store/mem" + "tailscale.com/tailcfg" "tailscale.com/tstest" - "tailscale.com/types/logger" + "tailscale.com/types/netmap" "tailscale.com/util/must" ) @@ -221,11 +225,13 @@ func TestDebugACMEDirectoryURL(t *testing.T) { func TestGetCertPEMWithValidity(t *testing.T) { const testDomain = "example.com" - b := &LocalBackend{ - store: &mem.Store{}, - varRoot: t.TempDir(), - ctx: context.Background(), - logf: t.Logf, + b := newTestLocalBackend(t) + b.varRoot = t.TempDir() + b.netMap = &netmap.NetworkMap{ + DNS: tailcfg.DNSConfig{ + CertDomains: []string{testDomain}, + }, + SelfNode: (&tailcfg.Node{}).View(), } certDir, err := b.certDir() if err != nil { @@ -332,13 +338,8 @@ func TestGetCertPEMWithValidity(t *testing.T) { } })() - // Set to true if get getCertPEM is called. GetCertPEM can be called in a goroutine for async - // renewal or in the main goroutine if issuance is required to obtain valid TLS credentials. - getCertPemWasCalled := false - getCertPEM = func(ctx context.Context, b *LocalBackend, cs certStore, logf logger.Logf, traceACME func(any), domain string, now time.Time, minValidity time.Duration) (*TLSCertKeyPair, error) { - getCertPemWasCalled = true - return nil, nil - } + srv, certIssued := newFakeACMEServer(t) + t.Setenv("TS_DEBUG_ACME_DIRECTORY_URL", srv.URL+"/directory") prevGoRoutines := b.goTracker.StartedGoroutines() _, err = b.GetCertPEMWithValidity(context.Background(), testDomain, 0) if (err != nil) != tt.wantErr { @@ -359,10 +360,71 @@ func TestGetCertPEMWithValidity(t *testing.T) { t.Fatalf("wants getCertPem to be called async: %v, got called %v", tt.wantAsyncRenewal, gotAsyncRenewal) } // Verify that (non-async) issuance was started if expected. - gotIssuance := getCertPemWasCalled && !gotAsyncRenewal + gotIssuance := *certIssued && !gotAsyncRenewal if tt.wantIssuance != gotIssuance { t.Errorf("wants getCertPem to be called: %v, got called %v", tt.wantIssuance, gotIssuance) } }) } } + +// newFakeACMEServer does the minimum required work to allow our ACME client to +// be happy that it has successfully issued a cert. +func newFakeACMEServer(t *testing.T) (*httptest.Server, *bool) { + // Set to true if a cert is issued. May be issued in a goroutine for async + // renewal or in the main goroutine if issuance is required to obtain valid TLS credentials. + var certIssued bool + + url := func(host, path string) string { + return fmt.Sprintf("http://%s%s", host, path) + } + validResponse := func(w http.ResponseWriter, body map[string]any) { + w.Header().Set("Replay-Nonce", "fake-nonce") + body["status"] = "valid" + writeJSON(w, body) + } + + mux := http.NewServeMux() + mux.HandleFunc("GET /directory", func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Replay-Nonce", "fake-nonce") + writeJSON(w, map[string]any{ + "newAccount": url(r.Host, "/acme/new-account"), + "newOrder": url(r.Host, "/acme/new-order"), + }) + }) + mux.HandleFunc("POST /acme/new-account", func(w http.ResponseWriter, r *http.Request) { + validResponse(w, map[string]any{}) + }) + mux.HandleFunc("POST /acme/new-order", func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Location", url(r.Host, "/acme/order/1")) + w.WriteHeader(http.StatusCreated) + // We don't present any challenges, so the client skips straight to + // waiting for the order without having to mock the DNS challenge. + validResponse(w, map[string]any{}) + }) + mux.HandleFunc("POST /acme/order/1", func(w http.ResponseWriter, r *http.Request) { + validResponse(w, map[string]any{ + "finalize": url(r.Host, "/acme/finalize"), + }) + }) + mux.HandleFunc("POST /acme/finalize", func(w http.ResponseWriter, r *http.Request) { + validResponse(w, map[string]any{ + "certificate": url(r.Host, "/acme/cert"), + }) + certIssued = true + }) + mux.HandleFunc("POST /acme/cert", func(w http.ResponseWriter, r *http.Request) { + certPEM, err := certTestFS.ReadFile("testdata/example.com.pem") + if err != nil { + t.Fatal(err) + } + w.Write(certPEM) + }) + mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { + t.Fatalf("unimplemented: %s %s", r.Method, r.URL.Path) + }) + + srv := httptest.NewServer(mux) + t.Cleanup(srv.Close) + return srv, &certIssued +}