mirror of
https://github.com/tailscale/tailscale.git
synced 2025-10-09 08:01:31 +00:00
Move Linux client & common packages into a public repo.
This commit is contained in:
55
wgengine/faketun.go
Normal file
55
wgengine/faketun.go
Normal file
@@ -0,0 +1,55 @@
|
||||
// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package wgengine
|
||||
|
||||
import (
|
||||
"github.com/tailscale/wireguard-go/tun"
|
||||
"io"
|
||||
"os"
|
||||
)
|
||||
|
||||
type fakeTun struct {
|
||||
datachan chan []byte
|
||||
evchan chan tun.Event
|
||||
closechan chan struct{}
|
||||
}
|
||||
|
||||
func NewFakeTun() tun.Device {
|
||||
return &fakeTun{
|
||||
datachan: make(chan []byte),
|
||||
evchan: make(chan tun.Event),
|
||||
closechan: make(chan struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
func (t *fakeTun) File() *os.File {
|
||||
panic("fakeTun.File() called, which makes no sense")
|
||||
}
|
||||
|
||||
func (t *fakeTun) Close() error {
|
||||
close(t.closechan)
|
||||
close(t.datachan)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *fakeTun) InsertRead(b []byte) {
|
||||
t.datachan <- b
|
||||
}
|
||||
|
||||
func (t *fakeTun) Read(out []byte, offset int) (int, error) {
|
||||
select {
|
||||
case <-t.closechan:
|
||||
return 0, io.EOF
|
||||
case b := <-t.datachan:
|
||||
copy(out[offset:offset+len(b)], b)
|
||||
return len(b), nil
|
||||
}
|
||||
}
|
||||
|
||||
func (t *fakeTun) Write(b []byte, n int) (int, error) { return len(b), nil }
|
||||
func (t *fakeTun) Flush() error { return nil }
|
||||
func (t *fakeTun) MTU() (int, error) { return 1500, nil }
|
||||
func (t *fakeTun) Name() (string, error) { return "FakeTun", nil }
|
||||
func (t *fakeTun) Events() chan tun.Event { return t.evchan }
|
218
wgengine/filter/filter.go
Normal file
218
wgengine/filter/filter.go
Normal file
@@ -0,0 +1,218 @@
|
||||
// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package filter
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/golang/groupcache/lru"
|
||||
"tailscale.com/ratelimit"
|
||||
"tailscale.com/wgengine/packet"
|
||||
)
|
||||
|
||||
type Filter struct {
|
||||
matches Matches
|
||||
|
||||
udpMu sync.Mutex
|
||||
udplru *lru.Cache
|
||||
}
|
||||
|
||||
type Response int
|
||||
|
||||
const (
|
||||
Drop Response = iota
|
||||
Accept
|
||||
noVerdict // Returned from subfilters to continue processing.
|
||||
)
|
||||
|
||||
func (r Response) String() string {
|
||||
switch r {
|
||||
case Drop:
|
||||
return "Drop"
|
||||
case Accept:
|
||||
return "Accept"
|
||||
case noVerdict:
|
||||
return "noVerdict"
|
||||
default:
|
||||
return "???"
|
||||
}
|
||||
}
|
||||
|
||||
type RunFlags int
|
||||
|
||||
const (
|
||||
LogDrops RunFlags = 1 << iota
|
||||
LogAccepts
|
||||
HexdumpDrops
|
||||
HexdumpAccepts
|
||||
)
|
||||
|
||||
type tuple struct {
|
||||
SrcIP IP
|
||||
DstIP IP
|
||||
SrcPort uint16
|
||||
DstPort uint16
|
||||
}
|
||||
|
||||
const LRU_MAX = 512 // max entries in UDP LRU cache
|
||||
|
||||
var MatchAllowAll = Matches{
|
||||
Match{[]IPPortRange{IPPortRangeAny}, []IP{IPAny}},
|
||||
}
|
||||
|
||||
func NewAllowAll() *Filter {
|
||||
return New(MatchAllowAll)
|
||||
}
|
||||
|
||||
func NewAllowNone() *Filter {
|
||||
return New(nil)
|
||||
}
|
||||
|
||||
func New(matches Matches) *Filter {
|
||||
f := &Filter{
|
||||
matches: matches,
|
||||
udplru: lru.New(LRU_MAX),
|
||||
}
|
||||
return f
|
||||
}
|
||||
|
||||
func maybeHexdump(flag RunFlags, b []byte) string {
|
||||
if flag != 0 {
|
||||
return packet.Hexdump(b) + "\n"
|
||||
} else {
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
// TODO(apenwarr): use a bigger bucket for specifically TCP SYN accept logging?
|
||||
// Logging is a quick way to record every newly opened TCP connection, but
|
||||
// we have to be cautious about flooding the logs vs letting people use
|
||||
// flood protection to hide their traffic. We could use a rate limiter in
|
||||
// the actual *filter* for SYN accepts, perhaps.
|
||||
var acceptBucket = ratelimit.Bucket{
|
||||
Burst: 3,
|
||||
FillInterval: 10 * time.Second,
|
||||
}
|
||||
var dropBucket = ratelimit.Bucket{
|
||||
Burst: 10,
|
||||
FillInterval: 5 * time.Second,
|
||||
}
|
||||
|
||||
func logRateLimit(runflags RunFlags, b []byte, q *packet.QDecode, r Response, why string) {
|
||||
if r == Drop && (runflags&LogDrops) != 0 && dropBucket.TryGet() > 0 {
|
||||
var qs string
|
||||
if q == nil {
|
||||
qs = fmt.Sprintf("(%d bytes)", len(b))
|
||||
} else {
|
||||
qs = q.String()
|
||||
}
|
||||
log.Printf("Drop: %v %v %s\n%s", qs, len(b), why, maybeHexdump(runflags&HexdumpDrops, b))
|
||||
} else if r == Accept && (runflags&LogAccepts) != 0 && acceptBucket.TryGet() > 0 {
|
||||
log.Printf("Accept: %v %v %s\n%s", q, len(b), why, maybeHexdump(runflags&HexdumpAccepts, b))
|
||||
}
|
||||
}
|
||||
|
||||
func (f *Filter) RunIn(b []byte, q *packet.QDecode, rf RunFlags) Response {
|
||||
r := pre(b, q, rf)
|
||||
if r == Accept || r == Drop {
|
||||
// already logged
|
||||
return r
|
||||
}
|
||||
|
||||
r, why := f.runIn(q)
|
||||
logRateLimit(rf, b, q, r, why)
|
||||
return r
|
||||
}
|
||||
|
||||
func (f *Filter) RunOut(b []byte, q *packet.QDecode, rf RunFlags) Response {
|
||||
r := pre(b, q, rf)
|
||||
if r == Drop || r == Accept {
|
||||
// already logged
|
||||
return r
|
||||
}
|
||||
r, why := f.runOut(q)
|
||||
logRateLimit(rf, b, q, r, why)
|
||||
return r
|
||||
}
|
||||
|
||||
func (f *Filter) runIn(q *packet.QDecode) (r Response, why string) {
|
||||
switch q.IPProto {
|
||||
case packet.ICMP:
|
||||
// If any port is open to an IP, allow ICMP to it.
|
||||
if matchIPWithoutPorts(f.matches, q) {
|
||||
return Accept, "icmp ok"
|
||||
}
|
||||
case packet.TCP:
|
||||
// For TCP, we want to allow *outgoing* connections,
|
||||
// which means we want to allow return packets on those
|
||||
// connections. To make this restriction work, we need to
|
||||
// allow non-SYN packets (continuation of an existing session)
|
||||
// to arrive. This should be okay since a new incoming session
|
||||
// can't be initiated without first sending a SYN.
|
||||
// It happens to also be much faster.
|
||||
// TODO(apenwarr): Skip the rest of decoding in this path?
|
||||
if q.IPProto == packet.TCP && !q.IsTCPSyn() {
|
||||
return Accept, "tcp non-syn"
|
||||
}
|
||||
if matchIPPorts(f.matches, q) {
|
||||
return Accept, "tcp ok"
|
||||
}
|
||||
case packet.UDP:
|
||||
t := tuple{q.SrcIP, q.DstIP, q.SrcPort, q.DstPort}
|
||||
|
||||
f.udpMu.Lock()
|
||||
_, ok := f.udplru.Get(t)
|
||||
f.udpMu.Unlock()
|
||||
|
||||
if ok {
|
||||
return Accept, "udp cached"
|
||||
}
|
||||
if matchIPPorts(f.matches, q) {
|
||||
return Accept, "udp ok"
|
||||
}
|
||||
default:
|
||||
return Drop, "Unknown proto"
|
||||
}
|
||||
return Drop, "no rules matched"
|
||||
}
|
||||
|
||||
func (f *Filter) runOut(q *packet.QDecode) (r Response, why string) {
|
||||
if q.IPProto == packet.UDP {
|
||||
t := tuple{q.DstIP, q.SrcIP, q.DstPort, q.SrcPort}
|
||||
|
||||
f.udpMu.Lock()
|
||||
f.udplru.Add(t, t)
|
||||
f.udpMu.Unlock()
|
||||
}
|
||||
return Accept, "ok out"
|
||||
}
|
||||
|
||||
func pre(b []byte, q *packet.QDecode, rf RunFlags) Response {
|
||||
if len(b) == 0 {
|
||||
// wireguard keepalive packet, always permit.
|
||||
return Accept
|
||||
}
|
||||
if len(b) < 20 {
|
||||
logRateLimit(rf, b, nil, Drop, "too short")
|
||||
return Drop
|
||||
}
|
||||
q.Decode(b)
|
||||
|
||||
if q.IPProto == packet.Junk {
|
||||
// Junk packets are dangerous; always drop them.
|
||||
logRateLimit(rf, b, q, Drop, "junk!")
|
||||
return Drop
|
||||
} else if q.IPProto == packet.Fragment {
|
||||
// Fragments after the first always need to be passed through.
|
||||
// Very small fragments are considered Junk by QDecode.
|
||||
logRateLimit(rf, b, q, Accept, "fragment")
|
||||
return Accept
|
||||
}
|
||||
|
||||
return noVerdict
|
||||
}
|
162
wgengine/filter/filter_test.go
Normal file
162
wgengine/filter/filter_test.go
Normal file
@@ -0,0 +1,162 @@
|
||||
// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package filter
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
"net"
|
||||
"testing"
|
||||
|
||||
"tailscale.com/wgengine/packet"
|
||||
)
|
||||
|
||||
type QDecode = packet.QDecode
|
||||
|
||||
var Junk = packet.Junk
|
||||
var ICMP = packet.ICMP
|
||||
var TCP = packet.TCP
|
||||
var UDP = packet.UDP
|
||||
var Fragment = packet.Fragment
|
||||
|
||||
func ippr(ip IP, start, end uint16) []IPPortRange {
|
||||
return []IPPortRange{
|
||||
IPPortRange{ip, PortRange{start, end}},
|
||||
}
|
||||
}
|
||||
|
||||
func TestFilter(t *testing.T) {
|
||||
mm := Matches{
|
||||
{SrcIPs: []IP{0x08010101, 0x08020202}, DstPorts: []IPPortRange{
|
||||
IPPortRange{0x01020304, PortRange{22, 22}},
|
||||
IPPortRange{0x05060708, PortRange{23, 24}},
|
||||
}},
|
||||
{SrcIPs: []IP{0x08010101, 0x08020202}, DstPorts: ippr(0x05060708, 27, 28)},
|
||||
{SrcIPs: []IP{0x02020202}, DstPorts: ippr(0x08010101, 22, 22)},
|
||||
{SrcIPs: []IP{0}, DstPorts: ippr(0x647a6232, 0, 65535)},
|
||||
{SrcIPs: []IP{0}, DstPorts: ippr(0, 443, 443)},
|
||||
{SrcIPs: []IP{0x99010101, 0x99010102, 0x99030303}, DstPorts: ippr(0x01020304, 999, 999)},
|
||||
}
|
||||
acl := New(mm)
|
||||
|
||||
for _, ent := range []Matches{Matches{mm[0]}, mm} {
|
||||
b, err := json.Marshal(ent)
|
||||
if err != nil {
|
||||
t.Fatalf("marshal: %v", err)
|
||||
}
|
||||
|
||||
mm2 := Matches{}
|
||||
if err := json.Unmarshal(b, &mm2); err != nil {
|
||||
t.Fatalf("unmarshal: %v (%v)", err, string(b))
|
||||
}
|
||||
}
|
||||
|
||||
// check packet filtering based on the table
|
||||
|
||||
type InOut struct {
|
||||
want Response
|
||||
p QDecode
|
||||
}
|
||||
tests := []InOut{
|
||||
// Basic
|
||||
{Accept, qdecode(TCP, 0x08010101, 0x01020304, 999, 22)},
|
||||
{Accept, qdecode(UDP, 0x08010101, 0x01020304, 999, 22)},
|
||||
{Accept, qdecode(ICMP, 0x08010101, 0x01020304, 0, 0)},
|
||||
{Drop, qdecode(TCP, 0x08010101, 0x01020304, 0, 0)},
|
||||
{Accept, qdecode(TCP, 0x08010101, 0x01020304, 0, 22)},
|
||||
{Drop, qdecode(TCP, 0x08010101, 0x01020304, 0, 21)},
|
||||
{Accept, qdecode(TCP, 0x11223344, 0x22334455, 0, 443)},
|
||||
{Drop, qdecode(TCP, 0x11223344, 0x22334455, 0, 444)},
|
||||
{Accept, qdecode(TCP, 0x11223344, 0x647a6232, 0, 999)},
|
||||
{Accept, qdecode(TCP, 0x11223344, 0x647a6232, 0, 0)},
|
||||
|
||||
// Stateful UDP.
|
||||
// Initially empty cache
|
||||
{Drop, qdecode(UDP, 0x77777777, 0x66666666, 4242, 4343)},
|
||||
// Return packet from previous attempt is allowed
|
||||
{Accept, qdecode(UDP, 0x66666666, 0x77777777, 4343, 4242)},
|
||||
// Because of the return above, initial attempt is allowed now
|
||||
{Accept, qdecode(UDP, 0x77777777, 0x66666666, 4242, 4343)},
|
||||
}
|
||||
for i, test := range tests {
|
||||
if got, _ := acl.runIn(&test.p); test.want != got {
|
||||
t.Errorf("#%d got=%v want=%v packet:%v\n", i, got, test.want, test.p)
|
||||
}
|
||||
// Update UDP state
|
||||
_, _ = acl.runOut(&test.p)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPreFilter(t *testing.T) {
|
||||
packets := []struct {
|
||||
desc string
|
||||
want Response
|
||||
b []byte
|
||||
}{
|
||||
{"empty", Accept, []byte{}},
|
||||
{"short", Drop, []byte("short")},
|
||||
{"junk", Drop, rawpacket(Junk, 10)},
|
||||
{"fragment", Accept, rawpacket(Fragment, 40)},
|
||||
{"tcp", noVerdict, rawpacket(TCP, 200)},
|
||||
{"udp", noVerdict, rawpacket(UDP, 200)},
|
||||
{"icmp", noVerdict, rawpacket(ICMP, 200)},
|
||||
}
|
||||
for _, testPacket := range packets {
|
||||
got := pre([]byte(testPacket.b), &QDecode{}, LogDrops|LogAccepts)
|
||||
if got != testPacket.want {
|
||||
t.Errorf("%q got=%v want=%v packet:\n%s", testPacket.desc, got, testPacket.want, packet.Hexdump(testPacket.b))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func qdecode(proto packet.IPProto, src, dst packet.IP, sport, dport uint16) QDecode {
|
||||
return QDecode{
|
||||
IPProto: proto,
|
||||
SrcIP: src,
|
||||
DstIP: dst,
|
||||
SrcPort: sport,
|
||||
DstPort: dport,
|
||||
TCPFlags: packet.TCPSyn,
|
||||
}
|
||||
}
|
||||
|
||||
func rawpacket(proto packet.IPProto, len uint16) []byte {
|
||||
bl := len
|
||||
if len < 24 {
|
||||
bl = 24
|
||||
}
|
||||
bin := binary.BigEndian
|
||||
hdr := make([]byte, bl)
|
||||
hdr[0] = 0x45
|
||||
bin.PutUint16(hdr[2:4], len)
|
||||
hdr[8] = 64
|
||||
ip := net.IPv4(8, 8, 8, 8).To4()
|
||||
copy(hdr[12:16], ip)
|
||||
copy(hdr[16:20], ip)
|
||||
// ports
|
||||
bin.PutUint16(hdr[20:22], 53)
|
||||
bin.PutUint16(hdr[22:24], 53)
|
||||
|
||||
switch proto {
|
||||
case ICMP:
|
||||
hdr[9] = 1
|
||||
case TCP:
|
||||
hdr[9] = 6
|
||||
case UDP:
|
||||
hdr[9] = 17
|
||||
case Fragment:
|
||||
hdr[9] = 6
|
||||
// flags + fragOff
|
||||
bin.PutUint16(hdr[6:8], (1<<13)|1234)
|
||||
case Junk:
|
||||
default:
|
||||
panic("unknown protocol")
|
||||
}
|
||||
|
||||
// Truncate the header if requested
|
||||
hdr = hdr[:len]
|
||||
|
||||
return hdr
|
||||
}
|
121
wgengine/filter/match.go
Normal file
121
wgengine/filter/match.go
Normal file
@@ -0,0 +1,121 @@
|
||||
// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package filter
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
"tailscale.com/wgengine/packet"
|
||||
)
|
||||
|
||||
type IP = packet.IP
|
||||
|
||||
const IPAny = IP(0)
|
||||
|
||||
var NewIP = packet.NewIP
|
||||
|
||||
type PortRange struct {
|
||||
First, Last uint16
|
||||
}
|
||||
|
||||
var PortRangeAny = PortRange{0, 65535}
|
||||
|
||||
func (pr PortRange) String() string {
|
||||
if pr.First == 0 && pr.Last == 65535 {
|
||||
return "*"
|
||||
} else if pr.First == pr.Last {
|
||||
return fmt.Sprintf("%d", pr.First)
|
||||
} else {
|
||||
return fmt.Sprintf("%d-%d", pr.First, pr.Last)
|
||||
}
|
||||
}
|
||||
|
||||
type IPPortRange struct {
|
||||
IP IP
|
||||
Ports PortRange
|
||||
}
|
||||
|
||||
var IPPortRangeAny = IPPortRange{IPAny, PortRangeAny}
|
||||
|
||||
func (ipr IPPortRange) String() string {
|
||||
return fmt.Sprintf("%v:%v", ipr.IP, ipr.Ports)
|
||||
}
|
||||
|
||||
type Match struct {
|
||||
DstPorts []IPPortRange
|
||||
SrcIPs []IP
|
||||
}
|
||||
|
||||
func (m Match) String() string {
|
||||
srcs := []string{}
|
||||
for _, srcip := range m.SrcIPs {
|
||||
srcs = append(srcs, srcip.String())
|
||||
}
|
||||
dsts := []string{}
|
||||
for _, dst := range m.DstPorts {
|
||||
dsts = append(dsts, dst.String())
|
||||
}
|
||||
|
||||
var ss, ds string
|
||||
if len(srcs) == 1 {
|
||||
ss = srcs[0]
|
||||
} else {
|
||||
ss = "[" + strings.Join(srcs, ",") + "]"
|
||||
}
|
||||
if len(dsts) == 1 {
|
||||
ds = dsts[0]
|
||||
} else {
|
||||
ds = "[" + strings.Join(dsts, ",") + "]"
|
||||
}
|
||||
return fmt.Sprintf("%v=>%v", ss, ds)
|
||||
}
|
||||
|
||||
type Matches []Match
|
||||
|
||||
func ipInList(ip IP, iplist []IP) bool {
|
||||
for _, ipp := range iplist {
|
||||
if ipp == IPAny || ipp == ip {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func matchIPPorts(mm Matches, q *packet.QDecode) bool {
|
||||
for _, acl := range mm {
|
||||
for _, dst := range acl.DstPorts {
|
||||
if dst.IP != IPAny && dst.IP != q.DstIP {
|
||||
continue
|
||||
}
|
||||
if q.DstPort < dst.Ports.First || q.DstPort > dst.Ports.Last {
|
||||
continue
|
||||
}
|
||||
if !ipInList(q.SrcIP, acl.SrcIPs) {
|
||||
// Skip other dests in this acl, since
|
||||
// the src will never match.
|
||||
break
|
||||
}
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func matchIPWithoutPorts(mm Matches, q *packet.QDecode) bool {
|
||||
for _, acl := range mm {
|
||||
for _, dst := range acl.DstPorts {
|
||||
if dst.IP != IPAny && dst.IP != q.DstIP {
|
||||
continue
|
||||
}
|
||||
if !ipInList(q.SrcIP, acl.SrcIPs) {
|
||||
// Skip other dests in this acl, since
|
||||
// the src will never match.
|
||||
break
|
||||
}
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
411
wgengine/ifconfig_windows.go
Normal file
411
wgengine/ifconfig_windows.go
Normal file
@@ -0,0 +1,411 @@
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2019 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package wgengine
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log"
|
||||
"net"
|
||||
"sort"
|
||||
"time"
|
||||
"unsafe"
|
||||
|
||||
"github.com/go-ole/go-ole"
|
||||
"github.com/tailscale/wireguard-go/device"
|
||||
"github.com/tailscale/wireguard-go/tun"
|
||||
"github.com/tailscale/wireguard-go/wgcfg"
|
||||
"golang.org/x/sys/windows"
|
||||
"golang.org/x/sys/windows/registry"
|
||||
"golang.zx2c4.com/winipcfg"
|
||||
"tailscale.com/wgengine/winnet"
|
||||
)
|
||||
|
||||
const (
|
||||
sockoptIP_UNICAST_IF = 31
|
||||
sockoptIPV6_UNICAST_IF = 31
|
||||
)
|
||||
|
||||
func htonl(val uint32) uint32 {
|
||||
bytes := make([]byte, 4)
|
||||
binary.BigEndian.PutUint32(bytes, val)
|
||||
return *(*uint32)(unsafe.Pointer(&bytes[0]))
|
||||
}
|
||||
|
||||
func bindSocketRoute(family winipcfg.AddressFamily, device *device.Device, ourLuid uint64, lastLuid *uint64) error {
|
||||
routes, err := winipcfg.GetRoutes(family)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
lowestMetric := ^uint32(0)
|
||||
index := uint32(0) // Zero is "unspecified", which for IP_UNICAST_IF resets the value, which is what we want.
|
||||
luid := uint64(0) // Hopefully luid zero is unspecified, but hard to find docs saying so.
|
||||
for _, route := range routes {
|
||||
if route.DestinationPrefix.PrefixLength != 0 || route.InterfaceLuid == ourLuid {
|
||||
continue
|
||||
}
|
||||
if route.Metric < lowestMetric {
|
||||
lowestMetric = route.Metric
|
||||
index = route.InterfaceIndex
|
||||
luid = route.InterfaceLuid
|
||||
}
|
||||
}
|
||||
if luid == *lastLuid {
|
||||
return nil
|
||||
}
|
||||
*lastLuid = luid
|
||||
if false {
|
||||
// TODO(apenwarr): doesn't work with magic socket yet.
|
||||
if family == winipcfg.AF_INET {
|
||||
return device.BindSocketToInterface4(index, false)
|
||||
} else if family == winipcfg.AF_INET6 {
|
||||
return device.BindSocketToInterface6(index, false)
|
||||
}
|
||||
} else {
|
||||
log.Printf("WARNING: skipping windows socket binding.\n")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func MonitorDefaultRoutes(device *device.Device, autoMTU bool, tun *tun.NativeTun) (*winipcfg.RouteChangeCallback, error) {
|
||||
guid := tun.GUID()
|
||||
ourLuid, err := winipcfg.InterfaceGuidToLuid(&guid)
|
||||
lastLuid4 := uint64(0)
|
||||
lastLuid6 := uint64(0)
|
||||
lastMtu := uint32(0)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
doIt := func() error {
|
||||
err = bindSocketRoute(winipcfg.AF_INET, device, ourLuid, &lastLuid4)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = bindSocketRoute(winipcfg.AF_INET6, device, ourLuid, &lastLuid6)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !autoMTU {
|
||||
return nil
|
||||
}
|
||||
mtu := uint32(0)
|
||||
if lastLuid4 != 0 {
|
||||
iface, err := winipcfg.InterfaceFromLUID(lastLuid4)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if iface.Mtu > 0 {
|
||||
mtu = iface.Mtu
|
||||
}
|
||||
}
|
||||
if lastLuid6 != 0 {
|
||||
iface, err := winipcfg.InterfaceFromLUID(lastLuid6)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if iface.Mtu > 0 && iface.Mtu < mtu {
|
||||
mtu = iface.Mtu
|
||||
}
|
||||
}
|
||||
if mtu > 0 && (lastMtu == 0 || lastMtu != mtu) {
|
||||
iface, err := winipcfg.GetIpInterface(ourLuid, winipcfg.AF_INET)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
iface.NlMtu = mtu - 80
|
||||
if iface.NlMtu < 576 {
|
||||
iface.NlMtu = 576
|
||||
}
|
||||
err = iface.Set()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
tun.ForceMTU(int(iface.NlMtu)) //TODO: it sort of breaks the model with v6 mtu and v4 mtu being different. Just set v4 one for now.
|
||||
iface, err = winipcfg.GetIpInterface(ourLuid, winipcfg.AF_INET6)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
iface.NlMtu = mtu - 80
|
||||
if iface.NlMtu < 1280 {
|
||||
iface.NlMtu = 1280
|
||||
}
|
||||
err = iface.Set()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
lastMtu = mtu
|
||||
}
|
||||
return nil
|
||||
}
|
||||
err = doIt()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
cb, err := winipcfg.RegisterRouteChangeCallback(func(notificationType winipcfg.MibNotificationType, route *winipcfg.Route) {
|
||||
//fmt.Printf("MonitorDefaultRoutes: changed: %v\n", route.DestinationPrefix)
|
||||
if route.DestinationPrefix.PrefixLength == 0 {
|
||||
_ = doIt()
|
||||
}
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return cb, nil
|
||||
}
|
||||
|
||||
func setDNSDomains(g windows.GUID, dnsDomains []string) {
|
||||
gs := g.String()
|
||||
log.Printf("setDNSDomains(%v) guid=%v\n", dnsDomains, gs)
|
||||
p := `SYSTEM\CurrentControlSet\Services\Tcpip\Parameters\Interfaces\` + gs
|
||||
key, err := registry.OpenKey(registry.LOCAL_MACHINE, p, registry.READ|registry.SET_VALUE)
|
||||
if err != nil {
|
||||
log.Printf("setDNSDomains(%v): open: %v\n", p, err)
|
||||
return
|
||||
}
|
||||
defer key.Close()
|
||||
|
||||
// Windows only supports a single per-interface DNS domain.
|
||||
dom := ""
|
||||
if len(dnsDomains) > 0 {
|
||||
dom = dnsDomains[0]
|
||||
}
|
||||
err = key.SetStringValue("Domain", dom)
|
||||
if err != nil {
|
||||
log.Printf("setDNSDomains(%v): SetStringValue: %v\n", p, err)
|
||||
}
|
||||
}
|
||||
|
||||
func setFirewall(ifcGUID *windows.GUID) (bool, error) {
|
||||
c := ole.Connection{}
|
||||
err := c.Initialize()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
defer c.Uninitialize()
|
||||
|
||||
m, err := winnet.NewNetworkListManager(&c)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
defer m.Release()
|
||||
|
||||
cl, err := m.GetNetworkConnections()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
defer cl.Release()
|
||||
|
||||
for _, nco := range cl {
|
||||
aid, err := nco.GetAdapterId()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
if aid != ifcGUID.String() {
|
||||
log.Printf("skipping adapter id: %v\n", aid)
|
||||
continue
|
||||
}
|
||||
log.Printf("found! adapter id: %v\n", aid)
|
||||
|
||||
n, err := nco.GetNetwork()
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("GetNetwork: %v", err)
|
||||
}
|
||||
defer n.Release()
|
||||
|
||||
cat, err := n.GetCategory()
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("GetCategory: %v", err)
|
||||
}
|
||||
|
||||
if cat == 0 {
|
||||
err = n.SetCategory(1)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("SetCategory: %v", err)
|
||||
}
|
||||
} else {
|
||||
log.Printf("setFirewall: already category %v\n", cat)
|
||||
}
|
||||
|
||||
return true, nil
|
||||
}
|
||||
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func ConfigureInterface(m *wgcfg.Config, tun *tun.NativeTun, dns []net.IP, dnsDomains []string) error {
|
||||
const mtu = 0
|
||||
guid := tun.GUID()
|
||||
log.Printf("wintun GUID is %v\n", guid)
|
||||
iface, err := winipcfg.InterfaceFromGUID(&guid)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
go func() {
|
||||
// It takes a weirdly long time for Windows to notice the
|
||||
// new interface has come up. Poll periodically until it
|
||||
// does.
|
||||
for i := 0; i < 20; i++ {
|
||||
found, err := setFirewall(&guid)
|
||||
if err != nil {
|
||||
log.Printf("setFirewall: %v\n", err)
|
||||
// fall through anyway, this isn't fatal.
|
||||
}
|
||||
if found {
|
||||
break
|
||||
}
|
||||
time.Sleep(1 * time.Second)
|
||||
}
|
||||
}()
|
||||
|
||||
setDNSDomains(guid, dnsDomains)
|
||||
|
||||
routes := []winipcfg.RouteData{}
|
||||
var firstGateway4 *net.IP
|
||||
var firstGateway6 *net.IP
|
||||
addresses := make([]*net.IPNet, len(m.Interface.Addresses))
|
||||
for i, addr := range m.Interface.Addresses {
|
||||
ipnet := addr.IPNet()
|
||||
addresses[i] = ipnet
|
||||
gateway := ipnet.IP
|
||||
if addr.IP.Is4() && firstGateway4 == nil {
|
||||
firstGateway4 = &gateway
|
||||
} else if addr.IP.Is6() && firstGateway6 == nil {
|
||||
firstGateway6 = &gateway
|
||||
}
|
||||
}
|
||||
|
||||
foundDefault4 := false
|
||||
foundDefault6 := false
|
||||
for _, peer := range m.Peers {
|
||||
for _, allowedip := range peer.AllowedIPs {
|
||||
if (allowedip.IP.Is4() && firstGateway4 == nil) || (allowedip.IP.Is6() && firstGateway6 == nil) {
|
||||
return errors.New("Due to a Windows limitation, one cannot have interface routes without an interface address")
|
||||
}
|
||||
|
||||
ipn := allowedip.IPNet()
|
||||
var gateway net.IP
|
||||
if allowedip.IP.Is4() {
|
||||
gateway = *firstGateway4
|
||||
} else if allowedip.IP.Is6() {
|
||||
gateway = *firstGateway6
|
||||
}
|
||||
r := winipcfg.RouteData{
|
||||
Destination: net.IPNet{
|
||||
IP: ipn.IP.Mask(ipn.Mask),
|
||||
Mask: ipn.Mask,
|
||||
},
|
||||
NextHop: gateway,
|
||||
Metric: 0,
|
||||
}
|
||||
if bytes.Compare(r.Destination.IP, gateway) == 0 {
|
||||
// no need to add a route for the interface's
|
||||
// own IP. The kernel does that for us.
|
||||
// If we try to replace it, we'll fail to
|
||||
// add the route unless NextHop is set, but
|
||||
// then the interface's IP won't be pingable.
|
||||
continue
|
||||
}
|
||||
if allowedip.IP.Is4() {
|
||||
if allowedip.Mask == 0 {
|
||||
foundDefault4 = true
|
||||
}
|
||||
r.NextHop = *firstGateway4
|
||||
} else if allowedip.IP.Is6() {
|
||||
if allowedip.Mask == 0 {
|
||||
foundDefault6 = true
|
||||
}
|
||||
r.NextHop = *firstGateway6
|
||||
}
|
||||
routes = append(routes, r)
|
||||
}
|
||||
}
|
||||
|
||||
err = iface.SetAddresses(addresses)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
sort.Slice(routes, func(i, j int) bool {
|
||||
return (bytes.Compare(routes[i].Destination.IP, routes[j].Destination.IP) == -1 ||
|
||||
// Narrower masks first
|
||||
bytes.Compare(routes[i].Destination.Mask, routes[j].Destination.Mask) == 1 ||
|
||||
// No nexthop before non-empty nexthop
|
||||
bytes.Compare(routes[i].NextHop, routes[j].NextHop) == -1 ||
|
||||
// Lower metrics first
|
||||
routes[i].Metric < routes[j].Metric)
|
||||
})
|
||||
|
||||
deduplicatedRoutes := []*winipcfg.RouteData{}
|
||||
for i := 0; i < len(routes); i++ {
|
||||
// There's only one way to get to a given IP+Mask, so delete
|
||||
// all matches after the first.
|
||||
if i > 0 &&
|
||||
bytes.Equal(routes[i].Destination.IP, routes[i-1].Destination.IP) &&
|
||||
bytes.Equal(routes[i].Destination.Mask, routes[i-1].Destination.Mask) {
|
||||
continue
|
||||
}
|
||||
deduplicatedRoutes = append(deduplicatedRoutes, &routes[i])
|
||||
}
|
||||
log.Printf("routes: %v\n", routes)
|
||||
|
||||
var errAcc error
|
||||
err = iface.SetRoutes(deduplicatedRoutes)
|
||||
if err != nil && errAcc == nil {
|
||||
log.Printf("setroutes: %v\n", err)
|
||||
errAcc = err
|
||||
}
|
||||
|
||||
err = iface.SetDNS(dns)
|
||||
if err != nil && errAcc == nil {
|
||||
log.Printf("setdns: %v\n", err)
|
||||
errAcc = err
|
||||
}
|
||||
|
||||
ipif, err := iface.GetIpInterface(winipcfg.AF_INET)
|
||||
if err != nil {
|
||||
log.Printf("getipif: %v\n", err)
|
||||
return err
|
||||
}
|
||||
log.Printf("foundDefault4: %v\n", foundDefault4)
|
||||
if foundDefault4 {
|
||||
ipif.UseAutomaticMetric = false
|
||||
ipif.Metric = 0
|
||||
}
|
||||
if mtu > 0 {
|
||||
ipif.NlMtu = uint32(mtu)
|
||||
tun.ForceMTU(int(ipif.NlMtu))
|
||||
}
|
||||
err = ipif.Set()
|
||||
if err != nil && errAcc == nil {
|
||||
errAcc = err
|
||||
}
|
||||
|
||||
ipif, err = iface.GetIpInterface(winipcfg.AF_INET6)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err != nil && errAcc == nil {
|
||||
errAcc = err
|
||||
}
|
||||
if foundDefault6 {
|
||||
ipif.UseAutomaticMetric = false
|
||||
ipif.Metric = 0
|
||||
}
|
||||
if mtu > 0 {
|
||||
ipif.NlMtu = uint32(mtu)
|
||||
}
|
||||
ipif.DadTransmits = 0
|
||||
ipif.RouterDiscoveryBehavior = winipcfg.RouterDiscoveryDisabled
|
||||
err = ipif.Set()
|
||||
if err != nil && errAcc == nil {
|
||||
errAcc = err
|
||||
}
|
||||
|
||||
return errAcc
|
||||
}
|
815
wgengine/magicsock/magicsock.go
Normal file
815
wgengine/magicsock/magicsock.go
Normal file
@@ -0,0 +1,815 @@
|
||||
// Copyright 2019 Tailscale & AUTHORS. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
// Package magicsock implements a socket that can change its communication path while
|
||||
// in use, actively searching for the best way to communicate.
|
||||
package magicsock
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"log"
|
||||
"net"
|
||||
"strings"
|
||||
"sync"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/tailscale/wireguard-go/device"
|
||||
"github.com/tailscale/wireguard-go/wgcfg"
|
||||
"tailscale.com/derp/derphttp"
|
||||
"tailscale.com/stun"
|
||||
"tailscale.com/stunner"
|
||||
)
|
||||
|
||||
// A Conn routes UDP packets and actively manages a list of its endpoints.
|
||||
// It implements wireguard/device.Bind.
|
||||
type Conn struct {
|
||||
pconn *RebindingUDPConn
|
||||
pconnPort uint16
|
||||
stunServers []string
|
||||
derpServer string
|
||||
startEpUpdate chan struct{} // send to trigger endpoint update
|
||||
epUpdateCancel func()
|
||||
epFunc func(endpoints []string)
|
||||
logf func(format string, args ...interface{})
|
||||
|
||||
// indexedAddrs is a map of every remote ip:port to a priority
|
||||
// list of endpoint addresses for a peer.
|
||||
// The priority list is provided by wgengine configuration.
|
||||
//
|
||||
// Given a wgcfg describing:
|
||||
// machineA: 10.0.0.1:1, 10.0.0.2:2
|
||||
// machineB: 10.0.0.3:3
|
||||
// the indexedAddrs map contains:
|
||||
// 10.0.0.1:1 -> [10.0.0.1:1, 10.0.0.2:2], index:0
|
||||
// 10.0.0.2:2 -> [10.0.0.1:1, 10.0.0.2:2], index:1
|
||||
// 10.0.0.3:3 -> [10.0.0.3:3], index:0
|
||||
indexedAddrsMu sync.Mutex
|
||||
indexedAddrs map[udpAddr]indexedAddrSet
|
||||
|
||||
stunReceiveMu sync.Mutex
|
||||
stunReceive func(p []byte, fromAddr *net.UDPAddr)
|
||||
|
||||
derpMu sync.Mutex
|
||||
derp *derphttp.Client
|
||||
}
|
||||
|
||||
// udpAddr is the key in the indexedAddrs map.
|
||||
// It maps an ip:port onto an indexedAddr.
|
||||
type udpAddr struct {
|
||||
ip wgcfg.IP
|
||||
port uint16
|
||||
}
|
||||
|
||||
// indexedAddrSet is an AddrSet (a priority list of ip:ports for a peer and the
|
||||
// current favored ip:port for communicating with the peer) and an index
|
||||
// number saying which element of the priority list is this map entry.
|
||||
type indexedAddrSet struct {
|
||||
addr *AddrSet
|
||||
index int // index of map key in addr.Addrs
|
||||
}
|
||||
|
||||
const DefaultPort = 0
|
||||
|
||||
const DefaultDERP = "https://derp.tailscale.com/derp"
|
||||
|
||||
var DefaultSTUN = []string{
|
||||
"stun.l.google.com:19302",
|
||||
"stun3.l.google.com:19302",
|
||||
}
|
||||
|
||||
// Options contains options for Listen.
|
||||
type Options struct {
|
||||
// Port is the port to listen on.
|
||||
// Zero means to pick one automatically.
|
||||
Port uint16
|
||||
|
||||
STUN []string
|
||||
DERP string
|
||||
|
||||
// EndpointsFunc optionally provides a func to be called when
|
||||
// endpoints change. The called func does not own the slice.
|
||||
EndpointsFunc func(endpoint []string)
|
||||
}
|
||||
|
||||
func (o *Options) endpointsFunc() func([]string) {
|
||||
if o == nil || o.EndpointsFunc == nil {
|
||||
return func([]string) {}
|
||||
}
|
||||
return o.EndpointsFunc
|
||||
}
|
||||
|
||||
// Listen creates a magic Conn listening on opts.Port.
|
||||
// As the set of possible endpoints for a Conn changes, the
|
||||
// callback opts.EndpointsFunc is called.
|
||||
func Listen(opts Options) (*Conn, error) {
|
||||
var packetConn net.PacketConn
|
||||
var err error
|
||||
if opts.Port == 0 {
|
||||
// Our choice of port. Start with DefaultPort.
|
||||
// If unavailable, pick any port.
|
||||
want := fmt.Sprintf(":%d", DefaultPort)
|
||||
log.Printf("magicsock: bind: trying %v\n", want)
|
||||
packetConn, err = net.ListenPacket("udp4", want)
|
||||
if err != nil {
|
||||
want = ":0"
|
||||
log.Printf("magicsock: bind: falling back to %v (%v)\n", want, err)
|
||||
packetConn, err = net.ListenPacket("udp4", want)
|
||||
}
|
||||
} else {
|
||||
packetConn, err = net.ListenPacket("udp4", fmt.Sprintf(":%d", opts.Port))
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("magicsock.Listen: %v", err)
|
||||
}
|
||||
|
||||
epUpdateCtx, epUpdateCancel := context.WithCancel(context.Background())
|
||||
c := &Conn{
|
||||
pconn: new(RebindingUDPConn),
|
||||
stunServers: append([]string{}, opts.STUN...),
|
||||
derpServer: opts.DERP,
|
||||
startEpUpdate: make(chan struct{}, 1),
|
||||
epUpdateCancel: epUpdateCancel,
|
||||
epFunc: opts.endpointsFunc(),
|
||||
logf: log.Printf,
|
||||
indexedAddrs: make(map[udpAddr]indexedAddrSet),
|
||||
}
|
||||
c.pconn.Reset(packetConn.(*net.UDPConn))
|
||||
c.startEpUpdate <- struct{}{} // STUN immediately on start
|
||||
go c.epUpdate(epUpdateCtx)
|
||||
return c, nil
|
||||
}
|
||||
|
||||
func (c *Conn) epUpdate(ctx context.Context) {
|
||||
var lastEndpoints []string
|
||||
var lastCancel func()
|
||||
var lastDone chan struct{}
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
if lastCancel != nil {
|
||||
lastCancel()
|
||||
}
|
||||
return
|
||||
case <-c.startEpUpdate:
|
||||
}
|
||||
|
||||
if lastCancel != nil {
|
||||
lastCancel()
|
||||
<-lastDone
|
||||
}
|
||||
var epCtx context.Context
|
||||
epCtx, lastCancel = context.WithCancel(ctx)
|
||||
lastDone = make(chan struct{})
|
||||
|
||||
go func() {
|
||||
defer close(lastDone)
|
||||
endpoints, err := c.determineEndpoints(epCtx)
|
||||
if err != nil {
|
||||
c.logf("magicsock.Conn: endpoint update failed: %v", err)
|
||||
// TODO(crawshaw): are there any conditions under which
|
||||
// we should trigger a retry based on the error here?
|
||||
return
|
||||
}
|
||||
if stringsEqual(endpoints, lastEndpoints) {
|
||||
return
|
||||
}
|
||||
lastEndpoints = endpoints
|
||||
c.epFunc(endpoints)
|
||||
}()
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Conn) determineEndpoints(ctx context.Context) ([]string, error) {
|
||||
var alreadyMu sync.Mutex
|
||||
already := make(map[string]struct{})
|
||||
var eps []string
|
||||
|
||||
addAddr := func(s, reason string) {
|
||||
log.Printf("magicsock: found local %s (%s)\n", s, reason)
|
||||
|
||||
alreadyMu.Lock()
|
||||
defer alreadyMu.Unlock()
|
||||
if _, ok := already[s]; !ok {
|
||||
already[s] = struct{}{}
|
||||
eps = append(eps, s)
|
||||
}
|
||||
}
|
||||
|
||||
s := &stunner.Stunner{
|
||||
Send: c.pconn.WriteTo,
|
||||
Endpoint: func(s string) { addAddr(s, "stun") },
|
||||
Servers: c.stunServers,
|
||||
Logf: c.logf,
|
||||
}
|
||||
|
||||
c.stunReceiveMu.Lock()
|
||||
c.stunReceive = s.Receive
|
||||
c.stunReceiveMu.Unlock()
|
||||
|
||||
if err := s.Run(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
c.stunReceiveMu.Lock()
|
||||
c.stunReceive = nil
|
||||
c.stunReceiveMu.Unlock()
|
||||
|
||||
if localAddr := c.pconn.LocalAddr(); localAddr.IP.IsUnspecified() {
|
||||
localPort := fmt.Sprintf("%d", localAddr.Port)
|
||||
loopbacks, err := localAddresses(localPort, func(s string) {
|
||||
addAddr(s, "localAddresses")
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(eps) == 0 {
|
||||
// Only include loopback addresses if we have no
|
||||
// interfaces at all to use as endpoints. This allows
|
||||
// for localhost testing when you're on a plane and
|
||||
// offline, for example.
|
||||
for _, s := range loopbacks {
|
||||
addAddr(s, "loopback")
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Our local endpoint is bound to a particular address.
|
||||
// Do not offer addresses on other local interfaces.
|
||||
addAddr(localAddr.String(), "socket")
|
||||
}
|
||||
|
||||
// Note: the endpoints are intentionally returned in priority order,
|
||||
// from "farthest but most reliable" to "closest but least
|
||||
// reliable." Addresses returned from STUN should be globally
|
||||
// addressable, but might go farther on the network than necessary.
|
||||
// Local interface addresses might have lower latency, but not be
|
||||
// globally addressable.
|
||||
//
|
||||
// The STUN address(es) are always first so that legacy wireguard
|
||||
// can use eps[0] as its only known endpoint address (although that's
|
||||
// obviously non-ideal).
|
||||
return eps, nil
|
||||
}
|
||||
|
||||
func stringsEqual(x, y []string) bool {
|
||||
if len(x) != len(y) {
|
||||
return false
|
||||
}
|
||||
for i := range x {
|
||||
if x[i] != y[i] {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func localAddresses(localPort string, addAddr func(s string)) ([]string, error) {
|
||||
var loopback []string
|
||||
|
||||
// TODO(crawshaw): don't serve interface addresses that we are routing
|
||||
ifaces, err := net.Interfaces()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for _, i := range ifaces {
|
||||
if (i.Flags & net.FlagUp) == 0 {
|
||||
// Down interfaces don't count
|
||||
continue
|
||||
}
|
||||
ifcIsLoopback := (i.Flags & net.FlagLoopback) != 0
|
||||
|
||||
addrs, err := i.Addrs()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for _, a := range addrs {
|
||||
switch v := a.(type) {
|
||||
case *net.IPNet:
|
||||
// TODO(crawshaw): IPv6 support.
|
||||
// Easy to do here, but we need good endpoint ordering logic.
|
||||
ip := v.IP.To4()
|
||||
if ip == nil {
|
||||
continue
|
||||
}
|
||||
// TODO(apenwarr): don't special case cgNAT.
|
||||
// In the general wireguard case, it might
|
||||
// very well be something we can route to
|
||||
// directly, because both nodes are
|
||||
// behind the same CGNAT router.
|
||||
if cgNAT.Contains(ip) {
|
||||
continue
|
||||
}
|
||||
if linkLocalIPv4.Contains(ip) {
|
||||
continue
|
||||
}
|
||||
ep := net.JoinHostPort(ip.String(), localPort)
|
||||
if ip.IsLoopback() || ifcIsLoopback {
|
||||
loopback = append(loopback, ep)
|
||||
continue
|
||||
}
|
||||
addAddr(ep)
|
||||
}
|
||||
}
|
||||
}
|
||||
return loopback, nil
|
||||
}
|
||||
|
||||
var cgNAT = func() *net.IPNet {
|
||||
_, ipNet, err := net.ParseCIDR("100.64.0.0/10")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return ipNet
|
||||
}()
|
||||
|
||||
var linkLocalIPv4 = func() *net.IPNet {
|
||||
_, ipNet, err := net.ParseCIDR("169.254.0.0/16")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return ipNet
|
||||
}()
|
||||
|
||||
func (c *Conn) LocalPort() uint16 {
|
||||
laddr := c.pconn.LocalAddr()
|
||||
return uint16(laddr.Port)
|
||||
}
|
||||
|
||||
func (c *Conn) Send(b []byte, ep device.Endpoint) error {
|
||||
a := ep.(*AddrSet)
|
||||
|
||||
msgType := binary.LittleEndian.Uint32(b[:4])
|
||||
switch msgType {
|
||||
case device.MessageInitiationType, device.MessageResponseType, device.MessageCookieReplyType:
|
||||
// Part of the wireguard handshake.
|
||||
// Send to every potential endpoint we have for a peer.
|
||||
a.mu.Lock()
|
||||
roamAddr := a.roamAddr
|
||||
a.mu.Unlock()
|
||||
|
||||
var err error
|
||||
var success bool
|
||||
if roamAddr != nil {
|
||||
_, err = c.pconn.WriteTo(b, roamAddr)
|
||||
if err == nil {
|
||||
success = true
|
||||
}
|
||||
}
|
||||
for i := len(a.addrs) - 1; i >= 0; i-- {
|
||||
addr := &a.addrs[i]
|
||||
_, err = c.pconn.WriteTo(b, addr)
|
||||
if err == nil {
|
||||
success = true
|
||||
}
|
||||
}
|
||||
|
||||
if msgType == device.MessageInitiationType {
|
||||
// Send initial handshake messages via DERP.
|
||||
c.derpMu.Lock()
|
||||
derp := c.derp
|
||||
c.derpMu.Unlock()
|
||||
|
||||
if derp != nil {
|
||||
if err := derp.Send(a.publicKey, b); err != nil {
|
||||
log.Printf("derp send failed: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if success {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// Write to the highest-priority address we have seen so far.
|
||||
_, err := c.pconn.WriteTo(b, a.dst())
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *Conn) findIndexedAddrSet(addr *net.UDPAddr) (addrSet *AddrSet, index int) {
|
||||
var epAddr udpAddr
|
||||
copy(epAddr.ip.Addr[:], addr.IP.To16())
|
||||
epAddr.port = uint16(addr.Port)
|
||||
|
||||
c.indexedAddrsMu.Lock()
|
||||
defer c.indexedAddrsMu.Unlock()
|
||||
|
||||
indAddr := c.indexedAddrs[epAddr]
|
||||
if indAddr.addr == nil {
|
||||
return nil, 0
|
||||
}
|
||||
return indAddr.addr, indAddr.index
|
||||
}
|
||||
|
||||
func (c *Conn) ReceiveIPv4(b []byte) (n int, ep device.Endpoint, addr *net.UDPAddr, err error) {
|
||||
// Read a packet, and process any STUN packets before returning.
|
||||
for {
|
||||
var pAddr net.Addr
|
||||
n, pAddr, err = c.pconn.ReadFrom(b)
|
||||
if err != nil {
|
||||
return n, nil, nil, err
|
||||
}
|
||||
addr = pAddr.(*net.UDPAddr)
|
||||
addr.IP = addr.IP.To4()
|
||||
|
||||
if !stun.Is(b[:n]) {
|
||||
break
|
||||
}
|
||||
c.stunReceiveMu.Lock()
|
||||
fn := c.stunReceive
|
||||
c.stunReceiveMu.Unlock()
|
||||
|
||||
if fn != nil {
|
||||
fn(b, addr)
|
||||
}
|
||||
}
|
||||
|
||||
// TODO(crawshaw): remove all the indexed-addr logic
|
||||
addrSet, _ := c.findIndexedAddrSet(addr)
|
||||
if addrSet == nil {
|
||||
// The peer that sent this packet has roamed beyond the
|
||||
// knowledge provided by the control server.
|
||||
// If the packet is valid wireguard will call UpdateDst
|
||||
// on the original endpoint using this addr.
|
||||
return n, (*singleEndpoint)(addr), addr, nil
|
||||
}
|
||||
return n, addrSet, addr, nil
|
||||
}
|
||||
|
||||
func (c *Conn) ReceiveIPv6(buff []byte) (int, device.Endpoint, *net.UDPAddr, error) {
|
||||
// TODO(crawshaw): IPv6 support
|
||||
return 0, nil, nil, syscall.EAFNOSUPPORT
|
||||
}
|
||||
|
||||
func (c *Conn) SetPrivateKey(privateKey [32]byte) error {
|
||||
if c.derpServer == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
derp, err := derphttp.NewClient(privateKey, c.derpServer, log.Printf)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
go func() {
|
||||
var b [1 << 16]byte
|
||||
for {
|
||||
n, err := derp.Recv(b[:])
|
||||
if err != nil {
|
||||
if err == derphttp.ErrClientClosed {
|
||||
return
|
||||
}
|
||||
log.Printf("%v", err)
|
||||
time.Sleep(250 * time.Millisecond)
|
||||
}
|
||||
|
||||
// Trigger re-STUN.
|
||||
c.startEpUpdate <- struct{}{}
|
||||
|
||||
addr := c.pconn.LocalAddr()
|
||||
if _, err := c.pconn.WriteToUDP(b[:n], addr); err != nil {
|
||||
log.Printf("%v", err)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
c.derpMu.Lock()
|
||||
if c.derp != nil {
|
||||
if err := c.derp.Close(); err != nil {
|
||||
log.Printf("derp.Close: %v", err)
|
||||
}
|
||||
}
|
||||
c.derp = derp
|
||||
c.derpMu.Unlock()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Conn) SetMark(value uint32) error { return nil }
|
||||
|
||||
func (c *Conn) Close() error {
|
||||
c.epUpdateCancel()
|
||||
return c.pconn.Close()
|
||||
}
|
||||
|
||||
func (c *Conn) LinkChange() {
|
||||
defer func() {
|
||||
c.startEpUpdate <- struct{}{} // re-STUN
|
||||
}()
|
||||
|
||||
if c.pconnPort != 0 {
|
||||
c.pconn.mu.Lock()
|
||||
if err := c.pconn.pconn.Close(); err != nil {
|
||||
log.Printf("magicsock: link change close failed: %v", err)
|
||||
}
|
||||
packetConn, err := net.ListenPacket("udp4", fmt.Sprintf(":%d", c.pconnPort))
|
||||
if err == nil {
|
||||
log.Printf("magicsock: link change rebound port: %d", c.pconnPort)
|
||||
c.pconn.pconn = packetConn.(*net.UDPConn)
|
||||
c.pconn.mu.Unlock()
|
||||
return
|
||||
}
|
||||
log.Printf("magicsock: link change unable to bind fixed port %d: %v, falling back to random port", c.pconnPort, err)
|
||||
c.pconn.mu.Unlock()
|
||||
}
|
||||
|
||||
log.Printf("magicsock: link change, binding new port")
|
||||
packetConn, err := net.ListenPacket("udp4", ":0")
|
||||
if err != nil {
|
||||
log.Printf("magicsock: link change failed to bind new port: %v", err)
|
||||
return
|
||||
}
|
||||
c.pconn.Reset(packetConn.(*net.UDPConn))
|
||||
}
|
||||
|
||||
// AddrSet is a set of UDP addresses that implements wireguard/device.Endpoint.
|
||||
type AddrSet struct {
|
||||
publicKey [32]byte // peer public key used for DERP communication
|
||||
addrs []net.UDPAddr // ordered priority list provided by wgengine
|
||||
|
||||
mu sync.Mutex // guards roamAddr and curAddr
|
||||
roamAddr *net.UDPAddr // peer addr determined from incoming packets
|
||||
// curAddr is an index into addrs of the highest-priority
|
||||
// address a valid packet has been received from so far.
|
||||
// If no valid packet from addrs has been received, curAddr is -1.
|
||||
curAddr int
|
||||
}
|
||||
|
||||
var noAddr = &net.UDPAddr{
|
||||
IP: net.ParseIP("127.127.127.127"),
|
||||
Port: 127,
|
||||
}
|
||||
|
||||
func (a *AddrSet) dst() *net.UDPAddr {
|
||||
a.mu.Lock()
|
||||
defer a.mu.Unlock()
|
||||
|
||||
if a.roamAddr != nil {
|
||||
return a.roamAddr
|
||||
}
|
||||
if len(a.addrs) == 0 {
|
||||
return noAddr
|
||||
}
|
||||
i := a.curAddr
|
||||
if i == -1 {
|
||||
i = 0
|
||||
}
|
||||
return &a.addrs[i]
|
||||
}
|
||||
|
||||
func (a *AddrSet) DstToBytes() []byte {
|
||||
dst := a.dst()
|
||||
b := append([]byte(nil), dst.IP.To4()...)
|
||||
if len(b) == 0 {
|
||||
b = append([]byte(nil), dst.IP...)
|
||||
}
|
||||
b = append(b, byte(dst.Port&0xff))
|
||||
b = append(b, byte((dst.Port>>8)&0xff))
|
||||
return b
|
||||
}
|
||||
func (a *AddrSet) DstToString() string {
|
||||
dst := a.dst()
|
||||
return dst.String()
|
||||
}
|
||||
func (a *AddrSet) DstIP() net.IP {
|
||||
return a.dst().IP
|
||||
}
|
||||
func (a *AddrSet) SrcIP() net.IP { return nil }
|
||||
func (a *AddrSet) SrcToString() string { return "" }
|
||||
func (a *AddrSet) ClearSrc() {}
|
||||
|
||||
func (a *AddrSet) UpdateDst(new *net.UDPAddr) error {
|
||||
a.mu.Lock()
|
||||
defer a.mu.Unlock()
|
||||
|
||||
if a.roamAddr != nil {
|
||||
if equalUDPAddr(a.roamAddr, new) {
|
||||
// Packet from the current roaming address, no logging.
|
||||
// This is a hot path for established connections.
|
||||
return nil
|
||||
}
|
||||
} else if a.curAddr >= 0 && equalUDPAddr(new, &a.addrs[a.curAddr]) {
|
||||
// Packet from current-priority address, no logging.
|
||||
// This is a hot path for established connections.
|
||||
return nil
|
||||
}
|
||||
|
||||
index := -1
|
||||
for i := range a.addrs {
|
||||
if equalUDPAddr(new, &a.addrs[i]) {
|
||||
index = i
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
publicKey := wgcfg.Key(a.publicKey)
|
||||
pk := publicKey.ShortString()
|
||||
old := "<none>"
|
||||
if a.curAddr >= 0 {
|
||||
old = a.addrs[a.curAddr].String()
|
||||
}
|
||||
|
||||
switch {
|
||||
case index == -1:
|
||||
if a.roamAddr == nil {
|
||||
log.Printf("magicsock: rx %s from roaming address %s, set as new priority", pk, new)
|
||||
} else {
|
||||
log.Printf("magicsock: rx %s from roaming address %s, replaces roaming address %s", pk, new, a.roamAddr)
|
||||
}
|
||||
a.roamAddr = new
|
||||
|
||||
case a.roamAddr != nil:
|
||||
log.Printf("magicsock: rx %s from known %s (%d), replacs roaming address %s", pk, new, index, a.roamAddr)
|
||||
a.roamAddr = nil
|
||||
a.curAddr = index
|
||||
|
||||
case a.curAddr == -1:
|
||||
log.Printf("magicsock: rx %s from %s (%d/%d), set as new priority", pk, new, index, len(a.addrs))
|
||||
a.curAddr = index
|
||||
|
||||
case index < a.curAddr:
|
||||
log.Printf("magicsock: rx %s from low-pri %s (%d), keeping current %s (%d)", pk, new, index, old, a.curAddr)
|
||||
|
||||
default: // index > a.curAddr
|
||||
log.Printf("magicsock: rx %s from %s (%d/%d), replaces old priority %s", pk, new, index, len(a.addrs), old)
|
||||
a.curAddr = index
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func equalUDPAddr(x, y *net.UDPAddr) bool {
|
||||
return x.Port == y.Port && x.IP.Equal(y.IP)
|
||||
}
|
||||
|
||||
func (a *AddrSet) String() string {
|
||||
a.mu.Lock()
|
||||
defer a.mu.Unlock()
|
||||
|
||||
buf := new(strings.Builder)
|
||||
buf.WriteByte('[')
|
||||
if a.roamAddr != nil {
|
||||
fmt.Fprintf(buf, "roam:%s:%d", a.roamAddr.IP, a.roamAddr.Port)
|
||||
}
|
||||
for i, addr := range a.addrs {
|
||||
if i > 0 || a.roamAddr != nil {
|
||||
buf.WriteString(", ")
|
||||
}
|
||||
fmt.Fprintf(buf, "%s:%d", addr.IP, addr.Port)
|
||||
if a.curAddr == i {
|
||||
buf.WriteByte('*')
|
||||
}
|
||||
}
|
||||
buf.WriteByte(']')
|
||||
|
||||
return buf.String()
|
||||
}
|
||||
|
||||
func (c *Conn) CreateEndpoint(key [32]byte, s string) (device.Endpoint, error) {
|
||||
pk := wgcfg.Key(key)
|
||||
log.Printf("magicsock: CreateEndpoint: key=%s: %s", pk.ShortString(), s)
|
||||
a := &AddrSet{
|
||||
publicKey: key,
|
||||
curAddr: -1,
|
||||
}
|
||||
|
||||
if s != "" {
|
||||
for _, ep := range strings.Split(s, ",") {
|
||||
addr, err := net.ResolveUDPAddr("udp", ep)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if ip4 := addr.IP.To4(); ip4 != nil {
|
||||
addr.IP = ip4
|
||||
}
|
||||
a.addrs = append(a.addrs, *addr)
|
||||
}
|
||||
}
|
||||
|
||||
c.indexedAddrsMu.Lock()
|
||||
for i, addr := range a.addrs {
|
||||
var epAddr udpAddr
|
||||
copy(epAddr.ip.Addr[:], addr.IP.To16())
|
||||
epAddr.port = uint16(addr.Port)
|
||||
c.indexedAddrs[epAddr] = indexedAddrSet{
|
||||
addr: a,
|
||||
index: i,
|
||||
}
|
||||
}
|
||||
c.indexedAddrsMu.Unlock()
|
||||
|
||||
return a, nil
|
||||
}
|
||||
|
||||
type singleEndpoint net.UDPAddr
|
||||
|
||||
func (e *singleEndpoint) ClearSrc() {}
|
||||
func (e *singleEndpoint) DstIP() net.IP { return (*net.UDPAddr)(e).IP }
|
||||
func (e *singleEndpoint) SrcIP() net.IP { return nil }
|
||||
func (e *singleEndpoint) SrcToString() string { return "" }
|
||||
func (e *singleEndpoint) DstToString() string { return (*net.UDPAddr)(e).String() }
|
||||
func (e *singleEndpoint) DstToBytes() []byte {
|
||||
addr := (*net.UDPAddr)(e)
|
||||
out := addr.IP.To4()
|
||||
if out == nil {
|
||||
out = addr.IP
|
||||
}
|
||||
out = append(out, byte(addr.Port&0xff))
|
||||
out = append(out, byte((addr.Port>>8)&0xff))
|
||||
return out
|
||||
}
|
||||
func (e *singleEndpoint) UpdateDst(dst *net.UDPAddr) error {
|
||||
return fmt.Errorf("magicsock.singleEndpoint(%s).UpdateDst(%s): should never be called", (*net.UDPAddr)(e), dst)
|
||||
}
|
||||
|
||||
// RebindingUDPConn is a UDP socket that can be re-bound.
|
||||
// Unix has no notion of re-binding a socket, so we swap it out for a new one.
|
||||
type RebindingUDPConn struct {
|
||||
mu sync.Mutex
|
||||
pconn *net.UDPConn
|
||||
}
|
||||
|
||||
func (c *RebindingUDPConn) Reset(pconn *net.UDPConn) {
|
||||
c.mu.Lock()
|
||||
old := c.pconn
|
||||
c.pconn = pconn
|
||||
c.mu.Unlock()
|
||||
|
||||
if old != nil {
|
||||
old.Close()
|
||||
}
|
||||
}
|
||||
|
||||
func (c *RebindingUDPConn) ReadFrom(b []byte) (int, net.Addr, error) {
|
||||
for {
|
||||
c.mu.Lock()
|
||||
pconn := c.pconn
|
||||
c.mu.Unlock()
|
||||
|
||||
n, addr, err := pconn.ReadFrom(b)
|
||||
if err != nil {
|
||||
c.mu.Lock()
|
||||
pconn2 := c.pconn
|
||||
c.mu.Unlock()
|
||||
|
||||
if pconn != pconn2 {
|
||||
continue
|
||||
}
|
||||
}
|
||||
return n, addr, err
|
||||
}
|
||||
}
|
||||
|
||||
func (c *RebindingUDPConn) LocalAddr() *net.UDPAddr {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
return c.pconn.LocalAddr().(*net.UDPAddr)
|
||||
}
|
||||
|
||||
func (c *RebindingUDPConn) Close() error {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
return c.pconn.Close()
|
||||
}
|
||||
|
||||
func (c *RebindingUDPConn) WriteToUDP(b []byte, addr *net.UDPAddr) (int, error) {
|
||||
for {
|
||||
c.mu.Lock()
|
||||
pconn := c.pconn
|
||||
c.mu.Unlock()
|
||||
|
||||
n, err := pconn.WriteToUDP(b, addr)
|
||||
if err != nil {
|
||||
c.mu.Lock()
|
||||
pconn2 := c.pconn
|
||||
c.mu.Unlock()
|
||||
|
||||
if pconn != pconn2 {
|
||||
continue
|
||||
}
|
||||
}
|
||||
return n, err
|
||||
}
|
||||
}
|
||||
|
||||
func (c *RebindingUDPConn) WriteTo(b []byte, addr net.Addr) (int, error) {
|
||||
for {
|
||||
c.mu.Lock()
|
||||
pconn := c.pconn
|
||||
c.mu.Unlock()
|
||||
|
||||
n, err := pconn.WriteTo(b, addr)
|
||||
if err != nil {
|
||||
c.mu.Lock()
|
||||
pconn2 := c.pconn
|
||||
c.mu.Unlock()
|
||||
|
||||
if pconn != pconn2 {
|
||||
continue
|
||||
}
|
||||
}
|
||||
return n, err
|
||||
}
|
||||
}
|
73
wgengine/magicsock/magicsock_test.go
Normal file
73
wgengine/magicsock/magicsock_test.go
Normal file
@@ -0,0 +1,73 @@
|
||||
// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package magicsock
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestListen(t *testing.T) {
|
||||
epCh := make(chan string, 16)
|
||||
epFunc := func(endpoints []string) {
|
||||
for _, ep := range endpoints {
|
||||
epCh <- ep
|
||||
}
|
||||
}
|
||||
|
||||
// TODO(crawshaw): break test dependency on the network
|
||||
// using "gortc.io/stun" (like stunner_test.go).
|
||||
stunServers := DefaultSTUN
|
||||
|
||||
port := pickPort(t)
|
||||
conn, err := Listen(Options{
|
||||
Port: port,
|
||||
STUN: stunServers,
|
||||
EndpointsFunc: epFunc,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
go func() {
|
||||
var pkt [1 << 16]byte
|
||||
for {
|
||||
_, _, _, err := conn.ReceiveIPv4(pkt[:])
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
timeout := time.After(10 * time.Second)
|
||||
var endpoints []string
|
||||
suffix := fmt.Sprintf(":%d", port)
|
||||
collectEndpoints:
|
||||
for {
|
||||
select {
|
||||
case ep := <-epCh:
|
||||
endpoints = append(endpoints, ep)
|
||||
if strings.HasSuffix(ep, suffix) {
|
||||
break collectEndpoints
|
||||
}
|
||||
case <-timeout:
|
||||
t.Fatalf("timeout with endpoints: %v", endpoints)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func pickPort(t *testing.T) uint16 {
|
||||
t.Helper()
|
||||
conn, err := net.ListenPacket("udp4", ":0")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer conn.Close()
|
||||
return uint16(conn.LocalAddr().(*net.UDPAddr).Port)
|
||||
}
|
363
wgengine/packet/packet.go
Normal file
363
wgengine/packet/packet.go
Normal file
@@ -0,0 +1,363 @@
|
||||
// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package packet
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log"
|
||||
"net"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type IPProto int
|
||||
|
||||
const (
|
||||
Junk IPProto = iota
|
||||
Fragment
|
||||
ICMP
|
||||
UDP
|
||||
TCP
|
||||
)
|
||||
|
||||
// RFC1858: prevent overlapping fragment attacks.
|
||||
const MIN_FRAG = 60 + 20 // max IPv4 header + basic TCP header
|
||||
|
||||
func (p IPProto) String() string {
|
||||
switch p {
|
||||
case Fragment:
|
||||
return "Frag"
|
||||
case ICMP:
|
||||
return "ICMP"
|
||||
case UDP:
|
||||
return "UDP"
|
||||
case TCP:
|
||||
return "TCP"
|
||||
default:
|
||||
return "Junk"
|
||||
}
|
||||
}
|
||||
|
||||
type IP uint32
|
||||
|
||||
const IPAny = IP(0)
|
||||
|
||||
func NewIP(b net.IP) IP {
|
||||
b4 := b.To4()
|
||||
if b4 == nil {
|
||||
panic(fmt.Sprintf("To4(%v) failed", b))
|
||||
}
|
||||
return IP(binary.BigEndian.Uint32(b4))
|
||||
}
|
||||
|
||||
func (ip IP) String() string {
|
||||
if ip == 0 {
|
||||
return "*"
|
||||
}
|
||||
b := make([]byte, 4)
|
||||
binary.BigEndian.PutUint32(b, uint32(ip))
|
||||
return fmt.Sprintf("%d.%d.%d.%d", b[0], b[1], b[2], b[3])
|
||||
}
|
||||
|
||||
func (ipp *IP) MarshalJSON() ([]byte, error) {
|
||||
s := "\"" + (*ipp).String() + "\""
|
||||
return []byte(s), nil
|
||||
}
|
||||
|
||||
func (ipp *IP) UnmarshalJSON(b []byte) error {
|
||||
var hostp *string
|
||||
err := json.Unmarshal(b, &hostp)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
host := *hostp
|
||||
ip := net.ParseIP(host)
|
||||
if ip != nil && ip.IsUnspecified() {
|
||||
// For clarity, reject 0.0.0.0 as an input
|
||||
return fmt.Errorf("Ports=%#v: to allow all IP addresses, use *:port, not 0.0.0.0:port", host)
|
||||
} else if ip == nil && host == "*" {
|
||||
// User explicitly requested wildcard dst ip
|
||||
*ipp = IPAny
|
||||
} else {
|
||||
if ip != nil {
|
||||
ip = ip.To4()
|
||||
}
|
||||
if ip == nil || len(ip) != 4 {
|
||||
return fmt.Errorf("Ports=%#v: invalid IPv4 address", host)
|
||||
}
|
||||
*ipp = NewIP(ip)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
const (
|
||||
EchoReply uint8 = 0x00
|
||||
EchoRequest uint8 = 0x08
|
||||
)
|
||||
|
||||
const (
|
||||
TCPSyn uint8 = 0x02
|
||||
TCPAck uint8 = 0x10
|
||||
TCPSynAck uint8 = TCPSyn | TCPAck
|
||||
)
|
||||
|
||||
type QDecode struct {
|
||||
b []byte // Packet buffer that this decodes
|
||||
subofs int // byte offset of IP subprotocol
|
||||
|
||||
IPProto IPProto // IP subprotocol (UDP, TCP, etc)
|
||||
SrcIP IP // IP source address
|
||||
DstIP IP // IP destination address
|
||||
SrcPort uint16 // TCP/UDP source port
|
||||
DstPort uint16 // TCP/UDP destination port
|
||||
TCPFlags uint8 // TCP flags (SYN, ACK, etc)
|
||||
}
|
||||
|
||||
func (q QDecode) String() string {
|
||||
if q.IPProto == Junk {
|
||||
return "Junk{}"
|
||||
}
|
||||
srcip := make([]byte, 4)
|
||||
dstip := make([]byte, 4)
|
||||
binary.BigEndian.PutUint32(srcip, uint32(q.SrcIP))
|
||||
binary.BigEndian.PutUint32(dstip, uint32(q.DstIP))
|
||||
return fmt.Sprintf("%v{%d.%d.%d.%d:%d > %d.%d.%d.%d:%d}",
|
||||
q.IPProto,
|
||||
srcip[0], srcip[1], srcip[2], srcip[3], q.SrcPort,
|
||||
dstip[0], dstip[1], dstip[2], dstip[3], q.DstPort)
|
||||
}
|
||||
|
||||
// based on https://tools.ietf.org/html/rfc1071
|
||||
func ipChecksum(b []byte) uint16 {
|
||||
var ac uint32
|
||||
i := 0
|
||||
n := len(b)
|
||||
for n >= 2 {
|
||||
ac += uint32(binary.BigEndian.Uint16(b[i : i+2]))
|
||||
n -= 2
|
||||
i += 2
|
||||
}
|
||||
if n == 1 {
|
||||
ac += uint32(b[i]) << 8
|
||||
}
|
||||
for (ac >> 16) > 0 {
|
||||
ac = (ac >> 16) + (ac & 0xffff)
|
||||
}
|
||||
return uint16(^ac)
|
||||
}
|
||||
|
||||
func GenICMP(srcIP, dstIP IP, ipid uint16, icmpType uint8, icmpCode uint8, payload []byte) []byte {
|
||||
if len(payload) < 4 {
|
||||
return nil
|
||||
}
|
||||
if len(payload) > 65535-24 {
|
||||
return nil
|
||||
}
|
||||
|
||||
sz := 24 + len(payload)
|
||||
out := make([]byte, 24+len(payload))
|
||||
out[0] = 0x45 // IPv4, 20-byte header
|
||||
out[1] = 0x00 // DHCP, ECN
|
||||
binary.BigEndian.PutUint16(out[2:4], uint16(sz))
|
||||
binary.BigEndian.PutUint16(out[4:6], ipid)
|
||||
binary.BigEndian.PutUint16(out[6:8], 0) // flags, offset
|
||||
out[8] = 64 // TTL
|
||||
out[9] = 0x01 // ICMPv4
|
||||
// out[10:12] = 0x00 // blank IP header checksum
|
||||
binary.BigEndian.PutUint32(out[12:16], uint32(srcIP))
|
||||
binary.BigEndian.PutUint32(out[16:20], uint32(dstIP))
|
||||
|
||||
out[20] = icmpType
|
||||
out[21] = icmpCode
|
||||
//out[22:24] = 0x00 // blank ICMP checksum
|
||||
copy(out[24:len(out)], payload)
|
||||
|
||||
binary.BigEndian.PutUint16(out[10:12], ipChecksum(out[0:20]))
|
||||
binary.BigEndian.PutUint16(out[22:24], ipChecksum(out))
|
||||
return out
|
||||
}
|
||||
|
||||
// An extremely simple packet decoder for basic IPv4 packet types.
|
||||
// It extracts only the subprotocol id, IP addresses, and (if any) ports,
|
||||
// and shouldn't need any memory allocation.
|
||||
func (q *QDecode) Decode(b []byte) {
|
||||
q.b = nil
|
||||
|
||||
if len(b) < 20 {
|
||||
q.IPProto = Junk
|
||||
return
|
||||
}
|
||||
// Check that it's IPv4.
|
||||
// TODO(apenwarr): consider IPv6 support
|
||||
if ((b[0] & 0xF0) >> 4) != 4 {
|
||||
q.IPProto = Junk
|
||||
return
|
||||
}
|
||||
|
||||
n := int(binary.BigEndian.Uint16(b[2:4]))
|
||||
if len(b) < n {
|
||||
// Packet was cut off before full IPv4 length.
|
||||
q.IPProto = Junk
|
||||
return
|
||||
}
|
||||
|
||||
// If it's valid IPv4, then the IP addresses are valid
|
||||
q.SrcIP = IP(binary.BigEndian.Uint32(b[12:16]))
|
||||
q.DstIP = IP(binary.BigEndian.Uint32(b[16:20]))
|
||||
|
||||
q.subofs = int((b[0] & 0x0F) * 4)
|
||||
sub := b[q.subofs:]
|
||||
|
||||
// We don't care much about IP fragmentation, except insofar as it's
|
||||
// used for firewall bypass attacks. The trick is make the first
|
||||
// fragment of a TCP or UDP packet so short that it doesn't fit
|
||||
// the TCP or UDP header, so we can't read the port, in hope that
|
||||
// it'll sneak past. Then subsequent fragments fill it in, but we're
|
||||
// missing the first part of the header, so we can't read that either.
|
||||
//
|
||||
// A "perfectly correct" implementation would have to reassemble
|
||||
// fragments before deciding what to do. But the truth is there's
|
||||
// zero reason to send such a short first fragment, so we can treat
|
||||
// it as Junk. We can also treat any subsequent fragment that starts
|
||||
// at such a low offset as Junk.
|
||||
fragFlags := binary.BigEndian.Uint16(b[6:8])
|
||||
moreFrags := (fragFlags & 0x20) != 0
|
||||
fragOfs := fragFlags & 0x1FFF
|
||||
if fragOfs == 0 {
|
||||
// This is the first fragment
|
||||
if moreFrags && len(sub) < MIN_FRAG {
|
||||
// Suspiciously short first fragment, dump it.
|
||||
log.Printf("junk1!\n")
|
||||
q.IPProto = Junk
|
||||
return
|
||||
}
|
||||
// otherwise, this is either non-fragmented (the usual case)
|
||||
// or a big enough initial fragment that we can read the
|
||||
// whole subprotocol header.
|
||||
proto := b[9]
|
||||
switch proto {
|
||||
case 1: // ICMPv4
|
||||
if len(sub) < 8 {
|
||||
q.IPProto = Junk
|
||||
return
|
||||
}
|
||||
q.IPProto = ICMP
|
||||
q.SrcPort = 0
|
||||
q.DstPort = 0
|
||||
q.b = b
|
||||
return
|
||||
case 6: // TCP
|
||||
if len(sub) < 20 {
|
||||
q.IPProto = Junk
|
||||
return
|
||||
}
|
||||
q.IPProto = TCP
|
||||
q.SrcPort = binary.BigEndian.Uint16(sub[0:2])
|
||||
q.DstPort = binary.BigEndian.Uint16(sub[2:4])
|
||||
q.TCPFlags = sub[13] & 0x3F
|
||||
q.b = b
|
||||
return
|
||||
case 17: // UDP
|
||||
if len(sub) < 8 {
|
||||
q.IPProto = Junk
|
||||
return
|
||||
}
|
||||
q.IPProto = UDP
|
||||
q.SrcPort = binary.BigEndian.Uint16(sub[0:2])
|
||||
q.DstPort = binary.BigEndian.Uint16(sub[2:4])
|
||||
q.b = b
|
||||
return
|
||||
default:
|
||||
q.IPProto = Junk
|
||||
return
|
||||
}
|
||||
} else {
|
||||
// This is a fragment other than the first one.
|
||||
if fragOfs < MIN_FRAG {
|
||||
// First frag was suspiciously short, so we can't
|
||||
// trust the followup either.
|
||||
q.IPProto = Junk
|
||||
return
|
||||
}
|
||||
// otherwise, we have to permit the fragment to slide through.
|
||||
// Second and later fragments don't have sub-headers.
|
||||
// Ideally, we would drop fragments that we can't identify,
|
||||
// but that would require statefulness. Anyway, receivers'
|
||||
// kernels know to drop fragments where the initial fragment
|
||||
// doesn't arrive.
|
||||
q.IPProto = Fragment
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Returns a subset of the IP subprotocol section.
|
||||
func (q *QDecode) Sub(begin, n int) []byte {
|
||||
return q.b[q.subofs+begin : q.subofs+begin+n]
|
||||
}
|
||||
|
||||
// For a packet that is known to be IPv4, trim the buffer to its IPv4 length.
|
||||
// Sometimes packets arrive from an interface with extra bytes on the end.
|
||||
// This removes them.
|
||||
func (q *QDecode) Trim() []byte {
|
||||
n := binary.BigEndian.Uint16(q.b[2:4])
|
||||
return q.b[0:n]
|
||||
}
|
||||
|
||||
// For a decoded TCP packet, return true if it's a TCP SYN packet (ie. the
|
||||
// first packet in a new connection).
|
||||
func (q *QDecode) IsTCPSyn() bool {
|
||||
const Syn = 0x02
|
||||
const Ack = 0x10
|
||||
const SynAck = Syn | Ack
|
||||
return (q.TCPFlags & SynAck) == Syn
|
||||
}
|
||||
|
||||
// For a packet that has already been decoded, check if it's an IPv4 ICMP
|
||||
// Echo Request.
|
||||
func (q *QDecode) IsEchoRequest() bool {
|
||||
if q.IPProto == ICMP && len(q.b) >= q.subofs+8 {
|
||||
return q.b[q.subofs] == EchoRequest && q.b[q.subofs+1] == 0
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (q *QDecode) EchoRespond() []byte {
|
||||
icmpid := binary.BigEndian.Uint16(q.Sub(4, 2))
|
||||
b := q.Trim()
|
||||
return GenICMP(q.DstIP, q.SrcIP, icmpid, EchoReply, 0, b[q.subofs+4:])
|
||||
}
|
||||
|
||||
func Hexdump(b []byte) string {
|
||||
out := new(strings.Builder)
|
||||
for i := 0; i < len(b); i += 16 {
|
||||
if i > 0 {
|
||||
fmt.Fprintf(out, "\n")
|
||||
}
|
||||
fmt.Fprintf(out, " %04x ", i)
|
||||
j := 0
|
||||
for ; j < 16 && i+j < len(b); j++ {
|
||||
if j == 8 {
|
||||
fmt.Fprintf(out, " ")
|
||||
}
|
||||
fmt.Fprintf(out, "%02x ", b[i+j])
|
||||
}
|
||||
for ; j < 16; j++ {
|
||||
if j == 8 {
|
||||
fmt.Fprintf(out, " ")
|
||||
}
|
||||
fmt.Fprintf(out, " ")
|
||||
}
|
||||
fmt.Fprintf(out, " ")
|
||||
for j = 0; j < 16 && i+j < len(b); j++ {
|
||||
if b[i+j] >= 32 && b[i+j] < 128 {
|
||||
fmt.Fprintf(out, "%c", b[i+j])
|
||||
} else {
|
||||
fmt.Fprintf(out, ".")
|
||||
}
|
||||
}
|
||||
}
|
||||
return out.String()
|
||||
}
|
36
wgengine/router_darwin.go
Normal file
36
wgengine/router_darwin.go
Normal file
@@ -0,0 +1,36 @@
|
||||
// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package wgengine
|
||||
|
||||
import (
|
||||
"github.com/tailscale/wireguard-go/device"
|
||||
"github.com/tailscale/wireguard-go/tun"
|
||||
"tailscale.com/logger"
|
||||
)
|
||||
|
||||
type darwinRouter struct {
|
||||
tunname string
|
||||
}
|
||||
|
||||
func NewUserspaceRouter(logf logger.Logf, tunname string, dev *device.Device, tuntap tun.Device, netChanged func()) Router {
|
||||
r := darwinRouter{
|
||||
tunname: tunname,
|
||||
}
|
||||
return &r
|
||||
}
|
||||
|
||||
func (r *darwinRouter) Up() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *darwinRouter) SetRoutes(rs RouteSettings) error {
|
||||
if SetRoutesFunc != nil {
|
||||
return SetRoutesFunc(rs)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *darwinRouter) Close() {
|
||||
}
|
17
wgengine/router_default.go
Normal file
17
wgengine/router_default.go
Normal file
@@ -0,0 +1,17 @@
|
||||
// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
// +build !windows,!linux,!darwin
|
||||
|
||||
package wgengine
|
||||
|
||||
import (
|
||||
"github.com/tailscale/wireguard-go/device"
|
||||
"github.com/tailscale/wireguard-go/tun"
|
||||
"tailscale.com/logger"
|
||||
)
|
||||
|
||||
func NewUserspaceRouter(logf logger.Logf, tunname string, dev *device.Device, tuntap tun.Device, netChanged func()) Router {
|
||||
return NewFakeRouter(logf, tunname, dev, tuntap)
|
||||
}
|
38
wgengine/router_fake.go
Normal file
38
wgengine/router_fake.go
Normal file
@@ -0,0 +1,38 @@
|
||||
// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package wgengine
|
||||
|
||||
import (
|
||||
"github.com/tailscale/wireguard-go/device"
|
||||
"github.com/tailscale/wireguard-go/tun"
|
||||
"tailscale.com/logger"
|
||||
)
|
||||
|
||||
type fakeRouter struct {
|
||||
tunname string
|
||||
logf logger.Logf
|
||||
}
|
||||
|
||||
func NewFakeRouter(logf logger.Logf, tunname string, dev *device.Device, tuntap tun.Device, netChanged func()) Router {
|
||||
r := fakeRouter{
|
||||
logf: logf,
|
||||
tunname: tunname,
|
||||
}
|
||||
return &r
|
||||
}
|
||||
|
||||
func (r *fakeRouter) Up() error {
|
||||
r.logf("Warning: fakeRouter.Up: not implemented.\n")
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *fakeRouter) SetRoutes(rs RouteSettings) error {
|
||||
r.logf("Warning: fakeRouter.SetRoutes: not implemented.\n")
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *fakeRouter) Close() {
|
||||
r.logf("Warning: fakeRouter.Close: not implemented.\n")
|
||||
}
|
267
wgengine/router_linux.go
Normal file
267
wgengine/router_linux.go
Normal file
@@ -0,0 +1,267 @@
|
||||
// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package wgengine
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"log"
|
||||
"net"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"github.com/tailscale/wireguard-go/device"
|
||||
"github.com/tailscale/wireguard-go/tun"
|
||||
"github.com/tailscale/wireguard-go/wgcfg"
|
||||
"tailscale.com/atomicfile"
|
||||
"tailscale.com/logger"
|
||||
"tailscale.com/wgengine/rtnlmon"
|
||||
)
|
||||
|
||||
type linuxRouter struct {
|
||||
logf func(fmt string, args ...interface{})
|
||||
tunname string
|
||||
mon *rtnlmon.Mon
|
||||
netChanged func()
|
||||
local wgcfg.CIDR
|
||||
routes map[wgcfg.CIDR]struct{}
|
||||
}
|
||||
|
||||
func NewUserspaceRouter(logf logger.Logf, tunname string, dev *device.Device, tuntap tun.Device, netChanged func()) Router {
|
||||
mon, err := rtnlmon.New(logf, netChanged)
|
||||
if err != nil {
|
||||
log.Fatalf("rtnlmon.New() failed: %v", err)
|
||||
}
|
||||
|
||||
r := linuxRouter{
|
||||
logf: logf,
|
||||
tunname: tunname,
|
||||
mon: mon,
|
||||
netChanged: netChanged,
|
||||
}
|
||||
return &r
|
||||
}
|
||||
|
||||
func cmd(args ...string) *exec.Cmd {
|
||||
if len(args) == 0 {
|
||||
log.Fatalf("exec.Cmd(%#v) invalid; need argv[0]\n", args)
|
||||
}
|
||||
return exec.Command(args[0], args[1:]...)
|
||||
}
|
||||
|
||||
func (r *linuxRouter) Up() error {
|
||||
out, err := cmd("ip", "link", "set", r.tunname, "up").CombinedOutput()
|
||||
if err != nil {
|
||||
log.Fatalf("running ip link failed: %v\n%s", err, out)
|
||||
}
|
||||
|
||||
// TODO(apenwarr): This never cleans up after itself!
|
||||
out, err = cmd("iptables",
|
||||
"-A", "FORWARD",
|
||||
"-i", r.tunname,
|
||||
"-j", "ACCEPT").CombinedOutput()
|
||||
if err != nil {
|
||||
r.logf("iptables forward failed: %v\n%s", err, out)
|
||||
}
|
||||
// TODO(apenwarr): hardcoded eth0 interface is obviously not right.
|
||||
out, err = cmd("iptables",
|
||||
"-t", "nat",
|
||||
"-A", "POSTROUTING",
|
||||
"-o", "eth0",
|
||||
"-j", "MASQUERADE").CombinedOutput()
|
||||
if err != nil {
|
||||
r.logf("iptables nat failed: %v\n%s", err, out)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *linuxRouter) SetRoutes(rs RouteSettings) error {
|
||||
var errq error
|
||||
|
||||
if rs.LocalAddr != r.local {
|
||||
if r.local != (wgcfg.CIDR{}) {
|
||||
addrdel := []string{"ip", "addr",
|
||||
"del", r.local.String(),
|
||||
"dev", r.tunname}
|
||||
out, err := cmd(addrdel...).CombinedOutput()
|
||||
if err != nil {
|
||||
r.logf("addr del failed: %v: %v\n%s", addrdel, err, out)
|
||||
if errq == nil {
|
||||
errq = err
|
||||
}
|
||||
}
|
||||
}
|
||||
addradd := []string{"ip", "addr",
|
||||
"add", rs.LocalAddr.String(),
|
||||
"dev", r.tunname}
|
||||
out, err := cmd(addradd...).CombinedOutput()
|
||||
if err != nil {
|
||||
r.logf("addr add failed: %v: %v\n%s", addradd, err, out)
|
||||
if errq == nil {
|
||||
errq = err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
newRoutes := make(map[wgcfg.CIDR]struct{})
|
||||
for _, peer := range rs.Cfg.Peers {
|
||||
for _, route := range peer.AllowedIPs {
|
||||
newRoutes[route] = struct{}{}
|
||||
}
|
||||
}
|
||||
for route := range r.routes {
|
||||
if _, keep := newRoutes[route]; !keep {
|
||||
net := route.IPNet()
|
||||
nip := net.IP.Mask(net.Mask)
|
||||
nstr := fmt.Sprintf("%v/%d", nip, route.Mask)
|
||||
addrdel := []string{"ip", "route",
|
||||
"del", nstr,
|
||||
"via", r.local.IP.String(),
|
||||
"dev", r.tunname}
|
||||
out, err := cmd(addrdel...).CombinedOutput()
|
||||
if err != nil {
|
||||
r.logf("addr del failed: %v: %v\n%s", addrdel, err, out)
|
||||
if errq == nil {
|
||||
errq = err
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
for route := range newRoutes {
|
||||
if _, exists := r.routes[route]; !exists {
|
||||
net := route.IPNet()
|
||||
nip := net.IP.Mask(net.Mask)
|
||||
nstr := fmt.Sprintf("%v/%d", nip, route.Mask)
|
||||
addradd := []string{"ip", "route",
|
||||
"add", nstr,
|
||||
"via", rs.LocalAddr.IP.String(),
|
||||
"dev", r.tunname}
|
||||
out, err := cmd(addradd...).CombinedOutput()
|
||||
if err != nil {
|
||||
r.logf("addr add failed: %v: %v\n%s", addradd, err, out)
|
||||
if errq == nil {
|
||||
errq = err
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
r.local = rs.LocalAddr
|
||||
r.routes = newRoutes
|
||||
|
||||
if false {
|
||||
if err := r.replaceResolvConf(rs.DNS, rs.DNSDomains); err != nil {
|
||||
errq = fmt.Errorf("replacing resolv.conf failed: %v", err)
|
||||
}
|
||||
}
|
||||
return errq
|
||||
}
|
||||
|
||||
func (r *linuxRouter) Close() {
|
||||
r.mon.Close()
|
||||
if err := r.restoreResolvConf(); err != nil {
|
||||
r.logf("failed to restore system resolv.conf: %v", err)
|
||||
}
|
||||
// TODO(apenwarr): clean up iptables etc.
|
||||
}
|
||||
|
||||
const (
|
||||
tsConf = "/etc/resolv.tailscale.conf"
|
||||
backupConf = "/etc/resolv.pre-tailscale-backup.conf"
|
||||
resolvConf = "/etc/resolv.conf"
|
||||
)
|
||||
|
||||
func (r *linuxRouter) replaceResolvConf(servers []net.IP, domains []string) error {
|
||||
if len(servers) == 0 {
|
||||
return r.restoreResolvConf()
|
||||
}
|
||||
|
||||
// First write the tsConf file.
|
||||
buf := new(bytes.Buffer)
|
||||
fmt.Fprintf(buf, "# resolv.conf(5) file generated by tailscale\n")
|
||||
fmt.Fprintf(buf, "# DO NOT EDIT THIS FILE BY HAND -- CHANGES WILL BE OVERWRITTEN\n\n")
|
||||
for _, ns := range servers {
|
||||
fmt.Fprintf(buf, "nameserver %s\n", ns)
|
||||
}
|
||||
if len(domains) > 0 {
|
||||
fmt.Fprintf(buf, "search "+strings.Join(domains, " ")+"\n")
|
||||
}
|
||||
f, err := ioutil.TempFile(filepath.Dir(tsConf), filepath.Base(tsConf)+".*")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
f.Close()
|
||||
if err := atomicfile.WriteFile(f.Name(), buf.Bytes(), 0644); err != nil {
|
||||
return err
|
||||
}
|
||||
os.Chmod(f.Name(), 0644) // ioutil.TempFile creates the file with 0600
|
||||
if err := os.Rename(f.Name(), tsConf); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if linkPath, err := os.Readlink(resolvConf); err != nil {
|
||||
// Remove any old backup that may exist.
|
||||
os.Remove(backupConf)
|
||||
|
||||
// Backup the existing /etc/resolv.conf file.
|
||||
contents, err := ioutil.ReadFile(resolvConf)
|
||||
if os.IsNotExist(err) {
|
||||
// No existing /etc/resolve.conf file to backup.
|
||||
// Nothing to do.
|
||||
return nil
|
||||
} else if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := atomicfile.WriteFile(backupConf, contents, 0644); err != nil {
|
||||
return err
|
||||
}
|
||||
} else if linkPath != tsConf {
|
||||
// Backup the existing symlink.
|
||||
os.Remove(backupConf)
|
||||
if err := os.Symlink(linkPath, backupConf); err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
// Nothing to do, resolvConf already points to tsConf.
|
||||
return nil
|
||||
}
|
||||
|
||||
os.Remove(resolvConf)
|
||||
if err := os.Symlink(tsConf, resolvConf); err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
out, _ := exec.Command("service", "systemd-resolved", "restart").CombinedOutput()
|
||||
if len(out) > 0 {
|
||||
r.logf("service systemd-resolved restart: %s", out)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *linuxRouter) restoreResolvConf() error {
|
||||
if _, err := os.Stat(backupConf); err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
return nil // no backup resolve.conf to restore
|
||||
}
|
||||
return err
|
||||
}
|
||||
if ln, err := os.Readlink(resolvConf); err != nil {
|
||||
return err
|
||||
} else if ln != tsConf {
|
||||
return fmt.Errorf("resolve.conf is not a symlink to %s", tsConf)
|
||||
}
|
||||
if err := os.Rename(backupConf, resolvConf); err != nil {
|
||||
return err
|
||||
}
|
||||
os.Remove(tsConf) // best effort removal of tsConf file
|
||||
out, _ := exec.Command("service", "systemd-resolved", "restart").CombinedOutput()
|
||||
if len(out) > 0 {
|
||||
r.logf("service systemd-resolved restart: %s", out)
|
||||
}
|
||||
return nil
|
||||
}
|
58
wgengine/router_windows.go
Normal file
58
wgengine/router_windows.go
Normal file
@@ -0,0 +1,58 @@
|
||||
// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package wgengine
|
||||
|
||||
import (
|
||||
"log"
|
||||
|
||||
"github.com/tailscale/wireguard-go/device"
|
||||
"github.com/tailscale/wireguard-go/tun"
|
||||
"golang.zx2c4.com/winipcfg"
|
||||
"tailscale.com/logger"
|
||||
)
|
||||
|
||||
type winRouter struct {
|
||||
logf func(fmt string, args ...interface{})
|
||||
tunname string
|
||||
dev *device.Device
|
||||
nativeTun *tun.NativeTun
|
||||
routeChangeCallback *winipcfg.RouteChangeCallback
|
||||
}
|
||||
|
||||
func NewUserspaceRouter(logf logger.Logf, tunname string, dev *device.Device, tuntap tun.Device, netChanged func()) Router {
|
||||
r := winRouter{
|
||||
logf: logf,
|
||||
tunname: tunname,
|
||||
dev: dev,
|
||||
nativeTun: tuntap.(*tun.NativeTun),
|
||||
}
|
||||
return &r
|
||||
}
|
||||
|
||||
func (r *winRouter) Up() error {
|
||||
// MonitorDefaultRoutes handles making sure our wireguard UDP
|
||||
// traffic goes through the old route, not recursively through the VPN.
|
||||
var err error
|
||||
r.routeChangeCallback, err = MonitorDefaultRoutes(r.dev, true, r.nativeTun)
|
||||
if err != nil {
|
||||
log.Fatalf("MonitorDefaultRoutes: %v\n", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *winRouter) SetRoutes(rs RouteSettings) error {
|
||||
err := ConfigureInterface(&rs.Cfg, r.nativeTun, rs.DNS, rs.DNSDomains)
|
||||
if err != nil {
|
||||
r.logf("ConfigureInterface: %v\n", err)
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *winRouter) Close() {
|
||||
if r.routeChangeCallback != nil {
|
||||
r.routeChangeCallback.Unregister()
|
||||
}
|
||||
}
|
114
wgengine/rtnlmon/mon.go
Normal file
114
wgengine/rtnlmon/mon.go
Normal file
@@ -0,0 +1,114 @@
|
||||
// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
// Package rtnlmon watches for "interesting" changes to the network
|
||||
// stack and fires a callback.
|
||||
package rtnlmon
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/mdlayher/netlink"
|
||||
"golang.org/x/sys/unix"
|
||||
"tailscale.com/logger"
|
||||
)
|
||||
|
||||
// Netlink is not a great protocol for *knowing* things. The protocol
|
||||
// design makes it impossible to track changes precisely. You can see
|
||||
// this by looking at things like Quagga or Bird, which all include
|
||||
// keeping a local impression of what they think is in the kernel, and
|
||||
// periodically doing a full state dump to find errors. They do use
|
||||
// events, but explicitly only as an optimization, because they can't
|
||||
// be trusted.
|
||||
//
|
||||
// Fortunately, we don't really need to know what exactly changed. We
|
||||
// just want to know that network conditions may have changed, and we
|
||||
// should re-explore connectivity. This is why we subscribe to events,
|
||||
// and then blindly fire our callback without looking at the content
|
||||
// of the notifications.
|
||||
|
||||
type ChangeFunc func()
|
||||
|
||||
type Mon struct {
|
||||
logf logger.Logf
|
||||
cb ChangeFunc
|
||||
nl *netlink.Conn
|
||||
change chan struct{}
|
||||
stop chan struct{}
|
||||
}
|
||||
|
||||
func New(logf logger.Logf, callback ChangeFunc) (*Mon, error) {
|
||||
conn, err := netlink.Dial(unix.NETLINK_ROUTE, &netlink.Config{
|
||||
// IPv4 address and route changes. Routes get us most of the
|
||||
// events of interest, but we need address as well to cover
|
||||
// things like DHCP deciding to give us a new address upon
|
||||
// renewal - routing wouldn't change, but all reachability
|
||||
// would.
|
||||
//
|
||||
// Why magic numbers? These aren't exposed in x/sys/unix
|
||||
// yet. The values come from rtnetlink.h, RTMGRP_IPV4_IFADDR
|
||||
// and RTMGRP_IPV4_ROUTE.
|
||||
Groups: 0x10 | 0x40,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("dialing netlink socket: %v", err)
|
||||
}
|
||||
|
||||
ret := &Mon{
|
||||
logf: logf,
|
||||
cb: callback,
|
||||
nl: conn,
|
||||
change: make(chan struct{}, 1),
|
||||
stop: make(chan struct{}),
|
||||
}
|
||||
go ret.pump()
|
||||
go ret.debounce()
|
||||
return ret, nil
|
||||
}
|
||||
|
||||
func (m *Mon) Close() error {
|
||||
close(m.stop)
|
||||
return m.nl.Close()
|
||||
}
|
||||
|
||||
func (m *Mon) pump() {
|
||||
for {
|
||||
_, err := m.nl.Receive()
|
||||
if err != nil {
|
||||
select {
|
||||
case <-m.stop:
|
||||
return
|
||||
default:
|
||||
}
|
||||
// Keep retrying while we're not closed.
|
||||
m.logf("Error receiving from netlink: %v", err)
|
||||
time.Sleep(time.Second)
|
||||
continue
|
||||
}
|
||||
|
||||
select {
|
||||
case m.change <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Mon) debounce() {
|
||||
for {
|
||||
select {
|
||||
case <-m.stop:
|
||||
return
|
||||
case <-m.change:
|
||||
}
|
||||
|
||||
m.cb()
|
||||
|
||||
select {
|
||||
case <-m.stop:
|
||||
return
|
||||
case <-time.After(100 * time.Millisecond):
|
||||
}
|
||||
}
|
||||
}
|
21
wgengine/rusage.go
Normal file
21
wgengine/rusage.go
Normal file
@@ -0,0 +1,21 @@
|
||||
// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package wgengine
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"runtime"
|
||||
)
|
||||
|
||||
func RusagePrefixLog(logf func(f string, argv ...interface{})) func(f string, argv ...interface{}) {
|
||||
return func(f string, argv ...interface{}) {
|
||||
var m runtime.MemStats
|
||||
runtime.ReadMemStats(&m)
|
||||
goMem := float64(m.HeapInuse+m.StackInuse) / (1 << 20)
|
||||
maxRSS := rusageMaxRSS()
|
||||
pf := fmt.Sprintf("%.1fM/%.1fM %s", goMem, maxRSS, f)
|
||||
logf(pf, argv...)
|
||||
}
|
||||
}
|
29
wgengine/rusage_nowindows.go
Normal file
29
wgengine/rusage_nowindows.go
Normal file
@@ -0,0 +1,29 @@
|
||||
// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
// +build !windows
|
||||
|
||||
package wgengine
|
||||
|
||||
import (
|
||||
"runtime"
|
||||
"syscall"
|
||||
)
|
||||
|
||||
func rusageMaxRSS() float64 {
|
||||
var ru syscall.Rusage
|
||||
err := syscall.Getrusage(syscall.RUSAGE_SELF, &ru)
|
||||
if err != nil {
|
||||
return 0
|
||||
}
|
||||
|
||||
rss := float64(ru.Maxrss)
|
||||
if runtime.GOOS == "darwin" {
|
||||
rss /= 1 << 20 // ru_maxrss is bytes on darwin
|
||||
} else {
|
||||
// ru_maxrss is kilobytes elsewhere (linux, openbsd, etc)
|
||||
rss /= 1024
|
||||
}
|
||||
return rss
|
||||
}
|
10
wgengine/rusage_windows.go
Normal file
10
wgengine/rusage_windows.go
Normal file
@@ -0,0 +1,10 @@
|
||||
// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package wgengine
|
||||
|
||||
func rusageMaxRSS() float64 {
|
||||
// TODO(apenwarr): Substitute Windows equivalent of Getrusage() here.
|
||||
return 0
|
||||
}
|
477
wgengine/userspace.go
Normal file
477
wgengine/userspace.go
Normal file
@@ -0,0 +1,477 @@
|
||||
// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package wgengine
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"fmt"
|
||||
"log"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/tailscale/wireguard-go/device"
|
||||
"github.com/tailscale/wireguard-go/tun"
|
||||
"github.com/tailscale/wireguard-go/wgcfg"
|
||||
"tailscale.com/logger"
|
||||
"tailscale.com/tailcfg"
|
||||
"tailscale.com/wgengine/filter"
|
||||
"tailscale.com/wgengine/magicsock"
|
||||
"tailscale.com/wgengine/packet"
|
||||
)
|
||||
|
||||
type userspaceEngine struct {
|
||||
logf logger.Logf
|
||||
statusCallback StatusCallback
|
||||
reqCh chan struct{}
|
||||
waitCh chan struct{}
|
||||
tuntap tun.Device
|
||||
wgdev *device.Device
|
||||
router Router
|
||||
magicConn *magicsock.Conn
|
||||
|
||||
wgLock sync.Mutex // serializes all wgdev operations
|
||||
lastReconfig string
|
||||
lastRoutes string
|
||||
|
||||
mu sync.Mutex
|
||||
peerSequence []wgcfg.Key
|
||||
endpoints []string
|
||||
}
|
||||
|
||||
type Loggify struct {
|
||||
f logger.Logf
|
||||
}
|
||||
|
||||
func (l *Loggify) Write(b []byte) (int, error) {
|
||||
l.f(string(b))
|
||||
return len(b), nil
|
||||
}
|
||||
|
||||
func NewFakeUserspaceEngine(logf logger.Logf, listenPort uint16, derp bool) (Engine, error) {
|
||||
logf("Starting userspace wireguard engine (FAKE tuntap device).")
|
||||
tun := NewFakeTun()
|
||||
return NewUserspaceEngineAdvanced(logf, tun, NewFakeRouter, listenPort, derp)
|
||||
}
|
||||
|
||||
func NewUserspaceEngine(logf logger.Logf, tunname string, listenPort uint16, derp bool) (Engine, error) {
|
||||
logf("Starting userspace wireguard engine.")
|
||||
logf("external packet routing via --tun=%s enabled", tunname)
|
||||
|
||||
if tunname == "" {
|
||||
return nil, fmt.Errorf("--tun name must not be blank.")
|
||||
}
|
||||
|
||||
tuntap, err := tun.CreateTUN(tunname, device.DefaultMTU)
|
||||
if err != nil {
|
||||
log.Printf("CreateTUN: %v\n", err)
|
||||
return nil, err
|
||||
}
|
||||
log.Printf("CreateTUN ok.\n")
|
||||
|
||||
e, err := NewUserspaceEngineAdvanced(logf, tuntap, NewUserspaceRouter, listenPort, derp)
|
||||
if err != nil {
|
||||
log.Printf("NewUserspaceEngineAdv: %v\n", err)
|
||||
return nil, err
|
||||
}
|
||||
return e, err
|
||||
}
|
||||
|
||||
type RouterGen func(logf logger.Logf, tunname string, dev *device.Device, tuntap tun.Device, netStateChanged func()) Router
|
||||
|
||||
func NewUserspaceEngineAdvanced(logf logger.Logf, tuntap tun.Device, routerGen RouterGen, listenPort uint16, derp bool) (Engine, error) {
|
||||
e := &userspaceEngine{
|
||||
logf: logf,
|
||||
reqCh: make(chan struct{}, 1),
|
||||
waitCh: make(chan struct{}),
|
||||
tuntap: tuntap,
|
||||
}
|
||||
|
||||
tunname, err := tuntap.Name()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
endpointsFn := func(endpoints []string) {
|
||||
e.mu.Lock()
|
||||
if e.endpoints != nil {
|
||||
e.endpoints = e.endpoints[:0]
|
||||
}
|
||||
e.endpoints = append(e.endpoints, endpoints...)
|
||||
e.mu.Unlock()
|
||||
|
||||
e.RequestStatus()
|
||||
}
|
||||
magicsockOpts := magicsock.Options{
|
||||
Port: listenPort,
|
||||
STUN: magicsock.DefaultSTUN,
|
||||
// TODO(crawshaw): DERP: magicsock.DefaultDERP,
|
||||
EndpointsFunc: endpointsFn,
|
||||
}
|
||||
if derp {
|
||||
magicsockOpts.DERP = magicsock.DefaultDERP
|
||||
}
|
||||
e.magicConn, err = magicsock.Listen(magicsockOpts)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("wgengine: %v", err)
|
||||
}
|
||||
|
||||
// flags==0 because logf is already nested in another logger.
|
||||
// The outer one can display the preferred log prefixes, etc.
|
||||
dlog := log.New(&Loggify{logf}, "", 0)
|
||||
logger := device.Logger{
|
||||
Debug: dlog,
|
||||
Info: dlog,
|
||||
Error: dlog,
|
||||
}
|
||||
nofilter := func(b []byte) device.FilterResult {
|
||||
// for safety, default to dropping all packets
|
||||
logf("Warning: you forgot to use wgengine.SetFilterInOut()! Packet dropped.\n")
|
||||
return device.FilterDrop
|
||||
}
|
||||
|
||||
opts := &device.DeviceOptions{
|
||||
Logger: &logger,
|
||||
FilterIn: nofilter,
|
||||
FilterOut: nofilter,
|
||||
HandshakeDone: func() {
|
||||
// Send an unsolicited status event every time a
|
||||
// handshake completes. This makes sure our UI can
|
||||
// update quickly as soon as it connects to a peer.
|
||||
//
|
||||
// We use a goroutine here to avoid deadlocking
|
||||
// wireguard, since RequestStatus() will call back
|
||||
// into it, and wireguard is what called us to get
|
||||
// here.
|
||||
go e.RequestStatus()
|
||||
},
|
||||
CreateBind: func(uint16) (device.Bind, uint16, error) {
|
||||
return e.magicConn, e.magicConn.LocalPort(), nil
|
||||
},
|
||||
CreateEndpoint: e.magicConn.CreateEndpoint,
|
||||
SkipBindUpdate: true,
|
||||
}
|
||||
|
||||
e.wgdev = device.NewDevice(e.tuntap, opts)
|
||||
|
||||
go func() {
|
||||
up := false
|
||||
for event := range e.tuntap.Events() {
|
||||
if event&tun.EventMTUUpdate != 0 {
|
||||
mtu, err := e.tuntap.MTU()
|
||||
e.logf("external route MTU: %d (%v)", mtu, err)
|
||||
}
|
||||
if event&tun.EventUp != 0 && !up {
|
||||
e.logf("external route: up")
|
||||
e.RequestStatus()
|
||||
up = true
|
||||
}
|
||||
if event&tun.EventDown != 0 && up {
|
||||
e.logf("external route: down")
|
||||
e.RequestStatus()
|
||||
up = false
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
e.router = routerGen(logf, tunname, e.wgdev, e.tuntap, func() { e.LinkChange(false) })
|
||||
e.wgdev.Up()
|
||||
if err := e.router.Up(); err != nil {
|
||||
e.wgdev.Close()
|
||||
return nil, err
|
||||
}
|
||||
if err := e.router.SetRoutes(RouteSettings{}); err != nil {
|
||||
e.wgdev.Close()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return e, nil
|
||||
}
|
||||
|
||||
// TODO(apenwarr): dnsDomains really ought to be in wgcfg.Config.
|
||||
// However, we don't actually ever provide it to wireguard and it's not in
|
||||
// the traditional wireguard config format. On the other hand, wireguard
|
||||
// itself doesn't use the traditional 'dns =' setting either.
|
||||
func (e *userspaceEngine) Reconfig(cfg *wgcfg.Config, dnsDomains []string) error {
|
||||
e.logf("Reconfig(): configuring userspace wireguard engine.\n")
|
||||
e.wgLock.Lock()
|
||||
defer e.wgLock.Unlock()
|
||||
|
||||
e.peerSequence = make([]wgcfg.Key, len(cfg.Peers))
|
||||
for i, p := range cfg.Peers {
|
||||
e.peerSequence[i] = p.PublicKey
|
||||
}
|
||||
|
||||
// TODO(apenwarr): get rid of silly uapi stuff for in-process comms
|
||||
uapi, err := cfg.ToUAPI()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
rc := uapi + "\x00" + strings.Join(dnsDomains, "\x00")
|
||||
if rc == e.lastReconfig {
|
||||
e.logf("...unchanged config, skipping.\n")
|
||||
return nil
|
||||
}
|
||||
e.lastReconfig = rc
|
||||
|
||||
r := bufio.NewReader(strings.NewReader(uapi))
|
||||
if err = e.wgdev.IpcSetOperation(r); err != nil {
|
||||
e.logf("IpcSetOperation: %v\n", err)
|
||||
return err
|
||||
}
|
||||
|
||||
if err := e.magicConn.SetPrivateKey(cfg.Interface.PrivateKey); err != nil {
|
||||
e.logf("magicsock: %v\n", err)
|
||||
}
|
||||
|
||||
// TODO(apenwarr): only handling the first local address.
|
||||
// Currently we never use more than one anyway.
|
||||
var cidr wgcfg.CIDR
|
||||
if len(cfg.Interface.Addresses) > 0 {
|
||||
cidr = cfg.Interface.Addresses[0]
|
||||
// TODO(apenwarr): this shouldn't be hardcoded in the client
|
||||
cidr.Mask = 10 // route the whole cgnat range
|
||||
}
|
||||
|
||||
rs := RouteSettings{
|
||||
LocalAddr: cidr,
|
||||
Cfg: *cfg,
|
||||
DNS: cfg.Interface.Dns,
|
||||
DNSDomains: dnsDomains,
|
||||
}
|
||||
e.logf("Reconfiguring router. la=%v dns=%v dom=%v\n",
|
||||
rs.LocalAddr, rs.DNS, rs.DNSDomains)
|
||||
|
||||
// TODO(apenwarr): all the parts of RouteSettings should be "relevant."
|
||||
// We're checking only the "relevant" parts to see if they have
|
||||
// changed, and if not, skipping SetRoutes(). But if SetRoutes()
|
||||
// is getting the non-relevant parts of Cfg, it might act on them,
|
||||
// and this optimization is unsafe. Probably we should not pass
|
||||
// a whole Cfg object as part of RouteSettings; instead, trim it to
|
||||
// just what's absolutely needed (the set of actual routes).
|
||||
rss := rs.OnlyRelevantParts()
|
||||
e.logf("New routes: %v\n", rss)
|
||||
if rss == e.lastRoutes {
|
||||
e.logf("...unchanged routes, skipping.\n")
|
||||
return nil
|
||||
}
|
||||
e.lastRoutes = rss
|
||||
err = e.router.SetRoutes(rs)
|
||||
e.logf("Reconfig() done.\n")
|
||||
return err
|
||||
}
|
||||
|
||||
func (e *userspaceEngine) SetFilter(filt *filter.Filter) {
|
||||
var filtin, filtout func(b []byte) device.FilterResult
|
||||
if filt == nil {
|
||||
e.logf("wgengine: nil filter provided; no access restrictions.\n")
|
||||
} else {
|
||||
ft, ft_ok := e.tuntap.(*fakeTun)
|
||||
filtin = func(b []byte) device.FilterResult {
|
||||
runf := filter.LogDrops
|
||||
//runf |= filter.HexdumpDrops
|
||||
runf |= filter.LogAccepts
|
||||
//runf |= filter.HexdumpAccepts
|
||||
q := &packet.QDecode{}
|
||||
if filt.RunIn(b, q, runf) == filter.Accept {
|
||||
// Only in fake mode, answer any incoming pings
|
||||
if ft_ok && q.IsEchoRequest() {
|
||||
pb := q.EchoRespond()
|
||||
ft.InsertRead(pb)
|
||||
// We already handled it, stop.
|
||||
return device.FilterDrop
|
||||
}
|
||||
return device.FilterAccept
|
||||
}
|
||||
return device.FilterDrop
|
||||
}
|
||||
|
||||
filtout = func(b []byte) device.FilterResult {
|
||||
runf := filter.LogDrops
|
||||
//runf |= filter.HexdumpDrops
|
||||
runf |= filter.LogAccepts
|
||||
//runf |= filter.HexdumpAccepts
|
||||
q := &packet.QDecode{}
|
||||
if filt.RunOut(b, q, runf) == filter.Accept {
|
||||
return device.FilterAccept
|
||||
}
|
||||
return device.FilterDrop
|
||||
}
|
||||
}
|
||||
|
||||
e.wgLock.Lock()
|
||||
defer e.wgLock.Unlock()
|
||||
|
||||
e.wgdev.SetFilterInOut(filtin, filtout)
|
||||
}
|
||||
|
||||
func (e *userspaceEngine) SetStatusCallback(cb StatusCallback) {
|
||||
e.statusCallback = cb
|
||||
}
|
||||
|
||||
func (e *userspaceEngine) getStatus() (*Status, error) {
|
||||
e.wgLock.Lock()
|
||||
defer e.wgLock.Unlock()
|
||||
|
||||
if e.wgdev == nil {
|
||||
// RequestStatus was invoked before the wgengine has
|
||||
// finished initializing. This can happen when wgegine
|
||||
// provides a callback to magicsock for endpoint
|
||||
// updates that calls RequestStatus.
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// TODO(apenwarr): get rid of silly uapi stuff for in-process comms
|
||||
// FIXME: get notified of status changes instead of polling.
|
||||
var bb strings.Builder
|
||||
bio := bufio.NewWriter(&bb)
|
||||
ipcErr := e.wgdev.IpcGetOperation(bio)
|
||||
if ipcErr != nil {
|
||||
log.Fatalf("IpcGetOperation: %v\n", ipcErr)
|
||||
}
|
||||
bio.Flush()
|
||||
|
||||
s := Status{}
|
||||
pp := make(map[wgcfg.Key]*PeerStatus)
|
||||
var p *PeerStatus = &PeerStatus{}
|
||||
bbs := bb.String()
|
||||
lines := strings.Split(bbs, "\n")
|
||||
var hst1, hst2, n int64
|
||||
var err error
|
||||
for _, line := range lines {
|
||||
kv := strings.SplitN(line, "=", 2)
|
||||
var k, v string
|
||||
k = kv[0]
|
||||
if len(kv) > 1 {
|
||||
v = kv[1]
|
||||
}
|
||||
switch k {
|
||||
case "public_key":
|
||||
pk, err := wgcfg.ParseHexKey(v)
|
||||
if err != nil {
|
||||
log.Fatalf("IpcGetOperation: invalid key %#v\n", v)
|
||||
}
|
||||
p = &PeerStatus{}
|
||||
pp[*pk] = p
|
||||
|
||||
key := tailcfg.NodeKey(*pk)
|
||||
p.NodeKey = key
|
||||
case "rx_bytes":
|
||||
n, err = strconv.ParseInt(v, 10, 64)
|
||||
p.RxBytes = ByteCount(n)
|
||||
if err != nil {
|
||||
log.Fatalf("IpcGetOperation: rx_bytes invalid: %#v\n", line)
|
||||
}
|
||||
case "tx_bytes":
|
||||
n, err = strconv.ParseInt(v, 10, 64)
|
||||
p.TxBytes = ByteCount(n)
|
||||
if err != nil {
|
||||
log.Fatalf("IpcGetOperation: tx_bytes invalid: %#v\n", line)
|
||||
}
|
||||
case "last_handshake_time_sec":
|
||||
hst1, err = strconv.ParseInt(v, 10, 64)
|
||||
if err != nil {
|
||||
log.Fatalf("IpcGetOperation: hst1 invalid: %#v\n", line)
|
||||
}
|
||||
case "last_handshake_time_nsec":
|
||||
hst2, err = strconv.ParseInt(v, 10, 64)
|
||||
if err != nil {
|
||||
log.Fatalf("IpcGetOperation: hst2 invalid: %#v\n", line)
|
||||
}
|
||||
if hst1 != 0 || hst2 != 0 {
|
||||
p.LastHandshake = time.Unix(hst1, hst2)
|
||||
} // else leave at time.IsZero()
|
||||
}
|
||||
}
|
||||
|
||||
e.mu.Lock()
|
||||
defer e.mu.Unlock()
|
||||
|
||||
var peers []PeerStatus
|
||||
for _, pk := range e.peerSequence {
|
||||
p := pp[pk]
|
||||
if p == nil {
|
||||
p = &PeerStatus{}
|
||||
}
|
||||
peers = append(peers, *p)
|
||||
}
|
||||
|
||||
if len(pp) != len(e.peerSequence) {
|
||||
e.logf("wg status returned %v peers, expected %v\n", len(s.Peers), len(e.peerSequence))
|
||||
}
|
||||
|
||||
return &Status{
|
||||
LocalAddrs: append([]string(nil), e.endpoints...),
|
||||
Peers: peers,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (e *userspaceEngine) RequestStatus() {
|
||||
// This is slightly tricky. e.getStatus() can theoretically get
|
||||
// blocked inside wireguard for a while, and RequestStatus() is
|
||||
// sometimes called from a goroutine, so we don't want a lot of
|
||||
// them hanging around. On the other hand, requesting multiple
|
||||
// status updates simultaneously is pointless anyway; they will
|
||||
// all say the same thing.
|
||||
|
||||
// Enqueue at most one request. If one is in progress already, this
|
||||
// adds one more to the queue. If one has been requested but not
|
||||
// started, it is a no-op.
|
||||
select {
|
||||
case e.reqCh <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
|
||||
// Dequeue at most one request. Another thread may have already
|
||||
// dequeued the request we enqueued above, which is fine, since the
|
||||
// information is guaranteed to be at least as recent as the current
|
||||
// call to RequestStatus().
|
||||
select {
|
||||
case <-e.reqCh:
|
||||
s, err := e.getStatus()
|
||||
if s == nil && err == nil {
|
||||
e.logf("RequestStatus: weird: both s and err are nil\n")
|
||||
return
|
||||
}
|
||||
if e.statusCallback != nil {
|
||||
e.statusCallback(s, err)
|
||||
}
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
func (e *userspaceEngine) Close() {
|
||||
e.Reconfig(&wgcfg.Config{}, nil)
|
||||
e.router.Close()
|
||||
e.magicConn.Close()
|
||||
close(e.waitCh)
|
||||
}
|
||||
|
||||
func (e *userspaceEngine) Wait() {
|
||||
<-e.waitCh
|
||||
}
|
||||
|
||||
func (e *userspaceEngine) LinkChange(isExpensive bool) {
|
||||
e.logf("LinkChange(isExpensive=%v): rebinding socket", isExpensive)
|
||||
e.wgLock.Lock()
|
||||
defer e.wgLock.Unlock()
|
||||
|
||||
// TODO(crawshaw): use isExpensive=true to switch into "client mode" on macOS?
|
||||
e.magicConn.LinkChange()
|
||||
|
||||
// TODO(crawshaw): when we have an incremental notion of reconfig,
|
||||
// be gentler here. No need to smash in-progress connections,
|
||||
// we just need to handshake again.
|
||||
if e.lastReconfig == "" {
|
||||
return
|
||||
}
|
||||
uapi := e.lastReconfig[:strings.Index(e.lastReconfig, "\x00")]
|
||||
r := bufio.NewReader(strings.NewReader(uapi))
|
||||
if err := e.wgdev.IpcSetOperation(r); err != nil {
|
||||
e.logf("IpcSetOperation: %v\n", err)
|
||||
}
|
||||
}
|
83
wgengine/watchdog.go
Normal file
83
wgengine/watchdog.go
Normal file
@@ -0,0 +1,83 @@
|
||||
// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package wgengine
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"log"
|
||||
"runtime/pprof"
|
||||
"time"
|
||||
|
||||
"github.com/tailscale/wireguard-go/wgcfg"
|
||||
"tailscale.com/wgengine/filter"
|
||||
)
|
||||
|
||||
// NewWatchdog wraps an Engine and makes sure that all methods complete
|
||||
// within a reasonable amount of time.
|
||||
//
|
||||
// If they do not, the watchdog crashes the process.
|
||||
func NewWatchdog(e Engine) Engine {
|
||||
return &watchdogEngine{
|
||||
wrap: e,
|
||||
logf: log.Printf,
|
||||
fatalf: log.Fatalf,
|
||||
maxWait: 45 * time.Second,
|
||||
}
|
||||
}
|
||||
|
||||
type watchdogEngine struct {
|
||||
wrap Engine
|
||||
logf func(format string, args ...interface{})
|
||||
fatalf func(format string, args ...interface{})
|
||||
maxWait time.Duration
|
||||
}
|
||||
|
||||
func (e *watchdogEngine) watchdogErr(name string, fn func() error) error {
|
||||
errCh := make(chan error)
|
||||
go func() {
|
||||
errCh <- fn()
|
||||
}()
|
||||
t := time.NewTimer(e.maxWait)
|
||||
select {
|
||||
case err := <-errCh:
|
||||
t.Stop()
|
||||
return err
|
||||
case <-t.C:
|
||||
buf := new(bytes.Buffer)
|
||||
pprof.Lookup("goroutine").WriteTo(buf, 1)
|
||||
e.logf("wgengine watchdog stacks:\n%s", buf.String())
|
||||
e.fatalf("wgengine: watchdog timeout on %s", name)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func (e *watchdogEngine) watchdog(name string, fn func()) {
|
||||
e.watchdogErr(name, func() error {
|
||||
fn()
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
func (e *watchdogEngine) Reconfig(cfg *wgcfg.Config, dnsDomains []string) error {
|
||||
return e.watchdogErr("Reconfig", func() error { return e.wrap.Reconfig(cfg, dnsDomains) })
|
||||
}
|
||||
func (e *watchdogEngine) SetFilter(filt *filter.Filter) {
|
||||
e.watchdog("SetFilter", func() { e.wrap.SetFilter(filt) })
|
||||
}
|
||||
func (e *watchdogEngine) SetStatusCallback(cb StatusCallback) {
|
||||
e.watchdog("SetStatusCallback", func() { e.wrap.SetStatusCallback(cb) })
|
||||
}
|
||||
func (e *watchdogEngine) RequestStatus() {
|
||||
e.watchdog("RequestStatus", func() { e.wrap.RequestStatus() })
|
||||
}
|
||||
func (e *watchdogEngine) LinkChange(isExpensive bool) {
|
||||
e.watchdog("LinkChange", func() { e.wrap.LinkChange(isExpensive) })
|
||||
}
|
||||
func (e *watchdogEngine) Close() {
|
||||
e.watchdog("Close", e.wrap.Close)
|
||||
}
|
||||
func (e *watchdogEngine) Wait() {
|
||||
e.wrap.Wait()
|
||||
}
|
71
wgengine/watchdog_test.go
Normal file
71
wgengine/watchdog_test.go
Normal file
@@ -0,0 +1,71 @@
|
||||
// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package wgengine
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestWatchdog(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("default watchdog does not fire", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
tun := NewFakeTun()
|
||||
e, err := NewUserspaceEngineAdvanced(t.Logf, tun, NewFakeRouter, 0, false)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
e = NewWatchdog(e)
|
||||
e.(*watchdogEngine).maxWait = 150 * time.Millisecond
|
||||
|
||||
e.RequestStatus()
|
||||
e.RequestStatus()
|
||||
e.RequestStatus()
|
||||
e.Close()
|
||||
})
|
||||
|
||||
t.Run("watchdog fires on blocked getStatus", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
tun := NewFakeTun()
|
||||
e, err := NewUserspaceEngineAdvanced(t.Logf, tun, NewFakeRouter, 0, false)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
usEngine := e.(*userspaceEngine)
|
||||
e = NewWatchdog(e)
|
||||
wdEngine := e.(*watchdogEngine)
|
||||
wdEngine.maxWait = 100 * time.Millisecond
|
||||
|
||||
logBuf := new(bytes.Buffer)
|
||||
fatalCalled := make(chan struct{})
|
||||
wdEngine.logf = func(format string, args ...interface{}) {
|
||||
fmt.Fprintf(logBuf, format+"\n", args...)
|
||||
}
|
||||
wdEngine.fatalf = func(format string, args ...interface{}) {
|
||||
t.Logf("FATAL: %s", fmt.Sprintf(format, args...))
|
||||
fatalCalled <- struct{}{}
|
||||
}
|
||||
|
||||
usEngine.wgLock.Lock() // blocks getStatus so the watchdog will fire
|
||||
|
||||
go e.RequestStatus()
|
||||
|
||||
select {
|
||||
case <-fatalCalled:
|
||||
if !strings.Contains(logBuf.String(), "goroutine profile: total ") {
|
||||
t.Errorf("fatal called without watchdog stacks, got: %s", logBuf.String())
|
||||
}
|
||||
// expected
|
||||
case <-time.After(3 * time.Second):
|
||||
t.Fatalf("watchdog failed to fire")
|
||||
}
|
||||
})
|
||||
}
|
79
wgengine/wgengine.go
Normal file
79
wgengine/wgengine.go
Normal file
@@ -0,0 +1,79 @@
|
||||
// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package wgengine
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"github.com/tailscale/wireguard-go/wgcfg"
|
||||
"tailscale.com/tailcfg"
|
||||
"tailscale.com/wgengine/filter"
|
||||
)
|
||||
|
||||
type ByteCount int64
|
||||
|
||||
type PeerStatus struct {
|
||||
TxBytes, RxBytes ByteCount
|
||||
LastHandshake time.Time
|
||||
NodeKey tailcfg.NodeKey
|
||||
}
|
||||
|
||||
type Status struct {
|
||||
Peers []PeerStatus
|
||||
LocalAddrs []string // TODO(crawshaw): []wgcfg.Endpoint?
|
||||
}
|
||||
|
||||
type StatusCallback func(s *Status, err error)
|
||||
|
||||
type RouteSettings struct {
|
||||
LocalAddr wgcfg.CIDR
|
||||
DNS []net.IP
|
||||
DNSDomains []string
|
||||
Cfg wgcfg.Config
|
||||
}
|
||||
|
||||
// Only used on darwin for now
|
||||
// TODO(apenwarr): This probably belongs in the darwinRouter struct.
|
||||
var SetRoutesFunc func(rs RouteSettings) error
|
||||
|
||||
func (rs *RouteSettings) OnlyRelevantParts() string {
|
||||
var peers [][]wgcfg.CIDR
|
||||
for _, p := range rs.Cfg.Peers {
|
||||
peers = append(peers, p.AllowedIPs)
|
||||
}
|
||||
return fmt.Sprintf("%v %v %v %v",
|
||||
rs.LocalAddr, rs.DNS, rs.DNSDomains, peers)
|
||||
}
|
||||
|
||||
type Router interface {
|
||||
Up() error
|
||||
SetRoutes(rs RouteSettings) error
|
||||
Close()
|
||||
}
|
||||
|
||||
type Engine interface {
|
||||
// Reconfigure wireguard and make sure it's running.
|
||||
// This also handles setting up any kernel routes.
|
||||
Reconfig(cfg *wgcfg.Config, dnsDomains []string) error
|
||||
// Update the packet filter.
|
||||
SetFilter(filt *filter.Filter)
|
||||
// Set the function to call when wireguard status changes.
|
||||
SetStatusCallback(cb StatusCallback)
|
||||
// Request a wireguard status update right away, sent to the callback.
|
||||
RequestStatus()
|
||||
// Shut down this wireguard instance, remove any routes it added, etc.
|
||||
// To bring it up again later, you'll need a new Engine.
|
||||
Close()
|
||||
// Wait until the Engine is .Close()ed or aborts with an error.
|
||||
// You don't have to call this.
|
||||
Wait()
|
||||
// LinkChange informs the engine that the system network
|
||||
// link has changed. The isExpensive parameter is set on links
|
||||
// where sending packets uses substantial power or dollars
|
||||
// (such as LTE on a phone).
|
||||
LinkChange(isExpensive bool)
|
||||
}
|
153
wgengine/winnet/winnet.go
Normal file
153
wgengine/winnet/winnet.go
Normal file
@@ -0,0 +1,153 @@
|
||||
// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package winnet
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/go-ole/go-ole"
|
||||
"github.com/go-ole/go-ole/oleutil"
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
const CLSID_NetworkListManager = "{DCB00C01-570F-4A9B-8D69-199FDBA5723B}"
|
||||
|
||||
var IID_INetwork = ole.NewGUID("{8A40A45D-055C-4B62-ABD7-6D613E2CEAEC}")
|
||||
var IID_INetworkConnection = ole.NewGUID("{DCB00005-570F-4A9B-8D69-199FDBA5723B}")
|
||||
|
||||
type NetworkListManager struct {
|
||||
d *ole.Dispatch
|
||||
}
|
||||
|
||||
type INetworkConnection struct {
|
||||
ole.IDispatch
|
||||
}
|
||||
|
||||
type ConnectionList []*INetworkConnection
|
||||
|
||||
type INetworkConnectionVtbl struct {
|
||||
ole.IDispatchVtbl
|
||||
GetNetwork uintptr
|
||||
Get_IsConnectedToInternet uintptr
|
||||
Get_IsConnected uintptr
|
||||
GetConnectivity uintptr
|
||||
GetConnectionId uintptr
|
||||
GetAdapterId uintptr
|
||||
GetDomainType uintptr
|
||||
}
|
||||
|
||||
type INetwork struct {
|
||||
ole.IDispatch
|
||||
}
|
||||
|
||||
func NewNetworkListManager(c *ole.Connection) (*NetworkListManager, error) {
|
||||
err := c.Create(CLSID_NetworkListManager)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer c.Release()
|
||||
|
||||
d, err := c.Dispatch()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &NetworkListManager{
|
||||
d: d,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (m *NetworkListManager) Release() {
|
||||
m.d.Release()
|
||||
}
|
||||
|
||||
func (cl ConnectionList) Release() {
|
||||
for _, v := range cl {
|
||||
v.Release()
|
||||
}
|
||||
}
|
||||
|
||||
func asIID(u ole.UnknownLike, iid *ole.GUID) (*ole.IDispatch, error) {
|
||||
if u == nil {
|
||||
return nil, fmt.Errorf("asIID: nil UnknownLike")
|
||||
}
|
||||
|
||||
d, err := u.QueryInterface(iid)
|
||||
u.Release()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return d, nil
|
||||
}
|
||||
|
||||
func (m *NetworkListManager) GetNetworkConnections() (ConnectionList, error) {
|
||||
ncraw, err := m.d.Call("GetNetworkConnections")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
nli := ncraw.ToIDispatch()
|
||||
if nli == nil {
|
||||
return nil, fmt.Errorf("GetNetworkConnections: not IDispatch")
|
||||
}
|
||||
|
||||
cl := ConnectionList{}
|
||||
|
||||
err = oleutil.ForEach(nli, func(v *ole.VARIANT) error {
|
||||
nc, err := asIID(v.ToIUnknown(), IID_INetworkConnection)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
nco := (*INetworkConnection)(unsafe.Pointer(nc))
|
||||
cl = append(cl, nco)
|
||||
return nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
cl.Release()
|
||||
return nil, err
|
||||
}
|
||||
return cl, nil
|
||||
}
|
||||
|
||||
func (n *INetwork) GetName() (string, error) {
|
||||
v, err := n.CallMethod("GetName")
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return v.ToString(), err
|
||||
}
|
||||
|
||||
func (n *INetwork) GetCategory() (int32, error) {
|
||||
v, err := n.CallMethod("GetCategory")
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return v.Value().(int32), err
|
||||
}
|
||||
|
||||
func (n *INetwork) SetCategory(v uint32) error {
|
||||
_, err := n.CallMethod("SetCategory", v)
|
||||
return err
|
||||
}
|
||||
|
||||
func (v *INetworkConnection) VTable() *INetworkConnectionVtbl {
|
||||
return (*INetworkConnectionVtbl)(unsafe.Pointer(v.RawVTable))
|
||||
}
|
||||
|
||||
func (v *INetworkConnection) GetNetwork() (*INetwork, error) {
|
||||
nraw, err := v.CallMethod("GetNetwork")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
n := nraw.ToIDispatch()
|
||||
if n == nil {
|
||||
return nil, fmt.Errorf("GetNetwork: nil IDispatch")
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return (*INetwork)(unsafe.Pointer(n)), nil
|
||||
}
|
26
wgengine/winnet/winnet_windows.go
Normal file
26
wgengine/winnet/winnet_windows.go
Normal file
@@ -0,0 +1,26 @@
|
||||
// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package winnet
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/go-ole/go-ole"
|
||||
"syscall"
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
func (v *INetworkConnection) GetAdapterId() (string, error) {
|
||||
buf := ole.GUID{}
|
||||
hr, _, _ := syscall.Syscall(
|
||||
v.VTable().GetAdapterId,
|
||||
2,
|
||||
uintptr(unsafe.Pointer(v)),
|
||||
uintptr(unsafe.Pointer(&buf)),
|
||||
0)
|
||||
if hr != 0 {
|
||||
return "", fmt.Errorf("GetAdapterId failed: %08x", hr)
|
||||
}
|
||||
return buf.String(), nil
|
||||
}
|
Reference in New Issue
Block a user