net/tstun: fix data races

To remove some multi-case selects, we intentionally allowed
sends on closed channels (cc23049cd2).

However, we also introduced concurrent sends and closes,
which is a data race.

This commit fixes the data race. The mutexes here are uncontended,
and thus very cheap.

Signed-off-by: Josh Bleecher Snyder <josh@tailscale.com>
This commit is contained in:
Josh Bleecher Snyder 2021-07-07 15:45:00 -07:00 committed by Josh Bleecher Snyder
parent 7d417586a8
commit 0ad92b89a6

View File

@ -71,6 +71,9 @@ type Wrapper struct {
// buffer stores the oldest unconsumed packet from tdev. // buffer stores the oldest unconsumed packet from tdev.
// It is made a static buffer in order to avoid allocations. // It is made a static buffer in order to avoid allocations.
buffer [maxBufferSize]byte buffer [maxBufferSize]byte
// bufferConsumedMu protects bufferConsumed from concurrent sends and closes.
// It does not prevent send-after-close, only data races.
bufferConsumedMu sync.Mutex
// bufferConsumed synchronizes access to buffer (shared by Read and poll). // bufferConsumed synchronizes access to buffer (shared by Read and poll).
// //
// Close closes bufferConsumed. There may be outstanding sends to bufferConsumed // Close closes bufferConsumed. There may be outstanding sends to bufferConsumed
@ -80,6 +83,9 @@ type Wrapper struct {
// closed signals poll (by closing) when the device is closed. // closed signals poll (by closing) when the device is closed.
closed chan struct{} closed chan struct{}
// outboundMu protects outbound from concurrent sends and closes.
// It does not prevent send-after-close, only data races.
outboundMu sync.Mutex
// outbound is the queue by which packets leave the TUN device. // outbound is the queue by which packets leave the TUN device.
// //
// The directions are relative to the network, not the device: // The directions are relative to the network, not the device:
@ -174,8 +180,12 @@ func (t *Wrapper) Close() error {
var err error var err error
t.closeOnce.Do(func() { t.closeOnce.Do(func() {
close(t.closed) close(t.closed)
t.bufferConsumedMu.Lock()
close(t.bufferConsumed) close(t.bufferConsumed)
t.bufferConsumedMu.Unlock()
t.outboundMu.Lock()
close(t.outbound) close(t.outbound)
t.outboundMu.Unlock()
err = t.tdev.Close() err = t.tdev.Close()
}) })
return err return err
@ -275,7 +285,6 @@ func allowSendOnClosedChannel() {
// This is needed because t.tdev.Read in general may block (it does on Windows), // This is needed because t.tdev.Read in general may block (it does on Windows),
// so packets may be stuck in t.outbound if t.Read called t.tdev.Read directly. // so packets may be stuck in t.outbound if t.Read called t.tdev.Read directly.
func (t *Wrapper) poll() { func (t *Wrapper) poll() {
defer allowSendOnClosedChannel() // for send to t.outbound
for range t.bufferConsumed { for range t.bufferConsumed {
var n int var n int
var err error var err error
@ -293,10 +302,28 @@ func (t *Wrapper) poll() {
} }
n, err = t.tdev.Read(t.buffer[:], PacketStartOffset) n, err = t.tdev.Read(t.buffer[:], PacketStartOffset)
} }
t.outbound <- tunReadResult{data: t.buffer[PacketStartOffset : PacketStartOffset+n], err: err} t.sendOutbound(tunReadResult{data: t.buffer[PacketStartOffset : PacketStartOffset+n], err: err})
} }
} }
// sendBufferConsumed does t.bufferConsumed <- struct{}{}.
// It protects against any panics or data races that that send could cause.
func (t *Wrapper) sendBufferConsumed() {
defer allowSendOnClosedChannel()
t.bufferConsumedMu.Lock()
defer t.bufferConsumedMu.Unlock()
t.bufferConsumed <- struct{}{}
}
// sendOutbound does t.outboundMu <- r.
// It protects against any panics or data races that that send could cause.
func (t *Wrapper) sendOutbound(r tunReadResult) {
defer allowSendOnClosedChannel()
t.outboundMu.Lock()
defer t.outboundMu.Unlock()
t.outbound <- r
}
var magicDNSIPPort = netaddr.MustParseIPPort("100.100.100.100:0") var magicDNSIPPort = netaddr.MustParseIPPort("100.100.100.100:0")
func (t *Wrapper) filterOut(p *packet.Parsed) filter.Response { func (t *Wrapper) filterOut(p *packet.Parsed) filter.Response {
@ -357,7 +384,6 @@ func (t *Wrapper) Read(buf []byte, offset int) (int, error) {
if res.err != nil { if res.err != nil {
return 0, res.err return 0, res.err
} }
defer allowSendOnClosedChannel() // for send to t.bufferConsumed
pkt := res.data pkt := res.data
n := copy(buf[offset:], pkt) n := copy(buf[offset:], pkt)
// t.buffer has a fixed location in memory. // t.buffer has a fixed location in memory.
@ -366,7 +392,7 @@ func (t *Wrapper) Read(buf []byte, offset int) (int, error) {
isInjectedPacket := &pkt[0] != &t.buffer[PacketStartOffset] isInjectedPacket := &pkt[0] != &t.buffer[PacketStartOffset]
if !isInjectedPacket { if !isInjectedPacket {
// We are done with t.buffer. Let poll re-use it. // We are done with t.buffer. Let poll re-use it.
t.bufferConsumed <- struct{}{} t.sendBufferConsumed()
} }
p := parsedPacketPool.Get().(*packet.Parsed) p := parsedPacketPool.Get().(*packet.Parsed)
@ -583,8 +609,7 @@ func (t *Wrapper) InjectOutbound(packet []byte) error {
if len(packet) == 0 { if len(packet) == 0 {
return nil return nil
} }
defer allowSendOnClosedChannel() // for send to t.outbound t.sendOutbound(tunReadResult{data: packet})
t.outbound <- tunReadResult{data: packet}
return nil return nil
} }