2023-01-30 10:53:58 -08:00
|
|
|
// Copyright (c) Tailscale Inc & AUTHORS
|
|
|
|
// SPDX-License-Identifier: BSD-3-Clause
|
|
|
|
|
|
|
|
// Package memnet implements an in-memory network implementation.
|
|
|
|
// It is useful for dialing and listening on in-memory addresses
|
|
|
|
// in tests and other situations where you don't want to use the
|
|
|
|
// network.
|
|
|
|
package memnet
|
2025-04-08 08:32:27 -07:00
|
|
|
|
|
|
|
import (
|
|
|
|
"context"
|
|
|
|
"fmt"
|
|
|
|
"net"
|
|
|
|
"net/netip"
|
|
|
|
"sync"
|
|
|
|
|
|
|
|
"tailscale.com/net/netx"
|
|
|
|
)
|
|
|
|
|
|
|
|
var _ netx.Network = (*Network)(nil)
|
|
|
|
|
|
|
|
// Network implements [Network] using an in-memory network, usually
|
|
|
|
// used for testing.
|
|
|
|
//
|
|
|
|
// As of 2025-04-08, it only supports TCP.
|
|
|
|
//
|
|
|
|
// Its zero value is a valid [netx.Network] implementation.
|
|
|
|
type Network struct {
|
|
|
|
mu sync.Mutex
|
|
|
|
lns map[string]*Listener // address -> listener
|
|
|
|
}
|
|
|
|
|
|
|
|
func (m *Network) Listen(network, address string) (net.Listener, error) {
|
|
|
|
if network != "tcp" && network != "tcp4" && network != "tcp6" {
|
|
|
|
return nil, fmt.Errorf("memNetwork: Listen called with unsupported network %q", network)
|
|
|
|
}
|
|
|
|
ap, err := netip.ParseAddrPort(address)
|
|
|
|
if err != nil {
|
|
|
|
return nil, fmt.Errorf("memNetwork: Listen called with invalid address %q: %w", address, err)
|
|
|
|
}
|
|
|
|
|
|
|
|
m.mu.Lock()
|
|
|
|
defer m.mu.Unlock()
|
|
|
|
|
|
|
|
if m.lns == nil {
|
|
|
|
m.lns = make(map[string]*Listener)
|
|
|
|
}
|
|
|
|
port := ap.Port()
|
|
|
|
for {
|
|
|
|
if port == 0 {
|
|
|
|
port = 33000
|
|
|
|
}
|
|
|
|
key := net.JoinHostPort(ap.Addr().String(), fmt.Sprint(port))
|
|
|
|
_, ok := m.lns[key]
|
|
|
|
if ok {
|
|
|
|
if ap.Port() != 0 {
|
|
|
|
return nil, fmt.Errorf("memNetwork: Listen called with duplicate address %q", address)
|
|
|
|
}
|
|
|
|
port++
|
|
|
|
continue
|
|
|
|
}
|
|
|
|
ln := Listen(key)
|
|
|
|
m.lns[key] = ln
|
|
|
|
return ln, nil
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
func (m *Network) NewLocalTCPListener() net.Listener {
|
|
|
|
ln, err := m.Listen("tcp", "127.0.0.1:0")
|
|
|
|
if err != nil {
|
|
|
|
panic(fmt.Sprintf("memNetwork: failed to create local TCP listener: %v", err))
|
|
|
|
}
|
|
|
|
return ln
|
|
|
|
}
|
|
|
|
|
|
|
|
func (m *Network) Dial(ctx context.Context, network, address string) (net.Conn, error) {
|
|
|
|
if network != "tcp" && network != "tcp4" && network != "tcp6" {
|
|
|
|
return nil, fmt.Errorf("memNetwork: Dial called with unsupported network %q", network)
|
|
|
|
}
|
|
|
|
m.mu.Lock()
|
|
|
|
ln, ok := m.lns[address]
|
|
|
|
m.mu.Unlock()
|
|
|
|
if !ok {
|
|
|
|
return nil, fmt.Errorf("memNetwork: Dial called on unknown address %q", address)
|
|
|
|
}
|
|
|
|
return ln.Dial(ctx, network, address)
|
|
|
|
}
|