diff --git a/client/local/local_test.go b/client/local/local_test.go index 4322e4dde..0e01e74cd 100644 --- a/client/local/local_test.go +++ b/client/local/local_test.go @@ -9,10 +9,10 @@ import ( "context" "net" "net/http" - "net/http/httptest" "testing" "tailscale.com/tstest/deptest" + "tailscale.com/tstest/nettest" "tailscale.com/types/key" ) @@ -36,15 +36,15 @@ func TestGetServeConfigFromJSON(t *testing.T) { } func TestWhoIsPeerNotFound(t *testing.T) { - ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + nw := nettest.GetNetwork(t) + ts := nettest.NewHTTPServer(nw, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(404) })) defer ts.Close() lc := &Client{ Dial: func(ctx context.Context, network, addr string) (net.Conn, error) { - var std net.Dialer - return std.DialContext(ctx, network, ts.Listener.Addr().(*net.TCPAddr).String()) + return nw.Dial(ctx, network, ts.Listener.Addr().String()) }, } var k key.NodePublic diff --git a/client/web/web_test.go b/client/web/web_test.go index 334b403a6..2a6bc787a 100644 --- a/client/web/web_test.go +++ b/client/web/web_test.go @@ -28,6 +28,7 @@ import ( "tailscale.com/ipn/ipnstate" "tailscale.com/net/memnet" "tailscale.com/tailcfg" + "tailscale.com/tstest/nettest" "tailscale.com/types/views" "tailscale.com/util/httpm" ) @@ -1508,7 +1509,7 @@ func TestCSRFProtect(t *testing.T) { } }) h := s.withCSRF(mux) - ser := httptest.NewServer(h) + ser := nettest.NewHTTPServer(nettest.GetNetwork(t), h) defer ser.Close() jar, err := cookiejar.New(nil) diff --git a/control/controlclient/noise_test.go b/control/controlclient/noise_test.go index dadf237df..4904016f2 100644 --- a/control/controlclient/noise_test.go +++ b/control/controlclient/noise_test.go @@ -10,7 +10,6 @@ import ( "io" "math" "net/http" - "net/http/httptest" "testing" "time" @@ -20,6 +19,7 @@ import ( "tailscale.com/net/netmon" "tailscale.com/net/tsdial" "tailscale.com/tailcfg" + "tailscale.com/tstest/nettest" "tailscale.com/types/key" "tailscale.com/types/logger" ) @@ -178,7 +178,8 @@ func (tt noiseClientTest) run(t *testing.T) { const msg = "Hello, client" h2 := &http2.Server{} - hs := httptest.NewServer(&Upgrader{ + nw := nettest.GetNetwork(t) + hs := nettest.NewHTTPServer(nw, &Upgrader{ h2srv: h2, noiseKeyPriv: serverPrivate, sendEarlyPayload: tt.sendEarlyPayload, @@ -193,6 +194,10 @@ func (tt noiseClientTest) run(t *testing.T) { defer hs.Close() dialer := tsdial.NewDialer(netmon.NewStatic()) + if nettest.PreferMemNetwork() { + dialer.SetSystemDialerForTest(nw.Dial) + } + nc, err := NewNoiseClient(NoiseOpts{ PrivKey: clientPrivate, ServerPubKey: serverPrivate.Public(), diff --git a/net/tsdial/tsdial.go b/net/tsdial/tsdial.go index 8d287fdb0..8fddd63f2 100644 --- a/net/tsdial/tsdial.go +++ b/net/tsdial/tsdial.go @@ -71,6 +71,7 @@ type Dialer struct { netnsDialerOnce sync.Once netnsDialer netns.Dialer + sysDialForTest func(_ context.Context, network, addr string) (net.Conn, error) // or nil routes atomic.Pointer[bart.Table[bool]] // or nil if UserDial should not use routes. `true` indicates routes that point into the Tailscale interface @@ -361,6 +362,13 @@ func (d *Dialer) logf(format string, args ...any) { } } +// SetSystemDialerForTest sets an alternate function to use for SystemDial +// instead of netns.Dialer. This is intended for use with nettest.MemoryNetwork. +func (d *Dialer) SetSystemDialerForTest(fn func(ctx context.Context, network, addr string) (net.Conn, error)) { + testenv.AssertInTest() + d.sysDialForTest = fn +} + // SystemDial connects to the provided network address without going over // Tailscale. It prefers going over the default interface and closes existing // connections if the default interface changes. It is used to connect to @@ -380,10 +388,16 @@ func (d *Dialer) SystemDial(ctx context.Context, network, addr string) (net.Conn return nil, net.ErrClosed } - d.netnsDialerOnce.Do(func() { - d.netnsDialer = netns.NewDialer(d.logf, d.netMon) - }) - c, err := d.netnsDialer.DialContext(ctx, network, addr) + var c net.Conn + var err error + if d.sysDialForTest != nil { + c, err = d.sysDialForTest(ctx, network, addr) + } else { + d.netnsDialerOnce.Do(func() { + d.netnsDialer = netns.NewDialer(d.logf, d.netMon) + }) + c, err = d.netnsDialer.DialContext(ctx, network, addr) + } if err != nil { return nil, err } diff --git a/tstest/nettest/nettest.go b/tstest/nettest/nettest.go index 47c8857a5..f03d6987b 100644 --- a/tstest/nettest/nettest.go +++ b/tstest/nettest/nettest.go @@ -6,11 +6,23 @@ package nettest import ( + "context" + "flag" + "fmt" + "net" + "net/http" + "net/http/httptest" + "net/netip" + "sync" "testing" + "tailscale.com/net/memnet" "tailscale.com/net/netmon" + "tailscale.com/util/testenv" ) +var useMemNet = flag.Bool("use-test-memnet", false, "prefer using in-memory network for tests") + // SkipIfNoNetwork skips the test if it looks like there's no network // access. func SkipIfNoNetwork(t testing.TB) { @@ -19,3 +31,190 @@ func SkipIfNoNetwork(t testing.TB) { t.Skip("skipping; test requires network but no interface is up") } } + +// Network is an interface for use in tests that describes either [RealNetwork] +// or [MemNetwork]. +type Network interface { + NewLocalTCPListener() net.Listener + Listen(network, address string) (net.Listener, error) + Dial(ctx context.Context, network, address string) (net.Conn, error) +} + +// PreferMemNetwork reports whether the --use-test-memnet flag is set. +func PreferMemNetwork() bool { + return *useMemNet +} + +// GetNetwork returns the appropriate Network implementation based on +// whether the --use-test-memnet flag is set. +// +// Each call generates a new network. +func GetNetwork(tb testing.TB) Network { + var n Network + if PreferMemNetwork() { + n = MemNetwork() + } else { + n = RealNetwork() + } + + detectLeaks := PreferMemNetwork() || !testenv.InParallelTest(tb) + if detectLeaks { + tb.Cleanup(func() { + // TODO: leak detection, making sure no connections + // remain at the end of the test. For real network, + // snapshot conns in pid table before & after. + }) + } + return n +} + +// RealNetwork returns a Network implementation that uses the real +// net package. +func RealNetwork() Network { return realNetwork{} } + +// realNetwork implements [Network] using the real net package. +type realNetwork struct{} + +func (realNetwork) Listen(network, address string) (net.Listener, error) { + return net.Listen(network, address) +} + +func (realNetwork) Dial(ctx context.Context, network, address string) (net.Conn, error) { + var d net.Dialer + return d.DialContext(ctx, network, address) +} + +func (realNetwork) NewLocalTCPListener() net.Listener { + ln, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + if ln, err = net.Listen("tcp6", "[::1]:0"); err != nil { + panic(fmt.Sprintf("httptest: failed to listen on a port: %v", err)) + } + } + return ln +} + +// MemNetwork returns a Network implementation that uses an in-memory +// network for testing. It is only suitable for tests that do not +// require real network access. +func MemNetwork() Network { return &memNetwork{} } + +// memNetwork implements [Network] using an in-memory network. +type memNetwork struct { + mu sync.Mutex + lns map[string]*memnet.Listener // address -> listener +} + +func (m *memNetwork) Listen(network, address string) (net.Listener, error) { + if network != "tcp" && network != "tcp4" && network != "tcp6" { + return nil, fmt.Errorf("memNetwork: Listen called with unsupported network %q", network) + } + ap, err := netip.ParseAddrPort(address) + if err != nil { + return nil, fmt.Errorf("memNetwork: Listen called with invalid address %q: %w", address, err) + } + + m.mu.Lock() + defer m.mu.Unlock() + + if m.lns == nil { + m.lns = make(map[string]*memnet.Listener) + } + port := ap.Port() + for { + if port == 0 { + port = 33000 + } + key := net.JoinHostPort(ap.Addr().String(), fmt.Sprint(port)) + _, ok := m.lns[key] + if ok { + if ap.Port() != 0 { + return nil, fmt.Errorf("memNetwork: Listen called with duplicate address %q", address) + } + port++ + continue + } + ln := memnet.Listen(key) + m.lns[key] = ln + return ln, nil + } +} + +func (m *memNetwork) NewLocalTCPListener() net.Listener { + ln, err := m.Listen("tcp", "127.0.0.1:0") + if err != nil { + panic(fmt.Sprintf("memNetwork: failed to create local TCP listener: %v", err)) + } + return ln +} + +func (m *memNetwork) Dial(ctx context.Context, network, address string) (net.Conn, error) { + if network != "tcp" && network != "tcp4" && network != "tcp6" { + return nil, fmt.Errorf("memNetwork: Dial called with unsupported network %q", network) + } + m.mu.Lock() + ln, ok := m.lns[address] + m.mu.Unlock() + if !ok { + return nil, fmt.Errorf("memNetwork: Dial called on unknown address %q", address) + } + return ln.Dial(ctx, network, address) +} + +// NewHTTPServer starts and returns a new [httptest.Server]. +// The caller should call Close when finished, to shut it down. +func NewHTTPServer(net Network, handler http.Handler) *httptest.Server { + ts := NewUnstartedHTTPServer(net, handler) + ts.Start() + return ts +} + +// NewUnstartedHTTPServer returns a new [httptest.Server] but doesn't start it. +// +// After changing its configuration, the caller should call Start or +// StartTLS. +// +// The caller should call Close when finished, to shut it down. +func NewUnstartedHTTPServer(nw Network, handler http.Handler) *httptest.Server { + s := &httptest.Server{ + Config: &http.Server{Handler: handler}, + } + ln := nw.NewLocalTCPListener() + s.Listener = &listenerOnAddrOnce{ + Listener: ln, + fn: func() { + c := s.Client() + if c == nil { + // This httptest.Server.Start initialization order has been true + // for over 10 years. Let's keep counting on it. + panic("httptest.Server: Client not initialized before Addr called") + } + if c.Transport == nil { + c.Transport = &http.Transport{} + } + tr := c.Transport.(*http.Transport) + if tr.Dial != nil || tr.DialContext != nil { + panic("unexpected non-nil Dial or DialContext in httptest.Server.Client.Transport") + } + tr.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) { + return nw.Dial(ctx, network, addr) + } + }, + } + return s +} + +// listenerOnAddrOnce is a net.Listener that wraps another net.Listener +// and calls a function the first time its Addr is called. +type listenerOnAddrOnce struct { + net.Listener + once sync.Once + fn func() +} + +func (ln *listenerOnAddrOnce) Addr() net.Addr { + ln.once.Do(func() { + ln.fn() + }) + return ln.Listener.Addr() +} diff --git a/util/testenv/testenv.go b/util/testenv/testenv.go index 3e23baef4..aa6660411 100644 --- a/util/testenv/testenv.go +++ b/util/testenv/testenv.go @@ -58,3 +58,10 @@ func InParallelTest(t TB) (isParallel bool) { t.Chdir(".") // panics in a t.Parallel test return false } + +// AssertInTest panics if called outside of a test binary. +func AssertInTest() { + if !InTest() { + panic("func called outside of test binary") + } +}