mirror of
https://github.com/tailscale/tailscale.git
synced 2025-02-19 19:38:40 +00:00
wgengine/netstack: add a per-client limit for in-flight TCP forwards
This is a fun one. Right now, when a client is connecting through a subnet router, here's roughly what happens: 1. The client initiates a connection to an IP address behind a subnet router, and sends a TCP SYN 2. The subnet router gets the SYN packet from netstack, and after running through acceptTCP, starts DialContext-ing the destination IP, without accepting the connection¹ 3. The client retransmits the SYN packet a few times while the dial is in progress, until either... 4. The subnet router successfully establishes a connection to the destination IP and sends the SYN-ACK back to the client, or... 5. The subnet router times out and sends a RST to the client. 6. If the connection was successful, the client ACKs the SYN-ACK it received, and traffic starts flowing As a result, the notification code in forwardTCP never notices when a new connection attempt is aborted, and it will wait until either the connection is established, or until the OS-level connection timeout is reached and it aborts. To mitigate this, add a per-client limit on how many in-flight TCP forwarding connections can be in-progress; after this, clients will see a similar behaviour to the global limit, where new connection attempts are aborted instead of waiting. This prevents a single misbehaving client from blocking all other clients of a subnet router by ensuring that it doesn't starve the global limiter. Also, bump the global limit again to a higher value. ¹ We can't accept the connection before establishing a connection to the remote server since otherwise we'd be opening the connection and then immediately closing it, which breaks a bunch of stuff; see #5503 for more details. Updates tailscale/corp#12184 Signed-off-by: Andrew Dunham <andrew@du.nham.ca> Change-Id: I76e7008ddd497303d75d473f534e32309c8a5144
This commit is contained in:
parent
352c1ac96c
commit
c5abbcd4b4
@ -54,6 +54,7 @@ import (
|
||||
"tailscale.com/types/netmap"
|
||||
"tailscale.com/types/nettype"
|
||||
"tailscale.com/util/clientmetric"
|
||||
"tailscale.com/version"
|
||||
"tailscale.com/version/distro"
|
||||
"tailscale.com/wgengine"
|
||||
"tailscale.com/wgengine/filter"
|
||||
@ -62,6 +63,60 @@ import (
|
||||
|
||||
const debugPackets = false
|
||||
|
||||
// If non-zero, these override the values returned from the corresponding
|
||||
// functions, below.
|
||||
var (
|
||||
maxInFlightConnectionAttemptsForTest int
|
||||
maxInFlightConnectionAttemptsPerClientForTest int
|
||||
)
|
||||
|
||||
// maxInFlightConnectionAttempts returns the global number of in-flight
|
||||
// connection attempts that we allow for a single netstack Impl. Any new
|
||||
// forwarded TCP connections that are opened after the limit has been hit are
|
||||
// rejected until the number of in-flight connections drops below the limit
|
||||
// again.
|
||||
//
|
||||
// Each in-flight connection attempt is a new goroutine and an open TCP
|
||||
// connection, so we want to ensure that we don't allow an unbounded number of
|
||||
// connections.
|
||||
func maxInFlightConnectionAttempts() int {
|
||||
if n := maxInFlightConnectionAttemptsForTest; n > 0 {
|
||||
return n
|
||||
}
|
||||
|
||||
if version.IsMobile() {
|
||||
return 1024 // previous global value
|
||||
}
|
||||
switch version.OS() {
|
||||
case "linux":
|
||||
// On the assumption that most subnet routers deployed in
|
||||
// production are running on Linux, we return a higher value.
|
||||
//
|
||||
// TODO(andrew-d): tune this based on the amount of system
|
||||
// memory instead of a fixed limit.
|
||||
return 8192
|
||||
default:
|
||||
// On all other platforms, return a reasonably high value that
|
||||
// most users won't hit.
|
||||
return 2048
|
||||
}
|
||||
}
|
||||
|
||||
// maxInFlightConnectionAttemptsPerClient is the same as
|
||||
// maxInFlightConnectionAttempts, but applies on a per-client basis
|
||||
// (i.e. keyed by the remote Tailscale IP).
|
||||
func maxInFlightConnectionAttemptsPerClient() int {
|
||||
if n := maxInFlightConnectionAttemptsPerClientForTest; n > 0 {
|
||||
return n
|
||||
}
|
||||
|
||||
// For now, allow each individual client at most 2/3rds of the global
|
||||
// limit. On all platforms except mobile, this won't be a visible
|
||||
// change for users since this limit was added at the same time as we
|
||||
// bumped the global limit, above.
|
||||
return maxInFlightConnectionAttempts() * 2 / 3
|
||||
}
|
||||
|
||||
var debugNetstack = envknob.RegisterBool("TS_DEBUG_NETSTACK")
|
||||
|
||||
var (
|
||||
@ -145,12 +200,30 @@ type Impl struct {
|
||||
// updates.
|
||||
atomicIsLocalIPFunc syncs.AtomicValue[func(netip.Addr) bool]
|
||||
|
||||
// forwardDialFunc, if non-nil, is the net.Dialer.DialContext-style
|
||||
// function that is used to make outgoing connections when forwarding a
|
||||
// TCP connection to another host (e.g. in subnet router mode).
|
||||
//
|
||||
// This is currently only used in tests.
|
||||
forwardDialFunc func(context.Context, string, string) (net.Conn, error)
|
||||
|
||||
// forwardInFlightPerClientDropped is a metric that tracks how many
|
||||
// in-flight TCP forward requests were dropped due to the per-client
|
||||
// limit.
|
||||
forwardInFlightPerClientDropped expvar.Int
|
||||
|
||||
mu sync.Mutex
|
||||
// connsOpenBySubnetIP keeps track of number of connections open
|
||||
// for each subnet IP temporarily registered on netstack for active
|
||||
// TCP connections, so they can be unregistered when connections are
|
||||
// closed.
|
||||
connsOpenBySubnetIP map[netip.Addr]int
|
||||
// connsInFlightByClient keeps track of the number of in-flight
|
||||
// connections by the client ("Tailscale") IP. This is used to apply a
|
||||
// per-client limit on in-flight connections that's smaller than the
|
||||
// global limit, preventing a misbehaving client from starving the
|
||||
// global limit.
|
||||
connsInFlightByClient map[netip.Addr]int
|
||||
}
|
||||
|
||||
const nicID = 1
|
||||
@ -232,17 +305,18 @@ func Create(logf logger.Logf, tundev *tstun.Wrapper, e wgengine.Engine, mc *magi
|
||||
},
|
||||
})
|
||||
ns := &Impl{
|
||||
logf: logf,
|
||||
ipstack: ipstack,
|
||||
linkEP: linkEP,
|
||||
tundev: tundev,
|
||||
e: e,
|
||||
pm: pm,
|
||||
mc: mc,
|
||||
dialer: dialer,
|
||||
connsOpenBySubnetIP: make(map[netip.Addr]int),
|
||||
dns: dns,
|
||||
tailFSForLocal: tailFSForLocal,
|
||||
logf: logf,
|
||||
ipstack: ipstack,
|
||||
linkEP: linkEP,
|
||||
tundev: tundev,
|
||||
e: e,
|
||||
pm: pm,
|
||||
mc: mc,
|
||||
dialer: dialer,
|
||||
connsOpenBySubnetIP: make(map[netip.Addr]int),
|
||||
connsInFlightByClient: make(map[netip.Addr]int),
|
||||
dns: dns,
|
||||
tailFSForLocal: tailFSForLocal,
|
||||
}
|
||||
ns.ctx, ns.ctxCancel = context.WithCancel(context.Background())
|
||||
ns.atomicIsLocalIPFunc.Store(tsaddr.FalseContainsIPFunc())
|
||||
@ -283,10 +357,10 @@ func init() {
|
||||
})
|
||||
}
|
||||
|
||||
// wrapProtoHandler returns protocol handler h wrapped in a version
|
||||
// that dynamically reconfigures ns's subnet addresses as needed for
|
||||
// outbound traffic.
|
||||
func (ns *Impl) wrapProtoHandler(h func(stack.TransportEndpointID, stack.PacketBufferPtr) bool) func(stack.TransportEndpointID, stack.PacketBufferPtr) bool {
|
||||
type protocolHandlerFunc func(stack.TransportEndpointID, stack.PacketBufferPtr) bool
|
||||
|
||||
// wrapUDPProtocolHandler wraps the protocol handler we pass to netstack for UDP.
|
||||
func (ns *Impl) wrapUDPProtocolHandler(h protocolHandlerFunc) protocolHandlerFunc {
|
||||
return func(tei stack.TransportEndpointID, pb stack.PacketBufferPtr) bool {
|
||||
addr := tei.LocalAddress
|
||||
ip, ok := netip.AddrFromSlice(addr.AsSlice())
|
||||
@ -294,6 +368,9 @@ func (ns *Impl) wrapProtoHandler(h func(stack.TransportEndpointID, stack.PacketB
|
||||
ns.logf("netstack: could not parse local address for incoming connection")
|
||||
return false
|
||||
}
|
||||
|
||||
// Dynamically reconfigure ns's subnet addresses as needed for
|
||||
// outbound traffic.
|
||||
ip = ip.Unmap()
|
||||
if !ns.isLocalIP(ip) {
|
||||
ns.addSubnetAddress(ip)
|
||||
@ -302,6 +379,94 @@ func (ns *Impl) wrapProtoHandler(h func(stack.TransportEndpointID, stack.PacketB
|
||||
}
|
||||
}
|
||||
|
||||
var (
|
||||
metricPerClientForwardLimit = clientmetric.NewCounter("netstack_tcp_forward_dropped_attempts_per_client")
|
||||
)
|
||||
|
||||
// wrapTCPProtocolHandler wraps the protocol handler we pass to netstack for TCP.
|
||||
func (ns *Impl) wrapTCPProtocolHandler(h protocolHandlerFunc) protocolHandlerFunc {
|
||||
// 'handled' is whether the packet should be accepted by netstack; if
|
||||
// true, then the TCP connection is accepted by the transport layer and
|
||||
// passes through our acceptTCP handler/etc. If false, then the packet
|
||||
// is dropped and the TCP connection is rejected (typically with an
|
||||
// ICMP Port Unreachable or ICMP Protocol Unreachable message).
|
||||
return func(tei stack.TransportEndpointID, pb stack.PacketBufferPtr) (handled bool) {
|
||||
localIP, ok := netip.AddrFromSlice(tei.LocalAddress.AsSlice())
|
||||
if !ok {
|
||||
ns.logf("netstack: could not parse local address for incoming connection")
|
||||
return false
|
||||
}
|
||||
localIP = localIP.Unmap()
|
||||
|
||||
remoteIP, ok := netip.AddrFromSlice(tei.RemoteAddress.AsSlice())
|
||||
if !ok {
|
||||
ns.logf("netstack: could not parse remote address for incoming connection")
|
||||
return false
|
||||
}
|
||||
|
||||
// If we have too many in-flight connections for this client, abort
|
||||
// early and don't open a new one.
|
||||
//
|
||||
// NOTE: the counter is decremented in
|
||||
// decrementInFlightTCPForward, called from the acceptTCP
|
||||
// function, below.
|
||||
ns.mu.Lock()
|
||||
inFlight := ns.connsInFlightByClient[remoteIP]
|
||||
tooManyInFlight := inFlight >= maxInFlightConnectionAttemptsPerClient()
|
||||
if !tooManyInFlight {
|
||||
ns.connsInFlightByClient[remoteIP]++
|
||||
}
|
||||
ns.mu.Unlock()
|
||||
if debugNetstack() {
|
||||
ns.logf("[v2] netstack: in-flight connections for client %v: %d", remoteIP, inFlight)
|
||||
}
|
||||
if tooManyInFlight {
|
||||
ns.logf("netstack: ignoring a new TCP connection from %v to %v because the client already has %d in-flight connections", localIP, remoteIP, inFlight)
|
||||
metricPerClientForwardLimit.Add(1)
|
||||
ns.forwardInFlightPerClientDropped.Add(1)
|
||||
return false // unhandled
|
||||
}
|
||||
|
||||
// On return, if this packet isn't handled by the inner handler
|
||||
// we're wrapping (`h`), we need to decrement the per-client
|
||||
// in-flight count. This can happen if the underlying
|
||||
// forwarder's limit has been reached, at which point it will
|
||||
// return false to indicate that it's not handling the packet,
|
||||
// and it will not run acceptTCP. If we don't decrement here,
|
||||
// then we would eventually increment the per-client counter up
|
||||
// to the limit and never decrement because we'd never hit the
|
||||
// codepath in acceptTCP, below.
|
||||
defer func() {
|
||||
if !handled {
|
||||
ns.mu.Lock()
|
||||
ns.connsInFlightByClient[remoteIP]--
|
||||
ns.mu.Unlock()
|
||||
}
|
||||
}()
|
||||
|
||||
// Dynamically reconfigure ns's subnet addresses as needed for
|
||||
// outbound traffic.
|
||||
if !ns.isLocalIP(localIP) {
|
||||
ns.addSubnetAddress(localIP)
|
||||
}
|
||||
|
||||
return h(tei, pb)
|
||||
}
|
||||
}
|
||||
|
||||
func (ns *Impl) decrementInFlightTCPForward(remoteAddr netip.Addr) {
|
||||
ns.mu.Lock()
|
||||
defer ns.mu.Unlock()
|
||||
|
||||
was := ns.connsInFlightByClient[remoteAddr]
|
||||
newVal := was - 1
|
||||
if newVal == 0 {
|
||||
delete(ns.connsInFlightByClient, remoteAddr) // free up space in the map
|
||||
} else {
|
||||
ns.connsInFlightByClient[remoteAddr] = newVal
|
||||
}
|
||||
}
|
||||
|
||||
// Start sets up all the handlers so netstack can start working. Implements
|
||||
// wgengine.FakeImpl.
|
||||
func (ns *Impl) Start(lb *ipnlocal.LocalBackend) error {
|
||||
@ -311,11 +476,10 @@ func (ns *Impl) Start(lb *ipnlocal.LocalBackend) error {
|
||||
ns.lb = lb
|
||||
// size = 0 means use default buffer size
|
||||
const tcpReceiveBufferSize = 0
|
||||
const maxInFlightConnectionAttempts = 1024
|
||||
tcpFwd := tcp.NewForwarder(ns.ipstack, tcpReceiveBufferSize, maxInFlightConnectionAttempts, ns.acceptTCP)
|
||||
tcpFwd := tcp.NewForwarder(ns.ipstack, tcpReceiveBufferSize, maxInFlightConnectionAttempts(), ns.acceptTCP)
|
||||
udpFwd := udp.NewForwarder(ns.ipstack, ns.acceptUDP)
|
||||
ns.ipstack.SetTransportProtocolHandler(tcp.ProtocolNumber, ns.wrapProtoHandler(tcpFwd.HandlePacket))
|
||||
ns.ipstack.SetTransportProtocolHandler(udp.ProtocolNumber, ns.wrapProtoHandler(udpFwd.HandlePacket))
|
||||
ns.ipstack.SetTransportProtocolHandler(tcp.ProtocolNumber, ns.wrapTCPProtocolHandler(tcpFwd.HandlePacket))
|
||||
ns.ipstack.SetTransportProtocolHandler(udp.ProtocolNumber, ns.wrapUDPProtocolHandler(udpFwd.HandlePacket))
|
||||
go ns.inject()
|
||||
return nil
|
||||
}
|
||||
@ -881,6 +1045,17 @@ func (ns *Impl) acceptTCP(r *tcp.ForwarderRequest) {
|
||||
r.Complete(true) // sends a RST
|
||||
return
|
||||
}
|
||||
|
||||
// After we've returned from this function or have otherwise reached a
|
||||
// non-pending state, decrement the per-client in-flight count so
|
||||
// future TCP connections aren't dropped.
|
||||
inFlightCompleted := false
|
||||
defer func() {
|
||||
if !inFlightCompleted {
|
||||
ns.decrementInFlightTCPForward(clientRemoteIP)
|
||||
}
|
||||
}()
|
||||
|
||||
clientRemotePort := reqDetails.RemotePort
|
||||
clientRemoteAddrPort := netip.AddrPortFrom(clientRemoteIP, clientRemotePort)
|
||||
|
||||
@ -934,6 +1109,14 @@ func (ns *Impl) acceptTCP(r *tcp.ForwarderRequest) {
|
||||
// peers.
|
||||
ep.SocketOptions().SetKeepAlive(true)
|
||||
|
||||
// This function is called when we're ready to use the
|
||||
// underlying connection, and thus it's no longer in a
|
||||
// "in-flight" state; decrement our per-client limit right now,
|
||||
// and tell the defer in acceptTCP that it doesn't need to do
|
||||
// so upon return.
|
||||
ns.decrementInFlightTCPForward(clientRemoteIP)
|
||||
inFlightCompleted = true
|
||||
|
||||
// The ForwarderRequest.CreateEndpoint above asynchronously
|
||||
// starts the TCP handshake. Note that the gonet.TCPConn
|
||||
// methods c.RemoteAddr() and c.LocalAddr() will return nil
|
||||
@ -1035,8 +1218,14 @@ func (ns *Impl) forwardTCP(getClient func(...tcpip.SettableSocketOption) *gonet.
|
||||
}()
|
||||
|
||||
// Attempt to dial the outbound connection before we accept the inbound one.
|
||||
var stdDialer net.Dialer
|
||||
server, err := stdDialer.DialContext(ctx, "tcp", dialAddrStr)
|
||||
var dialFunc func(context.Context, string, string) (net.Conn, error)
|
||||
if ns.forwardDialFunc != nil {
|
||||
dialFunc = ns.forwardDialFunc
|
||||
} else {
|
||||
var stdDialer net.Dialer
|
||||
dialFunc = stdDialer.DialContext
|
||||
}
|
||||
server, err := dialFunc(ctx, "tcp", dialAddrStr)
|
||||
if err != nil {
|
||||
ns.logf("netstack: could not connect to local server at %s: %v", dialAddr.String(), err)
|
||||
return
|
||||
@ -1456,5 +1645,45 @@ func (ns *Impl) ExpVar() expvar.Var {
|
||||
}))
|
||||
}
|
||||
|
||||
// Export gauges that show the current TCP forwarding limits.
|
||||
m.Set("gauge_tcp_forward_in_flight_limit", expvar.Func(func() any {
|
||||
return maxInFlightConnectionAttempts()
|
||||
}))
|
||||
m.Set("gauge_tcp_forward_in_flight_per_client_limit", expvar.Func(func() any {
|
||||
return maxInFlightConnectionAttemptsPerClient()
|
||||
}))
|
||||
|
||||
// This metric tracks the number of in-flight TCP forwarding
|
||||
// connections that are "in-flight"–i.e. waiting to complete.
|
||||
m.Set("gauge_tcp_forward_in_flight", expvar.Func(func() any {
|
||||
ns.mu.Lock()
|
||||
defer ns.mu.Unlock()
|
||||
|
||||
var sum int64
|
||||
for _, n := range ns.connsInFlightByClient {
|
||||
sum += int64(n)
|
||||
}
|
||||
return sum
|
||||
}))
|
||||
|
||||
m.Set("counter_tcp_forward_max_in_flight_per_client_drop", &ns.forwardInFlightPerClientDropped)
|
||||
|
||||
// This metric tracks how many (if any) of the per-client limit on
|
||||
// in-flight TCP forwarding requests have been reached.
|
||||
m.Set("gauge_tcp_forward_in_flight_per_client_limit_reached", expvar.Func(func() any {
|
||||
ns.mu.Lock()
|
||||
defer ns.mu.Unlock()
|
||||
|
||||
limit := maxInFlightConnectionAttemptsPerClient()
|
||||
|
||||
var count int64
|
||||
for _, n := range ns.connsInFlightByClient {
|
||||
if n == limit {
|
||||
count++
|
||||
}
|
||||
}
|
||||
return count
|
||||
}))
|
||||
|
||||
return m
|
||||
}
|
||||
|
@ -4,14 +4,22 @@
|
||||
package netstack
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"maps"
|
||||
"net"
|
||||
"net/netip"
|
||||
"runtime"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"gvisor.dev/gvisor/pkg/tcpip"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/header"
|
||||
"tailscale.com/envknob"
|
||||
"tailscale.com/ipn"
|
||||
"tailscale.com/ipn/ipnlocal"
|
||||
"tailscale.com/ipn/store/mem"
|
||||
"tailscale.com/metrics"
|
||||
"tailscale.com/net/packet"
|
||||
"tailscale.com/net/tsaddr"
|
||||
"tailscale.com/net/tsdial"
|
||||
@ -455,3 +463,234 @@ func TestShouldProcessInbound(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func tcp4syn(tb testing.TB, src, dst netip.Addr, sport, dport uint16) []byte {
|
||||
ip := header.IPv4(make([]byte, header.IPv4MinimumSize+header.TCPMinimumSize))
|
||||
ip.Encode(&header.IPv4Fields{
|
||||
Protocol: uint8(header.TCPProtocolNumber),
|
||||
TotalLength: header.IPv4MinimumSize + header.TCPMinimumSize,
|
||||
TTL: 64,
|
||||
SrcAddr: tcpip.AddrFrom4Slice(src.AsSlice()),
|
||||
DstAddr: tcpip.AddrFrom4Slice(dst.AsSlice()),
|
||||
})
|
||||
ip.SetChecksum(^ip.CalculateChecksum())
|
||||
if !ip.IsChecksumValid() {
|
||||
tb.Fatal("test broken; packet has incorrect IP checksum")
|
||||
}
|
||||
|
||||
tcp := header.TCP(ip[header.IPv4MinimumSize:])
|
||||
tcp.Encode(&header.TCPFields{
|
||||
SrcPort: sport,
|
||||
DstPort: dport,
|
||||
SeqNum: 0,
|
||||
DataOffset: header.TCPMinimumSize,
|
||||
Flags: header.TCPFlagSyn,
|
||||
WindowSize: 65535,
|
||||
Checksum: 0,
|
||||
})
|
||||
xsum := header.PseudoHeaderChecksum(
|
||||
header.TCPProtocolNumber,
|
||||
tcpip.AddrFrom4Slice(src.AsSlice()),
|
||||
tcpip.AddrFrom4Slice(dst.AsSlice()),
|
||||
uint16(header.TCPMinimumSize),
|
||||
)
|
||||
tcp.SetChecksum(^tcp.CalculateChecksum(xsum))
|
||||
if !tcp.IsChecksumValid(tcpip.AddrFrom4Slice(src.AsSlice()), tcpip.AddrFrom4Slice(dst.AsSlice()), 0, 0) {
|
||||
tb.Fatal("test broken; packet has incorrect TCP checksum")
|
||||
}
|
||||
|
||||
return ip
|
||||
}
|
||||
|
||||
// makeHangDialer returns a dialer that notifies the returned channel when a
|
||||
// connection is dialed and then hangs until the test finishes.
|
||||
func makeHangDialer(tb testing.TB) (func(context.Context, string, string) (net.Conn, error), chan struct{}) {
|
||||
done := make(chan struct{})
|
||||
tb.Cleanup(func() {
|
||||
close(done)
|
||||
})
|
||||
|
||||
gotConn := make(chan struct{}, 1)
|
||||
fn := func(ctx context.Context, network, address string) (net.Conn, error) {
|
||||
// Signal that we have a new connection
|
||||
tb.Logf("hangDialer: called with network=%q address=%q", network, address)
|
||||
select {
|
||||
case gotConn <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
|
||||
// Hang until the test is done.
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
tb.Logf("context done")
|
||||
case <-done:
|
||||
tb.Logf("function completed")
|
||||
}
|
||||
return nil, fmt.Errorf("canceled")
|
||||
}
|
||||
return fn, gotConn
|
||||
}
|
||||
|
||||
// TestTCPForwardLimits verifies that the limits on the TCP forwarder work in a
|
||||
// success case (i.e. when we don't hit the limit).
|
||||
func TestTCPForwardLimits(t *testing.T) {
|
||||
envknob.Setenv("TS_DEBUG_NETSTACK", "true")
|
||||
impl := makeNetstack(t, func(impl *Impl) {
|
||||
impl.ProcessSubnets = true
|
||||
})
|
||||
|
||||
dialFn, gotConn := makeHangDialer(t)
|
||||
impl.forwardDialFunc = dialFn
|
||||
|
||||
prefs := ipn.NewPrefs()
|
||||
prefs.AdvertiseRoutes = []netip.Prefix{
|
||||
// This is the TEST-NET-1 IP block for use in documentation,
|
||||
// and should never actually be routable.
|
||||
netip.MustParsePrefix("192.0.2.0/24"),
|
||||
}
|
||||
impl.lb.Start(ipn.Options{
|
||||
LegacyMigrationPrefs: prefs,
|
||||
})
|
||||
impl.atomicIsLocalIPFunc.Store(looksLikeATailscaleSelfAddress)
|
||||
|
||||
// Inject an "outbound" packet that's going to an IP address that times
|
||||
// out. We need to re-parse from a byte slice so that the internal
|
||||
// buffer in the packet.Parsed type is filled out.
|
||||
client := netip.MustParseAddr("100.101.102.103")
|
||||
destAddr := netip.MustParseAddr("192.0.2.1")
|
||||
pkt := tcp4syn(t, client, destAddr, 1234, 4567)
|
||||
var parsed packet.Parsed
|
||||
parsed.Decode(pkt)
|
||||
|
||||
// When injecting this packet, we want the outcome to be "drop
|
||||
// silently", which indicates that netstack is processing the
|
||||
// packet and not delivering it to the host system.
|
||||
if resp := impl.injectInbound(&parsed, impl.tundev); resp != filter.DropSilently {
|
||||
t.Errorf("got filter outcome %v, want filter.DropSilently", resp)
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
// Wait until we have an in-flight outgoing connection.
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
t.Fatalf("timed out waiting for connection")
|
||||
case <-gotConn:
|
||||
t.Logf("got connection in progress")
|
||||
}
|
||||
|
||||
// Verify that we now have a single in-flight address in our map.
|
||||
impl.mu.Lock()
|
||||
inFlight := maps.Clone(impl.connsInFlightByClient)
|
||||
impl.mu.Unlock()
|
||||
|
||||
if got, ok := inFlight[client]; !ok || got != 1 {
|
||||
t.Errorf("expected 1 in-flight connection for %v, got: %v", client, inFlight)
|
||||
}
|
||||
|
||||
// Get the expvar statistics and verify that we're exporting the
|
||||
// correct metric.
|
||||
metrics := impl.ExpVar().(*metrics.Set)
|
||||
|
||||
const metricName = "gauge_tcp_forward_in_flight"
|
||||
if v := metrics.Get(metricName).String(); v != "1" {
|
||||
t.Errorf("got metric %q=%s, want 1", metricName, v)
|
||||
}
|
||||
}
|
||||
|
||||
// TestTCPForwardLimits_PerClient verifies that the per-client limit for TCP
|
||||
// forwarding works.
|
||||
func TestTCPForwardLimits_PerClient(t *testing.T) {
|
||||
envknob.Setenv("TS_DEBUG_NETSTACK", "true")
|
||||
|
||||
// Set our test override limits during this test.
|
||||
tstest.Replace(t, &maxInFlightConnectionAttemptsForTest, 2)
|
||||
tstest.Replace(t, &maxInFlightConnectionAttemptsPerClientForTest, 1)
|
||||
|
||||
impl := makeNetstack(t, func(impl *Impl) {
|
||||
impl.ProcessSubnets = true
|
||||
})
|
||||
|
||||
dialFn, gotConn := makeHangDialer(t)
|
||||
impl.forwardDialFunc = dialFn
|
||||
|
||||
prefs := ipn.NewPrefs()
|
||||
prefs.AdvertiseRoutes = []netip.Prefix{
|
||||
// This is the TEST-NET-1 IP block for use in documentation,
|
||||
// and should never actually be routable.
|
||||
netip.MustParsePrefix("192.0.2.0/24"),
|
||||
}
|
||||
impl.lb.Start(ipn.Options{
|
||||
LegacyMigrationPrefs: prefs,
|
||||
})
|
||||
impl.atomicIsLocalIPFunc.Store(looksLikeATailscaleSelfAddress)
|
||||
|
||||
// Inject an "outbound" packet that's going to an IP address that times
|
||||
// out. We need to re-parse from a byte slice so that the internal
|
||||
// buffer in the packet.Parsed type is filled out.
|
||||
client := netip.MustParseAddr("100.101.102.103")
|
||||
destAddr := netip.MustParseAddr("192.0.2.1")
|
||||
|
||||
// Helpers
|
||||
mustInjectPacket := func() {
|
||||
pkt := tcp4syn(t, client, destAddr, 1234, 4567)
|
||||
var parsed packet.Parsed
|
||||
parsed.Decode(pkt)
|
||||
|
||||
// When injecting this packet, we want the outcome to be "drop
|
||||
// silently", which indicates that netstack is processing the
|
||||
// packet and not delivering it to the host system.
|
||||
if resp := impl.injectInbound(&parsed, impl.tundev); resp != filter.DropSilently {
|
||||
t.Fatalf("got filter outcome %v, want filter.DropSilently", resp)
|
||||
}
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
waitPacket := func() {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
t.Fatalf("timed out waiting for connection")
|
||||
case <-gotConn:
|
||||
t.Logf("got connection in progress")
|
||||
}
|
||||
}
|
||||
|
||||
// Inject the packet to start the TCP forward and wait until we have an
|
||||
// in-flight outgoing connection.
|
||||
mustInjectPacket()
|
||||
waitPacket()
|
||||
|
||||
// Verify that we now have a single in-flight address in our map.
|
||||
impl.mu.Lock()
|
||||
inFlight := maps.Clone(impl.connsInFlightByClient)
|
||||
impl.mu.Unlock()
|
||||
|
||||
if got, ok := inFlight[client]; !ok || got != 1 {
|
||||
t.Errorf("expected 1 in-flight connection for %v, got: %v", client, inFlight)
|
||||
}
|
||||
|
||||
metrics := impl.ExpVar().(*metrics.Set)
|
||||
|
||||
// One client should have reached the limit at this point.
|
||||
if v := metrics.Get("gauge_tcp_forward_in_flight_per_client_limit_reached").String(); v != "1" {
|
||||
t.Errorf("got limit reached expvar metric=%s, want 1", v)
|
||||
}
|
||||
|
||||
// Inject another packet, and verify that we've incremented our
|
||||
// "dropped" metrics since this will have been dropped.
|
||||
mustInjectPacket()
|
||||
|
||||
// expvar metric
|
||||
const metricName = "counter_tcp_forward_max_in_flight_per_client_drop"
|
||||
if v := metrics.Get(metricName).String(); v != "1" {
|
||||
t.Errorf("got expvar metric %q=%s, want 1", metricName, v)
|
||||
}
|
||||
|
||||
// client metric
|
||||
if v := metricPerClientForwardLimit.Value(); v != 1 {
|
||||
t.Errorf("got clientmetric limit metric=%d, want 1", v)
|
||||
}
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user