From 03b47a55c7956d872f7e3d54ca5c868e571517ff Mon Sep 17 00:00:00 2001 From: Brad Fitzpatrick Date: Tue, 8 Apr 2025 07:39:52 -0700 Subject: [PATCH] tstest/nettest: pull the non-test Network abstraction out to netx package We want to be able to use the netx.Network (and RealNetwork implemementation) outside of tests, without linking "testing". So split out the non-test stuff of nettest into its own package. We tend to use "foox" as the convention for things we wish were in the standard library's foo package, so "netx" seems consistent. Updates tailscale/corp#27636 Change-Id: I1911d361f4fbdf189837bf629a20f2ebfa863c44 Signed-off-by: Brad Fitzpatrick --- net/netx/netx.go | 120 ++++++++++++++++++++++++++++++++++++++ tstest/nettest/nettest.go | 117 +++---------------------------------- 2 files changed, 127 insertions(+), 110 deletions(-) create mode 100644 net/netx/netx.go diff --git a/net/netx/netx.go b/net/netx/netx.go new file mode 100644 index 000000000..0be277a15 --- /dev/null +++ b/net/netx/netx.go @@ -0,0 +1,120 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package netx contains the Network type to abstract over either a real +// network or a virtual network for testing. +package netx + +import ( + "context" + "fmt" + "net" + "net/netip" + "sync" + + "tailscale.com/net/memnet" +) + +// Network describes a network that can listen and dial. The two common +// implementations are [RealNetwork], using the net package to use the real +// network, or [MemNetwork], using an in-memory network (typically for testing) +type Network interface { + NewLocalTCPListener() net.Listener + Listen(network, address string) (net.Listener, error) + Dial(ctx context.Context, network, address string) (net.Conn, error) +} + +// RealNetwork returns a Network implementation that uses the real +// net package. +func RealNetwork() Network { return realNetwork{} } + +// realNetwork implements [Network] using the real net package. +type realNetwork struct{} + +func (realNetwork) Listen(network, address string) (net.Listener, error) { + return net.Listen(network, address) +} + +func (realNetwork) Dial(ctx context.Context, network, address string) (net.Conn, error) { + var d net.Dialer + return d.DialContext(ctx, network, address) +} + +func (realNetwork) NewLocalTCPListener() net.Listener { + ln, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + if ln, err = net.Listen("tcp6", "[::1]:0"); err != nil { + panic(fmt.Sprintf("httptest: failed to listen on a port: %v", err)) + } + } + return ln +} + +// MemNetwork returns a Network implementation that uses an in-memory +// network for testing. It is only suitable for tests that do not +// require real network access. +// +// As of 2025-04-08, it only supports TCP. +func MemNetwork() Network { return &memNetwork{} } + +// memNetwork implements [Network] using an in-memory network. +type memNetwork struct { + mu sync.Mutex + lns map[string]*memnet.Listener // address -> listener +} + +func (m *memNetwork) Listen(network, address string) (net.Listener, error) { + if network != "tcp" && network != "tcp4" && network != "tcp6" { + return nil, fmt.Errorf("memNetwork: Listen called with unsupported network %q", network) + } + ap, err := netip.ParseAddrPort(address) + if err != nil { + return nil, fmt.Errorf("memNetwork: Listen called with invalid address %q: %w", address, err) + } + + m.mu.Lock() + defer m.mu.Unlock() + + if m.lns == nil { + m.lns = make(map[string]*memnet.Listener) + } + port := ap.Port() + for { + if port == 0 { + port = 33000 + } + key := net.JoinHostPort(ap.Addr().String(), fmt.Sprint(port)) + _, ok := m.lns[key] + if ok { + if ap.Port() != 0 { + return nil, fmt.Errorf("memNetwork: Listen called with duplicate address %q", address) + } + port++ + continue + } + ln := memnet.Listen(key) + m.lns[key] = ln + return ln, nil + } +} + +func (m *memNetwork) NewLocalTCPListener() net.Listener { + ln, err := m.Listen("tcp", "127.0.0.1:0") + if err != nil { + panic(fmt.Sprintf("memNetwork: failed to create local TCP listener: %v", err)) + } + return ln +} + +func (m *memNetwork) Dial(ctx context.Context, network, address string) (net.Conn, error) { + if network != "tcp" && network != "tcp4" && network != "tcp6" { + return nil, fmt.Errorf("memNetwork: Dial called with unsupported network %q", network) + } + m.mu.Lock() + ln, ok := m.lns[address] + m.mu.Unlock() + if !ok { + return nil, fmt.Errorf("memNetwork: Dial called on unknown address %q", address) + } + return ln.Dial(ctx, network, address) +} diff --git a/tstest/nettest/nettest.go b/tstest/nettest/nettest.go index f03d6987b..98662fe39 100644 --- a/tstest/nettest/nettest.go +++ b/tstest/nettest/nettest.go @@ -8,16 +8,14 @@ package nettest import ( "context" "flag" - "fmt" "net" "net/http" "net/http/httptest" - "net/netip" "sync" "testing" - "tailscale.com/net/memnet" "tailscale.com/net/netmon" + "tailscale.com/net/netx" "tailscale.com/util/testenv" ) @@ -32,14 +30,6 @@ func SkipIfNoNetwork(t testing.TB) { } } -// Network is an interface for use in tests that describes either [RealNetwork] -// or [MemNetwork]. -type Network interface { - NewLocalTCPListener() net.Listener - Listen(network, address string) (net.Listener, error) - Dial(ctx context.Context, network, address string) (net.Conn, error) -} - // PreferMemNetwork reports whether the --use-test-memnet flag is set. func PreferMemNetwork() bool { return *useMemNet @@ -49,12 +39,12 @@ func PreferMemNetwork() bool { // whether the --use-test-memnet flag is set. // // Each call generates a new network. -func GetNetwork(tb testing.TB) Network { - var n Network +func GetNetwork(tb testing.TB) netx.Network { + var n netx.Network if PreferMemNetwork() { - n = MemNetwork() + n = netx.MemNetwork() } else { - n = RealNetwork() + n = netx.RealNetwork() } detectLeaks := PreferMemNetwork() || !testenv.InParallelTest(tb) @@ -68,102 +58,9 @@ func GetNetwork(tb testing.TB) Network { return n } -// RealNetwork returns a Network implementation that uses the real -// net package. -func RealNetwork() Network { return realNetwork{} } - -// realNetwork implements [Network] using the real net package. -type realNetwork struct{} - -func (realNetwork) Listen(network, address string) (net.Listener, error) { - return net.Listen(network, address) -} - -func (realNetwork) Dial(ctx context.Context, network, address string) (net.Conn, error) { - var d net.Dialer - return d.DialContext(ctx, network, address) -} - -func (realNetwork) NewLocalTCPListener() net.Listener { - ln, err := net.Listen("tcp", "127.0.0.1:0") - if err != nil { - if ln, err = net.Listen("tcp6", "[::1]:0"); err != nil { - panic(fmt.Sprintf("httptest: failed to listen on a port: %v", err)) - } - } - return ln -} - -// MemNetwork returns a Network implementation that uses an in-memory -// network for testing. It is only suitable for tests that do not -// require real network access. -func MemNetwork() Network { return &memNetwork{} } - -// memNetwork implements [Network] using an in-memory network. -type memNetwork struct { - mu sync.Mutex - lns map[string]*memnet.Listener // address -> listener -} - -func (m *memNetwork) Listen(network, address string) (net.Listener, error) { - if network != "tcp" && network != "tcp4" && network != "tcp6" { - return nil, fmt.Errorf("memNetwork: Listen called with unsupported network %q", network) - } - ap, err := netip.ParseAddrPort(address) - if err != nil { - return nil, fmt.Errorf("memNetwork: Listen called with invalid address %q: %w", address, err) - } - - m.mu.Lock() - defer m.mu.Unlock() - - if m.lns == nil { - m.lns = make(map[string]*memnet.Listener) - } - port := ap.Port() - for { - if port == 0 { - port = 33000 - } - key := net.JoinHostPort(ap.Addr().String(), fmt.Sprint(port)) - _, ok := m.lns[key] - if ok { - if ap.Port() != 0 { - return nil, fmt.Errorf("memNetwork: Listen called with duplicate address %q", address) - } - port++ - continue - } - ln := memnet.Listen(key) - m.lns[key] = ln - return ln, nil - } -} - -func (m *memNetwork) NewLocalTCPListener() net.Listener { - ln, err := m.Listen("tcp", "127.0.0.1:0") - if err != nil { - panic(fmt.Sprintf("memNetwork: failed to create local TCP listener: %v", err)) - } - return ln -} - -func (m *memNetwork) Dial(ctx context.Context, network, address string) (net.Conn, error) { - if network != "tcp" && network != "tcp4" && network != "tcp6" { - return nil, fmt.Errorf("memNetwork: Dial called with unsupported network %q", network) - } - m.mu.Lock() - ln, ok := m.lns[address] - m.mu.Unlock() - if !ok { - return nil, fmt.Errorf("memNetwork: Dial called on unknown address %q", address) - } - return ln.Dial(ctx, network, address) -} - // NewHTTPServer starts and returns a new [httptest.Server]. // The caller should call Close when finished, to shut it down. -func NewHTTPServer(net Network, handler http.Handler) *httptest.Server { +func NewHTTPServer(net netx.Network, handler http.Handler) *httptest.Server { ts := NewUnstartedHTTPServer(net, handler) ts.Start() return ts @@ -175,7 +72,7 @@ func NewHTTPServer(net Network, handler http.Handler) *httptest.Server { // StartTLS. // // The caller should call Close when finished, to shut it down. -func NewUnstartedHTTPServer(nw Network, handler http.Handler) *httptest.Server { +func NewUnstartedHTTPServer(nw netx.Network, handler http.Handler) *httptest.Server { s := &httptest.Server{ Config: &http.Server{Handler: handler}, }