net/tlsdial: fix TLS cert validation of HTTPS proxies

If you had HTTPS_PROXY=https://some-valid-cert.example.com running a
CONNECT proxy, we should've been able to do a TLS CONNECT request to
e.g. controlplane.tailscale.com:443 through that, and I'm pretty sure
it used to work, but refactorings and lack of integration tests made
it regress.

It probably regressed when we added the baked-in LetsEncrypt root cert
validation fallback code, which was testing against the wrong hostname
(the ultimate one, not the one which we were being asked to validate)

Fixes #16222

Change-Id: If014e395f830e2f87f056f588edacad5c15e91bc
Signed-off-by: Brad Fitzpatrick <bradfitz@tailscale.com>
This commit is contained in:
Brad Fitzpatrick
2025-06-08 18:51:41 -07:00
committed by Brad Fitzpatrick
parent 4979ce7a94
commit e92eb6b17b
17 changed files with 672 additions and 49 deletions

View File

@@ -4,13 +4,35 @@
package controlclient
import (
"context"
"crypto/tls"
"errors"
"flag"
"fmt"
"io"
"net"
"net/http"
"net/netip"
"net/url"
"reflect"
"slices"
"sync/atomic"
"testing"
"time"
"tailscale.com/control/controlknobs"
"tailscale.com/health"
"tailscale.com/net/bakedroots"
"tailscale.com/net/connectproxy"
"tailscale.com/net/netmon"
"tailscale.com/net/tsdial"
"tailscale.com/tailcfg"
"tailscale.com/tstest"
"tailscale.com/tstest/integration/testcontrol"
"tailscale.com/tstest/tlstest"
"tailscale.com/tstime"
"tailscale.com/types/key"
"tailscale.com/types/logger"
"tailscale.com/types/netmap"
"tailscale.com/types/persist"
)
@@ -188,3 +210,206 @@ func isRetryableErrorForTest(err error) bool {
}
return false
}
var liveNetworkTest = flag.Bool("live-network-test", false, "run live network tests")
func TestDirectProxyManual(t *testing.T) {
if !*liveNetworkTest {
t.Skip("skipping without --live-network-test")
}
dialer := &tsdial.Dialer{}
dialer.SetNetMon(netmon.NewStatic())
opts := Options{
Persist: persist.Persist{},
GetMachinePrivateKey: func() (key.MachinePrivate, error) {
return key.NewMachine(), nil
},
ServerURL: "https://controlplane.tailscale.com",
Clock: tstime.StdClock{},
Hostinfo: &tailcfg.Hostinfo{
BackendLogID: "test-backend-log-id",
},
DiscoPublicKey: key.NewDisco().Public(),
Logf: t.Logf,
HealthTracker: &health.Tracker{},
PopBrowserURL: func(url string) {
t.Logf("PopBrowserURL: %q", url)
},
Dialer: dialer,
ControlKnobs: &controlknobs.Knobs{},
}
d, err := NewDirect(opts)
if err != nil {
t.Fatalf("NewDirect: %v", err)
}
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
url, err := d.TryLogin(ctx, LoginEphemeral)
if err != nil {
t.Fatalf("TryLogin: %v", err)
}
t.Logf("URL: %q", url)
}
func TestHTTPSNoProxy(t *testing.T) { testHTTPS(t, false) }
// TestTLSWithProxy verifies we can connect to the control plane via
// an HTTPS proxy.
func TestHTTPSWithProxy(t *testing.T) { testHTTPS(t, true) }
func testHTTPS(t *testing.T, withProxy bool) {
bakedroots.ResetForTest(t, tlstest.TestRootCA())
controlLn, err := tls.Listen("tcp", "127.0.0.1:0", tlstest.ControlPlaneKeyPair.ServerTLSConfig())
if err != nil {
t.Fatal(err)
}
defer controlLn.Close()
proxyLn, err := tls.Listen("tcp", "127.0.0.1:0", tlstest.ProxyServerKeyPair.ServerTLSConfig())
if err != nil {
t.Fatal(err)
}
defer proxyLn.Close()
const requiredAuthKey = "hunter2"
const someUsername = "testuser"
const somePassword = "testpass"
testControl := &testcontrol.Server{
Logf: tstest.WhileTestRunningLogger(t),
RequireAuthKey: requiredAuthKey,
}
controlSrv := &http.Server{
Handler: testControl,
ErrorLog: logger.StdLogger(t.Logf),
}
go controlSrv.Serve(controlLn)
const fakeControlIP = "1.2.3.4"
const fakeProxyIP = "5.6.7.8"
dialer := &tsdial.Dialer{}
dialer.SetNetMon(netmon.NewStatic())
dialer.SetSystemDialerForTest(func(ctx context.Context, network, addr string) (net.Conn, error) {
host, _, err := net.SplitHostPort(addr)
if err != nil {
return nil, fmt.Errorf("SplitHostPort(%q): %v", addr, err)
}
var d net.Dialer
if host == fakeControlIP {
return d.DialContext(ctx, network, controlLn.Addr().String())
}
if host == fakeProxyIP {
return d.DialContext(ctx, network, proxyLn.Addr().String())
}
return nil, fmt.Errorf("unexpected dial to %q", addr)
})
opts := Options{
Persist: persist.Persist{},
GetMachinePrivateKey: func() (key.MachinePrivate, error) {
return key.NewMachine(), nil
},
AuthKey: requiredAuthKey,
ServerURL: "https://controlplane.tstest",
Clock: tstime.StdClock{},
Hostinfo: &tailcfg.Hostinfo{
BackendLogID: "test-backend-log-id",
},
DiscoPublicKey: key.NewDisco().Public(),
Logf: t.Logf,
HealthTracker: &health.Tracker{},
PopBrowserURL: func(url string) {
t.Logf("PopBrowserURL: %q", url)
},
Dialer: dialer,
}
d, err := NewDirect(opts)
if err != nil {
t.Fatalf("NewDirect: %v", err)
}
d.dnsCache.LookupIPForTest = func(ctx context.Context, host string) ([]netip.Addr, error) {
switch host {
case "controlplane.tstest":
return []netip.Addr{netip.MustParseAddr(fakeControlIP)}, nil
case "proxy.tstest":
if !withProxy {
t.Errorf("unexpected DNS lookup for %q with proxy disabled", host)
return nil, fmt.Errorf("unexpected DNS lookup for %q", host)
}
return []netip.Addr{netip.MustParseAddr(fakeProxyIP)}, nil
}
t.Errorf("unexpected DNS query for %q", host)
return []netip.Addr{}, nil
}
var proxyReqs atomic.Int64
if withProxy {
d.httpc.Transport.(*http.Transport).Proxy = func(req *http.Request) (*url.URL, error) {
t.Logf("using proxy for %q", req.URL)
u := &url.URL{
Scheme: "https",
Host: "proxy.tstest:443",
User: url.UserPassword(someUsername, somePassword),
}
return u, nil
}
connectProxy := &http.Server{
Handler: connectProxyTo(t, "controlplane.tstest:443", controlLn.Addr().String(), &proxyReqs),
}
go connectProxy.Serve(proxyLn)
}
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
url, err := d.TryLogin(ctx, LoginEphemeral)
if err != nil {
t.Fatalf("TryLogin: %v", err)
}
if url != "" {
t.Errorf("got URL %q, want empty", url)
}
if withProxy {
if got, want := proxyReqs.Load(), int64(1); got != want {
t.Errorf("proxy CONNECT requests = %d; want %d", got, want)
}
}
}
func connectProxyTo(t testing.TB, target, backendAddrPort string, reqs *atomic.Int64) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.RequestURI != target {
t.Errorf("invalid CONNECT request to %q; want %q", r.RequestURI, target)
http.Error(w, "bad target", http.StatusBadRequest)
return
}
r.Header.Set("Authorization", r.Header.Get("Proxy-Authorization")) // for the BasicAuth method. kinda trashy.
user, pass, ok := r.BasicAuth()
if !ok || user != "testuser" || pass != "testpass" {
t.Errorf("invalid CONNECT auth %q:%q; want %q:%q", user, pass, "testuser", "testpass")
http.Error(w, "bad auth", http.StatusUnauthorized)
return
}
(&connectproxy.Handler{
Dial: func(ctx context.Context, network, addr string) (net.Conn, error) {
var d net.Dialer
c, err := d.DialContext(ctx, network, backendAddrPort)
if err == nil {
reqs.Add(1)
}
return c, err
},
Logf: t.Logf,
}).ServeHTTP(w, r)
})
}

