mirror of
synced 2025-03-23 09:40:59 +00:00
416 lines
9.8 KiB
416 lines
9.8 KiB
/* SPDX-License-Identifier: MIT
* Copyright (C) 2019 WireGuard LLC. All Rights Reserved.
package wgengine
import (
ole "github.com/go-ole/go-ole"
winipcfg "github.com/tailscale/winipcfg-go"
const (
sockoptIP_UNICAST_IF = 31
sockoptIPV6_UNICAST_IF = 31
func htonl(val uint32) uint32 {
bytes := make([]byte, 4)
binary.BigEndian.PutUint32(bytes, val)
return *(*uint32)(unsafe.Pointer(&bytes[0]))
func bindSocketRoute(family winipcfg.AddressFamily, device *device.Device, ourLuid uint64, lastLuid *uint64) error {
routes, err := winipcfg.GetRoutes(family)
if err != nil {
return err
lowestMetric := ^uint32(0)
index := uint32(0) // Zero is "unspecified", which for IP_UNICAST_IF resets the value, which is what we want.
luid := uint64(0) // Hopefully luid zero is unspecified, but hard to find docs saying so.
for _, route := range routes {
if route.DestinationPrefix.PrefixLength != 0 || route.InterfaceLuid == ourLuid {
if route.Metric < lowestMetric {
lowestMetric = route.Metric
index = route.InterfaceIndex
luid = route.InterfaceLuid
if luid == *lastLuid {
return nil
*lastLuid = luid
if false {
// TODO(apenwarr): doesn't work with magic socket yet.
if family == winipcfg.AF_INET {
return device.BindSocketToInterface4(index, false)
} else if family == winipcfg.AF_INET6 {
return device.BindSocketToInterface6(index, false)
} else {
log.Printf("WARNING: skipping windows socket binding.\n")
return nil
func monitorDefaultRoutes(device *device.Device, autoMTU bool, tun *tun.NativeTun) (*winipcfg.RouteChangeCallback, error) {
guid := tun.GUID()
ourLuid, err := winipcfg.InterfaceGuidToLuid(&guid)
lastLuid4 := uint64(0)
lastLuid6 := uint64(0)
lastMtu := uint32(0)
if err != nil {
return nil, err
doIt := func() error {
err = bindSocketRoute(winipcfg.AF_INET, device, ourLuid, &lastLuid4)
if err != nil {
return err
err = bindSocketRoute(winipcfg.AF_INET6, device, ourLuid, &lastLuid6)
if err != nil {
return err
if !autoMTU {
return nil
mtu := uint32(0)
if lastLuid4 != 0 {
iface, err := winipcfg.InterfaceFromLUID(lastLuid4)
if err != nil {
return err
if iface.Mtu > 0 {
mtu = iface.Mtu
if lastLuid6 != 0 {
iface, err := winipcfg.InterfaceFromLUID(lastLuid6)
if err != nil {
return err
if iface.Mtu > 0 && iface.Mtu < mtu {
mtu = iface.Mtu
if mtu > 0 && (lastMtu == 0 || lastMtu != mtu) {
iface, err := winipcfg.GetIpInterface(ourLuid, winipcfg.AF_INET)
if err != nil {
return err
iface.NlMtu = mtu - 80
if iface.NlMtu < 576 {
iface.NlMtu = 576
err = iface.Set()
if err != nil {
return err
tun.ForceMTU(int(iface.NlMtu)) //TODO: it sort of breaks the model with v6 mtu and v4 mtu being different. Just set v4 one for now.
iface, err = winipcfg.GetIpInterface(ourLuid, winipcfg.AF_INET6)
if err != nil {
return err
iface.NlMtu = mtu - 80
if iface.NlMtu < 1280 {
iface.NlMtu = 1280
err = iface.Set()
if err != nil {
return err
lastMtu = mtu
return nil
err = doIt()
if err != nil {
return nil, err
cb, err := winipcfg.RegisterRouteChangeCallback(func(notificationType winipcfg.MibNotificationType, route *winipcfg.Route) {
//fmt.Printf("MonitorDefaultRoutes: changed: %v\n", route.DestinationPrefix)
if route.DestinationPrefix.PrefixLength == 0 {
_ = doIt()
if err != nil {
return nil, err
return cb, nil
func setDNSDomains(g windows.GUID, dnsDomains []string) {
gs := g.String()
log.Printf("setDNSDomains(%v) guid=%v\n", dnsDomains, gs)
p := `SYSTEM\CurrentControlSet\Services\Tcpip\Parameters\Interfaces\` + gs
key, err := registry.OpenKey(registry.LOCAL_MACHINE, p, registry.READ|registry.SET_VALUE)
if err != nil {
log.Printf("setDNSDomains(%v): open: %v\n", p, err)
defer key.Close()
// Windows only supports a single per-interface DNS domain.
dom := ""
if len(dnsDomains) > 0 {
dom = dnsDomains[0]
err = key.SetStringValue("Domain", dom)
if err != nil {
log.Printf("setDNSDomains(%v): SetStringValue: %v\n", p, err)
func setFirewall(ifcGUID *windows.GUID) (bool, error) {
c := ole.Connection{}
err := c.Initialize()
if err != nil {
defer c.Uninitialize()
m, err := winnet.NewNetworkListManager(&c)
if err != nil {
defer m.Release()
cl, err := m.GetNetworkConnections()
if err != nil {
defer cl.Release()
for _, nco := range cl {
aid, err := nco.GetAdapterId()
if err != nil {
if aid != ifcGUID.String() {
log.Printf("skipping adapter id: %v\n", aid)
log.Printf("found! adapter id: %v\n", aid)
n, err := nco.GetNetwork()
if err != nil {
return false, fmt.Errorf("GetNetwork: %v", err)
defer n.Release()
cat, err := n.GetCategory()
if err != nil {
return false, fmt.Errorf("GetCategory: %v", err)
if cat == 0 {
err = n.SetCategory(1)
if err != nil {
return false, fmt.Errorf("SetCategory: %v", err)
} else {
log.Printf("setFirewall: already category %v\n", cat)
return true, nil
return false, nil
func configureInterface(m *wgcfg.Config, tun *tun.NativeTun, dns []wgcfg.IP, dnsDomains []string) error {
const mtu = 0
guid := tun.GUID()
log.Printf("wintun GUID is %v\n", guid)
iface, err := winipcfg.InterfaceFromGUID(&guid)
if err != nil {
return err
go func() {
// It takes a weirdly long time for Windows to notice the
// new interface has come up. Poll periodically until it
// does.
for i := 0; i < 20; i++ {
found, err := setFirewall(&guid)
if err != nil {
log.Printf("setFirewall: %v\n", err)
// fall through anyway, this isn't fatal.
if found {
time.Sleep(1 * time.Second)
setDNSDomains(guid, dnsDomains)
routes := []winipcfg.RouteData{}
var firstGateway4 *net.IP
var firstGateway6 *net.IP
addresses := make([]*net.IPNet, len(m.Addresses))
for i, addr := range m.Addresses {
ipnet := addr.IPNet()
addresses[i] = ipnet
gateway := ipnet.IP
if addr.IP.Is4() && firstGateway4 == nil {
firstGateway4 = &gateway
} else if addr.IP.Is6() && firstGateway6 == nil {
firstGateway6 = &gateway
foundDefault4 := false
foundDefault6 := false
for _, peer := range m.Peers {
for _, allowedip := range peer.AllowedIPs {
if (allowedip.IP.Is4() && firstGateway4 == nil) || (allowedip.IP.Is6() && firstGateway6 == nil) {
return errors.New("Due to a Windows limitation, one cannot have interface routes without an interface address")
ipn := allowedip.IPNet()
var gateway net.IP
if allowedip.IP.Is4() {
gateway = *firstGateway4
} else if allowedip.IP.Is6() {
gateway = *firstGateway6
r := winipcfg.RouteData{
Destination: net.IPNet{
IP: ipn.IP.Mask(ipn.Mask),
Mask: ipn.Mask,
NextHop: gateway,
Metric: 0,
if bytes.Compare(r.Destination.IP, gateway) == 0 {
// no need to add a route for the interface's
// own IP. The kernel does that for us.
// If we try to replace it, we'll fail to
// add the route unless NextHop is set, but
// then the interface's IP won't be pingable.
if allowedip.IP.Is4() {
if allowedip.Mask == 0 {
foundDefault4 = true
r.NextHop = *firstGateway4
} else if allowedip.IP.Is6() {
if allowedip.Mask == 0 {
foundDefault6 = true
r.NextHop = *firstGateway6
routes = append(routes, r)
err = iface.SetAddresses(addresses)
if err != nil {
return err
sort.Slice(routes, func(i, j int) bool {
return (bytes.Compare(routes[i].Destination.IP, routes[j].Destination.IP) == -1 ||
// Narrower masks first
bytes.Compare(routes[i].Destination.Mask, routes[j].Destination.Mask) == 1 ||
// No nexthop before non-empty nexthop
bytes.Compare(routes[i].NextHop, routes[j].NextHop) == -1 ||
// Lower metrics first
routes[i].Metric < routes[j].Metric)
deduplicatedRoutes := []*winipcfg.RouteData{}
for i := 0; i < len(routes); i++ {
// There's only one way to get to a given IP+Mask, so delete
// all matches after the first.
if i > 0 &&
bytes.Equal(routes[i].Destination.IP, routes[i-1].Destination.IP) &&
bytes.Equal(routes[i].Destination.Mask, routes[i-1].Destination.Mask) {
deduplicatedRoutes = append(deduplicatedRoutes, &routes[i])
log.Printf("routes: %v\n", routes)
var errAcc error
err = iface.SetRoutes(deduplicatedRoutes)
if err != nil && errAcc == nil {
log.Printf("setroutes: %v\n", err)
errAcc = err
var dnsIPs []net.IP
for _, ip := range dns {
dnsIPs = append(dnsIPs, ip.IP())
err = iface.SetDNS(dnsIPs)
if err != nil && errAcc == nil {
log.Printf("setdns: %v\n", err)
errAcc = err
ipif, err := iface.GetIpInterface(winipcfg.AF_INET)
if err != nil {
log.Printf("getipif: %v\n", err)
return err
log.Printf("foundDefault4: %v\n", foundDefault4)
if foundDefault4 {
ipif.UseAutomaticMetric = false
ipif.Metric = 0
if mtu > 0 {
ipif.NlMtu = uint32(mtu)
err = ipif.Set()
if err != nil && errAcc == nil {
errAcc = err
ipif, err = iface.GetIpInterface(winipcfg.AF_INET6)
if err != nil {
return err
if err != nil && errAcc == nil {
errAcc = err
if foundDefault6 {
ipif.UseAutomaticMetric = false
ipif.Metric = 0
if mtu > 0 {
ipif.NlMtu = uint32(mtu)
ipif.DadTransmits = 0
ipif.RouterDiscoveryBehavior = winipcfg.RouterDiscoveryDisabled
err = ipif.Set()
if err != nil && errAcc == nil {
errAcc = err
return errAcc