wgengine/netstack: add a per-client limit for in-flight TCP forwards

This is a fun one. Right now, when a client is connecting through a
subnet router, here's roughly what happens:

1. The client initiates a connection to an IP address behind a subnet
   router, and sends a TCP SYN
2. The subnet router gets the SYN packet from netstack, and after
   running through acceptTCP, starts DialContext-ing the destination IP,
   without accepting the connection¹
3. The client retransmits the SYN packet a few times while the dial is
   in progress, until either...
4. The subnet router successfully establishes a connection to the
   destination IP and sends the SYN-ACK back to the client, or...
5. The subnet router times out and sends a RST to the client.
6. If the connection was successful, the client ACKs the SYN-ACK it
   received, and traffic starts flowing

As a result, the notification code in forwardTCP never notices when a
new connection attempt is aborted, and it will wait until either the
connection is established, or until the OS-level connection timeout is
reached and it aborts.

To mitigate this, add a per-client limit on how many in-flight TCP
forwarding connections can be in-progress; after this, clients will see
a similar behaviour to the global limit, where new connection attempts
are aborted instead of waiting. This prevents a single misbehaving
client from blocking all other clients of a subnet router by ensuring
that it doesn't starve the global limiter.

Also, bump the global limit again to a higher value.

¹ We can't accept the connection before establishing a connection to the
remote server since otherwise we'd be opening the connection and then
immediately closing it, which breaks a bunch of stuff; see #5503 for
more details.

Updates tailscale/corp#12184

Signed-off-by: Andrew Dunham <andrew@du.nham.ca>
Change-Id: I76e7008ddd497303d75d473f534e32309c8a5144
This commit is contained in:
Andrew Dunham
2024-02-26 15:06:47 -05:00
parent 352c1ac96c
commit c5abbcd4b4
2 changed files with 489 additions and 21 deletions

View File

