diff --git a/cmd/proxy-test-server/proxy-test-server.go b/cmd/proxy-test-server/proxy-test-server.go new file mode 100644 index 000000000..9f8c94a38 --- /dev/null +++ b/cmd/proxy-test-server/proxy-test-server.go @@ -0,0 +1,81 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// The proxy-test-server command is a simple HTTP proxy server for testing +// Tailscale's client proxy functionality. +package main + +import ( + "crypto/tls" + "flag" + "fmt" + "log" + "net" + "net/http" + "os" + "strings" + + "golang.org/x/crypto/acme/autocert" + "tailscale.com/net/connectproxy" + "tailscale.com/tempfork/acme" +) + +var ( + listen = flag.String("listen", ":8080", "Address to listen on for HTTPS proxy requests") + hostname = flag.String("hostname", "localhost", "Hostname for the proxy server") + tailscaleOnly = flag.Bool("tailscale-only", true, "Restrict proxy to Tailscale targets only") + extraAllowedHosts = flag.String("allow-hosts", "", "Comma-separated list of allowed target hosts to additionally allow if --tailscale-only is true") +) + +func main() { + flag.Parse() + + am := &autocert.Manager{ + HostPolicy: autocert.HostWhitelist(*hostname), + Prompt: autocert.AcceptTOS, + Cache: autocert.DirCache(os.ExpandEnv("$HOME/.cache/autocert/proxy-test-server")), + } + var allowTarget func(hostPort string) error + if *tailscaleOnly { + allowTarget = func(hostPort string) error { + host, port, err := net.SplitHostPort(hostPort) + if err != nil { + return fmt.Errorf("invalid target %q: %v", hostPort, err) + } + if port != "443" { + return fmt.Errorf("target %q must use port 443", hostPort) + } + for allowed := range strings.SplitSeq(*extraAllowedHosts, ",") { + if host == allowed { + return nil // explicitly allowed target + } + } + if !strings.HasSuffix(host, ".tailscale.com") { + return fmt.Errorf("target %q is not a Tailscale host", hostPort) + } + return nil // valid Tailscale target + } + } + + go func() { + if err := http.ListenAndServe(":http", am.HTTPHandler(nil)); err != nil { + log.Fatalf("autocert HTTP server failed: %v", err) + } + }() + hs := &http.Server{ + Addr: *listen, + Handler: &connectproxy.Handler{ + Check: allowTarget, + Logf: log.Printf, + }, + TLSConfig: &tls.Config{ + GetCertificate: am.GetCertificate, + NextProtos: []string{ + "http/1.1", // enable HTTP/2 + acme.ALPNProto, // enable tls-alpn ACME challenges + }, + }, + } + log.Printf("Starting proxy-test-server on %s (hostname: %q)\n", *listen, *hostname) + log.Fatal(hs.ListenAndServeTLS("", "")) // cert and key are provided by autocert +} diff --git a/control/controlclient/controlclient_test.go b/control/controlclient/controlclient_test.go index f8882a4e7..1107f76a4 100644 --- a/control/controlclient/controlclient_test.go +++ b/control/controlclient/controlclient_test.go @@ -4,13 +4,35 @@ package controlclient import ( + "context" + "crypto/tls" "errors" + "flag" "fmt" "io" + "net" + "net/http" + "net/netip" + "net/url" "reflect" "slices" + "sync/atomic" "testing" + "time" + "tailscale.com/control/controlknobs" + "tailscale.com/health" + "tailscale.com/net/bakedroots" + "tailscale.com/net/connectproxy" + "tailscale.com/net/netmon" + "tailscale.com/net/tsdial" + "tailscale.com/tailcfg" + "tailscale.com/tstest" + "tailscale.com/tstest/integration/testcontrol" + "tailscale.com/tstest/tlstest" + "tailscale.com/tstime" + "tailscale.com/types/key" + "tailscale.com/types/logger" "tailscale.com/types/netmap" "tailscale.com/types/persist" ) @@ -188,3 +210,206 @@ func isRetryableErrorForTest(err error) bool { } return false } + +var liveNetworkTest = flag.Bool("live-network-test", false, "run live network tests") + +func TestDirectProxyManual(t *testing.T) { + if !*liveNetworkTest { + t.Skip("skipping without --live-network-test") + } + + dialer := &tsdial.Dialer{} + dialer.SetNetMon(netmon.NewStatic()) + + opts := Options{ + Persist: persist.Persist{}, + GetMachinePrivateKey: func() (key.MachinePrivate, error) { + return key.NewMachine(), nil + }, + ServerURL: "https://controlplane.tailscale.com", + Clock: tstime.StdClock{}, + Hostinfo: &tailcfg.Hostinfo{ + BackendLogID: "test-backend-log-id", + }, + DiscoPublicKey: key.NewDisco().Public(), + Logf: t.Logf, + HealthTracker: &health.Tracker{}, + PopBrowserURL: func(url string) { + t.Logf("PopBrowserURL: %q", url) + }, + Dialer: dialer, + ControlKnobs: &controlknobs.Knobs{}, + } + d, err := NewDirect(opts) + if err != nil { + t.Fatalf("NewDirect: %v", err) + } + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + url, err := d.TryLogin(ctx, LoginEphemeral) + if err != nil { + t.Fatalf("TryLogin: %v", err) + } + t.Logf("URL: %q", url) +} + +func TestHTTPSNoProxy(t *testing.T) { testHTTPS(t, false) } + +// TestTLSWithProxy verifies we can connect to the control plane via +// an HTTPS proxy. +func TestHTTPSWithProxy(t *testing.T) { testHTTPS(t, true) } + +func testHTTPS(t *testing.T, withProxy bool) { + bakedroots.ResetForTest(t, tlstest.TestRootCA()) + + controlLn, err := tls.Listen("tcp", "127.0.0.1:0", tlstest.ControlPlaneKeyPair.ServerTLSConfig()) + if err != nil { + t.Fatal(err) + } + defer controlLn.Close() + + proxyLn, err := tls.Listen("tcp", "127.0.0.1:0", tlstest.ProxyServerKeyPair.ServerTLSConfig()) + if err != nil { + t.Fatal(err) + } + defer proxyLn.Close() + + const requiredAuthKey = "hunter2" + const someUsername = "testuser" + const somePassword = "testpass" + + testControl := &testcontrol.Server{ + Logf: tstest.WhileTestRunningLogger(t), + RequireAuthKey: requiredAuthKey, + } + controlSrv := &http.Server{ + Handler: testControl, + ErrorLog: logger.StdLogger(t.Logf), + } + go controlSrv.Serve(controlLn) + + const fakeControlIP = "1.2.3.4" + const fakeProxyIP = "5.6.7.8" + + dialer := &tsdial.Dialer{} + dialer.SetNetMon(netmon.NewStatic()) + dialer.SetSystemDialerForTest(func(ctx context.Context, network, addr string) (net.Conn, error) { + host, _, err := net.SplitHostPort(addr) + if err != nil { + return nil, fmt.Errorf("SplitHostPort(%q): %v", addr, err) + } + var d net.Dialer + if host == fakeControlIP { + return d.DialContext(ctx, network, controlLn.Addr().String()) + } + if host == fakeProxyIP { + return d.DialContext(ctx, network, proxyLn.Addr().String()) + } + return nil, fmt.Errorf("unexpected dial to %q", addr) + }) + + opts := Options{ + Persist: persist.Persist{}, + GetMachinePrivateKey: func() (key.MachinePrivate, error) { + return key.NewMachine(), nil + }, + AuthKey: requiredAuthKey, + ServerURL: "https://controlplane.tstest", + Clock: tstime.StdClock{}, + Hostinfo: &tailcfg.Hostinfo{ + BackendLogID: "test-backend-log-id", + }, + DiscoPublicKey: key.NewDisco().Public(), + Logf: t.Logf, + HealthTracker: &health.Tracker{}, + PopBrowserURL: func(url string) { + t.Logf("PopBrowserURL: %q", url) + }, + Dialer: dialer, + } + d, err := NewDirect(opts) + if err != nil { + t.Fatalf("NewDirect: %v", err) + } + + d.dnsCache.LookupIPForTest = func(ctx context.Context, host string) ([]netip.Addr, error) { + switch host { + case "controlplane.tstest": + return []netip.Addr{netip.MustParseAddr(fakeControlIP)}, nil + case "proxy.tstest": + if !withProxy { + t.Errorf("unexpected DNS lookup for %q with proxy disabled", host) + return nil, fmt.Errorf("unexpected DNS lookup for %q", host) + } + return []netip.Addr{netip.MustParseAddr(fakeProxyIP)}, nil + } + t.Errorf("unexpected DNS query for %q", host) + return []netip.Addr{}, nil + } + + var proxyReqs atomic.Int64 + if withProxy { + d.httpc.Transport.(*http.Transport).Proxy = func(req *http.Request) (*url.URL, error) { + t.Logf("using proxy for %q", req.URL) + u := &url.URL{ + Scheme: "https", + Host: "proxy.tstest:443", + User: url.UserPassword(someUsername, somePassword), + } + return u, nil + } + + connectProxy := &http.Server{ + Handler: connectProxyTo(t, "controlplane.tstest:443", controlLn.Addr().String(), &proxyReqs), + } + go connectProxy.Serve(proxyLn) + } + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + url, err := d.TryLogin(ctx, LoginEphemeral) + if err != nil { + t.Fatalf("TryLogin: %v", err) + } + if url != "" { + t.Errorf("got URL %q, want empty", url) + } + + if withProxy { + if got, want := proxyReqs.Load(), int64(1); got != want { + t.Errorf("proxy CONNECT requests = %d; want %d", got, want) + } + } +} + +func connectProxyTo(t testing.TB, target, backendAddrPort string, reqs *atomic.Int64) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.RequestURI != target { + t.Errorf("invalid CONNECT request to %q; want %q", r.RequestURI, target) + http.Error(w, "bad target", http.StatusBadRequest) + return + } + + r.Header.Set("Authorization", r.Header.Get("Proxy-Authorization")) // for the BasicAuth method. kinda trashy. + user, pass, ok := r.BasicAuth() + if !ok || user != "testuser" || pass != "testpass" { + t.Errorf("invalid CONNECT auth %q:%q; want %q:%q", user, pass, "testuser", "testpass") + http.Error(w, "bad auth", http.StatusUnauthorized) + return + } + + (&connectproxy.Handler{ + Dial: func(ctx context.Context, network, addr string) (net.Conn, error) { + var d net.Dialer + c, err := d.DialContext(ctx, network, backendAddrPort) + if err == nil { + reqs.Add(1) + } + return c, err + }, + Logf: t.Logf, + }).ServeHTTP(w, r) + }) +} diff --git a/control/controlclient/direct.go b/control/controlclient/direct.go index 2d6dc6e36..4c9b04ce9 100644 --- a/control/controlclient/direct.go +++ b/control/controlclient/direct.go @@ -16,7 +16,6 @@ import ( "net" "net/http" "net/netip" - "net/url" "os" "reflect" "runtime" @@ -240,10 +239,6 @@ func NewDirect(opts Options) (*Direct, error) { opts.ControlKnobs = &controlknobs.Knobs{} } opts.ServerURL = strings.TrimRight(opts.ServerURL, "/") - serverURL, err := url.Parse(opts.ServerURL) - if err != nil { - return nil, err - } if opts.Clock == nil { opts.Clock = tstime.StdClock{} } @@ -273,7 +268,7 @@ func NewDirect(opts Options) (*Direct, error) { tr := http.DefaultTransport.(*http.Transport).Clone() tr.Proxy = tshttpproxy.ProxyFromEnvironment tshttpproxy.SetTransportGetProxyConnectHeader(tr) - tr.TLSClientConfig = tlsdial.Config(serverURL.Hostname(), opts.HealthTracker, tr.TLSClientConfig) + tr.TLSClientConfig = tlsdial.Config(opts.HealthTracker, tr.TLSClientConfig) var dialFunc netx.DialFunc dialFunc, interceptedDial = makeScreenTimeDetectingDialFunc(opts.Dialer.SystemDial) tr.DialContext = dnscache.Dialer(dialFunc, dnsCache) diff --git a/control/controlhttp/client.go b/control/controlhttp/client.go index 869bcb599..1bb60d672 100644 --- a/control/controlhttp/client.go +++ b/control/controlhttp/client.go @@ -534,7 +534,7 @@ func (a *Dialer) tryURLUpgrade(ctx context.Context, u *url.URL, optAddr netip.Ad // Disable HTTP2, since h2 can't do protocol switching. tr.TLSClientConfig.NextProtos = []string{} tr.TLSNextProto = map[string]func(string, *tls.Conn) http.RoundTripper{} - tr.TLSClientConfig = tlsdial.Config(a.Hostname, a.HealthTracker, tr.TLSClientConfig) + tr.TLSClientConfig = tlsdial.Config(a.HealthTracker, tr.TLSClientConfig) if !tr.TLSClientConfig.InsecureSkipVerify { panic("unexpected") // should be set by tlsdial.Config } diff --git a/derp/derphttp/derphttp_client.go b/derp/derphttp/derphttp_client.go index 8c42e9070..7385f0ad1 100644 --- a/derp/derphttp/derphttp_client.go +++ b/derp/derphttp/derphttp_client.go @@ -647,12 +647,13 @@ func (c *Client) dialRegion(ctx context.Context, reg *tailcfg.DERPRegion) (net.C } func (c *Client) tlsClient(nc net.Conn, node *tailcfg.DERPNode) *tls.Conn { - tlsConf := tlsdial.Config(c.tlsServerName(node), c.HealthTracker, c.TLSConfig) + tlsConf := tlsdial.Config(c.HealthTracker, c.TLSConfig) if node != nil { if node.InsecureForTests { tlsConf.InsecureSkipVerify = true tlsConf.VerifyConnection = nil } + tlsConf.ServerName = c.tlsServerName(node) if node.CertName != "" { if suf, ok := strings.CutPrefix(node.CertName, "sha256-raw:"); ok { tlsdial.SetConfigExpectedCertHash(tlsConf, suf) diff --git a/derp/derphttp/derphttp_test.go b/derp/derphttp/derphttp_test.go index 252549660..7f0a7e333 100644 --- a/derp/derphttp/derphttp_test.go +++ b/derp/derphttp/derphttp_test.go @@ -7,10 +7,14 @@ import ( "bytes" "context" "crypto/tls" + "encoding/json" + "flag" "fmt" + "maps" "net" "net/http" "net/http/httptest" + "slices" "strings" "sync" "testing" @@ -19,6 +23,7 @@ import ( "tailscale.com/derp" "tailscale.com/net/netmon" "tailscale.com/net/netx" + "tailscale.com/tailcfg" "tailscale.com/types/key" ) @@ -556,3 +561,32 @@ func TestNotifyError(t *testing.T) { t.Fatalf("context done before receiving error: %v", ctx.Err()) } } + +var liveNetworkTest = flag.Bool("live-net-tests", false, "run live network tests") + +func TestManualDial(t *testing.T) { + if !*liveNetworkTest { + t.Skip("skipping live network test without --live-net-tests") + } + dm := &tailcfg.DERPMap{} + res, err := http.Get("https://controlplane.tailscale.com/derpmap/default") + if err != nil { + t.Fatalf("fetching DERPMap: %v", err) + } + defer res.Body.Close() + if err := json.NewDecoder(res.Body).Decode(dm); err != nil { + t.Fatalf("decoding DERPMap: %v", err) + } + + region := slices.Sorted(maps.Keys(dm.Regions))[0] + + netMon := netmon.NewStatic() + rc := NewRegionClient(key.NewNode(), t.Logf, netMon, func() *tailcfg.DERPRegion { + return dm.Regions[region] + }) + defer rc.Close() + + if err := rc.Connect(context.Background()); err != nil { + t.Fatalf("rc.Connect: %v", err) + } +} diff --git a/logpolicy/logpolicy.go b/logpolicy/logpolicy.go index fc259a417..b84528d7b 100644 --- a/logpolicy/logpolicy.go +++ b/logpolicy/logpolicy.go @@ -9,7 +9,6 @@ package logpolicy import ( "bufio" "bytes" - "cmp" "context" "crypto/tls" "encoding/json" @@ -911,8 +910,7 @@ func (opts TransportOptions) New() http.RoundTripper { tr.TLSNextProto = map[string]func(authority string, c *tls.Conn) http.RoundTripper{} } - host := cmp.Or(opts.Host, logtail.DefaultHost) - tr.TLSClientConfig = tlsdial.Config(host, opts.Health, tr.TLSClientConfig) + tr.TLSClientConfig = tlsdial.Config(opts.Health, tr.TLSClientConfig) // Force TLS 1.3 since we know log.tailscale.com supports it. tr.TLSClientConfig.MinVersion = tls.VersionTLS13 diff --git a/net/bakedroots/bakedroots.go b/net/bakedroots/bakedroots.go index 42e70c0dd..8787b4a6d 100644 --- a/net/bakedroots/bakedroots.go +++ b/net/bakedroots/bakedroots.go @@ -7,6 +7,7 @@ package bakedroots import ( "crypto/x509" + "fmt" "sync" "tailscale.com/util/testenv" @@ -14,7 +15,7 @@ import ( // Get returns the baked-in roots. // -// As of 2025-01-21, this includes only the LetsEncrypt ISRG Root X1 root. +// As of 2025-01-21, this includes only the LetsEncrypt ISRG Root X1 & X2 roots. func Get() *x509.CertPool { roots.once.Do(func() { roots.parsePEM(append( @@ -56,7 +57,7 @@ type rootsOnce struct { func (r *rootsOnce) parsePEM(caPEM []byte) { p := x509.NewCertPool() if !p.AppendCertsFromPEM(caPEM) { - panic("bogus PEM") + panic(fmt.Sprintf("bogus PEM: %q", caPEM)) } r.p = p } diff --git a/net/connectproxy/connectproxy.go b/net/connectproxy/connectproxy.go new file mode 100644 index 000000000..4bf687502 --- /dev/null +++ b/net/connectproxy/connectproxy.go @@ -0,0 +1,93 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package connectproxy contains some CONNECT proxy code. +package connectproxy + +import ( + "context" + "io" + "log" + "net" + "net/http" + "time" + + "tailscale.com/net/netx" + "tailscale.com/types/logger" +) + +// Handler is an HTTP CONNECT proxy handler. +type Handler struct { + // Dial, if non-nil, is an alternate dialer to use + // instead of the default dialer. + Dial netx.DialFunc + + // Logf, if non-nil, is an alterate logger to + // use instead of log.Printf. + Logf logger.Logf + + // Check, if non-nil, validates the CONNECT target. + Check func(hostPort string) error +} + +func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + if r.Method != "CONNECT" { + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + return + } + + dial := h.Dial + if dial == nil { + var d net.Dialer + dial = d.DialContext + } + logf := h.Logf + if logf == nil { + logf = log.Printf + } + + hostPort := r.RequestURI + if h.Check != nil { + if err := h.Check(hostPort); err != nil { + logf("CONNECT target %q not allowed: %v", hostPort, err) + http.Error(w, "Invalid CONNECT target", http.StatusForbidden) + return + } + } + + ctx, cancel := context.WithTimeout(ctx, 15*time.Second) + defer cancel() + back, err := dial(ctx, "tcp", hostPort) + if err != nil { + logf("error CONNECT dialing %v: %v", hostPort, err) + http.Error(w, "Connect failure", http.StatusBadGateway) + return + } + defer back.Close() + + hj, ok := w.(http.Hijacker) + if !ok { + http.Error(w, "CONNECT hijack unavailable", http.StatusInternalServerError) + return + } + c, br, err := hj.Hijack() + if err != nil { + logf("CONNECT hijack: %v", err) + return + } + defer c.Close() + + io.WriteString(c, "HTTP/1.1 200 OK\r\n\r\n") + + errc := make(chan error, 2) + go func() { + _, err := io.Copy(c, back) + errc <- err + }() + go func() { + _, err := io.Copy(back, br) + errc <- err + }() + <-errc +} diff --git a/net/dnscache/dnscache.go b/net/dnscache/dnscache.go index 96550cbb1..d60e92f0b 100644 --- a/net/dnscache/dnscache.go +++ b/net/dnscache/dnscache.go @@ -24,6 +24,7 @@ import ( "tailscale.com/util/cloudenv" "tailscale.com/util/singleflight" "tailscale.com/util/slicesx" + "tailscale.com/util/testenv" ) var zaddr netip.Addr @@ -63,6 +64,10 @@ type Resolver struct { // If nil, net.DefaultResolver is used. Forward *net.Resolver + // LookupIPForTest, if non-nil and in tests, handles requests instead + // of the usual mechanisms. + LookupIPForTest func(ctx context.Context, host string) ([]netip.Addr, error) + // LookupIPFallback optionally provides a backup DNS mechanism // to use if Forward returns an error or no results. LookupIPFallback func(ctx context.Context, host string) ([]netip.Addr, error) @@ -284,7 +289,13 @@ func (r *Resolver) lookupIP(ctx context.Context, host string) (ip, ip6 netip.Add lookupCtx, lookupCancel := context.WithTimeout(ctx, r.lookupTimeoutForHost(host)) defer lookupCancel() - ips, err := r.fwd().LookupNetIP(lookupCtx, "ip", host) + + var ips []netip.Addr + if r.LookupIPForTest != nil && testenv.InTest() { + ips, err = r.LookupIPForTest(ctx, host) + } else { + ips, err = r.fwd().LookupNetIP(lookupCtx, "ip", host) + } if err != nil || len(ips) == 0 { if resolver, ok := r.cloudHostResolver(); ok { r.dlogf("resolving %q via cloud resolver", host) diff --git a/net/dnsfallback/dnsfallback.go b/net/dnsfallback/dnsfallback.go index 4c5d5fa2f..8e53c3b29 100644 --- a/net/dnsfallback/dnsfallback.go +++ b/net/dnsfallback/dnsfallback.go @@ -286,7 +286,7 @@ func bootstrapDNSMap(ctx context.Context, serverName string, serverIP netip.Addr tr.DialContext = func(ctx context.Context, netw, addr string) (net.Conn, error) { return dialer.DialContext(ctx, "tcp", net.JoinHostPort(serverIP.String(), "443")) } - tr.TLSClientConfig = tlsdial.Config(serverName, ht, tr.TLSClientConfig) + tr.TLSClientConfig = tlsdial.Config(ht, tr.TLSClientConfig) c := &http.Client{Transport: tr} req, err := http.NewRequestWithContext(ctx, "GET", "https://"+serverName+"/bootstrap-dns?q="+url.QueryEscape(queryName), nil) if err != nil { diff --git a/net/tlsdial/tlsdial.go b/net/tlsdial/tlsdial.go index 1bd2450aa..80f3bfc06 100644 --- a/net/tlsdial/tlsdial.go +++ b/net/tlsdial/tlsdial.go @@ -59,18 +59,26 @@ var mitmBlockWarnable = health.Register(&health.Warnable{ ImpactsConnectivity: true, }) -// Config returns a tls.Config for connecting to a server. +// Config returns a tls.Config for connecting to a server that +// uses system roots for validation but, if those fail, also tries +// the baked-in LetsEncrypt roots as a fallback validation method. +// // If base is non-nil, it's cloned as the base config before // being configured and returned. // If ht is non-nil, it's used to report health errors. -func Config(host string, ht *health.Tracker, base *tls.Config) *tls.Config { +func Config(ht *health.Tracker, base *tls.Config) *tls.Config { var conf *tls.Config if base == nil { conf = new(tls.Config) } else { conf = base.Clone() } - conf.ServerName = host + + // Note: we do NOT set conf.ServerName here (as we accidentally did + // previously), as this path is also used when dialing an HTTPS proxy server + // (through which we'll send a CONNECT request to get a TCP connection to do + // the real TCP connection) because host is the ultimate hostname, but this + // tls.Config is used for both the proxy and the ultimate target. if n := sslKeyLogFile; n != "" { f, err := os.OpenFile(n, os.O_CREATE|os.O_APPEND|os.O_WRONLY, 0600) @@ -93,7 +101,9 @@ func Config(host string, ht *health.Tracker, base *tls.Config) *tls.Config { // (with the baked-in fallback root) in the VerifyConnection hook. conf.InsecureSkipVerify = true conf.VerifyConnection = func(cs tls.ConnectionState) (retErr error) { - if host == "log.tailscale.com" && hostinfo.IsNATLabGuestVM() { + dialedHost := cs.ServerName + + if dialedHost == "log.tailscale.com" && hostinfo.IsNATLabGuestVM() { // Allow log.tailscale.com TLS MITM for integration tests when // the client's running within a NATLab VM. return nil @@ -116,7 +126,7 @@ func Config(host string, ht *health.Tracker, base *tls.Config) *tls.Config { // Show a dedicated warning. m, ok := blockblame.VerifyCertificate(cert) if ok { - log.Printf("tlsdial: server cert for %q looks like %q equipment (could be blocking Tailscale)", host, m.Name) + log.Printf("tlsdial: server cert seen while dialing %q looks like %q equipment (could be blocking Tailscale)", dialedHost, m.Name) ht.SetUnhealthy(mitmBlockWarnable, health.Args{"manufacturer": m.Name}) } else { ht.SetHealthy(mitmBlockWarnable) @@ -135,7 +145,7 @@ func Config(host string, ht *health.Tracker, base *tls.Config) *tls.Config { ht.SetTLSConnectionError(cs.ServerName, nil) if selfSignedIssuer != "" { // Log the self-signed issuer, but don't treat it as an error. - log.Printf("tlsdial: warning: server cert for %q passed x509 validation but is self-signed by %q", host, selfSignedIssuer) + log.Printf("tlsdial: warning: server cert for %q passed x509 validation but is self-signed by %q", dialedHost, selfSignedIssuer) } } }() @@ -144,7 +154,7 @@ func Config(host string, ht *health.Tracker, base *tls.Config) *tls.Config { // First try doing x509 verification with the system's // root CA pool. opts := x509.VerifyOptions{ - DNSName: cs.ServerName, + DNSName: dialedHost, Intermediates: x509.NewCertPool(), } for _, cert := range cs.PeerCertificates[1:] { @@ -152,7 +162,7 @@ func Config(host string, ht *health.Tracker, base *tls.Config) *tls.Config { } _, errSys := cs.PeerCertificates[0].Verify(opts) if debug() { - log.Printf("tlsdial(sys %q): %v", host, errSys) + log.Printf("tlsdial(sys %q): %v", dialedHost, errSys) } // Always verify with our baked-in Let's Encrypt certificate, @@ -161,13 +171,11 @@ func Config(host string, ht *health.Tracker, base *tls.Config) *tls.Config { opts.Roots = bakedroots.Get() _, bakedErr := cs.PeerCertificates[0].Verify(opts) if debug() { - log.Printf("tlsdial(bake %q): %v", host, bakedErr) + log.Printf("tlsdial(bake %q): %v", dialedHost, bakedErr) } else if bakedErr != nil { - if _, loaded := tlsdialWarningPrinted.LoadOrStore(host, true); !loaded { - if errSys == nil { - log.Printf("tlsdial: warning: server cert for %q is not a Let's Encrypt cert", host) - } else { - log.Printf("tlsdial: error: server cert for %q failed to verify and is not a Let's Encrypt cert", host) + if _, loaded := tlsdialWarningPrinted.LoadOrStore(dialedHost, true); !loaded { + if errSys != nil { + log.Printf("tlsdial: error: server cert for %q failed both system roots & Let's Encrypt root validation", dialedHost) } } } @@ -202,9 +210,6 @@ func SetConfigExpectedCert(c *tls.Config, certDNSName string) { c.ServerName = certDNSName return } - if c.VerifyPeerCertificate != nil { - panic("refusing to override tls.Config.VerifyPeerCertificate") - } // Set InsecureSkipVerify to prevent crypto/tls from doing its // own cert verification, but do the same work that it'd do // (but using certDNSName) in the VerifyPeerCertificate hook. @@ -257,29 +262,30 @@ func SetConfigExpectedCertHash(c *tls.Config, wantFullCertSHA256Hex string) { if c.VerifyPeerCertificate != nil { panic("refusing to override tls.Config.VerifyPeerCertificate") } + // Set InsecureSkipVerify to prevent crypto/tls from doing its // own cert verification, but do the same work that it'd do - // (but using certDNSName) in the VerifyPeerCertificate hook. + // (but using certDNSName) in the VerifyConnection hook. c.InsecureSkipVerify = true - c.VerifyConnection = nil - c.VerifyPeerCertificate = func(rawCerts [][]byte, _ [][]*x509.Certificate) error { + + c.VerifyConnection = func(cs tls.ConnectionState) error { + dialedHost := cs.ServerName var sawGoodCert bool - for _, rawCert := range rawCerts { - cert, err := x509.ParseCertificate(rawCert) - if err != nil { - return fmt.Errorf("ParseCertificate: %w", err) - } + + for _, cert := range cs.PeerCertificates { if strings.HasPrefix(cert.Subject.CommonName, derpconst.MetaCertCommonNamePrefix) { continue } if sawGoodCert { return errors.New("unexpected multiple certs presented") } - if fmt.Sprintf("%02x", sha256.Sum256(rawCert)) != wantFullCertSHA256Hex { + if fmt.Sprintf("%02x", sha256.Sum256(cert.Raw)) != wantFullCertSHA256Hex { return fmt.Errorf("cert hash does not match expected cert hash") } - if err := cert.VerifyHostname(c.ServerName); err != nil { - return fmt.Errorf("cert does not match server name %q: %w", c.ServerName, err) + if dialedHost != "" { // it's empty when dialing a derper by IP with no hostname + if err := cert.VerifyHostname(dialedHost); err != nil { + return fmt.Errorf("cert does not match server name %q: %w", dialedHost, err) + } } now := time.Now() if now.After(cert.NotAfter) { @@ -302,12 +308,8 @@ func SetConfigExpectedCertHash(c *tls.Config, wantFullCertSHA256Hex string) { func NewTransport() *http.Transport { return &http.Transport{ DialTLSContext: func(ctx context.Context, network, addr string) (net.Conn, error) { - host, _, err := net.SplitHostPort(addr) - if err != nil { - return nil, err - } var d tls.Dialer - d.Config = Config(host, nil, nil) + d.Config = Config(nil, nil) return d.DialContext(ctx, network, addr) }, } diff --git a/net/tlsdial/tlsdial_test.go b/net/tlsdial/tlsdial_test.go index 6723b82e0..e2c4cdd4f 100644 --- a/net/tlsdial/tlsdial_test.go +++ b/net/tlsdial/tlsdial_test.go @@ -86,7 +86,7 @@ func TestFallbackRootWorks(t *testing.T) { DisableKeepAlives: true, // for test cleanup ease } ht := new(health.Tracker) - tr.TLSClientConfig = Config("tlsdial.test", ht, tr.TLSClientConfig) + tr.TLSClientConfig = Config(ht, tr.TLSClientConfig) c := &http.Client{Transport: tr} ctr0 := atomic.LoadInt32(&counterFallbackOK) diff --git a/tstest/tlstest/testdata/controlplane.tstest.key b/tstest/tlstest/testdata/controlplane.tstest.key new file mode 100644 index 000000000..dbe5ede34 --- /dev/null +++ b/tstest/tlstest/testdata/controlplane.tstest.key @@ -0,0 +1,5 @@ +-----BEGIN EC PRIVATE KEY----- +MHcCAQEEIHcxOQNVyqvBSSlu7c93QW6OsyccjL+R1evW4acd32MWoAoGCCqGSM49 +AwEHoUQDQgAEIOY5/CQ8CMuKYPLf+r6OEneqfzQ5RfgPnLdkL22qhm8xb69ZCXxz +UecawU0KEDfHLYbUYXSuhAFxxuPh9I3x5Q== +-----END EC PRIVATE KEY----- diff --git a/tstest/tlstest/testdata/proxy.tstest.key b/tstest/tlstest/testdata/proxy.tstest.key new file mode 100644 index 000000000..067279089 --- /dev/null +++ b/tstest/tlstest/testdata/proxy.tstest.key @@ -0,0 +1,5 @@ +-----BEGIN EC PRIVATE KEY----- +MHcCAQEEING1XBDWFXQjqBmLjhp20hXOf2rk/I0N6W7muv9RVvk3oAoGCCqGSM49 +AwEHoUQDQgAE8lxnEEeLqYikwmXbXSsIQSw20R0oLA831s960KQZEgt0P9SbWcJc +QTk98rdfYT/QDdHn157Oh4FPcDtxmdQ4vw== +-----END EC PRIVATE KEY----- diff --git a/tstest/tlstest/testdata/root-ca.key b/tstest/tlstest/testdata/root-ca.key new file mode 100644 index 000000000..ece23ddf9 --- /dev/null +++ b/tstest/tlstest/testdata/root-ca.key @@ -0,0 +1,5 @@ +-----BEGIN EC PRIVATE KEY----- +MHcCAQEEIMl3xjqt1dnXBpYJSEqevirAcnSJ79I2tucdRazlrDG9oAoGCCqGSM49 +AwEHoUQDQgAEQ/+Jme+16hgO7TtPSIFHVV0Yt969ltVlARVcNUZmWc0upQaq7uiJ +Aur5KtzwxU3YI4bhNK0593OK2TLvEEWIdw== +-----END EC PRIVATE KEY----- diff --git a/tstest/tlstest/tlstest.go b/tstest/tlstest/tlstest.go new file mode 100644 index 000000000..f65c261e8 --- /dev/null +++ b/tstest/tlstest/tlstest.go @@ -0,0 +1,167 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package tlstest contains code to help test Tailscale's client proxy support. +package tlstest + +import ( + "bytes" + "crypto/ecdsa" + "crypto/rand" + "crypto/tls" + "crypto/x509" + "crypto/x509/pkix" + _ "embed" + "encoding/pem" + "fmt" + "math/big" + "sync" + "time" +) + +// Some baked-in ECDSA keys to speed up tests, not having to burn CPU to +// generate them each time. We only make the certs (which have expiry times) +// at runtime. +// +// They were made with: +// +// openssl ecparam -name prime256v1 -genkey -noout -out root-ca.key +var ( + //go:embed testdata/root-ca.key + rootCAKeyPEM []byte + + // TestProxyServerKey is the PEM private key for [TestProxyServerCert]. + // + //go:embed testdata/proxy.tstest.key + TestProxyServerKey []byte + + // TestControlPlaneKey is the PEM private key for [TestControlPlaneCert]. + // + //go:embed testdata/controlplane.tstest.key + TestControlPlaneKey []byte +) + +// TestRootCA returns a self-signed ECDSA root CA certificate (as PEM) for +// testing purposes. +func TestRootCA() []byte { + return bytes.Clone(testRootCAOncer()) +} + +var testRootCAOncer = sync.OnceValue(func() []byte { + key := rootCAKey() + now := time.Now().Add(-time.Hour) + tpl := &x509.Certificate{ + SerialNumber: big.NewInt(1), + Subject: pkix.Name{ + CommonName: "Tailscale Unit Test ECDSA Root", + Organization: []string{"Tailscale Test Org"}, + }, + NotBefore: now, + NotAfter: now.AddDate(5, 0, 0), + + IsCA: true, + BasicConstraintsValid: true, + KeyUsage: x509.KeyUsageCertSign | x509.KeyUsageCRLSign, + SubjectKeyId: mustSKID(&key.PublicKey), + } + + der, err := x509.CreateCertificate(rand.Reader, tpl, tpl, &key.PublicKey, key) + if err != nil { + panic(err) + } + return pemCert(der) +}) + +func pemCert(der []byte) []byte { + var buf bytes.Buffer + if err := pem.Encode(&buf, &pem.Block{Type: "CERTIFICATE", Bytes: der}); err != nil { + panic(fmt.Sprintf("failed to encode PEM: %v", err)) + } + return buf.Bytes() +} + +var rootCAKey = sync.OnceValue(func() *ecdsa.PrivateKey { + return mustParsePEM(rootCAKeyPEM, x509.ParseECPrivateKey) +}) + +func mustParsePEM[T any](pemBytes []byte, parse func([]byte) (T, error)) T { + block, rest := pem.Decode(pemBytes) + if block == nil || len(rest) > 0 { + panic("invalid PEM") + } + v, err := parse(block.Bytes) + if err != nil { + panic(fmt.Sprintf("invalid PEM: %v", err)) + } + return v +} + +// KeyPair is a simple struct to hold a certificate and its private key. +type KeyPair struct { + Domain string + KeyPEM []byte // PEM-encoded private key +} + +// ServerTLSConfig returns a TLS configuration suitable for a server +// using the KeyPair's certificate and private key. +func (p KeyPair) ServerTLSConfig() *tls.Config { + cert, err := tls.X509KeyPair(p.CertPEM(), p.KeyPEM) + if err != nil { + panic("invalid TLS key pair: " + err.Error()) + } + return &tls.Config{ + Certificates: []tls.Certificate{cert}, + } +} + +// ProxyServerKeyPair is a KeyPair for a test control plane server +// with domain name "proxy.tstest". +var ProxyServerKeyPair = KeyPair{ + Domain: "proxy.tstest", + KeyPEM: TestProxyServerKey, +} + +// ControlPlaneKeyPair is a KeyPair for a test control plane server +// with domain name "controlplane.tstest". +var ControlPlaneKeyPair = KeyPair{ + Domain: "controlplane.tstest", + KeyPEM: TestControlPlaneKey, +} + +func (p KeyPair) CertPEM() []byte { + caCert := mustParsePEM(TestRootCA(), x509.ParseCertificate) + caPriv := mustParsePEM(rootCAKeyPEM, x509.ParseECPrivateKey) + leafKey := mustParsePEM(p.KeyPEM, x509.ParseECPrivateKey) + + serial, err := rand.Int(rand.Reader, big.NewInt(0).Lsh(big.NewInt(1), 128)) + if err != nil { + panic(err) + } + + now := time.Now().Add(-time.Hour) + tpl := &x509.Certificate{ + SerialNumber: serial, + Subject: pkix.Name{CommonName: p.Domain}, + NotBefore: now, + NotAfter: now.AddDate(2, 0, 0), + + KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyEncipherment, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, + BasicConstraintsValid: true, + DNSNames: []string{p.Domain}, + } + + der, err := x509.CreateCertificate(rand.Reader, tpl, caCert, &leafKey.PublicKey, caPriv) + if err != nil { + panic(err) + } + return pemCert(der) +} + +func mustSKID(pub *ecdsa.PublicKey) []byte { + skid, err := x509.MarshalPKIXPublicKey(pub) + if err != nil { + panic(err) + } + return skid[:20] // same as x509 library +}