ipn/ipnlocal: poc level of effort to mock the ACME server

Change-Id: I76bf6aa20fe83eb6d1052603bb26a72de3fbddca
Signed-off-by: Tom Proctor <tomhjp@users.noreply.github.com>
This commit is contained in:
Tom Proctor 2025-03-12 16:40:04 +00:00
parent bcff106b4b
commit b4212ae7e4
2 changed files with 79 additions and 18 deletions

View File

@ -137,7 +137,7 @@ func (b *LocalBackend) GetCertPEMWithValidity(ctx context.Context, domain string
if minValidity == 0 {
logf("starting async renewal")
// Start renewal in the background, return current valid cert.
b.goTracker.Go(func() { getCertPEM(context.Background(), b, cs, logf, traceACME, domain, now, minValidity) })
b.goTracker.Go(func() { b.getCertPEM(context.Background(), cs, logf, traceACME, domain, now, minValidity) })
return pair, nil
}
// If the caller requested a specific validity duration, fall through
@ -149,7 +149,7 @@ func (b *LocalBackend) GetCertPEMWithValidity(ctx context.Context, domain string
return nil, fmt.Errorf("retrieving cached TLS certificate failed with %w, and cert store is configured in read-only mode, not attempting to issue new certificate", err)
}
pair, err := getCertPEM(ctx, b, cs, logf, traceACME, domain, now, minValidity)
pair, err := b.getCertPEM(ctx, cs, logf, traceACME, domain, now, minValidity)
if err != nil {
logf("getCertPEM: %v", err)
return nil, err
@ -476,8 +476,7 @@ func getCertPEMCached(cs certStore, domain string, now time.Time) (p *TLSCertKey
}
// getCertPem checks if a cert needs to be renewed and if so, renews it.
// It can be overridden in tests.
var getCertPEM = func(ctx context.Context, b *LocalBackend, cs certStore, logf logger.Logf, traceACME func(any), domain string, now time.Time, minValidity time.Duration) (*TLSCertKeyPair, error) {
func (b *LocalBackend) getCertPEM(ctx context.Context, cs certStore, logf logger.Logf, traceACME func(any), domain string, now time.Time, minValidity time.Duration) (*TLSCertKeyPair, error) {
acmeMu.Lock()
defer acmeMu.Unlock()

View File

@ -14,7 +14,10 @@ import (
"crypto/x509/pkix"
"embed"
"encoding/pem"
"fmt"
"math/big"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"testing"
@ -23,8 +26,9 @@ import (
"github.com/google/go-cmp/cmp"
"tailscale.com/envknob"
"tailscale.com/ipn/store/mem"
"tailscale.com/tailcfg"
"tailscale.com/tstest"
"tailscale.com/types/logger"
"tailscale.com/types/netmap"
"tailscale.com/util/must"
)
@ -221,11 +225,13 @@ func TestDebugACMEDirectoryURL(t *testing.T) {
func TestGetCertPEMWithValidity(t *testing.T) {
const testDomain = "example.com"
b := &LocalBackend{
store: &mem.Store{},
varRoot: t.TempDir(),
ctx: context.Background(),
logf: t.Logf,
b := newTestLocalBackend(t)
b.varRoot = t.TempDir()
b.netMap = &netmap.NetworkMap{
DNS: tailcfg.DNSConfig{
CertDomains: []string{testDomain},
},
SelfNode: (&tailcfg.Node{}).View(),
}
certDir, err := b.certDir()
if err != nil {
@ -332,13 +338,8 @@ func TestGetCertPEMWithValidity(t *testing.T) {
}
})()
// Set to true if get getCertPEM is called. GetCertPEM can be called in a goroutine for async
// renewal or in the main goroutine if issuance is required to obtain valid TLS credentials.
getCertPemWasCalled := false
getCertPEM = func(ctx context.Context, b *LocalBackend, cs certStore, logf logger.Logf, traceACME func(any), domain string, now time.Time, minValidity time.Duration) (*TLSCertKeyPair, error) {
getCertPemWasCalled = true
return nil, nil
}
srv, certIssued := newFakeACMEServer(t)
t.Setenv("TS_DEBUG_ACME_DIRECTORY_URL", srv.URL+"/directory")
prevGoRoutines := b.goTracker.StartedGoroutines()
_, err = b.GetCertPEMWithValidity(context.Background(), testDomain, 0)
if (err != nil) != tt.wantErr {
@ -359,10 +360,71 @@ func TestGetCertPEMWithValidity(t *testing.T) {
t.Fatalf("wants getCertPem to be called async: %v, got called %v", tt.wantAsyncRenewal, gotAsyncRenewal)
}
// Verify that (non-async) issuance was started if expected.
gotIssuance := getCertPemWasCalled && !gotAsyncRenewal
gotIssuance := *certIssued && !gotAsyncRenewal
if tt.wantIssuance != gotIssuance {
t.Errorf("wants getCertPem to be called: %v, got called %v", tt.wantIssuance, gotIssuance)
}
})
}
}
// newFakeACMEServer does the minimum required work to allow our ACME client to
// be happy that it has successfully issued a cert.
func newFakeACMEServer(t *testing.T) (*httptest.Server, *bool) {
// Set to true if a cert is issued. May be issued in a goroutine for async
// renewal or in the main goroutine if issuance is required to obtain valid TLS credentials.
var certIssued bool
url := func(host, path string) string {
return fmt.Sprintf("http://%s%s", host, path)
}
validResponse := func(w http.ResponseWriter, body map[string]any) {
w.Header().Set("Replay-Nonce", "fake-nonce")
body["status"] = "valid"
writeJSON(w, body)
}
mux := http.NewServeMux()
mux.HandleFunc("GET /directory", func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Replay-Nonce", "fake-nonce")
writeJSON(w, map[string]any{
"newAccount": url(r.Host, "/acme/new-account"),
"newOrder": url(r.Host, "/acme/new-order"),
})
})
mux.HandleFunc("POST /acme/new-account", func(w http.ResponseWriter, r *http.Request) {
validResponse(w, map[string]any{})
})
mux.HandleFunc("POST /acme/new-order", func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Location", url(r.Host, "/acme/order/1"))
w.WriteHeader(http.StatusCreated)
// We don't present any challenges, so the client skips straight to
// waiting for the order without having to mock the DNS challenge.
validResponse(w, map[string]any{})
})
mux.HandleFunc("POST /acme/order/1", func(w http.ResponseWriter, r *http.Request) {
validResponse(w, map[string]any{
"finalize": url(r.Host, "/acme/finalize"),
})
})
mux.HandleFunc("POST /acme/finalize", func(w http.ResponseWriter, r *http.Request) {
validResponse(w, map[string]any{
"certificate": url(r.Host, "/acme/cert"),
})
certIssued = true
})
mux.HandleFunc("POST /acme/cert", func(w http.ResponseWriter, r *http.Request) {
certPEM, err := certTestFS.ReadFile("testdata/example.com.pem")
if err != nil {
t.Fatal(err)
}
w.Write(certPEM)
})
mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
t.Fatalf("unimplemented: %s %s", r.Method, r.URL.Path)
})
srv := httptest.NewServer(mux)
t.Cleanup(srv.Close)
return srv, &certIssued
}