tailscale/wgengine/tstun/tun.go
Dmytro Shynkevych 02231e968e
wgengine/tstun: add tests and benchmarks (#436)
Signed-off-by: Dmytro Shynkevych <dmytro@tailscale.com>
2020-06-05 11:19:03 -04:00

305 lines
7.8 KiB
Go

// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package tstun provides a TUN struct implementing the tun.Device interface
// with additional features as required by wgengine.
package tstun
import (
"errors"
"io"
"os"
"sync/atomic"
"github.com/tailscale/wireguard-go/device"
"github.com/tailscale/wireguard-go/tun"
"tailscale.com/types/logger"
"tailscale.com/wgengine/filter"
"tailscale.com/wgengine/packet"
)
const (
readMaxSize = device.MaxMessageSize
readOffset = device.MessageTransportHeaderSize
)
// MaxPacketSize is the maximum size (in bytes)
// of a packet that can be injected into a tstun.TUN.
const MaxPacketSize = device.MaxContentSize
var (
// ErrClosed is returned when attempting an operation on a closed TUN.
ErrClosed = errors.New("device closed")
// ErrFiltered is returned when the acted-on packet is rejected by a filter.
ErrFiltered = errors.New("packet dropped by filter")
)
var errPacketTooBig = errors.New("packet too big")
// TUN wraps a tun.Device from wireguard-go,
// augmenting it with filtering and packet injection.
// All the added work happens in Read and Write:
// the other methods delegate to the underlying tdev.
type TUN struct {
logf logger.Logf
// tdev is the underlying TUN device.
tdev tun.Device
// buffer stores the oldest unconsumed packet from tdev.
// It is made a static buffer in order to avoid graticious allocation.
buffer [readMaxSize]byte
// bufferConsumed synchronizes access to buffer (shared by Read and poll).
bufferConsumed chan struct{}
// closed signals poll (by closing) when the device is closed.
closed chan struct{}
// errors is the error queue populated by poll.
errors chan error
// outbound is the queue by which packets leave the TUN device.
//
// The directions are relative to the network, not the device:
// inbound packets arrive via UDP and are written into the TUN device;
// outbound packets are read from the TUN device and sent out via UDP.
// This queue is needed because although inbound writes are synchronous,
// the other direction must wait on a Wireguard goroutine to poll it.
//
// Empty reads are skipped by Wireguard, so it is always legal
// to discard an empty packet instead of sending it through t.outbound.
outbound chan []byte
// fitler stores the currently active package filter
filter atomic.Value // of *filter.Filter
// filterFlags control the verbosity of logging packet drops/accepts.
filterFlags filter.RunFlags
// insecure disables all filtering when set. This is useful in tests.
insecure bool
}
func WrapTUN(logf logger.Logf, tdev tun.Device) *TUN {
tun := &TUN{
logf: logf,
tdev: tdev,
// bufferConsumed is conceptually a condition variable:
// a goroutine should not block when setting it, even with no listeners.
bufferConsumed: make(chan struct{}, 1),
closed: make(chan struct{}),
errors: make(chan error),
outbound: make(chan []byte),
filterFlags: filter.LogAccepts | filter.LogDrops,
}
go tun.poll()
// The buffer starts out consumed.
tun.bufferConsumed <- struct{}{}
return tun
}
func (t *TUN) Close() error {
select {
case <-t.closed:
// continue
default:
// Other channels need not be closed: poll will exit gracefully after this.
close(t.closed)
}
return t.tdev.Close()
}
func (t *TUN) Events() chan tun.Event {
return t.tdev.Events()
}
func (t *TUN) File() *os.File {
return t.tdev.File()
}
func (t *TUN) Flush() error {
return t.tdev.Flush()
}
func (t *TUN) MTU() (int, error) {
return t.tdev.MTU()
}
func (t *TUN) Name() (string, error) {
return t.tdev.Name()
}
// poll polls t.tdev.Read, placing the oldest unconsumed packet into t.buffer.
// This is needed because t.tdev.Read in general may block (it does on Windows),
// so packets may be stuck in t.outbound if t.Read called t.tdev.Read directly.
func (t *TUN) poll() {
for {
select {
case <-t.closed:
return
case <-t.bufferConsumed:
// continue
}
// Read may use memory in t.buffer before readOffset for mandatory headers.
// This is the rationale behind the tun.TUN.{Read,Write} interfaces
// and the reason t.buffer has size MaxMessageSize and not MaxContentSize.
n, err := t.tdev.Read(t.buffer[:], readOffset)
if err != nil {
select {
case <-t.closed:
return
case t.errors <- err:
// In principle, read errors are not fatal (but wireguard-go disagrees).
t.bufferConsumed <- struct{}{}
}
continue
}
// Wireguard will skip an empty read,
// so we might as well do it here to avoid the send through t.outbound.
if n == 0 {
t.bufferConsumed <- struct{}{}
continue
}
select {
case <-t.closed:
return
case t.outbound <- t.buffer[readOffset : readOffset+n]:
// continue
}
}
}
func (t *TUN) filterOut(buf []byte) filter.Response {
filt, _ := t.filter.Load().(*filter.Filter)
if filt == nil {
t.logf("Warning: you forgot to use SetFilter()! Packet dropped.")
return filter.Drop
}
var p packet.ParsedPacket
if filt.RunOut(buf, &p, t.filterFlags) == filter.Accept {
return filter.Accept
}
return filter.Drop
}
func (t *TUN) Read(buf []byte, offset int) (int, error) {
var n int
select {
case <-t.closed:
return 0, io.EOF
case err := <-t.errors:
return 0, err
case packet := <-t.outbound:
n = copy(buf[offset:], packet)
// t.buffer has a fixed location in memory,
// so this is the easiest way to tell when it has been consumed.
// &packet[0] can be used because empty packets do not reach t.outbound.
if &packet[0] == &t.buffer[readOffset] {
t.bufferConsumed <- struct{}{}
}
}
if !t.insecure {
response := t.filterOut(buf[offset : offset+n])
if response != filter.Accept {
// Wireguard considers read errors fatal; pretend nothing was read
return 0, nil
}
}
return n, nil
}
func (t *TUN) filterIn(buf []byte) filter.Response {
filt, _ := t.filter.Load().(*filter.Filter)
if filt == nil {
t.logf("Warning: you forgot to use SetFilter()! Packet dropped.")
return filter.Drop
}
var p packet.ParsedPacket
if filt.RunIn(buf, &p, t.filterFlags) == filter.Accept {
// Only in fake mode, answer any incoming pings.
if p.IsEchoRequest() {
ft, ok := t.tdev.(*fakeTUN)
if ok {
header := p.ICMPHeader()
header.ToResponse()
packet := packet.Generate(&header, p.Payload())
ft.Write(packet, 0)
// We already handled it, stop.
return filter.Drop
}
}
return filter.Accept
}
return filter.Drop
}
func (t *TUN) Write(buf []byte, offset int) (int, error) {
if !t.insecure {
response := t.filterIn(buf[offset:])
if response != filter.Accept {
return 0, ErrFiltered
}
}
return t.tdev.Write(buf, offset)
}
func (t *TUN) GetFilter() *filter.Filter {
filt, _ := t.filter.Load().(*filter.Filter)
return filt
}
func (t *TUN) SetFilter(filt *filter.Filter) {
t.filter.Store(filt)
}
// InjectInbound makes the TUN device behave as if a packet
// with the given contents was received from the network.
// It blocks and does not take ownership of the packet.
// Injecting an empty packet is a no-op.
func (t *TUN) InjectInbound(packet []byte) error {
if len(packet) > MaxPacketSize {
return errPacketTooBig
}
if len(packet) == 0 {
return nil
}
_, err := t.Write(packet, 0)
return err
}
// InjectOutbound makes the TUN device behave as if a packet
// with the given contents was sent to the network.
// It does not block, but takes ownership of the packet.
// Injecting an empty packet is a no-op.
func (t *TUN) InjectOutbound(packet []byte) error {
if len(packet) > MaxPacketSize {
return errPacketTooBig
}
if len(packet) == 0 {
return nil
}
select {
case <-t.closed:
return ErrClosed
case t.outbound <- packet:
return nil
}
}
// Unwrap returns the underlying TUN device.
func (t *TUN) Unwrap() tun.Device {
return t.tdev
}