mirror of
https://github.com/tailscale/tailscale.git
synced 2025-08-11 21:27:31 +00:00
tailcfg, control/controlhttp, control/controlclient: add ControlDialPlan field (#5648)
* tailcfg, control/controlhttp, control/controlclient: add ControlDialPlan field This field allows the control server to provide explicit information about how to connect to it; useful if the client's link status can change after the initial connection, or if the DNS settings pushed by the control server break future connections. Change-Id: I720afe6289ec27d40a41b3dcb310ec45bd7e5f3e Signed-off-by: Andrew Dunham <andrew@tailscale.com>
This commit is contained in:
@@ -76,6 +76,8 @@ type Direct struct {
|
||||
popBrowser func(url string) // or nil
|
||||
c2nHandler http.Handler // or nil
|
||||
|
||||
dialPlan ControlDialPlanner // can be nil
|
||||
|
||||
mu sync.Mutex // mutex guards the following fields
|
||||
serverKey key.MachinePublic // original ("legacy") nacl crypto_box-based public key
|
||||
serverNoiseKey key.MachinePublic
|
||||
@@ -133,6 +135,34 @@ type Options struct {
|
||||
// MapResponse.PingRequest queries from the control plane.
|
||||
// If nil, PingRequest queries are not answered.
|
||||
Pinger Pinger
|
||||
|
||||
// DialPlan contains and stores a previous dial plan that we received
|
||||
// from the control server; if nil, we fall back to using DNS.
|
||||
//
|
||||
// If we receive a new DialPlan from the server, this value will be
|
||||
// updated.
|
||||
DialPlan ControlDialPlanner
|
||||
}
|
||||
|
||||
// ControlDialPlanner is the interface optionally supplied when creating a
|
||||
// control client to control exactly how TCP connections to the control plane
|
||||
// are dialed.
|
||||
//
|
||||
// It is usually implemented by an atomic.Pointer.
|
||||
type ControlDialPlanner interface {
|
||||
// Load returns the current plan for how to connect to control.
|
||||
//
|
||||
// The returned plan can be nil. If so, connections should be made by
|
||||
// resolving the control URL using DNS.
|
||||
Load() *tailcfg.ControlDialPlan
|
||||
|
||||
// Store updates the dial plan with new directions from the control
|
||||
// server.
|
||||
//
|
||||
// The dial plan can span multiple connections to the control server.
|
||||
// That is, a dial plan received when connected over Wi-Fi is still
|
||||
// valid for a subsequent connection over LTE after a network switch.
|
||||
Store(*tailcfg.ControlDialPlan)
|
||||
}
|
||||
|
||||
// Pinger is the LocalBackend.Ping method.
|
||||
@@ -216,6 +246,7 @@ func NewDirect(opts Options) (*Direct, error) {
|
||||
popBrowser: opts.PopBrowserURL,
|
||||
c2nHandler: opts.C2NHandler,
|
||||
dialer: opts.Dialer,
|
||||
dialPlan: opts.DialPlan,
|
||||
}
|
||||
if opts.Hostinfo == nil {
|
||||
c.SetHostinfo(hostinfo.New())
|
||||
@@ -915,6 +946,14 @@ func (c *Direct) sendMapRequest(ctx context.Context, maxPolls int, readOnly bool
|
||||
} else {
|
||||
vlogf("netmap: got new map")
|
||||
}
|
||||
if resp.ControlDialPlan != nil {
|
||||
if c.dialPlan != nil {
|
||||
c.logf("netmap: got new dial plan from control")
|
||||
c.dialPlan.Store(resp.ControlDialPlan)
|
||||
} else {
|
||||
c.logf("netmap: [unexpected] new dial plan; nowhere to store it")
|
||||
}
|
||||
}
|
||||
|
||||
select {
|
||||
case timeoutReset <- struct{}{}:
|
||||
@@ -1365,12 +1404,17 @@ func (c *Direct) getNoiseClient() (*noiseClient, error) {
|
||||
if nc != nil {
|
||||
return nc, nil
|
||||
}
|
||||
var dp func() *tailcfg.ControlDialPlan
|
||||
if c.dialPlan != nil {
|
||||
dp = c.dialPlan.Load
|
||||
}
|
||||
nc, err, _ := c.sfGroup.Do(struct{}{}, func() (*noiseClient, error) {
|
||||
k, err := c.getMachinePrivKey()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
nc, err := newNoiseClient(k, serverNoiseKey, c.serverURL, c.dialer)
|
||||
c.logf("creating new noise client")
|
||||
nc, err := newNoiseClient(k, serverNoiseKey, c.serverURL, c.dialer, dp)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@@ -53,6 +53,11 @@ type noiseClient struct {
|
||||
httpPort string // the default port to call
|
||||
httpsPort string // the fallback Noise-over-https port
|
||||
|
||||
// dialPlan optionally returns a ControlDialPlan previously received
|
||||
// from the control server; either the function or the return value can
|
||||
// be nil.
|
||||
dialPlan func() *tailcfg.ControlDialPlan
|
||||
|
||||
// mu only protects the following variables.
|
||||
mu sync.Mutex
|
||||
nextID int
|
||||
@@ -61,7 +66,9 @@ type noiseClient struct {
|
||||
|
||||
// newNoiseClient returns a new noiseClient for the provided server and machine key.
|
||||
// serverURL is of the form https://<host>:<port> (no trailing slash).
|
||||
func newNoiseClient(priKey key.MachinePrivate, serverPubKey key.MachinePublic, serverURL string, dialer *tsdial.Dialer) (*noiseClient, error) {
|
||||
//
|
||||
// dialPlan may be nil
|
||||
func newNoiseClient(priKey key.MachinePrivate, serverPubKey key.MachinePublic, serverURL string, dialer *tsdial.Dialer, dialPlan func() *tailcfg.ControlDialPlan) (*noiseClient, error) {
|
||||
u, err := url.Parse(serverURL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -89,6 +96,7 @@ func newNoiseClient(priKey key.MachinePrivate, serverPubKey key.MachinePublic, s
|
||||
httpPort: httpPort,
|
||||
httpsPort: httpsPort,
|
||||
dialer: dialer,
|
||||
dialPlan: dialPlan,
|
||||
}
|
||||
|
||||
// Create the HTTP/2 Transport using a net/http.Transport
|
||||
@@ -155,16 +163,51 @@ func (nc *noiseClient) dial(_, _ string, _ *tls.Config) (net.Conn, error) {
|
||||
nc.nextID++
|
||||
nc.mu.Unlock()
|
||||
|
||||
// Timeout is a little arbitrary, but plenty long enough for even the
|
||||
// highest latency links.
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
if tailcfg.CurrentCapabilityVersion > math.MaxUint16 {
|
||||
// Panic, because a test should have started failing several
|
||||
// thousand version numbers before getting to this point.
|
||||
panic("capability version is too high to fit in the wire protocol")
|
||||
}
|
||||
|
||||
var dialPlan *tailcfg.ControlDialPlan
|
||||
if nc.dialPlan != nil {
|
||||
dialPlan = nc.dialPlan()
|
||||
}
|
||||
|
||||
// If we have a dial plan, then set our timeout as slightly longer than
|
||||
// the maximum amount of time contained therein; we assume that
|
||||
// explicit instructions on timeouts are more useful than a single
|
||||
// hard-coded timeout.
|
||||
//
|
||||
// The default value of 5 is chosen so that, when there's no dial plan,
|
||||
// we retain the previous behaviour of 10 seconds end-to-end timeout.
|
||||
timeoutSec := 5.0
|
||||
if dialPlan != nil {
|
||||
for _, c := range dialPlan.Candidates {
|
||||
if v := c.DialStartDelaySec + c.DialTimeoutSec; v > timeoutSec {
|
||||
timeoutSec = v
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// After we establish a connection, we need some time to actually
|
||||
// upgrade it into a Noise connection. With a ballpark worst-case RTT
|
||||
// of 1000ms, give ourselves an extra 5 seconds to complete the
|
||||
// handshake.
|
||||
timeoutSec += 5
|
||||
|
||||
// Be extremely defensive and ensure that the timeout is in the range
|
||||
// [5, 60] seconds (e.g. if we accidentally get a negative number).
|
||||
if timeoutSec > 60 {
|
||||
timeoutSec = 60
|
||||
} else if timeoutSec < 5 {
|
||||
timeoutSec = 5
|
||||
}
|
||||
|
||||
timeout := time.Duration(timeoutSec * float64(time.Second))
|
||||
ctx, cancel := context.WithTimeout(context.Background(), timeout)
|
||||
defer cancel()
|
||||
|
||||
conn, err := (&controlhttp.Dialer{
|
||||
Hostname: nc.host,
|
||||
HTTPPort: nc.httpPort,
|
||||
@@ -173,6 +216,7 @@ func (nc *noiseClient) dial(_, _ string, _ *tls.Config) (net.Conn, error) {
|
||||
ControlKey: nc.serverPubKey,
|
||||
ProtocolVersion: uint16(tailcfg.CurrentCapabilityVersion),
|
||||
Dialer: nc.dialer.SystemDial,
|
||||
DialPlan: dialPlan,
|
||||
}).Dial(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@@ -28,18 +28,25 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"math"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httptrace"
|
||||
"net/netip"
|
||||
"net/url"
|
||||
"sort"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"tailscale.com/control/controlbase"
|
||||
"tailscale.com/envknob"
|
||||
"tailscale.com/net/dnscache"
|
||||
"tailscale.com/net/dnsfallback"
|
||||
"tailscale.com/net/netutil"
|
||||
"tailscale.com/net/tlsdial"
|
||||
"tailscale.com/net/tshttpproxy"
|
||||
"tailscale.com/tailcfg"
|
||||
"tailscale.com/util/multierr"
|
||||
)
|
||||
|
||||
var stdDialer net.Dialer
|
||||
@@ -82,7 +89,170 @@ func (a *Dialer) httpsFallbackDelay() time.Duration {
|
||||
return 500 * time.Millisecond
|
||||
}
|
||||
|
||||
var _ = envknob.RegisterBool("TS_USE_CONTROL_DIAL_PLAN") // to record at init time whether it's in use
|
||||
|
||||
func (a *Dialer) dial(ctx context.Context) (*controlbase.Conn, error) {
|
||||
// If we don't have a dial plan, just fall back to dialing the single
|
||||
// host we know about.
|
||||
useDialPlan := envknob.BoolDefaultTrue("TS_USE_CONTROL_DIAL_PLAN")
|
||||
if !useDialPlan || a.DialPlan == nil || len(a.DialPlan.Candidates) == 0 {
|
||||
return a.dialHost(ctx, netip.Addr{})
|
||||
}
|
||||
candidates := a.DialPlan.Candidates
|
||||
|
||||
// Otherwise, we try dialing per the plan. Store the highest priority
|
||||
// in the list, so that if we get a connection to one of those
|
||||
// candidates we can return quickly.
|
||||
var highestPriority int = math.MinInt
|
||||
for _, c := range candidates {
|
||||
if c.Priority > highestPriority {
|
||||
highestPriority = c.Priority
|
||||
}
|
||||
}
|
||||
|
||||
// This context allows us to cancel in-flight connections if we get a
|
||||
// highest-priority connection before we're all done.
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
|
||||
// Now, for each candidate, kick off a dial in parallel.
|
||||
type dialResult struct {
|
||||
conn *controlbase.Conn
|
||||
err error
|
||||
addr netip.Addr
|
||||
priority int
|
||||
}
|
||||
resultsCh := make(chan dialResult, len(candidates))
|
||||
|
||||
var pending atomic.Int32
|
||||
pending.Store(int32(len(candidates)))
|
||||
for _, c := range candidates {
|
||||
go func(ctx context.Context, c tailcfg.ControlIPCandidate) {
|
||||
var (
|
||||
conn *controlbase.Conn
|
||||
err error
|
||||
)
|
||||
|
||||
// Always send results back to our channel.
|
||||
defer func() {
|
||||
resultsCh <- dialResult{conn, err, c.IP, c.Priority}
|
||||
if pending.Add(-1) == 0 {
|
||||
close(resultsCh)
|
||||
}
|
||||
}()
|
||||
|
||||
// If non-zero, wait the configured start timeout
|
||||
// before we do anything.
|
||||
if c.DialStartDelaySec > 0 {
|
||||
a.logf("[v2] controlhttp: waiting %.2f seconds before dialing %q @ %v", c.DialStartDelaySec, a.Hostname, c.IP)
|
||||
tmr := time.NewTimer(time.Duration(c.DialStartDelaySec * float64(time.Second)))
|
||||
defer tmr.Stop()
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
err = ctx.Err()
|
||||
return
|
||||
case <-tmr.C:
|
||||
}
|
||||
}
|
||||
|
||||
// Now, create a sub-context with the given timeout and
|
||||
// try dialing the provided host.
|
||||
ctx, cancel := context.WithTimeout(ctx, time.Duration(c.DialTimeoutSec*float64(time.Second)))
|
||||
defer cancel()
|
||||
|
||||
// This will dial, and the defer above sends it back to our parent.
|
||||
a.logf("[v2] controlhttp: trying to dial %q @ %v", a.Hostname, c.IP)
|
||||
conn, err = a.dialHost(ctx, c.IP)
|
||||
}(ctx, c)
|
||||
}
|
||||
|
||||
var results []dialResult
|
||||
for res := range resultsCh {
|
||||
// If we get a response that has the highest priority, we don't
|
||||
// need to wait for any of the other connections to finish; we
|
||||
// can just return this connection.
|
||||
//
|
||||
// TODO(andrew): we could make this better by keeping track of
|
||||
// the highest remaining priority dynamically, instead of just
|
||||
// checking for the highest total
|
||||
if res.priority == highestPriority && res.conn != nil {
|
||||
a.logf("[v1] controlhttp: high-priority success dialing %q @ %v from dial plan", a.Hostname, res.addr)
|
||||
|
||||
// Drain the channel and any existing connections in
|
||||
// the background.
|
||||
go func() {
|
||||
for _, res := range results {
|
||||
if res.conn != nil {
|
||||
res.conn.Close()
|
||||
}
|
||||
}
|
||||
for res := range resultsCh {
|
||||
if res.conn != nil {
|
||||
res.conn.Close()
|
||||
}
|
||||
}
|
||||
if a.drainFinished != nil {
|
||||
close(a.drainFinished)
|
||||
}
|
||||
}()
|
||||
return res.conn, nil
|
||||
}
|
||||
|
||||
// This isn't a highest-priority result, so just store it until
|
||||
// we're done.
|
||||
results = append(results, res)
|
||||
}
|
||||
|
||||
// After we finish this function, close any remaining open connections.
|
||||
defer func() {
|
||||
for _, result := range results {
|
||||
// Note: below, we nil out the returned connection (if
|
||||
// any) in the slice so we don't close it.
|
||||
if result.conn != nil {
|
||||
result.conn.Close()
|
||||
}
|
||||
}
|
||||
|
||||
// We don't drain asynchronously after this point, so notify our
|
||||
// channel when we return.
|
||||
if a.drainFinished != nil {
|
||||
close(a.drainFinished)
|
||||
}
|
||||
}()
|
||||
|
||||
// Sort by priority, then take the first non-error response.
|
||||
sort.Slice(results, func(i, j int) bool {
|
||||
// NOTE: intentionally inverted so that the highest priority
|
||||
// item comes first
|
||||
return results[i].priority > results[j].priority
|
||||
})
|
||||
|
||||
var (
|
||||
conn *controlbase.Conn
|
||||
errs []error
|
||||
)
|
||||
for i, result := range results {
|
||||
if result.err != nil {
|
||||
errs = append(errs, result.err)
|
||||
continue
|
||||
}
|
||||
|
||||
a.logf("[v1] controlhttp: succeeded dialing %q @ %v from dial plan", a.Hostname, result.addr)
|
||||
conn = result.conn
|
||||
results[i].conn = nil // so we don't close it in the defer
|
||||
return conn, nil
|
||||
}
|
||||
merr := multierr.New(errs...)
|
||||
|
||||
// If we get here, then we didn't get anywhere with our dial plan; fall back to just using DNS.
|
||||
a.logf("controlhttp: failed dialing using DialPlan, falling back to DNS; errs=%s", merr.Error())
|
||||
return a.dialHost(ctx, netip.Addr{})
|
||||
}
|
||||
|
||||
// dialHost connects to the configured Dialer.Hostname and upgrades the
|
||||
// connection into a controlbase.Conn. If addr is valid, then no DNS is used
|
||||
// and the connection will be made to the provided address.
|
||||
func (a *Dialer) dialHost(ctx context.Context, addr netip.Addr) (*controlbase.Conn, error) {
|
||||
// Create one shared context used by both port 80 and port 443 dials.
|
||||
// If port 80 is still in flight when 443 returns, this deferred cancel
|
||||
// will stop the port 80 dial.
|
||||
@@ -110,7 +280,7 @@ func (a *Dialer) dial(ctx context.Context) (*controlbase.Conn, error) {
|
||||
}
|
||||
ch := make(chan tryURLRes) // must be unbuffered
|
||||
try := func(u *url.URL) {
|
||||
cbConn, err := a.dialURL(ctx, u)
|
||||
cbConn, err := a.dialURL(ctx, u, addr)
|
||||
select {
|
||||
case ch <- tryURLRes{u, cbConn, err}:
|
||||
case <-ctx.Done():
|
||||
@@ -161,12 +331,12 @@ func (a *Dialer) dial(ctx context.Context) (*controlbase.Conn, error) {
|
||||
}
|
||||
|
||||
// dialURL attempts to connect to the given URL.
|
||||
func (a *Dialer) dialURL(ctx context.Context, u *url.URL) (*controlbase.Conn, error) {
|
||||
func (a *Dialer) dialURL(ctx context.Context, u *url.URL, addr netip.Addr) (*controlbase.Conn, error) {
|
||||
init, cont, err := controlbase.ClientDeferred(a.MachineKey, a.ControlKey, a.ProtocolVersion)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
netConn, err := a.tryURLUpgrade(ctx, u, init)
|
||||
netConn, err := a.tryURLUpgrade(ctx, u, addr, init)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -178,14 +348,27 @@ func (a *Dialer) dialURL(ctx context.Context, u *url.URL) (*controlbase.Conn, er
|
||||
return cbConn, nil
|
||||
}
|
||||
|
||||
// tryURLUpgrade connects to u, and tries to upgrade it to a net.Conn.
|
||||
// tryURLUpgrade connects to u, and tries to upgrade it to a net.Conn. If addr
|
||||
// is valid, then no DNS is used and the connection will be made to the
|
||||
// provided address.
|
||||
//
|
||||
// Only the provided ctx is used, not a.ctx.
|
||||
func (a *Dialer) tryURLUpgrade(ctx context.Context, u *url.URL, init []byte) (net.Conn, error) {
|
||||
dns := &dnscache.Resolver{
|
||||
Forward: dnscache.Get().Forward,
|
||||
LookupIPFallback: dnsfallback.Lookup,
|
||||
UseLastGood: true,
|
||||
func (a *Dialer) tryURLUpgrade(ctx context.Context, u *url.URL, addr netip.Addr, init []byte) (net.Conn, error) {
|
||||
var dns *dnscache.Resolver
|
||||
|
||||
// If we were provided an address to dial, then create a resolver that just
|
||||
// returns that value; otherwise, fall back to DNS.
|
||||
if addr.IsValid() {
|
||||
dns = &dnscache.Resolver{
|
||||
SingleHostStaticResult: []netip.Addr{addr},
|
||||
SingleHost: u.Hostname(),
|
||||
}
|
||||
} else {
|
||||
dns = &dnscache.Resolver{
|
||||
Forward: dnscache.Get().Forward,
|
||||
LookupIPFallback: dnsfallback.Lookup,
|
||||
UseLastGood: true,
|
||||
}
|
||||
}
|
||||
|
||||
var dialer dnscache.DialContextFunc
|
||||
|
@@ -10,6 +10,7 @@ import (
|
||||
"time"
|
||||
|
||||
"tailscale.com/net/dnscache"
|
||||
"tailscale.com/tailcfg"
|
||||
"tailscale.com/types/key"
|
||||
"tailscale.com/types/logger"
|
||||
)
|
||||
@@ -70,9 +71,15 @@ type Dialer struct {
|
||||
// dropped.
|
||||
Logf logger.Logf
|
||||
|
||||
// DialPlan, if set, contains instructions from the control server on
|
||||
// how to connect to it. If present, we will try the methods in this
|
||||
// plan before falling back to DNS.
|
||||
DialPlan *tailcfg.ControlDialPlan
|
||||
|
||||
proxyFunc func(*http.Request) (*url.URL, error) // or nil
|
||||
|
||||
// For tests only
|
||||
drainFinished chan struct{}
|
||||
insecureTLS bool
|
||||
testFallbackDelay time.Duration
|
||||
}
|
||||
|
@@ -13,16 +13,21 @@ import (
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httputil"
|
||||
"net/netip"
|
||||
"net/url"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"tailscale.com/control/controlbase"
|
||||
"tailscale.com/net/dnscache"
|
||||
"tailscale.com/net/socks5"
|
||||
"tailscale.com/net/tsdial"
|
||||
"tailscale.com/tailcfg"
|
||||
"tailscale.com/types/key"
|
||||
"tailscale.com/types/logger"
|
||||
)
|
||||
|
||||
type httpTestParam struct {
|
||||
@@ -444,3 +449,263 @@ func brokenMITMHandler(w http.ResponseWriter, r *http.Request) {
|
||||
w.(http.Flusher).Flush()
|
||||
<-r.Context().Done()
|
||||
}
|
||||
|
||||
func TestDialPlan(t *testing.T) {
|
||||
if runtime.GOOS != "linux" {
|
||||
t.Skip("only works on Linux due to multiple localhost addresses")
|
||||
}
|
||||
|
||||
client, server := key.NewMachine(), key.NewMachine()
|
||||
|
||||
const (
|
||||
testProtocolVersion = 1
|
||||
|
||||
// We need consistent ports for each address; these are chosen
|
||||
// randomly and we hope that they won't conflict during this test.
|
||||
httpPort = "40080"
|
||||
httpsPort = "40443"
|
||||
)
|
||||
|
||||
makeHandler := func(t *testing.T, name string, host netip.Addr, wrap func(http.Handler) http.Handler) {
|
||||
done := make(chan struct{})
|
||||
t.Cleanup(func() {
|
||||
close(done)
|
||||
})
|
||||
var handler http.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
conn, err := AcceptHTTP(context.Background(), w, r, server)
|
||||
if err != nil {
|
||||
log.Print(err)
|
||||
} else {
|
||||
defer conn.Close()
|
||||
}
|
||||
w.Header().Set("X-Handler-Name", name)
|
||||
<-done
|
||||
})
|
||||
if wrap != nil {
|
||||
handler = wrap(handler)
|
||||
}
|
||||
|
||||
httpLn, err := net.Listen("tcp", host.String()+":"+httpPort)
|
||||
if err != nil {
|
||||
t.Fatalf("HTTP listen: %v", err)
|
||||
}
|
||||
httpsLn, err := net.Listen("tcp", host.String()+":"+httpsPort)
|
||||
if err != nil {
|
||||
t.Fatalf("HTTPS listen: %v", err)
|
||||
}
|
||||
|
||||
httpServer := &http.Server{Handler: handler}
|
||||
go httpServer.Serve(httpLn)
|
||||
t.Cleanup(func() {
|
||||
httpServer.Close()
|
||||
})
|
||||
|
||||
httpsServer := &http.Server{
|
||||
Handler: handler,
|
||||
TLSConfig: tlsConfig(t),
|
||||
ErrorLog: logger.StdLogger(logger.WithPrefix(t.Logf, "http.Server.ErrorLog: ")),
|
||||
}
|
||||
go httpsServer.ServeTLS(httpsLn, "", "")
|
||||
t.Cleanup(func() {
|
||||
httpsServer.Close()
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
fallbackAddr := netip.MustParseAddr("127.0.0.1")
|
||||
goodAddr := netip.MustParseAddr("127.0.0.2")
|
||||
otherAddr := netip.MustParseAddr("127.0.0.3")
|
||||
other2Addr := netip.MustParseAddr("127.0.0.4")
|
||||
brokenAddr := netip.MustParseAddr("127.0.0.10")
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
plan *tailcfg.ControlDialPlan
|
||||
wrap func(http.Handler) http.Handler
|
||||
want netip.Addr
|
||||
|
||||
allowFallback bool
|
||||
}{
|
||||
{
|
||||
name: "single",
|
||||
plan: &tailcfg.ControlDialPlan{Candidates: []tailcfg.ControlIPCandidate{
|
||||
{IP: goodAddr, Priority: 1, DialTimeoutSec: 10},
|
||||
}},
|
||||
want: goodAddr,
|
||||
},
|
||||
{
|
||||
name: "broken-then-good",
|
||||
plan: &tailcfg.ControlDialPlan{Candidates: []tailcfg.ControlIPCandidate{
|
||||
// Dials the broken one, which fails, and then
|
||||
// eventually dials the good one and succeeds
|
||||
{IP: brokenAddr, Priority: 2, DialTimeoutSec: 10},
|
||||
{IP: goodAddr, Priority: 1, DialTimeoutSec: 10, DialStartDelaySec: 1},
|
||||
}},
|
||||
want: goodAddr,
|
||||
},
|
||||
{
|
||||
name: "multiple-priority-fast-path",
|
||||
plan: &tailcfg.ControlDialPlan{Candidates: []tailcfg.ControlIPCandidate{
|
||||
// Dials some good IPs and our bad one (which
|
||||
// hangs forever), which then hits the fast
|
||||
// path where we bail without waiting.
|
||||
{IP: brokenAddr, Priority: 1, DialTimeoutSec: 10},
|
||||
{IP: goodAddr, Priority: 1, DialTimeoutSec: 10},
|
||||
{IP: other2Addr, Priority: 1, DialTimeoutSec: 10},
|
||||
{IP: otherAddr, Priority: 2, DialTimeoutSec: 10},
|
||||
}},
|
||||
want: otherAddr,
|
||||
},
|
||||
{
|
||||
name: "multiple-priority-slow-path",
|
||||
plan: &tailcfg.ControlDialPlan{Candidates: []tailcfg.ControlIPCandidate{
|
||||
// Our broken address is the highest priority,
|
||||
// so we don't hit our fast path.
|
||||
{IP: brokenAddr, Priority: 10, DialTimeoutSec: 10},
|
||||
{IP: otherAddr, Priority: 2, DialTimeoutSec: 10},
|
||||
{IP: goodAddr, Priority: 1, DialTimeoutSec: 10},
|
||||
}},
|
||||
want: otherAddr,
|
||||
},
|
||||
{
|
||||
name: "fallback",
|
||||
plan: &tailcfg.ControlDialPlan{Candidates: []tailcfg.ControlIPCandidate{
|
||||
{IP: brokenAddr, Priority: 1, DialTimeoutSec: 1},
|
||||
}},
|
||||
want: fallbackAddr,
|
||||
allowFallback: true,
|
||||
},
|
||||
}
|
||||
for _, tt := range testCases {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
makeHandler(t, "fallback", fallbackAddr, nil)
|
||||
makeHandler(t, "good", goodAddr, nil)
|
||||
makeHandler(t, "other", otherAddr, nil)
|
||||
makeHandler(t, "other2", other2Addr, nil)
|
||||
makeHandler(t, "broken", brokenAddr, func(h http.Handler) http.Handler {
|
||||
return http.HandlerFunc(brokenMITMHandler)
|
||||
})
|
||||
|
||||
dialer := closeTrackDialer{
|
||||
t: t,
|
||||
inner: new(tsdial.Dialer).SystemDial,
|
||||
conns: make(map[*closeTrackConn]bool),
|
||||
}
|
||||
defer dialer.Done()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
// By default, we intentionally point to something that
|
||||
// we know won't connect, since we want a fallback to
|
||||
// DNS to be an error.
|
||||
host := "example.com"
|
||||
if tt.allowFallback {
|
||||
host = "localhost"
|
||||
}
|
||||
|
||||
drained := make(chan struct{})
|
||||
a := &Dialer{
|
||||
Hostname: host,
|
||||
HTTPPort: httpPort,
|
||||
HTTPSPort: httpsPort,
|
||||
MachineKey: client,
|
||||
ControlKey: server.Public(),
|
||||
ProtocolVersion: testProtocolVersion,
|
||||
Dialer: dialer.Dial,
|
||||
Logf: t.Logf,
|
||||
DialPlan: tt.plan,
|
||||
proxyFunc: func(*http.Request) (*url.URL, error) { return nil, nil },
|
||||
drainFinished: drained,
|
||||
insecureTLS: true,
|
||||
testFallbackDelay: 50 * time.Millisecond,
|
||||
}
|
||||
|
||||
conn, err := a.dial(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("dialing controlhttp: %v", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
raddr := conn.RemoteAddr().(*net.TCPAddr)
|
||||
|
||||
got, ok := netip.AddrFromSlice(raddr.IP)
|
||||
if !ok {
|
||||
t.Errorf("invalid remote IP: %v", raddr.IP)
|
||||
} else if got != tt.want {
|
||||
t.Errorf("got connection from %q; want %q", got, tt.want)
|
||||
} else {
|
||||
t.Logf("successfully connected to %q", raddr.String())
|
||||
}
|
||||
|
||||
// Wait until our dialer drains so we can verify that
|
||||
// all connections are closed.
|
||||
<-drained
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
type closeTrackDialer struct {
|
||||
t testing.TB
|
||||
inner dnscache.DialContextFunc
|
||||
mu sync.Mutex
|
||||
conns map[*closeTrackConn]bool
|
||||
}
|
||||
|
||||
func (d *closeTrackDialer) Dial(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
c, err := d.inner(ctx, network, addr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ct := &closeTrackConn{Conn: c, d: d}
|
||||
|
||||
d.mu.Lock()
|
||||
d.conns[ct] = true
|
||||
d.mu.Unlock()
|
||||
return ct, nil
|
||||
}
|
||||
|
||||
func (d *closeTrackDialer) Done() {
|
||||
// Unfortunately, tsdial.Dialer.SystemDial closes connections
|
||||
// asynchronously in a goroutine, so we can't assume that everything is
|
||||
// closed by the time we get here.
|
||||
//
|
||||
// Sleep/wait a few times on the assumption that things will close
|
||||
// "eventually".
|
||||
const iters = 100
|
||||
for i := 0; i < iters; i++ {
|
||||
d.mu.Lock()
|
||||
if len(d.conns) == 0 {
|
||||
d.mu.Unlock()
|
||||
return
|
||||
}
|
||||
|
||||
// Only error on last iteration
|
||||
if i != iters-1 {
|
||||
d.mu.Unlock()
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
continue
|
||||
}
|
||||
|
||||
for conn := range d.conns {
|
||||
d.t.Errorf("expected close of conn %p; RemoteAddr=%q", conn, conn.RemoteAddr().String())
|
||||
}
|
||||
d.mu.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
func (d *closeTrackDialer) noteClose(c *closeTrackConn) {
|
||||
d.mu.Lock()
|
||||
delete(d.conns, c) // safe if already deleted
|
||||
d.mu.Unlock()
|
||||
}
|
||||
|
||||
type closeTrackConn struct {
|
||||
net.Conn
|
||||
d *closeTrackDialer
|
||||
}
|
||||
|
||||
func (c *closeTrackConn) Close() error {
|
||||
c.d.noteClose(c)
|
||||
return c.Conn.Close()
|
||||
}
|
||||
|
Reference in New Issue
Block a user