mirror of
https://github.com/tailscale/tailscale.git
synced 2025-06-20 07:08:40 +00:00
net/tlsdial: fix TLS cert validation of HTTPS proxies
If you had HTTPS_PROXY=https://some-valid-cert.example.com running a CONNECT proxy, we should've been able to do a TLS CONNECT request to e.g. controlplane.tailscale.com:443 through that, and I'm pretty sure it used to work, but refactorings and lack of integration tests made it regress. It probably regressed when we added the baked-in LetsEncrypt root cert validation fallback code, which was testing against the wrong hostname (the ultimate one, not the one which we were being asked to validate) Fixes #16222 Change-Id: If014e395f830e2f87f056f588edacad5c15e91bc Signed-off-by: Brad Fitzpatrick <bradfitz@tailscale.com>
This commit is contained in:
parent
4979ce7a94
commit
e92eb6b17b
81
cmd/proxy-test-server/proxy-test-server.go
Normal file
81
cmd/proxy-test-server/proxy-test-server.go
Normal file
@ -0,0 +1,81 @@
|
|||||||
|
// Copyright (c) Tailscale Inc & AUTHORS
|
||||||
|
// SPDX-License-Identifier: BSD-3-Clause
|
||||||
|
|
||||||
|
// The proxy-test-server command is a simple HTTP proxy server for testing
|
||||||
|
// Tailscale's client proxy functionality.
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/tls"
|
||||||
|
"flag"
|
||||||
|
"fmt"
|
||||||
|
"log"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"os"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"golang.org/x/crypto/acme/autocert"
|
||||||
|
"tailscale.com/net/connectproxy"
|
||||||
|
"tailscale.com/tempfork/acme"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
listen = flag.String("listen", ":8080", "Address to listen on for HTTPS proxy requests")
|
||||||
|
hostname = flag.String("hostname", "localhost", "Hostname for the proxy server")
|
||||||
|
tailscaleOnly = flag.Bool("tailscale-only", true, "Restrict proxy to Tailscale targets only")
|
||||||
|
extraAllowedHosts = flag.String("allow-hosts", "", "Comma-separated list of allowed target hosts to additionally allow if --tailscale-only is true")
|
||||||
|
)
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
flag.Parse()
|
||||||
|
|
||||||
|
am := &autocert.Manager{
|
||||||
|
HostPolicy: autocert.HostWhitelist(*hostname),
|
||||||
|
Prompt: autocert.AcceptTOS,
|
||||||
|
Cache: autocert.DirCache(os.ExpandEnv("$HOME/.cache/autocert/proxy-test-server")),
|
||||||
|
}
|
||||||
|
var allowTarget func(hostPort string) error
|
||||||
|
if *tailscaleOnly {
|
||||||
|
allowTarget = func(hostPort string) error {
|
||||||
|
host, port, err := net.SplitHostPort(hostPort)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("invalid target %q: %v", hostPort, err)
|
||||||
|
}
|
||||||
|
if port != "443" {
|
||||||
|
return fmt.Errorf("target %q must use port 443", hostPort)
|
||||||
|
}
|
||||||
|
for allowed := range strings.SplitSeq(*extraAllowedHosts, ",") {
|
||||||
|
if host == allowed {
|
||||||
|
return nil // explicitly allowed target
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !strings.HasSuffix(host, ".tailscale.com") {
|
||||||
|
return fmt.Errorf("target %q is not a Tailscale host", hostPort)
|
||||||
|
}
|
||||||
|
return nil // valid Tailscale target
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
if err := http.ListenAndServe(":http", am.HTTPHandler(nil)); err != nil {
|
||||||
|
log.Fatalf("autocert HTTP server failed: %v", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
hs := &http.Server{
|
||||||
|
Addr: *listen,
|
||||||
|
Handler: &connectproxy.Handler{
|
||||||
|
Check: allowTarget,
|
||||||
|
Logf: log.Printf,
|
||||||
|
},
|
||||||
|
TLSConfig: &tls.Config{
|
||||||
|
GetCertificate: am.GetCertificate,
|
||||||
|
NextProtos: []string{
|
||||||
|
"http/1.1", // enable HTTP/2
|
||||||
|
acme.ALPNProto, // enable tls-alpn ACME challenges
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
log.Printf("Starting proxy-test-server on %s (hostname: %q)\n", *listen, *hostname)
|
||||||
|
log.Fatal(hs.ListenAndServeTLS("", "")) // cert and key are provided by autocert
|
||||||
|
}
|
@ -4,13 +4,35 @@
|
|||||||
package controlclient
|
package controlclient
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
|
"crypto/tls"
|
||||||
"errors"
|
"errors"
|
||||||
|
"flag"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"net/netip"
|
||||||
|
"net/url"
|
||||||
"reflect"
|
"reflect"
|
||||||
"slices"
|
"slices"
|
||||||
|
"sync/atomic"
|
||||||
"testing"
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"tailscale.com/control/controlknobs"
|
||||||
|
"tailscale.com/health"
|
||||||
|
"tailscale.com/net/bakedroots"
|
||||||
|
"tailscale.com/net/connectproxy"
|
||||||
|
"tailscale.com/net/netmon"
|
||||||
|
"tailscale.com/net/tsdial"
|
||||||
|
"tailscale.com/tailcfg"
|
||||||
|
"tailscale.com/tstest"
|
||||||
|
"tailscale.com/tstest/integration/testcontrol"
|
||||||
|
"tailscale.com/tstest/tlstest"
|
||||||
|
"tailscale.com/tstime"
|
||||||
|
"tailscale.com/types/key"
|
||||||
|
"tailscale.com/types/logger"
|
||||||
"tailscale.com/types/netmap"
|
"tailscale.com/types/netmap"
|
||||||
"tailscale.com/types/persist"
|
"tailscale.com/types/persist"
|
||||||
)
|
)
|
||||||
@ -188,3 +210,206 @@ func isRetryableErrorForTest(err error) bool {
|
|||||||
}
|
}
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var liveNetworkTest = flag.Bool("live-network-test", false, "run live network tests")
|
||||||
|
|
||||||
|
func TestDirectProxyManual(t *testing.T) {
|
||||||
|
if !*liveNetworkTest {
|
||||||
|
t.Skip("skipping without --live-network-test")
|
||||||
|
}
|
||||||
|
|
||||||
|
dialer := &tsdial.Dialer{}
|
||||||
|
dialer.SetNetMon(netmon.NewStatic())
|
||||||
|
|
||||||
|
opts := Options{
|
||||||
|
Persist: persist.Persist{},
|
||||||
|
GetMachinePrivateKey: func() (key.MachinePrivate, error) {
|
||||||
|
return key.NewMachine(), nil
|
||||||
|
},
|
||||||
|
ServerURL: "https://controlplane.tailscale.com",
|
||||||
|
Clock: tstime.StdClock{},
|
||||||
|
Hostinfo: &tailcfg.Hostinfo{
|
||||||
|
BackendLogID: "test-backend-log-id",
|
||||||
|
},
|
||||||
|
DiscoPublicKey: key.NewDisco().Public(),
|
||||||
|
Logf: t.Logf,
|
||||||
|
HealthTracker: &health.Tracker{},
|
||||||
|
PopBrowserURL: func(url string) {
|
||||||
|
t.Logf("PopBrowserURL: %q", url)
|
||||||
|
},
|
||||||
|
Dialer: dialer,
|
||||||
|
ControlKnobs: &controlknobs.Knobs{},
|
||||||
|
}
|
||||||
|
d, err := NewDirect(opts)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("NewDirect: %v", err)
|
||||||
|
}
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
url, err := d.TryLogin(ctx, LoginEphemeral)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("TryLogin: %v", err)
|
||||||
|
}
|
||||||
|
t.Logf("URL: %q", url)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHTTPSNoProxy(t *testing.T) { testHTTPS(t, false) }
|
||||||
|
|
||||||
|
// TestTLSWithProxy verifies we can connect to the control plane via
|
||||||
|
// an HTTPS proxy.
|
||||||
|
func TestHTTPSWithProxy(t *testing.T) { testHTTPS(t, true) }
|
||||||
|
|
||||||
|
func testHTTPS(t *testing.T, withProxy bool) {
|
||||||
|
bakedroots.ResetForTest(t, tlstest.TestRootCA())
|
||||||
|
|
||||||
|
controlLn, err := tls.Listen("tcp", "127.0.0.1:0", tlstest.ControlPlaneKeyPair.ServerTLSConfig())
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
defer controlLn.Close()
|
||||||
|
|
||||||
|
proxyLn, err := tls.Listen("tcp", "127.0.0.1:0", tlstest.ProxyServerKeyPair.ServerTLSConfig())
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
defer proxyLn.Close()
|
||||||
|
|
||||||
|
const requiredAuthKey = "hunter2"
|
||||||
|
const someUsername = "testuser"
|
||||||
|
const somePassword = "testpass"
|
||||||
|
|
||||||
|
testControl := &testcontrol.Server{
|
||||||
|
Logf: tstest.WhileTestRunningLogger(t),
|
||||||
|
RequireAuthKey: requiredAuthKey,
|
||||||
|
}
|
||||||
|
controlSrv := &http.Server{
|
||||||
|
Handler: testControl,
|
||||||
|
ErrorLog: logger.StdLogger(t.Logf),
|
||||||
|
}
|
||||||
|
go controlSrv.Serve(controlLn)
|
||||||
|
|
||||||
|
const fakeControlIP = "1.2.3.4"
|
||||||
|
const fakeProxyIP = "5.6.7.8"
|
||||||
|
|
||||||
|
dialer := &tsdial.Dialer{}
|
||||||
|
dialer.SetNetMon(netmon.NewStatic())
|
||||||
|
dialer.SetSystemDialerForTest(func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||||
|
host, _, err := net.SplitHostPort(addr)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("SplitHostPort(%q): %v", addr, err)
|
||||||
|
}
|
||||||
|
var d net.Dialer
|
||||||
|
if host == fakeControlIP {
|
||||||
|
return d.DialContext(ctx, network, controlLn.Addr().String())
|
||||||
|
}
|
||||||
|
if host == fakeProxyIP {
|
||||||
|
return d.DialContext(ctx, network, proxyLn.Addr().String())
|
||||||
|
}
|
||||||
|
return nil, fmt.Errorf("unexpected dial to %q", addr)
|
||||||
|
})
|
||||||
|
|
||||||
|
opts := Options{
|
||||||
|
Persist: persist.Persist{},
|
||||||
|
GetMachinePrivateKey: func() (key.MachinePrivate, error) {
|
||||||
|
return key.NewMachine(), nil
|
||||||
|
},
|
||||||
|
AuthKey: requiredAuthKey,
|
||||||
|
ServerURL: "https://controlplane.tstest",
|
||||||
|
Clock: tstime.StdClock{},
|
||||||
|
Hostinfo: &tailcfg.Hostinfo{
|
||||||
|
BackendLogID: "test-backend-log-id",
|
||||||
|
},
|
||||||
|
DiscoPublicKey: key.NewDisco().Public(),
|
||||||
|
Logf: t.Logf,
|
||||||
|
HealthTracker: &health.Tracker{},
|
||||||
|
PopBrowserURL: func(url string) {
|
||||||
|
t.Logf("PopBrowserURL: %q", url)
|
||||||
|
},
|
||||||
|
Dialer: dialer,
|
||||||
|
}
|
||||||
|
d, err := NewDirect(opts)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("NewDirect: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
d.dnsCache.LookupIPForTest = func(ctx context.Context, host string) ([]netip.Addr, error) {
|
||||||
|
switch host {
|
||||||
|
case "controlplane.tstest":
|
||||||
|
return []netip.Addr{netip.MustParseAddr(fakeControlIP)}, nil
|
||||||
|
case "proxy.tstest":
|
||||||
|
if !withProxy {
|
||||||
|
t.Errorf("unexpected DNS lookup for %q with proxy disabled", host)
|
||||||
|
return nil, fmt.Errorf("unexpected DNS lookup for %q", host)
|
||||||
|
}
|
||||||
|
return []netip.Addr{netip.MustParseAddr(fakeProxyIP)}, nil
|
||||||
|
}
|
||||||
|
t.Errorf("unexpected DNS query for %q", host)
|
||||||
|
return []netip.Addr{}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var proxyReqs atomic.Int64
|
||||||
|
if withProxy {
|
||||||
|
d.httpc.Transport.(*http.Transport).Proxy = func(req *http.Request) (*url.URL, error) {
|
||||||
|
t.Logf("using proxy for %q", req.URL)
|
||||||
|
u := &url.URL{
|
||||||
|
Scheme: "https",
|
||||||
|
Host: "proxy.tstest:443",
|
||||||
|
User: url.UserPassword(someUsername, somePassword),
|
||||||
|
}
|
||||||
|
return u, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
connectProxy := &http.Server{
|
||||||
|
Handler: connectProxyTo(t, "controlplane.tstest:443", controlLn.Addr().String(), &proxyReqs),
|
||||||
|
}
|
||||||
|
go connectProxy.Serve(proxyLn)
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
url, err := d.TryLogin(ctx, LoginEphemeral)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("TryLogin: %v", err)
|
||||||
|
}
|
||||||
|
if url != "" {
|
||||||
|
t.Errorf("got URL %q, want empty", url)
|
||||||
|
}
|
||||||
|
|
||||||
|
if withProxy {
|
||||||
|
if got, want := proxyReqs.Load(), int64(1); got != want {
|
||||||
|
t.Errorf("proxy CONNECT requests = %d; want %d", got, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func connectProxyTo(t testing.TB, target, backendAddrPort string, reqs *atomic.Int64) http.Handler {
|
||||||
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if r.RequestURI != target {
|
||||||
|
t.Errorf("invalid CONNECT request to %q; want %q", r.RequestURI, target)
|
||||||
|
http.Error(w, "bad target", http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
r.Header.Set("Authorization", r.Header.Get("Proxy-Authorization")) // for the BasicAuth method. kinda trashy.
|
||||||
|
user, pass, ok := r.BasicAuth()
|
||||||
|
if !ok || user != "testuser" || pass != "testpass" {
|
||||||
|
t.Errorf("invalid CONNECT auth %q:%q; want %q:%q", user, pass, "testuser", "testpass")
|
||||||
|
http.Error(w, "bad auth", http.StatusUnauthorized)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
(&connectproxy.Handler{
|
||||||
|
Dial: func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||||
|
var d net.Dialer
|
||||||
|
c, err := d.DialContext(ctx, network, backendAddrPort)
|
||||||
|
if err == nil {
|
||||||
|
reqs.Add(1)
|
||||||
|
}
|
||||||
|
return c, err
|
||||||
|
},
|
||||||
|
Logf: t.Logf,
|
||||||
|
}).ServeHTTP(w, r)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
@ -16,7 +16,6 @@ import (
|
|||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"net/url"
|
|
||||||
"os"
|
"os"
|
||||||
"reflect"
|
"reflect"
|
||||||
"runtime"
|
"runtime"
|
||||||
@ -240,10 +239,6 @@ func NewDirect(opts Options) (*Direct, error) {
|
|||||||
opts.ControlKnobs = &controlknobs.Knobs{}
|
opts.ControlKnobs = &controlknobs.Knobs{}
|
||||||
}
|
}
|
||||||
opts.ServerURL = strings.TrimRight(opts.ServerURL, "/")
|
opts.ServerURL = strings.TrimRight(opts.ServerURL, "/")
|
||||||
serverURL, err := url.Parse(opts.ServerURL)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if opts.Clock == nil {
|
if opts.Clock == nil {
|
||||||
opts.Clock = tstime.StdClock{}
|
opts.Clock = tstime.StdClock{}
|
||||||
}
|
}
|
||||||
@ -273,7 +268,7 @@ func NewDirect(opts Options) (*Direct, error) {
|
|||||||
tr := http.DefaultTransport.(*http.Transport).Clone()
|
tr := http.DefaultTransport.(*http.Transport).Clone()
|
||||||
tr.Proxy = tshttpproxy.ProxyFromEnvironment
|
tr.Proxy = tshttpproxy.ProxyFromEnvironment
|
||||||
tshttpproxy.SetTransportGetProxyConnectHeader(tr)
|
tshttpproxy.SetTransportGetProxyConnectHeader(tr)
|
||||||
tr.TLSClientConfig = tlsdial.Config(serverURL.Hostname(), opts.HealthTracker, tr.TLSClientConfig)
|
tr.TLSClientConfig = tlsdial.Config(opts.HealthTracker, tr.TLSClientConfig)
|
||||||
var dialFunc netx.DialFunc
|
var dialFunc netx.DialFunc
|
||||||
dialFunc, interceptedDial = makeScreenTimeDetectingDialFunc(opts.Dialer.SystemDial)
|
dialFunc, interceptedDial = makeScreenTimeDetectingDialFunc(opts.Dialer.SystemDial)
|
||||||
tr.DialContext = dnscache.Dialer(dialFunc, dnsCache)
|
tr.DialContext = dnscache.Dialer(dialFunc, dnsCache)
|
||||||
|
@ -534,7 +534,7 @@ func (a *Dialer) tryURLUpgrade(ctx context.Context, u *url.URL, optAddr netip.Ad
|
|||||||
// Disable HTTP2, since h2 can't do protocol switching.
|
// Disable HTTP2, since h2 can't do protocol switching.
|
||||||
tr.TLSClientConfig.NextProtos = []string{}
|
tr.TLSClientConfig.NextProtos = []string{}
|
||||||
tr.TLSNextProto = map[string]func(string, *tls.Conn) http.RoundTripper{}
|
tr.TLSNextProto = map[string]func(string, *tls.Conn) http.RoundTripper{}
|
||||||
tr.TLSClientConfig = tlsdial.Config(a.Hostname, a.HealthTracker, tr.TLSClientConfig)
|
tr.TLSClientConfig = tlsdial.Config(a.HealthTracker, tr.TLSClientConfig)
|
||||||
if !tr.TLSClientConfig.InsecureSkipVerify {
|
if !tr.TLSClientConfig.InsecureSkipVerify {
|
||||||
panic("unexpected") // should be set by tlsdial.Config
|
panic("unexpected") // should be set by tlsdial.Config
|
||||||
}
|
}
|
||||||
|
@ -647,12 +647,13 @@ func (c *Client) dialRegion(ctx context.Context, reg *tailcfg.DERPRegion) (net.C
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (c *Client) tlsClient(nc net.Conn, node *tailcfg.DERPNode) *tls.Conn {
|
func (c *Client) tlsClient(nc net.Conn, node *tailcfg.DERPNode) *tls.Conn {
|
||||||
tlsConf := tlsdial.Config(c.tlsServerName(node), c.HealthTracker, c.TLSConfig)
|
tlsConf := tlsdial.Config(c.HealthTracker, c.TLSConfig)
|
||||||
if node != nil {
|
if node != nil {
|
||||||
if node.InsecureForTests {
|
if node.InsecureForTests {
|
||||||
tlsConf.InsecureSkipVerify = true
|
tlsConf.InsecureSkipVerify = true
|
||||||
tlsConf.VerifyConnection = nil
|
tlsConf.VerifyConnection = nil
|
||||||
}
|
}
|
||||||
|
tlsConf.ServerName = c.tlsServerName(node)
|
||||||
if node.CertName != "" {
|
if node.CertName != "" {
|
||||||
if suf, ok := strings.CutPrefix(node.CertName, "sha256-raw:"); ok {
|
if suf, ok := strings.CutPrefix(node.CertName, "sha256-raw:"); ok {
|
||||||
tlsdial.SetConfigExpectedCertHash(tlsConf, suf)
|
tlsdial.SetConfigExpectedCertHash(tlsConf, suf)
|
||||||
|
@ -7,10 +7,14 @@ import (
|
|||||||
"bytes"
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
|
"encoding/json"
|
||||||
|
"flag"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"maps"
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
|
"slices"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"testing"
|
"testing"
|
||||||
@ -19,6 +23,7 @@ import (
|
|||||||
"tailscale.com/derp"
|
"tailscale.com/derp"
|
||||||
"tailscale.com/net/netmon"
|
"tailscale.com/net/netmon"
|
||||||
"tailscale.com/net/netx"
|
"tailscale.com/net/netx"
|
||||||
|
"tailscale.com/tailcfg"
|
||||||
"tailscale.com/types/key"
|
"tailscale.com/types/key"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -556,3 +561,32 @@ func TestNotifyError(t *testing.T) {
|
|||||||
t.Fatalf("context done before receiving error: %v", ctx.Err())
|
t.Fatalf("context done before receiving error: %v", ctx.Err())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var liveNetworkTest = flag.Bool("live-net-tests", false, "run live network tests")
|
||||||
|
|
||||||
|
func TestManualDial(t *testing.T) {
|
||||||
|
if !*liveNetworkTest {
|
||||||
|
t.Skip("skipping live network test without --live-net-tests")
|
||||||
|
}
|
||||||
|
dm := &tailcfg.DERPMap{}
|
||||||
|
res, err := http.Get("https://controlplane.tailscale.com/derpmap/default")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("fetching DERPMap: %v", err)
|
||||||
|
}
|
||||||
|
defer res.Body.Close()
|
||||||
|
if err := json.NewDecoder(res.Body).Decode(dm); err != nil {
|
||||||
|
t.Fatalf("decoding DERPMap: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
region := slices.Sorted(maps.Keys(dm.Regions))[0]
|
||||||
|
|
||||||
|
netMon := netmon.NewStatic()
|
||||||
|
rc := NewRegionClient(key.NewNode(), t.Logf, netMon, func() *tailcfg.DERPRegion {
|
||||||
|
return dm.Regions[region]
|
||||||
|
})
|
||||||
|
defer rc.Close()
|
||||||
|
|
||||||
|
if err := rc.Connect(context.Background()); err != nil {
|
||||||
|
t.Fatalf("rc.Connect: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
@ -9,7 +9,6 @@ package logpolicy
|
|||||||
import (
|
import (
|
||||||
"bufio"
|
"bufio"
|
||||||
"bytes"
|
"bytes"
|
||||||
"cmp"
|
|
||||||
"context"
|
"context"
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
@ -911,8 +910,7 @@ func (opts TransportOptions) New() http.RoundTripper {
|
|||||||
tr.TLSNextProto = map[string]func(authority string, c *tls.Conn) http.RoundTripper{}
|
tr.TLSNextProto = map[string]func(authority string, c *tls.Conn) http.RoundTripper{}
|
||||||
}
|
}
|
||||||
|
|
||||||
host := cmp.Or(opts.Host, logtail.DefaultHost)
|
tr.TLSClientConfig = tlsdial.Config(opts.Health, tr.TLSClientConfig)
|
||||||
tr.TLSClientConfig = tlsdial.Config(host, opts.Health, tr.TLSClientConfig)
|
|
||||||
// Force TLS 1.3 since we know log.tailscale.com supports it.
|
// Force TLS 1.3 since we know log.tailscale.com supports it.
|
||||||
tr.TLSClientConfig.MinVersion = tls.VersionTLS13
|
tr.TLSClientConfig.MinVersion = tls.VersionTLS13
|
||||||
|
|
||||||
|
@ -7,6 +7,7 @@ package bakedroots
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"crypto/x509"
|
"crypto/x509"
|
||||||
|
"fmt"
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
"tailscale.com/util/testenv"
|
"tailscale.com/util/testenv"
|
||||||
@ -14,7 +15,7 @@ import (
|
|||||||
|
|
||||||
// Get returns the baked-in roots.
|
// Get returns the baked-in roots.
|
||||||
//
|
//
|
||||||
// As of 2025-01-21, this includes only the LetsEncrypt ISRG Root X1 root.
|
// As of 2025-01-21, this includes only the LetsEncrypt ISRG Root X1 & X2 roots.
|
||||||
func Get() *x509.CertPool {
|
func Get() *x509.CertPool {
|
||||||
roots.once.Do(func() {
|
roots.once.Do(func() {
|
||||||
roots.parsePEM(append(
|
roots.parsePEM(append(
|
||||||
@ -56,7 +57,7 @@ type rootsOnce struct {
|
|||||||
func (r *rootsOnce) parsePEM(caPEM []byte) {
|
func (r *rootsOnce) parsePEM(caPEM []byte) {
|
||||||
p := x509.NewCertPool()
|
p := x509.NewCertPool()
|
||||||
if !p.AppendCertsFromPEM(caPEM) {
|
if !p.AppendCertsFromPEM(caPEM) {
|
||||||
panic("bogus PEM")
|
panic(fmt.Sprintf("bogus PEM: %q", caPEM))
|
||||||
}
|
}
|
||||||
r.p = p
|
r.p = p
|
||||||
}
|
}
|
||||||
|
93
net/connectproxy/connectproxy.go
Normal file
93
net/connectproxy/connectproxy.go
Normal file
@ -0,0 +1,93 @@
|
|||||||
|
// Copyright (c) Tailscale Inc & AUTHORS
|
||||||
|
// SPDX-License-Identifier: BSD-3-Clause
|
||||||
|
|
||||||
|
// Package connectproxy contains some CONNECT proxy code.
|
||||||
|
package connectproxy
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"io"
|
||||||
|
"log"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"tailscale.com/net/netx"
|
||||||
|
"tailscale.com/types/logger"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Handler is an HTTP CONNECT proxy handler.
|
||||||
|
type Handler struct {
|
||||||
|
// Dial, if non-nil, is an alternate dialer to use
|
||||||
|
// instead of the default dialer.
|
||||||
|
Dial netx.DialFunc
|
||||||
|
|
||||||
|
// Logf, if non-nil, is an alterate logger to
|
||||||
|
// use instead of log.Printf.
|
||||||
|
Logf logger.Logf
|
||||||
|
|
||||||
|
// Check, if non-nil, validates the CONNECT target.
|
||||||
|
Check func(hostPort string) error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||||
|
ctx := r.Context()
|
||||||
|
if r.Method != "CONNECT" {
|
||||||
|
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
dial := h.Dial
|
||||||
|
if dial == nil {
|
||||||
|
var d net.Dialer
|
||||||
|
dial = d.DialContext
|
||||||
|
}
|
||||||
|
logf := h.Logf
|
||||||
|
if logf == nil {
|
||||||
|
logf = log.Printf
|
||||||
|
}
|
||||||
|
|
||||||
|
hostPort := r.RequestURI
|
||||||
|
if h.Check != nil {
|
||||||
|
if err := h.Check(hostPort); err != nil {
|
||||||
|
logf("CONNECT target %q not allowed: %v", hostPort, err)
|
||||||
|
http.Error(w, "Invalid CONNECT target", http.StatusForbidden)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(ctx, 15*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
back, err := dial(ctx, "tcp", hostPort)
|
||||||
|
if err != nil {
|
||||||
|
logf("error CONNECT dialing %v: %v", hostPort, err)
|
||||||
|
http.Error(w, "Connect failure", http.StatusBadGateway)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer back.Close()
|
||||||
|
|
||||||
|
hj, ok := w.(http.Hijacker)
|
||||||
|
if !ok {
|
||||||
|
http.Error(w, "CONNECT hijack unavailable", http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
c, br, err := hj.Hijack()
|
||||||
|
if err != nil {
|
||||||
|
logf("CONNECT hijack: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer c.Close()
|
||||||
|
|
||||||
|
io.WriteString(c, "HTTP/1.1 200 OK\r\n\r\n")
|
||||||
|
|
||||||
|
errc := make(chan error, 2)
|
||||||
|
go func() {
|
||||||
|
_, err := io.Copy(c, back)
|
||||||
|
errc <- err
|
||||||
|
}()
|
||||||
|
go func() {
|
||||||
|
_, err := io.Copy(back, br)
|
||||||
|
errc <- err
|
||||||
|
}()
|
||||||
|
<-errc
|
||||||
|
}
|
@ -24,6 +24,7 @@ import (
|
|||||||
"tailscale.com/util/cloudenv"
|
"tailscale.com/util/cloudenv"
|
||||||
"tailscale.com/util/singleflight"
|
"tailscale.com/util/singleflight"
|
||||||
"tailscale.com/util/slicesx"
|
"tailscale.com/util/slicesx"
|
||||||
|
"tailscale.com/util/testenv"
|
||||||
)
|
)
|
||||||
|
|
||||||
var zaddr netip.Addr
|
var zaddr netip.Addr
|
||||||
@ -63,6 +64,10 @@ type Resolver struct {
|
|||||||
// If nil, net.DefaultResolver is used.
|
// If nil, net.DefaultResolver is used.
|
||||||
Forward *net.Resolver
|
Forward *net.Resolver
|
||||||
|
|
||||||
|
// LookupIPForTest, if non-nil and in tests, handles requests instead
|
||||||
|
// of the usual mechanisms.
|
||||||
|
LookupIPForTest func(ctx context.Context, host string) ([]netip.Addr, error)
|
||||||
|
|
||||||
// LookupIPFallback optionally provides a backup DNS mechanism
|
// LookupIPFallback optionally provides a backup DNS mechanism
|
||||||
// to use if Forward returns an error or no results.
|
// to use if Forward returns an error or no results.
|
||||||
LookupIPFallback func(ctx context.Context, host string) ([]netip.Addr, error)
|
LookupIPFallback func(ctx context.Context, host string) ([]netip.Addr, error)
|
||||||
@ -284,7 +289,13 @@ func (r *Resolver) lookupIP(ctx context.Context, host string) (ip, ip6 netip.Add
|
|||||||
|
|
||||||
lookupCtx, lookupCancel := context.WithTimeout(ctx, r.lookupTimeoutForHost(host))
|
lookupCtx, lookupCancel := context.WithTimeout(ctx, r.lookupTimeoutForHost(host))
|
||||||
defer lookupCancel()
|
defer lookupCancel()
|
||||||
ips, err := r.fwd().LookupNetIP(lookupCtx, "ip", host)
|
|
||||||
|
var ips []netip.Addr
|
||||||
|
if r.LookupIPForTest != nil && testenv.InTest() {
|
||||||
|
ips, err = r.LookupIPForTest(ctx, host)
|
||||||
|
} else {
|
||||||
|
ips, err = r.fwd().LookupNetIP(lookupCtx, "ip", host)
|
||||||
|
}
|
||||||
if err != nil || len(ips) == 0 {
|
if err != nil || len(ips) == 0 {
|
||||||
if resolver, ok := r.cloudHostResolver(); ok {
|
if resolver, ok := r.cloudHostResolver(); ok {
|
||||||
r.dlogf("resolving %q via cloud resolver", host)
|
r.dlogf("resolving %q via cloud resolver", host)
|
||||||
|
@ -286,7 +286,7 @@ func bootstrapDNSMap(ctx context.Context, serverName string, serverIP netip.Addr
|
|||||||
tr.DialContext = func(ctx context.Context, netw, addr string) (net.Conn, error) {
|
tr.DialContext = func(ctx context.Context, netw, addr string) (net.Conn, error) {
|
||||||
return dialer.DialContext(ctx, "tcp", net.JoinHostPort(serverIP.String(), "443"))
|
return dialer.DialContext(ctx, "tcp", net.JoinHostPort(serverIP.String(), "443"))
|
||||||
}
|
}
|
||||||
tr.TLSClientConfig = tlsdial.Config(serverName, ht, tr.TLSClientConfig)
|
tr.TLSClientConfig = tlsdial.Config(ht, tr.TLSClientConfig)
|
||||||
c := &http.Client{Transport: tr}
|
c := &http.Client{Transport: tr}
|
||||||
req, err := http.NewRequestWithContext(ctx, "GET", "https://"+serverName+"/bootstrap-dns?q="+url.QueryEscape(queryName), nil)
|
req, err := http.NewRequestWithContext(ctx, "GET", "https://"+serverName+"/bootstrap-dns?q="+url.QueryEscape(queryName), nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -59,18 +59,26 @@ var mitmBlockWarnable = health.Register(&health.Warnable{
|
|||||||
ImpactsConnectivity: true,
|
ImpactsConnectivity: true,
|
||||||
})
|
})
|
||||||
|
|
||||||
// Config returns a tls.Config for connecting to a server.
|
// Config returns a tls.Config for connecting to a server that
|
||||||
|
// uses system roots for validation but, if those fail, also tries
|
||||||
|
// the baked-in LetsEncrypt roots as a fallback validation method.
|
||||||
|
//
|
||||||
// If base is non-nil, it's cloned as the base config before
|
// If base is non-nil, it's cloned as the base config before
|
||||||
// being configured and returned.
|
// being configured and returned.
|
||||||
// If ht is non-nil, it's used to report health errors.
|
// If ht is non-nil, it's used to report health errors.
|
||||||
func Config(host string, ht *health.Tracker, base *tls.Config) *tls.Config {
|
func Config(ht *health.Tracker, base *tls.Config) *tls.Config {
|
||||||
var conf *tls.Config
|
var conf *tls.Config
|
||||||
if base == nil {
|
if base == nil {
|
||||||
conf = new(tls.Config)
|
conf = new(tls.Config)
|
||||||
} else {
|
} else {
|
||||||
conf = base.Clone()
|
conf = base.Clone()
|
||||||
}
|
}
|
||||||
conf.ServerName = host
|
|
||||||
|
// Note: we do NOT set conf.ServerName here (as we accidentally did
|
||||||
|
// previously), as this path is also used when dialing an HTTPS proxy server
|
||||||
|
// (through which we'll send a CONNECT request to get a TCP connection to do
|
||||||
|
// the real TCP connection) because host is the ultimate hostname, but this
|
||||||
|
// tls.Config is used for both the proxy and the ultimate target.
|
||||||
|
|
||||||
if n := sslKeyLogFile; n != "" {
|
if n := sslKeyLogFile; n != "" {
|
||||||
f, err := os.OpenFile(n, os.O_CREATE|os.O_APPEND|os.O_WRONLY, 0600)
|
f, err := os.OpenFile(n, os.O_CREATE|os.O_APPEND|os.O_WRONLY, 0600)
|
||||||
@ -93,7 +101,9 @@ func Config(host string, ht *health.Tracker, base *tls.Config) *tls.Config {
|
|||||||
// (with the baked-in fallback root) in the VerifyConnection hook.
|
// (with the baked-in fallback root) in the VerifyConnection hook.
|
||||||
conf.InsecureSkipVerify = true
|
conf.InsecureSkipVerify = true
|
||||||
conf.VerifyConnection = func(cs tls.ConnectionState) (retErr error) {
|
conf.VerifyConnection = func(cs tls.ConnectionState) (retErr error) {
|
||||||
if host == "log.tailscale.com" && hostinfo.IsNATLabGuestVM() {
|
dialedHost := cs.ServerName
|
||||||
|
|
||||||
|
if dialedHost == "log.tailscale.com" && hostinfo.IsNATLabGuestVM() {
|
||||||
// Allow log.tailscale.com TLS MITM for integration tests when
|
// Allow log.tailscale.com TLS MITM for integration tests when
|
||||||
// the client's running within a NATLab VM.
|
// the client's running within a NATLab VM.
|
||||||
return nil
|
return nil
|
||||||
@ -116,7 +126,7 @@ func Config(host string, ht *health.Tracker, base *tls.Config) *tls.Config {
|
|||||||
// Show a dedicated warning.
|
// Show a dedicated warning.
|
||||||
m, ok := blockblame.VerifyCertificate(cert)
|
m, ok := blockblame.VerifyCertificate(cert)
|
||||||
if ok {
|
if ok {
|
||||||
log.Printf("tlsdial: server cert for %q looks like %q equipment (could be blocking Tailscale)", host, m.Name)
|
log.Printf("tlsdial: server cert seen while dialing %q looks like %q equipment (could be blocking Tailscale)", dialedHost, m.Name)
|
||||||
ht.SetUnhealthy(mitmBlockWarnable, health.Args{"manufacturer": m.Name})
|
ht.SetUnhealthy(mitmBlockWarnable, health.Args{"manufacturer": m.Name})
|
||||||
} else {
|
} else {
|
||||||
ht.SetHealthy(mitmBlockWarnable)
|
ht.SetHealthy(mitmBlockWarnable)
|
||||||
@ -135,7 +145,7 @@ func Config(host string, ht *health.Tracker, base *tls.Config) *tls.Config {
|
|||||||
ht.SetTLSConnectionError(cs.ServerName, nil)
|
ht.SetTLSConnectionError(cs.ServerName, nil)
|
||||||
if selfSignedIssuer != "" {
|
if selfSignedIssuer != "" {
|
||||||
// Log the self-signed issuer, but don't treat it as an error.
|
// Log the self-signed issuer, but don't treat it as an error.
|
||||||
log.Printf("tlsdial: warning: server cert for %q passed x509 validation but is self-signed by %q", host, selfSignedIssuer)
|
log.Printf("tlsdial: warning: server cert for %q passed x509 validation but is self-signed by %q", dialedHost, selfSignedIssuer)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
@ -144,7 +154,7 @@ func Config(host string, ht *health.Tracker, base *tls.Config) *tls.Config {
|
|||||||
// First try doing x509 verification with the system's
|
// First try doing x509 verification with the system's
|
||||||
// root CA pool.
|
// root CA pool.
|
||||||
opts := x509.VerifyOptions{
|
opts := x509.VerifyOptions{
|
||||||
DNSName: cs.ServerName,
|
DNSName: dialedHost,
|
||||||
Intermediates: x509.NewCertPool(),
|
Intermediates: x509.NewCertPool(),
|
||||||
}
|
}
|
||||||
for _, cert := range cs.PeerCertificates[1:] {
|
for _, cert := range cs.PeerCertificates[1:] {
|
||||||
@ -152,7 +162,7 @@ func Config(host string, ht *health.Tracker, base *tls.Config) *tls.Config {
|
|||||||
}
|
}
|
||||||
_, errSys := cs.PeerCertificates[0].Verify(opts)
|
_, errSys := cs.PeerCertificates[0].Verify(opts)
|
||||||
if debug() {
|
if debug() {
|
||||||
log.Printf("tlsdial(sys %q): %v", host, errSys)
|
log.Printf("tlsdial(sys %q): %v", dialedHost, errSys)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Always verify with our baked-in Let's Encrypt certificate,
|
// Always verify with our baked-in Let's Encrypt certificate,
|
||||||
@ -161,13 +171,11 @@ func Config(host string, ht *health.Tracker, base *tls.Config) *tls.Config {
|
|||||||
opts.Roots = bakedroots.Get()
|
opts.Roots = bakedroots.Get()
|
||||||
_, bakedErr := cs.PeerCertificates[0].Verify(opts)
|
_, bakedErr := cs.PeerCertificates[0].Verify(opts)
|
||||||
if debug() {
|
if debug() {
|
||||||
log.Printf("tlsdial(bake %q): %v", host, bakedErr)
|
log.Printf("tlsdial(bake %q): %v", dialedHost, bakedErr)
|
||||||
} else if bakedErr != nil {
|
} else if bakedErr != nil {
|
||||||
if _, loaded := tlsdialWarningPrinted.LoadOrStore(host, true); !loaded {
|
if _, loaded := tlsdialWarningPrinted.LoadOrStore(dialedHost, true); !loaded {
|
||||||
if errSys == nil {
|
if errSys != nil {
|
||||||
log.Printf("tlsdial: warning: server cert for %q is not a Let's Encrypt cert", host)
|
log.Printf("tlsdial: error: server cert for %q failed both system roots & Let's Encrypt root validation", dialedHost)
|
||||||
} else {
|
|
||||||
log.Printf("tlsdial: error: server cert for %q failed to verify and is not a Let's Encrypt cert", host)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -202,9 +210,6 @@ func SetConfigExpectedCert(c *tls.Config, certDNSName string) {
|
|||||||
c.ServerName = certDNSName
|
c.ServerName = certDNSName
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if c.VerifyPeerCertificate != nil {
|
|
||||||
panic("refusing to override tls.Config.VerifyPeerCertificate")
|
|
||||||
}
|
|
||||||
// Set InsecureSkipVerify to prevent crypto/tls from doing its
|
// Set InsecureSkipVerify to prevent crypto/tls from doing its
|
||||||
// own cert verification, but do the same work that it'd do
|
// own cert verification, but do the same work that it'd do
|
||||||
// (but using certDNSName) in the VerifyPeerCertificate hook.
|
// (but using certDNSName) in the VerifyPeerCertificate hook.
|
||||||
@ -257,29 +262,30 @@ func SetConfigExpectedCertHash(c *tls.Config, wantFullCertSHA256Hex string) {
|
|||||||
if c.VerifyPeerCertificate != nil {
|
if c.VerifyPeerCertificate != nil {
|
||||||
panic("refusing to override tls.Config.VerifyPeerCertificate")
|
panic("refusing to override tls.Config.VerifyPeerCertificate")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Set InsecureSkipVerify to prevent crypto/tls from doing its
|
// Set InsecureSkipVerify to prevent crypto/tls from doing its
|
||||||
// own cert verification, but do the same work that it'd do
|
// own cert verification, but do the same work that it'd do
|
||||||
// (but using certDNSName) in the VerifyPeerCertificate hook.
|
// (but using certDNSName) in the VerifyConnection hook.
|
||||||
c.InsecureSkipVerify = true
|
c.InsecureSkipVerify = true
|
||||||
c.VerifyConnection = nil
|
|
||||||
c.VerifyPeerCertificate = func(rawCerts [][]byte, _ [][]*x509.Certificate) error {
|
c.VerifyConnection = func(cs tls.ConnectionState) error {
|
||||||
|
dialedHost := cs.ServerName
|
||||||
var sawGoodCert bool
|
var sawGoodCert bool
|
||||||
for _, rawCert := range rawCerts {
|
|
||||||
cert, err := x509.ParseCertificate(rawCert)
|
for _, cert := range cs.PeerCertificates {
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("ParseCertificate: %w", err)
|
|
||||||
}
|
|
||||||
if strings.HasPrefix(cert.Subject.CommonName, derpconst.MetaCertCommonNamePrefix) {
|
if strings.HasPrefix(cert.Subject.CommonName, derpconst.MetaCertCommonNamePrefix) {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
if sawGoodCert {
|
if sawGoodCert {
|
||||||
return errors.New("unexpected multiple certs presented")
|
return errors.New("unexpected multiple certs presented")
|
||||||
}
|
}
|
||||||
if fmt.Sprintf("%02x", sha256.Sum256(rawCert)) != wantFullCertSHA256Hex {
|
if fmt.Sprintf("%02x", sha256.Sum256(cert.Raw)) != wantFullCertSHA256Hex {
|
||||||
return fmt.Errorf("cert hash does not match expected cert hash")
|
return fmt.Errorf("cert hash does not match expected cert hash")
|
||||||
}
|
}
|
||||||
if err := cert.VerifyHostname(c.ServerName); err != nil {
|
if dialedHost != "" { // it's empty when dialing a derper by IP with no hostname
|
||||||
return fmt.Errorf("cert does not match server name %q: %w", c.ServerName, err)
|
if err := cert.VerifyHostname(dialedHost); err != nil {
|
||||||
|
return fmt.Errorf("cert does not match server name %q: %w", dialedHost, err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
if now.After(cert.NotAfter) {
|
if now.After(cert.NotAfter) {
|
||||||
@ -302,12 +308,8 @@ func SetConfigExpectedCertHash(c *tls.Config, wantFullCertSHA256Hex string) {
|
|||||||
func NewTransport() *http.Transport {
|
func NewTransport() *http.Transport {
|
||||||
return &http.Transport{
|
return &http.Transport{
|
||||||
DialTLSContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
|
DialTLSContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||||
host, _, err := net.SplitHostPort(addr)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
var d tls.Dialer
|
var d tls.Dialer
|
||||||
d.Config = Config(host, nil, nil)
|
d.Config = Config(nil, nil)
|
||||||
return d.DialContext(ctx, network, addr)
|
return d.DialContext(ctx, network, addr)
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
@ -86,7 +86,7 @@ func TestFallbackRootWorks(t *testing.T) {
|
|||||||
DisableKeepAlives: true, // for test cleanup ease
|
DisableKeepAlives: true, // for test cleanup ease
|
||||||
}
|
}
|
||||||
ht := new(health.Tracker)
|
ht := new(health.Tracker)
|
||||||
tr.TLSClientConfig = Config("tlsdial.test", ht, tr.TLSClientConfig)
|
tr.TLSClientConfig = Config(ht, tr.TLSClientConfig)
|
||||||
c := &http.Client{Transport: tr}
|
c := &http.Client{Transport: tr}
|
||||||
|
|
||||||
ctr0 := atomic.LoadInt32(&counterFallbackOK)
|
ctr0 := atomic.LoadInt32(&counterFallbackOK)
|
||||||
|
5
tstest/tlstest/testdata/controlplane.tstest.key
vendored
Normal file
5
tstest/tlstest/testdata/controlplane.tstest.key
vendored
Normal file
@ -0,0 +1,5 @@
|
|||||||
|
-----BEGIN EC PRIVATE KEY-----
|
||||||
|
MHcCAQEEIHcxOQNVyqvBSSlu7c93QW6OsyccjL+R1evW4acd32MWoAoGCCqGSM49
|
||||||
|
AwEHoUQDQgAEIOY5/CQ8CMuKYPLf+r6OEneqfzQ5RfgPnLdkL22qhm8xb69ZCXxz
|
||||||
|
UecawU0KEDfHLYbUYXSuhAFxxuPh9I3x5Q==
|
||||||
|
-----END EC PRIVATE KEY-----
|
5
tstest/tlstest/testdata/proxy.tstest.key
vendored
Normal file
5
tstest/tlstest/testdata/proxy.tstest.key
vendored
Normal file
@ -0,0 +1,5 @@
|
|||||||
|
-----BEGIN EC PRIVATE KEY-----
|
||||||
|
MHcCAQEEING1XBDWFXQjqBmLjhp20hXOf2rk/I0N6W7muv9RVvk3oAoGCCqGSM49
|
||||||
|
AwEHoUQDQgAE8lxnEEeLqYikwmXbXSsIQSw20R0oLA831s960KQZEgt0P9SbWcJc
|
||||||
|
QTk98rdfYT/QDdHn157Oh4FPcDtxmdQ4vw==
|
||||||
|
-----END EC PRIVATE KEY-----
|
5
tstest/tlstest/testdata/root-ca.key
vendored
Normal file
5
tstest/tlstest/testdata/root-ca.key
vendored
Normal file
@ -0,0 +1,5 @@
|
|||||||
|
-----BEGIN EC PRIVATE KEY-----
|
||||||
|
MHcCAQEEIMl3xjqt1dnXBpYJSEqevirAcnSJ79I2tucdRazlrDG9oAoGCCqGSM49
|
||||||
|
AwEHoUQDQgAEQ/+Jme+16hgO7TtPSIFHVV0Yt969ltVlARVcNUZmWc0upQaq7uiJ
|
||||||
|
Aur5KtzwxU3YI4bhNK0593OK2TLvEEWIdw==
|
||||||
|
-----END EC PRIVATE KEY-----
|
167
tstest/tlstest/tlstest.go
Normal file
167
tstest/tlstest/tlstest.go
Normal file
@ -0,0 +1,167 @@
|
|||||||
|
// Copyright (c) Tailscale Inc & AUTHORS
|
||||||
|
// SPDX-License-Identifier: BSD-3-Clause
|
||||||
|
|
||||||
|
// Package tlstest contains code to help test Tailscale's client proxy support.
|
||||||
|
package tlstest
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"crypto/ecdsa"
|
||||||
|
"crypto/rand"
|
||||||
|
"crypto/tls"
|
||||||
|
"crypto/x509"
|
||||||
|
"crypto/x509/pkix"
|
||||||
|
_ "embed"
|
||||||
|
"encoding/pem"
|
||||||
|
"fmt"
|
||||||
|
"math/big"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Some baked-in ECDSA keys to speed up tests, not having to burn CPU to
|
||||||
|
// generate them each time. We only make the certs (which have expiry times)
|
||||||
|
// at runtime.
|
||||||
|
//
|
||||||
|
// They were made with:
|
||||||
|
//
|
||||||
|
// openssl ecparam -name prime256v1 -genkey -noout -out root-ca.key
|
||||||
|
var (
|
||||||
|
//go:embed testdata/root-ca.key
|
||||||
|
rootCAKeyPEM []byte
|
||||||
|
|
||||||
|
// TestProxyServerKey is the PEM private key for [TestProxyServerCert].
|
||||||
|
//
|
||||||
|
//go:embed testdata/proxy.tstest.key
|
||||||
|
TestProxyServerKey []byte
|
||||||
|
|
||||||
|
// TestControlPlaneKey is the PEM private key for [TestControlPlaneCert].
|
||||||
|
//
|
||||||
|
//go:embed testdata/controlplane.tstest.key
|
||||||
|
TestControlPlaneKey []byte
|
||||||
|
)
|
||||||
|
|
||||||
|
// TestRootCA returns a self-signed ECDSA root CA certificate (as PEM) for
|
||||||
|
// testing purposes.
|
||||||
|
func TestRootCA() []byte {
|
||||||
|
return bytes.Clone(testRootCAOncer())
|
||||||
|
}
|
||||||
|
|
||||||
|
var testRootCAOncer = sync.OnceValue(func() []byte {
|
||||||
|
key := rootCAKey()
|
||||||
|
now := time.Now().Add(-time.Hour)
|
||||||
|
tpl := &x509.Certificate{
|
||||||
|
SerialNumber: big.NewInt(1),
|
||||||
|
Subject: pkix.Name{
|
||||||
|
CommonName: "Tailscale Unit Test ECDSA Root",
|
||||||
|
Organization: []string{"Tailscale Test Org"},
|
||||||
|
},
|
||||||
|
NotBefore: now,
|
||||||
|
NotAfter: now.AddDate(5, 0, 0),
|
||||||
|
|
||||||
|
IsCA: true,
|
||||||
|
BasicConstraintsValid: true,
|
||||||
|
KeyUsage: x509.KeyUsageCertSign | x509.KeyUsageCRLSign,
|
||||||
|
SubjectKeyId: mustSKID(&key.PublicKey),
|
||||||
|
}
|
||||||
|
|
||||||
|
der, err := x509.CreateCertificate(rand.Reader, tpl, tpl, &key.PublicKey, key)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
return pemCert(der)
|
||||||
|
})
|
||||||
|
|
||||||
|
func pemCert(der []byte) []byte {
|
||||||
|
var buf bytes.Buffer
|
||||||
|
if err := pem.Encode(&buf, &pem.Block{Type: "CERTIFICATE", Bytes: der}); err != nil {
|
||||||
|
panic(fmt.Sprintf("failed to encode PEM: %v", err))
|
||||||
|
}
|
||||||
|
return buf.Bytes()
|
||||||
|
}
|
||||||
|
|
||||||
|
var rootCAKey = sync.OnceValue(func() *ecdsa.PrivateKey {
|
||||||
|
return mustParsePEM(rootCAKeyPEM, x509.ParseECPrivateKey)
|
||||||
|
})
|
||||||
|
|
||||||
|
func mustParsePEM[T any](pemBytes []byte, parse func([]byte) (T, error)) T {
|
||||||
|
block, rest := pem.Decode(pemBytes)
|
||||||
|
if block == nil || len(rest) > 0 {
|
||||||
|
panic("invalid PEM")
|
||||||
|
}
|
||||||
|
v, err := parse(block.Bytes)
|
||||||
|
if err != nil {
|
||||||
|
panic(fmt.Sprintf("invalid PEM: %v", err))
|
||||||
|
}
|
||||||
|
return v
|
||||||
|
}
|
||||||
|
|
||||||
|
// KeyPair is a simple struct to hold a certificate and its private key.
|
||||||
|
type KeyPair struct {
|
||||||
|
Domain string
|
||||||
|
KeyPEM []byte // PEM-encoded private key
|
||||||
|
}
|
||||||
|
|
||||||
|
// ServerTLSConfig returns a TLS configuration suitable for a server
|
||||||
|
// using the KeyPair's certificate and private key.
|
||||||
|
func (p KeyPair) ServerTLSConfig() *tls.Config {
|
||||||
|
cert, err := tls.X509KeyPair(p.CertPEM(), p.KeyPEM)
|
||||||
|
if err != nil {
|
||||||
|
panic("invalid TLS key pair: " + err.Error())
|
||||||
|
}
|
||||||
|
return &tls.Config{
|
||||||
|
Certificates: []tls.Certificate{cert},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ProxyServerKeyPair is a KeyPair for a test control plane server
|
||||||
|
// with domain name "proxy.tstest".
|
||||||
|
var ProxyServerKeyPair = KeyPair{
|
||||||
|
Domain: "proxy.tstest",
|
||||||
|
KeyPEM: TestProxyServerKey,
|
||||||
|
}
|
||||||
|
|
||||||
|
// ControlPlaneKeyPair is a KeyPair for a test control plane server
|
||||||
|
// with domain name "controlplane.tstest".
|
||||||
|
var ControlPlaneKeyPair = KeyPair{
|
||||||
|
Domain: "controlplane.tstest",
|
||||||
|
KeyPEM: TestControlPlaneKey,
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p KeyPair) CertPEM() []byte {
|
||||||
|
caCert := mustParsePEM(TestRootCA(), x509.ParseCertificate)
|
||||||
|
caPriv := mustParsePEM(rootCAKeyPEM, x509.ParseECPrivateKey)
|
||||||
|
leafKey := mustParsePEM(p.KeyPEM, x509.ParseECPrivateKey)
|
||||||
|
|
||||||
|
serial, err := rand.Int(rand.Reader, big.NewInt(0).Lsh(big.NewInt(1), 128))
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
now := time.Now().Add(-time.Hour)
|
||||||
|
tpl := &x509.Certificate{
|
||||||
|
SerialNumber: serial,
|
||||||
|
Subject: pkix.Name{CommonName: p.Domain},
|
||||||
|
NotBefore: now,
|
||||||
|
NotAfter: now.AddDate(2, 0, 0),
|
||||||
|
|
||||||
|
KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyEncipherment,
|
||||||
|
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
|
||||||
|
BasicConstraintsValid: true,
|
||||||
|
DNSNames: []string{p.Domain},
|
||||||
|
}
|
||||||
|
|
||||||
|
der, err := x509.CreateCertificate(rand.Reader, tpl, caCert, &leafKey.PublicKey, caPriv)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
return pemCert(der)
|
||||||
|
}
|
||||||
|
|
||||||
|
func mustSKID(pub *ecdsa.PublicKey) []byte {
|
||||||
|
skid, err := x509.MarshalPKIXPublicKey(pub)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
return skid[:20] // same as x509 library
|
||||||
|
}
|
Loading…
x
Reference in New Issue
Block a user