control/controlhttp: don't assume port 80 upgrade response will work

Just because we get an HTTP upgrade response over port 80, don't
assume we'll be able to do bi-di Noise over it. There might be a MITM
corp proxy or anti-virus/firewall interfering. Do a bit more work to
validate the connection before proceeding to give up on the TLS port
443 dial.

Updates #4557 (probably fixes)

Change-Id: I0e1bcc195af21ad3d360ffe79daead730dfd86f1
Signed-off-by: Brad Fitzpatrick <bradfitz@tailscale.com>
This commit is contained in:
Brad Fitzpatrick 2022-04-28 08:10:26 -07:00 committed by Brad Fitzpatrick
parent 488e63979e
commit 1237000efe
2 changed files with 106 additions and 48 deletions

View File

@ -70,7 +70,6 @@ func Dial(ctx context.Context, addr string, machineKey key.MachinePrivate, contr
return nil, err return nil, err
} }
a := &dialParams{ a := &dialParams{
ctx: ctx,
host: host, host: host,
httpPort: port, httpPort: port,
httpsPort: "443", httpsPort: "443",
@ -80,11 +79,10 @@ func Dial(ctx context.Context, addr string, machineKey key.MachinePrivate, contr
proxyFunc: tshttpproxy.ProxyFromEnvironment, proxyFunc: tshttpproxy.ProxyFromEnvironment,
dialer: dialer, dialer: dialer,
} }
return a.dial() return a.dial(ctx)
} }
type dialParams struct { type dialParams struct {
ctx context.Context
host string host string
httpPort string httpPort string
httpsPort string httpsPort string
@ -96,13 +94,23 @@ type dialParams struct {
// For tests only // For tests only
insecureTLS bool insecureTLS bool
testFallbackDelay time.Duration
} }
func (a *dialParams) dial() (*controlbase.Conn, error) { // httpsFallbackDelay is how long we'll wait for a.httpPort to work before
// starting to try a.httpsPort.
func (a *dialParams) httpsFallbackDelay() time.Duration {
if v := a.testFallbackDelay; v != 0 {
return v
}
return 500 * time.Millisecond
}
func (a *dialParams) dial(ctx context.Context) (*controlbase.Conn, error) {
// Create one shared context used by both port 80 and port 443 dials. // Create one shared context used by both port 80 and port 443 dials.
// If port 80 is still in flight when 443 returns, this deferred cancel // If port 80 is still in flight when 443 returns, this deferred cancel
// will stop the port 80 dial. // will stop the port 80 dial.
ctx, cancel := context.WithCancel(a.ctx) ctx, cancel := context.WithCancel(ctx)
defer cancel() defer cancel()
// u80 and u443 are the URLs we'll try to hit over HTTP or HTTPS, // u80 and u443 are the URLs we'll try to hit over HTTP or HTTPS,
@ -118,26 +126,20 @@ func (a *dialParams) dial() (*controlbase.Conn, error) {
Host: net.JoinHostPort(a.host, a.httpsPort), Host: net.JoinHostPort(a.host, a.httpsPort),
Path: serverUpgradePath, Path: serverUpgradePath,
} }
type tryURLRes struct { type tryURLRes struct {
u *url.URL u *url.URL // input (the URL conn+err are for/from)
conn net.Conn conn *controlbase.Conn // result (mutually exclusive with err)
cont controlbase.HandshakeContinuation
err error err error
} }
ch := make(chan tryURLRes) // must be unbuffered ch := make(chan tryURLRes) // must be unbuffered
try := func(u *url.URL) { try := func(u *url.URL) {
res := tryURLRes{u: u} cbConn, err := a.dialURL(ctx, u)
var init []byte
init, res.cont, res.err = controlbase.ClientDeferred(a.machineKey, a.controlKey, a.version)
if res.err == nil {
res.conn, res.err = a.tryURL(ctx, u, init)
}
select { select {
case ch <- res: case ch <- tryURLRes{u, cbConn, err}:
case <-ctx.Done(): case <-ctx.Done():
if res.conn != nil { if cbConn != nil {
res.conn.Close() cbConn.Close()
} }
} }
} }
@ -147,7 +149,7 @@ type tryURLRes struct {
// In case outbound port 80 blocked or MITM'ed poorly, start a backup timer // In case outbound port 80 blocked or MITM'ed poorly, start a backup timer
// to dial port 443 if port 80 doesn't either succeed or fail quickly. // to dial port 443 if port 80 doesn't either succeed or fail quickly.
try443Timer := time.AfterFunc(500*time.Millisecond, func() { try(u443) }) try443Timer := time.AfterFunc(a.httpsFallbackDelay(), func() { try(u443) })
defer try443Timer.Stop() defer try443Timer.Stop()
var err80, err443 error var err80, err443 error
@ -157,12 +159,7 @@ type tryURLRes struct {
return nil, fmt.Errorf("connection attempts aborted by context: %w", ctx.Err()) return nil, fmt.Errorf("connection attempts aborted by context: %w", ctx.Err())
case res := <-ch: case res := <-ch:
if res.err == nil { if res.err == nil {
ret, err := res.cont(ctx, res.conn) return res.conn, nil
if err != nil {
res.conn.Close()
return nil, err
}
return ret, nil
} }
switch res.u { switch res.u {
case u80: case u80:
@ -187,10 +184,28 @@ type tryURLRes struct {
} }
} }
// tryURL connects to u, and tries to upgrade it to a net.Conn. // dialURL attempts to connect to the given URL.
func (a *dialParams) dialURL(ctx context.Context, u *url.URL) (*controlbase.Conn, error) {
init, cont, err := controlbase.ClientDeferred(a.machineKey, a.controlKey, a.version)
if err != nil {
return nil, err
}
netConn, err := a.tryURLUpgrade(ctx, u, init)
if err != nil {
return nil, err
}
cbConn, err := cont(ctx, netConn)
if err != nil {
netConn.Close()
return nil, err
}
return cbConn, nil
}
// tryURLUpgrade connects to u, and tries to upgrade it to a net.Conn.
// //
// Only the provided ctx is used, not a.ctx. // Only the provided ctx is used, not a.ctx.
func (a *dialParams) tryURL(ctx context.Context, u *url.URL, init []byte) (net.Conn, error) { func (a *dialParams) tryURLUpgrade(ctx context.Context, u *url.URL, init []byte) (net.Conn, error) {
dns := &dnscache.Resolver{ dns := &dnscache.Resolver{
Forward: dnscache.Get().Forward, Forward: dnscache.Get().Forward,
LookupIPFallback: dnsfallback.Lookup, LookupIPFallback: dnsfallback.Lookup,

View File

@ -17,6 +17,7 @@
"strconv" "strconv"
"sync" "sync"
"testing" "testing"
"time"
"tailscale.com/control/controlbase" "tailscale.com/control/controlbase"
"tailscale.com/net/socks5" "tailscale.com/net/socks5"
@ -24,16 +25,28 @@
"tailscale.com/types/key" "tailscale.com/types/key"
) )
func TestControlHTTP(t *testing.T) { type httpTestParam struct {
tests := []struct {
name string name string
proxy proxy proxy proxy
}{
// makeHTTPHangAfterUpgrade makes the HTTP response hang after sending a
// 101 switching protocols.
makeHTTPHangAfterUpgrade bool
}
func TestControlHTTP(t *testing.T) {
tests := []httpTestParam{
// direct connection // direct connection
{ {
name: "no_proxy", name: "no_proxy",
proxy: nil, proxy: nil,
}, },
// direct connection but port 80 is MITM'ed and broken
{
name: "port80_broken_mitm",
proxy: nil,
makeHTTPHangAfterUpgrade: true,
},
// SOCKS5 // SOCKS5
{ {
name: "socks5", name: "socks5",
@ -97,12 +110,13 @@ func TestControlHTTP(t *testing.T) {
for _, test := range tests { for _, test := range tests {
t.Run(test.name, func(t *testing.T) { t.Run(test.name, func(t *testing.T) {
testControlHTTP(t, test.proxy) testControlHTTP(t, test)
}) })
} }
} }
func testControlHTTP(t *testing.T, proxy proxy) { func testControlHTTP(t *testing.T, param httpTestParam) {
proxy := param.proxy
client, server := key.NewMachine(), key.NewMachine() client, server := key.NewMachine(), key.NewMachine()
const testProtocolVersion = 1 const testProtocolVersion = 1
@ -133,7 +147,11 @@ func testControlHTTP(t *testing.T, proxy proxy) {
t.Fatalf("HTTPS listen: %v", err) t.Fatalf("HTTPS listen: %v", err)
} }
httpServer := &http.Server{Handler: handler} var httpHandler http.Handler = handler
if param.makeHTTPHangAfterUpgrade {
httpHandler = http.HandlerFunc(brokenMITMHandler)
}
httpServer := &http.Server{Handler: httpHandler}
go httpServer.Serve(httpLn) go httpServer.Serve(httpLn)
defer httpServer.Close() defer httpServer.Close()
@ -144,11 +162,15 @@ func testControlHTTP(t *testing.T, proxy proxy) {
go httpsServer.ServeTLS(httpsLn, "", "") go httpsServer.ServeTLS(httpsLn, "", "")
defer httpsServer.Close() defer httpsServer.Close()
//ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) ctx := context.Background()
//defer cancel() const debugTimeout = false
if debugTimeout {
var cancel context.CancelFunc
ctx, cancel = context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
}
a := dialParams{ a := dialParams{
ctx: context.Background(), //ctx,
host: "localhost", host: "localhost",
httpPort: strconv.Itoa(httpLn.Addr().(*net.TCPAddr).Port), httpPort: strconv.Itoa(httpLn.Addr().(*net.TCPAddr).Port),
httpsPort: strconv.Itoa(httpsLn.Addr().(*net.TCPAddr).Port), httpsPort: strconv.Itoa(httpsLn.Addr().(*net.TCPAddr).Port),
@ -157,6 +179,7 @@ func testControlHTTP(t *testing.T, proxy proxy) {
version: testProtocolVersion, version: testProtocolVersion,
insecureTLS: true, insecureTLS: true,
dialer: new(tsdial.Dialer).SystemDial, dialer: new(tsdial.Dialer).SystemDial,
testFallbackDelay: 50 * time.Millisecond,
} }
if proxy != nil { if proxy != nil {
@ -175,7 +198,7 @@ func testControlHTTP(t *testing.T, proxy proxy) {
} }
} }
conn, err := a.dial() conn, err := a.dial(ctx)
if err != nil { if err != nil {
t.Fatalf("dialing controlhttp: %v", err) t.Fatalf("dialing controlhttp: %v", err)
} }
@ -217,6 +240,7 @@ type proxy interface {
type socksProxy struct { type socksProxy struct {
sync.Mutex sync.Mutex
closed bool
proxy socks5.Server proxy socks5.Server
ln net.Listener ln net.Listener
clientConnAddrs map[string]bool // addrs of the local end of outgoing conns from proxy clientConnAddrs map[string]bool // addrs of the local end of outgoing conns from proxy
@ -232,7 +256,14 @@ func (s *socksProxy) Start(t *testing.T) (url string) {
} }
s.ln = ln s.ln = ln
s.clientConnAddrs = map[string]bool{} s.clientConnAddrs = map[string]bool{}
s.proxy.Logf = t.Logf s.proxy.Logf = func(format string, a ...any) {
s.Lock()
defer s.Unlock()
if s.closed {
return
}
t.Logf(format, a...)
}
s.proxy.Dialer = s.dialAndRecord s.proxy.Dialer = s.dialAndRecord
go s.proxy.Serve(ln) go s.proxy.Serve(ln)
return fmt.Sprintf("socks5://%s", ln.Addr().String()) return fmt.Sprintf("socks5://%s", ln.Addr().String())
@ -241,6 +272,10 @@ func (s *socksProxy) Start(t *testing.T) (url string) {
func (s *socksProxy) Close() { func (s *socksProxy) Close() {
s.Lock() s.Lock()
defer s.Unlock() defer s.Unlock()
if s.closed {
return
}
s.closed = true
s.ln.Close() s.ln.Close()
} }
@ -400,3 +435,11 @@ func tlsConfig(t *testing.T) *tls.Config {
Certificates: []tls.Certificate{cert}, Certificates: []tls.Certificate{cert},
} }
} }
func brokenMITMHandler(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Upgrade", upgradeHeaderValue)
w.Header().Set("Connection", "upgrade")
w.WriteHeader(http.StatusSwitchingProtocols)
w.(http.Flusher).Flush()
<-r.Context().Done()
}