diff --git a/control/controlclient/noise.go b/control/controlclient/noise.go index 9dcf0351b..fcd65b836 100644 --- a/control/controlclient/noise.go +++ b/control/controlclient/noise.go @@ -210,7 +210,7 @@ func (nc *noiseClient) dial(_, _ string, _ *tls.Config) (net.Conn, error) { ctx, cancel := context.WithTimeout(context.Background(), timeout) defer cancel() - conn, err := (&controlhttp.Dialer{ + clientConn, err := (&controlhttp.Dialer{ Hostname: nc.host, HTTPPort: nc.httpPort, HTTPSPort: nc.httpsPort, @@ -226,7 +226,7 @@ func (nc *noiseClient) dial(_, _ string, _ *tls.Config) (net.Conn, error) { nc.mu.Lock() defer nc.mu.Unlock() - ncc := &noiseConn{Conn: conn, id: connID, pool: nc} + ncc := &noiseConn{Conn: clientConn.Conn, id: connID, pool: nc} mak.Set(&nc.connPool, ncc.id, ncc) return ncc, nil } diff --git a/control/controlhttp/client.go b/control/controlhttp/client.go index 1abb57e89..b27ef6279 100644 --- a/control/controlhttp/client.go +++ b/control/controlhttp/client.go @@ -60,7 +60,7 @@ // // The provided ctx is only used for the initial connection, until // Dial returns. It does not affect the connection once established. -func (a *Dialer) Dial(ctx context.Context) (*controlbase.Conn, error) { +func (a *Dialer) Dial(ctx context.Context) (*ClientConn, error) { if a.Hostname == "" { return nil, errors.New("required Dialer.Hostname empty") } @@ -91,7 +91,7 @@ func (a *Dialer) httpsFallbackDelay() time.Duration { var _ = envknob.RegisterBool("TS_USE_CONTROL_DIAL_PLAN") // to record at init time whether it's in use -func (a *Dialer) dial(ctx context.Context) (*controlbase.Conn, error) { +func (a *Dialer) dial(ctx context.Context) (*ClientConn, error) { // If we don't have a dial plan, just fall back to dialing the single // host we know about. useDialPlan := envknob.BoolDefaultTrue("TS_USE_CONTROL_DIAL_PLAN") @@ -117,7 +117,7 @@ func (a *Dialer) dial(ctx context.Context) (*controlbase.Conn, error) { // Now, for each candidate, kick off a dial in parallel. type dialResult struct { - conn *controlbase.Conn + conn *ClientConn err error addr netip.Addr priority int @@ -129,7 +129,7 @@ type dialResult struct { for _, c := range candidates { go func(ctx context.Context, c tailcfg.ControlIPCandidate) { var ( - conn *controlbase.Conn + conn *ClientConn err error ) @@ -228,7 +228,7 @@ type dialResult struct { }) var ( - conn *controlbase.Conn + conn *ClientConn errs []error ) for i, result := range results { @@ -252,7 +252,7 @@ type dialResult struct { // dialHost connects to the configured Dialer.Hostname and upgrades the // connection into a controlbase.Conn. If addr is valid, then no DNS is used // and the connection will be made to the provided address. -func (a *Dialer) dialHost(ctx context.Context, addr netip.Addr) (*controlbase.Conn, error) { +func (a *Dialer) dialHost(ctx context.Context, addr netip.Addr) (*ClientConn, error) { // Create one shared context used by both port 80 and port 443 dials. // If port 80 is still in flight when 443 returns, this deferred cancel // will stop the port 80 dial. @@ -274,8 +274,8 @@ func (a *Dialer) dialHost(ctx context.Context, addr netip.Addr) (*controlbase.Co } type tryURLRes struct { - u *url.URL // input (the URL conn+err are for/from) - conn *controlbase.Conn // result (mutually exclusive with err) + u *url.URL // input (the URL conn+err are for/from) + conn *ClientConn // result (mutually exclusive with err) err error } ch := make(chan tryURLRes) // must be unbuffered @@ -331,12 +331,12 @@ type tryURLRes struct { } // dialURL attempts to connect to the given URL. -func (a *Dialer) dialURL(ctx context.Context, u *url.URL, addr netip.Addr) (*controlbase.Conn, error) { +func (a *Dialer) dialURL(ctx context.Context, u *url.URL, addr netip.Addr) (*ClientConn, error) { init, cont, err := controlbase.ClientDeferred(a.MachineKey, a.ControlKey, a.ProtocolVersion) if err != nil { return nil, err } - netConn, err := a.tryURLUpgrade(ctx, u, addr, init) + netConn, untrustedUpgradeHeaders, err := a.tryURLUpgrade(ctx, u, addr, init) if err != nil { return nil, err } @@ -345,7 +345,10 @@ func (a *Dialer) dialURL(ctx context.Context, u *url.URL, addr netip.Addr) (*con netConn.Close() return nil, err } - return cbConn, nil + return &ClientConn{ + Conn: cbConn, + UntrustedUpgradeHeaders: untrustedUpgradeHeaders, + }, nil } // tryURLUpgrade connects to u, and tries to upgrade it to a net.Conn. If addr @@ -353,7 +356,7 @@ func (a *Dialer) dialURL(ctx context.Context, u *url.URL, addr netip.Addr) (*con // provided address. // // Only the provided ctx is used, not a.ctx. -func (a *Dialer) tryURLUpgrade(ctx context.Context, u *url.URL, addr netip.Addr, init []byte) (net.Conn, error) { +func (a *Dialer) tryURLUpgrade(ctx context.Context, u *url.URL, addr netip.Addr, init []byte) (_ net.Conn, untrustedUpgradeHeaders http.Header, _ error) { var dns *dnscache.Resolver // If we were provided an address to dial, then create a resolver that just @@ -435,11 +438,11 @@ func (a *Dialer) tryURLUpgrade(ctx context.Context, u *url.URL, addr netip.Addr, resp, err := tr.RoundTrip(req) if err != nil { - return nil, err + return nil, nil, err } if resp.StatusCode != http.StatusSwitchingProtocols { - return nil, fmt.Errorf("unexpected HTTP response: %s", resp.Status) + return nil, nil, fmt.Errorf("unexpected HTTP response: %s", resp.Status) } // From here on, the underlying net.Conn is ours to use, but there @@ -453,19 +456,19 @@ func (a *Dialer) tryURLUpgrade(ctx context.Context, u *url.URL, addr netip.Addr, } if switchedConn == nil { resp.Body.Close() - return nil, fmt.Errorf("httptrace didn't provide a connection") + return nil, nil, fmt.Errorf("httptrace didn't provide a connection") } if next := resp.Header.Get("Upgrade"); next != upgradeHeaderValue { resp.Body.Close() - return nil, fmt.Errorf("server switched to unexpected protocol %q", next) + return nil, nil, fmt.Errorf("server switched to unexpected protocol %q", next) } rwc, ok := resp.Body.(io.ReadWriteCloser) if !ok { resp.Body.Close() - return nil, errors.New("http Transport did not provide a writable body") + return nil, nil, errors.New("http Transport did not provide a writable body") } - return netutil.NewAltReadWriteCloserConn(rwc, switchedConn), nil + return netutil.NewAltReadWriteCloserConn(rwc, switchedConn), resp.Header, nil } diff --git a/control/controlhttp/client_common.go b/control/controlhttp/client_common.go new file mode 100644 index 000000000..dfc32b639 --- /dev/null +++ b/control/controlhttp/client_common.go @@ -0,0 +1,26 @@ +// Copyright (c) 2022 Tailscale Inc & AUTHORS All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package controlhttp + +import ( + "net/http" + + "tailscale.com/control/controlbase" +) + +// ClientConn is a Tailscale control client as returned by the Dialer. +// +// It's effectively just a *controlbase.Conn (which it embeds) with +// optional metadata. +type ClientConn struct { + // Conn is the noise connection. + *controlbase.Conn + + // UntrustedUpgradeHeaders are the HTTP headers seen in the + // 101 Switching Protocols upgrade response. They may be nil + // or even might've been tampered with by a middlebox. + // They should not be trusted. + UntrustedUpgradeHeaders http.Header +} diff --git a/control/controlhttp/client_js.go b/control/controlhttp/client_js.go index d31cae731..5bc2cda76 100644 --- a/control/controlhttp/client_js.go +++ b/control/controlhttp/client_js.go @@ -17,7 +17,7 @@ // Variant of Dial that tunnels the request over WebSockets, since we cannot do // bi-directional communication over an HTTP connection when in JS. -func (d *Dialer) Dial(ctx context.Context) (*controlbase.Conn, error) { +func (d *Dialer) Dial(ctx context.Context) (*ClientConn, error) { if d.Hostname == "" { return nil, errors.New("required Dialer.Hostname empty") } @@ -45,7 +45,7 @@ func (d *Dialer) Dial(ctx context.Context) (*controlbase.Conn, error) { handshakeHeaderName: []string{base64.StdEncoding.EncodeToString(init)}, }.Encode(), } - wsConn, _, err := websocket.Dial(ctx, wsURL.String(), &websocket.DialOptions{ + wsConn, httpRes, err := websocket.Dial(ctx, wsURL.String(), &websocket.DialOptions{ Subprotocols: []string{upgradeHeaderValue}, }) if err != nil { @@ -57,5 +57,8 @@ func (d *Dialer) Dial(ctx context.Context) (*controlbase.Conn, error) { netConn.Close() return nil, err } - return cbConn, nil + return &ClientConn{ + Conn: cbConn, + UntrustedUpgradeHeaders: httpRes.Header, + }, nil } diff --git a/control/controlhttp/http_test.go b/control/controlhttp/http_test.go index 323327cbb..d8ce4b43a 100644 --- a/control/controlhttp/http_test.go +++ b/control/controlhttp/http_test.go @@ -127,7 +127,7 @@ func testControlHTTP(t *testing.T, param httpTestParam) { const testProtocolVersion = 1 sch := make(chan serverResult, 1) handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - conn, err := AcceptHTTP(context.Background(), w, r, server) + conn, err := AcceptHTTP(context.Background(), w, r, server, nil) if err != nil { log.Print(err) } @@ -485,7 +485,7 @@ func TestDialPlan(t *testing.T) { close(done) }) var handler http.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - conn, err := AcceptHTTP(context.Background(), w, r, server) + conn, err := AcceptHTTP(context.Background(), w, r, server, nil) if err != nil { log.Print(err) } else { diff --git a/control/controlhttp/server.go b/control/controlhttp/server.go index 1e9ccd0ca..23a8cf8ff 100644 --- a/control/controlhttp/server.go +++ b/control/controlhttp/server.go @@ -22,7 +22,11 @@ // // AcceptHTTP always writes an HTTP response to w. The caller must not // attempt their own response after calling AcceptHTTP. -func AcceptHTTP(ctx context.Context, w http.ResponseWriter, r *http.Request, private key.MachinePrivate) (*controlbase.Conn, error) { +// +// extraHeader optionally specifies extra header(s) to send in the +// 101 Switching Protocols Upgrade response. It must not include the "Upgrade" +// or "Connection" headers; they will be replaced. +func AcceptHTTP(ctx context.Context, w http.ResponseWriter, r *http.Request, private key.MachinePrivate, extraHeader http.Header) (*controlbase.Conn, error) { next := r.Header.Get("Upgrade") if next == "" { http.Error(w, "missing next protocol", http.StatusBadRequest) @@ -53,6 +57,9 @@ func AcceptHTTP(ctx context.Context, w http.ResponseWriter, r *http.Request, pri return nil, errors.New("can't hijack client connection") } + for k, vv := range extraHeader { + w.Header()[k] = vv + } w.Header().Set("Upgrade", upgradeHeaderValue) w.Header().Set("Connection", "upgrade") w.WriteHeader(http.StatusSwitchingProtocols)