net/udprelay: use batching.Conn (#16866)

This significantly improves throughput of a peer relay server on Linux.

Server.packetReadLoop no longer passes sockets down the stack. Instead,
packet handling methods return a netip.AddrPort and []byte, which
packetReadLoop gathers together for eventual batched writes on the
appropriate socket(s).

Updates tailscale/corp#31164

Signed-off-by: Jordan Whited <jordan@tailscale.com>
This commit is contained in:
Jordan Whited
2025-08-19 14:44:39 -07:00
committed by GitHub
parent 5c560d7489
commit d4b7200129
6 changed files with 153 additions and 63 deletions

View File

@@ -32,7 +32,6 @@ type Conn interface {
// message may fall on either side of a nonzero.
//
// Each [ipv6.Message.OOB] must be sized to at least MinControlMessageSize().
// len(msgs) must be at least MinReadBatchMsgsLen().
ReadBatch(msgs []ipv6.Message, flags int) (n int, err error)
// WriteBatchTo writes buffs to addr.
//

View File

@@ -19,3 +19,5 @@ var controlMessageSize = 0
func MinControlMessageSize() int {
return controlMessageSize
}
const IdealBatchSize = 1

View File

@@ -384,7 +384,7 @@ func setGSOSizeInControl(control *[]byte, gsoSize uint16) {
}
// TryUpgradeToConn probes the capabilities of the OS and pconn, and upgrades
// pconn to a [Conn] if appropriate. A batch size of MinReadBatchMsgsLen() is
// pconn to a [Conn] if appropriate. A batch size of [IdealBatchSize] is
// suggested for the best performance.
func TryUpgradeToConn(pconn nettype.PacketConn, network string, batchSize int) nettype.PacketConn {
if runtime.GOOS != "linux" {
@@ -457,6 +457,4 @@ func MinControlMessageSize() int {
return controlMessageSize
}
func MinReadBatchMsgsLen() int {
return 128
}
const IdealBatchSize = 128

View File

@@ -310,7 +310,7 @@ func TestMinReadBatchMsgsLen(t *testing.T) {
// So long as magicsock uses [Conn], and [wireguard-go/conn.Bind] API is
// shaped for wireguard-go to control packet memory, these values should be
// aligned.
if MinReadBatchMsgsLen() != conn.IdealBatchSize {
t.Fatalf("MinReadBatchMsgsLen():%d != conn.IdealBatchSize(): %d", MinReadBatchMsgsLen(), conn.IdealBatchSize)
if IdealBatchSize != conn.IdealBatchSize {
t.Fatalf("IdealBatchSize: %d != conn.IdealBatchSize(): %d", IdealBatchSize, conn.IdealBatchSize)
}
}

View File

@@ -20,8 +20,11 @@ import (
"time"
"go4.org/mem"
"golang.org/x/net/ipv6"
"tailscale.com/client/local"
"tailscale.com/disco"
"tailscale.com/net/batching"
"tailscale.com/net/netaddr"
"tailscale.com/net/netcheck"
"tailscale.com/net/netmon"
"tailscale.com/net/packet"
@@ -57,10 +60,10 @@ type Server struct {
bindLifetime time.Duration
steadyStateLifetime time.Duration
bus *eventbus.Bus
uc4 *net.UDPConn // always non-nil
uc4Port uint16 // always nonzero
uc6 *net.UDPConn // may be nil if IPv6 bind fails during initialization
uc6Port uint16 // may be zero if IPv6 bind fails during initialization
uc4 batching.Conn // always non-nil
uc4Port uint16 // always nonzero
uc6 batching.Conn // may be nil if IPv6 bind fails during initialization
uc6Port uint16 // may be zero if IPv6 bind fails during initialization
closeOnce sync.Once
wg sync.WaitGroup
closeCh chan struct{}
@@ -96,9 +99,9 @@ type serverEndpoint struct {
allocatedAt time.Time
}
func (e *serverEndpoint) handleDiscoControlMsg(from netip.AddrPort, senderIndex int, discoMsg disco.Message, conn *net.UDPConn, serverDisco key.DiscoPublic) {
func (e *serverEndpoint) handleDiscoControlMsg(from netip.AddrPort, senderIndex int, discoMsg disco.Message, serverDisco key.DiscoPublic) (write []byte, to netip.AddrPort) {
if senderIndex != 0 && senderIndex != 1 {
return
return nil, netip.AddrPort{}
}
otherSender := 0
@@ -121,15 +124,15 @@ func (e *serverEndpoint) handleDiscoControlMsg(from netip.AddrPort, senderIndex
err := validateVNIAndRemoteKey(discoMsg.BindUDPRelayEndpointCommon)
if err != nil {
// silently drop
return
return nil, netip.AddrPort{}
}
if discoMsg.Generation == 0 {
// Generation must be nonzero, silently drop
return
return nil, netip.AddrPort{}
}
if e.handshakeGeneration[senderIndex] == discoMsg.Generation {
// we've seen this generation before, silently drop
return
return nil, netip.AddrPort{}
}
e.handshakeGeneration[senderIndex] = discoMsg.Generation
e.handshakeAddrPorts[senderIndex] = from
@@ -144,19 +147,18 @@ func (e *serverEndpoint) handleDiscoControlMsg(from netip.AddrPort, senderIndex
gh.VNI.Set(e.vni)
err = gh.Encode(reply)
if err != nil {
return
return nil, netip.AddrPort{}
}
reply = append(reply, disco.Magic...)
reply = serverDisco.AppendTo(reply)
box := e.discoSharedSecrets[senderIndex].Seal(m.AppendMarshal(nil))
reply = append(reply, box...)
conn.WriteMsgUDPAddrPort(reply, nil, from)
return
return reply, from
case *disco.BindUDPRelayEndpointAnswer:
err := validateVNIAndRemoteKey(discoMsg.BindUDPRelayEndpointCommon)
if err != nil {
// silently drop
return
return nil, netip.AddrPort{}
}
generation := e.handshakeGeneration[senderIndex]
if generation == 0 || // we have no active handshake
@@ -164,23 +166,23 @@ func (e *serverEndpoint) handleDiscoControlMsg(from netip.AddrPort, senderIndex
e.handshakeAddrPorts[senderIndex] != from || // mismatching source for the active handshake
!bytes.Equal(e.challenge[senderIndex][:], discoMsg.Challenge[:]) { // mismatching answer for the active handshake
// silently drop
return
return nil, netip.AddrPort{}
}
// Handshake complete. Update the binding for this sender.
e.boundAddrPorts[senderIndex] = from
e.lastSeen[senderIndex] = time.Now() // record last seen as bound time
return
return nil, netip.AddrPort{}
default:
// unexpected message types, silently drop
return
return nil, netip.AddrPort{}
}
}
func (e *serverEndpoint) handleSealedDiscoControlMsg(from netip.AddrPort, b []byte, conn *net.UDPConn, serverDisco key.DiscoPublic) {
func (e *serverEndpoint) handleSealedDiscoControlMsg(from netip.AddrPort, b []byte, serverDisco key.DiscoPublic) (write []byte, to netip.AddrPort) {
senderRaw, isDiscoMsg := disco.Source(b)
if !isDiscoMsg {
// Not a Disco message
return
return nil, netip.AddrPort{}
}
sender := key.DiscoPublicFromRaw32(mem.B(senderRaw))
senderIndex := -1
@@ -191,63 +193,51 @@ func (e *serverEndpoint) handleSealedDiscoControlMsg(from netip.AddrPort, b []by
senderIndex = 1
default:
// unknown Disco public key
return
return nil, netip.AddrPort{}
}
const headerLen = len(disco.Magic) + key.DiscoPublicRawLen
discoPayload, ok := e.discoSharedSecrets[senderIndex].Open(b[headerLen:])
if !ok {
// unable to decrypt the Disco payload
return
return nil, netip.AddrPort{}
}
discoMsg, err := disco.Parse(discoPayload)
if err != nil {
// unable to parse the Disco payload
return
return nil, netip.AddrPort{}
}
e.handleDiscoControlMsg(from, senderIndex, discoMsg, conn, serverDisco)
return e.handleDiscoControlMsg(from, senderIndex, discoMsg, serverDisco)
}
func (e *serverEndpoint) handlePacket(from netip.AddrPort, gh packet.GeneveHeader, b []byte, rxSocket, otherAFSocket *net.UDPConn, serverDisco key.DiscoPublic) {
func (e *serverEndpoint) handlePacket(from netip.AddrPort, gh packet.GeneveHeader, b []byte, serverDisco key.DiscoPublic) (write []byte, to netip.AddrPort) {
if !gh.Control {
if !e.isBound() {
// not a control packet, but serverEndpoint isn't bound
return
return nil, netip.AddrPort{}
}
var to netip.AddrPort
switch {
case from == e.boundAddrPorts[0]:
e.lastSeen[0] = time.Now()
to = e.boundAddrPorts[1]
return b, e.boundAddrPorts[1]
case from == e.boundAddrPorts[1]:
e.lastSeen[1] = time.Now()
to = e.boundAddrPorts[0]
return b, e.boundAddrPorts[0]
default:
// unrecognized source
return
return nil, netip.AddrPort{}
}
// Relay the packet towards the other party via the socket associated
// with the destination's address family. If source and destination
// address families are matching we tx on the same socket the packet
// was received (rxSocket), otherwise we use the "other" socket
// (otherAFSocket). [Server] makes no use of dual-stack sockets.
if from.Addr().Is4() == to.Addr().Is4() {
rxSocket.WriteMsgUDPAddrPort(b, nil, to)
} else if otherAFSocket != nil {
otherAFSocket.WriteMsgUDPAddrPort(b, nil, to)
}
return
}
if gh.Protocol != packet.GeneveProtocolDisco {
// control packet, but not Disco
return
return nil, netip.AddrPort{}
}
msg := b[packet.GeneveFixedHeaderLength:]
e.handleSealedDiscoControlMsg(from, msg, rxSocket, serverDisco)
return e.handleSealedDiscoControlMsg(from, msg, serverDisco)
}
func (e *serverEndpoint) isExpired(now time.Time, bindLifetime, steadyStateLifetime time.Duration) bool {
@@ -338,10 +328,10 @@ func NewServer(logf logger.Logf, port int, overrideAddrs []netip.Addr) (s *Serve
}
s.wg.Add(1)
go s.packetReadLoop(s.uc4, s.uc6)
go s.packetReadLoop(s.uc4, s.uc6, true)
if s.uc6 != nil {
s.wg.Add(1)
go s.packetReadLoop(s.uc6, s.uc4)
go s.packetReadLoop(s.uc6, s.uc4, false)
}
s.wg.Add(1)
go s.endpointGCLoop()
@@ -425,6 +415,41 @@ func (s *Server) addrDiscoveryLoop() {
}
}
// This is a compile-time assertion that [singlePacketConn] implements the
// [batching.Conn] interface.
var _ batching.Conn = (*singlePacketConn)(nil)
// singlePacketConn implements [batching.Conn] with single packet syscall
// operations.
type singlePacketConn struct {
*net.UDPConn
}
func (c *singlePacketConn) ReadBatch(msgs []ipv6.Message, _ int) (int, error) {
n, ap, err := c.UDPConn.ReadFromUDPAddrPort(msgs[0].Buffers[0])
if err != nil {
return 0, err
}
msgs[0].N = n
msgs[0].Addr = net.UDPAddrFromAddrPort(netaddr.Unmap(ap))
return 1, nil
}
func (c *singlePacketConn) WriteBatchTo(buffs [][]byte, addr netip.AddrPort, geneve packet.GeneveHeader, offset int) error {
for _, buff := range buffs {
if geneve.VNI.IsSet() {
geneve.Encode(buff)
} else {
buff = buff[offset:]
}
_, err := c.UDPConn.WriteToUDPAddrPort(buff, addr)
if err != nil {
return err
}
}
return nil
}
// listenOn binds an IPv4 and IPv6 socket to port. We consider it successful if
// we manage to bind the IPv4 socket.
//
@@ -433,7 +458,10 @@ func (s *Server) addrDiscoveryLoop() {
// across IPv4 and IPv6 if the requested port is zero.
//
// TODO: make these "re-bindable" in similar fashion to magicsock as a means to
// deal with EDR software closing them. http://go/corp/30118
// deal with EDR software closing them. http://go/corp/30118. We could re-use
// [magicsock.RebindingConn], which would also remove the need for
// [singlePacketConn], as [magicsock.RebindingConn] also handles fallback to
// single packet syscall operations.
func (s *Server) listenOn(port int) error {
for _, network := range []string{"udp4", "udp6"} {
uc, err := net.ListenUDP(network, &net.UDPAddr{Port: port})
@@ -462,11 +490,16 @@ func (s *Server) listenOn(port int) error {
}
return err
}
pc := batching.TryUpgradeToConn(uc, network, batching.IdealBatchSize)
bc, ok := pc.(batching.Conn)
if !ok {
bc = &singlePacketConn{uc}
}
if network == "udp4" {
s.uc4 = uc
s.uc4 = bc
s.uc4Port = uint16(portUint)
} else {
s.uc6 = uc
s.uc6 = bc
s.uc6Port = uint16(portUint)
}
}
@@ -526,18 +559,18 @@ func (s *Server) endpointGCLoop() {
}
}
func (s *Server) handlePacket(from netip.AddrPort, b []byte, rxSocket, otherAFSocket *net.UDPConn) {
func (s *Server) handlePacket(from netip.AddrPort, b []byte) (write []byte, to netip.AddrPort) {
if stun.Is(b) && b[1] == 0x01 {
// A b[1] value of 0x01 (STUN method binding) is sufficiently
// non-overlapping with the Geneve header where the LSB is always 0
// (part of 6 "reserved" bits).
s.netChecker.ReceiveSTUNPacket(b, from)
return
return nil, netip.AddrPort{}
}
gh := packet.GeneveHeader{}
err := gh.Decode(b)
if err != nil {
return
return nil, netip.AddrPort{}
}
// TODO: consider performance implications of holding s.mu for the remainder
// of this method, which does a bunch of disco/crypto work depending. Keep
@@ -547,13 +580,13 @@ func (s *Server) handlePacket(from netip.AddrPort, b []byte, rxSocket, otherAFSo
e, ok := s.byVNI[gh.VNI.Get()]
if !ok {
// unknown VNI
return
return nil, netip.AddrPort{}
}
e.handlePacket(from, gh, b, rxSocket, otherAFSocket, s.discoPublic)
return e.handlePacket(from, gh, b, s.discoPublic)
}
func (s *Server) packetReadLoop(readFromSocket, otherSocket *net.UDPConn) {
func (s *Server) packetReadLoop(readFromSocket, otherSocket batching.Conn, readFromSocketIsIPv4 bool) {
defer func() {
// We intentionally close the [Server] if we encounter a socket read
// error below, at least until socket "re-binding" is implemented as
@@ -564,15 +597,73 @@ func (s *Server) packetReadLoop(readFromSocket, otherSocket *net.UDPConn) {
s.wg.Done()
s.Close()
}()
b := make([]byte, 1<<16-1)
msgs := make([]ipv6.Message, batching.IdealBatchSize)
for i := range msgs {
msgs[i].OOB = make([]byte, batching.MinControlMessageSize())
msgs[i].Buffers = make([][]byte, 1)
msgs[i].Buffers[0] = make([]byte, 1<<16-1)
}
writeBuffsByDest := make(map[netip.AddrPort][][]byte, batching.IdealBatchSize)
for {
for i := range msgs {
msgs[i] = ipv6.Message{Buffers: msgs[i].Buffers, OOB: msgs[i].OOB[:cap(msgs[i].OOB)]}
}
// TODO: extract laddr from IP_PKTINFO for use in reply
n, from, err := readFromSocket.ReadFromUDPAddrPort(b)
// ReadBatch will split coalesced datagrams before returning, which
// WriteBatchTo will re-coalesce further down. We _could_ be more
// efficient and not split datagrams that belong to the same VNI if they
// are non-control/handshake packets. We pay the memmove/memcopy
// performance penalty for now in the interest of simple single packet
// handlers.
n, err := readFromSocket.ReadBatch(msgs, 0)
if err != nil {
s.logf("error reading from socket(%v): %v", readFromSocket.LocalAddr(), err)
return
}
s.handlePacket(from, b[:n], readFromSocket, otherSocket)
for _, msg := range msgs[:n] {
if msg.N == 0 {
continue
}
buf := msg.Buffers[0][:msg.N]
from := msg.Addr.(*net.UDPAddr).AddrPort()
write, to := s.handlePacket(from, buf)
if !to.IsValid() {
continue
}
if from.Addr().Is4() == to.Addr().Is4() || otherSocket != nil {
buffs, ok := writeBuffsByDest[to]
if !ok {
buffs = make([][]byte, 0, batching.IdealBatchSize)
}
buffs = append(buffs, write)
writeBuffsByDest[to] = buffs
} else {
// This is unexpected. We should never produce a packet to write
// to the "other" socket if the other socket is nil/unbound.
// [server.handlePacket] has to see a packet from a particular
// address family at least once in order for it to return a
// packet to write towards a dest for the same address family.
s.logf("[unexpected] packet from: %v produced packet to: %v while otherSocket is nil", from, to)
}
}
for dest, buffs := range writeBuffsByDest {
// Write the packet batches via the socket associated with the
// destination's address family. If source and destination address
// families are matching we tx on the same socket the packet was
// received, otherwise we use the "other" socket. [Server] makes no
// use of dual-stack sockets.
if dest.Addr().Is4() == readFromSocketIsIPv4 {
readFromSocket.WriteBatchTo(buffs, dest, packet.GeneveHeader{}, 0)
} else {
otherSocket.WriteBatchTo(buffs, dest, packet.GeneveHeader{}, 0)
}
delete(writeBuffsByDest, dest)
}
}
}