diff --git a/control/controlhttp/client.go b/control/controlhttp/client.go index ca407c101..7a12ba017 100644 --- a/control/controlhttp/client.go +++ b/control/controlhttp/client.go @@ -70,7 +70,6 @@ func Dial(ctx context.Context, addr string, machineKey key.MachinePrivate, contr return nil, err } a := &dialParams{ - ctx: ctx, host: host, httpPort: port, httpsPort: "443", @@ -80,11 +79,10 @@ func Dial(ctx context.Context, addr string, machineKey key.MachinePrivate, contr proxyFunc: tshttpproxy.ProxyFromEnvironment, dialer: dialer, } - return a.dial() + return a.dial(ctx) } type dialParams struct { - ctx context.Context host string httpPort string httpsPort string @@ -95,14 +93,24 @@ type dialParams struct { dialer dnscache.DialContextFunc // For tests only - insecureTLS bool + insecureTLS bool + testFallbackDelay time.Duration } -func (a *dialParams) dial() (*controlbase.Conn, error) { +// httpsFallbackDelay is how long we'll wait for a.httpPort to work before +// starting to try a.httpsPort. +func (a *dialParams) httpsFallbackDelay() time.Duration { + if v := a.testFallbackDelay; v != 0 { + return v + } + return 500 * time.Millisecond +} + +func (a *dialParams) dial(ctx context.Context) (*controlbase.Conn, error) { // Create one shared context used by both port 80 and port 443 dials. // If port 80 is still in flight when 443 returns, this deferred cancel // will stop the port 80 dial. - ctx, cancel := context.WithCancel(a.ctx) + ctx, cancel := context.WithCancel(ctx) defer cancel() // u80 and u443 are the URLs we'll try to hit over HTTP or HTTPS, @@ -118,26 +126,20 @@ func (a *dialParams) dial() (*controlbase.Conn, error) { Host: net.JoinHostPort(a.host, a.httpsPort), Path: serverUpgradePath, } + type tryURLRes struct { - u *url.URL - conn net.Conn - cont controlbase.HandshakeContinuation + u *url.URL // input (the URL conn+err are for/from) + conn *controlbase.Conn // result (mutually exclusive with err) err error } ch := make(chan tryURLRes) // must be unbuffered - try := func(u *url.URL) { - res := tryURLRes{u: u} - var init []byte - init, res.cont, res.err = controlbase.ClientDeferred(a.machineKey, a.controlKey, a.version) - if res.err == nil { - res.conn, res.err = a.tryURL(ctx, u, init) - } + cbConn, err := a.dialURL(ctx, u) select { - case ch <- res: + case ch <- tryURLRes{u, cbConn, err}: case <-ctx.Done(): - if res.conn != nil { - res.conn.Close() + if cbConn != nil { + cbConn.Close() } } } @@ -147,7 +149,7 @@ type tryURLRes struct { // In case outbound port 80 blocked or MITM'ed poorly, start a backup timer // to dial port 443 if port 80 doesn't either succeed or fail quickly. - try443Timer := time.AfterFunc(500*time.Millisecond, func() { try(u443) }) + try443Timer := time.AfterFunc(a.httpsFallbackDelay(), func() { try(u443) }) defer try443Timer.Stop() var err80, err443 error @@ -157,12 +159,7 @@ type tryURLRes struct { return nil, fmt.Errorf("connection attempts aborted by context: %w", ctx.Err()) case res := <-ch: if res.err == nil { - ret, err := res.cont(ctx, res.conn) - if err != nil { - res.conn.Close() - return nil, err - } - return ret, nil + return res.conn, nil } switch res.u { case u80: @@ -187,10 +184,28 @@ type tryURLRes struct { } } -// tryURL connects to u, and tries to upgrade it to a net.Conn. +// dialURL attempts to connect to the given URL. +func (a *dialParams) dialURL(ctx context.Context, u *url.URL) (*controlbase.Conn, error) { + init, cont, err := controlbase.ClientDeferred(a.machineKey, a.controlKey, a.version) + if err != nil { + return nil, err + } + netConn, err := a.tryURLUpgrade(ctx, u, init) + if err != nil { + return nil, err + } + cbConn, err := cont(ctx, netConn) + if err != nil { + netConn.Close() + return nil, err + } + return cbConn, nil +} + +// tryURLUpgrade connects to u, and tries to upgrade it to a net.Conn. // // Only the provided ctx is used, not a.ctx. -func (a *dialParams) tryURL(ctx context.Context, u *url.URL, init []byte) (net.Conn, error) { +func (a *dialParams) tryURLUpgrade(ctx context.Context, u *url.URL, init []byte) (net.Conn, error) { dns := &dnscache.Resolver{ Forward: dnscache.Get().Forward, LookupIPFallback: dnsfallback.Lookup, diff --git a/control/controlhttp/http_test.go b/control/controlhttp/http_test.go index 5f942c895..545b4c303 100644 --- a/control/controlhttp/http_test.go +++ b/control/controlhttp/http_test.go @@ -17,6 +17,7 @@ "strconv" "sync" "testing" + "time" "tailscale.com/control/controlbase" "tailscale.com/net/socks5" @@ -24,16 +25,28 @@ "tailscale.com/types/key" ) +type httpTestParam struct { + name string + proxy proxy + + // makeHTTPHangAfterUpgrade makes the HTTP response hang after sending a + // 101 switching protocols. + makeHTTPHangAfterUpgrade bool +} + func TestControlHTTP(t *testing.T) { - tests := []struct { - name string - proxy proxy - }{ + tests := []httpTestParam{ // direct connection { name: "no_proxy", proxy: nil, }, + // direct connection but port 80 is MITM'ed and broken + { + name: "port80_broken_mitm", + proxy: nil, + makeHTTPHangAfterUpgrade: true, + }, // SOCKS5 { name: "socks5", @@ -97,12 +110,13 @@ func TestControlHTTP(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { - testControlHTTP(t, test.proxy) + testControlHTTP(t, test) }) } } -func testControlHTTP(t *testing.T, proxy proxy) { +func testControlHTTP(t *testing.T, param httpTestParam) { + proxy := param.proxy client, server := key.NewMachine(), key.NewMachine() const testProtocolVersion = 1 @@ -133,7 +147,11 @@ func testControlHTTP(t *testing.T, proxy proxy) { t.Fatalf("HTTPS listen: %v", err) } - httpServer := &http.Server{Handler: handler} + var httpHandler http.Handler = handler + if param.makeHTTPHangAfterUpgrade { + httpHandler = http.HandlerFunc(brokenMITMHandler) + } + httpServer := &http.Server{Handler: httpHandler} go httpServer.Serve(httpLn) defer httpServer.Close() @@ -144,19 +162,24 @@ func testControlHTTP(t *testing.T, proxy proxy) { go httpsServer.ServeTLS(httpsLn, "", "") defer httpsServer.Close() - //ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) - //defer cancel() + ctx := context.Background() + const debugTimeout = false + if debugTimeout { + var cancel context.CancelFunc + ctx, cancel = context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + } a := dialParams{ - ctx: context.Background(), //ctx, - host: "localhost", - httpPort: strconv.Itoa(httpLn.Addr().(*net.TCPAddr).Port), - httpsPort: strconv.Itoa(httpsLn.Addr().(*net.TCPAddr).Port), - machineKey: client, - controlKey: server.Public(), - version: testProtocolVersion, - insecureTLS: true, - dialer: new(tsdial.Dialer).SystemDial, + host: "localhost", + httpPort: strconv.Itoa(httpLn.Addr().(*net.TCPAddr).Port), + httpsPort: strconv.Itoa(httpsLn.Addr().(*net.TCPAddr).Port), + machineKey: client, + controlKey: server.Public(), + version: testProtocolVersion, + insecureTLS: true, + dialer: new(tsdial.Dialer).SystemDial, + testFallbackDelay: 50 * time.Millisecond, } if proxy != nil { @@ -175,7 +198,7 @@ func testControlHTTP(t *testing.T, proxy proxy) { } } - conn, err := a.dial() + conn, err := a.dial(ctx) if err != nil { t.Fatalf("dialing controlhttp: %v", err) } @@ -217,6 +240,7 @@ type proxy interface { type socksProxy struct { sync.Mutex + closed bool proxy socks5.Server ln net.Listener clientConnAddrs map[string]bool // addrs of the local end of outgoing conns from proxy @@ -232,7 +256,14 @@ func (s *socksProxy) Start(t *testing.T) (url string) { } s.ln = ln s.clientConnAddrs = map[string]bool{} - s.proxy.Logf = t.Logf + s.proxy.Logf = func(format string, a ...any) { + s.Lock() + defer s.Unlock() + if s.closed { + return + } + t.Logf(format, a...) + } s.proxy.Dialer = s.dialAndRecord go s.proxy.Serve(ln) return fmt.Sprintf("socks5://%s", ln.Addr().String()) @@ -241,6 +272,10 @@ func (s *socksProxy) Start(t *testing.T) (url string) { func (s *socksProxy) Close() { s.Lock() defer s.Unlock() + if s.closed { + return + } + s.closed = true s.ln.Close() } @@ -400,3 +435,11 @@ func tlsConfig(t *testing.T) *tls.Config { Certificates: []tls.Certificate{cert}, } } + +func brokenMITMHandler(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Upgrade", upgradeHeaderValue) + w.Header().Set("Connection", "upgrade") + w.WriteHeader(http.StatusSwitchingProtocols) + w.(http.Flusher).Flush() + <-r.Context().Done() +}