mirror of
https://github.com/tailscale/tailscale.git
synced 2025-04-21 14:11:56 +00:00
tstest/nettest: pull the non-test Network abstraction out to netx package
We want to be able to use the netx.Network (and RealNetwork implemementation) outside of tests, without linking "testing". So split out the non-test stuff of nettest into its own package. We tend to use "foox" as the convention for things we wish were in the standard library's foo package, so "netx" seems consistent. Updates tailscale/corp#27636 Change-Id: I1911d361f4fbdf189837bf629a20f2ebfa863c44 Signed-off-by: Brad Fitzpatrick <bradfitz@tailscale.com>
This commit is contained in:
parent
ad2b075d4f
commit
03b47a55c7
120
net/netx/netx.go
Normal file
120
net/netx/netx.go
Normal file
@ -0,0 +1,120 @@
|
|||||||
|
// Copyright (c) Tailscale Inc & AUTHORS
|
||||||
|
// SPDX-License-Identifier: BSD-3-Clause
|
||||||
|
|
||||||
|
// Package netx contains the Network type to abstract over either a real
|
||||||
|
// network or a virtual network for testing.
|
||||||
|
package netx
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"net/netip"
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
"tailscale.com/net/memnet"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Network describes a network that can listen and dial. The two common
|
||||||
|
// implementations are [RealNetwork], using the net package to use the real
|
||||||
|
// network, or [MemNetwork], using an in-memory network (typically for testing)
|
||||||
|
type Network interface {
|
||||||
|
NewLocalTCPListener() net.Listener
|
||||||
|
Listen(network, address string) (net.Listener, error)
|
||||||
|
Dial(ctx context.Context, network, address string) (net.Conn, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
// RealNetwork returns a Network implementation that uses the real
|
||||||
|
// net package.
|
||||||
|
func RealNetwork() Network { return realNetwork{} }
|
||||||
|
|
||||||
|
// realNetwork implements [Network] using the real net package.
|
||||||
|
type realNetwork struct{}
|
||||||
|
|
||||||
|
func (realNetwork) Listen(network, address string) (net.Listener, error) {
|
||||||
|
return net.Listen(network, address)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (realNetwork) Dial(ctx context.Context, network, address string) (net.Conn, error) {
|
||||||
|
var d net.Dialer
|
||||||
|
return d.DialContext(ctx, network, address)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (realNetwork) NewLocalTCPListener() net.Listener {
|
||||||
|
ln, err := net.Listen("tcp", "127.0.0.1:0")
|
||||||
|
if err != nil {
|
||||||
|
if ln, err = net.Listen("tcp6", "[::1]:0"); err != nil {
|
||||||
|
panic(fmt.Sprintf("httptest: failed to listen on a port: %v", err))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ln
|
||||||
|
}
|
||||||
|
|
||||||
|
// MemNetwork returns a Network implementation that uses an in-memory
|
||||||
|
// network for testing. It is only suitable for tests that do not
|
||||||
|
// require real network access.
|
||||||
|
//
|
||||||
|
// As of 2025-04-08, it only supports TCP.
|
||||||
|
func MemNetwork() Network { return &memNetwork{} }
|
||||||
|
|
||||||
|
// memNetwork implements [Network] using an in-memory network.
|
||||||
|
type memNetwork struct {
|
||||||
|
mu sync.Mutex
|
||||||
|
lns map[string]*memnet.Listener // address -> listener
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *memNetwork) Listen(network, address string) (net.Listener, error) {
|
||||||
|
if network != "tcp" && network != "tcp4" && network != "tcp6" {
|
||||||
|
return nil, fmt.Errorf("memNetwork: Listen called with unsupported network %q", network)
|
||||||
|
}
|
||||||
|
ap, err := netip.ParseAddrPort(address)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("memNetwork: Listen called with invalid address %q: %w", address, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
|
|
||||||
|
if m.lns == nil {
|
||||||
|
m.lns = make(map[string]*memnet.Listener)
|
||||||
|
}
|
||||||
|
port := ap.Port()
|
||||||
|
for {
|
||||||
|
if port == 0 {
|
||||||
|
port = 33000
|
||||||
|
}
|
||||||
|
key := net.JoinHostPort(ap.Addr().String(), fmt.Sprint(port))
|
||||||
|
_, ok := m.lns[key]
|
||||||
|
if ok {
|
||||||
|
if ap.Port() != 0 {
|
||||||
|
return nil, fmt.Errorf("memNetwork: Listen called with duplicate address %q", address)
|
||||||
|
}
|
||||||
|
port++
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
ln := memnet.Listen(key)
|
||||||
|
m.lns[key] = ln
|
||||||
|
return ln, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *memNetwork) NewLocalTCPListener() net.Listener {
|
||||||
|
ln, err := m.Listen("tcp", "127.0.0.1:0")
|
||||||
|
if err != nil {
|
||||||
|
panic(fmt.Sprintf("memNetwork: failed to create local TCP listener: %v", err))
|
||||||
|
}
|
||||||
|
return ln
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *memNetwork) Dial(ctx context.Context, network, address string) (net.Conn, error) {
|
||||||
|
if network != "tcp" && network != "tcp4" && network != "tcp6" {
|
||||||
|
return nil, fmt.Errorf("memNetwork: Dial called with unsupported network %q", network)
|
||||||
|
}
|
||||||
|
m.mu.Lock()
|
||||||
|
ln, ok := m.lns[address]
|
||||||
|
m.mu.Unlock()
|
||||||
|
if !ok {
|
||||||
|
return nil, fmt.Errorf("memNetwork: Dial called on unknown address %q", address)
|
||||||
|
}
|
||||||
|
return ln.Dial(ctx, network, address)
|
||||||
|
}
|
@ -8,16 +8,14 @@ package nettest
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"flag"
|
"flag"
|
||||||
"fmt"
|
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
"net/netip"
|
|
||||||
"sync"
|
"sync"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"tailscale.com/net/memnet"
|
|
||||||
"tailscale.com/net/netmon"
|
"tailscale.com/net/netmon"
|
||||||
|
"tailscale.com/net/netx"
|
||||||
"tailscale.com/util/testenv"
|
"tailscale.com/util/testenv"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -32,14 +30,6 @@ func SkipIfNoNetwork(t testing.TB) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Network is an interface for use in tests that describes either [RealNetwork]
|
|
||||||
// or [MemNetwork].
|
|
||||||
type Network interface {
|
|
||||||
NewLocalTCPListener() net.Listener
|
|
||||||
Listen(network, address string) (net.Listener, error)
|
|
||||||
Dial(ctx context.Context, network, address string) (net.Conn, error)
|
|
||||||
}
|
|
||||||
|
|
||||||
// PreferMemNetwork reports whether the --use-test-memnet flag is set.
|
// PreferMemNetwork reports whether the --use-test-memnet flag is set.
|
||||||
func PreferMemNetwork() bool {
|
func PreferMemNetwork() bool {
|
||||||
return *useMemNet
|
return *useMemNet
|
||||||
@ -49,12 +39,12 @@ func PreferMemNetwork() bool {
|
|||||||
// whether the --use-test-memnet flag is set.
|
// whether the --use-test-memnet flag is set.
|
||||||
//
|
//
|
||||||
// Each call generates a new network.
|
// Each call generates a new network.
|
||||||
func GetNetwork(tb testing.TB) Network {
|
func GetNetwork(tb testing.TB) netx.Network {
|
||||||
var n Network
|
var n netx.Network
|
||||||
if PreferMemNetwork() {
|
if PreferMemNetwork() {
|
||||||
n = MemNetwork()
|
n = netx.MemNetwork()
|
||||||
} else {
|
} else {
|
||||||
n = RealNetwork()
|
n = netx.RealNetwork()
|
||||||
}
|
}
|
||||||
|
|
||||||
detectLeaks := PreferMemNetwork() || !testenv.InParallelTest(tb)
|
detectLeaks := PreferMemNetwork() || !testenv.InParallelTest(tb)
|
||||||
@ -68,102 +58,9 @@ func GetNetwork(tb testing.TB) Network {
|
|||||||
return n
|
return n
|
||||||
}
|
}
|
||||||
|
|
||||||
// RealNetwork returns a Network implementation that uses the real
|
|
||||||
// net package.
|
|
||||||
func RealNetwork() Network { return realNetwork{} }
|
|
||||||
|
|
||||||
// realNetwork implements [Network] using the real net package.
|
|
||||||
type realNetwork struct{}
|
|
||||||
|
|
||||||
func (realNetwork) Listen(network, address string) (net.Listener, error) {
|
|
||||||
return net.Listen(network, address)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (realNetwork) Dial(ctx context.Context, network, address string) (net.Conn, error) {
|
|
||||||
var d net.Dialer
|
|
||||||
return d.DialContext(ctx, network, address)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (realNetwork) NewLocalTCPListener() net.Listener {
|
|
||||||
ln, err := net.Listen("tcp", "127.0.0.1:0")
|
|
||||||
if err != nil {
|
|
||||||
if ln, err = net.Listen("tcp6", "[::1]:0"); err != nil {
|
|
||||||
panic(fmt.Sprintf("httptest: failed to listen on a port: %v", err))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return ln
|
|
||||||
}
|
|
||||||
|
|
||||||
// MemNetwork returns a Network implementation that uses an in-memory
|
|
||||||
// network for testing. It is only suitable for tests that do not
|
|
||||||
// require real network access.
|
|
||||||
func MemNetwork() Network { return &memNetwork{} }
|
|
||||||
|
|
||||||
// memNetwork implements [Network] using an in-memory network.
|
|
||||||
type memNetwork struct {
|
|
||||||
mu sync.Mutex
|
|
||||||
lns map[string]*memnet.Listener // address -> listener
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *memNetwork) Listen(network, address string) (net.Listener, error) {
|
|
||||||
if network != "tcp" && network != "tcp4" && network != "tcp6" {
|
|
||||||
return nil, fmt.Errorf("memNetwork: Listen called with unsupported network %q", network)
|
|
||||||
}
|
|
||||||
ap, err := netip.ParseAddrPort(address)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("memNetwork: Listen called with invalid address %q: %w", address, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
m.mu.Lock()
|
|
||||||
defer m.mu.Unlock()
|
|
||||||
|
|
||||||
if m.lns == nil {
|
|
||||||
m.lns = make(map[string]*memnet.Listener)
|
|
||||||
}
|
|
||||||
port := ap.Port()
|
|
||||||
for {
|
|
||||||
if port == 0 {
|
|
||||||
port = 33000
|
|
||||||
}
|
|
||||||
key := net.JoinHostPort(ap.Addr().String(), fmt.Sprint(port))
|
|
||||||
_, ok := m.lns[key]
|
|
||||||
if ok {
|
|
||||||
if ap.Port() != 0 {
|
|
||||||
return nil, fmt.Errorf("memNetwork: Listen called with duplicate address %q", address)
|
|
||||||
}
|
|
||||||
port++
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
ln := memnet.Listen(key)
|
|
||||||
m.lns[key] = ln
|
|
||||||
return ln, nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *memNetwork) NewLocalTCPListener() net.Listener {
|
|
||||||
ln, err := m.Listen("tcp", "127.0.0.1:0")
|
|
||||||
if err != nil {
|
|
||||||
panic(fmt.Sprintf("memNetwork: failed to create local TCP listener: %v", err))
|
|
||||||
}
|
|
||||||
return ln
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *memNetwork) Dial(ctx context.Context, network, address string) (net.Conn, error) {
|
|
||||||
if network != "tcp" && network != "tcp4" && network != "tcp6" {
|
|
||||||
return nil, fmt.Errorf("memNetwork: Dial called with unsupported network %q", network)
|
|
||||||
}
|
|
||||||
m.mu.Lock()
|
|
||||||
ln, ok := m.lns[address]
|
|
||||||
m.mu.Unlock()
|
|
||||||
if !ok {
|
|
||||||
return nil, fmt.Errorf("memNetwork: Dial called on unknown address %q", address)
|
|
||||||
}
|
|
||||||
return ln.Dial(ctx, network, address)
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewHTTPServer starts and returns a new [httptest.Server].
|
// NewHTTPServer starts and returns a new [httptest.Server].
|
||||||
// The caller should call Close when finished, to shut it down.
|
// The caller should call Close when finished, to shut it down.
|
||||||
func NewHTTPServer(net Network, handler http.Handler) *httptest.Server {
|
func NewHTTPServer(net netx.Network, handler http.Handler) *httptest.Server {
|
||||||
ts := NewUnstartedHTTPServer(net, handler)
|
ts := NewUnstartedHTTPServer(net, handler)
|
||||||
ts.Start()
|
ts.Start()
|
||||||
return ts
|
return ts
|
||||||
@ -175,7 +72,7 @@ func NewHTTPServer(net Network, handler http.Handler) *httptest.Server {
|
|||||||
// StartTLS.
|
// StartTLS.
|
||||||
//
|
//
|
||||||
// The caller should call Close when finished, to shut it down.
|
// The caller should call Close when finished, to shut it down.
|
||||||
func NewUnstartedHTTPServer(nw Network, handler http.Handler) *httptest.Server {
|
func NewUnstartedHTTPServer(nw netx.Network, handler http.Handler) *httptest.Server {
|
||||||
s := &httptest.Server{
|
s := &httptest.Server{
|
||||||
Config: &http.Server{Handler: handler},
|
Config: &http.Server{Handler: handler},
|
||||||
}
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user