mirror of
https://github.com/tailscale/tailscale.git
synced 2025-02-18 02:48:40 +00:00
control/controlhttp: allow setting, getting Upgrade headers in Noise upgrade
Not currently used, but will allow us to usually remove a round-trip for a future feature. Updates #5972 Change-Id: I2770ea28e3e6ec9626d1cbb505a38ba51df7fba2 Signed-off-by: Brad Fitzpatrick <bradfitz@tailscale.com>
This commit is contained in:
parent
03ecf335f7
commit
246274b8e9
@ -210,7 +210,7 @@ func (nc *noiseClient) dial(_, _ string, _ *tls.Config) (net.Conn, error) {
|
|||||||
ctx, cancel := context.WithTimeout(context.Background(), timeout)
|
ctx, cancel := context.WithTimeout(context.Background(), timeout)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
conn, err := (&controlhttp.Dialer{
|
clientConn, err := (&controlhttp.Dialer{
|
||||||
Hostname: nc.host,
|
Hostname: nc.host,
|
||||||
HTTPPort: nc.httpPort,
|
HTTPPort: nc.httpPort,
|
||||||
HTTPSPort: nc.httpsPort,
|
HTTPSPort: nc.httpsPort,
|
||||||
@ -226,7 +226,7 @@ func (nc *noiseClient) dial(_, _ string, _ *tls.Config) (net.Conn, error) {
|
|||||||
|
|
||||||
nc.mu.Lock()
|
nc.mu.Lock()
|
||||||
defer nc.mu.Unlock()
|
defer nc.mu.Unlock()
|
||||||
ncc := &noiseConn{Conn: conn, id: connID, pool: nc}
|
ncc := &noiseConn{Conn: clientConn.Conn, id: connID, pool: nc}
|
||||||
mak.Set(&nc.connPool, ncc.id, ncc)
|
mak.Set(&nc.connPool, ncc.id, ncc)
|
||||||
return ncc, nil
|
return ncc, nil
|
||||||
}
|
}
|
||||||
|
@ -60,7 +60,7 @@ var stdDialer net.Dialer
|
|||||||
//
|
//
|
||||||
// The provided ctx is only used for the initial connection, until
|
// The provided ctx is only used for the initial connection, until
|
||||||
// Dial returns. It does not affect the connection once established.
|
// Dial returns. It does not affect the connection once established.
|
||||||
func (a *Dialer) Dial(ctx context.Context) (*controlbase.Conn, error) {
|
func (a *Dialer) Dial(ctx context.Context) (*ClientConn, error) {
|
||||||
if a.Hostname == "" {
|
if a.Hostname == "" {
|
||||||
return nil, errors.New("required Dialer.Hostname empty")
|
return nil, errors.New("required Dialer.Hostname empty")
|
||||||
}
|
}
|
||||||
@ -91,7 +91,7 @@ func (a *Dialer) httpsFallbackDelay() time.Duration {
|
|||||||
|
|
||||||
var _ = envknob.RegisterBool("TS_USE_CONTROL_DIAL_PLAN") // to record at init time whether it's in use
|
var _ = envknob.RegisterBool("TS_USE_CONTROL_DIAL_PLAN") // to record at init time whether it's in use
|
||||||
|
|
||||||
func (a *Dialer) dial(ctx context.Context) (*controlbase.Conn, error) {
|
func (a *Dialer) dial(ctx context.Context) (*ClientConn, error) {
|
||||||
// If we don't have a dial plan, just fall back to dialing the single
|
// If we don't have a dial plan, just fall back to dialing the single
|
||||||
// host we know about.
|
// host we know about.
|
||||||
useDialPlan := envknob.BoolDefaultTrue("TS_USE_CONTROL_DIAL_PLAN")
|
useDialPlan := envknob.BoolDefaultTrue("TS_USE_CONTROL_DIAL_PLAN")
|
||||||
@ -117,7 +117,7 @@ func (a *Dialer) dial(ctx context.Context) (*controlbase.Conn, error) {
|
|||||||
|
|
||||||
// Now, for each candidate, kick off a dial in parallel.
|
// Now, for each candidate, kick off a dial in parallel.
|
||||||
type dialResult struct {
|
type dialResult struct {
|
||||||
conn *controlbase.Conn
|
conn *ClientConn
|
||||||
err error
|
err error
|
||||||
addr netip.Addr
|
addr netip.Addr
|
||||||
priority int
|
priority int
|
||||||
@ -129,7 +129,7 @@ func (a *Dialer) dial(ctx context.Context) (*controlbase.Conn, error) {
|
|||||||
for _, c := range candidates {
|
for _, c := range candidates {
|
||||||
go func(ctx context.Context, c tailcfg.ControlIPCandidate) {
|
go func(ctx context.Context, c tailcfg.ControlIPCandidate) {
|
||||||
var (
|
var (
|
||||||
conn *controlbase.Conn
|
conn *ClientConn
|
||||||
err error
|
err error
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -228,7 +228,7 @@ func (a *Dialer) dial(ctx context.Context) (*controlbase.Conn, error) {
|
|||||||
})
|
})
|
||||||
|
|
||||||
var (
|
var (
|
||||||
conn *controlbase.Conn
|
conn *ClientConn
|
||||||
errs []error
|
errs []error
|
||||||
)
|
)
|
||||||
for i, result := range results {
|
for i, result := range results {
|
||||||
@ -252,7 +252,7 @@ func (a *Dialer) dial(ctx context.Context) (*controlbase.Conn, error) {
|
|||||||
// dialHost connects to the configured Dialer.Hostname and upgrades the
|
// dialHost connects to the configured Dialer.Hostname and upgrades the
|
||||||
// connection into a controlbase.Conn. If addr is valid, then no DNS is used
|
// connection into a controlbase.Conn. If addr is valid, then no DNS is used
|
||||||
// and the connection will be made to the provided address.
|
// and the connection will be made to the provided address.
|
||||||
func (a *Dialer) dialHost(ctx context.Context, addr netip.Addr) (*controlbase.Conn, error) {
|
func (a *Dialer) dialHost(ctx context.Context, addr netip.Addr) (*ClientConn, error) {
|
||||||
// Create one shared context used by both port 80 and port 443 dials.
|
// Create one shared context used by both port 80 and port 443 dials.
|
||||||
// If port 80 is still in flight when 443 returns, this deferred cancel
|
// If port 80 is still in flight when 443 returns, this deferred cancel
|
||||||
// will stop the port 80 dial.
|
// will stop the port 80 dial.
|
||||||
@ -274,8 +274,8 @@ func (a *Dialer) dialHost(ctx context.Context, addr netip.Addr) (*controlbase.Co
|
|||||||
}
|
}
|
||||||
|
|
||||||
type tryURLRes struct {
|
type tryURLRes struct {
|
||||||
u *url.URL // input (the URL conn+err are for/from)
|
u *url.URL // input (the URL conn+err are for/from)
|
||||||
conn *controlbase.Conn // result (mutually exclusive with err)
|
conn *ClientConn // result (mutually exclusive with err)
|
||||||
err error
|
err error
|
||||||
}
|
}
|
||||||
ch := make(chan tryURLRes) // must be unbuffered
|
ch := make(chan tryURLRes) // must be unbuffered
|
||||||
@ -331,12 +331,12 @@ func (a *Dialer) dialHost(ctx context.Context, addr netip.Addr) (*controlbase.Co
|
|||||||
}
|
}
|
||||||
|
|
||||||
// dialURL attempts to connect to the given URL.
|
// dialURL attempts to connect to the given URL.
|
||||||
func (a *Dialer) dialURL(ctx context.Context, u *url.URL, addr netip.Addr) (*controlbase.Conn, error) {
|
func (a *Dialer) dialURL(ctx context.Context, u *url.URL, addr netip.Addr) (*ClientConn, error) {
|
||||||
init, cont, err := controlbase.ClientDeferred(a.MachineKey, a.ControlKey, a.ProtocolVersion)
|
init, cont, err := controlbase.ClientDeferred(a.MachineKey, a.ControlKey, a.ProtocolVersion)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
netConn, err := a.tryURLUpgrade(ctx, u, addr, init)
|
netConn, untrustedUpgradeHeaders, err := a.tryURLUpgrade(ctx, u, addr, init)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@ -345,7 +345,10 @@ func (a *Dialer) dialURL(ctx context.Context, u *url.URL, addr netip.Addr) (*con
|
|||||||
netConn.Close()
|
netConn.Close()
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return cbConn, nil
|
return &ClientConn{
|
||||||
|
Conn: cbConn,
|
||||||
|
UntrustedUpgradeHeaders: untrustedUpgradeHeaders,
|
||||||
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// tryURLUpgrade connects to u, and tries to upgrade it to a net.Conn. If addr
|
// tryURLUpgrade connects to u, and tries to upgrade it to a net.Conn. If addr
|
||||||
@ -353,7 +356,7 @@ func (a *Dialer) dialURL(ctx context.Context, u *url.URL, addr netip.Addr) (*con
|
|||||||
// provided address.
|
// provided address.
|
||||||
//
|
//
|
||||||
// Only the provided ctx is used, not a.ctx.
|
// Only the provided ctx is used, not a.ctx.
|
||||||
func (a *Dialer) tryURLUpgrade(ctx context.Context, u *url.URL, addr netip.Addr, init []byte) (net.Conn, error) {
|
func (a *Dialer) tryURLUpgrade(ctx context.Context, u *url.URL, addr netip.Addr, init []byte) (_ net.Conn, untrustedUpgradeHeaders http.Header, _ error) {
|
||||||
var dns *dnscache.Resolver
|
var dns *dnscache.Resolver
|
||||||
|
|
||||||
// If we were provided an address to dial, then create a resolver that just
|
// If we were provided an address to dial, then create a resolver that just
|
||||||
@ -435,11 +438,11 @@ func (a *Dialer) tryURLUpgrade(ctx context.Context, u *url.URL, addr netip.Addr,
|
|||||||
|
|
||||||
resp, err := tr.RoundTrip(req)
|
resp, err := tr.RoundTrip(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if resp.StatusCode != http.StatusSwitchingProtocols {
|
if resp.StatusCode != http.StatusSwitchingProtocols {
|
||||||
return nil, fmt.Errorf("unexpected HTTP response: %s", resp.Status)
|
return nil, nil, fmt.Errorf("unexpected HTTP response: %s", resp.Status)
|
||||||
}
|
}
|
||||||
|
|
||||||
// From here on, the underlying net.Conn is ours to use, but there
|
// From here on, the underlying net.Conn is ours to use, but there
|
||||||
@ -453,19 +456,19 @@ func (a *Dialer) tryURLUpgrade(ctx context.Context, u *url.URL, addr netip.Addr,
|
|||||||
}
|
}
|
||||||
if switchedConn == nil {
|
if switchedConn == nil {
|
||||||
resp.Body.Close()
|
resp.Body.Close()
|
||||||
return nil, fmt.Errorf("httptrace didn't provide a connection")
|
return nil, nil, fmt.Errorf("httptrace didn't provide a connection")
|
||||||
}
|
}
|
||||||
|
|
||||||
if next := resp.Header.Get("Upgrade"); next != upgradeHeaderValue {
|
if next := resp.Header.Get("Upgrade"); next != upgradeHeaderValue {
|
||||||
resp.Body.Close()
|
resp.Body.Close()
|
||||||
return nil, fmt.Errorf("server switched to unexpected protocol %q", next)
|
return nil, nil, fmt.Errorf("server switched to unexpected protocol %q", next)
|
||||||
}
|
}
|
||||||
|
|
||||||
rwc, ok := resp.Body.(io.ReadWriteCloser)
|
rwc, ok := resp.Body.(io.ReadWriteCloser)
|
||||||
if !ok {
|
if !ok {
|
||||||
resp.Body.Close()
|
resp.Body.Close()
|
||||||
return nil, errors.New("http Transport did not provide a writable body")
|
return nil, nil, errors.New("http Transport did not provide a writable body")
|
||||||
}
|
}
|
||||||
|
|
||||||
return netutil.NewAltReadWriteCloserConn(rwc, switchedConn), nil
|
return netutil.NewAltReadWriteCloserConn(rwc, switchedConn), resp.Header, nil
|
||||||
}
|
}
|
||||||
|
26
control/controlhttp/client_common.go
Normal file
26
control/controlhttp/client_common.go
Normal file
@ -0,0 +1,26 @@
|
|||||||
|
// Copyright (c) 2022 Tailscale Inc & AUTHORS All rights reserved.
|
||||||
|
// Use of this source code is governed by a BSD-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
package controlhttp
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
|
||||||
|
"tailscale.com/control/controlbase"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ClientConn is a Tailscale control client as returned by the Dialer.
|
||||||
|
//
|
||||||
|
// It's effectively just a *controlbase.Conn (which it embeds) with
|
||||||
|
// optional metadata.
|
||||||
|
type ClientConn struct {
|
||||||
|
// Conn is the noise connection.
|
||||||
|
*controlbase.Conn
|
||||||
|
|
||||||
|
// UntrustedUpgradeHeaders are the HTTP headers seen in the
|
||||||
|
// 101 Switching Protocols upgrade response. They may be nil
|
||||||
|
// or even might've been tampered with by a middlebox.
|
||||||
|
// They should not be trusted.
|
||||||
|
UntrustedUpgradeHeaders http.Header
|
||||||
|
}
|
@ -17,7 +17,7 @@ import (
|
|||||||
|
|
||||||
// Variant of Dial that tunnels the request over WebSockets, since we cannot do
|
// Variant of Dial that tunnels the request over WebSockets, since we cannot do
|
||||||
// bi-directional communication over an HTTP connection when in JS.
|
// bi-directional communication over an HTTP connection when in JS.
|
||||||
func (d *Dialer) Dial(ctx context.Context) (*controlbase.Conn, error) {
|
func (d *Dialer) Dial(ctx context.Context) (*ClientConn, error) {
|
||||||
if d.Hostname == "" {
|
if d.Hostname == "" {
|
||||||
return nil, errors.New("required Dialer.Hostname empty")
|
return nil, errors.New("required Dialer.Hostname empty")
|
||||||
}
|
}
|
||||||
@ -45,7 +45,7 @@ func (d *Dialer) Dial(ctx context.Context) (*controlbase.Conn, error) {
|
|||||||
handshakeHeaderName: []string{base64.StdEncoding.EncodeToString(init)},
|
handshakeHeaderName: []string{base64.StdEncoding.EncodeToString(init)},
|
||||||
}.Encode(),
|
}.Encode(),
|
||||||
}
|
}
|
||||||
wsConn, _, err := websocket.Dial(ctx, wsURL.String(), &websocket.DialOptions{
|
wsConn, httpRes, err := websocket.Dial(ctx, wsURL.String(), &websocket.DialOptions{
|
||||||
Subprotocols: []string{upgradeHeaderValue},
|
Subprotocols: []string{upgradeHeaderValue},
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -57,5 +57,8 @@ func (d *Dialer) Dial(ctx context.Context) (*controlbase.Conn, error) {
|
|||||||
netConn.Close()
|
netConn.Close()
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return cbConn, nil
|
return &ClientConn{
|
||||||
|
Conn: cbConn,
|
||||||
|
UntrustedUpgradeHeaders: httpRes.Header,
|
||||||
|
}, nil
|
||||||
}
|
}
|
||||||
|
@ -127,7 +127,7 @@ func testControlHTTP(t *testing.T, param httpTestParam) {
|
|||||||
const testProtocolVersion = 1
|
const testProtocolVersion = 1
|
||||||
sch := make(chan serverResult, 1)
|
sch := make(chan serverResult, 1)
|
||||||
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
conn, err := AcceptHTTP(context.Background(), w, r, server)
|
conn, err := AcceptHTTP(context.Background(), w, r, server, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Print(err)
|
log.Print(err)
|
||||||
}
|
}
|
||||||
@ -485,7 +485,7 @@ func TestDialPlan(t *testing.T) {
|
|||||||
close(done)
|
close(done)
|
||||||
})
|
})
|
||||||
var handler http.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
var handler http.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
conn, err := AcceptHTTP(context.Background(), w, r, server)
|
conn, err := AcceptHTTP(context.Background(), w, r, server, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Print(err)
|
log.Print(err)
|
||||||
} else {
|
} else {
|
||||||
|
@ -22,7 +22,11 @@ import (
|
|||||||
//
|
//
|
||||||
// AcceptHTTP always writes an HTTP response to w. The caller must not
|
// AcceptHTTP always writes an HTTP response to w. The caller must not
|
||||||
// attempt their own response after calling AcceptHTTP.
|
// attempt their own response after calling AcceptHTTP.
|
||||||
func AcceptHTTP(ctx context.Context, w http.ResponseWriter, r *http.Request, private key.MachinePrivate) (*controlbase.Conn, error) {
|
//
|
||||||
|
// extraHeader optionally specifies extra header(s) to send in the
|
||||||
|
// 101 Switching Protocols Upgrade response. It must not include the "Upgrade"
|
||||||
|
// or "Connection" headers; they will be replaced.
|
||||||
|
func AcceptHTTP(ctx context.Context, w http.ResponseWriter, r *http.Request, private key.MachinePrivate, extraHeader http.Header) (*controlbase.Conn, error) {
|
||||||
next := r.Header.Get("Upgrade")
|
next := r.Header.Get("Upgrade")
|
||||||
if next == "" {
|
if next == "" {
|
||||||
http.Error(w, "missing next protocol", http.StatusBadRequest)
|
http.Error(w, "missing next protocol", http.StatusBadRequest)
|
||||||
@ -53,6 +57,9 @@ func AcceptHTTP(ctx context.Context, w http.ResponseWriter, r *http.Request, pri
|
|||||||
return nil, errors.New("can't hijack client connection")
|
return nil, errors.New("can't hijack client connection")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
for k, vv := range extraHeader {
|
||||||
|
w.Header()[k] = vv
|
||||||
|
}
|
||||||
w.Header().Set("Upgrade", upgradeHeaderValue)
|
w.Header().Set("Upgrade", upgradeHeaderValue)
|
||||||
w.Header().Set("Connection", "upgrade")
|
w.Header().Set("Connection", "upgrade")
|
||||||
w.WriteHeader(http.StatusSwitchingProtocols)
|
w.WriteHeader(http.StatusSwitchingProtocols)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user