wgengine/...: split into multiple receive functions

Upstream wireguard-go has changed its receive model.
NewDevice now accepts a conn.Bind interface.

The conn.Bind is stateless; magicsock.Conns are stateful.
To work around this, we add a connBind type that supports
cheap teardown and bring-up, backed by a Conn.

The new conn.Bind allows us to specify a set of receive functions,
rather than having to shoehorn everything into ReceiveIPv4 and ReceiveIPv6.
This lets us plumbing DERP messages directly into wireguard-go,
instead of having to mux them via ReceiveIPv4.

One consequence of the new conn.Bind layer is that
closing the wireguard-go device is now indistinguishable
from the routine bring-up and tear-down normally experienced
by a conn.Bind. We thus have to explicitly close the magicsock.Conn
when the close the wireguard-go device.

One downside of this change is that we are reliant on wireguard-go
to call receiveDERP to process DERP messages. This is fine for now,
but is perhaps something we should fix in the future.

Signed-off-by: Josh Bleecher Snyder <josh@tailscale.com>
This commit is contained in:
Josh Bleecher Snyder
2021-03-24 09:41:57 -07:00
committed by Josh Bleecher Snyder
parent 2074dfa5e0
commit b3ceca1dd7
7 changed files with 191 additions and 271 deletions

View File

