mirror of
https://github.com/tailscale/tailscale.git
synced 2024-11-29 13:05:46 +00:00
8296c934ac
OLE calls sometimes unexpectedly fail, but retries can succeed. Change panic() to return errors. This way ConfigureInterface() retries can succeed.
416 lines
10 KiB
Go
416 lines
10 KiB
Go
/* SPDX-License-Identifier: MIT
|
|
*
|
|
* Copyright (C) 2019 WireGuard LLC. All Rights Reserved.
|
|
*/
|
|
|
|
package wgengine
|
|
|
|
import (
|
|
"bytes"
|
|
"encoding/binary"
|
|
"errors"
|
|
"fmt"
|
|
"log"
|
|
"net"
|
|
"sort"
|
|
"time"
|
|
"unsafe"
|
|
|
|
ole "github.com/go-ole/go-ole"
|
|
winipcfg "github.com/tailscale/winipcfg-go"
|
|
"github.com/tailscale/wireguard-go/device"
|
|
"github.com/tailscale/wireguard-go/tun"
|
|
"github.com/tailscale/wireguard-go/wgcfg"
|
|
"golang.org/x/sys/windows"
|
|
"golang.org/x/sys/windows/registry"
|
|
"tailscale.com/wgengine/winnet"
|
|
)
|
|
|
|
const (
|
|
sockoptIP_UNICAST_IF = 31
|
|
sockoptIPV6_UNICAST_IF = 31
|
|
)
|
|
|
|
func htonl(val uint32) uint32 {
|
|
bytes := make([]byte, 4)
|
|
binary.BigEndian.PutUint32(bytes, val)
|
|
return *(*uint32)(unsafe.Pointer(&bytes[0]))
|
|
}
|
|
|
|
func bindSocketRoute(family winipcfg.AddressFamily, device *device.Device, ourLuid uint64, lastLuid *uint64) error {
|
|
routes, err := winipcfg.GetRoutes(family)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
lowestMetric := ^uint32(0)
|
|
index := uint32(0) // Zero is "unspecified", which for IP_UNICAST_IF resets the value, which is what we want.
|
|
luid := uint64(0) // Hopefully luid zero is unspecified, but hard to find docs saying so.
|
|
for _, route := range routes {
|
|
if route.DestinationPrefix.PrefixLength != 0 || route.InterfaceLuid == ourLuid {
|
|
continue
|
|
}
|
|
if route.Metric < lowestMetric {
|
|
lowestMetric = route.Metric
|
|
index = route.InterfaceIndex
|
|
luid = route.InterfaceLuid
|
|
}
|
|
}
|
|
if luid == *lastLuid {
|
|
return nil
|
|
}
|
|
*lastLuid = luid
|
|
if false {
|
|
// TODO(apenwarr): doesn't work with magic socket yet.
|
|
if family == winipcfg.AF_INET {
|
|
return device.BindSocketToInterface4(index, false)
|
|
} else if family == winipcfg.AF_INET6 {
|
|
return device.BindSocketToInterface6(index, false)
|
|
}
|
|
} else {
|
|
log.Printf("WARNING: skipping windows socket binding.\n")
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func monitorDefaultRoutes(device *device.Device, autoMTU bool, tun *tun.NativeTun) (*winipcfg.RouteChangeCallback, error) {
|
|
guid := tun.GUID()
|
|
ourLuid, err := winipcfg.InterfaceGuidToLuid(&guid)
|
|
lastLuid4 := uint64(0)
|
|
lastLuid6 := uint64(0)
|
|
lastMtu := uint32(0)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
doIt := func() error {
|
|
err = bindSocketRoute(winipcfg.AF_INET, device, ourLuid, &lastLuid4)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
err = bindSocketRoute(winipcfg.AF_INET6, device, ourLuid, &lastLuid6)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if !autoMTU {
|
|
return nil
|
|
}
|
|
mtu := uint32(0)
|
|
if lastLuid4 != 0 {
|
|
iface, err := winipcfg.InterfaceFromLUID(lastLuid4)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if iface.Mtu > 0 {
|
|
mtu = iface.Mtu
|
|
}
|
|
}
|
|
if lastLuid6 != 0 {
|
|
iface, err := winipcfg.InterfaceFromLUID(lastLuid6)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if iface.Mtu > 0 && iface.Mtu < mtu {
|
|
mtu = iface.Mtu
|
|
}
|
|
}
|
|
if mtu > 0 && (lastMtu == 0 || lastMtu != mtu) {
|
|
iface, err := winipcfg.GetIpInterface(ourLuid, winipcfg.AF_INET)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
iface.NlMtu = mtu - 80
|
|
if iface.NlMtu < 576 {
|
|
iface.NlMtu = 576
|
|
}
|
|
err = iface.Set()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
tun.ForceMTU(int(iface.NlMtu)) //TODO: it sort of breaks the model with v6 mtu and v4 mtu being different. Just set v4 one for now.
|
|
iface, err = winipcfg.GetIpInterface(ourLuid, winipcfg.AF_INET6)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
iface.NlMtu = mtu - 80
|
|
if iface.NlMtu < 1280 {
|
|
iface.NlMtu = 1280
|
|
}
|
|
err = iface.Set()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
lastMtu = mtu
|
|
}
|
|
return nil
|
|
}
|
|
err = doIt()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
cb, err := winipcfg.RegisterRouteChangeCallback(func(notificationType winipcfg.MibNotificationType, route *winipcfg.Route) {
|
|
//fmt.Printf("MonitorDefaultRoutes: changed: %v\n", route.DestinationPrefix)
|
|
if route.DestinationPrefix.PrefixLength == 0 {
|
|
_ = doIt()
|
|
}
|
|
})
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return cb, nil
|
|
}
|
|
|
|
func setDNSDomains(g windows.GUID, dnsDomains []string) {
|
|
gs := g.String()
|
|
log.Printf("setDNSDomains(%v) guid=%v\n", dnsDomains, gs)
|
|
p := `SYSTEM\CurrentControlSet\Services\Tcpip\Parameters\Interfaces\` + gs
|
|
key, err := registry.OpenKey(registry.LOCAL_MACHINE, p, registry.READ|registry.SET_VALUE)
|
|
if err != nil {
|
|
log.Printf("setDNSDomains(%v): open: %v\n", p, err)
|
|
return
|
|
}
|
|
defer key.Close()
|
|
|
|
// Windows only supports a single per-interface DNS domain.
|
|
dom := ""
|
|
if len(dnsDomains) > 0 {
|
|
dom = dnsDomains[0]
|
|
}
|
|
err = key.SetStringValue("Domain", dom)
|
|
if err != nil {
|
|
log.Printf("setDNSDomains(%v): SetStringValue: %v\n", p, err)
|
|
}
|
|
}
|
|
|
|
func setFirewall(ifcGUID *windows.GUID) (bool, error) {
|
|
c := ole.Connection{}
|
|
err := c.Initialize()
|
|
if err != nil {
|
|
return false, fmt.Errorf("c.Initialize: %v", err)
|
|
}
|
|
defer c.Uninitialize()
|
|
|
|
m, err := winnet.NewNetworkListManager(&c)
|
|
if err != nil {
|
|
return false, fmt.Errorf("winnet.NewNetworkListManager: %v", err)
|
|
}
|
|
defer m.Release()
|
|
|
|
cl, err := m.GetNetworkConnections()
|
|
if err != nil {
|
|
return false, fmt.Errorf("m.GetNetworkConnections: %v", err)
|
|
}
|
|
defer cl.Release()
|
|
|
|
for _, nco := range cl {
|
|
aid, err := nco.GetAdapterId()
|
|
if err != nil {
|
|
return false, fmt.Errorf("nco.GetAdapterId: %v", err)
|
|
}
|
|
if aid != ifcGUID.String() {
|
|
log.Printf("skipping adapter id: %v\n", aid)
|
|
continue
|
|
}
|
|
log.Printf("found! adapter id: %v\n", aid)
|
|
|
|
n, err := nco.GetNetwork()
|
|
if err != nil {
|
|
return false, fmt.Errorf("GetNetwork: %v", err)
|
|
}
|
|
defer n.Release()
|
|
|
|
cat, err := n.GetCategory()
|
|
if err != nil {
|
|
return false, fmt.Errorf("GetCategory: %v", err)
|
|
}
|
|
|
|
if cat == 0 {
|
|
err = n.SetCategory(1)
|
|
if err != nil {
|
|
return false, fmt.Errorf("SetCategory: %v", err)
|
|
}
|
|
} else {
|
|
log.Printf("setFirewall: already category %v\n", cat)
|
|
}
|
|
|
|
return true, nil
|
|
}
|
|
|
|
return false, nil
|
|
}
|
|
|
|
func configureInterface(m *wgcfg.Config, tun *tun.NativeTun, dns []wgcfg.IP, dnsDomains []string) error {
|
|
const mtu = 0
|
|
guid := tun.GUID()
|
|
log.Printf("wintun GUID is %v\n", guid)
|
|
iface, err := winipcfg.InterfaceFromGUID(&guid)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
go func() {
|
|
// It takes a weirdly long time for Windows to notice the
|
|
// new interface has come up. Poll periodically until it
|
|
// does.
|
|
for i := 0; i < 20; i++ {
|
|
found, err := setFirewall(&guid)
|
|
if err != nil {
|
|
log.Printf("setFirewall: %v\n", err)
|
|
// fall through anyway, this isn't fatal.
|
|
}
|
|
if found {
|
|
break
|
|
}
|
|
time.Sleep(1 * time.Second)
|
|
}
|
|
}()
|
|
|
|
setDNSDomains(guid, dnsDomains)
|
|
|
|
routes := []winipcfg.RouteData{}
|
|
var firstGateway4 *net.IP
|
|
var firstGateway6 *net.IP
|
|
addresses := make([]*net.IPNet, len(m.Addresses))
|
|
for i, addr := range m.Addresses {
|
|
ipnet := addr.IPNet()
|
|
addresses[i] = ipnet
|
|
gateway := ipnet.IP
|
|
if addr.IP.Is4() && firstGateway4 == nil {
|
|
firstGateway4 = &gateway
|
|
} else if addr.IP.Is6() && firstGateway6 == nil {
|
|
firstGateway6 = &gateway
|
|
}
|
|
}
|
|
|
|
foundDefault4 := false
|
|
foundDefault6 := false
|
|
for _, peer := range m.Peers {
|
|
for _, allowedip := range peer.AllowedIPs {
|
|
if (allowedip.IP.Is4() && firstGateway4 == nil) || (allowedip.IP.Is6() && firstGateway6 == nil) {
|
|
return errors.New("Due to a Windows limitation, one cannot have interface routes without an interface address")
|
|
}
|
|
|
|
ipn := allowedip.IPNet()
|
|
var gateway net.IP
|
|
if allowedip.IP.Is4() {
|
|
gateway = *firstGateway4
|
|
} else if allowedip.IP.Is6() {
|
|
gateway = *firstGateway6
|
|
}
|
|
r := winipcfg.RouteData{
|
|
Destination: net.IPNet{
|
|
IP: ipn.IP.Mask(ipn.Mask),
|
|
Mask: ipn.Mask,
|
|
},
|
|
NextHop: gateway,
|
|
Metric: 0,
|
|
}
|
|
if bytes.Compare(r.Destination.IP, gateway) == 0 {
|
|
// no need to add a route for the interface's
|
|
// own IP. The kernel does that for us.
|
|
// If we try to replace it, we'll fail to
|
|
// add the route unless NextHop is set, but
|
|
// then the interface's IP won't be pingable.
|
|
continue
|
|
}
|
|
if allowedip.IP.Is4() {
|
|
if allowedip.Mask == 0 {
|
|
foundDefault4 = true
|
|
}
|
|
r.NextHop = *firstGateway4
|
|
} else if allowedip.IP.Is6() {
|
|
if allowedip.Mask == 0 {
|
|
foundDefault6 = true
|
|
}
|
|
r.NextHop = *firstGateway6
|
|
}
|
|
routes = append(routes, r)
|
|
}
|
|
}
|
|
|
|
err = iface.SyncAddresses(addresses)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
sort.Slice(routes, func(i, j int) bool {
|
|
return (bytes.Compare(routes[i].Destination.IP, routes[j].Destination.IP) == -1 ||
|
|
// Narrower masks first
|
|
bytes.Compare(routes[i].Destination.Mask, routes[j].Destination.Mask) == 1 ||
|
|
// No nexthop before non-empty nexthop
|
|
bytes.Compare(routes[i].NextHop, routes[j].NextHop) == -1 ||
|
|
// Lower metrics first
|
|
routes[i].Metric < routes[j].Metric)
|
|
})
|
|
|
|
deduplicatedRoutes := []*winipcfg.RouteData{}
|
|
for i := 0; i < len(routes); i++ {
|
|
// There's only one way to get to a given IP+Mask, so delete
|
|
// all matches after the first.
|
|
if i > 0 &&
|
|
bytes.Equal(routes[i].Destination.IP, routes[i-1].Destination.IP) &&
|
|
bytes.Equal(routes[i].Destination.Mask, routes[i-1].Destination.Mask) {
|
|
continue
|
|
}
|
|
deduplicatedRoutes = append(deduplicatedRoutes, &routes[i])
|
|
}
|
|
log.Printf("routes: %v\n", routes)
|
|
|
|
var errAcc error
|
|
err = iface.SyncRoutes(deduplicatedRoutes)
|
|
if err != nil && errAcc == nil {
|
|
log.Printf("setroutes: %v\n", err)
|
|
errAcc = err
|
|
}
|
|
|
|
var dnsIPs []net.IP
|
|
for _, ip := range dns {
|
|
dnsIPs = append(dnsIPs, ip.IP())
|
|
}
|
|
err = iface.SetDNS(dnsIPs)
|
|
if err != nil && errAcc == nil {
|
|
log.Printf("setdns: %v\n", err)
|
|
errAcc = err
|
|
}
|
|
|
|
ipif, err := iface.GetIpInterface(winipcfg.AF_INET)
|
|
if err != nil {
|
|
log.Printf("getipif: %v\n", err)
|
|
return err
|
|
}
|
|
log.Printf("foundDefault4: %v\n", foundDefault4)
|
|
if foundDefault4 {
|
|
ipif.UseAutomaticMetric = false
|
|
ipif.Metric = 0
|
|
}
|
|
if mtu > 0 {
|
|
ipif.NlMtu = uint32(mtu)
|
|
tun.ForceMTU(int(ipif.NlMtu))
|
|
}
|
|
err = ipif.Set()
|
|
if err != nil && errAcc == nil {
|
|
errAcc = err
|
|
}
|
|
|
|
ipif, err = iface.GetIpInterface(winipcfg.AF_INET6)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if err != nil && errAcc == nil {
|
|
errAcc = err
|
|
}
|
|
if foundDefault6 {
|
|
ipif.UseAutomaticMetric = false
|
|
ipif.Metric = 0
|
|
}
|
|
if mtu > 0 {
|
|
ipif.NlMtu = uint32(mtu)
|
|
}
|
|
ipif.DadTransmits = 0
|
|
ipif.RouterDiscoveryBehavior = winipcfg.RouterDiscoveryDisabled
|
|
err = ipif.Set()
|
|
if err != nil && errAcc == nil {
|
|
errAcc = err
|
|
}
|
|
|
|
return errAcc
|
|
}
|