net/tstun: provide exactly one buffer of readahead on tun

Signed-off-by: James Tucker <james@tailscale.com>
This commit is contained in:
James Tucker 2022-08-29 16:32:32 -07:00
parent 0c34fc7b5b
commit 7898bc9d56
No known key found for this signature in database

View File

@ -33,6 +33,7 @@
"tailscale.com/wgengine/filter" "tailscale.com/wgengine/filter"
) )
const bufferPoolSize = 2
const maxBufferSize = device.MaxMessageSize const maxBufferSize = device.MaxMessageSize
// PacketStartOffset is the minimal amount of leading space that must exist // PacketStartOffset is the minimal amount of leading space that must exist
@ -87,18 +88,9 @@ type Wrapper struct {
destMACAtomic syncs.AtomicValue[[6]byte] destMACAtomic syncs.AtomicValue[[6]byte]
discoKey syncs.AtomicValue[key.DiscoPublic] discoKey syncs.AtomicValue[key.DiscoPublic]
// buffer stores the oldest unconsumed packet from tdev. // bufferPool contains a limited number of buffers for use pulling packets
// It is made a static buffer in order to avoid allocations. // from the tun device.
buffer [maxBufferSize]byte bufferPool chan []byte
// bufferConsumedMu protects bufferConsumed from concurrent sends and closes.
// It does not prevent send-after-close, only data races.
bufferConsumedMu sync.Mutex
// bufferConsumed synchronizes access to buffer (shared by Read and poll).
//
// Close closes bufferConsumed. There may be outstanding sends to bufferConsumed
// when that happens; we catch any resulting panics.
// This lets us avoid expensive multi-case selects.
bufferConsumed chan struct{}
// closed signals poll (by closing) when the device is closed. // closed signals poll (by closing) when the device is closed.
closed chan struct{} closed chan struct{}
@ -178,6 +170,10 @@ type tunReadResult struct {
packet *stack.PacketBuffer packet *stack.PacketBuffer
data []byte data []byte
// If poolbuf is non-nil it should be returned to the tun buffer pool when
// the read result is no longer being aliased.
poolbuf []byte
// injected is set if the read result was generated internally, and contained packets should not // injected is set if the read result was generated internally, and contained packets should not
// pass through filters. // pass through filters.
injected bool injected bool
@ -194,25 +190,25 @@ func Wrap(logf logger.Logf, tdev tun.Device) *Wrapper {
func wrap(logf logger.Logf, tdev tun.Device, isTAP bool) *Wrapper { func wrap(logf logger.Logf, tdev tun.Device, isTAP bool) *Wrapper {
logf = logger.WithPrefix(logf, "tstun: ") logf = logger.WithPrefix(logf, "tstun: ")
tun := &Wrapper{ tun := &Wrapper{
logf: logf, logf: logf,
limitedLogf: logger.RateLimitedFn(logf, 1*time.Minute, 2, 10), limitedLogf: logger.RateLimitedFn(logf, 1*time.Minute, 2, 10),
isTAP: isTAP, isTAP: isTAP,
tdev: tdev, tdev: tdev,
// bufferConsumed is conceptually a condition variable: closed: make(chan struct{}),
// a goroutine should not block when setting it, even with no listeners. outbound: make(chan tunReadResult),
bufferConsumed: make(chan struct{}, 1), eventsUpDown: make(chan tun.Event),
closed: make(chan struct{}), eventsOther: make(chan tun.Event),
outbound: make(chan tunReadResult),
eventsUpDown: make(chan tun.Event),
eventsOther: make(chan tun.Event),
// TODO(dmytro): (highly rate-limited) hexdumps should happen on unknown packets. // TODO(dmytro): (highly rate-limited) hexdumps should happen on unknown packets.
filterFlags: filter.LogAccepts | filter.LogDrops, filterFlags: filter.LogAccepts | filter.LogDrops,
} }
tun.bufferPool = make(chan []byte, bufferPoolSize)
for i := 0; i < bufferPoolSize; i++ {
tun.bufferPool <- make([]byte, maxBufferSize)
}
go tun.poll() go tun.poll()
go tun.pumpEvents() go tun.pumpEvents()
// The buffer starts out consumed.
tun.bufferConsumed <- struct{}{}
tun.noteActivity() tun.noteActivity()
return tun return tun
@ -255,9 +251,6 @@ func (t *Wrapper) Close() error {
var err error var err error
t.closeOnce.Do(func() { t.closeOnce.Do(func() {
close(t.closed) close(t.closed)
t.bufferConsumedMu.Lock()
close(t.bufferConsumed)
t.bufferConsumedMu.Unlock()
t.outboundMu.Lock() t.outboundMu.Lock()
close(t.outbound) close(t.outbound)
t.outboundMu.Unlock() t.outboundMu.Unlock()
@ -362,7 +355,7 @@ func allowSendOnClosedChannel() {
// This is needed because t.tdev.Read in general may block (it does on Windows), // This is needed because t.tdev.Read in general may block (it does on Windows),
// so packets may be stuck in t.outbound if t.Read called t.tdev.Read directly. // so packets may be stuck in t.outbound if t.Read called t.tdev.Read directly.
func (t *Wrapper) poll() { func (t *Wrapper) poll() {
for range t.bufferConsumed { for buffer := range t.bufferPool {
DoRead: DoRead:
var n int var n int
var err error var err error
@ -379,21 +372,21 @@ func (t *Wrapper) poll() {
return return
} }
if t.isTAP { if t.isTAP {
n, err = t.tdev.Read(t.buffer[:], PacketStartOffset-ethernetFrameSize) n, err = t.tdev.Read(buffer[:], PacketStartOffset-ethernetFrameSize)
if tapDebug { if tapDebug {
s := fmt.Sprintf("% x", t.buffer[:]) s := fmt.Sprintf("% x", buffer[:])
for strings.HasSuffix(s, " 00") { for strings.HasSuffix(s, " 00") {
s = strings.TrimSuffix(s, " 00") s = strings.TrimSuffix(s, " 00")
} }
t.logf("TAP read %v, %v: %s", n, err, s) t.logf("TAP read %v, %v: %s", n, err, s)
} }
} else { } else {
n, err = t.tdev.Read(t.buffer[:], PacketStartOffset) n, err = t.tdev.Read(buffer[:], PacketStartOffset)
} }
} }
if t.isTAP { if t.isTAP {
if err == nil { if err == nil {
ethernetFrame := t.buffer[PacketStartOffset-ethernetFrameSize:][:n] ethernetFrame := buffer[PacketStartOffset-ethernetFrameSize:][:n]
if t.handleTAPFrame(ethernetFrame) { if t.handleTAPFrame(ethernetFrame) {
goto DoRead goto DoRead
} }
@ -403,20 +396,17 @@ func (t *Wrapper) poll() {
n -= ethernetFrameSize n -= ethernetFrameSize
} }
if tapDebug { if tapDebug {
t.logf("tap regular frame: %x", t.buffer[PacketStartOffset:PacketStartOffset+n]) t.logf("tap regular frame: %x", buffer[PacketStartOffset:PacketStartOffset+n])
} }
} }
t.sendOutbound(tunReadResult{data: t.buffer[PacketStartOffset : PacketStartOffset+n], err: err}) t.sendOutbound(tunReadResult{data: buffer[PacketStartOffset : PacketStartOffset+n], poolbuf: buffer, err: err})
} }
} }
// sendBufferConsumed does t.bufferConsumed <- struct{}{}. // sendBufferConsumed does t.bufferConsumed <- struct{}{}.
// It protects against any panics or data races that that send could cause. // It protects against any panics or data races that that send could cause.
func (t *Wrapper) sendBufferConsumed() { func (t *Wrapper) sendBufferConsumed(buffer []byte) {
defer allowSendOnClosedChannel() t.bufferPool <- buffer
t.bufferConsumedMu.Lock()
defer t.bufferConsumedMu.Unlock()
t.bufferConsumed <- struct{}{}
} }
// sendOutbound does t.outboundMu <- r. // sendOutbound does t.outboundMu <- r.
@ -515,6 +505,9 @@ func (t *Wrapper) Read(buf []byte, offset int) (int, error) {
// Wrapper is closed. // Wrapper is closed.
return 0, io.EOF return 0, io.EOF
} }
if res.poolbuf != nil {
defer t.sendBufferConsumed(res.poolbuf)
}
if res.err != nil { if res.err != nil {
return 0, res.err return 0, res.err
} }
@ -531,12 +524,6 @@ func (t *Wrapper) Read(buf []byte, offset int) (int, error) {
res.packet.DecRef() res.packet.DecRef()
} else { } else {
n = copy(buf[offset:], res.data) n = copy(buf[offset:], res.data)
// t.buffer has a fixed location in memory.
if &res.data[0] == &t.buffer[PacketStartOffset] {
// We are done with t.buffer. Let poll re-use it.
t.sendBufferConsumed()
}
} }
p := parsedPacketPool.Get().(*packet.Parsed) p := parsedPacketPool.Get().(*packet.Parsed)