net/tstun,wgengine{/netstack/gro}: refactor and re-enable gVisor GRO for Linux (#13172)

In 2f27319baf we disabled GRO due to a
data race around concurrent calls to tstun.Wrapper.Write(). This commit
refactors GRO to be thread-safe, and re-enables it on Linux.

This refactor now carries a GRO type across tstun and netstack APIs
with a lifetime that is scoped to a single tstun.Wrapper.Write() call.

In 25f0a3fc8f we used build tags to
prevent importation of gVisor's GRO package on iOS as at the time we
believed it was contributing to additional memory usage on that
platform. It wasn't, so this commit simplifies and removes those
build tags.

Updates tailscale/corp#22353
Updates tailscale/corp#22125
Updates #6816

Signed-off-by: Jordan Whited <jordan@tailscale.com>
This commit is contained in:
Jordan Whited
2024-08-20 15:22:19 -07:00
committed by GitHub
parent 93dc2ded6e
commit df6014f1d7
12 changed files with 274 additions and 244 deletions

View File

@@ -0,0 +1,169 @@
// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause
// Package gro implements GRO for the receive (write) path into gVisor.
package gro
import (
"bytes"
"sync"
"github.com/tailscale/wireguard-go/tun"
"gvisor.dev/gvisor/pkg/buffer"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/header/parse"
"gvisor.dev/gvisor/pkg/tcpip/stack"
nsgro "gvisor.dev/gvisor/pkg/tcpip/stack/gro"
"tailscale.com/net/packet"
"tailscale.com/types/ipproto"
)
// RXChecksumOffload validates IPv4, TCP, and UDP header checksums in p,
// returning an equivalent *stack.PacketBuffer if they are valid, otherwise nil.
// The set of headers validated covers where gVisor would perform validation if
// !stack.PacketBuffer.RXChecksumValidated, i.e. it satisfies
// stack.CapabilityRXChecksumOffload. Other protocols with checksum fields,
// e.g. ICMP{v6}, are still validated by gVisor regardless of rx checksum
// offloading capabilities.
func RXChecksumOffload(p *packet.Parsed) *stack.PacketBuffer {
var (
pn tcpip.NetworkProtocolNumber
csumStart int
)
buf := p.Buffer()
switch p.IPVersion {
case 4:
if len(buf) < header.IPv4MinimumSize {
return nil
}
csumStart = int((buf[0] & 0x0F) * 4)
if csumStart < header.IPv4MinimumSize || csumStart > header.IPv4MaximumHeaderSize || len(buf) < csumStart {
return nil
}
if ^tun.Checksum(buf[:csumStart], 0) != 0 {
return nil
}
pn = header.IPv4ProtocolNumber
case 6:
if len(buf) < header.IPv6FixedHeaderSize {
return nil
}
csumStart = header.IPv6FixedHeaderSize
pn = header.IPv6ProtocolNumber
if p.IPProto != ipproto.ICMPv6 && p.IPProto != ipproto.TCP && p.IPProto != ipproto.UDP {
// buf could have extension headers before a UDP or TCP header, but
// packet.Parsed.IPProto will be set to the ext header type, so we
// have to look deeper. We are still responsible for validating the
// L4 checksum in this case. So, make use of gVisor's existing
// extension header parsing via parse.IPv6() in order to unpack the
// L4 csumStart index. This is not particularly efficient as we have
// to allocate a short-lived stack.PacketBuffer that cannot be
// re-used. parse.IPv6() "consumes" the IPv6 headers, so we can't
// inject this stack.PacketBuffer into the stack at a later point.
packetBuf := stack.NewPacketBuffer(stack.PacketBufferOptions{
Payload: buffer.MakeWithData(bytes.Clone(buf)),
})
defer packetBuf.DecRef()
// The rightmost bool returns false only if packetBuf is too short,
// which we've already accounted for above.
transportProto, _, _, _, _ := parse.IPv6(packetBuf)
if transportProto == header.TCPProtocolNumber || transportProto == header.UDPProtocolNumber {
csumLen := packetBuf.Data().Size()
if len(buf) < csumLen {
return nil
}
csumStart = len(buf) - csumLen
p.IPProto = ipproto.Proto(transportProto)
}
}
}
if p.IPProto == ipproto.TCP || p.IPProto == ipproto.UDP {
lenForPseudo := len(buf) - csumStart
csum := tun.PseudoHeaderChecksum(
uint8(p.IPProto),
p.Src.Addr().AsSlice(),
p.Dst.Addr().AsSlice(),
uint16(lenForPseudo))
csum = tun.Checksum(buf[csumStart:], csum)
if ^csum != 0 {
return nil
}
}
packetBuf := stack.NewPacketBuffer(stack.PacketBufferOptions{
Payload: buffer.MakeWithData(bytes.Clone(buf)),
})
packetBuf.NetworkProtocolNumber = pn
// Setting this is not technically required. gVisor overrides where
// stack.CapabilityRXChecksumOffload is advertised from Capabilities().
// https://github.com/google/gvisor/blob/64c016c92987cc04dfd4c7b091ddd21bdad875f8/pkg/tcpip/stack/nic.go#L763
// This is also why we offload for all packets since we cannot signal this
// per-packet.
packetBuf.RXChecksumValidated = true
return packetBuf
}
var (
groPool sync.Pool
)
func init() {
groPool.New = func() any {
g := &GRO{}
g.gro.Init(true)
return g
}
}
// GRO coalesces incoming packets to increase throughput. It is NOT thread-safe.
type GRO struct {
gro nsgro.GRO
maybeEnqueued bool
}
// NewGRO returns a new instance of *GRO from a sync.Pool. It can be returned to
// the pool with GRO.Flush().
func NewGRO() *GRO {
return groPool.Get().(*GRO)
}
// SetDispatcher sets the underlying stack.NetworkDispatcher where packets are
// delivered.
func (g *GRO) SetDispatcher(d stack.NetworkDispatcher) {
g.gro.Dispatcher = d
}
// Enqueue enqueues the provided packet for GRO. It may immediately deliver
// it to the underlying stack.NetworkDispatcher depending on its contents. To
// explicitly flush previously enqueued packets see Flush().
func (g *GRO) Enqueue(p *packet.Parsed) {
if g.gro.Dispatcher == nil {
return
}
pkt := RXChecksumOffload(p)
if pkt == nil {
return
}
// TODO(jwhited): g.gro.Enqueue() duplicates a lot of p.Decode().
// We may want to push stack.PacketBuffer further up as a
// replacement for packet.Parsed, or inversely push packet.Parsed
// down into refactored GRO logic.
g.gro.Enqueue(pkt)
g.maybeEnqueued = true
pkt.DecRef()
}
// Flush flushes previously enqueued packets to the underlying
// stack.NetworkDispatcher, and returns GRO to a pool for later re-use. Callers
// MUST NOT use GRO once it has been Flush()'d.
func (g *GRO) Flush() {
if g.gro.Dispatcher != nil && g.maybeEnqueued {
g.gro.Flush()
}
g.gro.Dispatcher = nil
g.maybeEnqueued = false
groPool.Put(g)
}

View File

@@ -0,0 +1,112 @@
// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause
package gro
import (
"bytes"
"net/netip"
"testing"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/header"
"tailscale.com/net/packet"
)
func Test_RXChecksumOffload(t *testing.T) {
payloadLen := 100
tcpFields := &header.TCPFields{
SrcPort: 1,
DstPort: 1,
SeqNum: 1,
AckNum: 1,
DataOffset: 20,
Flags: header.TCPFlagAck | header.TCPFlagPsh,
WindowSize: 3000,
}
tcp4 := make([]byte, 20+20+payloadLen)
ipv4H := header.IPv4(tcp4)
ipv4H.Encode(&header.IPv4Fields{
SrcAddr: tcpip.AddrFromSlice(netip.MustParseAddr("192.0.2.1").AsSlice()),
DstAddr: tcpip.AddrFromSlice(netip.MustParseAddr("192.0.2.2").AsSlice()),
Protocol: uint8(header.TCPProtocolNumber),
TTL: 64,
TotalLength: uint16(len(tcp4)),
})
ipv4H.SetChecksum(^ipv4H.CalculateChecksum())
tcpH := header.TCP(tcp4[20:])
tcpH.Encode(tcpFields)
pseudoCsum := header.PseudoHeaderChecksum(header.TCPProtocolNumber, ipv4H.SourceAddress(), ipv4H.DestinationAddress(), uint16(20+payloadLen))
tcpH.SetChecksum(^tcpH.CalculateChecksum(pseudoCsum))
tcp6ExtHeader := make([]byte, 40+8+20+payloadLen)
ipv6H := header.IPv6(tcp6ExtHeader)
ipv6H.Encode(&header.IPv6Fields{
SrcAddr: tcpip.AddrFromSlice(netip.MustParseAddr("2001:db8::1").AsSlice()),
DstAddr: tcpip.AddrFromSlice(netip.MustParseAddr("2001:db8::2").AsSlice()),
TransportProtocol: 60, // really next header; destination options ext header
HopLimit: 64,
PayloadLength: uint16(8 + 20 + payloadLen),
})
tcp6ExtHeader[40] = uint8(header.TCPProtocolNumber) // next header
tcp6ExtHeader[41] = 0 // length of ext header in 8-octet units, exclusive of first 8 octets.
// 42-47 options and padding
tcpH = header.TCP(tcp6ExtHeader[48:])
tcpH.Encode(tcpFields)
pseudoCsum = header.PseudoHeaderChecksum(header.TCPProtocolNumber, ipv6H.SourceAddress(), ipv6H.DestinationAddress(), uint16(20+payloadLen))
tcpH.SetChecksum(^tcpH.CalculateChecksum(pseudoCsum))
tcp4InvalidCsum := make([]byte, len(tcp4))
copy(tcp4InvalidCsum, tcp4)
at := 20 + 16
tcp4InvalidCsum[at] = ^tcp4InvalidCsum[at]
tcp6ExtHeaderInvalidCsum := make([]byte, len(tcp6ExtHeader))
copy(tcp6ExtHeaderInvalidCsum, tcp6ExtHeader)
at = 40 + 8 + 16
tcp6ExtHeaderInvalidCsum[at] = ^tcp6ExtHeaderInvalidCsum[at]
tests := []struct {
name string
input []byte
wantPB bool
}{
{
"tcp4 packet valid csum",
tcp4,
true,
},
{
"tcp6 with ext header valid csum",
tcp6ExtHeader,
true,
},
{
"tcp4 packet invalid csum",
tcp4InvalidCsum,
false,
},
{
"tcp6 with ext header invalid csum",
tcp6ExtHeaderInvalidCsum,
false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
p := &packet.Parsed{}
p.Decode(tt.input)
got := RXChecksumOffload(p)
if tt.wantPB != (got != nil) {
t.Fatalf("wantPB = %v != (got != nil): %v", tt.wantPB, got != nil)
}
if tt.wantPB {
gotBuf := got.ToBuffer()
if !bytes.Equal(tt.input, gotBuf.Flatten()) {
t.Fatal("output packet unequal to input")
}
}
})
}
}