diff --git a/net/tstun/wrap.go b/net/tstun/wrap.go index c65b65c88..25c585671 100644 --- a/net/tstun/wrap.go +++ b/net/tstun/wrap.go @@ -72,12 +72,14 @@ type Wrapper struct { // It is made a static buffer in order to avoid allocations. buffer [maxBufferSize]byte // bufferConsumed synchronizes access to buffer (shared by Read and poll). + // + // Close closes bufferConsumed. There may be outstanding sends to bufferConsumed + // when that happens; we catch any resulting panics. + // This lets us avoid expensive multi-case selects. bufferConsumed chan struct{} // closed signals poll (by closing) when the device is closed. closed chan struct{} - // errors is the error queue populated by poll. - errors chan error // outbound is the queue by which packets leave the TUN device. // // The directions are relative to the network, not the device: @@ -88,7 +90,11 @@ type Wrapper struct { // // Empty reads are skipped by Wireguard, so it is always legal // to discard an empty packet instead of sending it through t.outbound. - outbound chan []byte + // + // Close closes outbound. There may be outstanding sends to outbound + // when that happens; we catch any resulting panics. + // This lets us avoid expensive multi-case selects. + outbound chan tunReadResult // eventsUpDown yields up and down tun.Events that arrive on a Wrapper's events channel. eventsUpDown chan tun.Event @@ -125,6 +131,14 @@ type Wrapper struct { disableTSMPRejected bool } +// tunReadResult is the result of a TUN read: Some data and an error. +// The byte slice is not interpreted in the usual way for a Read method. +// See the comment in the middle of Wrap.Read. +type tunReadResult struct { + data []byte + err error +} + func Wrap(logf logger.Logf, tdev tun.Device) *Wrapper { tun := &Wrapper{ logf: logger.WithPrefix(logf, "tstun: "), @@ -133,8 +147,7 @@ func Wrap(logf logger.Logf, tdev tun.Device) *Wrapper { // a goroutine should not block when setting it, even with no listeners. bufferConsumed: make(chan struct{}, 1), closed: make(chan struct{}), - errors: make(chan error), - outbound: make(chan []byte), + outbound: make(chan tunReadResult), eventsUpDown: make(chan tun.Event), eventsOther: make(chan tun.Event), // TODO(dmytro): (highly rate-limited) hexdumps should happen on unknown packets. @@ -160,9 +173,9 @@ func (t *Wrapper) SetDestIPActivityFuncs(m map[netaddr.IP]func()) { func (t *Wrapper) Close() error { var err error t.closeOnce.Do(func() { - // Other channels need not be closed: poll will exit gracefully after this. close(t.closed) - + close(t.bufferConsumed) + close(t.outbound) err = t.tdev.Close() }) return err @@ -230,30 +243,40 @@ func (t *Wrapper) Name() (string, error) { return t.tdev.Name() } +// allowSendOnClosedChannel suppresses panics due to sending on a closed channel. +// This allows us to avoid synchronization between poll and Close. +// Such synchronization (particularly multi-case selects) is too expensive +// for code like poll or Read that is on the hot path of every packet. +// If this makes you sad or angry, you may want to join our +// weekly Go Performance Delinquents Anonymous meetings on Monday nights. +func allowSendOnClosedChannel() { + r := recover() + if r == nil { + return + } + e, _ := r.(error) + if e != nil && e.Error() == "send on closed channel" { + return + } + panic(r) +} + // poll polls t.tdev.Read, placing the oldest unconsumed packet into t.buffer. // This is needed because t.tdev.Read in general may block (it does on Windows), // so packets may be stuck in t.outbound if t.Read called t.tdev.Read directly. func (t *Wrapper) poll() { + defer allowSendOnClosedChannel() // for send to t.outbound for { - select { - case <-t.closed: - return - case <-t.bufferConsumed: - // continue - } + <-t.bufferConsumed // Read may use memory in t.buffer before PacketStartOffset for mandatory headers. // This is the rationale behind the tun.Wrapper.{Read,Write} interfaces // and the reason t.buffer has size MaxMessageSize and not MaxContentSize. n, err := t.tdev.Read(t.buffer[:], PacketStartOffset) if err != nil { - select { - case <-t.closed: - return - case t.errors <- err: - // In principle, read errors are not fatal (but wireguard-go disagrees). - t.bufferConsumed <- struct{}{} - } + t.outbound <- tunReadResult{err: err} + // In principle, read errors are not fatal (but wireguard-go disagrees). + t.bufferConsumed <- struct{}{} continue } @@ -264,12 +287,7 @@ func (t *Wrapper) poll() { continue } - select { - case <-t.closed: - return - case t.outbound <- t.buffer[PacketStartOffset : PacketStartOffset+n]: - // continue - } + t.outbound <- tunReadResult{data: t.buffer[PacketStartOffset : PacketStartOffset+n]} } } @@ -325,26 +343,26 @@ func (t *Wrapper) IdleDuration() time.Duration { } func (t *Wrapper) Read(buf []byte, offset int) (int, error) { - var n int - - wasInjectedPacket := false - - select { - case <-t.closed: + res, ok := <-t.outbound + if !ok { + // Wrapper is closed. return 0, io.EOF - case err := <-t.errors: - return 0, err - case pkt := <-t.outbound: - n = copy(buf[offset:], pkt) - // t.buffer has a fixed location in memory, - // so this is the easiest way to tell when it has been consumed. - // &pkt[0] can be used because empty packets do not reach t.outbound. - if &pkt[0] == &t.buffer[PacketStartOffset] { - t.bufferConsumed <- struct{}{} - } else { - // If the packet is not from t.buffer, then it is an injected packet. - wasInjectedPacket = true - } + } + if res.err != nil { + return 0, res.err + } + defer allowSendOnClosedChannel() // for send to t.bufferConsumed + pkt := res.data + n := copy(buf[offset:], pkt) + wasInjectedPacket := false + // t.buffer has a fixed location in memory, + // so this is the easiest way to tell when it has been consumed. + // &pkt[0] can be used because empty packets do not reach t.outbound. + if &pkt[0] == &t.buffer[PacketStartOffset] { + t.bufferConsumed <- struct{}{} + } else { + // If the packet is not from t.buffer, then it is an injected packet. + wasInjectedPacket = true } p := parsedPacketPool.Get().(*packet.Parsed) @@ -566,12 +584,9 @@ func (t *Wrapper) InjectOutbound(packet []byte) error { if len(packet) == 0 { return nil } - select { - case <-t.closed: - return ErrClosed - case t.outbound <- packet: - return nil - } + defer allowSendOnClosedChannel() // for send to t.outbound + t.outbound <- tunReadResult{data: packet} + return nil } // Unwrap returns the underlying tun.Device.