diff --git a/envknob/envknob.go b/envknob/envknob.go index e74bfea71..2662da2b4 100644 --- a/envknob/envknob.go +++ b/envknob/envknob.go @@ -417,6 +417,23 @@ func App() string { return "" } +// IsCertShareReadOnlyMode returns true if this replica should never attempt to +// issue or renew TLS credentials for any of the HTTPS endpoints that it is +// serving. It should only return certs found in its cert store. Currently, +// this is used by the Kubernetes Operator's HA Ingress via VIPServices, where +// multiple Ingress proxy instances serve the same HTTPS endpoint with a shared +// TLS credentials. The TLS credentials should only be issued by one of the +// replicas. +// For HTTPS Ingress the operator and containerboot ensure +// that read-only replicas will not be serving the HTTPS endpoints before there +// is a shared cert available. +func IsCertShareReadOnlyMode() bool { + m := String("TS_CERT_SHARE_MODE") + return m == modeRO +} + +const modeRO = "ro" + // CrashOnUnexpected reports whether the Tailscale client should panic // on unexpected conditions. If TS_DEBUG_CRASH_ON_UNEXPECTED is set, that's // used. Otherwise the default value is true for unstable builds. diff --git a/ipn/ipnlocal/cert.go b/ipn/ipnlocal/cert.go index 4c026a9e7..111dc5a2d 100644 --- a/ipn/ipnlocal/cert.go +++ b/ipn/ipnlocal/cert.go @@ -119,6 +119,9 @@ func (b *LocalBackend) GetCertPEMWithValidity(ctx context.Context, domain string } if pair, err := getCertPEMCached(cs, domain, now); err == nil { + if envknob.IsCertShareReadOnlyMode() { + return pair, nil + } // If we got here, we have a valid unexpired cert. // Check whether we should start an async renewal. shouldRenew, err := b.shouldStartDomainRenewal(cs, domain, now, pair, minValidity) @@ -134,7 +137,7 @@ func (b *LocalBackend) GetCertPEMWithValidity(ctx context.Context, domain string if minValidity == 0 { logf("starting async renewal") // Start renewal in the background, return current valid cert. - go b.getCertPEM(context.Background(), cs, logf, traceACME, domain, now, minValidity) + b.goTracker.Go(func() { getCertPEM(context.Background(), b, cs, logf, traceACME, domain, now, minValidity) }) return pair, nil } // If the caller requested a specific validity duration, fall through @@ -142,7 +145,11 @@ func (b *LocalBackend) GetCertPEMWithValidity(ctx context.Context, domain string logf("starting sync renewal") } - pair, err := b.getCertPEM(ctx, cs, logf, traceACME, domain, now, minValidity) + if envknob.IsCertShareReadOnlyMode() { + return nil, fmt.Errorf("retrieving cached TLS certificate failed and cert store is configured in read-only mode, not attempting to issue a new certificate: %w", err) + } + + pair, err := getCertPEM(ctx, b, cs, logf, traceACME, domain, now, minValidity) if err != nil { logf("getCertPEM: %v", err) return nil, err @@ -358,7 +365,29 @@ type certStateStore struct { testRoots *x509.CertPool } +// TLSCertKeyReader is an interface implemented by state stores where it makes +// sense to read the TLS cert and key in a single operation that can be +// distinguished from generic state value reads. Currently this is only implemented +// by the kubestore.Store, which, in some cases, need to read cert and key from a +// non-cached TLS Secret. +type TLSCertKeyReader interface { + ReadTLSCertAndKey(domain string) ([]byte, []byte, error) +} + func (s certStateStore) Read(domain string, now time.Time) (*TLSCertKeyPair, error) { + // If we're using a store that supports atomic reads, use that + if kr, ok := s.StateStore.(TLSCertKeyReader); ok { + cert, key, err := kr.ReadTLSCertAndKey(domain) + if err != nil { + return nil, err + } + if !validCertPEM(domain, key, cert, s.testRoots, now) { + return nil, errCertExpired + } + return &TLSCertKeyPair{CertPEM: cert, KeyPEM: key, Cached: true}, nil + } + + // Otherwise fall back to separate reads certPEM, err := s.ReadState(ipn.StateKey(domain + ".crt")) if err != nil { return nil, err @@ -446,7 +475,9 @@ func getCertPEMCached(cs certStore, domain string, now time.Time) (p *TLSCertKey return cs.Read(domain, now) } -func (b *LocalBackend) getCertPEM(ctx context.Context, cs certStore, logf logger.Logf, traceACME func(any), domain string, now time.Time, minValidity time.Duration) (*TLSCertKeyPair, error) { +// getCertPem checks if a cert needs to be renewed and if so, renews it. +// It can be overridden in tests. +var getCertPEM = func(ctx context.Context, b *LocalBackend, cs certStore, logf logger.Logf, traceACME func(any), domain string, now time.Time, minValidity time.Duration) (*TLSCertKeyPair, error) { acmeMu.Lock() defer acmeMu.Unlock() diff --git a/ipn/ipnlocal/cert_test.go b/ipn/ipnlocal/cert_test.go index c77570e87..e2398f670 100644 --- a/ipn/ipnlocal/cert_test.go +++ b/ipn/ipnlocal/cert_test.go @@ -6,6 +6,7 @@ package ipnlocal import ( + "context" "crypto/ecdsa" "crypto/elliptic" "crypto/rand" @@ -14,11 +15,17 @@ import ( "embed" "encoding/pem" "math/big" + "os" + "path/filepath" "testing" "time" "github.com/google/go-cmp/cmp" + "tailscale.com/envknob" "tailscale.com/ipn/store/mem" + "tailscale.com/tstest" + "tailscale.com/types/logger" + "tailscale.com/util/must" ) func TestValidLookingCertDomain(t *testing.T) { @@ -221,3 +228,151 @@ func TestDebugACMEDirectoryURL(t *testing.T) { }) } } + +func TestGetCertPEMWithValidity(t *testing.T) { + const testDomain = "example.com" + b := &LocalBackend{ + store: &mem.Store{}, + varRoot: t.TempDir(), + ctx: context.Background(), + logf: t.Logf, + } + certDir, err := b.certDir() + if err != nil { + t.Fatalf("certDir error: %v", err) + } + if _, err := b.getCertStore(); err != nil { + t.Fatalf("getCertStore error: %v", err) + } + testRoot, err := certTestFS.ReadFile("testdata/rootCA.pem") + if err != nil { + t.Fatal(err) + } + roots := x509.NewCertPool() + if !roots.AppendCertsFromPEM(testRoot) { + t.Fatal("Unable to add test CA to the cert pool") + } + testX509Roots = roots + defer func() { testX509Roots = nil }() + tests := []struct { + name string + now time.Time + // storeCerts is true if the test cert and key should be written to store. + storeCerts bool + readOnlyMode bool // TS_READ_ONLY_CERTS env var + wantAsyncRenewal bool // async issuance should be started + wantIssuance bool // sync issuance should be started + wantErr bool + }{ + { + name: "valid_no_renewal", + now: time.Date(2023, time.February, 20, 0, 0, 0, 0, time.UTC), + storeCerts: true, + wantAsyncRenewal: false, + wantIssuance: false, + wantErr: false, + }, + { + name: "issuance_needed", + now: time.Date(2023, time.February, 20, 0, 0, 0, 0, time.UTC), + storeCerts: false, + wantAsyncRenewal: false, + wantIssuance: true, + wantErr: false, + }, + { + name: "renewal_needed", + now: time.Date(2025, time.May, 1, 0, 0, 0, 0, time.UTC), + storeCerts: true, + wantAsyncRenewal: true, + wantIssuance: false, + wantErr: false, + }, + { + name: "renewal_needed_read_only_mode", + now: time.Date(2025, time.May, 1, 0, 0, 0, 0, time.UTC), + storeCerts: true, + readOnlyMode: true, + wantAsyncRenewal: false, + wantIssuance: false, + wantErr: false, + }, + { + name: "no_certs_read_only_mode", + now: time.Date(2025, time.May, 1, 0, 0, 0, 0, time.UTC), + storeCerts: false, + readOnlyMode: true, + wantAsyncRenewal: false, + wantIssuance: false, + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + + if tt.readOnlyMode { + envknob.Setenv("TS_CERT_SHARE_MODE", "ro") + } + + os.RemoveAll(certDir) + if tt.storeCerts { + os.MkdirAll(certDir, 0755) + if err := os.WriteFile(filepath.Join(certDir, "example.com.crt"), + must.Get(os.ReadFile("testdata/example.com.pem")), 0644); err != nil { + t.Fatal(err) + } + if err := os.WriteFile(filepath.Join(certDir, "example.com.key"), + must.Get(os.ReadFile("testdata/example.com-key.pem")), 0644); err != nil { + t.Fatal(err) + } + } + + b.clock = tstest.NewClock(tstest.ClockOpts{Start: tt.now}) + + allDone := make(chan bool, 1) + defer b.goTracker.AddDoneCallback(func() { + b.mu.Lock() + defer b.mu.Unlock() + if b.goTracker.RunningGoroutines() > 0 { + return + } + select { + case allDone <- true: + default: + } + })() + + // Set to true if get getCertPEM is called. GetCertPEM can be called in a goroutine for async + // renewal or in the main goroutine if issuance is required to obtain valid TLS credentials. + getCertPemWasCalled := false + getCertPEM = func(ctx context.Context, b *LocalBackend, cs certStore, logf logger.Logf, traceACME func(any), domain string, now time.Time, minValidity time.Duration) (*TLSCertKeyPair, error) { + getCertPemWasCalled = true + return nil, nil + } + prevGoRoutines := b.goTracker.StartedGoroutines() + _, err = b.GetCertPEMWithValidity(context.Background(), testDomain, 0) + if (err != nil) != tt.wantErr { + t.Errorf("b.GetCertPemWithValidity got err %v, wants error: '%v'", err, tt.wantErr) + } + // GetCertPEMWithValidity calls getCertPEM in a goroutine if async renewal is needed. That's the + // only goroutine it starts, so this can be used to test if async renewal was started. + gotAsyncRenewal := b.goTracker.StartedGoroutines()-prevGoRoutines != 0 + if gotAsyncRenewal { + select { + case <-time.After(5 * time.Second): + t.Fatal("timed out waiting for goroutines to finish") + case <-allDone: + } + } + // Verify that async renewal was triggered if expected. + if tt.wantAsyncRenewal != gotAsyncRenewal { + t.Fatalf("wants getCertPem to be called async: %v, got called %v", tt.wantAsyncRenewal, gotAsyncRenewal) + } + // Verify that (non-async) issuance was started if expected. + gotIssuance := getCertPemWasCalled && !gotAsyncRenewal + if tt.wantIssuance != gotIssuance { + t.Errorf("wants getCertPem to be called: %v, got called %v", tt.wantIssuance, gotIssuance) + } + }) + } +}