tstest/nettest: pull the non-test Network abstraction out to netx package

We want to be able to use the netx.Network (and RealNetwork
implemementation) outside of tests, without linking "testing".

So split out the non-test stuff of nettest into its own package.

We tend to use "foox" as the convention for things we wish were in the
standard library's foo package, so "netx" seems consistent.

Updates tailscale/corp#27636

Change-Id: I1911d361f4fbdf189837bf629a20f2ebfa863c44
Signed-off-by: Brad Fitzpatrick <bradfitz@tailscale.com>
This commit is contained in:
Brad Fitzpatrick 2025-04-08 07:39:52 -07:00 committed by Brad Fitzpatrick
parent ad2b075d4f
commit 03b47a55c7
2 changed files with 127 additions and 110 deletions

120
net/netx/netx.go Normal file
View File

@ -0,0 +1,120 @@
// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause
// Package netx contains the Network type to abstract over either a real
// network or a virtual network for testing.
package netx
import (
"context"
"fmt"
"net"
"net/netip"
"sync"
"tailscale.com/net/memnet"
)
// Network describes a network that can listen and dial. The two common
// implementations are [RealNetwork], using the net package to use the real
// network, or [MemNetwork], using an in-memory network (typically for testing)
type Network interface {
NewLocalTCPListener() net.Listener
Listen(network, address string) (net.Listener, error)
Dial(ctx context.Context, network, address string) (net.Conn, error)
}
// RealNetwork returns a Network implementation that uses the real
// net package.
func RealNetwork() Network { return realNetwork{} }
// realNetwork implements [Network] using the real net package.
type realNetwork struct{}
func (realNetwork) Listen(network, address string) (net.Listener, error) {
return net.Listen(network, address)
}
func (realNetwork) Dial(ctx context.Context, network, address string) (net.Conn, error) {
var d net.Dialer
return d.DialContext(ctx, network, address)
}
func (realNetwork) NewLocalTCPListener() net.Listener {
ln, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
if ln, err = net.Listen("tcp6", "[::1]:0"); err != nil {
panic(fmt.Sprintf("httptest: failed to listen on a port: %v", err))
}
}
return ln
}
// MemNetwork returns a Network implementation that uses an in-memory
// network for testing. It is only suitable for tests that do not
// require real network access.
//
// As of 2025-04-08, it only supports TCP.
func MemNetwork() Network { return &memNetwork{} }
// memNetwork implements [Network] using an in-memory network.
type memNetwork struct {
mu sync.Mutex
lns map[string]*memnet.Listener // address -> listener
}
func (m *memNetwork) Listen(network, address string) (net.Listener, error) {
if network != "tcp" && network != "tcp4" && network != "tcp6" {
return nil, fmt.Errorf("memNetwork: Listen called with unsupported network %q", network)
}
ap, err := netip.ParseAddrPort(address)
if err != nil {
return nil, fmt.Errorf("memNetwork: Listen called with invalid address %q: %w", address, err)
}
m.mu.Lock()
defer m.mu.Unlock()
if m.lns == nil {
m.lns = make(map[string]*memnet.Listener)
}
port := ap.Port()
for {
if port == 0 {
port = 33000
}
key := net.JoinHostPort(ap.Addr().String(), fmt.Sprint(port))
_, ok := m.lns[key]
if ok {
if ap.Port() != 0 {
return nil, fmt.Errorf("memNetwork: Listen called with duplicate address %q", address)
}
port++
continue
}
ln := memnet.Listen(key)
m.lns[key] = ln
return ln, nil
}
}
func (m *memNetwork) NewLocalTCPListener() net.Listener {
ln, err := m.Listen("tcp", "127.0.0.1:0")
if err != nil {
panic(fmt.Sprintf("memNetwork: failed to create local TCP listener: %v", err))
}
return ln
}
func (m *memNetwork) Dial(ctx context.Context, network, address string) (net.Conn, error) {
if network != "tcp" && network != "tcp4" && network != "tcp6" {
return nil, fmt.Errorf("memNetwork: Dial called with unsupported network %q", network)
}
m.mu.Lock()
ln, ok := m.lns[address]
m.mu.Unlock()
if !ok {
return nil, fmt.Errorf("memNetwork: Dial called on unknown address %q", address)
}
return ln.Dial(ctx, network, address)
}

View File

