mirror of
https://github.com/tailscale/tailscale.git
synced 2025-08-22 11:01:54 +00:00
tailcfg, control/controlhttp, control/controlclient: add ControlDialPlan field (#5648)
* tailcfg, control/controlhttp, control/controlclient: add ControlDialPlan field This field allows the control server to provide explicit information about how to connect to it; useful if the client's link status can change after the initial connection, or if the DNS settings pushed by the control server break future connections. Change-Id: I720afe6289ec27d40a41b3dcb310ec45bd7e5f3e Signed-off-by: Andrew Dunham <andrew@tailscale.com>
This commit is contained in:
@@ -13,16 +13,21 @@ import (
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httputil"
|
||||
"net/netip"
|
||||
"net/url"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"tailscale.com/control/controlbase"
|
||||
"tailscale.com/net/dnscache"
|
||||
"tailscale.com/net/socks5"
|
||||
"tailscale.com/net/tsdial"
|
||||
"tailscale.com/tailcfg"
|
||||
"tailscale.com/types/key"
|
||||
"tailscale.com/types/logger"
|
||||
)
|
||||
|
||||
type httpTestParam struct {
|
||||
@@ -444,3 +449,263 @@ func brokenMITMHandler(w http.ResponseWriter, r *http.Request) {
|
||||
w.(http.Flusher).Flush()
|
||||
<-r.Context().Done()
|
||||
}
|
||||
|
||||
func TestDialPlan(t *testing.T) {
|
||||
if runtime.GOOS != "linux" {
|
||||
t.Skip("only works on Linux due to multiple localhost addresses")
|
||||
}
|
||||
|
||||
client, server := key.NewMachine(), key.NewMachine()
|
||||
|
||||
const (
|
||||
testProtocolVersion = 1
|
||||
|
||||
// We need consistent ports for each address; these are chosen
|
||||
// randomly and we hope that they won't conflict during this test.
|
||||
httpPort = "40080"
|
||||
httpsPort = "40443"
|
||||
)
|
||||
|
||||
makeHandler := func(t *testing.T, name string, host netip.Addr, wrap func(http.Handler) http.Handler) {
|
||||
done := make(chan struct{})
|
||||
t.Cleanup(func() {
|
||||
close(done)
|
||||
})
|
||||
var handler http.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
conn, err := AcceptHTTP(context.Background(), w, r, server)
|
||||
if err != nil {
|
||||
log.Print(err)
|
||||
} else {
|
||||
defer conn.Close()
|
||||
}
|
||||
w.Header().Set("X-Handler-Name", name)
|
||||
<-done
|
||||
})
|
||||
if wrap != nil {
|
||||
handler = wrap(handler)
|
||||
}
|
||||
|
||||
httpLn, err := net.Listen("tcp", host.String()+":"+httpPort)
|
||||
if err != nil {
|
||||
t.Fatalf("HTTP listen: %v", err)
|
||||
}
|
||||
httpsLn, err := net.Listen("tcp", host.String()+":"+httpsPort)
|
||||
if err != nil {
|
||||
t.Fatalf("HTTPS listen: %v", err)
|
||||
}
|
||||
|
||||
httpServer := &http.Server{Handler: handler}
|
||||
go httpServer.Serve(httpLn)
|
||||
t.Cleanup(func() {
|
||||
httpServer.Close()
|
||||
})
|
||||
|
||||
httpsServer := &http.Server{
|
||||
Handler: handler,
|
||||
TLSConfig: tlsConfig(t),
|
||||
ErrorLog: logger.StdLogger(logger.WithPrefix(t.Logf, "http.Server.ErrorLog: ")),
|
||||
}
|
||||
go httpsServer.ServeTLS(httpsLn, "", "")
|
||||
t.Cleanup(func() {
|
||||
httpsServer.Close()
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
fallbackAddr := netip.MustParseAddr("127.0.0.1")
|
||||
goodAddr := netip.MustParseAddr("127.0.0.2")
|
||||
otherAddr := netip.MustParseAddr("127.0.0.3")
|
||||
other2Addr := netip.MustParseAddr("127.0.0.4")
|
||||
brokenAddr := netip.MustParseAddr("127.0.0.10")
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
plan *tailcfg.ControlDialPlan
|
||||
wrap func(http.Handler) http.Handler
|
||||
want netip.Addr
|
||||
|
||||
allowFallback bool
|
||||
}{
|
||||
{
|
||||
name: "single",
|
||||
plan: &tailcfg.ControlDialPlan{Candidates: []tailcfg.ControlIPCandidate{
|
||||
{IP: goodAddr, Priority: 1, DialTimeoutSec: 10},
|
||||
}},
|
||||
want: goodAddr,
|
||||
},
|
||||
{
|
||||
name: "broken-then-good",
|
||||
plan: &tailcfg.ControlDialPlan{Candidates: []tailcfg.ControlIPCandidate{
|
||||
// Dials the broken one, which fails, and then
|
||||
// eventually dials the good one and succeeds
|
||||
{IP: brokenAddr, Priority: 2, DialTimeoutSec: 10},
|
||||
{IP: goodAddr, Priority: 1, DialTimeoutSec: 10, DialStartDelaySec: 1},
|
||||
}},
|
||||
want: goodAddr,
|
||||
},
|
||||
{
|
||||
name: "multiple-priority-fast-path",
|
||||
plan: &tailcfg.ControlDialPlan{Candidates: []tailcfg.ControlIPCandidate{
|
||||
// Dials some good IPs and our bad one (which
|
||||
// hangs forever), which then hits the fast
|
||||
// path where we bail without waiting.
|
||||
{IP: brokenAddr, Priority: 1, DialTimeoutSec: 10},
|
||||
{IP: goodAddr, Priority: 1, DialTimeoutSec: 10},
|
||||
{IP: other2Addr, Priority: 1, DialTimeoutSec: 10},
|
||||
{IP: otherAddr, Priority: 2, DialTimeoutSec: 10},
|
||||
}},
|
||||
want: otherAddr,
|
||||
},
|
||||
{
|
||||
name: "multiple-priority-slow-path",
|
||||
plan: &tailcfg.ControlDialPlan{Candidates: []tailcfg.ControlIPCandidate{
|
||||
// Our broken address is the highest priority,
|
||||
// so we don't hit our fast path.
|
||||
{IP: brokenAddr, Priority: 10, DialTimeoutSec: 10},
|
||||
{IP: otherAddr, Priority: 2, DialTimeoutSec: 10},
|
||||
{IP: goodAddr, Priority: 1, DialTimeoutSec: 10},
|
||||
}},
|
||||
want: otherAddr,
|
||||
},
|
||||
{
|
||||
name: "fallback",
|
||||
plan: &tailcfg.ControlDialPlan{Candidates: []tailcfg.ControlIPCandidate{
|
||||
{IP: brokenAddr, Priority: 1, DialTimeoutSec: 1},
|
||||
}},
|
||||
want: fallbackAddr,
|
||||
allowFallback: true,
|
||||
},
|
||||
}
|
||||
for _, tt := range testCases {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
makeHandler(t, "fallback", fallbackAddr, nil)
|
||||
makeHandler(t, "good", goodAddr, nil)
|
||||
makeHandler(t, "other", otherAddr, nil)
|
||||
makeHandler(t, "other2", other2Addr, nil)
|
||||
makeHandler(t, "broken", brokenAddr, func(h http.Handler) http.Handler {
|
||||
return http.HandlerFunc(brokenMITMHandler)
|
||||
})
|
||||
|
||||
dialer := closeTrackDialer{
|
||||
t: t,
|
||||
inner: new(tsdial.Dialer).SystemDial,
|
||||
conns: make(map[*closeTrackConn]bool),
|
||||
}
|
||||
defer dialer.Done()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
// By default, we intentionally point to something that
|
||||
// we know won't connect, since we want a fallback to
|
||||
// DNS to be an error.
|
||||
host := "example.com"
|
||||
if tt.allowFallback {
|
||||
host = "localhost"
|
||||
}
|
||||
|
||||
drained := make(chan struct{})
|
||||
a := &Dialer{
|
||||
Hostname: host,
|
||||
HTTPPort: httpPort,
|
||||
HTTPSPort: httpsPort,
|
||||
MachineKey: client,
|
||||
ControlKey: server.Public(),
|
||||
ProtocolVersion: testProtocolVersion,
|
||||
Dialer: dialer.Dial,
|
||||
Logf: t.Logf,
|
||||
DialPlan: tt.plan,
|
||||
proxyFunc: func(*http.Request) (*url.URL, error) { return nil, nil },
|
||||
drainFinished: drained,
|
||||
insecureTLS: true,
|
||||
testFallbackDelay: 50 * time.Millisecond,
|
||||
}
|
||||
|
||||
conn, err := a.dial(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("dialing controlhttp: %v", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
raddr := conn.RemoteAddr().(*net.TCPAddr)
|
||||
|
||||
got, ok := netip.AddrFromSlice(raddr.IP)
|
||||
if !ok {
|
||||
t.Errorf("invalid remote IP: %v", raddr.IP)
|
||||
} else if got != tt.want {
|
||||
t.Errorf("got connection from %q; want %q", got, tt.want)
|
||||
} else {
|
||||
t.Logf("successfully connected to %q", raddr.String())
|
||||
}
|
||||
|
||||
// Wait until our dialer drains so we can verify that
|
||||
// all connections are closed.
|
||||
<-drained
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
type closeTrackDialer struct {
|
||||
t testing.TB
|
||||
inner dnscache.DialContextFunc
|
||||
mu sync.Mutex
|
||||
conns map[*closeTrackConn]bool
|
||||
}
|
||||
|
||||
func (d *closeTrackDialer) Dial(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
c, err := d.inner(ctx, network, addr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ct := &closeTrackConn{Conn: c, d: d}
|
||||
|
||||
d.mu.Lock()
|
||||
d.conns[ct] = true
|
||||
d.mu.Unlock()
|
||||
return ct, nil
|
||||
}
|
||||
|
||||
func (d *closeTrackDialer) Done() {
|
||||
// Unfortunately, tsdial.Dialer.SystemDial closes connections
|
||||
// asynchronously in a goroutine, so we can't assume that everything is
|
||||
// closed by the time we get here.
|
||||
//
|
||||
// Sleep/wait a few times on the assumption that things will close
|
||||
// "eventually".
|
||||
const iters = 100
|
||||
for i := 0; i < iters; i++ {
|
||||
d.mu.Lock()
|
||||
if len(d.conns) == 0 {
|
||||
d.mu.Unlock()
|
||||
return
|
||||
}
|
||||
|
||||
// Only error on last iteration
|
||||
if i != iters-1 {
|
||||
d.mu.Unlock()
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
continue
|
||||
}
|
||||
|
||||
for conn := range d.conns {
|
||||
d.t.Errorf("expected close of conn %p; RemoteAddr=%q", conn, conn.RemoteAddr().String())
|
||||
}
|
||||
d.mu.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
func (d *closeTrackDialer) noteClose(c *closeTrackConn) {
|
||||
d.mu.Lock()
|
||||
delete(d.conns, c) // safe if already deleted
|
||||
d.mu.Unlock()
|
||||
}
|
||||
|
||||
type closeTrackConn struct {
|
||||
net.Conn
|
||||
d *closeTrackDialer
|
||||
}
|
||||
|
||||
func (c *closeTrackConn) Close() error {
|
||||
c.d.noteClose(c)
|
||||
return c.Conn.Close()
|
||||
}
|
||||
|
Reference in New Issue
Block a user