Move Linux client & common packages into a public repo.

This commit is contained in:
Earl Lee
2020-02-05 14:16:58 -08:00
parent c955043dfe
commit a8d8b8719a
156 changed files with 17113 additions and 0 deletions

55
wgengine/faketun.go Normal file
View File

@@ -0,0 +1,55 @@
// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package wgengine
import (
"github.com/tailscale/wireguard-go/tun"
"io"
"os"
)
type fakeTun struct {
datachan chan []byte
evchan chan tun.Event
closechan chan struct{}
}
func NewFakeTun() tun.Device {
return &fakeTun{
datachan: make(chan []byte),
evchan: make(chan tun.Event),
closechan: make(chan struct{}),
}
}
func (t *fakeTun) File() *os.File {
panic("fakeTun.File() called, which makes no sense")
}
func (t *fakeTun) Close() error {
close(t.closechan)
close(t.datachan)
return nil
}
func (t *fakeTun) InsertRead(b []byte) {
t.datachan <- b
}
func (t *fakeTun) Read(out []byte, offset int) (int, error) {
select {
case <-t.closechan:
return 0, io.EOF
case b := <-t.datachan:
copy(out[offset:offset+len(b)], b)
return len(b), nil
}
}
func (t *fakeTun) Write(b []byte, n int) (int, error) { return len(b), nil }
func (t *fakeTun) Flush() error { return nil }
func (t *fakeTun) MTU() (int, error) { return 1500, nil }
func (t *fakeTun) Name() (string, error) { return "FakeTun", nil }
func (t *fakeTun) Events() chan tun.Event { return t.evchan }

218
wgengine/filter/filter.go Normal file
View File

@@ -0,0 +1,218 @@
// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package filter
import (
"fmt"
"log"
"sync"
"time"
"github.com/golang/groupcache/lru"
"tailscale.com/ratelimit"
"tailscale.com/wgengine/packet"
)
type Filter struct {
matches Matches
udpMu sync.Mutex
udplru *lru.Cache
}
type Response int
const (
Drop Response = iota
Accept
noVerdict // Returned from subfilters to continue processing.
)
func (r Response) String() string {
switch r {
case Drop:
return "Drop"
case Accept:
return "Accept"
case noVerdict:
return "noVerdict"
default:
return "???"
}
}
type RunFlags int
const (
LogDrops RunFlags = 1 << iota
LogAccepts
HexdumpDrops
HexdumpAccepts
)
type tuple struct {
SrcIP IP
DstIP IP
SrcPort uint16
DstPort uint16
}
const LRU_MAX = 512 // max entries in UDP LRU cache
var MatchAllowAll = Matches{
Match{[]IPPortRange{IPPortRangeAny}, []IP{IPAny}},
}
func NewAllowAll() *Filter {
return New(MatchAllowAll)
}
func NewAllowNone() *Filter {
return New(nil)
}
func New(matches Matches) *Filter {
f := &Filter{
matches: matches,
udplru: lru.New(LRU_MAX),
}
return f
}
func maybeHexdump(flag RunFlags, b []byte) string {
if flag != 0 {
return packet.Hexdump(b) + "\n"
} else {
return ""
}
}
// TODO(apenwarr): use a bigger bucket for specifically TCP SYN accept logging?
// Logging is a quick way to record every newly opened TCP connection, but
// we have to be cautious about flooding the logs vs letting people use
// flood protection to hide their traffic. We could use a rate limiter in
// the actual *filter* for SYN accepts, perhaps.
var acceptBucket = ratelimit.Bucket{
Burst: 3,
FillInterval: 10 * time.Second,
}
var dropBucket = ratelimit.Bucket{
Burst: 10,
FillInterval: 5 * time.Second,
}
func logRateLimit(runflags RunFlags, b []byte, q *packet.QDecode, r Response, why string) {
if r == Drop && (runflags&LogDrops) != 0 && dropBucket.TryGet() > 0 {
var qs string
if q == nil {
qs = fmt.Sprintf("(%d bytes)", len(b))
} else {
qs = q.String()
}
log.Printf("Drop: %v %v %s\n%s", qs, len(b), why, maybeHexdump(runflags&HexdumpDrops, b))
} else if r == Accept && (runflags&LogAccepts) != 0 && acceptBucket.TryGet() > 0 {
log.Printf("Accept: %v %v %s\n%s", q, len(b), why, maybeHexdump(runflags&HexdumpAccepts, b))
}
}
func (f *Filter) RunIn(b []byte, q *packet.QDecode, rf RunFlags) Response {
r := pre(b, q, rf)
if r == Accept || r == Drop {
// already logged
return r
}
r, why := f.runIn(q)
logRateLimit(rf, b, q, r, why)
return r
}
func (f *Filter) RunOut(b []byte, q *packet.QDecode, rf RunFlags) Response {
r := pre(b, q, rf)
if r == Drop || r == Accept {
// already logged
return r
}
r, why := f.runOut(q)
logRateLimit(rf, b, q, r, why)
return r
}
func (f *Filter) runIn(q *packet.QDecode) (r Response, why string) {
switch q.IPProto {
case packet.ICMP:
// If any port is open to an IP, allow ICMP to it.
if matchIPWithoutPorts(f.matches, q) {
return Accept, "icmp ok"
}
case packet.TCP:
// For TCP, we want to allow *outgoing* connections,
// which means we want to allow return packets on those
// connections. To make this restriction work, we need to
// allow non-SYN packets (continuation of an existing session)
// to arrive. This should be okay since a new incoming session
// can't be initiated without first sending a SYN.
// It happens to also be much faster.
// TODO(apenwarr): Skip the rest of decoding in this path?
if q.IPProto == packet.TCP && !q.IsTCPSyn() {
return Accept, "tcp non-syn"
}
if matchIPPorts(f.matches, q) {
return Accept, "tcp ok"
}
case packet.UDP:
t := tuple{q.SrcIP, q.DstIP, q.SrcPort, q.DstPort}
f.udpMu.Lock()
_, ok := f.udplru.Get(t)
f.udpMu.Unlock()
if ok {
return Accept, "udp cached"
}
if matchIPPorts(f.matches, q) {
return Accept, "udp ok"
}
default:
return Drop, "Unknown proto"
}
return Drop, "no rules matched"
}
func (f *Filter) runOut(q *packet.QDecode) (r Response, why string) {
if q.IPProto == packet.UDP {
t := tuple{q.DstIP, q.SrcIP, q.DstPort, q.SrcPort}
f.udpMu.Lock()
f.udplru.Add(t, t)
f.udpMu.Unlock()
}
return Accept, "ok out"
}
func pre(b []byte, q *packet.QDecode, rf RunFlags) Response {
if len(b) == 0 {
// wireguard keepalive packet, always permit.
return Accept
}
if len(b) < 20 {
logRateLimit(rf, b, nil, Drop, "too short")
return Drop
}
q.Decode(b)
if q.IPProto == packet.Junk {
// Junk packets are dangerous; always drop them.
logRateLimit(rf, b, q, Drop, "junk!")
return Drop
} else if q.IPProto == packet.Fragment {
// Fragments after the first always need to be passed through.
// Very small fragments are considered Junk by QDecode.
logRateLimit(rf, b, q, Accept, "fragment")
return Accept
}
return noVerdict
}

View File

@@ -0,0 +1,162 @@
// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package filter
import (
"encoding/binary"
"encoding/json"
"net"
"testing"
"tailscale.com/wgengine/packet"
)
type QDecode = packet.QDecode
var Junk = packet.Junk
var ICMP = packet.ICMP
var TCP = packet.TCP
var UDP = packet.UDP
var Fragment = packet.Fragment
func ippr(ip IP, start, end uint16) []IPPortRange {
return []IPPortRange{
IPPortRange{ip, PortRange{start, end}},
}
}
func TestFilter(t *testing.T) {
mm := Matches{
{SrcIPs: []IP{0x08010101, 0x08020202}, DstPorts: []IPPortRange{
IPPortRange{0x01020304, PortRange{22, 22}},
IPPortRange{0x05060708, PortRange{23, 24}},
}},
{SrcIPs: []IP{0x08010101, 0x08020202}, DstPorts: ippr(0x05060708, 27, 28)},
{SrcIPs: []IP{0x02020202}, DstPorts: ippr(0x08010101, 22, 22)},
{SrcIPs: []IP{0}, DstPorts: ippr(0x647a6232, 0, 65535)},
{SrcIPs: []IP{0}, DstPorts: ippr(0, 443, 443)},
{SrcIPs: []IP{0x99010101, 0x99010102, 0x99030303}, DstPorts: ippr(0x01020304, 999, 999)},
}
acl := New(mm)
for _, ent := range []Matches{Matches{mm[0]}, mm} {
b, err := json.Marshal(ent)
if err != nil {
t.Fatalf("marshal: %v", err)
}
mm2 := Matches{}
if err := json.Unmarshal(b, &mm2); err != nil {
t.Fatalf("unmarshal: %v (%v)", err, string(b))
}
}
// check packet filtering based on the table
type InOut struct {
want Response
p QDecode
}
tests := []InOut{
// Basic
{Accept, qdecode(TCP, 0x08010101, 0x01020304, 999, 22)},
{Accept, qdecode(UDP, 0x08010101, 0x01020304, 999, 22)},
{Accept, qdecode(ICMP, 0x08010101, 0x01020304, 0, 0)},
{Drop, qdecode(TCP, 0x08010101, 0x01020304, 0, 0)},
{Accept, qdecode(TCP, 0x08010101, 0x01020304, 0, 22)},
{Drop, qdecode(TCP, 0x08010101, 0x01020304, 0, 21)},
{Accept, qdecode(TCP, 0x11223344, 0x22334455, 0, 443)},
{Drop, qdecode(TCP, 0x11223344, 0x22334455, 0, 444)},
{Accept, qdecode(TCP, 0x11223344, 0x647a6232, 0, 999)},
{Accept, qdecode(TCP, 0x11223344, 0x647a6232, 0, 0)},
// Stateful UDP.
// Initially empty cache
{Drop, qdecode(UDP, 0x77777777, 0x66666666, 4242, 4343)},
// Return packet from previous attempt is allowed
{Accept, qdecode(UDP, 0x66666666, 0x77777777, 4343, 4242)},
// Because of the return above, initial attempt is allowed now
{Accept, qdecode(UDP, 0x77777777, 0x66666666, 4242, 4343)},
}
for i, test := range tests {
if got, _ := acl.runIn(&test.p); test.want != got {
t.Errorf("#%d got=%v want=%v packet:%v\n", i, got, test.want, test.p)
}
// Update UDP state
_, _ = acl.runOut(&test.p)
}
}
func TestPreFilter(t *testing.T) {
packets := []struct {
desc string
want Response
b []byte
}{
{"empty", Accept, []byte{}},
{"short", Drop, []byte("short")},
{"junk", Drop, rawpacket(Junk, 10)},
{"fragment", Accept, rawpacket(Fragment, 40)},
{"tcp", noVerdict, rawpacket(TCP, 200)},
{"udp", noVerdict, rawpacket(UDP, 200)},
{"icmp", noVerdict, rawpacket(ICMP, 200)},
}
for _, testPacket := range packets {
got := pre([]byte(testPacket.b), &QDecode{}, LogDrops|LogAccepts)
if got != testPacket.want {
t.Errorf("%q got=%v want=%v packet:\n%s", testPacket.desc, got, testPacket.want, packet.Hexdump(testPacket.b))
}
}
}
func qdecode(proto packet.IPProto, src, dst packet.IP, sport, dport uint16) QDecode {
return QDecode{
IPProto: proto,
SrcIP: src,
DstIP: dst,
SrcPort: sport,
DstPort: dport,
TCPFlags: packet.TCPSyn,
}
}
func rawpacket(proto packet.IPProto, len uint16) []byte {
bl := len
if len < 24 {
bl = 24
}
bin := binary.BigEndian
hdr := make([]byte, bl)
hdr[0] = 0x45
bin.PutUint16(hdr[2:4], len)
hdr[8] = 64
ip := net.IPv4(8, 8, 8, 8).To4()
copy(hdr[12:16], ip)
copy(hdr[16:20], ip)
// ports
bin.PutUint16(hdr[20:22], 53)
bin.PutUint16(hdr[22:24], 53)
switch proto {
case ICMP:
hdr[9] = 1
case TCP:
hdr[9] = 6
case UDP:
hdr[9] = 17
case Fragment:
hdr[9] = 6
// flags + fragOff
bin.PutUint16(hdr[6:8], (1<<13)|1234)
case Junk:
default:
panic("unknown protocol")
}
// Truncate the header if requested
hdr = hdr[:len]
return hdr
}

121
wgengine/filter/match.go Normal file
View File