@@ -54,6 +54,7 @@ import (
"tailscale.com/types/netmap"
"tailscale.com/types/nettype"
"tailscale.com/util/clientmetric"
"tailscale.com/version"
"tailscale.com/version/distro"
"tailscale.com/wgengine"
"tailscale.com/wgengine/filter"
@@ -62,6 +63,60 @@ import (
const debugPackets = false
// If non-zero, these override the values returned from the corresponding
// functions, below.
var (
maxInFlightConnectionAttemptsForTest int
maxInFlightConnectionAttemptsPerClientForTest int
)
// maxInFlightConnectionAttempts returns the global number of in-flight
// connection attempts that we allow for a single netstack Impl. Any new
// forwarded TCP connections that are opened after the limit has been hit are
// rejected until the number of in-flight connections drops below the limit
// again.
//
// Each in-flight connection attempt is a new goroutine and an open TCP
// connection, so we want to ensure that we don't allow an unbounded number of
// connections.
func maxInFlightConnectionAttempts() int {
if n := maxInFlightConnectionAttemptsForTest; n > 0 {
return n
}
if version.IsMobile() {
return 1024 // previous global value
}
switch version.OS() {
case "linux":
// On the assumption that most subnet routers deployed in
// production are running on Linux, we return a higher value.
//
// TODO(andrew-d): tune this based on the amount of system
// memory instead of a fixed limit.
return 8192
default:
// On all other platforms, return a reasonably high value that
// most users won't hit.
return 2048
}
}
// maxInFlightConnectionAttemptsPerClient is the same as
// maxInFlightConnectionAttempts, but applies on a per-client basis
// (i.e. keyed by the remote Tailscale IP).
func maxInFlightConnectionAttemptsPerClient() int {
if n := maxInFlightConnectionAttemptsPerClientForTest; n > 0 {
return n
}
// For now, allow each individual client at most 2/3rds of the global
// limit. On all platforms except mobile, this won't be a visible
// change for users since this limit was added at the same time as we
// bumped the global limit, above.
return maxInFlightConnectionAttempts() * 2 / 3
}
var debugNetstack = envknob.RegisterBool("TS_DEBUG_NETSTACK")
var (
@@ -145,12 +200,30 @@ type Impl struct {
// updates.
atomicIsLocalIPFunc syncs.AtomicValue[func(netip.Addr) bool]
// forwardDialFunc, if non-nil, is the net.Dialer.DialContext-style
// function that is used to make outgoing connections when forwarding a
// TCP connection to another host (e.g. in subnet router mode).
//
// This is currently only used in tests.
forwardDialFunc func(context.Context, string, string) (net.Conn, error)
// forwardInFlightPerClientDropped is a metric that tracks how many
// in-flight TCP forward requests were dropped due to the per-client
// limit.
forwardInFlightPerClientDropped expvar.Int
mu sync.Mutex
// connsOpenBySubnetIP keeps track of number of connections open
// for each subnet IP temporarily registered on netstack for active
// TCP connections, so they can be unregistered when connections are
// closed.
connsOpenBySubnetIP map[netip.Addr]int
// connsInFlightByClient keeps track of the number of in-flight
// connections by the client ("Tailscale") IP. This is used to apply a
// per-client limit on in-flight connections that's smaller than the
// global limit, preventing a misbehaving client from starving the
// global limit.
connsInFlightByClient map[netip.Addr]int
}
const nicID = 1
@@ -232,17 +305,18 @@ func Create(logf logger.Logf, tundev *tstun.Wrapper, e wgengine.Engine, mc *magi
},
})
ns := &Impl{
logf: logf,
ipstack: ipstack,
linkEP: linkEP,
tundev: tundev,
e: e,
pm: pm,
mc: mc,
dialer: dialer,
connsOpenBySubnetIP: make(map[netip.Addr]int),
dns: dns,
tailFSForLocal: tailFSForLocal,
logf: logf,
ipstack: ipstack,
linkEP: linkEP,
tundev: tundev,
e: e,
pm: pm,
mc: mc,
dialer: dialer,
connsOpenBySubnetIP: make(map[netip.Addr]int),
connsInFlightByClient: make(map[netip.Addr]int),
dns: dns,
tailFSForLocal: tailFSForLocal,
}
ns.ctx, ns.ctxCancel = context.WithCancel(context.Background())
ns.atomicIsLocalIPFunc.Store(tsaddr.FalseContainsIPFunc())
@@ -283,10 +357,10 @@ func init() {
})
}
// wrapProtoHandler returns protocol handler h wrapped in a version
// that dynamically reconfigures ns's subnet addresses as needed for
// outbound traffic.
func (ns *Impl) wrapProtoHandler(h func(stack.TransportEndpointID, stack.PacketBufferPtr) bool) func(stack.TransportEndpointID, stack.PacketBufferPtr) bool {
type protocolHandlerFunc func(stack.TransportEndpointID, stack.PacketBufferPtr) bool
// wrapUDPProtocolHandler wraps the protocol handler we pass to netstack for UDP.
func (ns *Impl) wrapUDPProtocolHandler(h protocolHandlerFunc) protocolHandlerFunc {
return func(tei stack.TransportEndpointID, pb stack.PacketBufferPtr) bool {
addr := tei.LocalAddress
ip, ok := netip.AddrFromSlice(addr.AsSlice())
@@ -294,6 +368,9 @@ func (ns *Impl) wrapProtoHandler(h func(stack.TransportEndpointID, stack.PacketB
ns.logf("netstack: could not parse local address for incoming connection")
return false
}
// Dynamically reconfigure ns's subnet addresses as needed for
// outbound traffic.
ip = ip.Unmap()
if !ns.isLocalIP(ip) {
ns.addSubnetAddress(ip)
@@ -302,6 +379,94 @@ func (ns *Impl) wrapProtoHandler(h func(stack.TransportEndpointID, stack.PacketB
}
}
var (
metricPerClientForwardLimit = clientmetric.NewCounter("netstack_tcp_forward_dropped_attempts_per_client")
)
// wrapTCPProtocolHandler wraps the protocol handler we pass to netstack for TCP.
func (ns *Impl) wrapTCPProtocolHandler(h protocolHandlerFunc) protocolHandlerFunc {
// 'handled' is whether the packet should be accepted by netstack; if
// true, then the TCP connection is accepted by the transport layer and
// passes through our acceptTCP handler/etc. If false, then the packet
// is dropped and the TCP connection is rejected (typically with an
// ICMP Port Unreachable or ICMP Protocol Unreachable message).
return func(tei stack.TransportEndpointID, pb stack.PacketBufferPtr) (handled bool) {
localIP, ok := netip.AddrFromSlice(tei.LocalAddress.AsSlice())
if !ok {
ns.logf("netstack: could not parse local address for incoming connection")
return false
}
localIP = localIP.Unmap()
remoteIP, ok := netip.AddrFromSlice(tei.RemoteAddress.AsSlice())
if !ok {
ns.logf("netstack: could not parse remote address for incoming connection")
return false
}
// If we have too many in-flight connections for this client, abort
// early and don't open a new one.
//
// NOTE: the counter is decremented in
// decrementInFlightTCPForward, called from the acceptTCP
// function, below.
ns.mu.Lock()
inFlight := ns.connsInFlightByClient[remoteIP]
tooManyInFlight := inFlight >= maxInFlightConnectionAttemptsPerClient()
if !tooManyInFlight {
ns.connsInFlightByClient[remoteIP]++
}
ns.mu.Unlock()
if debugNetstack() {
ns.logf("[v2] netstack: in-flight connections for client %v: %d", remoteIP, inFlight)
}
if tooManyInFlight {
ns.logf("netstack: ignoring a new TCP connection from %v to %v because the client already has %d in-flight connections", localIP, remoteIP, inFlight)
metricPerClientForwardLimit.Add(1)
ns.forwardInFlightPerClientDropped.Add(1)
return false // unhandled
}
// On return, if this packet isn't handled by the inner handler
// we're wrapping (`h`), we need to decrement the per-client
// in-flight count. This can happen if the underlying
// forwarder's limit has been reached, at which point it will
// return false to indicate that it's not handling the packet,
// and it will not run acceptTCP. If we don't decrement here,
// then we would eventually increment the per-client counter up
// to the limit and never decrement because we'd never hit the
// codepath in acceptTCP, below.
defer func() {
if !handled {
ns.mu.Lock()
ns.connsInFlightByClient[remoteIP]--
ns.mu.Unlock()
}
}()
// Dynamically reconfigure ns's subnet addresses as needed for
// outbound traffic.
if !ns.isLocalIP(localIP) {
ns.addSubnetAddress(localIP)
}
return h(tei, pb)
}
}
func (ns *Impl) decrementInFlightTCPForward(remoteAddr netip.Addr) {
ns.mu.Lock()
defer ns.mu.Unlock()
was := ns.connsInFlightByClient[remoteAddr]
newVal := was - 1
if newVal == 0 {
delete(ns.connsInFlightByClient, remoteAddr) // free up space in the map
} else {
ns.connsInFlightByClient[remoteAddr] = newVal
}
}
// Start sets up all the handlers so netstack can start working. Implements
// wgengine.FakeImpl.
func (ns *Impl) Start(lb *ipnlocal.LocalBackend) error {
@@ -311,11 +476,10 @@ func (ns *Impl) Start(lb *ipnlocal.LocalBackend) error {
ns.lb = lb
// size = 0 means use default buffer size
const tcpReceiveBufferSize = 0
const maxInFlightConnectionAttempts = 1024
tcpFwd := tcp.NewForwarder(ns.ipstack, tcpReceiveBufferSize, maxInFlightConnectionAttempts, ns.acceptTCP)
tcpFwd := tcp.NewForwarder(ns.ipstack, tcpReceiveBufferSize, maxInFlightConnectionAttempts(), ns.acceptTCP)
udpFwd := udp.NewForwarder(ns.ipstack, ns.acceptUDP)
ns.ipstack.SetTransportProtocolHandler(tcp.ProtocolNumber, ns.wrapProtoHandler(tcpFwd.HandlePacket))
ns.ipstack.SetTransportProtocolHandler(udp.ProtocolNumber, ns.wrapProtoHandler(udpFwd.HandlePacket))
ns.ipstack.SetTransportProtocolHandler(tcp.ProtocolNumber, ns.wrapTCPProtocolHandler(tcpFwd.HandlePacket))
ns.ipstack.SetTransportProtocolHandler(udp.ProtocolNumber, ns.wrapUDPProtocolHandler(udpFwd.HandlePacket))
go ns.inject()
return nil
}
@@ -881,6 +1045,17 @@ func (ns *Impl) acceptTCP(r *tcp.ForwarderRequest) {
r.Complete(true) // sends a RST
return
}
// After we've returned from this function or have otherwise reached a
// non-pending state, decrement the per-client in-flight count so
// future TCP connections aren't dropped.
inFlightCompleted := false
defer func() {
if !inFlightCompleted {
ns.decrementInFlightTCPForward(clientRemoteIP)
}
}()
clientRemotePort := reqDetails.RemotePort
clientRemoteAddrPort := netip.AddrPortFrom(clientRemoteIP, clientRemotePort)
@@ -934,6 +1109,14 @@ func (ns *Impl) acceptTCP(r *tcp.ForwarderRequest) {
// peers.
ep.SocketOptions().SetKeepAlive(true)
// This function is called when we're ready to use the
// underlying connection, and thus it's no longer in a
// "in-flight" state; decrement our per-client limit right now,
// and tell the defer in acceptTCP that it doesn't need to do
// so upon return.
ns.decrementInFlightTCPForward(clientRemoteIP)
inFlightCompleted = true
// The ForwarderRequest.CreateEndpoint above asynchronously
// starts the TCP handshake. Note that the gonet.TCPConn
// methods c.RemoteAddr() and c.LocalAddr() will return nil
@@ -1035,8 +1218,14 @@ func (ns *Impl) forwardTCP(getClient func(...tcpip.SettableSocketOption) *gonet.
}()
// Attempt to dial the outbound connection before we accept the inbound one.
var stdDialer net.Dialer
server, err := stdDialer.DialContext(ctx, "tcp", dialAddrStr)
var dialFunc func(context.Context, string, string) (net.Conn, error)
if ns.forwardDialFunc != nil {
dialFunc = ns.forwardDialFunc
} else {
var stdDialer net.Dialer
dialFunc = stdDialer.DialContext
}
server, err := dialFunc(ctx, "tcp", dialAddrStr)
if err != nil {
ns.logf("netstack: could not connect to local server at %s: %v", dialAddr.String(), err)
return
@@ -1456,5 +1645,45 @@ func (ns *Impl) ExpVar() expvar.Var {
}))
}
// Export gauges that show the current TCP forwarding limits.
m.Set("gauge_tcp_forward_in_flight_limit", expvar.Func(func() any {
return maxInFlightConnectionAttempts()
}))
m.Set("gauge_tcp_forward_in_flight_per_client_limit", expvar.Func(func() any {
return maxInFlightConnectionAttemptsPerClient()
}))
// This metric tracks the number of in-flight TCP forwarding
// connections that are "in-flight"i.e. waiting to complete.
m.Set("gauge_tcp_forward_in_flight", expvar.Func(func() any {
ns.mu.Lock()
defer ns.mu.Unlock()
var sum int64
for _, n := range ns.connsInFlightByClient {
sum += int64(n)
}
return sum
}))
m.Set("counter_tcp_forward_max_in_flight_per_client_drop", &ns.forwardInFlightPerClientDropped)
// This metric tracks how many (if any) of the per-client limit on
// in-flight TCP forwarding requests have been reached.
m.Set("gauge_tcp_forward_in_flight_per_client_limit_reached", expvar.Func(func() any {
ns.mu.Lock()
defer ns.mu.Unlock()
limit := maxInFlightConnectionAttemptsPerClient()
var count int64
for _, n := range ns.connsInFlightByClient {
if n == limit {
count++
}
}
return count
}))
return m
}