ipn/ipnlocal: poc level of effort to mock the ACME server

Change-Id: I76bf6aa20fe83eb6d1052603bb26a72de3fbddca
Signed-off-by: Tom Proctor <tomhjp@users.noreply.github.com>
This commit is contained in:
Tom Proctor 2025-03-12 16:40:04 +00:00
parent bcff106b4b
commit b4212ae7e4
2 changed files with 79 additions and 18 deletions

View File

@ -137,7 +137,7 @@ func (b *LocalBackend) GetCertPEMWithValidity(ctx context.Context, domain string
if minValidity == 0 { if minValidity == 0 {
logf("starting async renewal") logf("starting async renewal")
// Start renewal in the background, return current valid cert. // Start renewal in the background, return current valid cert.
b.goTracker.Go(func() { getCertPEM(context.Background(), b, cs, logf, traceACME, domain, now, minValidity) }) b.goTracker.Go(func() { b.getCertPEM(context.Background(), cs, logf, traceACME, domain, now, minValidity) })
return pair, nil return pair, nil
} }
// If the caller requested a specific validity duration, fall through // If the caller requested a specific validity duration, fall through
@ -149,7 +149,7 @@ func (b *LocalBackend) GetCertPEMWithValidity(ctx context.Context, domain string
return nil, fmt.Errorf("retrieving cached TLS certificate failed with %w, and cert store is configured in read-only mode, not attempting to issue new certificate", err) return nil, fmt.Errorf("retrieving cached TLS certificate failed with %w, and cert store is configured in read-only mode, not attempting to issue new certificate", err)
} }
pair, err := getCertPEM(ctx, b, cs, logf, traceACME, domain, now, minValidity) pair, err := b.getCertPEM(ctx, cs, logf, traceACME, domain, now, minValidity)
if err != nil { if err != nil {
logf("getCertPEM: %v", err) logf("getCertPEM: %v", err)
return nil, err return nil, err
@ -476,8 +476,7 @@ func getCertPEMCached(cs certStore, domain string, now time.Time) (p *TLSCertKey
} }
// getCertPem checks if a cert needs to be renewed and if so, renews it. // getCertPem checks if a cert needs to be renewed and if so, renews it.
// It can be overridden in tests. func (b *LocalBackend) getCertPEM(ctx context.Context, cs certStore, logf logger.Logf, traceACME func(any), domain string, now time.Time, minValidity time.Duration) (*TLSCertKeyPair, error) {
var getCertPEM = func(ctx context.Context, b *LocalBackend, cs certStore, logf logger.Logf, traceACME func(any), domain string, now time.Time, minValidity time.Duration) (*TLSCertKeyPair, error) {
acmeMu.Lock() acmeMu.Lock()
defer acmeMu.Unlock() defer acmeMu.Unlock()

View File

@ -14,7 +14,10 @@ import (
"crypto/x509/pkix" "crypto/x509/pkix"
"embed" "embed"
"encoding/pem" "encoding/pem"
"fmt"
"math/big" "math/big"
"net/http"
"net/http/httptest"
"os" "os"
"path/filepath" "path/filepath"
"testing" "testing"
@ -23,8 +26,9 @@ import (
"github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp"
"tailscale.com/envknob" "tailscale.com/envknob"
"tailscale.com/ipn/store/mem" "tailscale.com/ipn/store/mem"
"tailscale.com/tailcfg"
"tailscale.com/tstest" "tailscale.com/tstest"
"tailscale.com/types/logger" "tailscale.com/types/netmap"
"tailscale.com/util/must" "tailscale.com/util/must"
) )
@ -221,11 +225,13 @@ func TestDebugACMEDirectoryURL(t *testing.T) {
func TestGetCertPEMWithValidity(t *testing.T) { func TestGetCertPEMWithValidity(t *testing.T) {
const testDomain = "example.com" const testDomain = "example.com"
b := &LocalBackend{ b := newTestLocalBackend(t)
store: &mem.Store{}, b.varRoot = t.TempDir()
varRoot: t.TempDir(), b.netMap = &netmap.NetworkMap{
ctx: context.Background(), DNS: tailcfg.DNSConfig{
logf: t.Logf, CertDomains: []string{testDomain},
},
SelfNode: (&tailcfg.Node{}).View(),
} }
certDir, err := b.certDir() certDir, err := b.certDir()
if err != nil { if err != nil {
@ -332,13 +338,8 @@ func TestGetCertPEMWithValidity(t *testing.T) {
} }
})() })()
// Set to true if get getCertPEM is called. GetCertPEM can be called in a goroutine for async srv, certIssued := newFakeACMEServer(t)
// renewal or in the main goroutine if issuance is required to obtain valid TLS credentials. t.Setenv("TS_DEBUG_ACME_DIRECTORY_URL", srv.URL+"/directory")
getCertPemWasCalled := false
getCertPEM = func(ctx context.Context, b *LocalBackend, cs certStore, logf logger.Logf, traceACME func(any), domain string, now time.Time, minValidity time.Duration) (*TLSCertKeyPair, error) {
getCertPemWasCalled = true
return nil, nil
}
prevGoRoutines := b.goTracker.StartedGoroutines() prevGoRoutines := b.goTracker.StartedGoroutines()
_, err = b.GetCertPEMWithValidity(context.Background(), testDomain, 0) _, err = b.GetCertPEMWithValidity(context.Background(), testDomain, 0)
if (err != nil) != tt.wantErr { if (err != nil) != tt.wantErr {
@ -359,10 +360,71 @@ func TestGetCertPEMWithValidity(t *testing.T) {
t.Fatalf("wants getCertPem to be called async: %v, got called %v", tt.wantAsyncRenewal, gotAsyncRenewal) t.Fatalf("wants getCertPem to be called async: %v, got called %v", tt.wantAsyncRenewal, gotAsyncRenewal)
} }
// Verify that (non-async) issuance was started if expected. // Verify that (non-async) issuance was started if expected.
gotIssuance := getCertPemWasCalled && !gotAsyncRenewal gotIssuance := *certIssued && !gotAsyncRenewal
if tt.wantIssuance != gotIssuance { if tt.wantIssuance != gotIssuance {
t.Errorf("wants getCertPem to be called: %v, got called %v", tt.wantIssuance, gotIssuance) t.Errorf("wants getCertPem to be called: %v, got called %v", tt.wantIssuance, gotIssuance)
} }
}) })
} }
} }
// newFakeACMEServer does the minimum required work to allow our ACME client to
// be happy that it has successfully issued a cert.
func newFakeACMEServer(t *testing.T) (*httptest.Server, *bool) {
// Set to true if a cert is issued. May be issued in a goroutine for async
// renewal or in the main goroutine if issuance is required to obtain valid TLS credentials.
var certIssued bool
url := func(host, path string) string {
return fmt.Sprintf("http://%s%s", host, path)
}
validResponse := func(w http.ResponseWriter, body map[string]any) {
w.Header().Set("Replay-Nonce", "fake-nonce")
body["status"] = "valid"
writeJSON(w, body)
}
mux := http.NewServeMux()
mux.HandleFunc("GET /directory", func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Replay-Nonce", "fake-nonce")
writeJSON(w, map[string]any{
"newAccount": url(r.Host, "/acme/new-account"),
"newOrder": url(r.Host, "/acme/new-order"),
})
})
mux.HandleFunc("POST /acme/new-account", func(w http.ResponseWriter, r *http.Request) {
validResponse(w, map[string]any{})
})
mux.HandleFunc("POST /acme/new-order", func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Location", url(r.Host, "/acme/order/1"))
w.WriteHeader(http.StatusCreated)
// We don't present any challenges, so the client skips straight to
// waiting for the order without having to mock the DNS challenge.
validResponse(w, map[string]any{})
})
mux.HandleFunc("POST /acme/order/1", func(w http.ResponseWriter, r *http.Request) {
validResponse(w, map[string]any{
"finalize": url(r.Host, "/acme/finalize"),
})
})
mux.HandleFunc("POST /acme/finalize", func(w http.ResponseWriter, r *http.Request) {
validResponse(w, map[string]any{
"certificate": url(r.Host, "/acme/cert"),
})
certIssued = true
})
mux.HandleFunc("POST /acme/cert", func(w http.ResponseWriter, r *http.Request) {
certPEM, err := certTestFS.ReadFile("testdata/example.com.pem")
if err != nil {
t.Fatal(err)
}
w.Write(certPEM)
})
mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
t.Fatalf("unimplemented: %s %s", r.Method, r.URL.Path)
})
srv := httptest.NewServer(mux)
t.Cleanup(srv.Close)
return srv, &certIssued
}