diff --git a/net/tstun/wrap.go b/net/tstun/wrap.go index c11ecec08..756e68a0e 100644 --- a/net/tstun/wrap.go +++ b/net/tstun/wrap.go @@ -19,6 +19,7 @@ "go4.org/mem" "golang.zx2c4.com/wireguard/device" "golang.zx2c4.com/wireguard/tun" + "gvisor.dev/gvisor/pkg/tcpip/stack" "inet.af/netaddr" "tailscale.com/disco" "tailscale.com/net/packet" @@ -154,12 +155,19 @@ type Wrapper struct { disableTSMPRejected bool } -// tunReadResult is the result of a TUN read: Some data and an error. -// The byte slice is not interpreted in the usual way for a Read method. +// tunReadResult is the result of a TUN read, or an injected result pretending to be a TUN read. +// The data is not interpreted in the usual way for a Read method. // See the comment in the middle of Wrap.Read. type tunReadResult struct { - data []byte - err error + // Only one of err, packet or data should be set, and are read in that order + // of precendence. + err error + packet *stack.PacketBuffer + data []byte + + // injected is set if the read result was generated internally, and contained packets should not + // pass through filters. + injected bool } func WrapTAP(logf logger.Logf, tdev tun.Device) *Wrapper { @@ -494,15 +502,22 @@ func (t *Wrapper) Read(buf []byte, offset int) (int, error) { } metricPacketOut.Add(1) - pkt := res.data - n := copy(buf[offset:], pkt) - // t.buffer has a fixed location in memory. - // If the packet is not from t.buffer, then it is an injected packet. - // &pkt[0] can be used because empty packets do not reach t.outbound. - isInjectedPacket := &pkt[0] != &t.buffer[PacketStartOffset] - if !isInjectedPacket { - // We are done with t.buffer. Let poll re-use it. - t.sendBufferConsumed() + + var n int + if res.packet != nil { + n = copy(buf[offset:], res.packet.NetworkHeader().View()) + n += copy(buf[offset+n:], res.packet.TransportHeader().View()) + n += copy(buf[offset+n:], res.packet.Data().AsRange().AsView()) + + res.packet.DecRef() + } else { + n = copy(buf[offset:], res.data) + + // t.buffer has a fixed location in memory. + if &res.data[0] == &t.buffer[PacketStartOffset] { + // We are done with t.buffer. Let poll re-use it. + t.sendBufferConsumed() + } } p := parsedPacketPool.Get().(*packet.Parsed) @@ -516,7 +531,7 @@ func (t *Wrapper) Read(buf []byte, offset int) (int, error) { } // Do not filter injected packets. - if !isInjectedPacket && !t.disableFilter { + if !res.injected && !t.disableFilter { response := t.filterOut(p) if response != filter.Accept { metricPacketOutDrop.Add(1) @@ -741,7 +756,24 @@ func (t *Wrapper) InjectOutbound(packet []byte) error { if len(packet) == 0 { return nil } - t.sendOutbound(tunReadResult{data: packet}) + t.sendOutbound(tunReadResult{data: packet, injected: true}) + return nil +} + +// InjectOutboundPacketBuffer logically behaves as InjectOutbound. It takes ownership of one +// reference count on the packet, and the packet may be mutated. The packet refcount will be +// decremented after the injected buffer has been read. +func (t *Wrapper) InjectOutboundPacketBuffer(packet *stack.PacketBuffer) error { + size := packet.Size() + if size > MaxPacketSize { + packet.DecRef() + return errPacketTooBig + } + if size == 0 { + packet.DecRef() + return nil + } + t.sendOutbound(tunReadResult{packet: packet, injected: true}) return nil } diff --git a/wgengine/netstack/netstack.go b/wgengine/netstack/netstack.go index bbfd47512..670d27488 100644 --- a/wgengine/netstack/netstack.go +++ b/wgengine/netstack/netstack.go @@ -382,17 +382,14 @@ func (ns *Impl) injectOutbound() { ns.logf("[v2] ReadContext-for-write = ok=false") continue } - hdrNetwork := pkt.NetworkHeader() - hdrTransport := pkt.TransportHeader() - full := make([]byte, 0, pkt.Size()) - full = append(full, hdrNetwork.View()...) - full = append(full, hdrTransport.View()...) - full = append(full, pkt.Data().AsRange().AsView()...) if debugPackets { - ns.logf("[v2] packet Write out: % x", full) + ns.logf("[v2] packet Write out: % x", stack.PayloadSince(pkt.NetworkHeader())) } - if err := ns.tundev.InjectOutbound(full); err != nil { + + // pkt has a non-zero refcount, InjectOutboundPacketBuffer takes + // ownership of one count and will decrement on completion. + if err := ns.tundev.InjectOutboundPacketBuffer(pkt); err != nil { log.Printf("netstack inject outbound: %v", err) return }