mirror of
https://github.com/tailscale/tailscale.git
synced 2025-08-13 14:43:19 +00:00
control/controlhttp: allow setting, getting Upgrade headers in Noise upgrade
Not currently used, but will allow us to usually remove a round-trip for a future feature. Updates #5972 Change-Id: I2770ea28e3e6ec9626d1cbb505a38ba51df7fba2 Signed-off-by: Brad Fitzpatrick <bradfitz@tailscale.com>
This commit is contained in:

committed by
Brad Fitzpatrick

parent
03ecf335f7
commit
246274b8e9
@@ -60,7 +60,7 @@ var stdDialer net.Dialer
|
||||
//
|
||||
// The provided ctx is only used for the initial connection, until
|
||||
// Dial returns. It does not affect the connection once established.
|
||||
func (a *Dialer) Dial(ctx context.Context) (*controlbase.Conn, error) {
|
||||
func (a *Dialer) Dial(ctx context.Context) (*ClientConn, error) {
|
||||
if a.Hostname == "" {
|
||||
return nil, errors.New("required Dialer.Hostname empty")
|
||||
}
|
||||
@@ -91,7 +91,7 @@ func (a *Dialer) httpsFallbackDelay() time.Duration {
|
||||
|
||||
var _ = envknob.RegisterBool("TS_USE_CONTROL_DIAL_PLAN") // to record at init time whether it's in use
|
||||
|
||||
func (a *Dialer) dial(ctx context.Context) (*controlbase.Conn, error) {
|
||||
func (a *Dialer) dial(ctx context.Context) (*ClientConn, error) {
|
||||
// If we don't have a dial plan, just fall back to dialing the single
|
||||
// host we know about.
|
||||
useDialPlan := envknob.BoolDefaultTrue("TS_USE_CONTROL_DIAL_PLAN")
|
||||
@@ -117,7 +117,7 @@ func (a *Dialer) dial(ctx context.Context) (*controlbase.Conn, error) {
|
||||
|
||||
// Now, for each candidate, kick off a dial in parallel.
|
||||
type dialResult struct {
|
||||
conn *controlbase.Conn
|
||||
conn *ClientConn
|
||||
err error
|
||||
addr netip.Addr
|
||||
priority int
|
||||
@@ -129,7 +129,7 @@ func (a *Dialer) dial(ctx context.Context) (*controlbase.Conn, error) {
|
||||
for _, c := range candidates {
|
||||
go func(ctx context.Context, c tailcfg.ControlIPCandidate) {
|
||||
var (
|
||||
conn *controlbase.Conn
|
||||
conn *ClientConn
|
||||
err error
|
||||
)
|
||||
|
||||
@@ -228,7 +228,7 @@ func (a *Dialer) dial(ctx context.Context) (*controlbase.Conn, error) {
|
||||
})
|
||||
|
||||
var (
|
||||
conn *controlbase.Conn
|
||||
conn *ClientConn
|
||||
errs []error
|
||||
)
|
||||
for i, result := range results {
|
||||
@@ -252,7 +252,7 @@ func (a *Dialer) dial(ctx context.Context) (*controlbase.Conn, error) {
|
||||
// dialHost connects to the configured Dialer.Hostname and upgrades the
|
||||
// connection into a controlbase.Conn. If addr is valid, then no DNS is used
|
||||
// and the connection will be made to the provided address.
|
||||
func (a *Dialer) dialHost(ctx context.Context, addr netip.Addr) (*controlbase.Conn, error) {
|
||||
func (a *Dialer) dialHost(ctx context.Context, addr netip.Addr) (*ClientConn, error) {
|
||||
// Create one shared context used by both port 80 and port 443 dials.
|
||||
// If port 80 is still in flight when 443 returns, this deferred cancel
|
||||
// will stop the port 80 dial.
|
||||
@@ -274,8 +274,8 @@ func (a *Dialer) dialHost(ctx context.Context, addr netip.Addr) (*controlbase.Co
|
||||
}
|
||||
|
||||
type tryURLRes struct {
|
||||
u *url.URL // input (the URL conn+err are for/from)
|
||||
conn *controlbase.Conn // result (mutually exclusive with err)
|
||||
u *url.URL // input (the URL conn+err are for/from)
|
||||
conn *ClientConn // result (mutually exclusive with err)
|
||||
err error
|
||||
}
|
||||
ch := make(chan tryURLRes) // must be unbuffered
|
||||
@@ -331,12 +331,12 @@ func (a *Dialer) dialHost(ctx context.Context, addr netip.Addr) (*controlbase.Co
|
||||
}
|
||||
|
||||
// dialURL attempts to connect to the given URL.
|
||||
func (a *Dialer) dialURL(ctx context.Context, u *url.URL, addr netip.Addr) (*controlbase.Conn, error) {
|
||||
func (a *Dialer) dialURL(ctx context.Context, u *url.URL, addr netip.Addr) (*ClientConn, error) {
|
||||
init, cont, err := controlbase.ClientDeferred(a.MachineKey, a.ControlKey, a.ProtocolVersion)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
netConn, err := a.tryURLUpgrade(ctx, u, addr, init)
|
||||
netConn, untrustedUpgradeHeaders, err := a.tryURLUpgrade(ctx, u, addr, init)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -345,7 +345,10 @@ func (a *Dialer) dialURL(ctx context.Context, u *url.URL, addr netip.Addr) (*con
|
||||
netConn.Close()
|
||||
return nil, err
|
||||
}
|
||||
return cbConn, nil
|
||||
return &ClientConn{
|
||||
Conn: cbConn,
|
||||
UntrustedUpgradeHeaders: untrustedUpgradeHeaders,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// tryURLUpgrade connects to u, and tries to upgrade it to a net.Conn. If addr
|
||||
@@ -353,7 +356,7 @@ func (a *Dialer) dialURL(ctx context.Context, u *url.URL, addr netip.Addr) (*con
|
||||
// provided address.
|
||||
//
|
||||
// Only the provided ctx is used, not a.ctx.
|
||||
func (a *Dialer) tryURLUpgrade(ctx context.Context, u *url.URL, addr netip.Addr, init []byte) (net.Conn, error) {
|
||||
func (a *Dialer) tryURLUpgrade(ctx context.Context, u *url.URL, addr netip.Addr, init []byte) (_ net.Conn, untrustedUpgradeHeaders http.Header, _ error) {
|
||||
var dns *dnscache.Resolver
|
||||
|
||||
// If we were provided an address to dial, then create a resolver that just
|
||||
@@ -435,11 +438,11 @@ func (a *Dialer) tryURLUpgrade(ctx context.Context, u *url.URL, addr netip.Addr,
|
||||
|
||||
resp, err := tr.RoundTrip(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusSwitchingProtocols {
|
||||
return nil, fmt.Errorf("unexpected HTTP response: %s", resp.Status)
|
||||
return nil, nil, fmt.Errorf("unexpected HTTP response: %s", resp.Status)
|
||||
}
|
||||
|
||||
// From here on, the underlying net.Conn is ours to use, but there
|
||||
@@ -453,19 +456,19 @@ func (a *Dialer) tryURLUpgrade(ctx context.Context, u *url.URL, addr netip.Addr,
|
||||
}
|
||||
if switchedConn == nil {
|
||||
resp.Body.Close()
|
||||
return nil, fmt.Errorf("httptrace didn't provide a connection")
|
||||
return nil, nil, fmt.Errorf("httptrace didn't provide a connection")
|
||||
}
|
||||
|
||||
if next := resp.Header.Get("Upgrade"); next != upgradeHeaderValue {
|
||||
resp.Body.Close()
|
||||
return nil, fmt.Errorf("server switched to unexpected protocol %q", next)
|
||||
return nil, nil, fmt.Errorf("server switched to unexpected protocol %q", next)
|
||||
}
|
||||
|
||||
rwc, ok := resp.Body.(io.ReadWriteCloser)
|
||||
if !ok {
|
||||
resp.Body.Close()
|
||||
return nil, errors.New("http Transport did not provide a writable body")
|
||||
return nil, nil, errors.New("http Transport did not provide a writable body")
|
||||
}
|
||||
|
||||
return netutil.NewAltReadWriteCloserConn(rwc, switchedConn), nil
|
||||
return netutil.NewAltReadWriteCloserConn(rwc, switchedConn), resp.Header, nil
|
||||
}
|
||||
|
Reference in New Issue
Block a user