From 11eb71701294a2e2d731d38ee9a9d8214e91a0b2 Mon Sep 17 00:00:00 2001 From: Jordan Whited Date: Mon, 8 Aug 2022 13:58:08 -0700 Subject: [PATCH] wgengine/magicsock: implement wireguard-go conn.VectorBind --- go.mod | 2 + go.sum | 4 + types/nettype/nettype.go | 41 +++- wgengine/magicsock/magicsock.go | 339 +++++++++++++++++++++++++++++++- 4 files changed, 380 insertions(+), 6 deletions(-) diff --git a/go.mod b/go.mod index 5257e85aa..d10bd7275 100644 --- a/go.mod +++ b/go.mod @@ -2,6 +2,8 @@ module tailscale.com go 1.19 +replace golang.zx2c4.com/wireguard => /Users/jwhited/code/wireguard-go + require ( filippo.io/mkcert v1.4.3 github.com/akutz/memconn v0.1.0 diff --git a/go.sum b/go.sum index 5938ecbc0..675aacfa1 100644 --- a/go.sum +++ b/go.sum @@ -1250,6 +1250,7 @@ golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97/go.mod h1:GvvjBRRGRdwPK5y golang.org/x/crypto v0.0.0-20210817164053-32db794688a5/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= golang.org/x/crypto v0.0.0-20211117183948-ae814b36b871/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= +golang.org/x/crypto v0.0.0-20220315160706-3147a52a75dd/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= golang.org/x/crypto v0.0.0-20220427172511-eb4f295cb31f h1:OeJjE6G4dgCY4PIXvIRQbE8+RX+uXZyGhUy/ksMGJoc= golang.org/x/crypto v0.0.0-20220427172511-eb4f295cb31f/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= @@ -1357,6 +1358,7 @@ golang.org/x/net v0.0.0-20210928044308-7d9f5e0b762b/go.mod h1:9nx3DQGgdP8bBQD5qx golang.org/x/net v0.0.0-20211015210444-4f30a5c0130f/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= golang.org/x/net v0.0.0-20211112202133-69e39bad7dc2/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= golang.org/x/net v0.0.0-20220127200216-cd36cc0744dd/go.mod h1:CfG3xpIq0wQ8r1q4Su4UZFWDARRcnwPjda9FqA0JpMk= +golang.org/x/net v0.0.0-20220225172249-27dd8689420f/go.mod h1:CfG3xpIq0wQ8r1q4Su4UZFWDARRcnwPjda9FqA0JpMk= golang.org/x/net v0.0.0-20220607020251-c690dde0001d h1:4SFsTMi4UahlKoloni7L4eYzhFRifURQLw+yv0QDCx8= golang.org/x/net v0.0.0-20220607020251-c690dde0001d/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= @@ -1492,9 +1494,11 @@ golang.org/x/sys v0.0.0-20211002104244-808efd93c36d/go.mod h1:oPkhp1MJrh7nUepCBc golang.org/x/sys v0.0.0-20211013075003-97ac67df715c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20211019181941-9d821ace8654/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20211102192858-4dd72447c267/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20211103235746-7861aae1554b/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20211105183446-c75c47738b0c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20211216021012-1d35b9e2eb4e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220128215802-99c3d69c2c27/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220315194320-039c03cc5b86/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220610221304-9f5ed59c137d h1:Zu/JngovGLVi6t2J3nmAf3AoTDwuzw85YZ3b9o4yU7s= golang.org/x/sys v0.0.0-20220610221304-9f5ed59c137d/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/term v0.0.0-20201117132131-f5c789dd3221/go.mod h1:Nr5EML6q2oocZ2LXRh80K7BxOlk5/8JxuGnuhpl+muw= diff --git a/types/nettype/nettype.go b/types/nettype/nettype.go index a1506e115..82a53d372 100644 --- a/types/nettype/nettype.go +++ b/types/nettype/nettype.go @@ -9,6 +9,9 @@ "context" "net" "net/netip" + + "golang.org/x/net/ipv4" + "golang.org/x/net/ipv6" ) // PacketListener defines the ListenPacket method as implemented @@ -42,10 +45,46 @@ type packetListenerAdapter struct { PacketListener } +type packetConnWithBatch struct { + PacketConn + xpc4 *ipv4.PacketConn + xpc6 *ipv6.PacketConn +} + +func (p packetConnWithBatch) WriteBatchIPv4(ms []ipv4.Message, flags int) (int, error) { + return p.xpc4.WriteBatch(ms, flags) +} + +func (p packetConnWithBatch) ReadBatchIPv4(ms []ipv4.Message, flags int) (int, error) { + return p.xpc4.ReadBatch(ms, flags) +} + +func (p packetConnWithBatch) WriteBatchIPv6(ms []ipv6.Message, flags int) (int, error) { + return p.xpc6.WriteBatch(ms, flags) +} + +func (p packetConnWithBatch) ReadBatchIPv6(ms []ipv6.Message, flags int) (int, error) { + return p.xpc6.ReadBatch(ms, flags) +} + func (a packetListenerAdapter) ListenPacket(ctx context.Context, network, address string) (PacketConn, error) { pc, err := a.PacketListener.ListenPacket(ctx, network, address) if err != nil { return nil, err } - return pc.(PacketConn), nil + return packetConnWithBatch{ + PacketConn: pc.(PacketConn), + xpc4: ipv4.NewPacketConn(pc), + xpc6: ipv6.NewPacketConn(pc), + }, nil +} + +type BatchWriter interface { + WriteBatchIPv4([]ipv4.Message, int) (int, error) + WriteBatchIPv6([]ipv6.Message, int) (int, error) +} + +type BatchReader interface { + ReadBatchIPv4([]ipv4.Message, int) (int, error) + ReadBatchIPv6([]ipv6.Message, int) (int, error) } diff --git a/wgengine/magicsock/magicsock.go b/wgengine/magicsock/magicsock.go index 8e71570ee..c16bcf641 100644 --- a/wgengine/magicsock/magicsock.go +++ b/wgengine/magicsock/magicsock.go @@ -13,6 +13,7 @@ "encoding/binary" "errors" "fmt" + "golang.org/x/net/ipv6" "hash/fnv" "math" "math/rand" @@ -28,6 +29,7 @@ "time" "go4.org/mem" + "golang.org/x/net/ipv4" "golang.zx2c4.com/wireguard/conn" "tailscale.com/control/controlclient" "tailscale.com/derp" @@ -1179,6 +1181,25 @@ func (c *Conn) Send(b []byte, ep conn.Endpoint) error { return ep.(*endpoint).send(b) } +func (c *Conn) SendV(buffs [][]byte, ep conn.Endpoint) error { + n := int64(len(buffs)) + metricSendData.Add(n) + if n > 1 { + metricSendDataMultiPackets.Add(1) + } else { + metricSendDataSinglePacket.Add(1) + } + if c.networkDown() { + metricSendDataNetworkDown.Add(n) + return errNetworkDown + } + return ep.(*endpoint).sendv(buffs) +} + +func (c *Conn) MaxVectorSize() int { + return maxVectorSize +} + var errConnClosed = errors.New("Conn closed") var errDropDerpPacket = errors.New("too many DERP packets queued; dropping") @@ -1202,6 +1223,84 @@ func (c *Conn) sendUDP(ipp netip.AddrPort, b []byte) (sent bool, err error) { return } +const ( + maxVectorSize = 20 +) + +type ipv4SendBatch struct { + ua *net.UDPAddr + msgs []ipv4.Message +} + +var ipv4SendBatchPool = &sync.Pool{ + New: func() any { + ua := &net.UDPAddr{ + IP: make([]byte, 4), + } + msgs := make([]ipv4.Message, maxVectorSize) + for i := range msgs { + msgs[i].Buffers = make([][]byte, 1) + msgs[i].Addr = ua + } + return &ipv4SendBatch{ + ua: ua, + msgs: msgs, + } + }, +} + +type ipv6SendBatch struct { + ua *net.UDPAddr + msgs []ipv6.Message +} + +var ipv6SendBatchPool = &sync.Pool{ + New: func() any { + ua := &net.UDPAddr{ + IP: make([]byte, 4), + } + msgs := make([]ipv6.Message, maxVectorSize) + for i := range msgs { + msgs[i].Buffers = make([][]byte, 1) + msgs[i].Addr = ua + } + return &ipv6SendBatch{ + ua: ua, + msgs: msgs, + } + }, +} + +func (c *Conn) sendUDPBatch(addr netip.AddrPort, buffs [][]byte) (sent bool, err error) { + switch { + case addr.Addr().Is4(): + batch := ipv4SendBatchPool.Get().(*ipv4SendBatch) + as4 := addr.Addr().As4() + copy(batch.ua.IP, as4[:]) + batch.ua.Port = int(addr.Port()) + for i, buff := range buffs { + batch.msgs[i].Buffers[0] = buff + } + _, err := c.pconn4.WriteBatchIPv4(batch.msgs[:len(buffs)], 0) + ipv4SendBatchPool.Put(batch) + return err == nil, err + case addr.Addr().Is6(): + batch := ipv6SendBatchPool.Get().(*ipv6SendBatch) + as16 := addr.Addr().As16() + copy(batch.ua.IP, as16[:]) + batch.ua.Port = int(addr.Port()) + for i, buff := range buffs { + batch.msgs[i].Buffers[0] = buff + } + _, err := c.pconn6.WriteBatchIPv6(batch.msgs[:len(buffs)], 0) + ipv6SendBatchPool.Put(batch) + return err == nil, err + default: + panic("bogus sendUDPBatch addr type") + } + return err == nil, err +} + // sendUDP sends UDP packet b to addr. // See sendAddr's docs on the return value meanings. func (c *Conn) sendUDPStd(addr netip.AddrPort, b []byte) (sent bool, err error) { @@ -1631,6 +1730,77 @@ func (c *Conn) runDerpWriter(ctx context.Context, dc *derphttp.Client, ch <-chan } } +type ipv4ReceiveBatch struct { + msgs []ipv4.Message + sizes []int + endpoints []conn.Endpoint +} + +type ipv6ReceiveBatch struct { + msgs []ipv6.Message + sizes []int + endpoints []conn.Endpoint +} + +var ( + ipv4RB *ipv4ReceiveBatch + ipv6RB *ipv6ReceiveBatch +) + +func init() { + ipv4Msgs := make([]ipv4.Message, maxVectorSize) + ipv6Msgs := make([]ipv4.Message, maxVectorSize) + for i := range ipv4Msgs { + ipv4Msgs[i].Buffers = make([][]byte, 1) + ipv6Msgs[i].Buffers = make([][]byte, 1) + } + ipv4RB = &ipv4ReceiveBatch{ + msgs: ipv4Msgs, + sizes: make([]int, maxVectorSize), + endpoints: make([]conn.Endpoint, maxVectorSize), + } + ipv6RB = &ipv6ReceiveBatch{ + msgs: ipv6Msgs, + sizes: make([]int, maxVectorSize), + endpoints: make([]conn.Endpoint, maxVectorSize), + } +} + +func (c *Conn) receiveMultipleIPv6(buffs [][]byte) (sizes []int, eps []conn.Endpoint, err error) { + health.ReceiveIPv6.Enter() + defer health.ReceiveIPv6.Exit() + for { + batch := ipv6RB + for i := range buffs { + batch.msgs[i].Buffers[0] = buffs[i] + } + numMsgs, err := c.pconn6.ReadBatchIPv6(batch.msgs, 0) + if err != nil { + return nil, nil, err + } + for i := 0; i < numMsgs; i++ { + msg := &batch.msgs[i] + msg.Buffers[0] = msg.Buffers[0][:msg.N] + ipp := msg.Addr.(*net.UDPAddr).AddrPort() + if ep, ok := c.receiveIP(msg.Buffers[0], ipp, &c.ippEndpoint6); ok { + metricRecvDataIPv6.Add(1) + if numMsgs > 1 { + metricRecvDataMultiPackets.Add(1) + } else { + metricRecvDataSinglePacket.Add(1) + } + batch.sizes[i] = msg.N + batch.endpoints[i] = ep + } else { + batch.sizes[i] = 0 + } + } + if len(batch.sizes) > 0 { + return batch.sizes[:numMsgs], batch.endpoints[:numMsgs], nil + } + } +} + // receiveIPv6 receives a UDP IPv6 packet. It is called by wireguard-go. func (c *Conn) receiveIPv6(b []byte) (int, conn.Endpoint, error) { health.ReceiveIPv6.Enter() @@ -1647,6 +1817,41 @@ func (c *Conn) receiveIPv6(b []byte) (int, conn.Endpoint, error) { } } +func (c *Conn) receiveMultipleIPv4(buffs [][]byte) ([]int, []conn.Endpoint, error) { + health.ReceiveIPv4.Enter() + defer health.ReceiveIPv4.Exit() + for { + batch := ipv4RB + for i := range buffs { + batch.msgs[i].Buffers[0] = buffs[i] + } + numMsgs, err := c.pconn4.ReadBatchIPv4(batch.msgs, 0) + if err != nil { + return nil, nil, err + } + for i := 0; i < numMsgs; i++ { + msg := &batch.msgs[i] + msg.Buffers[0] = msg.Buffers[0][:msg.N] + ipp := msg.Addr.(*net.UDPAddr).AddrPort() + if ep, ok := c.receiveIP(msg.Buffers[0], ipp, &c.ippEndpoint4); ok { + metricRecvDataIPv4.Add(1) + if numMsgs > 1 { + metricRecvDataMultiPackets.Add(1) + } else { + metricRecvDataSinglePacket.Add(1) + } + batch.sizes[i] = msg.N + batch.endpoints[i] = ep + } else { + batch.sizes[i] = 0 + } + } + if len(batch.sizes) > 0 { + return batch.sizes[:numMsgs], batch.endpoints[:numMsgs], nil + } + } +} + // receiveIPv4 receives a UDP IPv4 packet. It is called by wireguard-go. func (c *Conn) receiveIPv4(b []byte) (n int, ep conn.Endpoint, err error) { health.ReceiveIPv4.Enter() @@ -1699,6 +1904,11 @@ func (c *Conn) receiveIP(b []byte, ipp netip.AddrPort, cache *ippEndpointCache) return ep, true } +func (c *connBind) receiveMultipleDERP(b [][]byte) (sizes []int, eps []conn.Endpoint, err error) { + n, ep, err := c.receiveDERP(b[0]) + return []int{n}, []conn.Endpoint{ep}, err +} + // receiveDERP reads a packet from c.derpRecvCh into b and returns the associated endpoint. // It is called by wireguard-go. // @@ -2594,6 +2804,17 @@ type connBind struct { closed bool } +func (c *connBind) OpenV(_ uint16) ([]conn.ReceiveVFunc, uint16, error) { + c.mu.Lock() + defer c.mu.Unlock() + if !c.closed { + return nil, 0, errors.New("magicsock: connBind already open") + } + c.closed = false + fns := []conn.ReceiveVFunc{c.receiveMultipleIPv4, c.receiveMultipleIPv6, c.receiveMultipleDERP} + return fns, c.LocalPort(), nil +} + // Open is called by WireGuard to create a UDP binding. // The ignoredPort comes from wireguard-go, via the wgcfg config. // We ignore that port value here, since we have the local port available easily. @@ -3012,6 +3233,36 @@ func (c *RebindingUDPConn) ReadFrom(b []byte) (int, net.Addr, error) { } } +func (c *RebindingUDPConn) ReadBatchIPv4(msgs []ipv4.Message, flags int) (int, error) { + for { + pconn := c.pconnAtomic.Load() + br, ok := pconn.(nettype.BatchReader) + if !ok { + panic("pconn is not a nettype.BatchReader") + } + n, err := br.ReadBatchIPv4(msgs, flags) + if err != nil && pconn != c.currentConn() { + continue + } + return n, err + } +} + +func (c *RebindingUDPConn) ReadBatchIPv6(msgs []ipv6.Message, flags int) (int, error) { + for { + pconn := c.pconnAtomic.Load() + br, ok := pconn.(nettype.BatchReader) + if !ok { + panic("pconn is not a nettype.BatchReader") + } + n, err := br.ReadBatchIPv6(msgs, flags) + if err != nil && pconn != c.currentConn() { + continue + } + return n, err + } +} + // ReadFromNetaddr reads a packet from c into b. // It returns the number of bytes copied and the return address. // It is identical to c.ReadFrom, except that it returns a netip.AddrPort instead of a net.Addr. @@ -3106,6 +3357,42 @@ func (c *RebindingUDPConn) WriteToUDPAddrPort(b []byte, addr netip.AddrPort) (in } } +func (c *RebindingUDPConn) WriteBatchIPv4(msgs []ipv4.Message, flags int) (int, error) { + for { + pconn := c.pconnAtomic.Load() + bw, ok := pconn.(nettype.BatchWriter) + if !ok { + return 0, errors.New("pconn is not a nettype.BatchWriter()") + } + + n, err := bw.WriteBatchIPv4(msgs, flags) + if err != nil { + if pconn != c.currentConn() { + continue + } + } + return n, err + } +} + +func (c *RebindingUDPConn) WriteBatchIPv6(msgs []ipv6.Message, flags int) (int, error) { + for { + pconn := c.pconnAtomic.Load() + bw, ok := pconn.(nettype.BatchWriter) + if !ok { + return 0, errors.New("pconn is not a nettype.BatchWriter()") + } + + n, err := bw.WriteBatchIPv6(msgs, flags) + if err != nil { + if pconn != c.currentConn() { + continue + } + } + return n, err + } +} + func newBlockForeverConn() *blockForeverConn { c := new(blockForeverConn) c.cond = sync.NewCond(&c.mu) @@ -3138,6 +3425,11 @@ func (c *blockForeverConn) WriteToUDPAddrPort(p []byte, addr netip.AddrPort) (in return len(p), nil } +func (c *blockForeverConn) WriteBatch(p []ipv4.Message, flags int) (int, error) { + // Silently drop writes. + return len(p), nil +} + func (c *blockForeverConn) LocalAddr() net.Addr { // Return a *net.UDPAddr because lots of code assumes that it will. return new(net.UDPAddr) @@ -3577,6 +3869,39 @@ func (de *endpoint) cliPing(res *ipnstate.PingResult, cb func(*ipnstate.PingResu de.noteActiveLocked() } +func (de *endpoint) sendv(buffs [][]byte) error { + now := mono.Now() + + de.mu.Lock() + udpAddr, derpAddr := de.addrForSendLocked(now) + if de.canP2P() && (!udpAddr.IsValid() || now.After(de.trustBestAddrUntil)) { + de.sendPingsLocked(now, true) + } + de.noteActiveLocked() + de.mu.Unlock() + + if !udpAddr.IsValid() && !derpAddr.IsValid() { + return errors.New("no UDP or DERP addr") + } + var err error + if udpAddr.IsValid() { + _, err = de.c.sendUDPBatch(udpAddr, buffs) + } + if derpAddr.IsValid() { + allOk := true + for _, buff := range buffs { + ok, _ := de.c.sendAddr(derpAddr, de.publicKey, buff) + if !ok { + allOk = false + } + } + if allOk { + return nil + } + } + return err +} + func (de *endpoint) send(b []byte) error { now := mono.Now() @@ -4165,11 +4490,15 @@ func (s derpAddrFamSelector) PreferIPv6() bool { metricSendDERPError = clientmetric.NewCounter("magicsock_send_derp_error") // Data packets (non-disco) - metricSendData = clientmetric.NewCounter("magicsock_send_data") - metricSendDataNetworkDown = clientmetric.NewCounter("magicsock_send_data_network_down") - metricRecvDataDERP = clientmetric.NewCounter("magicsock_recv_data_derp") - metricRecvDataIPv4 = clientmetric.NewCounter("magicsock_recv_data_ipv4") - metricRecvDataIPv6 = clientmetric.NewCounter("magicsock_recv_data_ipv6") + metricSendData = clientmetric.NewCounter("magicsock_send_data") + metricSendDataNetworkDown = clientmetric.NewCounter("magicsock_send_data_network_down") + metricSendDataMultiPackets = clientmetric.NewCounter("magicsock_send_data_multiple_packets") + metricSendDataSinglePacket = clientmetric.NewCounter("magicsock_send_data_single_packet") + metricRecvDataDERP = clientmetric.NewCounter("magicsock_recv_data_derp") + metricRecvDataIPv4 = clientmetric.NewCounter("magicsock_recv_data_ipv4") + metricRecvDataIPv6 = clientmetric.NewCounter("magicsock_recv_data_ipv6") + metricRecvDataMultiPackets = clientmetric.NewCounter("magicsock_recv_data_multiple_packets") + metricRecvDataSinglePacket = clientmetric.NewCounter("magicsock_recv_data_single_packet") // Disco packets metricSendDiscoUDP = clientmetric.NewCounter("magicsock_disco_send_udp")