net/tstun: fix data races

To remove some multi-case selects, we intentionally allowed
sends on closed channels (cc23049cd2).

However, we also introduced concurrent sends and closes,
which is a data race.

This commit fixes the data race. The mutexes here are uncontended,
and thus very cheap.

Signed-off-by: Josh Bleecher Snyder <josh@tailscale.com>
This commit is contained in:
Josh Bleecher Snyder 2021-07-07 15:45:00 -07:00 committed by Josh Bleecher Snyder
parent 7d417586a8
commit 0ad92b89a6

View File

@ -71,6 +71,9 @@ type Wrapper struct {
// buffer stores the oldest unconsumed packet from tdev.
// It is made a static buffer in order to avoid allocations.
buffer [maxBufferSize]byte
// bufferConsumedMu protects bufferConsumed from concurrent sends and closes.
// It does not prevent send-after-close, only data races.
bufferConsumedMu sync.Mutex
// bufferConsumed synchronizes access to buffer (shared by Read and poll).
//
// Close closes bufferConsumed. There may be outstanding sends to bufferConsumed
@ -80,6 +83,9 @@ type Wrapper struct {
// closed signals poll (by closing) when the device is closed.
closed chan struct{}
// outboundMu protects outbound from concurrent sends and closes.
// It does not prevent send-after-close, only data races.
outboundMu sync.Mutex
// outbound is the queue by which packets leave the TUN device.
//
// The directions are relative to the network, not the device:
@ -174,8 +180,12 @@ func (t *Wrapper) Close() error {
var err error
t.closeOnce.Do(func() {
close(t.closed)
t.bufferConsumedMu.Lock()
close(t.bufferConsumed)
t.bufferConsumedMu.Unlock()
t.outboundMu.Lock()
close(t.outbound)
t.outboundMu.Unlock()
err = t.tdev.Close()
})
return err
@ -275,7 +285,6 @@ func allowSendOnClosedChannel() {
// This is needed because t.tdev.Read in general may block (it does on Windows),
// so packets may be stuck in t.outbound if t.Read called t.tdev.Read directly.
func (t *Wrapper) poll() {
defer allowSendOnClosedChannel() // for send to t.outbound
for range t.bufferConsumed {
var n int
var err error
@ -293,10 +302,28 @@ func (t *Wrapper) poll() {
}
n, err = t.tdev.Read(t.buffer[:], PacketStartOffset)
}
t.outbound <- tunReadResult{data: t.buffer[PacketStartOffset : PacketStartOffset+n], err: err}
t.sendOutbound(tunReadResult{data: t.buffer[PacketStartOffset : PacketStartOffset+n], err: err})
}
}
// sendBufferConsumed does t.bufferConsumed <- struct{}{}.
// It protects against any panics or data races that that send could cause.
func (t *Wrapper) sendBufferConsumed() {
defer allowSendOnClosedChannel()
t.bufferConsumedMu.Lock()
defer t.bufferConsumedMu.Unlock()
t.bufferConsumed <- struct{}{}
}
// sendOutbound does t.outboundMu <- r.
// It protects against any panics or data races that that send could cause.
func (t *Wrapper) sendOutbound(r tunReadResult) {
defer allowSendOnClosedChannel()
t.outboundMu.Lock()
defer t.outboundMu.Unlock()
t.outbound <- r
}
var magicDNSIPPort = netaddr.MustParseIPPort("100.100.100.100:0")
func (t *Wrapper) filterOut(p *packet.Parsed) filter.Response {
@ -357,7 +384,6 @@ func (t *Wrapper) Read(buf []byte, offset int) (int, error) {
if res.err != nil {
return 0, res.err
}
defer allowSendOnClosedChannel() // for send to t.bufferConsumed
pkt := res.data
n := copy(buf[offset:], pkt)
// t.buffer has a fixed location in memory.
@ -366,7 +392,7 @@ func (t *Wrapper) Read(buf []byte, offset int) (int, error) {
isInjectedPacket := &pkt[0] != &t.buffer[PacketStartOffset]
if !isInjectedPacket {
// We are done with t.buffer. Let poll re-use it.
t.bufferConsumed <- struct{}{}
t.sendBufferConsumed()
}
p := parsedPacketPool.Get().(*packet.Parsed)
@ -583,8 +609,7 @@ func (t *Wrapper) InjectOutbound(packet []byte) error {
if len(packet) == 0 {
return nil
}
defer allowSendOnClosedChannel() // for send to t.outbound
t.outbound <- tunReadResult{data: packet}
t.sendOutbound(tunReadResult{data: packet})
return nil
}