wgengine/magicsock: implement wireguard-go conn.VectorBind

This commit is contained in:
Jordan Whited 2022-08-08 13:58:08 -07:00
parent 18109c63b0
commit 11eb717012
4 changed files with 380 additions and 6 deletions

2
go.mod
View File

@ -2,6 +2,8 @@ module tailscale.com
go 1.19
replace golang.zx2c4.com/wireguard => /Users/jwhited/code/wireguard-go
require (
filippo.io/mkcert v1.4.3
github.com/akutz/memconn v0.1.0

4
go.sum
View File

@ -1250,6 +1250,7 @@ golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97/go.mod h1:GvvjBRRGRdwPK5y
golang.org/x/crypto v0.0.0-20210817164053-32db794688a5/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
golang.org/x/crypto v0.0.0-20211117183948-ae814b36b871/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4=
golang.org/x/crypto v0.0.0-20220315160706-3147a52a75dd/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4=
golang.org/x/crypto v0.0.0-20220427172511-eb4f295cb31f h1:OeJjE6G4dgCY4PIXvIRQbE8+RX+uXZyGhUy/ksMGJoc=
golang.org/x/crypto v0.0.0-20220427172511-eb4f295cb31f/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4=
golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
@ -1357,6 +1358,7 @@ golang.org/x/net v0.0.0-20210928044308-7d9f5e0b762b/go.mod h1:9nx3DQGgdP8bBQD5qx
golang.org/x/net v0.0.0-20211015210444-4f30a5c0130f/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
golang.org/x/net v0.0.0-20211112202133-69e39bad7dc2/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
golang.org/x/net v0.0.0-20220127200216-cd36cc0744dd/go.mod h1:CfG3xpIq0wQ8r1q4Su4UZFWDARRcnwPjda9FqA0JpMk=
golang.org/x/net v0.0.0-20220225172249-27dd8689420f/go.mod h1:CfG3xpIq0wQ8r1q4Su4UZFWDARRcnwPjda9FqA0JpMk=
golang.org/x/net v0.0.0-20220607020251-c690dde0001d h1:4SFsTMi4UahlKoloni7L4eYzhFRifURQLw+yv0QDCx8=
golang.org/x/net v0.0.0-20220607020251-c690dde0001d/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c=
golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U=
@ -1492,9 +1494,11 @@ golang.org/x/sys v0.0.0-20211002104244-808efd93c36d/go.mod h1:oPkhp1MJrh7nUepCBc
golang.org/x/sys v0.0.0-20211013075003-97ac67df715c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20211019181941-9d821ace8654/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20211102192858-4dd72447c267/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20211103235746-7861aae1554b/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20211105183446-c75c47738b0c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20211216021012-1d35b9e2eb4e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220128215802-99c3d69c2c27/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220315194320-039c03cc5b86/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220610221304-9f5ed59c137d h1:Zu/JngovGLVi6t2J3nmAf3AoTDwuzw85YZ3b9o4yU7s=
golang.org/x/sys v0.0.0-20220610221304-9f5ed59c137d/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/term v0.0.0-20201117132131-f5c789dd3221/go.mod h1:Nr5EML6q2oocZ2LXRh80K7BxOlk5/8JxuGnuhpl+muw=

View File

@ -9,6 +9,9 @@
"context"
"net"
"net/netip"
"golang.org/x/net/ipv4"
"golang.org/x/net/ipv6"
)
// PacketListener defines the ListenPacket method as implemented
@ -42,10 +45,46 @@ type packetListenerAdapter struct {
PacketListener
}
type packetConnWithBatch struct {
PacketConn
xpc4 *ipv4.PacketConn
xpc6 *ipv6.PacketConn
}
func (p packetConnWithBatch) WriteBatchIPv4(ms []ipv4.Message, flags int) (int, error) {
return p.xpc4.WriteBatch(ms, flags)
}
func (p packetConnWithBatch) ReadBatchIPv4(ms []ipv4.Message, flags int) (int, error) {
return p.xpc4.ReadBatch(ms, flags)
}
func (p packetConnWithBatch) WriteBatchIPv6(ms []ipv6.Message, flags int) (int, error) {
return p.xpc6.WriteBatch(ms, flags)
}
func (p packetConnWithBatch) ReadBatchIPv6(ms []ipv6.Message, flags int) (int, error) {
return p.xpc6.ReadBatch(ms, flags)
}
func (a packetListenerAdapter) ListenPacket(ctx context.Context, network, address string) (PacketConn, error) {
pc, err := a.PacketListener.ListenPacket(ctx, network, address)
if err != nil {
return nil, err
}
return pc.(PacketConn), nil
return packetConnWithBatch{
PacketConn: pc.(PacketConn),
xpc4: ipv4.NewPacketConn(pc),
xpc6: ipv6.NewPacketConn(pc),
}, nil
}
type BatchWriter interface {
WriteBatchIPv4([]ipv4.Message, int) (int, error)
WriteBatchIPv6([]ipv6.Message, int) (int, error)
}
type BatchReader interface {
ReadBatchIPv4([]ipv4.Message, int) (int, error)
ReadBatchIPv6([]ipv6.Message, int) (int, error)
}

