mirror of
https://github.com/tailscale/tailscale.git
synced 2025-04-22 17:01:43 +00:00
nettest, *: add option to run HTTP tests with in-memory network
To avoid ephemeral port / TIME_WAIT exhaustion with high --count values, and to eventually detect leaked connections in tests. (Later the memory network will register a Cleanup on the TB to verify that everything's been shut down) Updates tailscale/corp#27636 Change-Id: Id06f1ae750d8719c5a75d871654574a8226d2733 Signed-off-by: Brad Fitzpatrick <bradfitz@tailscale.com>
This commit is contained in:
parent
6d117d64a2
commit
c76d075472
@ -9,10 +9,10 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"tailscale.com/tstest/deptest"
|
"tailscale.com/tstest/deptest"
|
||||||
|
"tailscale.com/tstest/nettest"
|
||||||
"tailscale.com/types/key"
|
"tailscale.com/types/key"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -36,15 +36,15 @@ func TestGetServeConfigFromJSON(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestWhoIsPeerNotFound(t *testing.T) {
|
func TestWhoIsPeerNotFound(t *testing.T) {
|
||||||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
nw := nettest.GetNetwork(t)
|
||||||
|
ts := nettest.NewHTTPServer(nw, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
w.WriteHeader(404)
|
w.WriteHeader(404)
|
||||||
}))
|
}))
|
||||||
defer ts.Close()
|
defer ts.Close()
|
||||||
|
|
||||||
lc := &Client{
|
lc := &Client{
|
||||||
Dial: func(ctx context.Context, network, addr string) (net.Conn, error) {
|
Dial: func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||||
var std net.Dialer
|
return nw.Dial(ctx, network, ts.Listener.Addr().String())
|
||||||
return std.DialContext(ctx, network, ts.Listener.Addr().(*net.TCPAddr).String())
|
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
var k key.NodePublic
|
var k key.NodePublic
|
||||||
|
@ -28,6 +28,7 @@ import (
|
|||||||
"tailscale.com/ipn/ipnstate"
|
"tailscale.com/ipn/ipnstate"
|
||||||
"tailscale.com/net/memnet"
|
"tailscale.com/net/memnet"
|
||||||
"tailscale.com/tailcfg"
|
"tailscale.com/tailcfg"
|
||||||
|
"tailscale.com/tstest/nettest"
|
||||||
"tailscale.com/types/views"
|
"tailscale.com/types/views"
|
||||||
"tailscale.com/util/httpm"
|
"tailscale.com/util/httpm"
|
||||||
)
|
)
|
||||||
@ -1508,7 +1509,7 @@ func TestCSRFProtect(t *testing.T) {
|
|||||||
}
|
}
|
||||||
})
|
})
|
||||||
h := s.withCSRF(mux)
|
h := s.withCSRF(mux)
|
||||||
ser := httptest.NewServer(h)
|
ser := nettest.NewHTTPServer(nettest.GetNetwork(t), h)
|
||||||
defer ser.Close()
|
defer ser.Close()
|
||||||
|
|
||||||
jar, err := cookiejar.New(nil)
|
jar, err := cookiejar.New(nil)
|
||||||
|
@ -10,7 +10,6 @@ import (
|
|||||||
"io"
|
"io"
|
||||||
"math"
|
"math"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@ -20,6 +19,7 @@ import (
|
|||||||
"tailscale.com/net/netmon"
|
"tailscale.com/net/netmon"
|
||||||
"tailscale.com/net/tsdial"
|
"tailscale.com/net/tsdial"
|
||||||
"tailscale.com/tailcfg"
|
"tailscale.com/tailcfg"
|
||||||
|
"tailscale.com/tstest/nettest"
|
||||||
"tailscale.com/types/key"
|
"tailscale.com/types/key"
|
||||||
"tailscale.com/types/logger"
|
"tailscale.com/types/logger"
|
||||||
)
|
)
|
||||||
@ -178,7 +178,8 @@ func (tt noiseClientTest) run(t *testing.T) {
|
|||||||
|
|
||||||
const msg = "Hello, client"
|
const msg = "Hello, client"
|
||||||
h2 := &http2.Server{}
|
h2 := &http2.Server{}
|
||||||
hs := httptest.NewServer(&Upgrader{
|
nw := nettest.GetNetwork(t)
|
||||||
|
hs := nettest.NewHTTPServer(nw, &Upgrader{
|
||||||
h2srv: h2,
|
h2srv: h2,
|
||||||
noiseKeyPriv: serverPrivate,
|
noiseKeyPriv: serverPrivate,
|
||||||
sendEarlyPayload: tt.sendEarlyPayload,
|
sendEarlyPayload: tt.sendEarlyPayload,
|
||||||
@ -193,6 +194,10 @@ func (tt noiseClientTest) run(t *testing.T) {
|
|||||||
defer hs.Close()
|
defer hs.Close()
|
||||||
|
|
||||||
dialer := tsdial.NewDialer(netmon.NewStatic())
|
dialer := tsdial.NewDialer(netmon.NewStatic())
|
||||||
|
if nettest.PreferMemNetwork() {
|
||||||
|
dialer.SetSystemDialerForTest(nw.Dial)
|
||||||
|
}
|
||||||
|
|
||||||
nc, err := NewNoiseClient(NoiseOpts{
|
nc, err := NewNoiseClient(NoiseOpts{
|
||||||
PrivKey: clientPrivate,
|
PrivKey: clientPrivate,
|
||||||
ServerPubKey: serverPrivate.Public(),
|
ServerPubKey: serverPrivate.Public(),
|
||||||
|
@ -71,6 +71,7 @@ type Dialer struct {
|
|||||||
|
|
||||||
netnsDialerOnce sync.Once
|
netnsDialerOnce sync.Once
|
||||||
netnsDialer netns.Dialer
|
netnsDialer netns.Dialer
|
||||||
|
sysDialForTest func(_ context.Context, network, addr string) (net.Conn, error) // or nil
|
||||||
|
|
||||||
routes atomic.Pointer[bart.Table[bool]] // or nil if UserDial should not use routes. `true` indicates routes that point into the Tailscale interface
|
routes atomic.Pointer[bart.Table[bool]] // or nil if UserDial should not use routes. `true` indicates routes that point into the Tailscale interface
|
||||||
|
|
||||||
@ -361,6 +362,13 @@ func (d *Dialer) logf(format string, args ...any) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetSystemDialerForTest sets an alternate function to use for SystemDial
|
||||||
|
// instead of netns.Dialer. This is intended for use with nettest.MemoryNetwork.
|
||||||
|
func (d *Dialer) SetSystemDialerForTest(fn func(ctx context.Context, network, addr string) (net.Conn, error)) {
|
||||||
|
testenv.AssertInTest()
|
||||||
|
d.sysDialForTest = fn
|
||||||
|
}
|
||||||
|
|
||||||
// SystemDial connects to the provided network address without going over
|
// SystemDial connects to the provided network address without going over
|
||||||
// Tailscale. It prefers going over the default interface and closes existing
|
// Tailscale. It prefers going over the default interface and closes existing
|
||||||
// connections if the default interface changes. It is used to connect to
|
// connections if the default interface changes. It is used to connect to
|
||||||
@ -380,10 +388,16 @@ func (d *Dialer) SystemDial(ctx context.Context, network, addr string) (net.Conn
|
|||||||
return nil, net.ErrClosed
|
return nil, net.ErrClosed
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var c net.Conn
|
||||||
|
var err error
|
||||||
|
if d.sysDialForTest != nil {
|
||||||
|
c, err = d.sysDialForTest(ctx, network, addr)
|
||||||
|
} else {
|
||||||
d.netnsDialerOnce.Do(func() {
|
d.netnsDialerOnce.Do(func() {
|
||||||
d.netnsDialer = netns.NewDialer(d.logf, d.netMon)
|
d.netnsDialer = netns.NewDialer(d.logf, d.netMon)
|
||||||
})
|
})
|
||||||
c, err := d.netnsDialer.DialContext(ctx, network, addr)
|
c, err = d.netnsDialer.DialContext(ctx, network, addr)
|
||||||
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -6,11 +6,23 @@
|
|||||||
package nettest
|
package nettest
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
|
"flag"
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"net/netip"
|
||||||
|
"sync"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"tailscale.com/net/memnet"
|
||||||
"tailscale.com/net/netmon"
|
"tailscale.com/net/netmon"
|
||||||
|
"tailscale.com/util/testenv"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
var useMemNet = flag.Bool("use-test-memnet", false, "prefer using in-memory network for tests")
|
||||||
|
|
||||||
// SkipIfNoNetwork skips the test if it looks like there's no network
|
// SkipIfNoNetwork skips the test if it looks like there's no network
|
||||||
// access.
|
// access.
|
||||||
func SkipIfNoNetwork(t testing.TB) {
|
func SkipIfNoNetwork(t testing.TB) {
|
||||||
@ -19,3 +31,190 @@ func SkipIfNoNetwork(t testing.TB) {
|
|||||||
t.Skip("skipping; test requires network but no interface is up")
|
t.Skip("skipping; test requires network but no interface is up")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Network is an interface for use in tests that describes either [RealNetwork]
|
||||||
|
// or [MemNetwork].
|
||||||
|
type Network interface {
|
||||||
|
NewLocalTCPListener() net.Listener
|
||||||
|
Listen(network, address string) (net.Listener, error)
|
||||||
|
Dial(ctx context.Context, network, address string) (net.Conn, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
// PreferMemNetwork reports whether the --use-test-memnet flag is set.
|
||||||
|
func PreferMemNetwork() bool {
|
||||||
|
return *useMemNet
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetNetwork returns the appropriate Network implementation based on
|
||||||
|
// whether the --use-test-memnet flag is set.
|
||||||
|
//
|
||||||
|
// Each call generates a new network.
|
||||||
|
func GetNetwork(tb testing.TB) Network {
|
||||||
|
var n Network
|
||||||
|
if PreferMemNetwork() {
|
||||||
|
n = MemNetwork()
|
||||||
|
} else {
|
||||||
|
n = RealNetwork()
|
||||||
|
}
|
||||||
|
|
||||||
|
detectLeaks := PreferMemNetwork() || !testenv.InParallelTest(tb)
|
||||||
|
if detectLeaks {
|
||||||
|
tb.Cleanup(func() {
|
||||||
|
// TODO: leak detection, making sure no connections
|
||||||
|
// remain at the end of the test. For real network,
|
||||||
|
// snapshot conns in pid table before & after.
|
||||||
|
})
|
||||||
|
}
|
||||||
|
return n
|
||||||
|
}
|
||||||
|
|
||||||
|
// RealNetwork returns a Network implementation that uses the real
|
||||||
|
// net package.
|
||||||
|
func RealNetwork() Network { return realNetwork{} }
|
||||||
|
|
||||||
|
// realNetwork implements [Network] using the real net package.
|
||||||
|
type realNetwork struct{}
|
||||||
|
|
||||||
|
func (realNetwork) Listen(network, address string) (net.Listener, error) {
|
||||||
|
return net.Listen(network, address)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (realNetwork) Dial(ctx context.Context, network, address string) (net.Conn, error) {
|
||||||
|
var d net.Dialer
|
||||||
|
return d.DialContext(ctx, network, address)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (realNetwork) NewLocalTCPListener() net.Listener {
|
||||||
|
ln, err := net.Listen("tcp", "127.0.0.1:0")
|
||||||
|
if err != nil {
|
||||||
|
if ln, err = net.Listen("tcp6", "[::1]:0"); err != nil {
|
||||||
|
panic(fmt.Sprintf("httptest: failed to listen on a port: %v", err))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ln
|
||||||
|
}
|
||||||
|
|
||||||
|
// MemNetwork returns a Network implementation that uses an in-memory
|
||||||
|
// network for testing. It is only suitable for tests that do not
|
||||||
|
// require real network access.
|
||||||
|
func MemNetwork() Network { return &memNetwork{} }
|
||||||
|
|
||||||
|
// memNetwork implements [Network] using an in-memory network.
|
||||||
|
type memNetwork struct {
|
||||||
|
mu sync.Mutex
|
||||||
|
lns map[string]*memnet.Listener // address -> listener
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *memNetwork) Listen(network, address string) (net.Listener, error) {
|
||||||
|
if network != "tcp" && network != "tcp4" && network != "tcp6" {
|
||||||
|
return nil, fmt.Errorf("memNetwork: Listen called with unsupported network %q", network)
|
||||||
|
}
|
||||||
|
ap, err := netip.ParseAddrPort(address)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("memNetwork: Listen called with invalid address %q: %w", address, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
|
|
||||||
|
if m.lns == nil {
|
||||||
|
m.lns = make(map[string]*memnet.Listener)
|
||||||
|
}
|
||||||
|
port := ap.Port()
|
||||||
|
for {
|
||||||
|
if port == 0 {
|
||||||
|
port = 33000
|
||||||
|
}
|
||||||
|
key := net.JoinHostPort(ap.Addr().String(), fmt.Sprint(port))
|
||||||
|
_, ok := m.lns[key]
|
||||||
|
if ok {
|
||||||
|
if ap.Port() != 0 {
|
||||||
|
return nil, fmt.Errorf("memNetwork: Listen called with duplicate address %q", address)
|
||||||
|
}
|
||||||
|
port++
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
ln := memnet.Listen(key)
|
||||||
|
m.lns[key] = ln
|
||||||
|
return ln, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *memNetwork) NewLocalTCPListener() net.Listener {
|
||||||
|
ln, err := m.Listen("tcp", "127.0.0.1:0")
|
||||||
|
if err != nil {
|
||||||
|
panic(fmt.Sprintf("memNetwork: failed to create local TCP listener: %v", err))
|
||||||
|
}
|
||||||
|
return ln
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *memNetwork) Dial(ctx context.Context, network, address string) (net.Conn, error) {
|
||||||
|
if network != "tcp" && network != "tcp4" && network != "tcp6" {
|
||||||
|
return nil, fmt.Errorf("memNetwork: Dial called with unsupported network %q", network)
|
||||||
|
}
|
||||||
|
m.mu.Lock()
|
||||||
|
ln, ok := m.lns[address]
|
||||||
|
m.mu.Unlock()
|
||||||
|
if !ok {
|
||||||
|
return nil, fmt.Errorf("memNetwork: Dial called on unknown address %q", address)
|
||||||
|
}
|
||||||
|
return ln.Dial(ctx, network, address)
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewHTTPServer starts and returns a new [httptest.Server].
|
||||||
|
// The caller should call Close when finished, to shut it down.
|
||||||
|
func NewHTTPServer(net Network, handler http.Handler) *httptest.Server {
|
||||||
|
ts := NewUnstartedHTTPServer(net, handler)
|
||||||
|
ts.Start()
|
||||||
|
return ts
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewUnstartedHTTPServer returns a new [httptest.Server] but doesn't start it.
|
||||||
|
//
|
||||||
|
// After changing its configuration, the caller should call Start or
|
||||||
|
// StartTLS.
|
||||||
|
//
|
||||||
|
// The caller should call Close when finished, to shut it down.
|
||||||
|
func NewUnstartedHTTPServer(nw Network, handler http.Handler) *httptest.Server {
|
||||||
|
s := &httptest.Server{
|
||||||
|
Config: &http.Server{Handler: handler},
|
||||||
|
}
|
||||||
|
ln := nw.NewLocalTCPListener()
|
||||||
|
s.Listener = &listenerOnAddrOnce{
|
||||||
|
Listener: ln,
|
||||||
|
fn: func() {
|
||||||
|
c := s.Client()
|
||||||
|
if c == nil {
|
||||||
|
// This httptest.Server.Start initialization order has been true
|
||||||
|
// for over 10 years. Let's keep counting on it.
|
||||||
|
panic("httptest.Server: Client not initialized before Addr called")
|
||||||
|
}
|
||||||
|
if c.Transport == nil {
|
||||||
|
c.Transport = &http.Transport{}
|
||||||
|
}
|
||||||
|
tr := c.Transport.(*http.Transport)
|
||||||
|
if tr.Dial != nil || tr.DialContext != nil {
|
||||||
|
panic("unexpected non-nil Dial or DialContext in httptest.Server.Client.Transport")
|
||||||
|
}
|
||||||
|
tr.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||||
|
return nw.Dial(ctx, network, addr)
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
|
||||||
|
// listenerOnAddrOnce is a net.Listener that wraps another net.Listener
|
||||||
|
// and calls a function the first time its Addr is called.
|
||||||
|
type listenerOnAddrOnce struct {
|
||||||
|
net.Listener
|
||||||
|
once sync.Once
|
||||||
|
fn func()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ln *listenerOnAddrOnce) Addr() net.Addr {
|
||||||
|
ln.once.Do(func() {
|
||||||
|
ln.fn()
|
||||||
|
})
|
||||||
|
return ln.Listener.Addr()
|
||||||
|
}
|
||||||
|
@ -58,3 +58,10 @@ func InParallelTest(t TB) (isParallel bool) {
|
|||||||
t.Chdir(".") // panics in a t.Parallel test
|
t.Chdir(".") // panics in a t.Parallel test
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// AssertInTest panics if called outside of a test binary.
|
||||||
|
func AssertInTest() {
|
||||||
|
if !InTest() {
|
||||||
|
panic("func called outside of test binary")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user