net/{netx,memnet},all: add netx.DialFunc, move memnet Network impl

This adds netx.DialFunc, unifying a type we have a bazillion other
places, giving it now a nice short name that's clickable in
editors, etc.

That highlighted that my earlier move (03b47a55c7) of stuff from
nettest into netx moved too much: it also dragged along the memnet
impl, meaning all users of netx.DialFunc who just wanted netx for the
type definition were instead also pulling in all of memnet.

So move the memnet implementation netx.Network into memnet, a package
we already had.

Then use netx.DialFunc in a bunch of places. I'm sure I missed some.
And plenty remain in other repos, to be updated later.

Updates tailscale/corp#27636

Change-Id: I7296cd4591218e8624e214f8c70dab05fb884e95
Signed-off-by: Brad Fitzpatrick <bradfitz@tailscale.com>
This commit is contained in:
Brad Fitzpatrick
2025-04-08 08:32:27 -07:00
committed by Brad Fitzpatrick
parent b95df54b06
commit fb96137d79
23 changed files with 135 additions and 113 deletions

View File

@@ -31,6 +31,7 @@ import (
"tailscale.com/net/dnscache"
"tailscale.com/net/neterror"
"tailscale.com/net/netmon"
"tailscale.com/net/netx"
"tailscale.com/net/sockstats"
"tailscale.com/net/tsdial"
"tailscale.com/types/dnstype"
@@ -739,7 +740,7 @@ func (f *forwarder) sendUDP(ctx context.Context, fq *forwardQuery, rr resolverAn
return out, nil
}
func (f *forwarder) getDialerType() dnscache.DialContextFunc {
func (f *forwarder) getDialerType() netx.DialFunc {
if f.controlKnobs != nil && f.controlKnobs.UserDialUseRoutes.Load() {
// It is safe to use UserDial as it dials external servers without going through Tailscale
// and closes connections on interface change in the same way as SystemDial does,

View File

@@ -19,6 +19,7 @@ import (
"time"
"tailscale.com/envknob"
"tailscale.com/net/netx"
"tailscale.com/types/logger"
"tailscale.com/util/cloudenv"
"tailscale.com/util/singleflight"
@@ -355,10 +356,8 @@ func (r *Resolver) addIPCache(host string, ip, ip6 netip.Addr, allIPs []netip.Ad
}
}
type DialContextFunc func(ctx context.Context, network, address string) (net.Conn, error)
// Dialer returns a wrapped DialContext func that uses the provided dnsCache.
func Dialer(fwd DialContextFunc, dnsCache *Resolver) DialContextFunc {
func Dialer(fwd netx.DialFunc, dnsCache *Resolver) netx.DialFunc {
d := &dialer{
fwd: fwd,
dnsCache: dnsCache,
@@ -369,7 +368,7 @@ func Dialer(fwd DialContextFunc, dnsCache *Resolver) DialContextFunc {
// dialer is the config and accumulated state for a dial func returned by Dialer.
type dialer struct {
fwd DialContextFunc
fwd netx.DialFunc
dnsCache *Resolver
mu sync.Mutex
@@ -653,7 +652,7 @@ func v6addrs(aa []netip.Addr) (ret []netip.Addr) {
// TLSDialer is like Dialer but returns a func suitable for using with net/http.Transport.DialTLSContext.
// It returns a *tls.Conn type on success.
// On TLS cert validation failure, it can invoke a backup DNS resolution strategy.
func TLSDialer(fwd DialContextFunc, dnsCache *Resolver, tlsConfigBase *tls.Config) DialContextFunc {
func TLSDialer(fwd netx.DialFunc, dnsCache *Resolver, tlsConfigBase *tls.Config) netx.DialFunc {
tcpDialer := Dialer(fwd, dnsCache)
return func(ctx context.Context, network, address string) (net.Conn, error) {
host, _, err := net.SplitHostPort(address)

View File

@@ -6,3 +6,82 @@
// in tests and other situations where you don't want to use the
// network.
package memnet
import (
"context"
"fmt"
"net"
"net/netip"
"sync"
"tailscale.com/net/netx"
)
var _ netx.Network = (*Network)(nil)
// Network implements [Network] using an in-memory network, usually
// used for testing.
//
// As of 2025-04-08, it only supports TCP.
//
// Its zero value is a valid [netx.Network] implementation.
type Network struct {
mu sync.Mutex
lns map[string]*Listener // address -> listener
}
func (m *Network) Listen(network, address string) (net.Listener, error) {
if network != "tcp" && network != "tcp4" && network != "tcp6" {
return nil, fmt.Errorf("memNetwork: Listen called with unsupported network %q", network)
}
ap, err := netip.ParseAddrPort(address)
if err != nil {
return nil, fmt.Errorf("memNetwork: Listen called with invalid address %q: %w", address, err)
}
m.mu.Lock()
defer m.mu.Unlock()
if m.lns == nil {
m.lns = make(map[string]*Listener)
}
port := ap.Port()
for {
if port == 0 {
port = 33000
}
key := net.JoinHostPort(ap.Addr().String(), fmt.Sprint(port))
_, ok := m.lns[key]
if ok {
if ap.Port() != 0 {
return nil, fmt.Errorf("memNetwork: Listen called with duplicate address %q", address)
}
port++
continue
}
ln := Listen(key)
m.lns[key] = ln
return ln, nil
}
}
func (m *Network) NewLocalTCPListener() net.Listener {
ln, err := m.Listen("tcp", "127.0.0.1:0")
if err != nil {
panic(fmt.Sprintf("memNetwork: failed to create local TCP listener: %v", err))
}
return ln
}
func (m *Network) Dial(ctx context.Context, network, address string) (net.Conn, error) {
if network != "tcp" && network != "tcp4" && network != "tcp6" {
return nil, fmt.Errorf("memNetwork: Dial called with unsupported network %q", network)
}
m.mu.Lock()
ln, ok := m.lns[address]
m.mu.Unlock()
if !ok {
return nil, fmt.Errorf("memNetwork: Dial called on unknown address %q", address)
}
return ln.Dial(ctx, network, address)
}

View File

@@ -1,23 +1,25 @@
// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause
// Package netx contains the Network type to abstract over either a real
// network or a virtual network for testing.
// Package netx contains types to describe and abstract over how dialing and
// listening are performed.
package netx
import (
"context"
"fmt"
"net"
"net/netip"
"sync"
"tailscale.com/net/memnet"
)
// DialFunc is a function that dials a network address.
//
// It's the type implemented by net.Dialer.DialContext or required
// by net/http.Transport.DialContext, etc.
type DialFunc func(ctx context.Context, network, address string) (net.Conn, error)
// Network describes a network that can listen and dial. The two common
// implementations are [RealNetwork], using the net package to use the real
// network, or [MemNetwork], using an in-memory network (typically for testing)
// network, or [memnet.Network], using an in-memory network (typically for testing)
type Network interface {
NewLocalTCPListener() net.Listener
Listen(network, address string) (net.Listener, error)
@@ -44,77 +46,8 @@ func (realNetwork) NewLocalTCPListener() net.Listener {
ln, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
if ln, err = net.Listen("tcp6", "[::1]:0"); err != nil {
panic(fmt.Sprintf("httptest: failed to listen on a port: %v", err))
panic(fmt.Sprintf("failed to listen on either IPv4 or IPv6 localhost port: %v", err))
}
}
return ln
}
// MemNetwork returns a Network implementation that uses an in-memory
// network for testing. It is only suitable for tests that do not
// require real network access.
//
// As of 2025-04-08, it only supports TCP.
func MemNetwork() Network { return &memNetwork{} }
// memNetwork implements [Network] using an in-memory network.
type memNetwork struct {
mu sync.Mutex
lns map[string]*memnet.Listener // address -> listener
}
func (m *memNetwork) Listen(network, address string) (net.Listener, error) {
if network != "tcp" && network != "tcp4" && network != "tcp6" {
return nil, fmt.Errorf("memNetwork: Listen called with unsupported network %q", network)
}
ap, err := netip.ParseAddrPort(address)
if err != nil {
return nil, fmt.Errorf("memNetwork: Listen called with invalid address %q: %w", address, err)
}
m.mu.Lock()
defer m.mu.Unlock()
if m.lns == nil {
m.lns = make(map[string]*memnet.Listener)
}
port := ap.Port()
for {
if port == 0 {
port = 33000
}
key := net.JoinHostPort(ap.Addr().String(), fmt.Sprint(port))
_, ok := m.lns[key]
if ok {
if ap.Port() != 0 {
return nil, fmt.Errorf("memNetwork: Listen called with duplicate address %q", address)
}
port++
continue
}
ln := memnet.Listen(key)
m.lns[key] = ln
return ln, nil
}
}
func (m *memNetwork) NewLocalTCPListener() net.Listener {
ln, err := m.Listen("tcp", "127.0.0.1:0")
if err != nil {
panic(fmt.Sprintf("memNetwork: failed to create local TCP listener: %v", err))
}
return ln
}
func (m *memNetwork) Dial(ctx context.Context, network, address string) (net.Conn, error) {
if network != "tcp" && network != "tcp4" && network != "tcp6" {
return nil, fmt.Errorf("memNetwork: Dial called with unsupported network %q", network)
}
m.mu.Lock()
ln, ok := m.lns[address]
m.mu.Unlock()
if !ok {
return nil, fmt.Errorf("memNetwork: Dial called on unknown address %q", address)
}
return ln.Dial(ctx, network, address)
}

View File

@@ -23,6 +23,7 @@ import (
"tailscale.com/net/netknob"
"tailscale.com/net/netmon"
"tailscale.com/net/netns"
"tailscale.com/net/netx"
"tailscale.com/net/tsaddr"
"tailscale.com/types/logger"
"tailscale.com/types/netmap"
@@ -71,7 +72,7 @@ type Dialer struct {
netnsDialerOnce sync.Once
netnsDialer netns.Dialer
sysDialForTest func(_ context.Context, network, addr string) (net.Conn, error) // or nil
sysDialForTest netx.DialFunc // or nil
routes atomic.Pointer[bart.Table[bool]] // or nil if UserDial should not use routes. `true` indicates routes that point into the Tailscale interface
@@ -364,7 +365,7 @@ func (d *Dialer) logf(format string, args ...any) {
// SetSystemDialerForTest sets an alternate function to use for SystemDial
// instead of netns.Dialer. This is intended for use with nettest.MemoryNetwork.
func (d *Dialer) SetSystemDialerForTest(fn func(ctx context.Context, network, addr string) (net.Conn, error)) {
func (d *Dialer) SetSystemDialerForTest(fn netx.DialFunc) {
testenv.AssertInTest()
d.sysDialForTest = fn
}