mirror of
https://github.com/tailscale/tailscale.git
synced 2025-08-12 05:37:32 +00:00
net/tstun: merge in wgengine/tstun.
Signed-off-by: David Anderson <danderson@tailscale.com>
This commit is contained in:

committed by
Dave Anderson

parent
018200aeba
commit
588b70f468
56
net/tstun/fake.go
Normal file
56
net/tstun/fake.go
Normal file
@@ -0,0 +1,56 @@
|
||||
// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package tstun
|
||||
|
||||
import (
|
||||
"io"
|
||||
"os"
|
||||
|
||||
"github.com/tailscale/wireguard-go/tun"
|
||||
)
|
||||
|
||||
type fakeTUN struct {
|
||||
evchan chan tun.Event
|
||||
closechan chan struct{}
|
||||
}
|
||||
|
||||
// NewFakeTUN returns a fake TUN device that does not depend on the
|
||||
// operating system or any special permissions.
|
||||
// It primarily exists for testing.
|
||||
func NewFakeTUN() tun.Device {
|
||||
return &fakeTUN{
|
||||
evchan: make(chan tun.Event),
|
||||
closechan: make(chan struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
func (t *fakeTUN) File() *os.File {
|
||||
panic("fakeTUN.File() called, which makes no sense")
|
||||
}
|
||||
|
||||
func (t *fakeTUN) Close() error {
|
||||
close(t.closechan)
|
||||
close(t.evchan)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *fakeTUN) Read(out []byte, offset int) (int, error) {
|
||||
<-t.closechan
|
||||
return 0, io.EOF
|
||||
}
|
||||
|
||||
func (t *fakeTUN) Write(b []byte, n int) (int, error) {
|
||||
select {
|
||||
case <-t.closechan:
|
||||
return 0, ErrClosed
|
||||
default:
|
||||
}
|
||||
return len(b), nil
|
||||
}
|
||||
|
||||
func (t *fakeTUN) Flush() error { return nil }
|
||||
func (t *fakeTUN) MTU() (int, error) { return 1500, nil }
|
||||
func (t *fakeTUN) Name() (string, error) { return "FakeTUN", nil }
|
||||
func (t *fakeTUN) Events() chan tun.Event { return t.evchan }
|
501
net/tstun/wrap.go
Normal file
501
net/tstun/wrap.go
Normal file
@@ -0,0 +1,501 @@
|
||||
// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
// Package tstun provides a TUN struct implementing the tun.Device interface
|
||||
// with additional features as required by wgengine.
|
||||
package tstun
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"io"
|
||||
"os"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/tailscale/wireguard-go/device"
|
||||
"github.com/tailscale/wireguard-go/tun"
|
||||
"inet.af/netaddr"
|
||||
"tailscale.com/net/packet"
|
||||
"tailscale.com/types/ipproto"
|
||||
"tailscale.com/types/logger"
|
||||
"tailscale.com/wgengine/filter"
|
||||
)
|
||||
|
||||
const maxBufferSize = device.MaxMessageSize
|
||||
|
||||
// PacketStartOffset is the minimal amount of leading space that must exist
|
||||
// before &packet[offset] in a packet passed to Read, Write, or InjectInboundDirect.
|
||||
// This is necessary to avoid reallocation in wireguard-go internals.
|
||||
const PacketStartOffset = device.MessageTransportHeaderSize
|
||||
|
||||
// MaxPacketSize is the maximum size (in bytes)
|
||||
// of a packet that can be injected into a tstun.TUN.
|
||||
const MaxPacketSize = device.MaxContentSize
|
||||
|
||||
var (
|
||||
// ErrClosed is returned when attempting an operation on a closed TUN.
|
||||
ErrClosed = errors.New("device closed")
|
||||
// ErrFiltered is returned when the acted-on packet is rejected by a filter.
|
||||
ErrFiltered = errors.New("packet dropped by filter")
|
||||
)
|
||||
|
||||
var (
|
||||
errPacketTooBig = errors.New("packet too big")
|
||||
errOffsetTooBig = errors.New("offset larger than buffer length")
|
||||
errOffsetTooSmall = errors.New("offset smaller than PacketStartOffset")
|
||||
)
|
||||
|
||||
// parsedPacketPool holds a pool of Parsed structs for use in filtering.
|
||||
// This is needed because escape analysis cannot see that parsed packets
|
||||
// do not escape through {Pre,Post}Filter{In,Out}.
|
||||
var parsedPacketPool = sync.Pool{New: func() interface{} { return new(packet.Parsed) }}
|
||||
|
||||
// FilterFunc is a packet-filtering function with access to the TUN device.
|
||||
// It must not hold onto the packet struct, as its backing storage will be reused.
|
||||
type FilterFunc func(*packet.Parsed, *TUN) filter.Response
|
||||
|
||||
// TUN wraps a tun.Device from wireguard-go,
|
||||
// augmenting it with filtering and packet injection.
|
||||
// All the added work happens in Read and Write:
|
||||
// the other methods delegate to the underlying tdev.
|
||||
type TUN struct {
|
||||
logf logger.Logf
|
||||
// tdev is the underlying TUN device.
|
||||
tdev tun.Device
|
||||
|
||||
closeOnce sync.Once
|
||||
|
||||
lastActivityAtomic int64 // unix seconds of last send or receive
|
||||
|
||||
destIPActivity atomic.Value // of map[netaddr.IP]func()
|
||||
|
||||
// buffer stores the oldest unconsumed packet from tdev.
|
||||
// It is made a static buffer in order to avoid allocations.
|
||||
buffer [maxBufferSize]byte
|
||||
// bufferConsumed synchronizes access to buffer (shared by Read and poll).
|
||||
bufferConsumed chan struct{}
|
||||
|
||||
// closed signals poll (by closing) when the device is closed.
|
||||
closed chan struct{}
|
||||
// errors is the error queue populated by poll.
|
||||
errors chan error
|
||||
// outbound is the queue by which packets leave the TUN device.
|
||||
//
|
||||
// The directions are relative to the network, not the device:
|
||||
// inbound packets arrive via UDP and are written into the TUN device;
|
||||
// outbound packets are read from the TUN device and sent out via UDP.
|
||||
// This queue is needed because although inbound writes are synchronous,
|
||||
// the other direction must wait on a Wireguard goroutine to poll it.
|
||||
//
|
||||
// Empty reads are skipped by Wireguard, so it is always legal
|
||||
// to discard an empty packet instead of sending it through t.outbound.
|
||||
outbound chan []byte
|
||||
|
||||
// fitler stores the currently active package filter
|
||||
filter atomic.Value // of *filter.Filter
|
||||
// filterFlags control the verbosity of logging packet drops/accepts.
|
||||
filterFlags filter.RunFlags
|
||||
|
||||
// PreFilterIn is the inbound filter function that runs before the main filter
|
||||
// and therefore sees the packets that may be later dropped by it.
|
||||
PreFilterIn FilterFunc
|
||||
// PostFilterIn is the inbound filter function that runs after the main filter.
|
||||
PostFilterIn FilterFunc
|
||||
// PreFilterOut is the outbound filter function that runs before the main filter
|
||||
// and therefore sees the packets that may be later dropped by it.
|
||||
PreFilterOut FilterFunc
|
||||
// PostFilterOut is the outbound filter function that runs after the main filter.
|
||||
PostFilterOut FilterFunc
|
||||
|
||||
// OnTSMPPongReceived, if non-nil, is called whenever a TSMP pong arrives.
|
||||
OnTSMPPongReceived func(data [8]byte)
|
||||
|
||||
// disableFilter disables all filtering when set. This should only be used in tests.
|
||||
disableFilter bool
|
||||
}
|
||||
|
||||
func WrapTUN(logf logger.Logf, tdev tun.Device) *TUN {
|
||||
tun := &TUN{
|
||||
logf: logger.WithPrefix(logf, "tstun: "),
|
||||
tdev: tdev,
|
||||
// bufferConsumed is conceptually a condition variable:
|
||||
// a goroutine should not block when setting it, even with no listeners.
|
||||
bufferConsumed: make(chan struct{}, 1),
|
||||
closed: make(chan struct{}),
|
||||
errors: make(chan error),
|
||||
outbound: make(chan []byte),
|
||||
// TODO(dmytro): (highly rate-limited) hexdumps should happen on unknown packets.
|
||||
filterFlags: filter.LogAccepts | filter.LogDrops,
|
||||
}
|
||||
|
||||
go tun.poll()
|
||||
// The buffer starts out consumed.
|
||||
tun.bufferConsumed <- struct{}{}
|
||||
|
||||
return tun
|
||||
}
|
||||
|
||||
// SetDestIPActivityFuncs sets a map of funcs to run per packet
|
||||
// destination (the map keys).
|
||||
//
|
||||
// The map ownership passes to the TUN. It must be non-nil.
|
||||
func (t *TUN) SetDestIPActivityFuncs(m map[netaddr.IP]func()) {
|
||||
t.destIPActivity.Store(m)
|
||||
}
|
||||
|
||||
func (t *TUN) Close() error {
|
||||
var err error
|
||||
t.closeOnce.Do(func() {
|
||||
// Other channels need not be closed: poll will exit gracefully after this.
|
||||
close(t.closed)
|
||||
|
||||
err = t.tdev.Close()
|
||||
})
|
||||
return err
|
||||
}
|
||||
|
||||
func (t *TUN) Events() chan tun.Event {
|
||||
return t.tdev.Events()
|
||||
}
|
||||
|
||||
func (t *TUN) File() *os.File {
|
||||
return t.tdev.File()
|
||||
}
|
||||
|
||||
func (t *TUN) Flush() error {
|
||||
return t.tdev.Flush()
|
||||
}
|
||||
|
||||
func (t *TUN) MTU() (int, error) {
|
||||
return t.tdev.MTU()
|
||||
}
|
||||
|
||||
func (t *TUN) Name() (string, error) {
|
||||
return t.tdev.Name()
|
||||
}
|
||||
|
||||
// poll polls t.tdev.Read, placing the oldest unconsumed packet into t.buffer.
|
||||
// This is needed because t.tdev.Read in general may block (it does on Windows),
|
||||
// so packets may be stuck in t.outbound if t.Read called t.tdev.Read directly.
|
||||
func (t *TUN) poll() {
|
||||
for {
|
||||
select {
|
||||
case <-t.closed:
|
||||
return
|
||||
case <-t.bufferConsumed:
|
||||
// continue
|
||||
}
|
||||
|
||||
// Read may use memory in t.buffer before PacketStartOffset for mandatory headers.
|
||||
// This is the rationale behind the tun.TUN.{Read,Write} interfaces
|
||||
// and the reason t.buffer has size MaxMessageSize and not MaxContentSize.
|
||||
n, err := t.tdev.Read(t.buffer[:], PacketStartOffset)
|
||||
if err != nil {
|
||||
select {
|
||||
case <-t.closed:
|
||||
return
|
||||
case t.errors <- err:
|
||||
// In principle, read errors are not fatal (but wireguard-go disagrees).
|
||||
t.bufferConsumed <- struct{}{}
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
// Wireguard will skip an empty read,
|
||||
// so we might as well do it here to avoid the send through t.outbound.
|
||||
if n == 0 {
|
||||
t.bufferConsumed <- struct{}{}
|
||||
continue
|
||||
}
|
||||
|
||||
select {
|
||||
case <-t.closed:
|
||||
return
|
||||
case t.outbound <- t.buffer[PacketStartOffset : PacketStartOffset+n]:
|
||||
// continue
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var magicDNSIPPort = netaddr.MustParseIPPort("100.100.100.100:0")
|
||||
|
||||
func (t *TUN) filterOut(p *packet.Parsed) filter.Response {
|
||||
// Fake ICMP echo responses to MagicDNS (100.100.100.100).
|
||||
if p.IsEchoRequest() && p.Dst == magicDNSIPPort {
|
||||
header := p.ICMP4Header()
|
||||
header.ToResponse()
|
||||
outp := packet.Generate(&header, p.Payload())
|
||||
t.InjectInboundCopy(outp)
|
||||
return filter.DropSilently // don't pass on to OS; already handled
|
||||
}
|
||||
|
||||
if t.PreFilterOut != nil {
|
||||
if res := t.PreFilterOut(p, t); res.IsDrop() {
|
||||
return res
|
||||
}
|
||||
}
|
||||
|
||||
filt, _ := t.filter.Load().(*filter.Filter)
|
||||
|
||||
if filt == nil {
|
||||
return filter.Drop
|
||||
}
|
||||
|
||||
if filt.RunOut(p, t.filterFlags) != filter.Accept {
|
||||
return filter.Drop
|
||||
}
|
||||
|
||||
if t.PostFilterOut != nil {
|
||||
if res := t.PostFilterOut(p, t); res.IsDrop() {
|
||||
return res
|
||||
}
|
||||
}
|
||||
|
||||
return filter.Accept
|
||||
}
|
||||
|
||||
// noteActivity records that there was a read or write at the current time.
|
||||
func (t *TUN) noteActivity() {
|
||||
atomic.StoreInt64(&t.lastActivityAtomic, time.Now().Unix())
|
||||
}
|
||||
|
||||
// IdleDuration reports how long it's been since the last read or write to this device.
|
||||
//
|
||||
// Its value is only accurate to roughly second granularity.
|
||||
// If there's never been activity, the duration is since 1970.
|
||||
func (t *TUN) IdleDuration() time.Duration {
|
||||
sec := atomic.LoadInt64(&t.lastActivityAtomic)
|
||||
return time.Since(time.Unix(sec, 0))
|
||||
}
|
||||
|
||||
func (t *TUN) Read(buf []byte, offset int) (int, error) {
|
||||
var n int
|
||||
|
||||
wasInjectedPacket := false
|
||||
|
||||
select {
|
||||
case <-t.closed:
|
||||
return 0, io.EOF
|
||||
case err := <-t.errors:
|
||||
return 0, err
|
||||
case pkt := <-t.outbound:
|
||||
n = copy(buf[offset:], pkt)
|
||||
// t.buffer has a fixed location in memory,
|
||||
// so this is the easiest way to tell when it has been consumed.
|
||||
// &pkt[0] can be used because empty packets do not reach t.outbound.
|
||||
if &pkt[0] == &t.buffer[PacketStartOffset] {
|
||||
t.bufferConsumed <- struct{}{}
|
||||
} else {
|
||||
// If the packet is not from t.buffer, then it is an injected packet.
|
||||
wasInjectedPacket = true
|
||||
}
|
||||
}
|
||||
|
||||
p := parsedPacketPool.Get().(*packet.Parsed)
|
||||
defer parsedPacketPool.Put(p)
|
||||
p.Decode(buf[offset : offset+n])
|
||||
|
||||
if m, ok := t.destIPActivity.Load().(map[netaddr.IP]func()); ok {
|
||||
if fn := m[p.Dst.IP]; fn != nil {
|
||||
fn()
|
||||
}
|
||||
}
|
||||
|
||||
// For injected packets, we return early to bypass filtering.
|
||||
if wasInjectedPacket {
|
||||
t.noteActivity()
|
||||
return n, nil
|
||||
}
|
||||
|
||||
if !t.disableFilter {
|
||||
response := t.filterOut(p)
|
||||
if response != filter.Accept {
|
||||
// Wireguard considers read errors fatal; pretend nothing was read
|
||||
return 0, nil
|
||||
}
|
||||
}
|
||||
|
||||
t.noteActivity()
|
||||
return n, nil
|
||||
}
|
||||
|
||||
func (t *TUN) filterIn(buf []byte) filter.Response {
|
||||
p := parsedPacketPool.Get().(*packet.Parsed)
|
||||
defer parsedPacketPool.Put(p)
|
||||
p.Decode(buf)
|
||||
|
||||
if p.IPProto == ipproto.TSMP {
|
||||
if pingReq, ok := p.AsTSMPPing(); ok {
|
||||
t.noteActivity()
|
||||
t.injectOutboundPong(p, pingReq)
|
||||
return filter.DropSilently
|
||||
} else if data, ok := p.AsTSMPPong(); ok {
|
||||
if f := t.OnTSMPPongReceived; f != nil {
|
||||
f(data)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if t.PreFilterIn != nil {
|
||||
if res := t.PreFilterIn(p, t); res.IsDrop() {
|
||||
return res
|
||||
}
|
||||
}
|
||||
|
||||
filt, _ := t.filter.Load().(*filter.Filter)
|
||||
|
||||
if filt == nil {
|
||||
return filter.Drop
|
||||
}
|
||||
|
||||
if filt.RunIn(p, t.filterFlags) != filter.Accept {
|
||||
|
||||
// Tell them, via TSMP, we're dropping them due to the ACL.
|
||||
// Their host networking stack can translate this into ICMP
|
||||
// or whatnot as required. But notably, their GUI or tailscale CLI
|
||||
// can show them a rejection history with reasons.
|
||||
if p.IPVersion == 4 && p.IPProto == ipproto.TCP && p.TCPFlags&packet.TCPSyn != 0 {
|
||||
rj := packet.TailscaleRejectedHeader{
|
||||
IPSrc: p.Dst.IP,
|
||||
IPDst: p.Src.IP,
|
||||
Src: p.Src,
|
||||
Dst: p.Dst,
|
||||
Proto: p.IPProto,
|
||||
Reason: packet.RejectedDueToACLs,
|
||||
}
|
||||
if filt.ShieldsUp() {
|
||||
rj.Reason = packet.RejectedDueToShieldsUp
|
||||
}
|
||||
pkt := packet.Generate(rj, nil)
|
||||
t.InjectOutbound(pkt)
|
||||
|
||||
// TODO(bradfitz): also send a TCP RST, after the TSMP message.
|
||||
}
|
||||
|
||||
return filter.Drop
|
||||
}
|
||||
|
||||
if t.PostFilterIn != nil {
|
||||
if res := t.PostFilterIn(p, t); res.IsDrop() {
|
||||
return res
|
||||
}
|
||||
}
|
||||
|
||||
return filter.Accept
|
||||
}
|
||||
|
||||
// Write accepts an incoming packet. The packet begins at buf[offset:],
|
||||
// like wireguard-go/tun.Device.Write.
|
||||
func (t *TUN) Write(buf []byte, offset int) (int, error) {
|
||||
if !t.disableFilter {
|
||||
res := t.filterIn(buf[offset:])
|
||||
if res == filter.DropSilently {
|
||||
return len(buf), nil
|
||||
}
|
||||
if res != filter.Accept {
|
||||
return 0, ErrFiltered
|
||||
}
|
||||
}
|
||||
|
||||
t.noteActivity()
|
||||
return t.tdev.Write(buf, offset)
|
||||
}
|
||||
|
||||
func (t *TUN) GetFilter() *filter.Filter {
|
||||
filt, _ := t.filter.Load().(*filter.Filter)
|
||||
return filt
|
||||
}
|
||||
|
||||
func (t *TUN) SetFilter(filt *filter.Filter) {
|
||||
t.filter.Store(filt)
|
||||
}
|
||||
|
||||
// InjectInboundDirect makes the TUN device behave as if a packet
|
||||
// with the given contents was received from the network.
|
||||
// It blocks and does not take ownership of the packet.
|
||||
// The injected packet will not pass through inbound filters.
|
||||
//
|
||||
// The packet contents are to start at &buf[offset].
|
||||
// offset must be greater or equal to PacketStartOffset.
|
||||
// The space before &buf[offset] will be used by Wireguard.
|
||||
func (t *TUN) InjectInboundDirect(buf []byte, offset int) error {
|
||||
if len(buf) > MaxPacketSize {
|
||||
return errPacketTooBig
|
||||
}
|
||||
if len(buf) < offset {
|
||||
return errOffsetTooBig
|
||||
}
|
||||
if offset < PacketStartOffset {
|
||||
return errOffsetTooSmall
|
||||
}
|
||||
|
||||
// Write to the underlying device to skip filters.
|
||||
_, err := t.tdev.Write(buf, offset)
|
||||
return err
|
||||
}
|
||||
|
||||
// InjectInboundCopy takes a packet without leading space,
|
||||
// reallocates it to conform to the InjectInboundDirect interface
|
||||
// and calls InjectInboundDirect on it. Injecting a nil packet is a no-op.
|
||||
func (t *TUN) InjectInboundCopy(packet []byte) error {
|
||||
// We duplicate this check from InjectInboundDirect here
|
||||
// to avoid wasting an allocation on an oversized packet.
|
||||
if len(packet) > MaxPacketSize {
|
||||
return errPacketTooBig
|
||||
}
|
||||
if len(packet) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
buf := make([]byte, PacketStartOffset+len(packet))
|
||||
copy(buf[PacketStartOffset:], packet)
|
||||
|
||||
return t.InjectInboundDirect(buf, PacketStartOffset)
|
||||
}
|
||||
|
||||
func (t *TUN) injectOutboundPong(pp *packet.Parsed, req packet.TSMPPingRequest) {
|
||||
pong := packet.TSMPPongReply{
|
||||
Data: req.Data,
|
||||
}
|
||||
switch pp.IPVersion {
|
||||
case 4:
|
||||
h4 := pp.IP4Header()
|
||||
h4.ToResponse()
|
||||
pong.IPHeader = h4
|
||||
case 6:
|
||||
h6 := pp.IP6Header()
|
||||
h6.ToResponse()
|
||||
pong.IPHeader = h6
|
||||
default:
|
||||
return
|
||||
}
|
||||
|
||||
t.InjectOutbound(packet.Generate(pong, nil))
|
||||
}
|
||||
|
||||
// InjectOutbound makes the TUN device behave as if a packet
|
||||
// with the given contents was sent to the network.
|
||||
// It does not block, but takes ownership of the packet.
|
||||
// The injected packet will not pass through outbound filters.
|
||||
// Injecting an empty packet is a no-op.
|
||||
func (t *TUN) InjectOutbound(packet []byte) error {
|
||||
if len(packet) > MaxPacketSize {
|
||||
return errPacketTooBig
|
||||
}
|
||||
if len(packet) == 0 {
|
||||
return nil
|
||||
}
|
||||
select {
|
||||
case <-t.closed:
|
||||
return ErrClosed
|
||||
case t.outbound <- packet:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// Unwrap returns the underlying TUN device.
|
||||
func (t *TUN) Unwrap() tun.Device {
|
||||
return t.tdev
|
||||
}
|
387
net/tstun/wrap_test.go
Normal file
387
net/tstun/wrap_test.go
Normal file
@@ -0,0 +1,387 @@
|
||||
// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package tstun
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"unsafe"
|
||||
|
||||
"github.com/tailscale/wireguard-go/tun/tuntest"
|
||||
"inet.af/netaddr"
|
||||
"tailscale.com/net/packet"
|
||||
"tailscale.com/types/ipproto"
|
||||
"tailscale.com/types/logger"
|
||||
"tailscale.com/wgengine/filter"
|
||||
)
|
||||
|
||||
func udp4(src, dst string, sport, dport uint16) []byte {
|
||||
sip, err := netaddr.ParseIP(src)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
dip, err := netaddr.ParseIP(dst)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
header := &packet.UDP4Header{
|
||||
IP4Header: packet.IP4Header{
|
||||
Src: sip,
|
||||
Dst: dip,
|
||||
IPID: 0,
|
||||
},
|
||||
SrcPort: sport,
|
||||
DstPort: dport,
|
||||
}
|
||||
return packet.Generate(header, []byte("udp_payload"))
|
||||
}
|
||||
|
||||
func nets(nets ...string) (ret []netaddr.IPPrefix) {
|
||||
for _, s := range nets {
|
||||
if i := strings.IndexByte(s, '/'); i == -1 {
|
||||
ip, err := netaddr.ParseIP(s)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
bits := uint8(32)
|
||||
if ip.Is6() {
|
||||
bits = 128
|
||||
}
|
||||
ret = append(ret, netaddr.IPPrefix{IP: ip, Bits: bits})
|
||||
} else {
|
||||
pfx, err := netaddr.ParseIPPrefix(s)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
ret = append(ret, pfx)
|
||||
}
|
||||
}
|
||||
return ret
|
||||
}
|
||||
|
||||
func ports(s string) filter.PortRange {
|
||||
if s == "*" {
|
||||
return filter.PortRange{First: 0, Last: 65535}
|
||||
}
|
||||
|
||||
var fs, ls string
|
||||
i := strings.IndexByte(s, '-')
|
||||
if i == -1 {
|
||||
fs = s
|
||||
ls = fs
|
||||
} else {
|
||||
fs = s[:i]
|
||||
ls = s[i+1:]
|
||||
}
|
||||
first, err := strconv.ParseInt(fs, 10, 16)
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("invalid NetPortRange %q", s))
|
||||
}
|
||||
last, err := strconv.ParseInt(ls, 10, 16)
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("invalid NetPortRange %q", s))
|
||||
}
|
||||
return filter.PortRange{First: uint16(first), Last: uint16(last)}
|
||||
}
|
||||
|
||||
func netports(netPorts ...string) (ret []filter.NetPortRange) {
|
||||
for _, s := range netPorts {
|
||||
i := strings.LastIndexByte(s, ':')
|
||||
if i == -1 {
|
||||
panic(fmt.Sprintf("invalid NetPortRange %q", s))
|
||||
}
|
||||
|
||||
npr := filter.NetPortRange{
|
||||
Net: nets(s[:i])[0],
|
||||
Ports: ports(s[i+1:]),
|
||||
}
|
||||
ret = append(ret, npr)
|
||||
}
|
||||
return ret
|
||||
}
|
||||
|
||||
func setfilter(logf logger.Logf, tun *TUN) {
|
||||
protos := []ipproto.Proto{
|
||||
ipproto.TCP,
|
||||
ipproto.UDP,
|
||||
}
|
||||
matches := []filter.Match{
|
||||
{IPProto: protos, Srcs: nets("5.6.7.8"), Dsts: netports("1.2.3.4:89-90")},
|
||||
{IPProto: protos, Srcs: nets("1.2.3.4"), Dsts: netports("5.6.7.8:98")},
|
||||
}
|
||||
var sb netaddr.IPSetBuilder
|
||||
sb.AddPrefix(netaddr.MustParseIPPrefix("1.2.0.0/16"))
|
||||
tun.SetFilter(filter.New(matches, sb.IPSet(), sb.IPSet(), nil, logf))
|
||||
}
|
||||
|
||||
func newChannelTUN(logf logger.Logf, secure bool) (*tuntest.ChannelTUN, *TUN) {
|
||||
chtun := tuntest.NewChannelTUN()
|
||||
tun := WrapTUN(logf, chtun.TUN())
|
||||
if secure {
|
||||
setfilter(logf, tun)
|
||||
} else {
|
||||
tun.disableFilter = true
|
||||
}
|
||||
return chtun, tun
|
||||
}
|
||||
|
||||
func newFakeTUN(logf logger.Logf, secure bool) (*fakeTUN, *TUN) {
|
||||
ftun := NewFakeTUN()
|
||||
tun := WrapTUN(logf, ftun)
|
||||
if secure {
|
||||
setfilter(logf, tun)
|
||||
} else {
|
||||
tun.disableFilter = true
|
||||
}
|
||||
return ftun.(*fakeTUN), tun
|
||||
}
|
||||
|
||||
func TestReadAndInject(t *testing.T) {
|
||||
chtun, tun := newChannelTUN(t.Logf, false)
|
||||
defer tun.Close()
|
||||
|
||||
const size = 2 // all payloads have this size
|
||||
written := []string{"w0", "w1"}
|
||||
injected := []string{"i0", "i1"}
|
||||
|
||||
go func() {
|
||||
for _, packet := range written {
|
||||
payload := []byte(packet)
|
||||
chtun.Outbound <- payload
|
||||
}
|
||||
}()
|
||||
|
||||
for _, packet := range injected {
|
||||
go func(packet string) {
|
||||
payload := []byte(packet)
|
||||
err := tun.InjectOutbound(payload)
|
||||
if err != nil {
|
||||
t.Errorf("%s: error: %v", packet, err)
|
||||
}
|
||||
}(packet)
|
||||
}
|
||||
|
||||
var buf [MaxPacketSize]byte
|
||||
var seen = make(map[string]bool)
|
||||
// We expect the same packets back, in no particular order.
|
||||
for i := 0; i < len(written)+len(injected); i++ {
|
||||
n, err := tun.Read(buf[:], 0)
|
||||
if err != nil {
|
||||
t.Errorf("read %d: error: %v", i, err)
|
||||
}
|
||||
if n != size {
|
||||
t.Errorf("read %d: got size %d; want %d", i, n, size)
|
||||
}
|
||||
got := string(buf[:n])
|
||||
t.Logf("read %d: got %s", i, got)
|
||||
seen[got] = true
|
||||
}
|
||||
|
||||
for _, packet := range written {
|
||||
if !seen[packet] {
|
||||
t.Errorf("%s not received", packet)
|
||||
}
|
||||
}
|
||||
for _, packet := range injected {
|
||||
if !seen[packet] {
|
||||
t.Errorf("%s not received", packet)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestWriteAndInject(t *testing.T) {
|
||||
chtun, tun := newChannelTUN(t.Logf, false)
|
||||
defer tun.Close()
|
||||
|
||||
const size = 2 // all payloads have this size
|
||||
written := []string{"w0", "w1"}
|
||||
injected := []string{"i0", "i1"}
|
||||
|
||||
go func() {
|
||||
for _, packet := range written {
|
||||
payload := []byte(packet)
|
||||
n, err := tun.Write(payload, 0)
|
||||
if err != nil {
|
||||
t.Errorf("%s: error: %v", packet, err)
|
||||
}
|
||||
if n != size {
|
||||
t.Errorf("%s: got size %d; want %d", packet, n, size)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
for _, packet := range injected {
|
||||
go func(packet string) {
|
||||
payload := []byte(packet)
|
||||
err := tun.InjectInboundCopy(payload)
|
||||
if err != nil {
|
||||
t.Errorf("%s: error: %v", packet, err)
|
||||
}
|
||||
}(packet)
|
||||
}
|
||||
|
||||
seen := make(map[string]bool)
|
||||
// We expect the same packets back, in no particular order.
|
||||
for i := 0; i < len(written)+len(injected); i++ {
|
||||
packet := <-chtun.Inbound
|
||||
got := string(packet)
|
||||
t.Logf("read %d: got %s", i, got)
|
||||
seen[got] = true
|
||||
}
|
||||
|
||||
for _, packet := range written {
|
||||
if !seen[packet] {
|
||||
t.Errorf("%s not received", packet)
|
||||
}
|
||||
}
|
||||
for _, packet := range injected {
|
||||
if !seen[packet] {
|
||||
t.Errorf("%s not received", packet)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestFilter(t *testing.T) {
|
||||
chtun, tun := newChannelTUN(t.Logf, true)
|
||||
defer tun.Close()
|
||||
|
||||
type direction int
|
||||
|
||||
const (
|
||||
in direction = iota
|
||||
out
|
||||
)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
dir direction
|
||||
drop bool
|
||||
data []byte
|
||||
}{
|
||||
{"junk_in", in, true, []byte("\x45not a valid IPv4 packet")},
|
||||
{"junk_out", out, true, []byte("\x45not a valid IPv4 packet")},
|
||||
{"bad_port_in", in, true, udp4("5.6.7.8", "1.2.3.4", 22, 22)},
|
||||
{"bad_port_out", out, false, udp4("1.2.3.4", "5.6.7.8", 22, 22)},
|
||||
{"bad_ip_in", in, true, udp4("8.1.1.1", "1.2.3.4", 89, 89)},
|
||||
{"bad_ip_out", out, false, udp4("1.2.3.4", "8.1.1.1", 98, 98)},
|
||||
{"good_packet_in", in, false, udp4("5.6.7.8", "1.2.3.4", 89, 89)},
|
||||
{"good_packet_out", out, false, udp4("1.2.3.4", "5.6.7.8", 98, 98)},
|
||||
}
|
||||
|
||||
// A reader on the other end of the TUN.
|
||||
go func() {
|
||||
var recvbuf []byte
|
||||
for {
|
||||
select {
|
||||
case <-tun.closed:
|
||||
return
|
||||
case recvbuf = <-chtun.Inbound:
|
||||
// continue
|
||||
}
|
||||
for _, tt := range tests {
|
||||
if tt.drop && bytes.Equal(recvbuf, tt.data) {
|
||||
t.Errorf("did not drop %s", tt.name)
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
var buf [MaxPacketSize]byte
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
var n int
|
||||
var err error
|
||||
var filtered bool
|
||||
|
||||
if tt.dir == in {
|
||||
_, err = tun.Write(tt.data, 0)
|
||||
if err == ErrFiltered {
|
||||
filtered = true
|
||||
err = nil
|
||||
}
|
||||
} else {
|
||||
chtun.Outbound <- tt.data
|
||||
n, err = tun.Read(buf[:], 0)
|
||||
// In the read direction, errors are fatal, so we return n = 0 instead.
|
||||
filtered = (n == 0)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("got err %v; want nil", err)
|
||||
}
|
||||
|
||||
if filtered {
|
||||
if !tt.drop {
|
||||
t.Errorf("got drop; want accept")
|
||||
}
|
||||
} else {
|
||||
if tt.drop {
|
||||
t.Errorf("got accept; want drop")
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAllocs(t *testing.T) {
|
||||
ftun, tun := newFakeTUN(t.Logf, false)
|
||||
defer tun.Close()
|
||||
|
||||
buf := []byte{0x00}
|
||||
allocs := testing.AllocsPerRun(100, func() {
|
||||
_, err := ftun.Write(buf, 0)
|
||||
if err != nil {
|
||||
t.Errorf("write: error: %v", err)
|
||||
return
|
||||
}
|
||||
})
|
||||
|
||||
if allocs > 0 {
|
||||
t.Errorf("read allocs = %v; want 0", allocs)
|
||||
}
|
||||
}
|
||||
|
||||
func TestClose(t *testing.T) {
|
||||
ftun, tun := newFakeTUN(t.Logf, false)
|
||||
|
||||
data := udp4("1.2.3.4", "5.6.7.8", 98, 98)
|
||||
_, err := ftun.Write(data, 0)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
tun.Close()
|
||||
_, err = ftun.Write(data, 0)
|
||||
if err == nil {
|
||||
t.Error("Expected error from ftun.Write() after Close()")
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkWrite(b *testing.B) {
|
||||
ftun, tun := newFakeTUN(b.Logf, true)
|
||||
defer tun.Close()
|
||||
|
||||
packet := udp4("5.6.7.8", "1.2.3.4", 89, 89)
|
||||
for i := 0; i < b.N; i++ {
|
||||
_, err := ftun.Write(packet, 0)
|
||||
if err != nil {
|
||||
b.Errorf("err = %v; want nil", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestAtomic64Alignment(t *testing.T) {
|
||||
off := unsafe.Offsetof(TUN{}.lastActivityAtomic)
|
||||
if off%8 != 0 {
|
||||
t.Errorf("offset %v not 8-byte aligned", off)
|
||||
}
|
||||
|
||||
c := new(TUN)
|
||||
atomic.StoreInt64(&c.lastActivityAtomic, 123)
|
||||
}
|
24
net/tstun/wrap_windows.go
Normal file
24
net/tstun/wrap_windows.go
Normal file
@@ -0,0 +1,24 @@
|
||||
// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package tstun
|
||||
|
||||
import (
|
||||
"github.com/tailscale/wireguard-go/tun"
|
||||
"github.com/tailscale/wireguard-go/tun/wintun"
|
||||
"golang.org/x/sys/windows"
|
||||
)
|
||||
|
||||
func init() {
|
||||
var err error
|
||||
tun.WintunPool, err = wintun.MakePool("Tailscale")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
guid, err := windows.GUIDFromString("{37217669-42da-4657-a55b-0d995d328250}")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
tun.WintunStaticRequestedGUID = &guid
|
||||
}
|
Reference in New Issue
Block a user