@ -8,16 +8,14 @@ package nettest
import ( import (
"context" "context"
"flag" "flag"
"fmt"
"net" "net"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"net/netip"
"sync" "sync"
"testing" "testing"
"tailscale.com/net/memnet"
"tailscale.com/net/netmon" "tailscale.com/net/netmon"
"tailscale.com/net/netx"
"tailscale.com/util/testenv" "tailscale.com/util/testenv"
) )
@ -32,14 +30,6 @@ func SkipIfNoNetwork(t testing.TB) {
} }
} }
// Network is an interface for use in tests that describes either [RealNetwork]
// or [MemNetwork].
type Network interface {
NewLocalTCPListener() net.Listener
Listen(network, address string) (net.Listener, error)
Dial(ctx context.Context, network, address string) (net.Conn, error)
}
// PreferMemNetwork reports whether the --use-test-memnet flag is set. // PreferMemNetwork reports whether the --use-test-memnet flag is set.
func PreferMemNetwork() bool { func PreferMemNetwork() bool {
return *useMemNet return *useMemNet
@ -49,12 +39,12 @@ func PreferMemNetwork() bool {
// whether the --use-test-memnet flag is set. // whether the --use-test-memnet flag is set.
// //
// Each call generates a new network. // Each call generates a new network.
func GetNetwork(tb testing.TB) Network { func GetNetwork(tb testing.TB) netx.Network {
var n Network var n netx.Network
if PreferMemNetwork() { if PreferMemNetwork() {
n = MemNetwork() n = netx.MemNetwork()
} else { } else {
n = RealNetwork() n = netx.RealNetwork()
} }
detectLeaks := PreferMemNetwork() || !testenv.InParallelTest(tb) detectLeaks := PreferMemNetwork() || !testenv.InParallelTest(tb)
@ -68,102 +58,9 @@ func GetNetwork(tb testing.TB) Network {
return n return n
} }
// RealNetwork returns a Network implementation that uses the real
// net package.
func RealNetwork() Network { return realNetwork{} }
// realNetwork implements [Network] using the real net package.
type realNetwork struct{}
func (realNetwork) Listen(network, address string) (net.Listener, error) {
return net.Listen(network, address)
}
func (realNetwork) Dial(ctx context.Context, network, address string) (net.Conn, error) {
var d net.Dialer
return d.DialContext(ctx, network, address)
}
func (realNetwork) NewLocalTCPListener() net.Listener {
ln, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
if ln, err = net.Listen("tcp6", "[::1]:0"); err != nil {
panic(fmt.Sprintf("httptest: failed to listen on a port: %v", err))
}
}
return ln
}
// MemNetwork returns a Network implementation that uses an in-memory
// network for testing. It is only suitable for tests that do not
// require real network access.
func MemNetwork() Network { return &memNetwork{} }
// memNetwork implements [Network] using an in-memory network.
type memNetwork struct {
mu sync.Mutex
lns map[string]*memnet.Listener // address -> listener
}
func (m *memNetwork) Listen(network, address string) (net.Listener, error) {
if network != "tcp" && network != "tcp4" && network != "tcp6" {
return nil, fmt.Errorf("memNetwork: Listen called with unsupported network %q", network)
}
ap, err := netip.ParseAddrPort(address)
if err != nil {
return nil, fmt.Errorf("memNetwork: Listen called with invalid address %q: %w", address, err)
}
m.mu.Lock()
defer m.mu.Unlock()
if m.lns == nil {
m.lns = make(map[string]*memnet.Listener)
}
port := ap.Port()
for {
if port == 0 {
port = 33000
}
key := net.JoinHostPort(ap.Addr().String(), fmt.Sprint(port))
_, ok := m.lns[key]
if ok {
if ap.Port() != 0 {
return nil, fmt.Errorf("memNetwork: Listen called with duplicate address %q", address)
}
port++
continue
}
ln := memnet.Listen(key)
m.lns[key] = ln
return ln, nil
}
}
func (m *memNetwork) NewLocalTCPListener() net.Listener {
ln, err := m.Listen("tcp", "127.0.0.1:0")
if err != nil {
panic(fmt.Sprintf("memNetwork: failed to create local TCP listener: %v", err))
}
return ln
}
func (m *memNetwork) Dial(ctx context.Context, network, address string) (net.Conn, error) {
if network != "tcp" && network != "tcp4" && network != "tcp6" {
return nil, fmt.Errorf("memNetwork: Dial called with unsupported network %q", network)
}
m.mu.Lock()
ln, ok := m.lns[address]
m.mu.Unlock()
if !ok {
return nil, fmt.Errorf("memNetwork: Dial called on unknown address %q", address)
}
return ln.Dial(ctx, network, address)
}
// NewHTTPServer starts and returns a new [httptest.Server]. // NewHTTPServer starts and returns a new [httptest.Server].
// The caller should call Close when finished, to shut it down. // The caller should call Close when finished, to shut it down.
func NewHTTPServer(net Network, handler http.Handler) *httptest.Server { func NewHTTPServer(net netx.Network, handler http.Handler) *httptest.Server {
ts := NewUnstartedHTTPServer(net, handler) ts := NewUnstartedHTTPServer(net, handler)
ts.Start() ts.Start()
return ts return ts
@ -175,7 +72,7 @@ func NewHTTPServer(net Network, handler http.Handler) *httptest.Server {
// StartTLS. // StartTLS.
// //
// The caller should call Close when finished, to shut it down. // The caller should call Close when finished, to shut it down.
func NewUnstartedHTTPServer(nw Network, handler http.Handler) *httptest.Server { func NewUnstartedHTTPServer(nw netx.Network, handler http.Handler) *httptest.Server {
s := &httptest.Server{ s := &httptest.Server{
Config: &http.Server{Handler: handler}, Config: &http.Server{Handler: handler},
} }