net/{batching,packet},wgengine/magicsock: export batchingConn (#16848)

For eventual use by net/udprelay.Server.

Updates tailscale/corp#31164

Signed-off-by: Jordan Whited <jordan@tailscale.com>
This commit is contained in:
Jordan Whited
2025-08-13 13:13:11 -07:00
committed by GitHub
parent f22c7657e5
commit 16bc0a5558
25 changed files with 328 additions and 268 deletions

48
net/batching/conn.go Normal file
View File

@@ -0,0 +1,48 @@
// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause
// Package batching implements a socket optimized for increased throughput.
package batching
import (
"net/netip"
"golang.org/x/net/ipv4"
"golang.org/x/net/ipv6"
"tailscale.com/net/packet"
"tailscale.com/types/nettype"
)
var (
// This acts as a compile-time check for our usage of ipv6.Message in
// [Conn] for both IPv6 and IPv4 operations.
_ ipv6.Message = ipv4.Message{}
)
// Conn is a nettype.PacketConn that provides batched i/o using
// platform-specific optimizations, e.g. {recv,send}mmsg & UDP GSO/GRO.
//
// Conn originated from (and is still used by) magicsock where its API was
// strongly influenced by [wireguard-go/conn.Bind] constraints, namely
// wireguard-go's ownership of packet memory.
type Conn interface {
nettype.PacketConn
// ReadBatch reads messages from [Conn] into msgs. It returns the number of
// messages the caller should evaluate for nonzero len, as a zero len
// message may fall on either side of a nonzero.
//
// Each [ipv6.Message.OOB] must be sized to at least MinControlMessageSize().
// len(msgs) must be at least MinReadBatchMsgsLen().
ReadBatch(msgs []ipv6.Message, flags int) (n int, err error)
// WriteBatchTo writes buffs to addr.
//
// If geneve.VNI.IsSet(), then geneve is encoded into the space preceding
// offset, and offset must equal [packet.GeneveFixedHeaderLength]. If
// !geneve.VNI.IsSet() then the space preceding offset is ignored.
//
// len(buffs) must be <= batchSize supplied in TryUpgradeToConn().
//
// WriteBatchTo may return a [neterror.ErrUDPGSODisabled] error if UDP GSO
// was disabled as a result of a send error.
WriteBatchTo(buffs [][]byte, addr netip.AddrPort, geneve packet.GeneveHeader, offset int) error
}

View File

@@ -0,0 +1,21 @@
// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause
//go:build !linux
package batching
import (
"tailscale.com/types/nettype"
)
// TryUpgradeToConn is no-op on all platforms except linux.
func TryUpgradeToConn(pconn nettype.PacketConn, _ string, _ int) nettype.PacketConn {
return pconn
}
var controlMessageSize = 0
func MinControlMessageSize() int {
return controlMessageSize
}

462
net/batching/conn_linux.go Normal file
View File

@@ -0,0 +1,462 @@
// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause
package batching
import (
"encoding/binary"
"errors"
"fmt"
"net"
"net/netip"
"runtime"
"strings"
"sync"
"sync/atomic"
"syscall"
"time"
"unsafe"
"golang.org/x/net/ipv4"
"golang.org/x/net/ipv6"
"golang.org/x/sys/unix"
"tailscale.com/hostinfo"
"tailscale.com/net/neterror"
"tailscale.com/net/packet"
"tailscale.com/types/nettype"
)
// xnetBatchReaderWriter defines the batching i/o methods of
// golang.org/x/net/ipv4.PacketConn (and ipv6.PacketConn).
// TODO(jwhited): This should eventually be replaced with the standard library
// implementation of https://github.com/golang/go/issues/45886
type xnetBatchReaderWriter interface {
xnetBatchReader
xnetBatchWriter
}
type xnetBatchReader interface {
ReadBatch([]ipv6.Message, int) (int, error)
}
type xnetBatchWriter interface {
WriteBatch([]ipv6.Message, int) (int, error)
}
var (
// [linuxBatchingConn] implements [Conn].
_ Conn = &linuxBatchingConn{}
)
// linuxBatchingConn is a UDP socket that provides batched i/o. It implements
// [Conn].
type linuxBatchingConn struct {
pc *net.UDPConn
xpc xnetBatchReaderWriter
rxOffload bool // supports UDP GRO or similar
txOffload atomic.Bool // supports UDP GSO or similar
setGSOSizeInControl func(control *[]byte, gsoSize uint16) // typically setGSOSizeInControl(); swappable for testing
getGSOSizeFromControl func(control []byte) (int, error) // typically getGSOSizeFromControl(); swappable for testing
sendBatchPool sync.Pool
}
func (c *linuxBatchingConn) ReadFromUDPAddrPort(p []byte) (n int, addr netip.AddrPort, err error) {
if c.rxOffload {
// UDP_GRO is opt-in on Linux via setsockopt(). Once enabled you may
// receive a "monster datagram" from any read call. The ReadFrom() API
// does not support passing the GSO size and is unsafe to use in such a
// case. Other platforms may vary in behavior, but we go with the most
// conservative approach to prevent this from becoming a footgun in the
// future.
return 0, netip.AddrPort{}, errors.New("rx UDP offload is enabled on this socket, single packet reads are unavailable")
}
return c.pc.ReadFromUDPAddrPort(p)
}
func (c *linuxBatchingConn) SetDeadline(t time.Time) error {
return c.pc.SetDeadline(t)
}
func (c *linuxBatchingConn) SetReadDeadline(t time.Time) error {
return c.pc.SetReadDeadline(t)
}
func (c *linuxBatchingConn) SetWriteDeadline(t time.Time) error {
return c.pc.SetWriteDeadline(t)
}
const (
// This was initially established for Linux, but may split out to
// GOOS-specific values later. It originates as UDP_MAX_SEGMENTS in the
// kernel's TX path, and UDP_GRO_CNT_MAX for RX.
udpSegmentMaxDatagrams = 64
)
const (
// Exceeding these values results in EMSGSIZE.
maxIPv4PayloadLen = 1<<16 - 1 - 20 - 8
maxIPv6PayloadLen = 1<<16 - 1 - 8
)
// coalesceMessages iterates 'buffs', setting and coalescing them in 'msgs'
// where possible while maintaining datagram order.
//
// All msgs have their Addr field set to addr.
//
// All msgs[i].Buffers[0] are preceded by a Geneve header (geneve) if geneve.VNI.IsSet().
func (c *linuxBatchingConn) coalesceMessages(addr *net.UDPAddr, geneve packet.GeneveHeader, buffs [][]byte, msgs []ipv6.Message, offset int) int {
var (
base = -1 // index of msg we are currently coalescing into
gsoSize int // segmentation size of msgs[base]
dgramCnt int // number of dgrams coalesced into msgs[base]
endBatch bool // tracking flag to start a new batch on next iteration of buffs
)
maxPayloadLen := maxIPv4PayloadLen
if addr.IP.To4() == nil {
maxPayloadLen = maxIPv6PayloadLen
}
vniIsSet := geneve.VNI.IsSet()
for i, buff := range buffs {
if vniIsSet {
geneve.Encode(buff)
} else {
buff = buff[offset:]
}
if i > 0 {
msgLen := len(buff)
baseLenBefore := len(msgs[base].Buffers[0])
freeBaseCap := cap(msgs[base].Buffers[0]) - baseLenBefore
if msgLen+baseLenBefore <= maxPayloadLen &&
msgLen <= gsoSize &&
msgLen <= freeBaseCap &&
dgramCnt < udpSegmentMaxDatagrams &&
!endBatch {
msgs[base].Buffers[0] = append(msgs[base].Buffers[0], make([]byte, msgLen)...)
copy(msgs[base].Buffers[0][baseLenBefore:], buff)
if i == len(buffs)-1 {
c.setGSOSizeInControl(&msgs[base].OOB, uint16(gsoSize))
}
dgramCnt++
if msgLen < gsoSize {
// A smaller than gsoSize packet on the tail is legal, but
// it must end the batch.
endBatch = true
}
continue
}
}
if dgramCnt > 1 {
c.setGSOSizeInControl(&msgs[base].OOB, uint16(gsoSize))
}
// Reset prior to incrementing base since we are preparing to start a
// new potential batch.
endBatch = false
base++
gsoSize = len(buff)
msgs[base].OOB = msgs[base].OOB[:0]
msgs[base].Buffers[0] = buff
msgs[base].Addr = addr
dgramCnt = 1
}
return base + 1
}
type sendBatch struct {
msgs []ipv6.Message
ua *net.UDPAddr
}
func (c *linuxBatchingConn) getSendBatch() *sendBatch {
batch := c.sendBatchPool.Get().(*sendBatch)
return batch
}
func (c *linuxBatchingConn) putSendBatch(batch *sendBatch) {
for i := range batch.msgs {
batch.msgs[i] = ipv6.Message{Buffers: batch.msgs[i].Buffers, OOB: batch.msgs[i].OOB}
}
c.sendBatchPool.Put(batch)
}
func (c *linuxBatchingConn) WriteBatchTo(buffs [][]byte, addr netip.AddrPort, geneve packet.GeneveHeader, offset int) error {
batch := c.getSendBatch()
defer c.putSendBatch(batch)
if addr.Addr().Is6() {
as16 := addr.Addr().As16()
copy(batch.ua.IP, as16[:])
batch.ua.IP = batch.ua.IP[:16]
} else {
as4 := addr.Addr().As4()
copy(batch.ua.IP, as4[:])
batch.ua.IP = batch.ua.IP[:4]
}
batch.ua.Port = int(addr.Port())
var (
n int
retried bool
)
retry:
if c.txOffload.Load() {
n = c.coalesceMessages(batch.ua, geneve, buffs, batch.msgs, offset)
} else {
vniIsSet := geneve.VNI.IsSet()
if vniIsSet {
offset -= packet.GeneveFixedHeaderLength
}
for i := range buffs {
if vniIsSet {
geneve.Encode(buffs[i])
}
batch.msgs[i].Buffers[0] = buffs[i][offset:]
batch.msgs[i].Addr = batch.ua
batch.msgs[i].OOB = batch.msgs[i].OOB[:0]
}
n = len(buffs)
}
err := c.writeBatch(batch.msgs[:n])
if err != nil && c.txOffload.Load() && neterror.ShouldDisableUDPGSO(err) {
c.txOffload.Store(false)
retried = true
goto retry
}
if retried {
return neterror.ErrUDPGSODisabled{OnLaddr: c.pc.LocalAddr().String(), RetryErr: err}
}
return err
}
func (c *linuxBatchingConn) SyscallConn() (syscall.RawConn, error) {
return c.pc.SyscallConn()
}
func (c *linuxBatchingConn) writeBatch(msgs []ipv6.Message) error {
var head int
for {
n, err := c.xpc.WriteBatch(msgs[head:], 0)
if err != nil || n == len(msgs[head:]) {
// Returning the number of packets written would require
// unraveling individual msg len and gso size during a coalesced
// write. The top of the call stack disregards partial success,
// so keep this simple for now.
return err
}
head += n
}
}
// splitCoalescedMessages splits coalesced messages from the tail of dst
// beginning at index 'firstMsgAt' into the head of the same slice. It reports
// the number of elements to evaluate in msgs for nonzero len (msgs[i].N). An
// error is returned if a socket control message cannot be parsed or a split
// operation would overflow msgs.
func (c *linuxBatchingConn) splitCoalescedMessages(msgs []ipv6.Message, firstMsgAt int) (n int, err error) {
for i := firstMsgAt; i < len(msgs); i++ {
msg := &msgs[i]
if msg.N == 0 {
return n, err
}
var (
gsoSize int
start int
end = msg.N
numToSplit = 1
)
gsoSize, err = c.getGSOSizeFromControl(msg.OOB[:msg.NN])
if err != nil {
return n, err
}
if gsoSize > 0 {
numToSplit = (msg.N + gsoSize - 1) / gsoSize
end = gsoSize
}
for j := 0; j < numToSplit; j++ {
if n > i {
return n, errors.New("splitting coalesced packet resulted in overflow")
}
copied := copy(msgs[n].Buffers[0], msg.Buffers[0][start:end])
msgs[n].N = copied
msgs[n].Addr = msg.Addr
start = end
end += gsoSize
if end > msg.N {
end = msg.N
}
n++
}
if i != n-1 {
// It is legal for bytes to move within msg.Buffers[0] as a result
// of splitting, so we only zero the source msg len when it is not
// the destination of the last split operation above.
msg.N = 0
}
}
return n, nil
}
func (c *linuxBatchingConn) ReadBatch(msgs []ipv6.Message, flags int) (n int, err error) {
if !c.rxOffload || len(msgs) < 2 {
return c.xpc.ReadBatch(msgs, flags)
}
// Read into the tail of msgs, split into the head.
readAt := len(msgs) - 2
numRead, err := c.xpc.ReadBatch(msgs[readAt:], 0)
if err != nil || numRead == 0 {
return 0, err
}
return c.splitCoalescedMessages(msgs, readAt)
}
func (c *linuxBatchingConn) LocalAddr() net.Addr {
return c.pc.LocalAddr().(*net.UDPAddr)
}
func (c *linuxBatchingConn) WriteToUDPAddrPort(b []byte, addr netip.AddrPort) (int, error) {
return c.pc.WriteToUDPAddrPort(b, addr)
}
func (c *linuxBatchingConn) Close() error {
return c.pc.Close()
}
// tryEnableUDPOffload attempts to enable the UDP_GRO socket option on pconn,
// and returns two booleans indicating TX and RX UDP offload support.
func tryEnableUDPOffload(pconn nettype.PacketConn) (hasTX bool, hasRX bool) {
if c, ok := pconn.(*net.UDPConn); ok {
rc, err := c.SyscallConn()
if err != nil {
return
}
err = rc.Control(func(fd uintptr) {
_, errSyscall := syscall.GetsockoptInt(int(fd), unix.IPPROTO_UDP, unix.UDP_SEGMENT)
hasTX = errSyscall == nil
errSyscall = syscall.SetsockoptInt(int(fd), unix.IPPROTO_UDP, unix.UDP_GRO, 1)
hasRX = errSyscall == nil
})
if err != nil {
return false, false
}
}
return hasTX, hasRX
}
// getGSOSizeFromControl returns the GSO size found in control. If no GSO size
// is found or the len(control) < unix.SizeofCmsghdr, this function returns 0.
// A non-nil error will be returned if len(control) > unix.SizeofCmsghdr but
// its contents cannot be parsed as a socket control message.
func getGSOSizeFromControl(control []byte) (int, error) {
var (
hdr unix.Cmsghdr
data []byte
rem = control
err error
)
for len(rem) > unix.SizeofCmsghdr {
hdr, data, rem, err = unix.ParseOneSocketControlMessage(control)
if err != nil {
return 0, fmt.Errorf("error parsing socket control message: %w", err)
}
if hdr.Level == unix.SOL_UDP && hdr.Type == unix.UDP_GRO && len(data) >= 2 {
return int(binary.NativeEndian.Uint16(data[:2])), nil
}
}
return 0, nil
}
// setGSOSizeInControl sets a socket control message in control containing
// gsoSize. If len(control) < controlMessageSize control's len will be set to 0.
func setGSOSizeInControl(control *[]byte, gsoSize uint16) {
*control = (*control)[:0]
if cap(*control) < int(unsafe.Sizeof(unix.Cmsghdr{})) {
return
}
if cap(*control) < controlMessageSize {
return
}
*control = (*control)[:cap(*control)]
hdr := (*unix.Cmsghdr)(unsafe.Pointer(&(*control)[0]))
hdr.Level = unix.SOL_UDP
hdr.Type = unix.UDP_SEGMENT
hdr.SetLen(unix.CmsgLen(2))
binary.NativeEndian.PutUint16((*control)[unix.SizeofCmsghdr:], gsoSize)
*control = (*control)[:unix.CmsgSpace(2)]
}
// TryUpgradeToConn probes the capabilities of the OS and pconn, and upgrades
// pconn to a [Conn] if appropriate. A batch size of MinReadBatchMsgsLen() is
// suggested for the best performance.
func TryUpgradeToConn(pconn nettype.PacketConn, network string, batchSize int) nettype.PacketConn {
if runtime.GOOS != "linux" {
// Exclude Android.
return pconn
}
if network != "udp4" && network != "udp6" {
return pconn
}
if strings.HasPrefix(hostinfo.GetOSVersion(), "2.") {
// recvmmsg/sendmmsg were added in 2.6.33, but we support down to
// 2.6.32 for old NAS devices. See https://github.com/tailscale/tailscale/issues/6807.
// As a cheap heuristic: if the Linux kernel starts with "2", just
// consider it too old for mmsg. Nobody who cares about performance runs
// such ancient kernels. UDP offload was added much later, so no
// upgrades are available.
return pconn
}
uc, ok := pconn.(*net.UDPConn)
if !ok {
return pconn
}
b := &linuxBatchingConn{
pc: uc,
getGSOSizeFromControl: getGSOSizeFromControl,
setGSOSizeInControl: setGSOSizeInControl,
sendBatchPool: sync.Pool{
New: func() any {
ua := &net.UDPAddr{
IP: make([]byte, 16),
}
msgs := make([]ipv6.Message, batchSize)
for i := range msgs {
msgs[i].Buffers = make([][]byte, 1)
msgs[i].Addr = ua
msgs[i].OOB = make([]byte, controlMessageSize)
}
return &sendBatch{
ua: ua,
msgs: msgs,
}
},
},
}
switch network {
case "udp4":
b.xpc = ipv4.NewPacketConn(uc)
case "udp6":
b.xpc = ipv6.NewPacketConn(uc)
default:
panic("bogus network")
}
var txOffload bool
txOffload, b.rxOffload = tryEnableUDPOffload(uc)
b.txOffload.Store(txOffload)
return b
}
var controlMessageSize = -1 // bomb if used for allocation before init
func init() {
// controlMessageSize is set to hold a UDP_GRO or UDP_SEGMENT control
// message. These contain a single uint16 of data.
controlMessageSize = unix.CmsgSpace(2)
}
// MinControlMessageSize returns the minimum control message size required to
// support read batching via [Conn.ReadBatch].
func MinControlMessageSize() int {
return controlMessageSize
}
func MinReadBatchMsgsLen() int {
return 128
}

View File

@@ -0,0 +1,316 @@
// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause
package batching
import (
"encoding/binary"
"net"
"testing"
"github.com/tailscale/wireguard-go/conn"
"golang.org/x/net/ipv6"
"tailscale.com/net/packet"
)
func setGSOSize(control *[]byte, gsoSize uint16) {
*control = (*control)[:cap(*control)]
binary.LittleEndian.PutUint16(*control, gsoSize)
}
func getGSOSize(control []byte) (int, error) {
if len(control) < 2 {
return 0, nil
}
return int(binary.LittleEndian.Uint16(control)), nil
}
func Test_linuxBatchingConn_splitCoalescedMessages(t *testing.T) {
c := &linuxBatchingConn{
setGSOSizeInControl: setGSOSize,
getGSOSizeFromControl: getGSOSize,
}
newMsg := func(n, gso int) ipv6.Message {
msg := ipv6.Message{
Buffers: [][]byte{make([]byte, 1024)},
N: n,
OOB: make([]byte, 2),
}
binary.LittleEndian.PutUint16(msg.OOB, uint16(gso))
if gso > 0 {
msg.NN = 2
}
return msg
}
cases := []struct {
name string
msgs []ipv6.Message
firstMsgAt int
wantNumEval int
wantMsgLens []int
wantErr bool
}{
{
name: "second last split last empty",
msgs: []ipv6.Message{
newMsg(0, 0),
newMsg(0, 0),
newMsg(3, 1),
newMsg(0, 0),
},
firstMsgAt: 2,
wantNumEval: 3,
wantMsgLens: []int{1, 1, 1, 0},
wantErr: false,
},
{
name: "second last no split last empty",
msgs: []ipv6.Message{
newMsg(0, 0),
newMsg(0, 0),
newMsg(1, 0),
newMsg(0, 0),
},
firstMsgAt: 2,
wantNumEval: 1,
wantMsgLens: []int{1, 0, 0, 0},
wantErr: false,
},
{
name: "second last no split last no split",
msgs: []ipv6.Message{
newMsg(0, 0),
newMsg(0, 0),
newMsg(1, 0),
newMsg(1, 0),
},
firstMsgAt: 2,
wantNumEval: 2,
wantMsgLens: []int{1, 1, 0, 0},
wantErr: false,
},
{
name: "second last no split last split",
msgs: []ipv6.Message{
newMsg(0, 0),
newMsg(0, 0),
newMsg(1, 0),
newMsg(3, 1),
},
firstMsgAt: 2,
wantNumEval: 4,
wantMsgLens: []int{1, 1, 1, 1},
wantErr: false,
},
{
name: "second last split last split",
msgs: []ipv6.Message{
newMsg(0, 0),
newMsg(0, 0),
newMsg(2, 1),
newMsg(2, 1),
},
firstMsgAt: 2,
wantNumEval: 4,
wantMsgLens: []int{1, 1, 1, 1},
wantErr: false,
},
{
name: "second last no split last split overflow",
msgs: []ipv6.Message{
newMsg(0, 0),
newMsg(0, 0),
newMsg(1, 0),
newMsg(4, 1),
},
firstMsgAt: 2,
wantNumEval: 4,
wantMsgLens: []int{1, 1, 1, 1},
wantErr: true,
},
}
for _, tt := range cases {
t.Run(tt.name, func(t *testing.T) {
got, err := c.splitCoalescedMessages(tt.msgs, 2)
if err != nil && !tt.wantErr {
t.Fatalf("err: %v", err)
}
if got != tt.wantNumEval {
t.Fatalf("got to eval: %d want: %d", got, tt.wantNumEval)
}
for i, msg := range tt.msgs {
if msg.N != tt.wantMsgLens[i] {
t.Fatalf("msg[%d].N: %d want: %d", i, msg.N, tt.wantMsgLens[i])
}
}
})
}
}
func Test_linuxBatchingConn_coalesceMessages(t *testing.T) {
c := &linuxBatchingConn{
setGSOSizeInControl: setGSOSize,
getGSOSizeFromControl: getGSOSize,
}
withGeneveSpace := func(len, cap int) []byte {
return make([]byte, len+packet.GeneveFixedHeaderLength, cap+packet.GeneveFixedHeaderLength)
}
geneve := packet.GeneveHeader{
Protocol: packet.GeneveProtocolWireGuard,
}
geneve.VNI.Set(1)
cases := []struct {
name string
buffs [][]byte
geneve packet.GeneveHeader
wantLens []int
wantGSO []int
}{
{
name: "one message no coalesce",
buffs: [][]byte{
withGeneveSpace(1, 1),
},
wantLens: []int{1},
wantGSO: []int{0},
},
{
name: "one message no coalesce vni.isSet",
buffs: [][]byte{
withGeneveSpace(1, 1),
},
geneve: geneve,
wantLens: []int{1 + packet.GeneveFixedHeaderLength},
wantGSO: []int{0},
},
{
name: "two messages equal len coalesce",
buffs: [][]byte{
withGeneveSpace(1, 2),
withGeneveSpace(1, 1),
},
wantLens: []int{2},
wantGSO: []int{1},
},
{
name: "two messages equal len coalesce vni.isSet",
buffs: [][]byte{
withGeneveSpace(1, 2+packet.GeneveFixedHeaderLength),
withGeneveSpace(1, 1),
},
geneve: geneve,
wantLens: []int{2 + (2 * packet.GeneveFixedHeaderLength)},
wantGSO: []int{1 + packet.GeneveFixedHeaderLength},
},
{
name: "two messages unequal len coalesce",
buffs: [][]byte{
withGeneveSpace(2, 3),
withGeneveSpace(1, 1),
},
wantLens: []int{3},
wantGSO: []int{2},
},
{
name: "two messages unequal len coalesce vni.isSet",
buffs: [][]byte{
withGeneveSpace(2, 3+packet.GeneveFixedHeaderLength),
withGeneveSpace(1, 1),
},
geneve: geneve,
wantLens: []int{3 + (2 * packet.GeneveFixedHeaderLength)},
wantGSO: []int{2 + packet.GeneveFixedHeaderLength},
},
{
name: "three messages second unequal len coalesce",
buffs: [][]byte{
withGeneveSpace(2, 3),
withGeneveSpace(1, 1),
withGeneveSpace(2, 2),
},
wantLens: []int{3, 2},
wantGSO: []int{2, 0},
},
{
name: "three messages second unequal len coalesce vni.isSet",
buffs: [][]byte{
withGeneveSpace(2, 3+(2*packet.GeneveFixedHeaderLength)),
withGeneveSpace(1, 1),
withGeneveSpace(2, 2),
},
geneve: geneve,
wantLens: []int{3 + (2 * packet.GeneveFixedHeaderLength), 2 + packet.GeneveFixedHeaderLength},
wantGSO: []int{2 + packet.GeneveFixedHeaderLength, 0},
},
{
name: "three messages limited cap coalesce",
buffs: [][]byte{
withGeneveSpace(2, 4),
withGeneveSpace(2, 2),
withGeneveSpace(2, 2),
},
wantLens: []int{4, 2},
wantGSO: []int{2, 0},
},
{
name: "three messages limited cap coalesce vni.isSet",
buffs: [][]byte{
withGeneveSpace(2, 4+packet.GeneveFixedHeaderLength),
withGeneveSpace(2, 2),
withGeneveSpace(2, 2),
},
geneve: geneve,
wantLens: []int{4 + (2 * packet.GeneveFixedHeaderLength), 2 + packet.GeneveFixedHeaderLength},
wantGSO: []int{2 + packet.GeneveFixedHeaderLength, 0},
},
}
for _, tt := range cases {
t.Run(tt.name, func(t *testing.T) {
addr := &net.UDPAddr{
IP: net.ParseIP("127.0.0.1"),
Port: 1,
}
msgs := make([]ipv6.Message, len(tt.buffs))
for i := range msgs {
msgs[i].Buffers = make([][]byte, 1)
msgs[i].OOB = make([]byte, 0, 2)
}
got := c.coalesceMessages(addr, tt.geneve, tt.buffs, msgs, packet.GeneveFixedHeaderLength)
if got != len(tt.wantLens) {
t.Fatalf("got len %d want: %d", got, len(tt.wantLens))
}
for i := range got {
if msgs[i].Addr != addr {
t.Errorf("msgs[%d].Addr != passed addr", i)
}
gotLen := len(msgs[i].Buffers[0])
if gotLen != tt.wantLens[i] {
t.Errorf("len(msgs[%d].Buffers[0]) %d != %d", i, gotLen, tt.wantLens[i])
}
gotGSO, err := getGSOSize(msgs[i].OOB)
if err != nil {
t.Fatalf("msgs[%d] getGSOSize err: %v", i, err)
}
if gotGSO != tt.wantGSO[i] {
t.Errorf("msgs[%d] gsoSize %d != %d", i, gotGSO, tt.wantGSO[i])
}
}
})
}
}
func TestMinReadBatchMsgsLen(t *testing.T) {
// So long as magicsock uses [Conn], and [wireguard-go/conn.Bind] API is
// shaped for wireguard-go to control packet memory, these values should be
// aligned.
if MinReadBatchMsgsLen() != conn.IdealBatchSize {
t.Fatalf("MinReadBatchMsgsLen():%d != conn.IdealBatchSize(): %d", MinReadBatchMsgsLen(), conn.IdealBatchSize)
}
}

View File

@@ -24,6 +24,33 @@ const (
GeneveProtocolWireGuard uint16 = 0x7A12
)
// VirtualNetworkID is a Geneve header (RFC8926) 3-byte virtual network
// identifier. Its methods are NOT thread-safe.
type VirtualNetworkID struct {
_vni uint32
}
const (
vniSetMask uint32 = 0xFF000000
vniGetMask uint32 = ^vniSetMask
)
// IsSet returns true if Set() had been called previously, otherwise false.
func (v *VirtualNetworkID) IsSet() bool {
return v._vni&vniSetMask != 0
}
// Set sets the provided VNI. If VNI exceeds the 3-byte storage it will be
// clamped.
func (v *VirtualNetworkID) Set(vni uint32) {
v._vni = vni | vniSetMask
}
// Get returns the VNI value.
func (v *VirtualNetworkID) Get() uint32 {
return v._vni & vniGetMask
}
// GeneveHeader represents the fixed size Geneve header from RFC8926.
// TLVs/options are not implemented/supported.
//
@@ -51,7 +78,7 @@ type GeneveHeader struct {
// decisions or MAY be used as a mechanism to distinguish between
// overlapping address spaces contained in the encapsulated packet when load
// balancing across CPUs.
VNI uint32
VNI VirtualNetworkID
// O (1 bit): Control packet. This packet contains a control message.
// Control messages are sent between tunnel endpoints. Tunnel endpoints MUST
@@ -65,12 +92,18 @@ type GeneveHeader struct {
Control bool
}
// Encode encodes GeneveHeader into b. If len(b) < GeneveFixedHeaderLength an
// io.ErrShortBuffer error is returned.
var ErrGeneveVNIUnset = errors.New("VNI is unset")
// Encode encodes GeneveHeader into b. If len(b) < [GeneveFixedHeaderLength] an
// [io.ErrShortBuffer] error is returned. If !h.VNI.IsSet() then an
// [ErrGeneveVNIUnset] error is returned.
func (h *GeneveHeader) Encode(b []byte) error {
if len(b) < GeneveFixedHeaderLength {
return io.ErrShortBuffer
}
if !h.VNI.IsSet() {
return ErrGeneveVNIUnset
}
if h.Version > 3 {
return errors.New("version must be <= 3")
}
@@ -81,15 +114,12 @@ func (h *GeneveHeader) Encode(b []byte) error {
b[1] |= 0x80
}
binary.BigEndian.PutUint16(b[2:], h.Protocol)
if h.VNI > 1<<24-1 {
return errors.New("VNI must be <= 2^24-1")
}
binary.BigEndian.PutUint32(b[4:], h.VNI<<8)
binary.BigEndian.PutUint32(b[4:], h.VNI.Get()<<8)
return nil
}
// Decode decodes GeneveHeader from b. If len(b) < GeneveFixedHeaderLength an
// io.ErrShortBuffer error is returned.
// Decode decodes GeneveHeader from b. If len(b) < [GeneveFixedHeaderLength] an
// [io.ErrShortBuffer] error is returned.
func (h *GeneveHeader) Decode(b []byte) error {
if len(b) < GeneveFixedHeaderLength {
return io.ErrShortBuffer
@@ -99,6 +129,6 @@ func (h *GeneveHeader) Decode(b []byte) error {
h.Control = true
}
h.Protocol = binary.BigEndian.Uint16(b[2:])
h.VNI = binary.BigEndian.Uint32(b[4:]) >> 8
h.VNI.Set(binary.BigEndian.Uint32(b[4:]) >> 8)
return nil
}

View File

@@ -4,18 +4,21 @@
package packet
import (
"math"
"testing"
"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
"tailscale.com/types/ptr"
)
func TestGeneveHeader(t *testing.T) {
in := GeneveHeader{
Version: 3,
Protocol: GeneveProtocolDisco,
VNI: 1<<24 - 1,
Control: true,
}
in.VNI.Set(1<<24 - 1)
b := make([]byte, GeneveFixedHeaderLength)
err := in.Encode(b)
if err != nil {
@@ -26,7 +29,56 @@ func TestGeneveHeader(t *testing.T) {
if err != nil {
t.Fatal(err)
}
if diff := cmp.Diff(out, in); diff != "" {
if diff := cmp.Diff(out, in, cmpopts.EquateComparable(VirtualNetworkID{})); diff != "" {
t.Fatalf("wrong results (-got +want)\n%s", diff)
}
}
func TestVirtualNetworkID(t *testing.T) {
tests := []struct {
name string
set *uint32
want uint32
}{
{
"don't Set",
nil,
0,
},
{
"Set 0",
ptr.To(uint32(0)),
0,
},
{
"Set 1",
ptr.To(uint32(1)),
1,
},
{
"Set math.MaxUint32",
ptr.To(uint32(math.MaxUint32)),
1<<24 - 1,
},
{
"Set max 3-byte value",
ptr.To(uint32(1<<24 - 1)),
1<<24 - 1,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
v := VirtualNetworkID{}
if tt.set != nil {
v.Set(*tt.set)
}
if v.IsSet() != (tt.set != nil) {
t.Fatalf("IsSet: %v != wantIsSet: %v", v.IsSet(), tt.set != nil)
}
if v.Get() != tt.want {
t.Fatalf("Get(): %v != want: %v", v.Get(), tt.want)
}
})
}
}

View File

@@ -140,7 +140,8 @@ func (e *serverEndpoint) handleDiscoControlMsg(from netip.AddrPort, senderIndex
rand.Read(e.challenge[senderIndex][:])
copy(m.Challenge[:], e.challenge[senderIndex][:])
reply := make([]byte, packet.GeneveFixedHeaderLength, 512)
gh := packet.GeneveHeader{Control: true, VNI: e.vni, Protocol: packet.GeneveProtocolDisco}
gh := packet.GeneveHeader{Control: true, Protocol: packet.GeneveProtocolDisco}
gh.VNI.Set(e.vni)
err = gh.Encode(reply)
if err != nil {
return
@@ -543,7 +544,7 @@ func (s *Server) handlePacket(from netip.AddrPort, b []byte, rxSocket, otherAFSo
// it simple (and slow) for now.
s.mu.Lock()
defer s.mu.Unlock()
e, ok := s.byVNI[gh.VNI]
e, ok := s.byVNI[gh.VNI.Get()]
if !ok {
// unknown VNI
return

View File

@@ -62,7 +62,8 @@ func (c *testClient) read(t *testing.T) []byte {
func (c *testClient) writeDataPkt(t *testing.T, b []byte) {
pkt := make([]byte, packet.GeneveFixedHeaderLength, packet.GeneveFixedHeaderLength+len(b))
gh := packet.GeneveHeader{Control: false, VNI: c.vni, Protocol: packet.GeneveProtocolWireGuard}
gh := packet.GeneveHeader{Control: false, Protocol: packet.GeneveProtocolWireGuard}
gh.VNI.Set(c.vni)
err := gh.Encode(pkt)
if err != nil {
t.Fatal(err)
@@ -84,7 +85,7 @@ func (c *testClient) readDataPkt(t *testing.T) []byte {
if gh.Control {
t.Fatal("unexpected control")
}
if gh.VNI != c.vni {
if gh.VNI.Get() != c.vni {
t.Fatal("unexpected vni")
}
return b[packet.GeneveFixedHeaderLength:]
@@ -92,7 +93,8 @@ func (c *testClient) readDataPkt(t *testing.T) []byte {
func (c *testClient) writeControlDiscoMsg(t *testing.T, msg disco.Message) {
pkt := make([]byte, packet.GeneveFixedHeaderLength, 512)
gh := packet.GeneveHeader{Control: true, VNI: c.vni, Protocol: packet.GeneveProtocolDisco}
gh := packet.GeneveHeader{Control: true, Protocol: packet.GeneveProtocolDisco}
gh.VNI.Set(c.vni)
err := gh.Encode(pkt)
if err != nil {
t.Fatal(err)
@@ -117,7 +119,7 @@ func (c *testClient) readControlDiscoMsg(t *testing.T) disco.Message {
if !gh.Control {
t.Fatal("unexpected non-control")
}
if gh.VNI != c.vni {
if gh.VNI.Get() != c.vni {
t.Fatal("unexpected vni")
}
b = b[packet.GeneveFixedHeaderLength:]