@@ -22,7 +22,6 @@ import (
"strconv"
"strings"
"sync"
"sync/atomic"
"testing"
"time"
"unsafe"
@@ -170,11 +169,7 @@ func newMagicStack(t testing.TB, logf logger.Logf, l nettype.PacketListener, der
tsTun.SetFilter(filter.NewAllowAllForTest(logf))
wgLogger := wglog.NewLogger(logf)
opts := &device.DeviceOptions{
CreateEndpoint: conn.CreateEndpoint,
CreateBind: conn.CreateBind,
}
dev := device.NewDevice(tsTun, wgLogger.DeviceLogger, opts)
dev := device.NewDevice(tsTun, conn.Bind(), wgLogger.DeviceLogger, new(device.DeviceOptions))
dev.Up()
// Wait for magicsock to connect up to DERP.
@@ -363,7 +358,7 @@ func TestNewConn(t *testing.T) {
go func() {
var pkt [64 << 10]byte
for {
_, _, err := conn.ReceiveIPv4(pkt[:])
_, _, err := conn.receiveIPv4(pkt[:])
if err != nil {
return
}
@@ -521,11 +516,7 @@ func TestDeviceStartStop(t *testing.T) {
tun := tuntest.NewChannelTUN()
wgLogger := wglog.NewLogger(t.Logf)
opts := &device.DeviceOptions{
CreateEndpoint: conn.CreateEndpoint,
CreateBind: conn.CreateBind,
}
dev := device.NewDevice(tun.TUN(), wgLogger.DeviceLogger, opts)
dev := device.NewDevice(tun.TUN(), conn.Bind(), wgLogger.DeviceLogger, new(device.DeviceOptions))
dev.Up()
dev.Close()
}
@@ -1382,14 +1373,10 @@ func stringifyConfig(cfg wgcfg.Config) string {
func Test32bitAlignment(t *testing.T) {
var de discoEndpoint
var c Conn
if off := unsafe.Offsetof(de.lastRecvUnixAtomic); off%8 != 0 {
t.Fatalf("discoEndpoint.lastRecvUnixAtomic is not 8-byte aligned")
}
if off := unsafe.Offsetof(c.derpRecvCountAtomic); off%8 != 0 {
t.Fatalf("Conn.derpRecvCountAtomic is not 8-byte aligned")
}
if !de.isFirstRecvActivityInAwhile() { // verify this doesn't panic on 32-bit
t.Error("expected true")
@@ -1397,7 +1384,6 @@ func Test32bitAlignment(t *testing.T) {
if de.isFirstRecvActivityInAwhile() {
t.Error("expected false on second call")
}
atomic.AddInt64(&c.derpRecvCountAtomic, 1)
}
// newNonLegacyTestConn returns a new Conn with DisableLegacyNetworking set true.
@@ -1418,92 +1404,6 @@ func newNonLegacyTestConn(t testing.TB) *Conn {
return conn
}
// Tests concurrent DERP readers pushing DERP data into ReceiveIPv4
// (which should blend all DERP reads into UDP reads).
func TestDerpReceiveFromIPv4(t *testing.T) {
conn := newNonLegacyTestConn(t)
defer conn.Close()
sendConn, err := net.ListenPacket("udp4", "127.0.0.1:0")
if err != nil {
t.Fatal(err)
}
defer sendConn.Close()
nodeKey, _ := addTestEndpoint(t, conn, sendConn)
var sends int = 250e3 // takes about a second
if testing.Short() {
sends /= 10
}
senders := runtime.NumCPU()
sends -= (sends % senders)
var wg sync.WaitGroup
defer wg.Wait()
t.Logf("doing %v sends over %d senders", sends, senders)
ctx, cancel := context.WithCancel(context.Background())
defer conn.Close()
defer cancel()
doneCtx, cancelDoneCtx := context.WithCancel(context.Background())
cancelDoneCtx()
for i := 0; i < senders; i++ {
wg.Add(1)
regionID := i + 1
go func() {
defer wg.Done()
for i := 0; i < sends/senders; i++ {
res := derpReadResult{
regionID: regionID,
n: 123,
src: key.Public(nodeKey),
copyBuf: func(dst []byte) int { return 123 },
}
// First send with the closed context. ~50% of
// these should end up going through the
// send-a-zero-derpReadResult path, returning
// true, in which case we don't want to send again.
// We test later that we hit the other path.
if conn.sendDerpReadResult(doneCtx, res) {
continue
}
if !conn.sendDerpReadResult(ctx, res) {
t.Error("unexpected false")
return
}
}
}()
}
zeroSendsStart := testCounterZeroDerpReadResultSend.Value()
buf := make([]byte, 1500)
for i := 0; i < sends; i++ {
n, ep, err := conn.ReceiveIPv4(buf)
if err != nil {
t.Fatal(err)
}
_ = n
_ = ep
}
t.Logf("did %d ReceiveIPv4 calls", sends)
zeroSends, zeroRecv := testCounterZeroDerpReadResultSend.Value(), testCounterZeroDerpReadResultRecv.Value()
if zeroSends != zeroRecv {
t.Errorf("did %d zero sends != %d corresponding receives", zeroSends, zeroRecv)
}
zeroSendDelta := zeroSends - zeroSendsStart
if zeroSendDelta == 0 {
t.Errorf("didn't see any sends of derpReadResult zero value")
}
if zeroSendDelta == int64(sends) {
t.Errorf("saw %v sends of the derpReadResult zero value which was unexpectedly high (100%% of our %v sends)", zeroSendDelta, sends)
}
}
// addTestEndpoint sets conn's network map to a single peer expected
// to receive packets from sendConn (or DERP), and returns that peer's
// nodekey and discokey.
@@ -1523,7 +1423,7 @@ func addTestEndpoint(tb testing.TB, conn *Conn, sendConn net.PacketConn) (tailcf
},
})
conn.SetPrivateKey(wgkey.Private{0: 1})
_, err := conn.CreateEndpoint(string(nodeKey[:]) + "0000000000000000000000000000000000000000000000000000000000000001.disco.tailscale:12345")
_, err := conn.ParseEndpoint(string(nodeKey[:]) + "0000000000000000000000000000000000000000000000000000000000000001.disco.tailscale:12345")
if err != nil {
tb.Fatal(err)
}
@@ -1554,7 +1454,7 @@ func setUpReceiveFrom(tb testing.TB) (roundTrip func()) {
if _, err := sendConn.WriteTo(sendBuf, dstAddr); err != nil {
tb.Fatalf("WriteTo: %v", err)
}
n, ep, err := conn.ReceiveIPv4(buf)
n, ep, err := conn.receiveIPv4(buf)
if err != nil {
tb.Fatal(err)
}
@@ -1697,7 +1597,7 @@ func TestSetNetworkMapChangingNodeKey(t *testing.T) {
},
},
})
_, err := conn.CreateEndpoint(string(nodeKey1[:]) + "0000000000000000000000000000000000000000000000000000000000000001.disco.tailscale:12345")
_, err := conn.ParseEndpoint(string(nodeKey1[:]) + "0000000000000000000000000000000000000000000000000000000000000001.disco.tailscale:12345")
if err != nil {
t.Fatal(err)
}
@@ -1755,7 +1655,7 @@ func TestRebindStress(t *testing.T) {
go func() {
buf := make([]byte, 1500)
for {
_, _, err := conn.ReceiveIPv4(buf)
_, _, err := conn.receiveIPv4(buf)
if ctx.Err() != nil {
errc <- nil
return