@@ -0,0 +1,121 @@
// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package filter
import (
"fmt"
"strings"
"tailscale.com/wgengine/packet"
)
type IP = packet.IP
const IPAny = IP(0)
var NewIP = packet.NewIP
type PortRange struct {
First, Last uint16
}
var PortRangeAny = PortRange{0, 65535}
func (pr PortRange) String() string {
if pr.First == 0 && pr.Last == 65535 {
return "*"
} else if pr.First == pr.Last {
return fmt.Sprintf("%d", pr.First)
} else {
return fmt.Sprintf("%d-%d", pr.First, pr.Last)
}
}
type IPPortRange struct {
IP IP
Ports PortRange
}
var IPPortRangeAny = IPPortRange{IPAny, PortRangeAny}
func (ipr IPPortRange) String() string {
return fmt.Sprintf("%v:%v", ipr.IP, ipr.Ports)
}
type Match struct {
DstPorts []IPPortRange
SrcIPs []IP
}
func (m Match) String() string {
srcs := []string{}
for _, srcip := range m.SrcIPs {
srcs = append(srcs, srcip.String())
}
dsts := []string{}
for _, dst := range m.DstPorts {
dsts = append(dsts, dst.String())
}
var ss, ds string
if len(srcs) == 1 {
ss = srcs[0]
} else {
ss = "[" + strings.Join(srcs, ",") + "]"
}
if len(dsts) == 1 {
ds = dsts[0]
} else {
ds = "[" + strings.Join(dsts, ",") + "]"
}
return fmt.Sprintf("%v=>%v", ss, ds)
}
type Matches []Match
func ipInList(ip IP, iplist []IP) bool {
for _, ipp := range iplist {
if ipp == IPAny || ipp == ip {
return true
}
}
return false
}
func matchIPPorts(mm Matches, q *packet.QDecode) bool {
for _, acl := range mm {
for _, dst := range acl.DstPorts {
if dst.IP != IPAny && dst.IP != q.DstIP {
continue
}
if q.DstPort < dst.Ports.First || q.DstPort > dst.Ports.Last {
continue
}
if !ipInList(q.SrcIP, acl.SrcIPs) {
// Skip other dests in this acl, since
// the src will never match.
break
}
return true
}
}
return false
}
func matchIPWithoutPorts(mm Matches, q *packet.QDecode) bool {
for _, acl := range mm {
for _, dst := range acl.DstPorts {
if dst.IP != IPAny && dst.IP != q.DstIP {
continue
}
if !ipInList(q.SrcIP, acl.SrcIPs) {
// Skip other dests in this acl, since
// the src will never match.
break
}
return true
}
}
return false
}

View File

