diff --git a/cmd/containerboot/certs.go b/cmd/containerboot/certs.go new file mode 100644 index 000000000..7af0424a9 --- /dev/null +++ b/cmd/containerboot/certs.go @@ -0,0 +1,147 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build linux + +package main + +import ( + "context" + "fmt" + "log" + "net" + "sync" + "time" + + "tailscale.com/ipn" + "tailscale.com/util/goroutines" + "tailscale.com/util/mak" +) + +// certManager is responsible for issuing certificates for known domains and for +// maintaining a loop that re-attempts issuance daily. +// Currently cert manager logic is only run on ingress ProxyGroup replicas that are responsible for managing certs for +// HA Ingress HTTPS endpoints ('write' replicas). +type certManager struct { + lc localClient + tracker goroutines.Tracker // tracks running goroutines + mu sync.Mutex // guards the following + // certLoops contains a map of DNS names, for which we currently need to + // manage certs to cancel functions that allow stopping a goroutine when + // we no longer need to manage certs for the DNS name. + certLoops map[string]context.CancelFunc +} + +// ensureCertLoops ensures that, for all currently managed Service HTTPS +// endpoints, there is a cert loop responsible for issuing and ensuring the +// renewal of the TLS certs. +// ServeConfig must not be nil. +func (cm *certManager) ensureCertLoops(ctx context.Context, sc *ipn.ServeConfig) error { + if sc == nil { + return fmt.Errorf("[unexpected] ensureCertLoops called with nil ServeConfig") + } + currentDomains := make(map[string]bool) + const httpsPort = "443" + for _, service := range sc.Services { + for hostPort := range service.Web { + domain, port, err := net.SplitHostPort(string(hostPort)) + if err != nil { + return fmt.Errorf("[unexpected] unable to parse HostPort %s", hostPort) + } + if port != httpsPort { // HA Ingress' HTTP endpoint + continue + } + currentDomains[domain] = true + } + } + cm.mu.Lock() + defer cm.mu.Unlock() + for domain := range currentDomains { + if _, exists := cm.certLoops[domain]; !exists { + cancelCtx, cancel := context.WithCancel(ctx) + mak.Set(&cm.certLoops, domain, cancel) + cm.tracker.Go(func() { cm.runCertLoop(cancelCtx, domain) }) + } + } + + // Stop goroutines for domain names that are no longer in the config. + for domain, cancel := range cm.certLoops { + if !currentDomains[domain] { + cancel() + delete(cm.certLoops, domain) + } + } + return nil +} + +// runCertLoop: +// - calls localAPI certificate endpoint to ensure that certs are issued for the +// given domain name +// - calls localAPI certificate endpoint daily to ensure that certs are renewed +// - if certificate issuance failed retries after an exponential backoff period +// starting at 1 minute and capped at 24 hours. Reset the backoff once issuance succeeds. +// Note that renewal check also happens when the node receives an HTTPS request and it is possible that certs get +// renewed at that point. Renewal here is needed to prevent the shared certs from expiry in edge cases where the 'write' +// replica does not get any HTTPS requests. +// https://letsencrypt.org/docs/integration-guide/#retrying-failures +func (cm *certManager) runCertLoop(ctx context.Context, domain string) { + const ( + normalInterval = 24 * time.Hour // regular renewal check + initialRetry = 1 * time.Minute // initial backoff after a failure + maxRetryInterval = 24 * time.Hour // max backoff period + ) + timer := time.NewTimer(0) // fire off timer immediately + defer timer.Stop() + retryCount := 0 + for { + select { + case <-ctx.Done(): + return + case <-timer.C: + // We call the certificate endpoint, but don't do anything + // with the returned certs here. + // The call to the certificate endpoint will ensure that + // certs are issued/renewed as needed and stored in the + // relevant state store. For example, for HA Ingress + // 'write' replica, the cert and key will be stored in a + // Kubernetes Secret named after the domain for which we + // are issuing. + // Note that renewals triggered by the call to the + // certificates endpoint here and by renewal check + // triggered during a call to node's HTTPS endpoint + // share the same state/renewal lock mechanism, so we + // should not run into redundant issuances during + // concurrent renewal checks. + // TODO(irbekrm): maybe it is worth adding a new + // issuance endpoint that explicitly only triggers + // issuance and stores certs in the relevant store, but + // does not return certs to the caller? + _, _, err := cm.lc.CertPair(ctx, domain) + if err != nil { + log.Printf("error refreshing certificate for %s: %v", domain, err) + } + var nextInterval time.Duration + // TODO(irbekrm): distinguish between LE rate limit + // errors and other error types like transient network + // errors. + if err == nil { + retryCount = 0 + nextInterval = normalInterval + } else { + retryCount++ + // Calculate backoff: initialRetry * 2^(retryCount-1) + // For retryCount=1: 1min * 2^0 = 1min + // For retryCount=2: 1min * 2^1 = 2min + // For retryCount=3: 1min * 2^2 = 4min + backoff := initialRetry * time.Duration(1<<(retryCount-1)) + if backoff > maxRetryInterval { + backoff = maxRetryInterval + } + nextInterval = backoff + log.Printf("Error refreshing certificate for %s (retry %d): %v. Will retry in %v\n", + domain, retryCount, err, nextInterval) + } + timer.Reset(nextInterval) + } + } +} diff --git a/cmd/containerboot/certs_test.go b/cmd/containerboot/certs_test.go new file mode 100644 index 000000000..577311ea3 --- /dev/null +++ b/cmd/containerboot/certs_test.go @@ -0,0 +1,229 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build linux + +package main + +import ( + "context" + "testing" + "time" + + "tailscale.com/ipn" + "tailscale.com/tailcfg" +) + +// TestEnsureCertLoops tests that the certManager correctly starts and stops +// update loops for certs when the serve config changes. It tracks goroutine +// count and uses that as a validator that the expected number of cert loops are +// running. +func TestEnsureCertLoops(t *testing.T) { + tests := []struct { + name string + initialConfig *ipn.ServeConfig + updatedConfig *ipn.ServeConfig + initialGoroutines int64 // after initial serve config is applied + updatedGoroutines int64 // after updated serve config is applied + wantErr bool + }{ + { + name: "empty_serve_config", + initialConfig: &ipn.ServeConfig{}, + initialGoroutines: 0, + }, + { + name: "nil_serve_config", + initialConfig: nil, + initialGoroutines: 0, + wantErr: true, + }, + { + name: "empty_to_one_service", + initialConfig: &ipn.ServeConfig{}, + updatedConfig: &ipn.ServeConfig{ + Services: map[tailcfg.ServiceName]*ipn.ServiceConfig{ + "svc:my-app": { + Web: map[ipn.HostPort]*ipn.WebServerConfig{ + "my-app.tailnetxyz.ts.net:443": {}, + }, + }, + }, + }, + initialGoroutines: 0, + updatedGoroutines: 1, + }, + { + name: "single_service", + initialConfig: &ipn.ServeConfig{ + Services: map[tailcfg.ServiceName]*ipn.ServiceConfig{ + "svc:my-app": { + Web: map[ipn.HostPort]*ipn.WebServerConfig{ + "my-app.tailnetxyz.ts.net:443": {}, + }, + }, + }, + }, + initialGoroutines: 1, + }, + { + name: "multiple_services", + initialConfig: &ipn.ServeConfig{ + Services: map[tailcfg.ServiceName]*ipn.ServiceConfig{ + "svc:my-app": { + Web: map[ipn.HostPort]*ipn.WebServerConfig{ + "my-app.tailnetxyz.ts.net:443": {}, + }, + }, + "svc:my-other-app": { + Web: map[ipn.HostPort]*ipn.WebServerConfig{ + "my-other-app.tailnetxyz.ts.net:443": {}, + }, + }, + }, + }, + initialGoroutines: 2, // one loop per domain across all services + }, + { + name: "ignore_non_https_ports", + initialConfig: &ipn.ServeConfig{ + Services: map[tailcfg.ServiceName]*ipn.ServiceConfig{ + "svc:my-app": { + Web: map[ipn.HostPort]*ipn.WebServerConfig{ + "my-app.tailnetxyz.ts.net:443": {}, + "my-app.tailnetxyz.ts.net:80": {}, + }, + }, + }, + }, + initialGoroutines: 1, // only one loop for the 443 endpoint + }, + { + name: "remove_domain", + initialConfig: &ipn.ServeConfig{ + Services: map[tailcfg.ServiceName]*ipn.ServiceConfig{ + "svc:my-app": { + Web: map[ipn.HostPort]*ipn.WebServerConfig{ + "my-app.tailnetxyz.ts.net:443": {}, + }, + }, + "svc:my-other-app": { + Web: map[ipn.HostPort]*ipn.WebServerConfig{ + "my-other-app.tailnetxyz.ts.net:443": {}, + }, + }, + }, + }, + updatedConfig: &ipn.ServeConfig{ + Services: map[tailcfg.ServiceName]*ipn.ServiceConfig{ + "svc:my-app": { + Web: map[ipn.HostPort]*ipn.WebServerConfig{ + "my-app.tailnetxyz.ts.net:443": {}, + }, + }, + }, + }, + initialGoroutines: 2, // initially two loops (one per service) + updatedGoroutines: 1, // one loop after removing service2 + }, + { + name: "add_domain", + initialConfig: &ipn.ServeConfig{ + Services: map[tailcfg.ServiceName]*ipn.ServiceConfig{ + "svc:my-app": { + Web: map[ipn.HostPort]*ipn.WebServerConfig{ + "my-app.tailnetxyz.ts.net:443": {}, + }, + }, + }, + }, + updatedConfig: &ipn.ServeConfig{ + Services: map[tailcfg.ServiceName]*ipn.ServiceConfig{ + "svc:my-app": { + Web: map[ipn.HostPort]*ipn.WebServerConfig{ + "my-app.tailnetxyz.ts.net:443": {}, + }, + }, + "svc:my-other-app": { + Web: map[ipn.HostPort]*ipn.WebServerConfig{ + "my-other-app.tailnetxyz.ts.net:443": {}, + }, + }, + }, + }, + initialGoroutines: 1, + updatedGoroutines: 2, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + cm := &certManager{ + lc: &fakeLocalClient{}, + certLoops: make(map[string]context.CancelFunc), + } + + allDone := make(chan bool, 1) + defer cm.tracker.AddDoneCallback(func() { + cm.mu.Lock() + defer cm.mu.Unlock() + if cm.tracker.RunningGoroutines() > 0 { + return + } + select { + case allDone <- true: + default: + } + })() + + err := cm.ensureCertLoops(ctx, tt.initialConfig) + if (err != nil) != tt.wantErr { + t.Fatalf("ensureCertLoops() error = %v", err) + } + + if got := cm.tracker.RunningGoroutines(); got != tt.initialGoroutines { + t.Errorf("after initial config: got %d running goroutines, want %d", got, tt.initialGoroutines) + } + + if tt.updatedConfig != nil { + if err := cm.ensureCertLoops(ctx, tt.updatedConfig); err != nil { + t.Fatalf("ensureCertLoops() error on update = %v", err) + } + + // Although starting goroutines and cancelling + // the context happens in the main goroutine, it + // the actual goroutine exit when a context is + // cancelled does not- so wait for a bit for the + // running goroutine count to reach the expected + // number. + deadline := time.After(5 * time.Second) + for { + if got := cm.tracker.RunningGoroutines(); got == tt.updatedGoroutines { + break + } + select { + case <-deadline: + t.Fatalf("timed out waiting for goroutine count to reach %d, currently at %d", + tt.updatedGoroutines, cm.tracker.RunningGoroutines()) + case <-time.After(10 * time.Millisecond): + continue + } + } + } + + if tt.updatedGoroutines == 0 { + return // no goroutines to wait for + } + // cancel context to make goroutines exit + cancel() + select { + case <-time.After(5 * time.Second): + t.Fatal("timed out waiting for goroutine to finish") + case <-allDone: + } + }) + } +} diff --git a/cmd/containerboot/main.go b/cmd/containerboot/main.go index cf4bd8620..5f8052bb9 100644 --- a/cmd/containerboot/main.go +++ b/cmd/containerboot/main.go @@ -646,7 +646,7 @@ runLoop: if cfg.ServeConfigPath != "" { triggerWatchServeConfigChanges.Do(func() { - go watchServeConfigChanges(ctx, cfg.ServeConfigPath, certDomainChanged, certDomain, client, kc) + go watchServeConfigChanges(ctx, certDomainChanged, certDomain, client, kc, cfg) }) } diff --git a/cmd/containerboot/serve.go b/cmd/containerboot/serve.go index 4ea5a9c46..37fd49777 100644 --- a/cmd/containerboot/serve.go +++ b/cmd/containerboot/serve.go @@ -28,10 +28,11 @@ import ( // applies it to lc. It exits when ctx is canceled. cdChanged is a channel that // is written to when the certDomain changes, causing the serve config to be // re-read and applied. -func watchServeConfigChanges(ctx context.Context, path string, cdChanged <-chan bool, certDomainAtomic *atomic.Pointer[string], lc *local.Client, kc *kubeClient) { +func watchServeConfigChanges(ctx context.Context, cdChanged <-chan bool, certDomainAtomic *atomic.Pointer[string], lc *local.Client, kc *kubeClient, cfg *settings) { if certDomainAtomic == nil { panic("certDomainAtomic must not be nil") } + var tickChan <-chan time.Time var eventChan <-chan fsnotify.Event if w, err := fsnotify.NewWatcher(); err != nil { @@ -43,7 +44,7 @@ func watchServeConfigChanges(ctx context.Context, path string, cdChanged <-chan tickChan = ticker.C } else { defer w.Close() - if err := w.Add(filepath.Dir(path)); err != nil { + if err := w.Add(filepath.Dir(cfg.ServeConfigPath)); err != nil { log.Fatalf("serve proxy: failed to add fsnotify watch: %v", err) } eventChan = w.Events @@ -51,6 +52,12 @@ func watchServeConfigChanges(ctx context.Context, path string, cdChanged <-chan var certDomain string var prevServeConfig *ipn.ServeConfig + var cm certManager + if cfg.CertShareMode == "rw" { + cm = certManager{ + lc: lc, + } + } for { select { case <-ctx.Done(): @@ -63,12 +70,12 @@ func watchServeConfigChanges(ctx context.Context, path string, cdChanged <-chan // k8s handles these mounts. So just re-read the file and apply it // if it's changed. } - sc, err := readServeConfig(path, certDomain) + sc, err := readServeConfig(cfg.ServeConfigPath, certDomain) if err != nil { log.Fatalf("serve proxy: failed to read serve config: %v", err) } if sc == nil { - log.Printf("serve proxy: no serve config at %q, skipping", path) + log.Printf("serve proxy: no serve config at %q, skipping", cfg.ServeConfigPath) continue } if prevServeConfig != nil && reflect.DeepEqual(sc, prevServeConfig) { @@ -83,6 +90,12 @@ func watchServeConfigChanges(ctx context.Context, path string, cdChanged <-chan } } prevServeConfig = sc + if cfg.CertShareMode != "rw" { + continue + } + if err := cm.ensureCertLoops(ctx, sc); err != nil { + log.Fatalf("serve proxy: error ensuring cert loops: %v", err) + } } } @@ -96,6 +109,7 @@ func certDomainFromNetmap(nm *netmap.NetworkMap) string { // localClient is a subset of [local.Client] that can be mocked for testing. type localClient interface { SetServeConfig(context.Context, *ipn.ServeConfig) error + CertPair(context.Context, string) ([]byte, []byte, error) } func updateServeConfig(ctx context.Context, sc *ipn.ServeConfig, certDomain string, lc localClient) error { diff --git a/cmd/containerboot/serve_test.go b/cmd/containerboot/serve_test.go index eb92a8dc8..fc18f254d 100644 --- a/cmd/containerboot/serve_test.go +++ b/cmd/containerboot/serve_test.go @@ -206,6 +206,10 @@ func (m *fakeLocalClient) SetServeConfig(ctx context.Context, cfg *ipn.ServeConf return nil } +func (m *fakeLocalClient) CertPair(ctx context.Context, domain string) (certPEM, keyPEM []byte, err error) { + return nil, nil, nil +} + func TestHasHTTPSEndpoint(t *testing.T) { tests := []struct { name string diff --git a/cmd/containerboot/settings.go b/cmd/containerboot/settings.go index 0da18e52c..c62db5340 100644 --- a/cmd/containerboot/settings.go +++ b/cmd/containerboot/settings.go @@ -74,6 +74,12 @@ type settings struct { HealthCheckEnabled bool DebugAddrPort string EgressProxiesCfgPath string + // CertShareMode is set for Kubernetes Pods running cert share mode. + // Possible values are empty (containerboot doesn't run any certs + // logic), 'ro' (for Pods that shold never attempt to issue/renew + // certs) and 'rw' for Pods that should manage the TLS certs shared + // amongst the replicas. + CertShareMode string } func configFromEnv() (*settings, error) { @@ -128,6 +134,17 @@ func configFromEnv() (*settings, error) { cfg.PodIPv6 = parsed.String() } } + // If cert share is enabled, set the replica as read or write. Only 0th + // replica should be able to write. + isInCertShareMode := defaultBool("TS_EXPERIMENTAL_CERT_SHARE", false) + if isInCertShareMode { + cfg.CertShareMode = "ro" + podName := os.Getenv("POD_NAME") + if strings.HasSuffix(podName, "-0") { + cfg.CertShareMode = "rw" + } + } + if err := cfg.validate(); err != nil { return nil, fmt.Errorf("invalid configuration: %v", err) } diff --git a/cmd/containerboot/tailscaled.go b/cmd/containerboot/tailscaled.go index 01ee96d3a..654b34757 100644 --- a/cmd/containerboot/tailscaled.go +++ b/cmd/containerboot/tailscaled.go @@ -33,6 +33,9 @@ func startTailscaled(ctx context.Context, cfg *settings) (*local.Client, *os.Pro cmd.SysProcAttr = &syscall.SysProcAttr{ Setpgid: true, } + if cfg.CertShareMode != "" { + cmd.Env = append(os.Environ(), "TS_CERT_SHARE_MODE="+cfg.CertShareMode) + } log.Printf("Starting tailscaled") if err := cmd.Start(); err != nil { return nil, nil, fmt.Errorf("starting tailscaled failed: %v", err)