// Copyright (c) 2021 Tailscale Inc & AUTHORS All rights reserved. // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. //go:build !js // +build !js // Package controlhttp implements the Tailscale 2021 control protocol // base transport over HTTP. // // This tunnels the protocol in control/controlbase over HTTP with a // variety of compatibility fallbacks for handling picky or deep // inspecting proxies. // // In the happy path, a client makes a single cleartext HTTP request // to the server, the server responds with 101 Switching Protocols, // and the control base protocol takes place over plain TCP. // // In the compatibility path, the client does the above over HTTPS, // resulting in double encryption (once for the control transport, and // once for the outer TLS layer). package controlhttp import ( "context" "crypto/tls" "encoding/base64" "errors" "fmt" "io" "math" "net" "net/http" "net/http/httptrace" "net/netip" "net/url" "sort" "sync/atomic" "time" "tailscale.com/control/controlbase" "tailscale.com/envknob" "tailscale.com/net/dnscache" "tailscale.com/net/dnsfallback" "tailscale.com/net/netutil" "tailscale.com/net/tlsdial" "tailscale.com/net/tshttpproxy" "tailscale.com/tailcfg" "tailscale.com/util/multierr" ) var stdDialer net.Dialer // Dial connects to the HTTP server at this Dialer's Host:HTTPPort, requests to // switch to the Tailscale control protocol, and returns an established control // protocol connection. // // If Dial fails to connect using HTTP, it also tries to tunnel over TLS to the // Dialer's Host:HTTPSPort as a compatibility fallback. // // The provided ctx is only used for the initial connection, until // Dial returns. It does not affect the connection once established. func (a *Dialer) Dial(ctx context.Context) (*ClientConn, error) { if a.Hostname == "" { return nil, errors.New("required Dialer.Hostname empty") } return a.dial(ctx) } func (a *Dialer) logf(format string, args ...any) { if a.Logf != nil { a.Logf(format, args...) } } func (a *Dialer) getProxyFunc() func(*http.Request) (*url.URL, error) { if a.proxyFunc != nil { return a.proxyFunc } return tshttpproxy.ProxyFromEnvironment } // httpsFallbackDelay is how long we'll wait for a.HTTPPort to work before // starting to try a.HTTPSPort. func (a *Dialer) httpsFallbackDelay() time.Duration { if v := a.testFallbackDelay; v != 0 { return v } return 500 * time.Millisecond } var _ = envknob.RegisterBool("TS_USE_CONTROL_DIAL_PLAN") // to record at init time whether it's in use func (a *Dialer) dial(ctx context.Context) (*ClientConn, error) { // If we don't have a dial plan, just fall back to dialing the single // host we know about. useDialPlan := envknob.BoolDefaultTrue("TS_USE_CONTROL_DIAL_PLAN") if !useDialPlan || a.DialPlan == nil || len(a.DialPlan.Candidates) == 0 { return a.dialHost(ctx, netip.Addr{}) } candidates := a.DialPlan.Candidates // Otherwise, we try dialing per the plan. Store the highest priority // in the list, so that if we get a connection to one of those // candidates we can return quickly. var highestPriority int = math.MinInt for _, c := range candidates { if c.Priority > highestPriority { highestPriority = c.Priority } } // This context allows us to cancel in-flight connections if we get a // highest-priority connection before we're all done. ctx, cancel := context.WithCancel(ctx) defer cancel() // Now, for each candidate, kick off a dial in parallel. type dialResult struct { conn *ClientConn err error addr netip.Addr priority int } resultsCh := make(chan dialResult, len(candidates)) var pending atomic.Int32 pending.Store(int32(len(candidates))) for _, c := range candidates { go func(ctx context.Context, c tailcfg.ControlIPCandidate) { var ( conn *ClientConn err error ) // Always send results back to our channel. defer func() { resultsCh <- dialResult{conn, err, c.IP, c.Priority} if pending.Add(-1) == 0 { close(resultsCh) } }() // If non-zero, wait the configured start timeout // before we do anything. if c.DialStartDelaySec > 0 { a.logf("[v2] controlhttp: waiting %.2f seconds before dialing %q @ %v", c.DialStartDelaySec, a.Hostname, c.IP) tmr := time.NewTimer(time.Duration(c.DialStartDelaySec * float64(time.Second))) defer tmr.Stop() select { case <-ctx.Done(): err = ctx.Err() return case <-tmr.C: } } // Now, create a sub-context with the given timeout and // try dialing the provided host. ctx, cancel := context.WithTimeout(ctx, time.Duration(c.DialTimeoutSec*float64(time.Second))) defer cancel() // This will dial, and the defer above sends it back to our parent. a.logf("[v2] controlhttp: trying to dial %q @ %v", a.Hostname, c.IP) conn, err = a.dialHost(ctx, c.IP) }(ctx, c) } var results []dialResult for res := range resultsCh { // If we get a response that has the highest priority, we don't // need to wait for any of the other connections to finish; we // can just return this connection. // // TODO(andrew): we could make this better by keeping track of // the highest remaining priority dynamically, instead of just // checking for the highest total if res.priority == highestPriority && res.conn != nil { a.logf("[v1] controlhttp: high-priority success dialing %q @ %v from dial plan", a.Hostname, res.addr) // Drain the channel and any existing connections in // the background. go func() { for _, res := range results { if res.conn != nil { res.conn.Close() } } for res := range resultsCh { if res.conn != nil { res.conn.Close() } } if a.drainFinished != nil { close(a.drainFinished) } }() return res.conn, nil } // This isn't a highest-priority result, so just store it until // we're done. results = append(results, res) } // After we finish this function, close any remaining open connections. defer func() { for _, result := range results { // Note: below, we nil out the returned connection (if // any) in the slice so we don't close it. if result.conn != nil { result.conn.Close() } } // We don't drain asynchronously after this point, so notify our // channel when we return. if a.drainFinished != nil { close(a.drainFinished) } }() // Sort by priority, then take the first non-error response. sort.Slice(results, func(i, j int) bool { // NOTE: intentionally inverted so that the highest priority // item comes first return results[i].priority > results[j].priority }) var ( conn *ClientConn errs []error ) for i, result := range results { if result.err != nil { errs = append(errs, result.err) continue } a.logf("[v1] controlhttp: succeeded dialing %q @ %v from dial plan", a.Hostname, result.addr) conn = result.conn results[i].conn = nil // so we don't close it in the defer return conn, nil } merr := multierr.New(errs...) // If we get here, then we didn't get anywhere with our dial plan; fall back to just using DNS. a.logf("controlhttp: failed dialing using DialPlan, falling back to DNS; errs=%s", merr.Error()) return a.dialHost(ctx, netip.Addr{}) } // dialHost connects to the configured Dialer.Hostname and upgrades the // connection into a controlbase.Conn. If addr is valid, then no DNS is used // and the connection will be made to the provided address. func (a *Dialer) dialHost(ctx context.Context, addr netip.Addr) (*ClientConn, error) { // Create one shared context used by both port 80 and port 443 dials. // If port 80 is still in flight when 443 returns, this deferred cancel // will stop the port 80 dial. ctx, cancel := context.WithCancel(ctx) defer cancel() // u80 and u443 are the URLs we'll try to hit over HTTP or HTTPS, // respectively, in order to do the HTTP upgrade to a net.Conn over which // we'll speak Noise. u80 := &url.URL{ Scheme: "http", Host: net.JoinHostPort(a.Hostname, strDef(a.HTTPPort, "80")), Path: serverUpgradePath, } u443 := &url.URL{ Scheme: "https", Host: net.JoinHostPort(a.Hostname, strDef(a.HTTPSPort, "443")), Path: serverUpgradePath, } type tryURLRes struct { u *url.URL // input (the URL conn+err are for/from) conn *ClientConn // result (mutually exclusive with err) err error } ch := make(chan tryURLRes) // must be unbuffered try := func(u *url.URL) { cbConn, err := a.dialURL(ctx, u, addr) select { case ch <- tryURLRes{u, cbConn, err}: case <-ctx.Done(): if cbConn != nil { cbConn.Close() } } } // Start the plaintext HTTP attempt first. go try(u80) // In case outbound port 80 blocked or MITM'ed poorly, start a backup timer // to dial port 443 if port 80 doesn't either succeed or fail quickly. try443Timer := time.AfterFunc(a.httpsFallbackDelay(), func() { try(u443) }) defer try443Timer.Stop() var err80, err443 error for { select { case <-ctx.Done(): return nil, fmt.Errorf("connection attempts aborted by context: %w", ctx.Err()) case res := <-ch: if res.err == nil { return res.conn, nil } switch res.u { case u80: // Connecting over plain HTTP failed; assume it's an HTTP proxy // being difficult and see if we can get through over HTTPS. err80 = res.err // Stop the fallback timer and run it immediately. We don't use // Timer.Reset(0) here because on AfterFuncs, that can run it // again. if try443Timer.Stop() { go try(u443) } // else we lost the race and it started already which is what we want case u443: err443 = res.err default: panic("invalid") } if err80 != nil && err443 != nil { return nil, fmt.Errorf("all connection attempts failed (HTTP: %v, HTTPS: %v)", err80, err443) } } } } // dialURL attempts to connect to the given URL. func (a *Dialer) dialURL(ctx context.Context, u *url.URL, addr netip.Addr) (*ClientConn, error) { init, cont, err := controlbase.ClientDeferred(a.MachineKey, a.ControlKey, a.ProtocolVersion) if err != nil { return nil, err } netConn, err := a.tryURLUpgrade(ctx, u, addr, init) if err != nil { return nil, err } cbConn, err := cont(ctx, netConn) if err != nil { netConn.Close() return nil, err } return &ClientConn{ Conn: cbConn, }, nil } // tryURLUpgrade connects to u, and tries to upgrade it to a net.Conn. If addr // is valid, then no DNS is used and the connection will be made to the // provided address. // // Only the provided ctx is used, not a.ctx. func (a *Dialer) tryURLUpgrade(ctx context.Context, u *url.URL, addr netip.Addr, init []byte) (net.Conn, error) { var dns *dnscache.Resolver // If we were provided an address to dial, then create a resolver that just // returns that value; otherwise, fall back to DNS. if addr.IsValid() { dns = &dnscache.Resolver{ SingleHostStaticResult: []netip.Addr{addr}, SingleHost: u.Hostname(), } } else { dns = &dnscache.Resolver{ Forward: dnscache.Get().Forward, LookupIPFallback: dnsfallback.Lookup, UseLastGood: true, } } var dialer dnscache.DialContextFunc if a.Dialer != nil { dialer = a.Dialer } else { dialer = stdDialer.DialContext } tr := http.DefaultTransport.(*http.Transport).Clone() defer tr.CloseIdleConnections() tr.Proxy = a.getProxyFunc() tshttpproxy.SetTransportGetProxyConnectHeader(tr) tr.DialContext = dnscache.Dialer(dialer, dns) // Disable HTTP2, since h2 can't do protocol switching. tr.TLSClientConfig.NextProtos = []string{} tr.TLSNextProto = map[string]func(string, *tls.Conn) http.RoundTripper{} tr.TLSClientConfig = tlsdial.Config(a.Hostname, tr.TLSClientConfig) if a.insecureTLS { tr.TLSClientConfig.InsecureSkipVerify = true tr.TLSClientConfig.VerifyConnection = nil } tr.DialTLSContext = dnscache.TLSDialer(dialer, dns, tr.TLSClientConfig) tr.DisableCompression = true // (mis)use httptrace to extract the underlying net.Conn from the // transport. We make exactly 1 request using this transport, so // there will be exactly 1 GotConn call. Additionally, the // transport handles 101 Switching Protocols correctly, such that // the Conn will not be reused or kept alive by the transport once // the response has been handed back from RoundTrip. // // In theory, the machinery of net/http should make it such that // the trace callback happens-before we get the response, but // there's no promise of that. So, to make sure, we use a buffered // channel as a synchronization step to avoid data races. // // Note that even though we're able to extract a net.Conn via this // mechanism, we must still keep using the eventual resp.Body to // read from, because it includes a buffer we can't get rid of. If // the server never sends any data after sending the HTTP // response, we could get away with it, but violating this // assumption leads to very mysterious transport errors (lockups, // unexpected EOFs...), and we're bound to forget someday and // introduce a protocol optimization at a higher level that starts // eagerly transmitting from the server. connCh := make(chan net.Conn, 1) trace := httptrace.ClientTrace{ GotConn: func(info httptrace.GotConnInfo) { connCh <- info.Conn }, } ctx = httptrace.WithClientTrace(ctx, &trace) req := &http.Request{ Method: "POST", URL: u, Header: http.Header{ "Upgrade": []string{upgradeHeaderValue}, "Connection": []string{"upgrade"}, handshakeHeaderName: []string{base64.StdEncoding.EncodeToString(init)}, }, } req = req.WithContext(ctx) resp, err := tr.RoundTrip(req) if err != nil { return nil, err } if resp.StatusCode != http.StatusSwitchingProtocols { return nil, fmt.Errorf("unexpected HTTP response: %s", resp.Status) } // From here on, the underlying net.Conn is ours to use, but there // is still a read buffer attached to it within resp.Body. So, we // must direct I/O through resp.Body, but we can still use the // underlying net.Conn for stuff like deadlines. var switchedConn net.Conn select { case switchedConn = <-connCh: default: } if switchedConn == nil { resp.Body.Close() return nil, fmt.Errorf("httptrace didn't provide a connection") } if next := resp.Header.Get("Upgrade"); next != upgradeHeaderValue { resp.Body.Close() return nil, fmt.Errorf("server switched to unexpected protocol %q", next) } rwc, ok := resp.Body.(io.ReadWriteCloser) if !ok { resp.Body.Close() return nil, errors.New("http Transport did not provide a writable body") } return netutil.NewAltReadWriteCloserConn(rwc, switchedConn), nil }