View File

@@ -16,7 +16,6 @@ import (
"net"
"net/http"
"net/netip"
"net/url"
"os"
"reflect"
"runtime"
@@ -240,10 +239,6 @@ func NewDirect(opts Options) (*Direct, error) {
opts.ControlKnobs = &controlknobs.Knobs{}
}
opts.ServerURL = strings.TrimRight(opts.ServerURL, "/")
serverURL, err := url.Parse(opts.ServerURL)
if err != nil {
return nil, err
}
if opts.Clock == nil {
opts.Clock = tstime.StdClock{}
}
@@ -273,7 +268,7 @@ func NewDirect(opts Options) (*Direct, error) {
tr := http.DefaultTransport.(*http.Transport).Clone()
tr.Proxy = tshttpproxy.ProxyFromEnvironment
tshttpproxy.SetTransportGetProxyConnectHeader(tr)
tr.TLSClientConfig = tlsdial.Config(serverURL.Hostname(), opts.HealthTracker, tr.TLSClientConfig)
tr.TLSClientConfig = tlsdial.Config(opts.HealthTracker, tr.TLSClientConfig)
var dialFunc netx.DialFunc
dialFunc, interceptedDial = makeScreenTimeDetectingDialFunc(opts.Dialer.SystemDial)
tr.DialContext = dnscache.Dialer(dialFunc, dnsCache)

View File

@@ -534,7 +534,7 @@ func (a *Dialer) tryURLUpgrade(ctx context.Context, u *url.URL, optAddr netip.Ad
// Disable HTTP2, since h2 can't do protocol switching.
tr.TLSClientConfig.NextProtos = []string{}
tr.TLSNextProto = map[string]func(string, *tls.Conn) http.RoundTripper{}
tr.TLSClientConfig = tlsdial.Config(a.Hostname, a.HealthTracker, tr.TLSClientConfig)
tr.TLSClientConfig = tlsdial.Config(a.HealthTracker, tr.TLSClientConfig)
if !tr.TLSClientConfig.InsecureSkipVerify {
panic("unexpected") // should be set by tlsdial.Config
}