From 7898bc9d56bd3ea899004e3d9697d6a573bf557d Mon Sep 17 00:00:00 2001 From: James Tucker Date: Mon, 29 Aug 2022 16:32:32 -0700 Subject: [PATCH] net/tstun: provide exactly one buffer of readahead on tun Signed-off-by: James Tucker --- net/tstun/wrap.go | 79 ++++++++++++++++++++--------------------------- 1 file changed, 33 insertions(+), 46 deletions(-) diff --git a/net/tstun/wrap.go b/net/tstun/wrap.go index 80b129279..e60dc6952 100644 --- a/net/tstun/wrap.go +++ b/net/tstun/wrap.go @@ -33,6 +33,7 @@ "tailscale.com/wgengine/filter" ) +const bufferPoolSize = 2 const maxBufferSize = device.MaxMessageSize // PacketStartOffset is the minimal amount of leading space that must exist @@ -87,18 +88,9 @@ type Wrapper struct { destMACAtomic syncs.AtomicValue[[6]byte] discoKey syncs.AtomicValue[key.DiscoPublic] - // buffer stores the oldest unconsumed packet from tdev. - // It is made a static buffer in order to avoid allocations. - buffer [maxBufferSize]byte - // bufferConsumedMu protects bufferConsumed from concurrent sends and closes. - // It does not prevent send-after-close, only data races. - bufferConsumedMu sync.Mutex - // bufferConsumed synchronizes access to buffer (shared by Read and poll). - // - // Close closes bufferConsumed. There may be outstanding sends to bufferConsumed - // when that happens; we catch any resulting panics. - // This lets us avoid expensive multi-case selects. - bufferConsumed chan struct{} + // bufferPool contains a limited number of buffers for use pulling packets + // from the tun device. + bufferPool chan []byte // closed signals poll (by closing) when the device is closed. closed chan struct{} @@ -178,6 +170,10 @@ type tunReadResult struct { packet *stack.PacketBuffer data []byte + // If poolbuf is non-nil it should be returned to the tun buffer pool when + // the read result is no longer being aliased. + poolbuf []byte + // injected is set if the read result was generated internally, and contained packets should not // pass through filters. injected bool @@ -194,25 +190,25 @@ func Wrap(logf logger.Logf, tdev tun.Device) *Wrapper { func wrap(logf logger.Logf, tdev tun.Device, isTAP bool) *Wrapper { logf = logger.WithPrefix(logf, "tstun: ") tun := &Wrapper{ - logf: logf, - limitedLogf: logger.RateLimitedFn(logf, 1*time.Minute, 2, 10), - isTAP: isTAP, - tdev: tdev, - // bufferConsumed is conceptually a condition variable: - // a goroutine should not block when setting it, even with no listeners. - bufferConsumed: make(chan struct{}, 1), - closed: make(chan struct{}), - outbound: make(chan tunReadResult), - eventsUpDown: make(chan tun.Event), - eventsOther: make(chan tun.Event), + logf: logf, + limitedLogf: logger.RateLimitedFn(logf, 1*time.Minute, 2, 10), + isTAP: isTAP, + tdev: tdev, + closed: make(chan struct{}), + outbound: make(chan tunReadResult), + eventsUpDown: make(chan tun.Event), + eventsOther: make(chan tun.Event), // TODO(dmytro): (highly rate-limited) hexdumps should happen on unknown packets. filterFlags: filter.LogAccepts | filter.LogDrops, } + tun.bufferPool = make(chan []byte, bufferPoolSize) + for i := 0; i < bufferPoolSize; i++ { + tun.bufferPool <- make([]byte, maxBufferSize) + } + go tun.poll() go tun.pumpEvents() - // The buffer starts out consumed. - tun.bufferConsumed <- struct{}{} tun.noteActivity() return tun @@ -255,9 +251,6 @@ func (t *Wrapper) Close() error { var err error t.closeOnce.Do(func() { close(t.closed) - t.bufferConsumedMu.Lock() - close(t.bufferConsumed) - t.bufferConsumedMu.Unlock() t.outboundMu.Lock() close(t.outbound) t.outboundMu.Unlock() @@ -362,7 +355,7 @@ func allowSendOnClosedChannel() { // This is needed because t.tdev.Read in general may block (it does on Windows), // so packets may be stuck in t.outbound if t.Read called t.tdev.Read directly. func (t *Wrapper) poll() { - for range t.bufferConsumed { + for buffer := range t.bufferPool { DoRead: var n int var err error @@ -379,21 +372,21 @@ func (t *Wrapper) poll() { return } if t.isTAP { - n, err = t.tdev.Read(t.buffer[:], PacketStartOffset-ethernetFrameSize) + n, err = t.tdev.Read(buffer[:], PacketStartOffset-ethernetFrameSize) if tapDebug { - s := fmt.Sprintf("% x", t.buffer[:]) + s := fmt.Sprintf("% x", buffer[:]) for strings.HasSuffix(s, " 00") { s = strings.TrimSuffix(s, " 00") } t.logf("TAP read %v, %v: %s", n, err, s) } } else { - n, err = t.tdev.Read(t.buffer[:], PacketStartOffset) + n, err = t.tdev.Read(buffer[:], PacketStartOffset) } } if t.isTAP { if err == nil { - ethernetFrame := t.buffer[PacketStartOffset-ethernetFrameSize:][:n] + ethernetFrame := buffer[PacketStartOffset-ethernetFrameSize:][:n] if t.handleTAPFrame(ethernetFrame) { goto DoRead } @@ -403,20 +396,17 @@ func (t *Wrapper) poll() { n -= ethernetFrameSize } if tapDebug { - t.logf("tap regular frame: %x", t.buffer[PacketStartOffset:PacketStartOffset+n]) + t.logf("tap regular frame: %x", buffer[PacketStartOffset:PacketStartOffset+n]) } } - t.sendOutbound(tunReadResult{data: t.buffer[PacketStartOffset : PacketStartOffset+n], err: err}) + t.sendOutbound(tunReadResult{data: buffer[PacketStartOffset : PacketStartOffset+n], poolbuf: buffer, err: err}) } } // sendBufferConsumed does t.bufferConsumed <- struct{}{}. // It protects against any panics or data races that that send could cause. -func (t *Wrapper) sendBufferConsumed() { - defer allowSendOnClosedChannel() - t.bufferConsumedMu.Lock() - defer t.bufferConsumedMu.Unlock() - t.bufferConsumed <- struct{}{} +func (t *Wrapper) sendBufferConsumed(buffer []byte) { + t.bufferPool <- buffer } // sendOutbound does t.outboundMu <- r. @@ -515,6 +505,9 @@ func (t *Wrapper) Read(buf []byte, offset int) (int, error) { // Wrapper is closed. return 0, io.EOF } + if res.poolbuf != nil { + defer t.sendBufferConsumed(res.poolbuf) + } if res.err != nil { return 0, res.err } @@ -531,12 +524,6 @@ func (t *Wrapper) Read(buf []byte, offset int) (int, error) { res.packet.DecRef() } else { n = copy(buf[offset:], res.data) - - // t.buffer has a fixed location in memory. - if &res.data[0] == &t.buffer[PacketStartOffset] { - // We are done with t.buffer. Let poll re-use it. - t.sendBufferConsumed() - } } p := parsedPacketPool.Get().(*packet.Parsed)