mirror of
https://github.com/tailscale/tailscale.git
synced 2025-01-10 10:03:43 +00:00
df6014f1d7
In 2f27319baf71681e221904d3a3ffe1badedc8e2e we disabled GRO due to a data race around concurrent calls to tstun.Wrapper.Write(). This commit refactors GRO to be thread-safe, and re-enables it on Linux. This refactor now carries a GRO type across tstun and netstack APIs with a lifetime that is scoped to a single tstun.Wrapper.Write() call. In 25f0a3fc8f6f9cf681bb5afda8e1762816c67a8b we used build tags to prevent importation of gVisor's GRO package on iOS as at the time we believed it was contributing to additional memory usage on that platform. It wasn't, so this commit simplifies and removes those build tags. Updates tailscale/corp#22353 Updates tailscale/corp#22125 Updates #6816 Signed-off-by: Jordan Whited <jordan@tailscale.com>
113 lines
3.1 KiB
Go
113 lines
3.1 KiB
Go
// Copyright (c) Tailscale Inc & AUTHORS
|
|
// SPDX-License-Identifier: BSD-3-Clause
|
|
|
|
package gro
|
|
|
|
import (
|
|
"bytes"
|
|
"net/netip"
|
|
"testing"
|
|
|
|
"gvisor.dev/gvisor/pkg/tcpip"
|
|
"gvisor.dev/gvisor/pkg/tcpip/header"
|
|
"tailscale.com/net/packet"
|
|
)
|
|
|
|
func Test_RXChecksumOffload(t *testing.T) {
|
|
payloadLen := 100
|
|
|
|
tcpFields := &header.TCPFields{
|
|
SrcPort: 1,
|
|
DstPort: 1,
|
|
SeqNum: 1,
|
|
AckNum: 1,
|
|
DataOffset: 20,
|
|
Flags: header.TCPFlagAck | header.TCPFlagPsh,
|
|
WindowSize: 3000,
|
|
}
|
|
tcp4 := make([]byte, 20+20+payloadLen)
|
|
ipv4H := header.IPv4(tcp4)
|
|
ipv4H.Encode(&header.IPv4Fields{
|
|
SrcAddr: tcpip.AddrFromSlice(netip.MustParseAddr("192.0.2.1").AsSlice()),
|
|
DstAddr: tcpip.AddrFromSlice(netip.MustParseAddr("192.0.2.2").AsSlice()),
|
|
Protocol: uint8(header.TCPProtocolNumber),
|
|
TTL: 64,
|
|
TotalLength: uint16(len(tcp4)),
|
|
})
|
|
ipv4H.SetChecksum(^ipv4H.CalculateChecksum())
|
|
tcpH := header.TCP(tcp4[20:])
|
|
tcpH.Encode(tcpFields)
|
|
pseudoCsum := header.PseudoHeaderChecksum(header.TCPProtocolNumber, ipv4H.SourceAddress(), ipv4H.DestinationAddress(), uint16(20+payloadLen))
|
|
tcpH.SetChecksum(^tcpH.CalculateChecksum(pseudoCsum))
|
|
|
|
tcp6ExtHeader := make([]byte, 40+8+20+payloadLen)
|
|
ipv6H := header.IPv6(tcp6ExtHeader)
|
|
ipv6H.Encode(&header.IPv6Fields{
|
|
SrcAddr: tcpip.AddrFromSlice(netip.MustParseAddr("2001:db8::1").AsSlice()),
|
|
DstAddr: tcpip.AddrFromSlice(netip.MustParseAddr("2001:db8::2").AsSlice()),
|
|
TransportProtocol: 60, // really next header; destination options ext header
|
|
HopLimit: 64,
|
|
PayloadLength: uint16(8 + 20 + payloadLen),
|
|
})
|
|
tcp6ExtHeader[40] = uint8(header.TCPProtocolNumber) // next header
|
|
tcp6ExtHeader[41] = 0 // length of ext header in 8-octet units, exclusive of first 8 octets.
|
|
// 42-47 options and padding
|
|
tcpH = header.TCP(tcp6ExtHeader[48:])
|
|
tcpH.Encode(tcpFields)
|
|
pseudoCsum = header.PseudoHeaderChecksum(header.TCPProtocolNumber, ipv6H.SourceAddress(), ipv6H.DestinationAddress(), uint16(20+payloadLen))
|
|
tcpH.SetChecksum(^tcpH.CalculateChecksum(pseudoCsum))
|
|
|
|
tcp4InvalidCsum := make([]byte, len(tcp4))
|
|
copy(tcp4InvalidCsum, tcp4)
|
|
at := 20 + 16
|
|
tcp4InvalidCsum[at] = ^tcp4InvalidCsum[at]
|
|
|
|
tcp6ExtHeaderInvalidCsum := make([]byte, len(tcp6ExtHeader))
|
|
copy(tcp6ExtHeaderInvalidCsum, tcp6ExtHeader)
|
|
at = 40 + 8 + 16
|
|
tcp6ExtHeaderInvalidCsum[at] = ^tcp6ExtHeaderInvalidCsum[at]
|
|
|
|
tests := []struct {
|
|
name string
|
|
input []byte
|
|
wantPB bool
|
|
}{
|
|
{
|
|
"tcp4 packet valid csum",
|
|
tcp4,
|
|
true,
|
|
},
|
|
{
|
|
"tcp6 with ext header valid csum",
|
|
tcp6ExtHeader,
|
|
true,
|
|
},
|
|
{
|
|
"tcp4 packet invalid csum",
|
|
tcp4InvalidCsum,
|
|
false,
|
|
},
|
|
{
|
|
"tcp6 with ext header invalid csum",
|
|
tcp6ExtHeaderInvalidCsum,
|
|
false,
|
|
},
|
|
}
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
p := &packet.Parsed{}
|
|
p.Decode(tt.input)
|
|
got := RXChecksumOffload(p)
|
|
if tt.wantPB != (got != nil) {
|
|
t.Fatalf("wantPB = %v != (got != nil): %v", tt.wantPB, got != nil)
|
|
}
|
|
if tt.wantPB {
|
|
gotBuf := got.ToBuffer()
|
|
if !bytes.Equal(tt.input, gotBuf.Flatten()) {
|
|
t.Fatal("output packet unequal to input")
|
|
}
|
|
}
|
|
})
|
|
}
|
|
}
|