wgengine/magicsock: fix DERP reader hang regression during concurrent reads

Fixes #1282

Signed-off-by: Brad Fitzpatrick <bradfitz@tailscale.com>
This commit is contained in:
Brad Fitzpatrick 2021-02-06 22:39:58 -08:00 committed by Brad Fitzpatrick
parent e1f773ebba
commit 6b365b0239
2 changed files with 92 additions and 37 deletions

View File

@ -12,6 +12,7 @@ import (
crand "crypto/rand" crand "crypto/rand"
"encoding/binary" "encoding/binary"
"errors" "errors"
"expvar"
"fmt" "fmt"
"hash/fnv" "hash/fnv"
"math" "math"
@ -155,13 +156,10 @@ type Conn struct {
// derpRecvCh is used by ReceiveIPv4 to read DERP messages. // derpRecvCh is used by ReceiveIPv4 to read DERP messages.
derpRecvCh chan derpReadResult derpRecvCh chan derpReadResult
// derpRecvCountAtomic is atomically incremented by runDerpReader whenever // derpRecvCountAtomic is how many derpRecvCh sends are pending.
// a DERP message arrives. It's incremented before runDerpReader is interrupted. // It's incremented by runDerpReader whenever a DERP message
// arrives and decremented when they're read.
derpRecvCountAtomic int64 derpRecvCountAtomic int64
// derpRecvCountLast is used by ReceiveIPv4 to compare against
// its last read value of derpRecvCountAtomic to determine
// whether a DERP channel read should be done.
derpRecvCountLast int64 // owned by ReceiveIPv4
// ippEndpoint4 and ippEndpoint6 are owned by ReceiveIPv4 and // ippEndpoint4 and ippEndpoint6 are owned by ReceiveIPv4 and
// ReceiveIPv6, respectively, to cache an IPPort->endpoint for // ReceiveIPv6, respectively, to cache an IPPort->endpoint for
@ -1358,6 +1356,8 @@ type derpReadResult struct {
// copyBuf is called to copy the data to dst. It returns how // copyBuf is called to copy the data to dst. It returns how
// much data was copied, which will be n if dst is large // much data was copied, which will be n if dst is large
// enough. copyBuf can only be called once. // enough. copyBuf can only be called once.
// If copyBuf is nil, that's a signal from the sender to ignore
// this message.
copyBuf func(dst []byte) int copyBuf func(dst []byte) int
} }
@ -1458,6 +1458,11 @@ func (c *Conn) runDerpReader(ctx context.Context, derpFakeAddr netaddr.IPPort, d
} }
} }
var (
testCounterZeroDerpReadResultSend expvar.Int
testCounterZeroDerpReadResultRecv expvar.Int
)
// sendDerpReadResult sends res to c.derpRecvCh and reports whether it // sendDerpReadResult sends res to c.derpRecvCh and reports whether it
// was sent. (It reports false if ctx was done first.) // was sent. (It reports false if ctx was done first.)
// //
@ -1465,7 +1470,7 @@ func (c *Conn) runDerpReader(ctx context.Context, derpFakeAddr netaddr.IPPort, d
// ReceiveIPv4's blocking UDP read. // ReceiveIPv4's blocking UDP read.
func (c *Conn) sendDerpReadResult(ctx context.Context, res derpReadResult) (sent bool) { func (c *Conn) sendDerpReadResult(ctx context.Context, res derpReadResult) (sent bool) {
// Before we wake up ReceiveIPv4 with SetReadDeadline, // Before we wake up ReceiveIPv4 with SetReadDeadline,
// note that a DERP packet has arrived. ReceiveIPv4 // note that a DERP packet has arrived. ReceiveIPv4
// will read this field to note that its UDP read // will read this field to note that its UDP read
// error is due to us. // error is due to us.
atomic.AddInt64(&c.derpRecvCountAtomic, 1) atomic.AddInt64(&c.derpRecvCountAtomic, 1)
@ -1473,6 +1478,23 @@ func (c *Conn) sendDerpReadResult(ctx context.Context, res derpReadResult) (sent
c.pconn4.SetReadDeadline(aLongTimeAgo) c.pconn4.SetReadDeadline(aLongTimeAgo)
select { select {
case <-ctx.Done(): case <-ctx.Done():
select {
case <-c.donec:
// The whole Conn shut down. The reader of
// c.derpRecvCh also selects on c.donec, so it's
// safe to abort now.
case c.derpRecvCh <- (derpReadResult{}):
// Just this DERP reader is closing (perhaps
// the user is logging out, or the DERP
// connection is too idle for sends). Since we
// already incremented c.derpRecvCountAtomic,
// we need to send on the channel (unless the
// conn is going down).
// The receiver treats a derpReadResult zero value
// message as a skip.
testCounterZeroDerpReadResultSend.Add(1)
}
return false return false
case c.derpRecvCh <- res: case c.derpRecvCh <- res:
return true return true
@ -1568,20 +1590,20 @@ func (c *Conn) ReceiveIPv6(b []byte) (int, conn.Endpoint, error) {
} }
func (c *Conn) derpPacketArrived() bool { func (c *Conn) derpPacketArrived() bool {
rc := atomic.LoadInt64(&c.derpRecvCountAtomic) return atomic.LoadInt64(&c.derpRecvCountAtomic) > 0
if rc != c.derpRecvCountLast {
c.derpRecvCountLast = rc
return true
}
return false
} }
// ReceiveIPv4 is called by wireguard-go to receive an IPv4 packet. // ReceiveIPv4 is called by wireguard-go to receive an IPv4 packet.
// In Tailscale's case, that packet might also arrive via DERP. A DERP packet arrival // In Tailscale's case, that packet might also arrive via DERP. A DERP packet arrival
// aborts the pconn4 read deadline to make it fail. // aborts the pconn4 read deadline to make it fail.
func (c *Conn) ReceiveIPv4(b []byte) (n int, ep conn.Endpoint, err error) { func (c *Conn) ReceiveIPv4(b []byte) (n int, ep conn.Endpoint, err error) {
var pAddr net.Addr
for { for {
n, pAddr, err := c.pconn4.ReadFrom(b) // Drain DERP queues before reading new UDP packets.
if c.derpPacketArrived() {
goto ReadDERP
}
n, pAddr, err = c.pconn4.ReadFrom(b)
if err != nil { if err != nil {
// If the pconn4 read failed, the likely reason is a DERP reader received // If the pconn4 read failed, the likely reason is a DERP reader received
// a packet and interrupted us. // a packet and interrupted us.
@ -1589,18 +1611,21 @@ func (c *Conn) ReceiveIPv4(b []byte) (n int, ep conn.Endpoint, err error) {
// and for there to have also had a DERP packet arrive, but that's fine: // and for there to have also had a DERP packet arrive, but that's fine:
// we'll get the same error from ReadFrom later. // we'll get the same error from ReadFrom later.
if c.derpPacketArrived() { if c.derpPacketArrived() {
c.pconn4.SetReadDeadline(time.Time{}) // restore goto ReadDERP
n, ep, err = c.receiveIPv4DERP(b)
if err == errLoopAgain {
continue
}
return n, ep, err
} }
return 0, nil, err return 0, nil, err
} }
if ep, ok := c.receiveIP(b[:n], pAddr.(*net.UDPAddr), &c.ippEndpoint4); ok { if ep, ok := c.receiveIP(b[:n], pAddr.(*net.UDPAddr), &c.ippEndpoint4); ok {
return n, ep, nil return n, ep, nil
} else {
continue
} }
ReadDERP:
n, ep, err = c.receiveIPv4DERP(b)
if err == errLoopAgain {
continue
}
return n, ep, err
} }
} }
@ -1668,6 +1693,13 @@ func (c *Conn) receiveIPv4DERP(b []byte) (n int, ep conn.Endpoint, err error) {
case dm = <-c.derpRecvCh: case dm = <-c.derpRecvCh:
// Below. // Below.
} }
if atomic.AddInt64(&c.derpRecvCountAtomic, -1) == 0 {
c.pconn4.SetReadDeadline(time.Time{})
}
if dm.copyBuf == nil {
testCounterZeroDerpReadResultRecv.Add(1)
return 0, nil, errLoopAgain
}
var regionID int var regionID int
n, regionID = dm.n, dm.regionID n, regionID = dm.n, dm.regionID

View File

@ -11,7 +11,6 @@ import (
"crypto/tls" "crypto/tls"
"encoding/binary" "encoding/binary"
"encoding/json" "encoding/json"
"flag"
"fmt" "fmt"
"io/ioutil" "io/ioutil"
"net" "net"
@ -1435,8 +1434,6 @@ func newNonLegacyTestConn(t testing.TB) *Conn {
return conn return conn
} }
var testIssue1282 = flag.Bool("test-issue-1282", false, "run test for https://github.com/tailscale/tailscale/issues/1282 on all CPUs")
// Tests concurrent DERP readers pushing DERP data into ReceiveIPv4 // Tests concurrent DERP readers pushing DERP data into ReceiveIPv4
// (which should blend all DERP reads into UDP reads). // (which should blend all DERP reads into UDP reads).
func TestDerpReceiveFromIPv4(t *testing.T) { func TestDerpReceiveFromIPv4(t *testing.T) {
@ -1450,42 +1447,54 @@ func TestDerpReceiveFromIPv4(t *testing.T) {
defer sendConn.Close() defer sendConn.Close()
nodeKey, _ := addTestEndpoint(conn, sendConn) nodeKey, _ := addTestEndpoint(conn, sendConn)
var sends int = 500 var sends int = 250e3 // takes about a second
senders := runtime.NumCPU() if testing.Short() {
if !*testIssue1282 { sends /= 10
t.Logf("--test-issue-1282 was not specified; so doing single-threaded test (instead of NumCPU=%d) to work around https://github.com/tailscale/tailscale/issues/1282", senders)
senders = 1
} }
senders := runtime.NumCPU()
sends -= (sends % senders) sends -= (sends % senders)
var wg sync.WaitGroup var wg sync.WaitGroup
defer wg.Wait() defer wg.Wait()
t.Logf("doing %v sends over %d senders", sends, senders) t.Logf("doing %v sends over %d senders", sends, senders)
ctx := context.Background()
ctx, cancel := context.WithCancel(context.Background())
defer conn.Close()
defer cancel()
doneCtx, cancelDoneCtx := context.WithCancel(context.Background())
cancelDoneCtx()
for i := 0; i < senders; i++ { for i := 0; i < senders; i++ {
wg.Add(1) wg.Add(1)
regionID := i + 1 regionID := i + 1
go func() { go func() {
defer wg.Done() defer wg.Done()
ch := make(chan bool, 1)
for i := 0; i < sends/senders; i++ { for i := 0; i < sends/senders; i++ {
if !conn.sendDerpReadResult(ctx, derpReadResult{ res := derpReadResult{
regionID: regionID, regionID: regionID,
n: 123, n: 123,
src: key.Public(nodeKey), src: key.Public(nodeKey),
copyBuf: func(dst []byte) int { copyBuf: func(dst []byte) int { return 123 },
ch <- true }
return 123 // First send with the closed context. ~50% of
}, // these should end up going through the
}) { // send-a-zero-derpReadResult path, returning
// true, in which case we don't want to send again.
// We test later that we hit the other path.
if conn.sendDerpReadResult(doneCtx, res) {
continue
}
if !conn.sendDerpReadResult(ctx, res) {
t.Error("unexpected false") t.Error("unexpected false")
return return
} }
<-ch
} }
}() }()
} }
zeroSendsStart := testCounterZeroDerpReadResultSend.Value()
buf := make([]byte, 1500) buf := make([]byte, 1500)
for i := 0; i < sends; i++ { for i := 0; i < sends; i++ {
n, ep, err := conn.ReceiveIPv4(buf) n, ep, err := conn.ReceiveIPv4(buf)
@ -1495,6 +1504,20 @@ func TestDerpReceiveFromIPv4(t *testing.T) {
_ = n _ = n
_ = ep _ = ep
} }
t.Logf("did %d ReceiveIPv4 calls", sends)
zeroSends, zeroRecv := testCounterZeroDerpReadResultSend.Value(), testCounterZeroDerpReadResultRecv.Value()
if zeroSends != zeroRecv {
t.Errorf("did %d zero sends != %d corresponding receives", zeroSends, zeroRecv)
}
zeroSendDelta := zeroSends - zeroSendsStart
if zeroSendDelta == 0 {
t.Errorf("didn't see any sends of derpReadResult zero value")
}
if zeroSendDelta == int64(sends) {
t.Errorf("saw %v sends of the derpReadResult zero value which was unexpectedly high (100%% of our %v sends)", zeroSendDelta, sends)
}
} }
// addTestEndpoint sets conn's network map to a single peer expected // addTestEndpoint sets conn's network map to a single peer expected