View File

@ -13,6 +13,7 @@
"encoding/binary"
"errors"
"fmt"
"golang.org/x/net/ipv6"
"hash/fnv"
"math"
"math/rand"
@ -28,6 +29,7 @@
"time"
"go4.org/mem"
"golang.org/x/net/ipv4"
"golang.zx2c4.com/wireguard/conn"
"tailscale.com/control/controlclient"
"tailscale.com/derp"
@ -1179,6 +1181,25 @@ func (c *Conn) Send(b []byte, ep conn.Endpoint) error {
return ep.(*endpoint).send(b)
}
func (c *Conn) SendV(buffs [][]byte, ep conn.Endpoint) error {
n := int64(len(buffs))
metricSendData.Add(n)
if n > 1 {
metricSendDataMultiPackets.Add(1)
} else {
metricSendDataSinglePacket.Add(1)
}
if c.networkDown() {
metricSendDataNetworkDown.Add(n)
return errNetworkDown
}
return ep.(*endpoint).sendv(buffs)
}
func (c *Conn) MaxVectorSize() int {
return maxVectorSize
}
var errConnClosed = errors.New("Conn closed")
var errDropDerpPacket = errors.New("too many DERP packets queued; dropping")
@ -1202,6 +1223,84 @@ func (c *Conn) sendUDP(ipp netip.AddrPort, b []byte) (sent bool, err error) {
return
}
const (
maxVectorSize = 20
)
type ipv4SendBatch struct {
ua *net.UDPAddr
msgs []ipv4.Message
}
var ipv4SendBatchPool = &sync.Pool{
New: func() any {
ua := &net.UDPAddr{
IP: make([]byte, 4),
}
msgs := make([]ipv4.Message, maxVectorSize)
for i := range msgs {
msgs[i].Buffers = make([][]byte, 1)
msgs[i].Addr = ua
}
return &ipv4SendBatch{
ua: ua,
msgs: msgs,
}
},
}
type ipv6SendBatch struct {
ua *net.UDPAddr
msgs []ipv6.Message
}
var ipv6SendBatchPool = &sync.Pool{
New: func() any {
ua := &net.UDPAddr{
IP: make([]byte, 4),
}
msgs := make([]ipv6.Message, maxVectorSize)
for i := range msgs {
msgs[i].Buffers = make([][]byte, 1)
msgs[i].Addr = ua
}
return &ipv6SendBatch{
ua: ua,
msgs: msgs,
}
},
}
func (c *Conn) sendUDPBatch(addr netip.AddrPort, buffs [][]byte) (sent bool, err error) {
switch {
case addr.Addr().Is4():
batch := ipv4SendBatchPool.Get().(*ipv4SendBatch)
as4 := addr.Addr().As4()
copy(batch.ua.IP, as4[:])
batch.ua.Port = int(addr.Port())
for i, buff := range buffs {
batch.msgs[i].Buffers[0] = buff
}
_, err := c.pconn4.WriteBatchIPv4(batch.msgs[:len(buffs)], 0)
ipv4SendBatchPool.Put(batch)
return err == nil, err
case addr.Addr().Is6():
batch := ipv6SendBatchPool.Get().(*ipv6SendBatch)
as16 := addr.Addr().As16()
copy(batch.ua.IP, as16[:])
batch.ua.Port = int(addr.Port())
for i, buff := range buffs {
batch.msgs[i].Buffers[0] = buff
}
_, err := c.pconn6.WriteBatchIPv6(batch.msgs[:len(buffs)], 0)
ipv6SendBatchPool.Put(batch)
return err == nil, err
default:
panic("bogus sendUDPBatch addr type")
}
return err == nil, err
}
// sendUDP sends UDP packet b to addr.
// See sendAddr's docs on the return value meanings.
func (c *Conn) sendUDPStd(addr netip.AddrPort, b []byte) (sent bool, err error) {
@ -1631,6 +1730,77 @@ func (c *Conn) runDerpWriter(ctx context.Context, dc *derphttp.Client, ch <-chan
}
}
type ipv4ReceiveBatch struct {
msgs []ipv4.Message
sizes []int
endpoints []conn.Endpoint
}
type ipv6ReceiveBatch struct {
msgs []ipv6.Message
sizes []int
endpoints []conn.Endpoint
}
var (
ipv4RB *ipv4ReceiveBatch
ipv6RB *ipv6ReceiveBatch
)
func init() {
ipv4Msgs := make([]ipv4.Message, maxVectorSize)
ipv6Msgs := make([]ipv4.Message, maxVectorSize)
for i := range ipv4Msgs {
ipv4Msgs[i].Buffers = make([][]byte, 1)
ipv6Msgs[i].Buffers = make([][]byte, 1)
}
ipv4RB = &ipv4ReceiveBatch{
msgs: ipv4Msgs,
sizes: make([]int, maxVectorSize),
endpoints: make([]conn.Endpoint, maxVectorSize),
}
ipv6RB = &ipv6ReceiveBatch{
msgs: ipv6Msgs,
sizes: make([]int, maxVectorSize),
endpoints: make([]conn.Endpoint, maxVectorSize),
}
}
func (c *Conn) receiveMultipleIPv6(buffs [][]byte) (sizes []int, eps []conn.Endpoint, err error) {
health.ReceiveIPv6.Enter()
defer health.ReceiveIPv6.Exit()
for {
batch := ipv6RB
for i := range buffs {
batch.msgs[i].Buffers[0] = buffs[i]
}
numMsgs, err := c.pconn6.ReadBatchIPv6(batch.msgs, 0)
if err != nil {
return nil, nil, err
}
for i := 0; i < numMsgs; i++ {
msg := &batch.msgs[i]
msg.Buffers[0] = msg.Buffers[0][:msg.N]
ipp := msg.Addr.(*net.UDPAddr).AddrPort()
if ep, ok := c.receiveIP(msg.Buffers[0], ipp, &c.ippEndpoint6); ok {
metricRecvDataIPv6.Add(1)
if numMsgs > 1 {
metricRecvDataMultiPackets.Add(1)
} else {
metricRecvDataSinglePacket.Add(1)
}
batch.sizes[i] = msg.N
batch.endpoints[i] = ep
} else {
batch.sizes[i] = 0
}
}
if len(batch.sizes) > 0 {
return batch.sizes[:numMsgs], batch.endpoints[:numMsgs], nil
}
}
}
// receiveIPv6 receives a UDP IPv6 packet. It is called by wireguard-go.
func (c *Conn) receiveIPv6(b []byte) (int, conn.Endpoint, error) {
health.ReceiveIPv6.Enter()
@ -1647,6 +1817,41 @@ func (c *Conn) receiveIPv6(b []byte) (int, conn.Endpoint, error) {
}
}
func (c *Conn) receiveMultipleIPv4(buffs [][]byte) ([]int, []conn.Endpoint, error) {
health.ReceiveIPv4.Enter()
defer health.ReceiveIPv4.Exit()
for {
batch := ipv4RB
for i := range buffs {
batch.msgs[i].Buffers[0] = buffs[i]
}
numMsgs, err := c.pconn4.ReadBatchIPv4(batch.msgs, 0)
if err != nil {
return nil, nil, err
}
for i := 0; i < numMsgs; i++ {
msg := &batch.msgs[i]
msg.Buffers[0] = msg.Buffers[0][:msg.N]
ipp := msg.Addr.(*net.UDPAddr).AddrPort()
if ep, ok := c.receiveIP(msg.Buffers[0], ipp, &c.ippEndpoint4); ok {
metricRecvDataIPv4.Add(1)
if numMsgs > 1 {
metricRecvDataMultiPackets.Add(1)
} else {
metricRecvDataSinglePacket.Add(1)
}
batch.sizes[i] = msg.N
batch.endpoints[i] = ep
} else {
batch.sizes[i] = 0
}
}
if len(batch.sizes) > 0 {
return batch.sizes[:numMsgs], batch.endpoints[:numMsgs], nil
}
}
}
// receiveIPv4 receives a UDP IPv4 packet. It is called by wireguard-go.
func (c *Conn) receiveIPv4(b []byte) (n int, ep conn.Endpoint, err error) {
health.ReceiveIPv4.Enter()
@ -1699,6 +1904,11 @@ func (c *Conn) receiveIP(b []byte, ipp netip.AddrPort, cache *ippEndpointCache)
return ep, true
}
func (c *connBind) receiveMultipleDERP(b [][]byte) (sizes []int, eps []conn.Endpoint, err error) {
n, ep, err := c.receiveDERP(b[0])
return []int{n}, []conn.Endpoint{ep}, err
}
// receiveDERP reads a packet from c.derpRecvCh into b and returns the associated endpoint.
// It is called by wireguard-go.
//
@ -2594,6 +2804,17 @@ type connBind struct {
closed bool
}
func (c *connBind) OpenV(_ uint16) ([]conn.ReceiveVFunc, uint16, error) {
c.mu.Lock()
defer c.mu.Unlock()
if !c.closed {
return nil, 0, errors.New("magicsock: connBind already open")
}
c.closed = false
fns := []conn.ReceiveVFunc{c.receiveMultipleIPv4, c.receiveMultipleIPv6, c.receiveMultipleDERP}
return fns, c.LocalPort(), nil
}
// Open is called by WireGuard to create a UDP binding.
// The ignoredPort comes from wireguard-go, via the wgcfg config.
// We ignore that port value here, since we have the local port available easily.
@ -3012,6 +3233,36 @@ func (c *RebindingUDPConn) ReadFrom(b []byte) (int, net.Addr, error) {
}
}
func (c *RebindingUDPConn) ReadBatchIPv4(msgs []ipv4.Message, flags int) (int, error) {
for {
pconn := c.pconnAtomic.Load()
br, ok := pconn.(nettype.BatchReader)
if !ok {
panic("pconn is not a nettype.BatchReader")
}
n, err := br.ReadBatchIPv4(msgs, flags)
if err != nil && pconn != c.currentConn() {
continue
}
return n, err
}
}
func (c *RebindingUDPConn) ReadBatchIPv6(msgs []ipv6.Message, flags int) (int, error) {
for {
pconn := c.pconnAtomic.Load()
br, ok := pconn.(nettype.BatchReader)
if !ok {
panic("pconn is not a nettype.BatchReader")
}
n, err := br.ReadBatchIPv6(msgs, flags)
if err != nil && pconn != c.currentConn() {
continue
}
return n, err
}
}
// ReadFromNetaddr reads a packet from c into b.
// It returns the number of bytes copied and the return address.
// It is identical to c.ReadFrom, except that it returns a netip.AddrPort instead of a net.Addr.
@ -3106,6 +3357,42 @@ func (c *RebindingUDPConn) WriteToUDPAddrPort(b []byte, addr netip.AddrPort) (in
}
}
func (c *RebindingUDPConn) WriteBatchIPv4(msgs []ipv4.Message, flags int) (int, error) {
for {
pconn := c.pconnAtomic.Load()
bw, ok := pconn.(nettype.BatchWriter)
if !ok {
return 0, errors.New("pconn is not a nettype.BatchWriter()")
}
n, err := bw.WriteBatchIPv4(msgs, flags)
if err != nil {
if pconn != c.currentConn() {
continue
}
}
return n, err
}
}
func (c *RebindingUDPConn) WriteBatchIPv6(msgs []ipv6.Message, flags int) (int, error) {
for {
pconn := c.pconnAtomic.Load()
bw, ok := pconn.(nettype.BatchWriter)
if !ok {
return 0, errors.New("pconn is not a nettype.BatchWriter()")
}
n, err := bw.WriteBatchIPv6(msgs, flags)
if err != nil {
if pconn != c.currentConn() {
continue
}
}
return n, err
}
}
func newBlockForeverConn() *blockForeverConn {
c := new(blockForeverConn)
c.cond = sync.NewCond(&c.mu)
@ -3138,6 +3425,11 @@ func (c *blockForeverConn) WriteToUDPAddrPort(p []byte, addr netip.AddrPort) (in
return len(p), nil
}
func (c *blockForeverConn) WriteBatch(p []ipv4.Message, flags int) (int, error) {
// Silently drop writes.
return len(p), nil
}
func (c *blockForeverConn) LocalAddr() net.Addr {
// Return a *net.UDPAddr because lots of code assumes that it will.
return new(net.UDPAddr)
@ -3577,6 +3869,39 @@ func (de *endpoint) cliPing(res *ipnstate.PingResult, cb func(*ipnstate.PingResu
de.noteActiveLocked()
}
func (de *endpoint) sendv(buffs [][]byte) error {
now := mono.Now()
de.mu.Lock()
udpAddr, derpAddr := de.addrForSendLocked(now)
if de.canP2P() && (!udpAddr.IsValid() || now.After(de.trustBestAddrUntil)) {
de.sendPingsLocked(now, true)
}
de.noteActiveLocked()
de.mu.Unlock()
if !udpAddr.IsValid() && !derpAddr.IsValid() {
return errors.New("no UDP or DERP addr")
}
var err error
if udpAddr.IsValid() {
_, err = de.c.sendUDPBatch(udpAddr, buffs)
}
if derpAddr.IsValid() {
allOk := true
for _, buff := range buffs {
ok, _ := de.c.sendAddr(derpAddr, de.publicKey, buff)
if !ok {
allOk = false
}
}
if allOk {
return nil
}
}
return err
}
func (de *endpoint) send(b []byte) error {
now := mono.Now()
@ -4165,11 +4490,15 @@ func (s derpAddrFamSelector) PreferIPv6() bool {
metricSendDERPError = clientmetric.NewCounter("magicsock_send_derp_error")
// Data packets (non-disco)
metricSendData = clientmetric.NewCounter("magicsock_send_data")
metricSendDataNetworkDown = clientmetric.NewCounter("magicsock_send_data_network_down")
metricRecvDataDERP = clientmetric.NewCounter("magicsock_recv_data_derp")
metricRecvDataIPv4 = clientmetric.NewCounter("magicsock_recv_data_ipv4")
metricRecvDataIPv6 = clientmetric.NewCounter("magicsock_recv_data_ipv6")
metricSendData = clientmetric.NewCounter("magicsock_send_data")
metricSendDataNetworkDown = clientmetric.NewCounter("magicsock_send_data_network_down")
metricSendDataMultiPackets = clientmetric.NewCounter("magicsock_send_data_multiple_packets")
metricSendDataSinglePacket = clientmetric.NewCounter("magicsock_send_data_single_packet")
metricRecvDataDERP = clientmetric.NewCounter("magicsock_recv_data_derp")
metricRecvDataIPv4 = clientmetric.NewCounter("magicsock_recv_data_ipv4")
metricRecvDataIPv6 = clientmetric.NewCounter("magicsock_recv_data_ipv6")
metricRecvDataMultiPackets = clientmetric.NewCounter("magicsock_recv_data_multiple_packets")
metricRecvDataSinglePacket = clientmetric.NewCounter("magicsock_recv_data_single_packet")
// Disco packets
metricSendDiscoUDP = clientmetric.NewCounter("magicsock_disco_send_udp")