@@ -0,0 +1,411 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2019 WireGuard LLC. All Rights Reserved.
*/
package wgengine
import (
"bytes"
"encoding/binary"
"errors"
"fmt"
"log"
"net"
"sort"
"time"
"unsafe"
"github.com/go-ole/go-ole"
"github.com/tailscale/wireguard-go/device"
"github.com/tailscale/wireguard-go/tun"
"github.com/tailscale/wireguard-go/wgcfg"
"golang.org/x/sys/windows"
"golang.org/x/sys/windows/registry"
"golang.zx2c4.com/winipcfg"
"tailscale.com/wgengine/winnet"
)
const (
sockoptIP_UNICAST_IF = 31
sockoptIPV6_UNICAST_IF = 31
)
func htonl(val uint32) uint32 {
bytes := make([]byte, 4)
binary.BigEndian.PutUint32(bytes, val)
return *(*uint32)(unsafe.Pointer(&bytes[0]))
}
func bindSocketRoute(family winipcfg.AddressFamily, device *device.Device, ourLuid uint64, lastLuid *uint64) error {
routes, err := winipcfg.GetRoutes(family)
if err != nil {
return err
}
lowestMetric := ^uint32(0)
index := uint32(0) // Zero is "unspecified", which for IP_UNICAST_IF resets the value, which is what we want.
luid := uint64(0) // Hopefully luid zero is unspecified, but hard to find docs saying so.
for _, route := range routes {
if route.DestinationPrefix.PrefixLength != 0 || route.InterfaceLuid == ourLuid {
continue
}
if route.Metric < lowestMetric {
lowestMetric = route.Metric
index = route.InterfaceIndex
luid = route.InterfaceLuid
}
}
if luid == *lastLuid {
return nil
}
*lastLuid = luid
if false {
// TODO(apenwarr): doesn't work with magic socket yet.
if family == winipcfg.AF_INET {
return device.BindSocketToInterface4(index, false)
} else if family == winipcfg.AF_INET6 {
return device.BindSocketToInterface6(index, false)
}
} else {
log.Printf("WARNING: skipping windows socket binding.\n")
}
return nil
}
func MonitorDefaultRoutes(device *device.Device, autoMTU bool, tun *tun.NativeTun) (*winipcfg.RouteChangeCallback, error) {
guid := tun.GUID()
ourLuid, err := winipcfg.InterfaceGuidToLuid(&guid)
lastLuid4 := uint64(0)
lastLuid6 := uint64(0)
lastMtu := uint32(0)
if err != nil {
return nil, err
}
doIt := func() error {
err = bindSocketRoute(winipcfg.AF_INET, device, ourLuid, &lastLuid4)
if err != nil {
return err
}
err = bindSocketRoute(winipcfg.AF_INET6, device, ourLuid, &lastLuid6)
if err != nil {
return err
}
if !autoMTU {
return nil
}
mtu := uint32(0)
if lastLuid4 != 0 {
iface, err := winipcfg.InterfaceFromLUID(lastLuid4)
if err != nil {
return err
}
if iface.Mtu > 0 {
mtu = iface.Mtu
}
}
if lastLuid6 != 0 {
iface, err := winipcfg.InterfaceFromLUID(lastLuid6)
if err != nil {
return err
}
if iface.Mtu > 0 && iface.Mtu < mtu {
mtu = iface.Mtu
}
}
if mtu > 0 && (lastMtu == 0 || lastMtu != mtu) {
iface, err := winipcfg.GetIpInterface(ourLuid, winipcfg.AF_INET)
if err != nil {
return err
}
iface.NlMtu = mtu - 80
if iface.NlMtu < 576 {
iface.NlMtu = 576
}
err = iface.Set()
if err != nil {
return err
}
tun.ForceMTU(int(iface.NlMtu)) //TODO: it sort of breaks the model with v6 mtu and v4 mtu being different. Just set v4 one for now.
iface, err = winipcfg.GetIpInterface(ourLuid, winipcfg.AF_INET6)
if err != nil {
return err
}
iface.NlMtu = mtu - 80
if iface.NlMtu < 1280 {
iface.NlMtu = 1280
}
err = iface.Set()
if err != nil {
return err
}
lastMtu = mtu
}
return nil
}
err = doIt()
if err != nil {
return nil, err
}
cb, err := winipcfg.RegisterRouteChangeCallback(func(notificationType winipcfg.MibNotificationType, route *winipcfg.Route) {
//fmt.Printf("MonitorDefaultRoutes: changed: %v\n", route.DestinationPrefix)
if route.DestinationPrefix.PrefixLength == 0 {
_ = doIt()
}
})
if err != nil {
return nil, err
}
return cb, nil
}
func setDNSDomains(g windows.GUID, dnsDomains []string) {
gs := g.String()
log.Printf("setDNSDomains(%v) guid=%v\n", dnsDomains, gs)
p := `SYSTEM\CurrentControlSet\Services\Tcpip\Parameters\Interfaces\` + gs
key, err := registry.OpenKey(registry.LOCAL_MACHINE, p, registry.READ|registry.SET_VALUE)
if err != nil {
log.Printf("setDNSDomains(%v): open: %v\n", p, err)
return
}
defer key.Close()
// Windows only supports a single per-interface DNS domain.
dom := ""
if len(dnsDomains) > 0 {
dom = dnsDomains[0]
}
err = key.SetStringValue("Domain", dom)
if err != nil {
log.Printf("setDNSDomains(%v): SetStringValue: %v\n", p, err)
}
}
func setFirewall(ifcGUID *windows.GUID) (bool, error) {
c := ole.Connection{}
err := c.Initialize()
if err != nil {
panic(err)
}
defer c.Uninitialize()
m, err := winnet.NewNetworkListManager(&c)
if err != nil {
panic(err)
}
defer m.Release()
cl, err := m.GetNetworkConnections()
if err != nil {
panic(err)
}
defer cl.Release()
for _, nco := range cl {
aid, err := nco.GetAdapterId()
if err != nil {
panic(err)
}
if aid != ifcGUID.String() {
log.Printf("skipping adapter id: %v\n", aid)
continue
}
log.Printf("found! adapter id: %v\n", aid)
n, err := nco.GetNetwork()
if err != nil {
return false, fmt.Errorf("GetNetwork: %v", err)
}
defer n.Release()
cat, err := n.GetCategory()
if err != nil {
return false, fmt.Errorf("GetCategory: %v", err)
}
if cat == 0 {
err = n.SetCategory(1)
if err != nil {
return false, fmt.Errorf("SetCategory: %v", err)
}
} else {
log.Printf("setFirewall: already category %v\n", cat)
}
return true, nil
}
return false, nil
}
func ConfigureInterface(m *wgcfg.Config, tun *tun.NativeTun, dns []net.IP, dnsDomains []string) error {
const mtu = 0
guid := tun.GUID()
log.Printf("wintun GUID is %v\n", guid)
iface, err := winipcfg.InterfaceFromGUID(&guid)
if err != nil {
return err
}
go func() {
// It takes a weirdly long time for Windows to notice the
// new interface has come up. Poll periodically until it
// does.
for i := 0; i < 20; i++ {
found, err := setFirewall(&guid)
if err != nil {
log.Printf("setFirewall: %v\n", err)
// fall through anyway, this isn't fatal.
}
if found {
break
}
time.Sleep(1 * time.Second)
}
}()
setDNSDomains(guid, dnsDomains)
routes := []winipcfg.RouteData{}
var firstGateway4 *net.IP
var firstGateway6 *net.IP
addresses := make([]*net.IPNet, len(m.Interface.Addresses))
for i, addr := range m.Interface.Addresses {
ipnet := addr.IPNet()
addresses[i] = ipnet
gateway := ipnet.IP
if addr.IP.Is4() && firstGateway4 == nil {
firstGateway4 = &gateway
} else if addr.IP.Is6() && firstGateway6 == nil {
firstGateway6 = &gateway
}
}
foundDefault4 := false
foundDefault6 := false
for _, peer := range m.Peers {
for _, allowedip := range peer.AllowedIPs {
if (allowedip.IP.Is4() && firstGateway4 == nil) || (allowedip.IP.Is6() && firstGateway6 == nil) {
return errors.New("Due to a Windows limitation, one cannot have interface routes without an interface address")
}
ipn := allowedip.IPNet()
var gateway net.IP
if allowedip.IP.Is4() {
gateway = *firstGateway4
} else if allowedip.IP.Is6() {
gateway = *firstGateway6
}
r := winipcfg.RouteData{
Destination: net.IPNet{
IP: ipn.IP.Mask(ipn.Mask),
Mask: ipn.Mask,
},
NextHop: gateway,
Metric: 0,
}
if bytes.Compare(r.Destination.IP, gateway) == 0 {
// no need to add a route for the interface's
// own IP. The kernel does that for us.
// If we try to replace it, we'll fail to
// add the route unless NextHop is set, but
// then the interface's IP won't be pingable.
continue
}
if allowedip.IP.Is4() {
if allowedip.Mask == 0 {
foundDefault4 = true
}
r.NextHop = *firstGateway4
} else if allowedip.IP.Is6() {
if allowedip.Mask == 0 {
foundDefault6 = true
}
r.NextHop = *firstGateway6
}
routes = append(routes, r)
}
}
err = iface.SetAddresses(addresses)
if err != nil {
return err
}
sort.Slice(routes, func(i, j int) bool {
return (bytes.Compare(routes[i].Destination.IP, routes[j].Destination.IP) == -1 ||
// Narrower masks first
bytes.Compare(routes[i].Destination.Mask, routes[j].Destination.Mask) == 1 ||
// No nexthop before non-empty nexthop
bytes.Compare(routes[i].NextHop, routes[j].NextHop) == -1 ||
// Lower metrics first
routes[i].Metric < routes[j].Metric)
})
deduplicatedRoutes := []*winipcfg.RouteData{}
for i := 0; i < len(routes); i++ {
// There's only one way to get to a given IP+Mask, so delete
// all matches after the first.
if i > 0 &&
bytes.Equal(routes[i].Destination.IP, routes[i-1].Destination.IP) &&
bytes.Equal(routes[i].Destination.Mask, routes[i-1].Destination.Mask) {
continue
}
deduplicatedRoutes = append(deduplicatedRoutes, &routes[i])
}
log.Printf("routes: %v\n", routes)
var errAcc error
err = iface.SetRoutes(deduplicatedRoutes)
if err != nil && errAcc == nil {
log.Printf("setroutes: %v\n", err)
errAcc = err
}
err = iface.SetDNS(dns)
if err != nil && errAcc == nil {
log.Printf("setdns: %v\n", err)
errAcc = err
}
ipif, err := iface.GetIpInterface(winipcfg.AF_INET)
if err != nil {
log.Printf("getipif: %v\n", err)
return err
}
log.Printf("foundDefault4: %v\n", foundDefault4)
if foundDefault4 {
ipif.UseAutomaticMetric = false
ipif.Metric = 0
}
if mtu > 0 {
ipif.NlMtu = uint32(mtu)
tun.ForceMTU(int(ipif.NlMtu))
}
err = ipif.Set()
if err != nil && errAcc == nil {
errAcc = err
}
ipif, err = iface.GetIpInterface(winipcfg.AF_INET6)
if err != nil {
return err
}
if err != nil && errAcc == nil {
errAcc = err
}
if foundDefault6 {
ipif.UseAutomaticMetric = false
ipif.Metric = 0
}
if mtu > 0 {
ipif.NlMtu = uint32(mtu)
}
ipif.DadTransmits = 0
ipif.RouterDiscoveryBehavior = winipcfg.RouterDiscoveryDisabled
err = ipif.Set()
if err != nil && errAcc == nil {
errAcc = err
}
return errAcc
}

View File

@@ -0,0 +1,815 @@
// Copyright 2019 Tailscale & AUTHORS. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package magicsock implements a socket that can change its communication path while
// in use, actively searching for the best way to communicate.
package magicsock
import (
"context"
"encoding/binary"
"fmt"
"log"
"net"
"strings"
"sync"
"syscall"
"time"
"github.com/tailscale/wireguard-go/device"
"github.com/tailscale/wireguard-go/wgcfg"
"tailscale.com/derp/derphttp"
"tailscale.com/stun"
"tailscale.com/stunner"
)
// A Conn routes UDP packets and actively manages a list of its endpoints.
// It implements wireguard/device.Bind.
type Conn struct {
pconn *RebindingUDPConn
pconnPort uint16
stunServers []string
derpServer string
startEpUpdate chan struct{} // send to trigger endpoint update
epUpdateCancel func()
epFunc func(endpoints []string)
logf func(format string, args ...interface{})
// indexedAddrs is a map of every remote ip:port to a priority
// list of endpoint addresses for a peer.
// The priority list is provided by wgengine configuration.
//
// Given a wgcfg describing:
// machineA: 10.0.0.1:1, 10.0.0.2:2
// machineB: 10.0.0.3:3
// the indexedAddrs map contains:
// 10.0.0.1:1 -> [10.0.0.1:1, 10.0.0.2:2], index:0
// 10.0.0.2:2 -> [10.0.0.1:1, 10.0.0.2:2], index:1
// 10.0.0.3:3 -> [10.0.0.3:3], index:0
indexedAddrsMu sync.Mutex
indexedAddrs map[udpAddr]indexedAddrSet
stunReceiveMu sync.Mutex
stunReceive func(p []byte, fromAddr *net.UDPAddr)
derpMu sync.Mutex
derp *derphttp.Client
}
// udpAddr is the key in the indexedAddrs map.
// It maps an ip:port onto an indexedAddr.
type udpAddr struct {
ip wgcfg.IP
port uint16
}
// indexedAddrSet is an AddrSet (a priority list of ip:ports for a peer and the
// current favored ip:port for communicating with the peer) and an index
// number saying which element of the priority list is this map entry.
type indexedAddrSet struct {
addr *AddrSet
index int // index of map key in addr.Addrs
}
const DefaultPort = 0
const DefaultDERP = "https://derp.tailscale.com/derp"
var DefaultSTUN = []string{
"stun.l.google.com:19302",
"stun3.l.google.com:19302",
}
// Options contains options for Listen.
type Options struct {
// Port is the port to listen on.
// Zero means to pick one automatically.
Port uint16
STUN []string
DERP string
// EndpointsFunc optionally provides a func to be called when
// endpoints change. The called func does not own the slice.
EndpointsFunc func(endpoint []string)
}
func (o *Options) endpointsFunc() func([]string) {
if o == nil || o.EndpointsFunc == nil {
return func([]string) {}
}
return o.EndpointsFunc
}
// Listen creates a magic Conn listening on opts.Port.
// As the set of possible endpoints for a Conn changes, the
// callback opts.EndpointsFunc is called.
func Listen(opts Options) (*Conn, error) {
var packetConn net.PacketConn
var err error
if opts.Port == 0 {
// Our choice of port. Start with DefaultPort.
// If unavailable, pick any port.
want := fmt.Sprintf(":%d", DefaultPort)
log.Printf("magicsock: bind: trying %v\n", want)
packetConn, err = net.ListenPacket("udp4", want)
if err != nil {
want = ":0"
log.Printf("magicsock: bind: falling back to %v (%v)\n", want, err)
packetConn, err = net.ListenPacket("udp4", want)
}
} else {
packetConn, err = net.ListenPacket("udp4", fmt.Sprintf(":%d", opts.Port))
}
if err != nil {
return nil, fmt.Errorf("magicsock.Listen: %v", err)
}
epUpdateCtx, epUpdateCancel := context.WithCancel(context.Background())
c := &Conn{
pconn: new(RebindingUDPConn),
stunServers: append([]string{}, opts.STUN...),
derpServer: opts.DERP,
startEpUpdate: make(chan struct{}, 1),
epUpdateCancel: epUpdateCancel,
epFunc: opts.endpointsFunc(),
logf: log.Printf,
indexedAddrs: make(map[udpAddr]indexedAddrSet),
}
c.pconn.Reset(packetConn.(*net.UDPConn))
c.startEpUpdate <- struct{}{} // STUN immediately on start
go c.epUpdate(epUpdateCtx)
return c, nil
}
func (c *Conn) epUpdate(ctx context.Context) {
var lastEndpoints []string
var lastCancel func()
var lastDone chan struct{}
for {
select {
case <-ctx.Done():
if lastCancel != nil {
lastCancel()
}
return
case <-c.startEpUpdate:
}
if lastCancel != nil {
lastCancel()
<-lastDone
}
var epCtx context.Context
epCtx, lastCancel = context.WithCancel(ctx)
lastDone = make(chan struct{})
go func() {
defer close(lastDone)
endpoints, err := c.determineEndpoints(epCtx)
if err != nil {
c.logf("magicsock.Conn: endpoint update failed: %v", err)
// TODO(crawshaw): are there any conditions under which
// we should trigger a retry based on the error here?
return
}
if stringsEqual(endpoints, lastEndpoints) {
return
}
lastEndpoints = endpoints
c.epFunc(endpoints)
}()
}
}
func (c *Conn) determineEndpoints(ctx context.Context) ([]string, error) {
var alreadyMu sync.Mutex
already := make(map[string]struct{})
var eps []string
addAddr := func(s, reason string) {
log.Printf("magicsock: found local %s (%s)\n", s, reason)
alreadyMu.Lock()
defer alreadyMu.Unlock()
if _, ok := already[s]; !ok {
already[s] = struct{}{}
eps = append(eps, s)
}
}
s := &stunner.Stunner{
Send: c.pconn.WriteTo,
Endpoint: func(s string) { addAddr(s, "stun") },
Servers: c.stunServers,
Logf: c.logf,
}
c.stunReceiveMu.Lock()
c.stunReceive = s.Receive
c.stunReceiveMu.Unlock()
if err := s.Run(ctx); err != nil {
return nil, err
}
c.stunReceiveMu.Lock()
c.stunReceive = nil
c.stunReceiveMu.Unlock()
if localAddr := c.pconn.LocalAddr(); localAddr.IP.IsUnspecified() {
localPort := fmt.Sprintf("%d", localAddr.Port)
loopbacks, err := localAddresses(localPort, func(s string) {
addAddr(s, "localAddresses")
})
if err != nil {
return nil, err
}
if len(eps) == 0 {
// Only include loopback addresses if we have no
// interfaces at all to use as endpoints. This allows
// for localhost testing when you're on a plane and
// offline, for example.
for _, s := range loopbacks {
addAddr(s, "loopback")
}
}
} else {
// Our local endpoint is bound to a particular address.
// Do not offer addresses on other local interfaces.
addAddr(localAddr.String(), "socket")
}
// Note: the endpoints are intentionally returned in priority order,
// from "farthest but most reliable" to "closest but least
// reliable." Addresses returned from STUN should be globally
// addressable, but might go farther on the network than necessary.
// Local interface addresses might have lower latency, but not be
// globally addressable.
//
// The STUN address(es) are always first so that legacy wireguard
// can use eps[0] as its only known endpoint address (although that's
// obviously non-ideal).
return eps, nil
}
func stringsEqual(x, y []string) bool {
if len(x) != len(y) {
return false
}
for i := range x {
if x[i] != y[i] {
return false
}
}
return true
}
func localAddresses(localPort string, addAddr func(s string)) ([]string, error) {
var loopback []string
// TODO(crawshaw): don't serve interface addresses that we are routing
ifaces, err := net.Interfaces()
if err != nil {
return nil, err
}
for _, i := range ifaces {
if (i.Flags & net.FlagUp) == 0 {
// Down interfaces don't count
continue
}
ifcIsLoopback := (i.Flags & net.FlagLoopback) != 0
addrs, err := i.Addrs()
if err != nil {
return nil, err
}
for _, a := range addrs {
switch v := a.(type) {
case *net.IPNet:
// TODO(crawshaw): IPv6 support.
// Easy to do here, but we need good endpoint ordering logic.
ip := v.IP.To4()
if ip == nil {
continue
}
// TODO(apenwarr): don't special case cgNAT.
// In the general wireguard case, it might
// very well be something we can route to
// directly, because both nodes are
// behind the same CGNAT router.
if cgNAT.Contains(ip) {
continue
}
if linkLocalIPv4.Contains(ip) {
continue
}
ep := net.JoinHostPort(ip.String(), localPort)
if ip.IsLoopback() || ifcIsLoopback {
loopback = append(loopback, ep)
continue
}
addAddr(ep)
}
}
}
return loopback, nil
}
var cgNAT = func() *net.IPNet {
_, ipNet, err := net.ParseCIDR("100.64.0.0/10")
if err != nil {
panic(err)
}
return ipNet
}()
var linkLocalIPv4 = func() *net.IPNet {
_, ipNet, err := net.ParseCIDR("169.254.0.0/16")
if err != nil {
panic(err)
}
return ipNet
}()
func (c *Conn) LocalPort() uint16 {
laddr := c.pconn.LocalAddr()
return uint16(laddr.Port)
}
func (c *Conn) Send(b []byte, ep device.Endpoint) error {
a := ep.(*AddrSet)
msgType := binary.LittleEndian.Uint32(b[:4])
switch msgType {
case device.MessageInitiationType, device.MessageResponseType, device.MessageCookieReplyType:
// Part of the wireguard handshake.
// Send to every potential endpoint we have for a peer.
a.mu.Lock()
roamAddr := a.roamAddr
a.mu.Unlock()
var err error
var success bool
if roamAddr != nil {
_, err = c.pconn.WriteTo(b, roamAddr)
if err == nil {
success = true
}
}
for i := len(a.addrs) - 1; i >= 0; i-- {
addr := &a.addrs[i]
_, err = c.pconn.WriteTo(b, addr)
if err == nil {
success = true
}
}
if msgType == device.MessageInitiationType {
// Send initial handshake messages via DERP.
c.derpMu.Lock()
derp := c.derp
c.derpMu.Unlock()
if derp != nil {
if err := derp.Send(a.publicKey, b); err != nil {
log.Printf("derp send failed: %v", err)
}
}
}
if success {
return nil
}
}
// Write to the highest-priority address we have seen so far.
_, err := c.pconn.WriteTo(b, a.dst())
return err
}
func (c *Conn) findIndexedAddrSet(addr *net.UDPAddr) (addrSet *AddrSet, index int) {
var epAddr udpAddr
copy(epAddr.ip.Addr[:], addr.IP.To16())
epAddr.port = uint16(addr.Port)
c.indexedAddrsMu.Lock()
defer c.indexedAddrsMu.Unlock()
indAddr := c.indexedAddrs[epAddr]
if indAddr.addr == nil {
return nil, 0
}
return indAddr.addr, indAddr.index
}
func (c *Conn) ReceiveIPv4(b []byte) (n int, ep device.Endpoint, addr *net.UDPAddr, err error) {
// Read a packet, and process any STUN packets before returning.
for {
var pAddr net.Addr
n, pAddr, err = c.pconn.ReadFrom(b)
if err != nil {
return n, nil, nil, err
}
addr = pAddr.(*net.UDPAddr)
addr.IP = addr.IP.To4()
if !stun.Is(b[:n]) {
break
}
c.stunReceiveMu.Lock()
fn := c.stunReceive
c.stunReceiveMu.Unlock()
if fn != nil {
fn(b, addr)
}
}
// TODO(crawshaw): remove all the indexed-addr logic
addrSet, _ := c.findIndexedAddrSet(addr)
if addrSet == nil {
// The peer that sent this packet has roamed beyond the
// knowledge provided by the control server.
// If the packet is valid wireguard will call UpdateDst
// on the original endpoint using this addr.
return n, (*singleEndpoint)(addr), addr, nil
}
return n, addrSet, addr, nil
}
func (c *Conn) ReceiveIPv6(buff []byte) (int, device.Endpoint, *net.UDPAddr, error) {
// TODO(crawshaw): IPv6 support
return 0, nil, nil, syscall.EAFNOSUPPORT
}
func (c *Conn) SetPrivateKey(privateKey [32]byte) error {
if c.derpServer == "" {
return nil
}
derp, err := derphttp.NewClient(privateKey, c.derpServer, log.Printf)
if err != nil {
return err
}
go func() {
var b [1 << 16]byte
for {
n, err := derp.Recv(b[:])
if err != nil {
if err == derphttp.ErrClientClosed {
return
}
log.Printf("%v", err)
time.Sleep(250 * time.Millisecond)
}
// Trigger re-STUN.
c.startEpUpdate <- struct{}{}
addr := c.pconn.LocalAddr()
if _, err := c.pconn.WriteToUDP(b[:n], addr); err != nil {
log.Printf("%v", err)
}
}
}()
c.derpMu.Lock()
if c.derp != nil {
if err := c.derp.Close(); err != nil {
log.Printf("derp.Close: %v", err)
}
}
c.derp = derp
c.derpMu.Unlock()
return nil
}
func (c *Conn) SetMark(value uint32) error { return nil }
func (c *Conn) Close() error {
c.epUpdateCancel()
return c.pconn.Close()
}
func (c *Conn) LinkChange() {
defer func() {
c.startEpUpdate <- struct{}{} // re-STUN
}()
if c.pconnPort != 0 {
c.pconn.mu.Lock()
if err := c.pconn.pconn.Close(); err != nil {
log.Printf("magicsock: link change close failed: %v", err)
}
packetConn, err := net.ListenPacket("udp4", fmt.Sprintf(":%d", c.pconnPort))
if err == nil {
log.Printf("magicsock: link change rebound port: %d", c.pconnPort)
c.pconn.pconn = packetConn.(*net.UDPConn)
c.pconn.mu.Unlock()
return
}
log.Printf("magicsock: link change unable to bind fixed port %d: %v, falling back to random port", c.pconnPort, err)
c.pconn.mu.Unlock()
}
log.Printf("magicsock: link change, binding new port")
packetConn, err := net.ListenPacket("udp4", ":0")
if err != nil {
log.Printf("magicsock: link change failed to bind new port: %v", err)
return
}
c.pconn.Reset(packetConn.(*net.UDPConn))
}
// AddrSet is a set of UDP addresses that implements wireguard/device.Endpoint.
type AddrSet struct {
publicKey [32]byte // peer public key used for DERP communication
addrs []net.UDPAddr // ordered priority list provided by wgengine
mu sync.Mutex // guards roamAddr and curAddr
roamAddr *net.UDPAddr // peer addr determined from incoming packets
// curAddr is an index into addrs of the highest-priority
// address a valid packet has been received from so far.
// If no valid packet from addrs has been received, curAddr is -1.
curAddr int
}
var noAddr = &net.UDPAddr{
IP: net.ParseIP("127.127.127.127"),
Port: 127,
}
func (a *AddrSet) dst() *net.UDPAddr {
a.mu.Lock()
defer a.mu.Unlock()
if a.roamAddr != nil {
return a.roamAddr
}
if len(a.addrs) == 0 {
return noAddr
}
i := a.curAddr
if i == -1 {
i = 0
}
return &a.addrs[i]
}
func (a *AddrSet) DstToBytes() []byte {
dst := a.dst()
b := append([]byte(nil), dst.IP.To4()...)
if len(b) == 0 {
b = append([]byte(nil), dst.IP...)
}
b = append(b, byte(dst.Port&0xff))
b = append(b, byte((dst.Port>>8)&0xff))
return b
}
func (a *AddrSet) DstToString() string {
dst := a.dst()
return dst.String()
}
func (a *AddrSet) DstIP() net.IP {
return a.dst().IP
}
func (a *AddrSet) SrcIP() net.IP { return nil }
func (a *AddrSet) SrcToString() string { return "" }
func (a *AddrSet) ClearSrc() {}
func (a *AddrSet) UpdateDst(new *net.UDPAddr) error {
a.mu.Lock()
defer a.mu.Unlock()
if a.roamAddr != nil {
if equalUDPAddr(a.roamAddr, new) {
// Packet from the current roaming address, no logging.
// This is a hot path for established connections.
return nil
}
} else if a.curAddr >= 0 && equalUDPAddr(new, &a.addrs[a.curAddr]) {
// Packet from current-priority address, no logging.
// This is a hot path for established connections.
return nil
}
index := -1
for i := range a.addrs {
if equalUDPAddr(new, &a.addrs[i]) {
index = i
break
}
}
publicKey := wgcfg.Key(a.publicKey)
pk := publicKey.ShortString()
old := "<none>"
if a.curAddr >= 0 {
old = a.addrs[a.curAddr].String()
}
switch {
case index == -1:
if a.roamAddr == nil {
log.Printf("magicsock: rx %s from roaming address %s, set as new priority", pk, new)
} else {
log.Printf("magicsock: rx %s from roaming address %s, replaces roaming address %s", pk, new, a.roamAddr)
}
a.roamAddr = new
case a.roamAddr != nil:
log.Printf("magicsock: rx %s from known %s (%d), replacs roaming address %s", pk, new, index, a.roamAddr)
a.roamAddr = nil
a.curAddr = index
case a.curAddr == -1:
log.Printf("magicsock: rx %s from %s (%d/%d), set as new priority", pk, new, index, len(a.addrs))
a.curAddr = index
case index < a.curAddr:
log.Printf("magicsock: rx %s from low-pri %s (%d), keeping current %s (%d)", pk, new, index, old, a.curAddr)
default: // index > a.curAddr
log.Printf("magicsock: rx %s from %s (%d/%d), replaces old priority %s", pk, new, index, len(a.addrs), old)
a.curAddr = index
}
return nil
}
func equalUDPAddr(x, y *net.UDPAddr) bool {
return x.Port == y.Port && x.IP.Equal(y.IP)
}
func (a *AddrSet) String() string {
a.mu.Lock()
defer a.mu.Unlock()
buf := new(strings.Builder)
buf.WriteByte('[')
if a.roamAddr != nil {
fmt.Fprintf(buf, "roam:%s:%d", a.roamAddr.IP, a.roamAddr.Port)
}
for i, addr := range a.addrs {
if i > 0 || a.roamAddr != nil {
buf.WriteString(", ")
}
fmt.Fprintf(buf, "%s:%d", addr.IP, addr.Port)
if a.curAddr == i {
buf.WriteByte('*')
}
}
buf.WriteByte(']')
return buf.String()
}
func (c *Conn) CreateEndpoint(key [32]byte, s string) (device.Endpoint, error) {
pk := wgcfg.Key(key)
log.Printf("magicsock: CreateEndpoint: key=%s: %s", pk.ShortString(), s)
a := &AddrSet{
publicKey: key,
curAddr: -1,
}
if s != "" {
for _, ep := range strings.Split(s, ",") {
addr, err := net.ResolveUDPAddr("udp", ep)
if err != nil {
return nil, err
}
if ip4 := addr.IP.To4(); ip4 != nil {
addr.IP = ip4
}
a.addrs = append(a.addrs, *addr)
}
}
c.indexedAddrsMu.Lock()
for i, addr := range a.addrs {
var epAddr udpAddr
copy(epAddr.ip.Addr[:], addr.IP.To16())
epAddr.port = uint16(addr.Port)
c.indexedAddrs[epAddr] = indexedAddrSet{
addr: a,
index: i,
}
}
c.indexedAddrsMu.Unlock()
return a, nil
}
type singleEndpoint net.UDPAddr
func (e *singleEndpoint) ClearSrc() {}
func (e *singleEndpoint) DstIP() net.IP { return (*net.UDPAddr)(e).IP }
func (e *singleEndpoint) SrcIP() net.IP { return nil }
func (e *singleEndpoint) SrcToString() string { return "" }
func (e *singleEndpoint) DstToString() string { return (*net.UDPAddr)(e).String() }
func (e *singleEndpoint) DstToBytes() []byte {
addr := (*net.UDPAddr)(e)
out := addr.IP.To4()
if out == nil {
out = addr.IP
}
out = append(out, byte(addr.Port&0xff))
out = append(out, byte((addr.Port>>8)&0xff))
return out
}
func (e *singleEndpoint) UpdateDst(dst *net.UDPAddr) error {
return fmt.Errorf("magicsock.singleEndpoint(%s).UpdateDst(%s): should never be called", (*net.UDPAddr)(e), dst)
}
// RebindingUDPConn is a UDP socket that can be re-bound.
// Unix has no notion of re-binding a socket, so we swap it out for a new one.
type RebindingUDPConn struct {
mu sync.Mutex
pconn *net.UDPConn
}
func (c *RebindingUDPConn) Reset(pconn *net.UDPConn) {
c.mu.Lock()
old := c.pconn
c.pconn = pconn
c.mu.Unlock()
if old != nil {
old.Close()
}
}
func (c *RebindingUDPConn) ReadFrom(b []byte) (int, net.Addr, error) {
for {
c.mu.Lock()
pconn := c.pconn
c.mu.Unlock()
n, addr, err := pconn.ReadFrom(b)
if err != nil {
c.mu.Lock()
pconn2 := c.pconn
c.mu.Unlock()
if pconn != pconn2 {
continue
}
}
return n, addr, err
}
}
func (c *RebindingUDPConn) LocalAddr() *net.UDPAddr {
c.mu.Lock()
defer c.mu.Unlock()
return c.pconn.LocalAddr().(*net.UDPAddr)
}
func (c *RebindingUDPConn) Close() error {
c.mu.Lock()
defer c.mu.Unlock()
return c.pconn.Close()
}
func (c *RebindingUDPConn) WriteToUDP(b []byte, addr *net.UDPAddr) (int, error) {
for {
c.mu.Lock()
pconn := c.pconn
c.mu.Unlock()
n, err := pconn.WriteToUDP(b, addr)
if err != nil {
c.mu.Lock()
pconn2 := c.pconn
c.mu.Unlock()
if pconn != pconn2 {
continue
}
}
return n, err
}
}
func (c *RebindingUDPConn) WriteTo(b []byte, addr net.Addr) (int, error) {
for {
c.mu.Lock()
pconn := c.pconn
c.mu.Unlock()
n, err := pconn.WriteTo(b, addr)
if err != nil {
c.mu.Lock()
pconn2 := c.pconn
c.mu.Unlock()
if pconn != pconn2 {
continue
}
}
return n, err
}
}

View File

@@ -0,0 +1,73 @@
// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package magicsock
import (
"fmt"
"net"
"strings"
"testing"
"time"
)
func TestListen(t *testing.T) {
epCh := make(chan string, 16)
epFunc := func(endpoints []string) {
for _, ep := range endpoints {
epCh <- ep
}
}
// TODO(crawshaw): break test dependency on the network
// using "gortc.io/stun" (like stunner_test.go).
stunServers := DefaultSTUN
port := pickPort(t)
conn, err := Listen(Options{
Port: port,
STUN: stunServers,
EndpointsFunc: epFunc,
})
if err != nil {
t.Fatal(err)
}
defer conn.Close()
go func() {
var pkt [1 << 16]byte
for {
_, _, _, err := conn.ReceiveIPv4(pkt[:])
if err != nil {
return
}
}
}()
timeout := time.After(10 * time.Second)
var endpoints []string
suffix := fmt.Sprintf(":%d", port)
collectEndpoints:
for {
select {
case ep := <-epCh:
endpoints = append(endpoints, ep)
if strings.HasSuffix(ep, suffix) {
break collectEndpoints
}
case <-timeout:
t.Fatalf("timeout with endpoints: %v", endpoints)
}
}
}
func pickPort(t *testing.T) uint16 {
t.Helper()
conn, err := net.ListenPacket("udp4", ":0")
if err != nil {
t.Fatal(err)
}
defer conn.Close()
return uint16(conn.LocalAddr().(*net.UDPAddr).Port)
}

363
wgengine/packet/packet.go Normal file
View File

@@ -0,0 +1,363 @@
// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package packet
import (
"encoding/binary"
"encoding/json"
"fmt"
"log"
"net"
"strings"
)
type IPProto int
const (
Junk IPProto = iota
Fragment
ICMP
UDP
TCP
)
// RFC1858: prevent overlapping fragment attacks.
const MIN_FRAG = 60 + 20 // max IPv4 header + basic TCP header
func (p IPProto) String() string {
switch p {
case Fragment:
return "Frag"
case ICMP:
return "ICMP"
case UDP:
return "UDP"
case TCP:
return "TCP"
default:
return "Junk"
}
}
type IP uint32
const IPAny = IP(0)
func NewIP(b net.IP) IP {
b4 := b.To4()
if b4 == nil {
panic(fmt.Sprintf("To4(%v) failed", b))
}
return IP(binary.BigEndian.Uint32(b4))
}
func (ip IP) String() string {
if ip == 0 {
return "*"
}
b := make([]byte, 4)
binary.BigEndian.PutUint32(b, uint32(ip))
return fmt.Sprintf("%d.%d.%d.%d", b[0], b[1], b[2], b[3])
}
func (ipp *IP) MarshalJSON() ([]byte, error) {
s := "\"" + (*ipp).String() + "\""
return []byte(s), nil
}
func (ipp *IP) UnmarshalJSON(b []byte) error {
var hostp *string
err := json.Unmarshal(b, &hostp)
if err != nil {
return err
}
host := *hostp
ip := net.ParseIP(host)
if ip != nil && ip.IsUnspecified() {
// For clarity, reject 0.0.0.0 as an input
return fmt.Errorf("Ports=%#v: to allow all IP addresses, use *:port, not 0.0.0.0:port", host)
} else if ip == nil && host == "*" {
// User explicitly requested wildcard dst ip
*ipp = IPAny
} else {
if ip != nil {
ip = ip.To4()
}
if ip == nil || len(ip) != 4 {
return fmt.Errorf("Ports=%#v: invalid IPv4 address", host)
}
*ipp = NewIP(ip)
}
return nil
}
const (
EchoReply uint8 = 0x00
EchoRequest uint8 = 0x08
)
const (
TCPSyn uint8 = 0x02
TCPAck uint8 = 0x10
TCPSynAck uint8 = TCPSyn | TCPAck
)
type QDecode struct {
b []byte // Packet buffer that this decodes
subofs int // byte offset of IP subprotocol
IPProto IPProto // IP subprotocol (UDP, TCP, etc)
SrcIP IP // IP source address
DstIP IP // IP destination address
SrcPort uint16 // TCP/UDP source port
DstPort uint16 // TCP/UDP destination port
TCPFlags uint8 // TCP flags (SYN, ACK, etc)
}
func (q QDecode) String() string {
if q.IPProto == Junk {
return "Junk{}"
}
srcip := make([]byte, 4)
dstip := make([]byte, 4)
binary.BigEndian.PutUint32(srcip, uint32(q.SrcIP))
binary.BigEndian.PutUint32(dstip, uint32(q.DstIP))
return fmt.Sprintf("%v{%d.%d.%d.%d:%d > %d.%d.%d.%d:%d}",
q.IPProto,
srcip[0], srcip[1], srcip[2], srcip[3], q.SrcPort,
dstip[0], dstip[1], dstip[2], dstip[3], q.DstPort)
}
// based on https://tools.ietf.org/html/rfc1071
func ipChecksum(b []byte) uint16 {
var ac uint32
i := 0
n := len(b)
for n >= 2 {
ac += uint32(binary.BigEndian.Uint16(b[i : i+2]))
n -= 2
i += 2
}
if n == 1 {
ac += uint32(b[i]) << 8
}
for (ac >> 16) > 0 {
ac = (ac >> 16) + (ac & 0xffff)
}
return uint16(^ac)
}
func GenICMP(srcIP, dstIP IP, ipid uint16, icmpType uint8, icmpCode uint8, payload []byte) []byte {
if len(payload) < 4 {
return nil
}
if len(payload) > 65535-24 {
return nil
}
sz := 24 + len(payload)
out := make([]byte, 24+len(payload))
out[0] = 0x45 // IPv4, 20-byte header
out[1] = 0x00 // DHCP, ECN
binary.BigEndian.PutUint16(out[2:4], uint16(sz))
binary.BigEndian.PutUint16(out[4:6], ipid)
binary.BigEndian.PutUint16(out[6:8], 0) // flags, offset
out[8] = 64 // TTL
out[9] = 0x01 // ICMPv4
// out[10:12] = 0x00 // blank IP header checksum
binary.BigEndian.PutUint32(out[12:16], uint32(srcIP))
binary.BigEndian.PutUint32(out[16:20], uint32(dstIP))
out[20] = icmpType
out[21] = icmpCode
//out[22:24] = 0x00 // blank ICMP checksum
copy(out[24:len(out)], payload)
binary.BigEndian.PutUint16(out[10:12], ipChecksum(out[0:20]))
binary.BigEndian.PutUint16(out[22:24], ipChecksum(out))
return out
}
// An extremely simple packet decoder for basic IPv4 packet types.
// It extracts only the subprotocol id, IP addresses, and (if any) ports,
// and shouldn't need any memory allocation.
func (q *QDecode) Decode(b []byte) {
q.b = nil
if len(b) < 20 {
q.IPProto = Junk
return
}
// Check that it's IPv4.
// TODO(apenwarr): consider IPv6 support
if ((b[0] & 0xF0) >> 4) != 4 {
q.IPProto = Junk
return
}
n := int(binary.BigEndian.Uint16(b[2:4]))
if len(b) < n {
// Packet was cut off before full IPv4 length.
q.IPProto = Junk
return
}
// If it's valid IPv4, then the IP addresses are valid
q.SrcIP = IP(binary.BigEndian.Uint32(b[12:16]))
q.DstIP = IP(binary.BigEndian.Uint32(b[16:20]))
q.subofs = int((b[0] & 0x0F) * 4)
sub := b[q.subofs:]
// We don't care much about IP fragmentation, except insofar as it's
// used for firewall bypass attacks. The trick is make the first
// fragment of a TCP or UDP packet so short that it doesn't fit
// the TCP or UDP header, so we can't read the port, in hope that
// it'll sneak past. Then subsequent fragments fill it in, but we're
// missing the first part of the header, so we can't read that either.
//
// A "perfectly correct" implementation would have to reassemble
// fragments before deciding what to do. But the truth is there's
// zero reason to send such a short first fragment, so we can treat
// it as Junk. We can also treat any subsequent fragment that starts
// at such a low offset as Junk.
fragFlags := binary.BigEndian.Uint16(b[6:8])
moreFrags := (fragFlags & 0x20) != 0
fragOfs := fragFlags & 0x1FFF
if fragOfs == 0 {
// This is the first fragment
if moreFrags && len(sub) < MIN_FRAG {
// Suspiciously short first fragment, dump it.
log.Printf("junk1!\n")
q.IPProto = Junk
return
}
// otherwise, this is either non-fragmented (the usual case)
// or a big enough initial fragment that we can read the
// whole subprotocol header.
proto := b[9]
switch proto {
case 1: // ICMPv4
if len(sub) < 8 {
q.IPProto = Junk
return
}
q.IPProto = ICMP
q.SrcPort = 0
q.DstPort = 0
q.b = b
return
case 6: // TCP
if len(sub) < 20 {
q.IPProto = Junk
return
}
q.IPProto = TCP
q.SrcPort = binary.BigEndian.Uint16(sub[0:2])
q.DstPort = binary.BigEndian.Uint16(sub[2:4])
q.TCPFlags = sub[13] & 0x3F
q.b = b
return
case 17: // UDP
if len(sub) < 8 {
q.IPProto = Junk
return
}
q.IPProto = UDP
q.SrcPort = binary.BigEndian.Uint16(sub[0:2])
q.DstPort = binary.BigEndian.Uint16(sub[2:4])
q.b = b
return
default:
q.IPProto = Junk
return
}
} else {
// This is a fragment other than the first one.
if fragOfs < MIN_FRAG {
// First frag was suspiciously short, so we can't
// trust the followup either.
q.IPProto = Junk
return
}
// otherwise, we have to permit the fragment to slide through.
// Second and later fragments don't have sub-headers.
// Ideally, we would drop fragments that we can't identify,
// but that would require statefulness. Anyway, receivers'
// kernels know to drop fragments where the initial fragment
// doesn't arrive.
q.IPProto = Fragment
return
}
}
// Returns a subset of the IP subprotocol section.
func (q *QDecode) Sub(begin, n int) []byte {
return q.b[q.subofs+begin : q.subofs+begin+n]
}
// For a packet that is known to be IPv4, trim the buffer to its IPv4 length.
// Sometimes packets arrive from an interface with extra bytes on the end.
// This removes them.
func (q *QDecode) Trim() []byte {
n := binary.BigEndian.Uint16(q.b[2:4])
return q.b[0:n]
}
// For a decoded TCP packet, return true if it's a TCP SYN packet (ie. the
// first packet in a new connection).
func (q *QDecode) IsTCPSyn() bool {
const Syn = 0x02
const Ack = 0x10
const SynAck = Syn | Ack
return (q.TCPFlags & SynAck) == Syn
}
// For a packet that has already been decoded, check if it's an IPv4 ICMP
// Echo Request.
func (q *QDecode) IsEchoRequest() bool {
if q.IPProto == ICMP && len(q.b) >= q.subofs+8 {
return q.b[q.subofs] == EchoRequest && q.b[q.subofs+1] == 0
}
return false
}
func (q *QDecode) EchoRespond() []byte {
icmpid := binary.BigEndian.Uint16(q.Sub(4, 2))
b := q.Trim()
return GenICMP(q.DstIP, q.SrcIP, icmpid, EchoReply, 0, b[q.subofs+4:])
}
func Hexdump(b []byte) string {
out := new(strings.Builder)
for i := 0; i < len(b); i += 16 {
if i > 0 {
fmt.Fprintf(out, "\n")
}
fmt.Fprintf(out, " %04x ", i)
j := 0
for ; j < 16 && i+j < len(b); j++ {
if j == 8 {
fmt.Fprintf(out, " ")
}
fmt.Fprintf(out, "%02x ", b[i+j])
}
for ; j < 16; j++ {
if j == 8 {
fmt.Fprintf(out, " ")
}
fmt.Fprintf(out, " ")
}
fmt.Fprintf(out, " ")
for j = 0; j < 16 && i+j < len(b); j++ {
if b[i+j] >= 32 && b[i+j] < 128 {
fmt.Fprintf(out, "%c", b[i+j])
} else {
fmt.Fprintf(out, ".")
}
}
}
return out.String()
}

36
wgengine/router_darwin.go Normal file
View File

@@ -0,0 +1,36 @@
// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package wgengine
import (
"github.com/tailscale/wireguard-go/device"
"github.com/tailscale/wireguard-go/tun"
"tailscale.com/logger"
)
type darwinRouter struct {
tunname string
}
func NewUserspaceRouter(logf logger.Logf, tunname string, dev *device.Device, tuntap tun.Device, netChanged func()) Router {
r := darwinRouter{
tunname: tunname,
}
return &r
}
func (r *darwinRouter) Up() error {
return nil
}
func (r *darwinRouter) SetRoutes(rs RouteSettings) error {
if SetRoutesFunc != nil {
return SetRoutesFunc(rs)
}
return nil
}
func (r *darwinRouter) Close() {
}

View File

@@ -0,0 +1,17 @@
// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// +build !windows,!linux,!darwin
package wgengine
import (
"github.com/tailscale/wireguard-go/device"
"github.com/tailscale/wireguard-go/tun"
"tailscale.com/logger"
)
func NewUserspaceRouter(logf logger.Logf, tunname string, dev *device.Device, tuntap tun.Device, netChanged func()) Router {
return NewFakeRouter(logf, tunname, dev, tuntap)
}

38
wgengine/router_fake.go Normal file
View File

@@ -0,0 +1,38 @@
// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package wgengine
import (
"github.com/tailscale/wireguard-go/device"
"github.com/tailscale/wireguard-go/tun"
"tailscale.com/logger"
)
type fakeRouter struct {
tunname string
logf logger.Logf
}
func NewFakeRouter(logf logger.Logf, tunname string, dev *device.Device, tuntap tun.Device, netChanged func()) Router {
r := fakeRouter{
logf: logf,
tunname: tunname,
}
return &r
}
func (r *fakeRouter) Up() error {
r.logf("Warning: fakeRouter.Up: not implemented.\n")
return nil
}
func (r *fakeRouter) SetRoutes(rs RouteSettings) error {
r.logf("Warning: fakeRouter.SetRoutes: not implemented.\n")
return nil
}
func (r *fakeRouter) Close() {
r.logf("Warning: fakeRouter.Close: not implemented.\n")
}

267
wgengine/router_linux.go Normal file
View File

@@ -0,0 +1,267 @@
// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package wgengine
import (
"bytes"
"fmt"
"io/ioutil"
"log"
"net"
"os"
"os/exec"
"path/filepath"
"strings"
"github.com/tailscale/wireguard-go/device"
"github.com/tailscale/wireguard-go/tun"
"github.com/tailscale/wireguard-go/wgcfg"
"tailscale.com/atomicfile"
"tailscale.com/logger"
"tailscale.com/wgengine/rtnlmon"
)
type linuxRouter struct {
logf func(fmt string, args ...interface{})
tunname string
mon *rtnlmon.Mon
netChanged func()
local wgcfg.CIDR
routes map[wgcfg.CIDR]struct{}
}
func NewUserspaceRouter(logf logger.Logf, tunname string, dev *device.Device, tuntap tun.Device, netChanged func()) Router {
mon, err := rtnlmon.New(logf, netChanged)
if err != nil {
log.Fatalf("rtnlmon.New() failed: %v", err)
}
r := linuxRouter{
logf: logf,
tunname: tunname,
mon: mon,
netChanged: netChanged,
}
return &r
}
func cmd(args ...string) *exec.Cmd {
if len(args) == 0 {
log.Fatalf("exec.Cmd(%#v) invalid; need argv[0]\n", args)
}
return exec.Command(args[0], args[1:]...)
}
func (r *linuxRouter) Up() error {
out, err := cmd("ip", "link", "set", r.tunname, "up").CombinedOutput()
if err != nil {
log.Fatalf("running ip link failed: %v\n%s", err, out)
}
// TODO(apenwarr): This never cleans up after itself!
out, err = cmd("iptables",
"-A", "FORWARD",
"-i", r.tunname,
"-j", "ACCEPT").CombinedOutput()
if err != nil {
r.logf("iptables forward failed: %v\n%s", err, out)
}
// TODO(apenwarr): hardcoded eth0 interface is obviously not right.
out, err = cmd("iptables",
"-t", "nat",
"-A", "POSTROUTING",
"-o", "eth0",
"-j", "MASQUERADE").CombinedOutput()
if err != nil {
r.logf("iptables nat failed: %v\n%s", err, out)
}
return nil
}
func (r *linuxRouter) SetRoutes(rs RouteSettings) error {
var errq error
if rs.LocalAddr != r.local {
if r.local != (wgcfg.CIDR{}) {
addrdel := []string{"ip", "addr",
"del", r.local.String(),
"dev", r.tunname}
out, err := cmd(addrdel...).CombinedOutput()
if err != nil {
r.logf("addr del failed: %v: %v\n%s", addrdel, err, out)
if errq == nil {
errq = err
}
}
}
addradd := []string{"ip", "addr",
"add", rs.LocalAddr.String(),
"dev", r.tunname}
out, err := cmd(addradd...).CombinedOutput()
if err != nil {
r.logf("addr add failed: %v: %v\n%s", addradd, err, out)
if errq == nil {
errq = err
}
}
}
newRoutes := make(map[wgcfg.CIDR]struct{})
for _, peer := range rs.Cfg.Peers {
for _, route := range peer.AllowedIPs {
newRoutes[route] = struct{}{}
}
}
for route := range r.routes {
if _, keep := newRoutes[route]; !keep {
net := route.IPNet()
nip := net.IP.Mask(net.Mask)
nstr := fmt.Sprintf("%v/%d", nip, route.Mask)
addrdel := []string{"ip", "route",
"del", nstr,
"via", r.local.IP.String(),
"dev", r.tunname}
out, err := cmd(addrdel...).CombinedOutput()
if err != nil {
r.logf("addr del failed: %v: %v\n%s", addrdel, err, out)
if errq == nil {
errq = err
}
}
}
}
for route := range newRoutes {
if _, exists := r.routes[route]; !exists {
net := route.IPNet()
nip := net.IP.Mask(net.Mask)
nstr := fmt.Sprintf("%v/%d", nip, route.Mask)
addradd := []string{"ip", "route",
"add", nstr,
"via", rs.LocalAddr.IP.String(),
"dev", r.tunname}
out, err := cmd(addradd...).CombinedOutput()
if err != nil {
r.logf("addr add failed: %v: %v\n%s", addradd, err, out)
if errq == nil {
errq = err
}
}
}
}
r.local = rs.LocalAddr
r.routes = newRoutes
if false {
if err := r.replaceResolvConf(rs.DNS, rs.DNSDomains); err != nil {
errq = fmt.Errorf("replacing resolv.conf failed: %v", err)
}
}
return errq
}
func (r *linuxRouter) Close() {
r.mon.Close()
if err := r.restoreResolvConf(); err != nil {
r.logf("failed to restore system resolv.conf: %v", err)
}
// TODO(apenwarr): clean up iptables etc.
}
const (
tsConf = "/etc/resolv.tailscale.conf"
backupConf = "/etc/resolv.pre-tailscale-backup.conf"
resolvConf = "/etc/resolv.conf"
)
func (r *linuxRouter) replaceResolvConf(servers []net.IP, domains []string) error {
if len(servers) == 0 {
return r.restoreResolvConf()
}
// First write the tsConf file.
buf := new(bytes.Buffer)
fmt.Fprintf(buf, "# resolv.conf(5) file generated by tailscale\n")
fmt.Fprintf(buf, "# DO NOT EDIT THIS FILE BY HAND -- CHANGES WILL BE OVERWRITTEN\n\n")
for _, ns := range servers {
fmt.Fprintf(buf, "nameserver %s\n", ns)
}
if len(domains) > 0 {
fmt.Fprintf(buf, "search "+strings.Join(domains, " ")+"\n")
}
f, err := ioutil.TempFile(filepath.Dir(tsConf), filepath.Base(tsConf)+".*")
if err != nil {
return err
}
f.Close()
if err := atomicfile.WriteFile(f.Name(), buf.Bytes(), 0644); err != nil {
return err
}
os.Chmod(f.Name(), 0644) // ioutil.TempFile creates the file with 0600
if err := os.Rename(f.Name(), tsConf); err != nil {
return err
}
if linkPath, err := os.Readlink(resolvConf); err != nil {
// Remove any old backup that may exist.
os.Remove(backupConf)
// Backup the existing /etc/resolv.conf file.
contents, err := ioutil.ReadFile(resolvConf)
if os.IsNotExist(err) {
// No existing /etc/resolve.conf file to backup.
// Nothing to do.
return nil
} else if err != nil {
return err
}
if err := atomicfile.WriteFile(backupConf, contents, 0644); err != nil {
return err
}
} else if linkPath != tsConf {
// Backup the existing symlink.
os.Remove(backupConf)
if err := os.Symlink(linkPath, backupConf); err != nil {
return err
}
} else {
// Nothing to do, resolvConf already points to tsConf.
return nil
}
os.Remove(resolvConf)
if err := os.Symlink(tsConf, resolvConf); err != nil {
return nil
}
out, _ := exec.Command("service", "systemd-resolved", "restart").CombinedOutput()
if len(out) > 0 {
r.logf("service systemd-resolved restart: %s", out)
}
return nil
}
func (r *linuxRouter) restoreResolvConf() error {
if _, err := os.Stat(backupConf); err != nil {
if os.IsNotExist(err) {
return nil // no backup resolve.conf to restore
}
return err
}
if ln, err := os.Readlink(resolvConf); err != nil {
return err
} else if ln != tsConf {
return fmt.Errorf("resolve.conf is not a symlink to %s", tsConf)
}
if err := os.Rename(backupConf, resolvConf); err != nil {
return err
}
os.Remove(tsConf) // best effort removal of tsConf file
out, _ := exec.Command("service", "systemd-resolved", "restart").CombinedOutput()
if len(out) > 0 {
r.logf("service systemd-resolved restart: %s", out)
}
return nil
}

View File

@@ -0,0 +1,58 @@
// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package wgengine
import (
"log"
"github.com/tailscale/wireguard-go/device"
"github.com/tailscale/wireguard-go/tun"
"golang.zx2c4.com/winipcfg"
"tailscale.com/logger"
)
type winRouter struct {
logf func(fmt string, args ...interface{})
tunname string
dev *device.Device
nativeTun *tun.NativeTun
routeChangeCallback *winipcfg.RouteChangeCallback
}
func NewUserspaceRouter(logf logger.Logf, tunname string, dev *device.Device, tuntap tun.Device, netChanged func()) Router {
r := winRouter{
logf: logf,
tunname: tunname,
dev: dev,
nativeTun: tuntap.(*tun.NativeTun),
}
return &r
}
func (r *winRouter) Up() error {
// MonitorDefaultRoutes handles making sure our wireguard UDP
// traffic goes through the old route, not recursively through the VPN.
var err error
r.routeChangeCallback, err = MonitorDefaultRoutes(r.dev, true, r.nativeTun)
if err != nil {
log.Fatalf("MonitorDefaultRoutes: %v\n", err)
}
return nil
}
func (r *winRouter) SetRoutes(rs RouteSettings) error {
err := ConfigureInterface(&rs.Cfg, r.nativeTun, rs.DNS, rs.DNSDomains)
if err != nil {
r.logf("ConfigureInterface: %v\n", err)
return err
}
return nil
}
func (r *winRouter) Close() {
if r.routeChangeCallback != nil {
r.routeChangeCallback.Unregister()
}
}

114
wgengine/rtnlmon/mon.go Normal file
View File

@@ -0,0 +1,114 @@
// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package rtnlmon watches for "interesting" changes to the network
// stack and fires a callback.
package rtnlmon
import (
"fmt"
"time"
"github.com/mdlayher/netlink"
"golang.org/x/sys/unix"
"tailscale.com/logger"
)
// Netlink is not a great protocol for *knowing* things. The protocol
// design makes it impossible to track changes precisely. You can see
// this by looking at things like Quagga or Bird, which all include
// keeping a local impression of what they think is in the kernel, and
// periodically doing a full state dump to find errors. They do use
// events, but explicitly only as an optimization, because they can't
// be trusted.
//
// Fortunately, we don't really need to know what exactly changed. We
// just want to know that network conditions may have changed, and we
// should re-explore connectivity. This is why we subscribe to events,
// and then blindly fire our callback without looking at the content
// of the notifications.
type ChangeFunc func()
type Mon struct {
logf logger.Logf
cb ChangeFunc
nl *netlink.Conn
change chan struct{}
stop chan struct{}
}
func New(logf logger.Logf, callback ChangeFunc) (*Mon, error) {
conn, err := netlink.Dial(unix.NETLINK_ROUTE, &netlink.Config{
// IPv4 address and route changes. Routes get us most of the
// events of interest, but we need address as well to cover
// things like DHCP deciding to give us a new address upon
// renewal - routing wouldn't change, but all reachability
// would.
//
// Why magic numbers? These aren't exposed in x/sys/unix
// yet. The values come from rtnetlink.h, RTMGRP_IPV4_IFADDR
// and RTMGRP_IPV4_ROUTE.
Groups: 0x10 | 0x40,
})
if err != nil {
return nil, fmt.Errorf("dialing netlink socket: %v", err)
}
ret := &Mon{
logf: logf,
cb: callback,
nl: conn,
change: make(chan struct{}, 1),
stop: make(chan struct{}),
}
go ret.pump()
go ret.debounce()
return ret, nil
}
func (m *Mon) Close() error {
close(m.stop)
return m.nl.Close()
}
func (m *Mon) pump() {
for {
_, err := m.nl.Receive()
if err != nil {
select {
case <-m.stop:
return
default:
}
// Keep retrying while we're not closed.
m.logf("Error receiving from netlink: %v", err)
time.Sleep(time.Second)
continue
}
select {
case m.change <- struct{}{}:
default:
}
}
}
func (m *Mon) debounce() {
for {
select {
case <-m.stop:
return
case <-m.change:
}
m.cb()
select {
case <-m.stop:
return
case <-time.After(100 * time.Millisecond):
}
}
}

21
wgengine/rusage.go Normal file
View File

@@ -0,0 +1,21 @@
// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package wgengine
import (
"fmt"
"runtime"
)
func RusagePrefixLog(logf func(f string, argv ...interface{})) func(f string, argv ...interface{}) {
return func(f string, argv ...interface{}) {
var m runtime.MemStats
runtime.ReadMemStats(&m)
goMem := float64(m.HeapInuse+m.StackInuse) / (1 << 20)
maxRSS := rusageMaxRSS()
pf := fmt.Sprintf("%.1fM/%.1fM %s", goMem, maxRSS, f)
logf(pf, argv...)
}
}

View File

@@ -0,0 +1,29 @@
// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// +build !windows
package wgengine
import (
"runtime"
"syscall"
)
func rusageMaxRSS() float64 {
var ru syscall.Rusage
err := syscall.Getrusage(syscall.RUSAGE_SELF, &ru)
if err != nil {
return 0
}
rss := float64(ru.Maxrss)
if runtime.GOOS == "darwin" {
rss /= 1 << 20 // ru_maxrss is bytes on darwin
} else {
// ru_maxrss is kilobytes elsewhere (linux, openbsd, etc)
rss /= 1024
}
return rss
}

View File

@@ -0,0 +1,10 @@
// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package wgengine
func rusageMaxRSS() float64 {
// TODO(apenwarr): Substitute Windows equivalent of Getrusage() here.
return 0
}

477
wgengine/userspace.go Normal file
View File

@@ -0,0 +1,477 @@
// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package wgengine
import (
"bufio"
"fmt"
"log"
"strconv"
"strings"
"sync"
"time"
"github.com/tailscale/wireguard-go/device"
"github.com/tailscale/wireguard-go/tun"
"github.com/tailscale/wireguard-go/wgcfg"
"tailscale.com/logger"
"tailscale.com/tailcfg"
"tailscale.com/wgengine/filter"
"tailscale.com/wgengine/magicsock"
"tailscale.com/wgengine/packet"
)
type userspaceEngine struct {
logf logger.Logf
statusCallback StatusCallback
reqCh chan struct{}
waitCh chan struct{}
tuntap tun.Device
wgdev *device.Device
router Router
magicConn *magicsock.Conn
wgLock sync.Mutex // serializes all wgdev operations
lastReconfig string
lastRoutes string
mu sync.Mutex
peerSequence []wgcfg.Key
endpoints []string
}
type Loggify struct {
f logger.Logf
}
func (l *Loggify) Write(b []byte) (int, error) {
l.f(string(b))
return len(b), nil
}
func NewFakeUserspaceEngine(logf logger.Logf, listenPort uint16, derp bool) (Engine, error) {
logf("Starting userspace wireguard engine (FAKE tuntap device).")
tun := NewFakeTun()
return NewUserspaceEngineAdvanced(logf, tun, NewFakeRouter, listenPort, derp)
}
func NewUserspaceEngine(logf logger.Logf, tunname string, listenPort uint16, derp bool) (Engine, error) {
logf("Starting userspace wireguard engine.")
logf("external packet routing via --tun=%s enabled", tunname)
if tunname == "" {
return nil, fmt.Errorf("--tun name must not be blank.")
}
tuntap, err := tun.CreateTUN(tunname, device.DefaultMTU)
if err != nil {
log.Printf("CreateTUN: %v\n", err)
return nil, err
}
log.Printf("CreateTUN ok.\n")
e, err := NewUserspaceEngineAdvanced(logf, tuntap, NewUserspaceRouter, listenPort, derp)
if err != nil {
log.Printf("NewUserspaceEngineAdv: %v\n", err)
return nil, err
}
return e, err
}
type RouterGen func(logf logger.Logf, tunname string, dev *device.Device, tuntap tun.Device, netStateChanged func()) Router
func NewUserspaceEngineAdvanced(logf logger.Logf, tuntap tun.Device, routerGen RouterGen, listenPort uint16, derp bool) (Engine, error) {
e := &userspaceEngine{
logf: logf,
reqCh: make(chan struct{}, 1),
waitCh: make(chan struct{}),
tuntap: tuntap,
}
tunname, err := tuntap.Name()
if err != nil {
return nil, err
}
endpointsFn := func(endpoints []string) {
e.mu.Lock()
if e.endpoints != nil {
e.endpoints = e.endpoints[:0]
}
e.endpoints = append(e.endpoints, endpoints...)
e.mu.Unlock()
e.RequestStatus()
}
magicsockOpts := magicsock.Options{
Port: listenPort,
STUN: magicsock.DefaultSTUN,
// TODO(crawshaw): DERP: magicsock.DefaultDERP,
EndpointsFunc: endpointsFn,
}
if derp {
magicsockOpts.DERP = magicsock.DefaultDERP
}
e.magicConn, err = magicsock.Listen(magicsockOpts)
if err != nil {
return nil, fmt.Errorf("wgengine: %v", err)
}
// flags==0 because logf is already nested in another logger.
// The outer one can display the preferred log prefixes, etc.
dlog := log.New(&Loggify{logf}, "", 0)
logger := device.Logger{
Debug: dlog,
Info: dlog,
Error: dlog,
}
nofilter := func(b []byte) device.FilterResult {
// for safety, default to dropping all packets
logf("Warning: you forgot to use wgengine.SetFilterInOut()! Packet dropped.\n")
return device.FilterDrop
}
opts := &device.DeviceOptions{
Logger: &logger,
FilterIn: nofilter,
FilterOut: nofilter,
HandshakeDone: func() {
// Send an unsolicited status event every time a
// handshake completes. This makes sure our UI can
// update quickly as soon as it connects to a peer.
//
// We use a goroutine here to avoid deadlocking
// wireguard, since RequestStatus() will call back
// into it, and wireguard is what called us to get
// here.
go e.RequestStatus()
},
CreateBind: func(uint16) (device.Bind, uint16, error) {
return e.magicConn, e.magicConn.LocalPort(), nil
},
CreateEndpoint: e.magicConn.CreateEndpoint,
SkipBindUpdate: true,
}
e.wgdev = device.NewDevice(e.tuntap, opts)
go func() {
up := false
for event := range e.tuntap.Events() {
if event&tun.EventMTUUpdate != 0 {
mtu, err := e.tuntap.MTU()
e.logf("external route MTU: %d (%v)", mtu, err)
}
if event&tun.EventUp != 0 && !up {
e.logf("external route: up")
e.RequestStatus()
up = true
}
if event&tun.EventDown != 0 && up {
e.logf("external route: down")
e.RequestStatus()
up = false
}
}
}()
e.router = routerGen(logf, tunname, e.wgdev, e.tuntap, func() { e.LinkChange(false) })
e.wgdev.Up()
if err := e.router.Up(); err != nil {
e.wgdev.Close()
return nil, err
}
if err := e.router.SetRoutes(RouteSettings{}); err != nil {
e.wgdev.Close()
return nil, err
}
return e, nil
}
// TODO(apenwarr): dnsDomains really ought to be in wgcfg.Config.
// However, we don't actually ever provide it to wireguard and it's not in
// the traditional wireguard config format. On the other hand, wireguard
// itself doesn't use the traditional 'dns =' setting either.
func (e *userspaceEngine) Reconfig(cfg *wgcfg.Config, dnsDomains []string) error {
e.logf("Reconfig(): configuring userspace wireguard engine.\n")
e.wgLock.Lock()
defer e.wgLock.Unlock()
e.peerSequence = make([]wgcfg.Key, len(cfg.Peers))
for i, p := range cfg.Peers {
e.peerSequence[i] = p.PublicKey
}
// TODO(apenwarr): get rid of silly uapi stuff for in-process comms
uapi, err := cfg.ToUAPI()
if err != nil {
return err
}
rc := uapi + "\x00" + strings.Join(dnsDomains, "\x00")
if rc == e.lastReconfig {
e.logf("...unchanged config, skipping.\n")
return nil
}
e.lastReconfig = rc
r := bufio.NewReader(strings.NewReader(uapi))
if err = e.wgdev.IpcSetOperation(r); err != nil {
e.logf("IpcSetOperation: %v\n", err)
return err
}
if err := e.magicConn.SetPrivateKey(cfg.Interface.PrivateKey); err != nil {
e.logf("magicsock: %v\n", err)
}
// TODO(apenwarr): only handling the first local address.
// Currently we never use more than one anyway.
var cidr wgcfg.CIDR
if len(cfg.Interface.Addresses) > 0 {
cidr = cfg.Interface.Addresses[0]
// TODO(apenwarr): this shouldn't be hardcoded in the client
cidr.Mask = 10 // route the whole cgnat range
}
rs := RouteSettings{
LocalAddr: cidr,
Cfg: *cfg,
DNS: cfg.Interface.Dns,
DNSDomains: dnsDomains,
}
e.logf("Reconfiguring router. la=%v dns=%v dom=%v\n",
rs.LocalAddr, rs.DNS, rs.DNSDomains)
// TODO(apenwarr): all the parts of RouteSettings should be "relevant."
// We're checking only the "relevant" parts to see if they have
// changed, and if not, skipping SetRoutes(). But if SetRoutes()
// is getting the non-relevant parts of Cfg, it might act on them,
// and this optimization is unsafe. Probably we should not pass
// a whole Cfg object as part of RouteSettings; instead, trim it to
// just what's absolutely needed (the set of actual routes).
rss := rs.OnlyRelevantParts()
e.logf("New routes: %v\n", rss)
if rss == e.lastRoutes {
e.logf("...unchanged routes, skipping.\n")
return nil
}
e.lastRoutes = rss
err = e.router.SetRoutes(rs)
e.logf("Reconfig() done.\n")
return err
}
func (e *userspaceEngine) SetFilter(filt *filter.Filter) {
var filtin, filtout func(b []byte) device.FilterResult
if filt == nil {
e.logf("wgengine: nil filter provided; no access restrictions.\n")
} else {
ft, ft_ok := e.tuntap.(*fakeTun)
filtin = func(b []byte) device.FilterResult {
runf := filter.LogDrops
//runf |= filter.HexdumpDrops
runf |= filter.LogAccepts
//runf |= filter.HexdumpAccepts
q := &packet.QDecode{}
if filt.RunIn(b, q, runf) == filter.Accept {
// Only in fake mode, answer any incoming pings
if ft_ok && q.IsEchoRequest() {
pb := q.EchoRespond()
ft.InsertRead(pb)
// We already handled it, stop.
return device.FilterDrop
}
return device.FilterAccept
}
return device.FilterDrop
}
filtout = func(b []byte) device.FilterResult {
runf := filter.LogDrops
//runf |= filter.HexdumpDrops
runf |= filter.LogAccepts
//runf |= filter.HexdumpAccepts
q := &packet.QDecode{}
if filt.RunOut(b, q, runf) == filter.Accept {
return device.FilterAccept
}
return device.FilterDrop
}
}
e.wgLock.Lock()
defer e.wgLock.Unlock()
e.wgdev.SetFilterInOut(filtin, filtout)
}
func (e *userspaceEngine) SetStatusCallback(cb StatusCallback) {
e.statusCallback = cb
}
func (e *userspaceEngine) getStatus() (*Status, error) {
e.wgLock.Lock()
defer e.wgLock.Unlock()
if e.wgdev == nil {
// RequestStatus was invoked before the wgengine has
// finished initializing. This can happen when wgegine
// provides a callback to magicsock for endpoint
// updates that calls RequestStatus.
return nil, nil
}
// TODO(apenwarr): get rid of silly uapi stuff for in-process comms
// FIXME: get notified of status changes instead of polling.
var bb strings.Builder
bio := bufio.NewWriter(&bb)
ipcErr := e.wgdev.IpcGetOperation(bio)
if ipcErr != nil {
log.Fatalf("IpcGetOperation: %v\n", ipcErr)
}
bio.Flush()
s := Status{}
pp := make(map[wgcfg.Key]*PeerStatus)
var p *PeerStatus = &PeerStatus{}
bbs := bb.String()
lines := strings.Split(bbs, "\n")
var hst1, hst2, n int64
var err error
for _, line := range lines {
kv := strings.SplitN(line, "=", 2)
var k, v string
k = kv[0]
if len(kv) > 1 {
v = kv[1]
}
switch k {
case "public_key":
pk, err := wgcfg.ParseHexKey(v)
if err != nil {
log.Fatalf("IpcGetOperation: invalid key %#v\n", v)
}
p = &PeerStatus{}
pp[*pk] = p
key := tailcfg.NodeKey(*pk)
p.NodeKey = key
case "rx_bytes":
n, err = strconv.ParseInt(v, 10, 64)
p.RxBytes = ByteCount(n)
if err != nil {
log.Fatalf("IpcGetOperation: rx_bytes invalid: %#v\n", line)
}
case "tx_bytes":
n, err = strconv.ParseInt(v, 10, 64)
p.TxBytes = ByteCount(n)
if err != nil {
log.Fatalf("IpcGetOperation: tx_bytes invalid: %#v\n", line)
}
case "last_handshake_time_sec":
hst1, err = strconv.ParseInt(v, 10, 64)
if err != nil {
log.Fatalf("IpcGetOperation: hst1 invalid: %#v\n", line)
}
case "last_handshake_time_nsec":
hst2, err = strconv.ParseInt(v, 10, 64)
if err != nil {
log.Fatalf("IpcGetOperation: hst2 invalid: %#v\n", line)
}
if hst1 != 0 || hst2 != 0 {
p.LastHandshake = time.Unix(hst1, hst2)
} // else leave at time.IsZero()
}
}
e.mu.Lock()
defer e.mu.Unlock()
var peers []PeerStatus
for _, pk := range e.peerSequence {
p := pp[pk]
if p == nil {
p = &PeerStatus{}
}
peers = append(peers, *p)
}
if len(pp) != len(e.peerSequence) {
e.logf("wg status returned %v peers, expected %v\n", len(s.Peers), len(e.peerSequence))
}
return &Status{
LocalAddrs: append([]string(nil), e.endpoints...),
Peers: peers,
}, nil
}
func (e *userspaceEngine) RequestStatus() {
// This is slightly tricky. e.getStatus() can theoretically get
// blocked inside wireguard for a while, and RequestStatus() is
// sometimes called from a goroutine, so we don't want a lot of
// them hanging around. On the other hand, requesting multiple
// status updates simultaneously is pointless anyway; they will
// all say the same thing.
// Enqueue at most one request. If one is in progress already, this
// adds one more to the queue. If one has been requested but not
// started, it is a no-op.
select {
case e.reqCh <- struct{}{}:
default:
}
// Dequeue at most one request. Another thread may have already
// dequeued the request we enqueued above, which is fine, since the
// information is guaranteed to be at least as recent as the current
// call to RequestStatus().
select {
case <-e.reqCh:
s, err := e.getStatus()
if s == nil && err == nil {
e.logf("RequestStatus: weird: both s and err are nil\n")
return
}
if e.statusCallback != nil {
e.statusCallback(s, err)
}
default:
}
}
func (e *userspaceEngine) Close() {
e.Reconfig(&wgcfg.Config{}, nil)
e.router.Close()
e.magicConn.Close()
close(e.waitCh)
}
func (e *userspaceEngine) Wait() {
<-e.waitCh
}
func (e *userspaceEngine) LinkChange(isExpensive bool) {
e.logf("LinkChange(isExpensive=%v): rebinding socket", isExpensive)
e.wgLock.Lock()
defer e.wgLock.Unlock()
// TODO(crawshaw): use isExpensive=true to switch into "client mode" on macOS?
e.magicConn.LinkChange()
// TODO(crawshaw): when we have an incremental notion of reconfig,
// be gentler here. No need to smash in-progress connections,
// we just need to handshake again.
if e.lastReconfig == "" {
return
}
uapi := e.lastReconfig[:strings.Index(e.lastReconfig, "\x00")]
r := bufio.NewReader(strings.NewReader(uapi))
if err := e.wgdev.IpcSetOperation(r); err != nil {
e.logf("IpcSetOperation: %v\n", err)
}
}

83
wgengine/watchdog.go Normal file
View File

@@ -0,0 +1,83 @@
// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package wgengine
import (
"bytes"
"log"
"runtime/pprof"
"time"
"github.com/tailscale/wireguard-go/wgcfg"
"tailscale.com/wgengine/filter"
)
// NewWatchdog wraps an Engine and makes sure that all methods complete
// within a reasonable amount of time.
//
// If they do not, the watchdog crashes the process.
func NewWatchdog(e Engine) Engine {
return &watchdogEngine{
wrap: e,
logf: log.Printf,
fatalf: log.Fatalf,
maxWait: 45 * time.Second,
}
}
type watchdogEngine struct {
wrap Engine
logf func(format string, args ...interface{})
fatalf func(format string, args ...interface{})
maxWait time.Duration
}
func (e *watchdogEngine) watchdogErr(name string, fn func() error) error {
errCh := make(chan error)
go func() {
errCh <- fn()
}()
t := time.NewTimer(e.maxWait)
select {
case err := <-errCh:
t.Stop()
return err
case <-t.C:
buf := new(bytes.Buffer)
pprof.Lookup("goroutine").WriteTo(buf, 1)
e.logf("wgengine watchdog stacks:\n%s", buf.String())
e.fatalf("wgengine: watchdog timeout on %s", name)
return nil
}
}
func (e *watchdogEngine) watchdog(name string, fn func()) {
e.watchdogErr(name, func() error {
fn()
return nil
})
}
func (e *watchdogEngine) Reconfig(cfg *wgcfg.Config, dnsDomains []string) error {
return e.watchdogErr("Reconfig", func() error { return e.wrap.Reconfig(cfg, dnsDomains) })
}
func (e *watchdogEngine) SetFilter(filt *filter.Filter) {
e.watchdog("SetFilter", func() { e.wrap.SetFilter(filt) })
}
func (e *watchdogEngine) SetStatusCallback(cb StatusCallback) {
e.watchdog("SetStatusCallback", func() { e.wrap.SetStatusCallback(cb) })
}
func (e *watchdogEngine) RequestStatus() {
e.watchdog("RequestStatus", func() { e.wrap.RequestStatus() })
}
func (e *watchdogEngine) LinkChange(isExpensive bool) {
e.watchdog("LinkChange", func() { e.wrap.LinkChange(isExpensive) })
}
func (e *watchdogEngine) Close() {
e.watchdog("Close", e.wrap.Close)
}
func (e *watchdogEngine) Wait() {
e.wrap.Wait()
}

71
wgengine/watchdog_test.go Normal file
View File

@@ -0,0 +1,71 @@
// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package wgengine
import (
"bytes"
"fmt"
"strings"
"testing"
"time"
)
func TestWatchdog(t *testing.T) {
t.Parallel()
t.Run("default watchdog does not fire", func(t *testing.T) {
t.Parallel()
tun := NewFakeTun()
e, err := NewUserspaceEngineAdvanced(t.Logf, tun, NewFakeRouter, 0, false)
if err != nil {
t.Fatal(err)
}
e = NewWatchdog(e)
e.(*watchdogEngine).maxWait = 150 * time.Millisecond
e.RequestStatus()
e.RequestStatus()
e.RequestStatus()
e.Close()
})
t.Run("watchdog fires on blocked getStatus", func(t *testing.T) {
t.Parallel()
tun := NewFakeTun()
e, err := NewUserspaceEngineAdvanced(t.Logf, tun, NewFakeRouter, 0, false)
if err != nil {
t.Fatal(err)
}
usEngine := e.(*userspaceEngine)
e = NewWatchdog(e)
wdEngine := e.(*watchdogEngine)
wdEngine.maxWait = 100 * time.Millisecond
logBuf := new(bytes.Buffer)
fatalCalled := make(chan struct{})
wdEngine.logf = func(format string, args ...interface{}) {
fmt.Fprintf(logBuf, format+"\n", args...)
}
wdEngine.fatalf = func(format string, args ...interface{}) {
t.Logf("FATAL: %s", fmt.Sprintf(format, args...))
fatalCalled <- struct{}{}
}
usEngine.wgLock.Lock() // blocks getStatus so the watchdog will fire
go e.RequestStatus()
select {
case <-fatalCalled:
if !strings.Contains(logBuf.String(), "goroutine profile: total ") {
t.Errorf("fatal called without watchdog stacks, got: %s", logBuf.String())
}
// expected
case <-time.After(3 * time.Second):
t.Fatalf("watchdog failed to fire")
}
})
}

79
wgengine/wgengine.go Normal file
View File

@@ -0,0 +1,79 @@
// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package wgengine
import (
"fmt"
"net"
"time"
"github.com/tailscale/wireguard-go/wgcfg"
"tailscale.com/tailcfg"
"tailscale.com/wgengine/filter"
)
type ByteCount int64
type PeerStatus struct {
TxBytes, RxBytes ByteCount
LastHandshake time.Time
NodeKey tailcfg.NodeKey
}
type Status struct {
Peers []PeerStatus
LocalAddrs []string // TODO(crawshaw): []wgcfg.Endpoint?
}
type StatusCallback func(s *Status, err error)
type RouteSettings struct {
LocalAddr wgcfg.CIDR
DNS []net.IP
DNSDomains []string
Cfg wgcfg.Config
}
// Only used on darwin for now
// TODO(apenwarr): This probably belongs in the darwinRouter struct.
var SetRoutesFunc func(rs RouteSettings) error
func (rs *RouteSettings) OnlyRelevantParts() string {
var peers [][]wgcfg.CIDR
for _, p := range rs.Cfg.Peers {
peers = append(peers, p.AllowedIPs)
}
return fmt.Sprintf("%v %v %v %v",
rs.LocalAddr, rs.DNS, rs.DNSDomains, peers)
}
type Router interface {
Up() error
SetRoutes(rs RouteSettings) error
Close()
}
type Engine interface {
// Reconfigure wireguard and make sure it's running.
// This also handles setting up any kernel routes.
Reconfig(cfg *wgcfg.Config, dnsDomains []string) error
// Update the packet filter.
SetFilter(filt *filter.Filter)
// Set the function to call when wireguard status changes.
SetStatusCallback(cb StatusCallback)
// Request a wireguard status update right away, sent to the callback.
RequestStatus()
// Shut down this wireguard instance, remove any routes it added, etc.
// To bring it up again later, you'll need a new Engine.
Close()
// Wait until the Engine is .Close()ed or aborts with an error.
// You don't have to call this.
Wait()
// LinkChange informs the engine that the system network
// link has changed. The isExpensive parameter is set on links
// where sending packets uses substantial power or dollars
// (such as LTE on a phone).
LinkChange(isExpensive bool)
}

153
wgengine/winnet/winnet.go Normal file
View File

@@ -0,0 +1,153 @@
// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package winnet
import (
"fmt"
"github.com/go-ole/go-ole"
"github.com/go-ole/go-ole/oleutil"
"unsafe"
)
const CLSID_NetworkListManager = "{DCB00C01-570F-4A9B-8D69-199FDBA5723B}"
var IID_INetwork = ole.NewGUID("{8A40A45D-055C-4B62-ABD7-6D613E2CEAEC}")
var IID_INetworkConnection = ole.NewGUID("{DCB00005-570F-4A9B-8D69-199FDBA5723B}")
type NetworkListManager struct {
d *ole.Dispatch
}
type INetworkConnection struct {
ole.IDispatch
}
type ConnectionList []*INetworkConnection
type INetworkConnectionVtbl struct {
ole.IDispatchVtbl
GetNetwork uintptr
Get_IsConnectedToInternet uintptr
Get_IsConnected uintptr
GetConnectivity uintptr
GetConnectionId uintptr
GetAdapterId uintptr
GetDomainType uintptr
}
type INetwork struct {
ole.IDispatch
}
func NewNetworkListManager(c *ole.Connection) (*NetworkListManager, error) {
err := c.Create(CLSID_NetworkListManager)
if err != nil {
return nil, err
}
defer c.Release()
d, err := c.Dispatch()
if err != nil {
return nil, err
}
return &NetworkListManager{
d: d,
}, nil
}
func (m *NetworkListManager) Release() {
m.d.Release()
}
func (cl ConnectionList) Release() {
for _, v := range cl {
v.Release()
}
}
func asIID(u ole.UnknownLike, iid *ole.GUID) (*ole.IDispatch, error) {
if u == nil {
return nil, fmt.Errorf("asIID: nil UnknownLike")
}
d, err := u.QueryInterface(iid)
u.Release()
if err != nil {
return nil, err
}
return d, nil
}
func (m *NetworkListManager) GetNetworkConnections() (ConnectionList, error) {
ncraw, err := m.d.Call("GetNetworkConnections")
if err != nil {
return nil, err
}
nli := ncraw.ToIDispatch()
if nli == nil {
return nil, fmt.Errorf("GetNetworkConnections: not IDispatch")
}
cl := ConnectionList{}
err = oleutil.ForEach(nli, func(v *ole.VARIANT) error {
nc, err := asIID(v.ToIUnknown(), IID_INetworkConnection)
if err != nil {
return err
}
nco := (*INetworkConnection)(unsafe.Pointer(nc))
cl = append(cl, nco)
return nil
})
if err != nil {
cl.Release()
return nil, err
}
return cl, nil
}
func (n *INetwork) GetName() (string, error) {
v, err := n.CallMethod("GetName")
if err != nil {
return "", err
}
return v.ToString(), err
}
func (n *INetwork) GetCategory() (int32, error) {
v, err := n.CallMethod("GetCategory")
if err != nil {
return 0, err
}
return v.Value().(int32), err
}
func (n *INetwork) SetCategory(v uint32) error {
_, err := n.CallMethod("SetCategory", v)
return err
}
func (v *INetworkConnection) VTable() *INetworkConnectionVtbl {
return (*INetworkConnectionVtbl)(unsafe.Pointer(v.RawVTable))
}
func (v *INetworkConnection) GetNetwork() (*INetwork, error) {
nraw, err := v.CallMethod("GetNetwork")
if err != nil {
return nil, err
}
n := nraw.ToIDispatch()
if n == nil {
return nil, fmt.Errorf("GetNetwork: nil IDispatch")
}
if err != nil {
return nil, err
}
return (*INetwork)(unsafe.Pointer(n)), nil
}

View File

@@ -0,0 +1,26 @@
// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package winnet
import (
"fmt"
"github.com/go-ole/go-ole"
"syscall"
"unsafe"
)
func (v *INetworkConnection) GetAdapterId() (string, error) {
buf := ole.GUID{}
hr, _, _ := syscall.Syscall(
v.VTable().GetAdapterId,
2,
uintptr(unsafe.Pointer(v)),
uintptr(unsafe.Pointer(&buf)),
0)
if hr != 0 {
return "", fmt.Errorf("GetAdapterId failed: %08x", hr)
}
return buf.String(), nil
}