mirror of
https://github.com/tailscale/tailscale.git
synced 2025-04-30 12:42:31 +00:00

We want to be able to use the netx.Network (and RealNetwork implemementation) outside of tests, without linking "testing". So split out the non-test stuff of nettest into its own package. We tend to use "foox" as the convention for things we wish were in the standard library's foo package, so "netx" seems consistent. Updates tailscale/corp#27636 Change-Id: I1911d361f4fbdf189837bf629a20f2ebfa863c44 Signed-off-by: Brad Fitzpatrick <bradfitz@tailscale.com>
121 lines
3.4 KiB
Go
121 lines
3.4 KiB
Go
// Copyright (c) Tailscale Inc & AUTHORS
|
|
// SPDX-License-Identifier: BSD-3-Clause
|
|
|
|
// Package netx contains the Network type to abstract over either a real
|
|
// network or a virtual network for testing.
|
|
package netx
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"net"
|
|
"net/netip"
|
|
"sync"
|
|
|
|
"tailscale.com/net/memnet"
|
|
)
|
|
|
|
// Network describes a network that can listen and dial. The two common
|
|
// implementations are [RealNetwork], using the net package to use the real
|
|
// network, or [MemNetwork], using an in-memory network (typically for testing)
|
|
type Network interface {
|
|
NewLocalTCPListener() net.Listener
|
|
Listen(network, address string) (net.Listener, error)
|
|
Dial(ctx context.Context, network, address string) (net.Conn, error)
|
|
}
|
|
|
|
// RealNetwork returns a Network implementation that uses the real
|
|
// net package.
|
|
func RealNetwork() Network { return realNetwork{} }
|
|
|
|
// realNetwork implements [Network] using the real net package.
|
|
type realNetwork struct{}
|
|
|
|
func (realNetwork) Listen(network, address string) (net.Listener, error) {
|
|
return net.Listen(network, address)
|
|
}
|
|
|
|
func (realNetwork) Dial(ctx context.Context, network, address string) (net.Conn, error) {
|
|
var d net.Dialer
|
|
return d.DialContext(ctx, network, address)
|
|
}
|
|
|
|
func (realNetwork) NewLocalTCPListener() net.Listener {
|
|
ln, err := net.Listen("tcp", "127.0.0.1:0")
|
|
if err != nil {
|
|
if ln, err = net.Listen("tcp6", "[::1]:0"); err != nil {
|
|
panic(fmt.Sprintf("httptest: failed to listen on a port: %v", err))
|
|
}
|
|
}
|
|
return ln
|
|
}
|
|
|
|
// MemNetwork returns a Network implementation that uses an in-memory
|
|
// network for testing. It is only suitable for tests that do not
|
|
// require real network access.
|
|
//
|
|
// As of 2025-04-08, it only supports TCP.
|
|
func MemNetwork() Network { return &memNetwork{} }
|
|
|
|
// memNetwork implements [Network] using an in-memory network.
|
|
type memNetwork struct {
|
|
mu sync.Mutex
|
|
lns map[string]*memnet.Listener // address -> listener
|
|
}
|
|
|
|
func (m *memNetwork) Listen(network, address string) (net.Listener, error) {
|
|
if network != "tcp" && network != "tcp4" && network != "tcp6" {
|
|
return nil, fmt.Errorf("memNetwork: Listen called with unsupported network %q", network)
|
|
}
|
|
ap, err := netip.ParseAddrPort(address)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("memNetwork: Listen called with invalid address %q: %w", address, err)
|
|
}
|
|
|
|
m.mu.Lock()
|
|
defer m.mu.Unlock()
|
|
|
|
if m.lns == nil {
|
|
m.lns = make(map[string]*memnet.Listener)
|
|
}
|
|
port := ap.Port()
|
|
for {
|
|
if port == 0 {
|
|
port = 33000
|
|
}
|
|
key := net.JoinHostPort(ap.Addr().String(), fmt.Sprint(port))
|
|
_, ok := m.lns[key]
|
|
if ok {
|
|
if ap.Port() != 0 {
|
|
return nil, fmt.Errorf("memNetwork: Listen called with duplicate address %q", address)
|
|
}
|
|
port++
|
|
continue
|
|
}
|
|
ln := memnet.Listen(key)
|
|
m.lns[key] = ln
|
|
return ln, nil
|
|
}
|
|
}
|
|
|
|
func (m *memNetwork) NewLocalTCPListener() net.Listener {
|
|
ln, err := m.Listen("tcp", "127.0.0.1:0")
|
|
if err != nil {
|
|
panic(fmt.Sprintf("memNetwork: failed to create local TCP listener: %v", err))
|
|
}
|
|
return ln
|
|
}
|
|
|
|
func (m *memNetwork) Dial(ctx context.Context, network, address string) (net.Conn, error) {
|
|
if network != "tcp" && network != "tcp4" && network != "tcp6" {
|
|
return nil, fmt.Errorf("memNetwork: Dial called with unsupported network %q", network)
|
|
}
|
|
m.mu.Lock()
|
|
ln, ok := m.lns[address]
|
|
m.mu.Unlock()
|
|
if !ok {
|
|
return nil, fmt.Errorf("memNetwork: Dial called on unknown address %q", address)
|
|
}
|
|
return ln.Dial(ctx, network, address)
|
|
}
|