wgengine/netstack: fix bug with duplicate SYN packets in client limit

This fixes a bug that was introduced in #11258 where the handling of the
per-client limit didn't properly account for the fact that the gVisor
TCP forwarder will return 'true' to indicate that it's handled a
duplicate SYN packet, but not launch the handler goroutine.

In such a case, we neither decremented our per-client limit in the
wrapper function, nor did we do so in the handler function, leading to
our per-client limit table slowly filling up without bound.

Fix this by doing the same duplicate-tracking logic that the TCP
forwarder does so we can detect such cases and appropriately decrement
our in-flight counter.

Updates tailscale/corp#12184

Signed-off-by: Andrew Dunham <andrew@du.nham.ca>
Change-Id: Ib6011a71d382a10d68c0802593f34b8153d06892
This commit is contained in:
Andrew Dunham 2024-02-28 23:21:31 -05:00
parent ad33e47270
commit 7429e8912a
2 changed files with 62 additions and 13 deletions

View File

@ -224,6 +224,19 @@ type Impl struct {
// global limit, preventing a misbehaving client from starving the
// global limit.
connsInFlightByClient map[netip.Addr]int
// packetsInFlight tracks whether we're already handling a packet by
// the given endpoint ID; clients can send repeated SYN packets while
// trying to establish a connection (and while we're dialing the
// upstream address). If we don't deduplicate based on the endpoint,
// each SYN retransmit results in us incrementing
// connsInFlightByClient, and not decrementing them because the
// underlying TCP forwarder returns 'true' to indicate that the packet
// is handled but never actually launches our acceptTCP function.
//
// This mimics the 'inFlight' map in the TCP forwarder; it's
// unfortunate that we have to track this all twice, but thankfully the
// map only holds pending (in-flight) packets, and it's reasonably cheap.
packetsInFlight map[stack.TransportEndpointID]struct{}
}
const nicID = 1
@ -315,6 +328,7 @@ func Create(logf logger.Logf, tundev *tstun.Wrapper, e wgengine.Engine, mc *magi
dialer: dialer,
connsOpenBySubnetIP: make(map[netip.Addr]int),
connsInFlightByClient: make(map[netip.Addr]int),
packetsInFlight: make(map[stack.TransportEndpointID]struct{}),
dns: dns,
tailFSForLocal: tailFSForLocal,
}
@ -410,13 +424,28 @@ func (ns *Impl) wrapTCPProtocolHandler(h protocolHandlerFunc) protocolHandlerFun
// NOTE: the counter is decremented in
// decrementInFlightTCPForward, called from the acceptTCP
// function, below.
ns.mu.Lock()
if _, ok := ns.packetsInFlight[tei]; ok {
// We're already handling this packet; just bail early
// (this is also what would happen in the TCP
// forwarder).
ns.mu.Unlock()
return true
}
// Check the per-client limit.
inFlight := ns.connsInFlightByClient[remoteIP]
tooManyInFlight := inFlight >= maxInFlightConnectionAttemptsPerClient()
if !tooManyInFlight {
ns.connsInFlightByClient[remoteIP]++
}
// We're handling this packet now; see the comment on the
// packetsInFlight field for more details.
ns.packetsInFlight[tei] = struct{}{}
ns.mu.Unlock()
if debugNetstack() {
ns.logf("[v2] netstack: in-flight connections for client %v: %d", remoteIP, inFlight)
}
@ -429,18 +458,23 @@ func (ns *Impl) wrapTCPProtocolHandler(h protocolHandlerFunc) protocolHandlerFun
// On return, if this packet isn't handled by the inner handler
// we're wrapping (`h`), we need to decrement the per-client
// in-flight count. This can happen if the underlying
// forwarder's limit has been reached, at which point it will
// return false to indicate that it's not handling the packet,
// and it will not run acceptTCP. If we don't decrement here,
// then we would eventually increment the per-client counter up
// to the limit and never decrement because we'd never hit the
// codepath in acceptTCP, below.
// in-flight count and remove the ID from our tracking map.
// This can happen if the underlying forwarder's limit has been
// reached, at which point it will return false to indicate
// that it's not handling the packet, and it will not run
// acceptTCP. If we don't decrement here, then we would
// eventually increment the per-client counter up to the limit
// and never decrement because we'd never hit the codepath in
// acceptTCP, below, or just drop all packets from the same
// endpoint due to the packetsInFlight check.
defer func() {
if !handled {
ns.mu.Lock()
delete(ns.packetsInFlight, tei)
ns.connsInFlightByClient[remoteIP]--
new := ns.connsInFlightByClient[remoteIP]
ns.mu.Unlock()
ns.logf("netstack: decrementing connsInFlightByClient[%v] because the packet was not handled; new value is %d", remoteIP, new)
}
}()
@ -454,10 +488,13 @@ func (ns *Impl) wrapTCPProtocolHandler(h protocolHandlerFunc) protocolHandlerFun
}
}
func (ns *Impl) decrementInFlightTCPForward(remoteAddr netip.Addr) {
func (ns *Impl) decrementInFlightTCPForward(tei stack.TransportEndpointID, remoteAddr netip.Addr) {
ns.mu.Lock()
defer ns.mu.Unlock()
// Remove this packet so future SYNs from this address will be handled.
delete(ns.packetsInFlight, tei)
was := ns.connsInFlightByClient[remoteAddr]
newVal := was - 1
if newVal == 0 {
@ -1047,12 +1084,14 @@ func (ns *Impl) acceptTCP(r *tcp.ForwarderRequest) {
}
// After we've returned from this function or have otherwise reached a
// non-pending state, decrement the per-client in-flight count so
// future TCP connections aren't dropped.
// non-pending state, decrement the per-client in-flight count and
// remove this endpoint from our packet tracking map so future TCP
// connections aren't dropped.
inFlightCompleted := false
tei := r.ID()
defer func() {
if !inFlightCompleted {
ns.decrementInFlightTCPForward(clientRemoteIP)
ns.decrementInFlightTCPForward(tei, clientRemoteIP)
}
}()
@ -1114,7 +1153,7 @@ func (ns *Impl) acceptTCP(r *tcp.ForwarderRequest) {
// "in-flight" state; decrement our per-client limit right now,
// and tell the defer in acceptTCP that it doesn't need to do
// so upon return.
ns.decrementInFlightTCPForward(clientRemoteIP)
ns.decrementInFlightTCPForward(tei, clientRemoteIP)
inFlightCompleted = true
// The ForwarderRequest.CreateEndpoint above asynchronously

View File

@ -580,6 +580,13 @@ func TestTCPForwardLimits(t *testing.T) {
t.Logf("got connection in progress")
}
// Inject another packet, which will be deduplicated and thus not
// increment our counter.
parsed.Decode(pkt)
if resp := impl.injectInbound(&parsed, impl.tundev); resp != filter.DropSilently {
t.Errorf("got filter outcome %v, want filter.DropSilently", resp)
}
// Verify that we now have a single in-flight address in our map.
impl.mu.Lock()
inFlight := maps.Clone(impl.connsInFlightByClient)
@ -633,8 +640,11 @@ func TestTCPForwardLimits_PerClient(t *testing.T) {
destAddr := netip.MustParseAddr("192.0.2.1")
// Helpers
var port uint16 = 1234
mustInjectPacket := func() {
pkt := tcp4syn(t, client, destAddr, 1234, 4567)
pkt := tcp4syn(t, client, destAddr, port, 4567)
port++ // to avoid deduplication based on endpoint
var parsed packet.Parsed
parsed.Decode(pkt)