mirror of
https://github.com/tailscale/tailscale.git
synced 2025-07-29 15:23:45 +00:00
ipn/ipnlocal: poc level of effort to mock the ACME server
Change-Id: I76bf6aa20fe83eb6d1052603bb26a72de3fbddca Signed-off-by: Tom Proctor <tomhjp@users.noreply.github.com>
This commit is contained in:
parent
bcff106b4b
commit
b4212ae7e4
@ -137,7 +137,7 @@ func (b *LocalBackend) GetCertPEMWithValidity(ctx context.Context, domain string
|
||||
if minValidity == 0 {
|
||||
logf("starting async renewal")
|
||||
// Start renewal in the background, return current valid cert.
|
||||
b.goTracker.Go(func() { getCertPEM(context.Background(), b, cs, logf, traceACME, domain, now, minValidity) })
|
||||
b.goTracker.Go(func() { b.getCertPEM(context.Background(), cs, logf, traceACME, domain, now, minValidity) })
|
||||
return pair, nil
|
||||
}
|
||||
// If the caller requested a specific validity duration, fall through
|
||||
@ -149,7 +149,7 @@ func (b *LocalBackend) GetCertPEMWithValidity(ctx context.Context, domain string
|
||||
return nil, fmt.Errorf("retrieving cached TLS certificate failed with %w, and cert store is configured in read-only mode, not attempting to issue new certificate", err)
|
||||
}
|
||||
|
||||
pair, err := getCertPEM(ctx, b, cs, logf, traceACME, domain, now, minValidity)
|
||||
pair, err := b.getCertPEM(ctx, cs, logf, traceACME, domain, now, minValidity)
|
||||
if err != nil {
|
||||
logf("getCertPEM: %v", err)
|
||||
return nil, err
|
||||
@ -476,8 +476,7 @@ func getCertPEMCached(cs certStore, domain string, now time.Time) (p *TLSCertKey
|
||||
}
|
||||
|
||||
// getCertPem checks if a cert needs to be renewed and if so, renews it.
|
||||
// It can be overridden in tests.
|
||||
var getCertPEM = func(ctx context.Context, b *LocalBackend, cs certStore, logf logger.Logf, traceACME func(any), domain string, now time.Time, minValidity time.Duration) (*TLSCertKeyPair, error) {
|
||||
func (b *LocalBackend) getCertPEM(ctx context.Context, cs certStore, logf logger.Logf, traceACME func(any), domain string, now time.Time, minValidity time.Duration) (*TLSCertKeyPair, error) {
|
||||
acmeMu.Lock()
|
||||
defer acmeMu.Unlock()
|
||||
|
||||
|
@ -14,7 +14,10 @@ import (
|
||||
"crypto/x509/pkix"
|
||||
"embed"
|
||||
"encoding/pem"
|
||||
"fmt"
|
||||
"math/big"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
@ -23,8 +26,9 @@ import (
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"tailscale.com/envknob"
|
||||
"tailscale.com/ipn/store/mem"
|
||||
"tailscale.com/tailcfg"
|
||||
"tailscale.com/tstest"
|
||||
"tailscale.com/types/logger"
|
||||
"tailscale.com/types/netmap"
|
||||
"tailscale.com/util/must"
|
||||
)
|
||||
|
||||
@ -221,11 +225,13 @@ func TestDebugACMEDirectoryURL(t *testing.T) {
|
||||
|
||||
func TestGetCertPEMWithValidity(t *testing.T) {
|
||||
const testDomain = "example.com"
|
||||
b := &LocalBackend{
|
||||
store: &mem.Store{},
|
||||
varRoot: t.TempDir(),
|
||||
ctx: context.Background(),
|
||||
logf: t.Logf,
|
||||
b := newTestLocalBackend(t)
|
||||
b.varRoot = t.TempDir()
|
||||
b.netMap = &netmap.NetworkMap{
|
||||
DNS: tailcfg.DNSConfig{
|
||||
CertDomains: []string{testDomain},
|
||||
},
|
||||
SelfNode: (&tailcfg.Node{}).View(),
|
||||
}
|
||||
certDir, err := b.certDir()
|
||||
if err != nil {
|
||||
@ -332,13 +338,8 @@ func TestGetCertPEMWithValidity(t *testing.T) {
|
||||
}
|
||||
})()
|
||||
|
||||
// Set to true if get getCertPEM is called. GetCertPEM can be called in a goroutine for async
|
||||
// renewal or in the main goroutine if issuance is required to obtain valid TLS credentials.
|
||||
getCertPemWasCalled := false
|
||||
getCertPEM = func(ctx context.Context, b *LocalBackend, cs certStore, logf logger.Logf, traceACME func(any), domain string, now time.Time, minValidity time.Duration) (*TLSCertKeyPair, error) {
|
||||
getCertPemWasCalled = true
|
||||
return nil, nil
|
||||
}
|
||||
srv, certIssued := newFakeACMEServer(t)
|
||||
t.Setenv("TS_DEBUG_ACME_DIRECTORY_URL", srv.URL+"/directory")
|
||||
prevGoRoutines := b.goTracker.StartedGoroutines()
|
||||
_, err = b.GetCertPEMWithValidity(context.Background(), testDomain, 0)
|
||||
if (err != nil) != tt.wantErr {
|
||||
@ -359,10 +360,71 @@ func TestGetCertPEMWithValidity(t *testing.T) {
|
||||
t.Fatalf("wants getCertPem to be called async: %v, got called %v", tt.wantAsyncRenewal, gotAsyncRenewal)
|
||||
}
|
||||
// Verify that (non-async) issuance was started if expected.
|
||||
gotIssuance := getCertPemWasCalled && !gotAsyncRenewal
|
||||
gotIssuance := *certIssued && !gotAsyncRenewal
|
||||
if tt.wantIssuance != gotIssuance {
|
||||
t.Errorf("wants getCertPem to be called: %v, got called %v", tt.wantIssuance, gotIssuance)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// newFakeACMEServer does the minimum required work to allow our ACME client to
|
||||
// be happy that it has successfully issued a cert.
|
||||
func newFakeACMEServer(t *testing.T) (*httptest.Server, *bool) {
|
||||
// Set to true if a cert is issued. May be issued in a goroutine for async
|
||||
// renewal or in the main goroutine if issuance is required to obtain valid TLS credentials.
|
||||
var certIssued bool
|
||||
|
||||
url := func(host, path string) string {
|
||||
return fmt.Sprintf("http://%s%s", host, path)
|
||||
}
|
||||
validResponse := func(w http.ResponseWriter, body map[string]any) {
|
||||
w.Header().Set("Replay-Nonce", "fake-nonce")
|
||||
body["status"] = "valid"
|
||||
writeJSON(w, body)
|
||||
}
|
||||
|
||||
mux := http.NewServeMux()
|
||||
mux.HandleFunc("GET /directory", func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Replay-Nonce", "fake-nonce")
|
||||
writeJSON(w, map[string]any{
|
||||
"newAccount": url(r.Host, "/acme/new-account"),
|
||||
"newOrder": url(r.Host, "/acme/new-order"),
|
||||
})
|
||||
})
|
||||
mux.HandleFunc("POST /acme/new-account", func(w http.ResponseWriter, r *http.Request) {
|
||||
validResponse(w, map[string]any{})
|
||||
})
|
||||
mux.HandleFunc("POST /acme/new-order", func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Location", url(r.Host, "/acme/order/1"))
|
||||
w.WriteHeader(http.StatusCreated)
|
||||
// We don't present any challenges, so the client skips straight to
|
||||
// waiting for the order without having to mock the DNS challenge.
|
||||
validResponse(w, map[string]any{})
|
||||
})
|
||||
mux.HandleFunc("POST /acme/order/1", func(w http.ResponseWriter, r *http.Request) {
|
||||
validResponse(w, map[string]any{
|
||||
"finalize": url(r.Host, "/acme/finalize"),
|
||||
})
|
||||
})
|
||||
mux.HandleFunc("POST /acme/finalize", func(w http.ResponseWriter, r *http.Request) {
|
||||
validResponse(w, map[string]any{
|
||||
"certificate": url(r.Host, "/acme/cert"),
|
||||
})
|
||||
certIssued = true
|
||||
})
|
||||
mux.HandleFunc("POST /acme/cert", func(w http.ResponseWriter, r *http.Request) {
|
||||
certPEM, err := certTestFS.ReadFile("testdata/example.com.pem")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
w.Write(certPEM)
|
||||
})
|
||||
mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
|
||||
t.Fatalf("unimplemented: %s %s", r.Method, r.URL.Path)
|
||||
})
|
||||
|
||||
srv := httptest.NewServer(mux)
|
||||
t.Cleanup(srv.Close)
|
||||
return srv, &certIssued
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user