diff --git a/net/tstun/wrap.go b/net/tstun/wrap.go index 655c4ad46..5bebf16ca 100644 --- a/net/tstun/wrap.go +++ b/net/tstun/wrap.go @@ -71,6 +71,9 @@ type Wrapper struct { // buffer stores the oldest unconsumed packet from tdev. // It is made a static buffer in order to avoid allocations. buffer [maxBufferSize]byte + // bufferConsumedMu protects bufferConsumed from concurrent sends and closes. + // It does not prevent send-after-close, only data races. + bufferConsumedMu sync.Mutex // bufferConsumed synchronizes access to buffer (shared by Read and poll). // // Close closes bufferConsumed. There may be outstanding sends to bufferConsumed @@ -80,6 +83,9 @@ type Wrapper struct { // closed signals poll (by closing) when the device is closed. closed chan struct{} + // outboundMu protects outbound from concurrent sends and closes. + // It does not prevent send-after-close, only data races. + outboundMu sync.Mutex // outbound is the queue by which packets leave the TUN device. // // The directions are relative to the network, not the device: @@ -174,8 +180,12 @@ func (t *Wrapper) Close() error { var err error t.closeOnce.Do(func() { close(t.closed) + t.bufferConsumedMu.Lock() close(t.bufferConsumed) + t.bufferConsumedMu.Unlock() + t.outboundMu.Lock() close(t.outbound) + t.outboundMu.Unlock() err = t.tdev.Close() }) return err @@ -275,7 +285,6 @@ func allowSendOnClosedChannel() { // This is needed because t.tdev.Read in general may block (it does on Windows), // so packets may be stuck in t.outbound if t.Read called t.tdev.Read directly. func (t *Wrapper) poll() { - defer allowSendOnClosedChannel() // for send to t.outbound for range t.bufferConsumed { var n int var err error @@ -293,10 +302,28 @@ func (t *Wrapper) poll() { } n, err = t.tdev.Read(t.buffer[:], PacketStartOffset) } - t.outbound <- tunReadResult{data: t.buffer[PacketStartOffset : PacketStartOffset+n], err: err} + t.sendOutbound(tunReadResult{data: t.buffer[PacketStartOffset : PacketStartOffset+n], err: err}) } } +// sendBufferConsumed does t.bufferConsumed <- struct{}{}. +// It protects against any panics or data races that that send could cause. +func (t *Wrapper) sendBufferConsumed() { + defer allowSendOnClosedChannel() + t.bufferConsumedMu.Lock() + defer t.bufferConsumedMu.Unlock() + t.bufferConsumed <- struct{}{} +} + +// sendOutbound does t.outboundMu <- r. +// It protects against any panics or data races that that send could cause. +func (t *Wrapper) sendOutbound(r tunReadResult) { + defer allowSendOnClosedChannel() + t.outboundMu.Lock() + defer t.outboundMu.Unlock() + t.outbound <- r +} + var magicDNSIPPort = netaddr.MustParseIPPort("100.100.100.100:0") func (t *Wrapper) filterOut(p *packet.Parsed) filter.Response { @@ -357,7 +384,6 @@ func (t *Wrapper) Read(buf []byte, offset int) (int, error) { if res.err != nil { return 0, res.err } - defer allowSendOnClosedChannel() // for send to t.bufferConsumed pkt := res.data n := copy(buf[offset:], pkt) // t.buffer has a fixed location in memory. @@ -366,7 +392,7 @@ func (t *Wrapper) Read(buf []byte, offset int) (int, error) { isInjectedPacket := &pkt[0] != &t.buffer[PacketStartOffset] if !isInjectedPacket { // We are done with t.buffer. Let poll re-use it. - t.bufferConsumed <- struct{}{} + t.sendBufferConsumed() } p := parsedPacketPool.Get().(*packet.Parsed) @@ -583,8 +609,7 @@ func (t *Wrapper) InjectOutbound(packet []byte) error { if len(packet) == 0 { return nil } - defer allowSendOnClosedChannel() // for send to t.outbound - t.outbound <- tunReadResult{data: packet} + t.sendOutbound(tunReadResult{data: packet}) return nil }