mirror of
				https://github.com/tailscale/tailscale.git
				synced 2025-10-25 10:09:17 +00:00 
			
		
		
		
	 1ac570def7
			
		
	
	1ac570def7
	
	
	
		
			
			The router implementations are logically separate, with their own API. Signed-off-by: David Anderson <danderson@tailscale.com>
		
			
				
	
	
		
			416 lines
		
	
	
		
			10 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
			
		
		
	
	
			416 lines
		
	
	
		
			10 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
| /* SPDX-License-Identifier: MIT
 | |
|  *
 | |
|  * Copyright (C) 2019 WireGuard LLC. All Rights Reserved.
 | |
|  */
 | |
| 
 | |
| package router
 | |
| 
 | |
| import (
 | |
| 	"bytes"
 | |
| 	"encoding/binary"
 | |
| 	"errors"
 | |
| 	"fmt"
 | |
| 	"log"
 | |
| 	"net"
 | |
| 	"sort"
 | |
| 	"time"
 | |
| 	"unsafe"
 | |
| 
 | |
| 	ole "github.com/go-ole/go-ole"
 | |
| 	winipcfg "github.com/tailscale/winipcfg-go"
 | |
| 	"github.com/tailscale/wireguard-go/device"
 | |
| 	"github.com/tailscale/wireguard-go/tun"
 | |
| 	"github.com/tailscale/wireguard-go/wgcfg"
 | |
| 	"golang.org/x/sys/windows"
 | |
| 	"golang.org/x/sys/windows/registry"
 | |
| 	"tailscale.com/wgengine/winnet"
 | |
| )
 | |
| 
 | |
| const (
 | |
| 	sockoptIP_UNICAST_IF   = 31
 | |
| 	sockoptIPV6_UNICAST_IF = 31
 | |
| )
 | |
| 
 | |
| func htonl(val uint32) uint32 {
 | |
| 	bytes := make([]byte, 4)
 | |
| 	binary.BigEndian.PutUint32(bytes, val)
 | |
| 	return *(*uint32)(unsafe.Pointer(&bytes[0]))
 | |
| }
 | |
| 
 | |
| func bindSocketRoute(family winipcfg.AddressFamily, device *device.Device, ourLuid uint64, lastLuid *uint64) error {
 | |
| 	routes, err := winipcfg.GetRoutes(family)
 | |
| 	if err != nil {
 | |
| 		return err
 | |
| 	}
 | |
| 	lowestMetric := ^uint32(0)
 | |
| 	index := uint32(0) // Zero is "unspecified", which for IP_UNICAST_IF resets the value, which is what we want.
 | |
| 	luid := uint64(0)  // Hopefully luid zero is unspecified, but hard to find docs saying so.
 | |
| 	for _, route := range routes {
 | |
| 		if route.DestinationPrefix.PrefixLength != 0 || route.InterfaceLuid == ourLuid {
 | |
| 			continue
 | |
| 		}
 | |
| 		if route.Metric < lowestMetric {
 | |
| 			lowestMetric = route.Metric
 | |
| 			index = route.InterfaceIndex
 | |
| 			luid = route.InterfaceLuid
 | |
| 		}
 | |
| 	}
 | |
| 	if luid == *lastLuid {
 | |
| 		return nil
 | |
| 	}
 | |
| 	*lastLuid = luid
 | |
| 	if false {
 | |
| 		// TODO(apenwarr): doesn't work with magic socket yet.
 | |
| 		if family == winipcfg.AF_INET {
 | |
| 			return device.BindSocketToInterface4(index, false)
 | |
| 		} else if family == winipcfg.AF_INET6 {
 | |
| 			return device.BindSocketToInterface6(index, false)
 | |
| 		}
 | |
| 	} else {
 | |
| 		log.Printf("WARNING: skipping windows socket binding.\n")
 | |
| 	}
 | |
| 	return nil
 | |
| }
 | |
| 
 | |
| func monitorDefaultRoutes(device *device.Device, autoMTU bool, tun *tun.NativeTun) (*winipcfg.RouteChangeCallback, error) {
 | |
| 	guid := tun.GUID()
 | |
| 	ourLuid, err := winipcfg.InterfaceGuidToLuid(&guid)
 | |
| 	lastLuid4 := uint64(0)
 | |
| 	lastLuid6 := uint64(0)
 | |
| 	lastMtu := uint32(0)
 | |
| 	if err != nil {
 | |
| 		return nil, err
 | |
| 	}
 | |
| 	doIt := func() error {
 | |
| 		err = bindSocketRoute(winipcfg.AF_INET, device, ourLuid, &lastLuid4)
 | |
| 		if err != nil {
 | |
| 			return err
 | |
| 		}
 | |
| 		err = bindSocketRoute(winipcfg.AF_INET6, device, ourLuid, &lastLuid6)
 | |
| 		if err != nil {
 | |
| 			return err
 | |
| 		}
 | |
| 		if !autoMTU {
 | |
| 			return nil
 | |
| 		}
 | |
| 		mtu := uint32(0)
 | |
| 		if lastLuid4 != 0 {
 | |
| 			iface, err := winipcfg.InterfaceFromLUID(lastLuid4)
 | |
| 			if err != nil {
 | |
| 				return err
 | |
| 			}
 | |
| 			if iface.Mtu > 0 {
 | |
| 				mtu = iface.Mtu
 | |
| 			}
 | |
| 		}
 | |
| 		if lastLuid6 != 0 {
 | |
| 			iface, err := winipcfg.InterfaceFromLUID(lastLuid6)
 | |
| 			if err != nil {
 | |
| 				return err
 | |
| 			}
 | |
| 			if iface.Mtu > 0 && iface.Mtu < mtu {
 | |
| 				mtu = iface.Mtu
 | |
| 			}
 | |
| 		}
 | |
| 		if mtu > 0 && (lastMtu == 0 || lastMtu != mtu) {
 | |
| 			iface, err := winipcfg.GetIpInterface(ourLuid, winipcfg.AF_INET)
 | |
| 			if err != nil {
 | |
| 				return err
 | |
| 			}
 | |
| 			iface.NlMtu = mtu - 80
 | |
| 			if iface.NlMtu < 576 {
 | |
| 				iface.NlMtu = 576
 | |
| 			}
 | |
| 			err = iface.Set()
 | |
| 			if err != nil {
 | |
| 				return err
 | |
| 			}
 | |
| 			tun.ForceMTU(int(iface.NlMtu)) //TODO: it sort of breaks the model with v6 mtu and v4 mtu being different. Just set v4 one for now.
 | |
| 			iface, err = winipcfg.GetIpInterface(ourLuid, winipcfg.AF_INET6)
 | |
| 			if err != nil {
 | |
| 				return err
 | |
| 			}
 | |
| 			iface.NlMtu = mtu - 80
 | |
| 			if iface.NlMtu < 1280 {
 | |
| 				iface.NlMtu = 1280
 | |
| 			}
 | |
| 			err = iface.Set()
 | |
| 			if err != nil {
 | |
| 				return err
 | |
| 			}
 | |
| 			lastMtu = mtu
 | |
| 		}
 | |
| 		return nil
 | |
| 	}
 | |
| 	err = doIt()
 | |
| 	if err != nil {
 | |
| 		return nil, err
 | |
| 	}
 | |
| 	cb, err := winipcfg.RegisterRouteChangeCallback(func(notificationType winipcfg.MibNotificationType, route *winipcfg.Route) {
 | |
| 		//fmt.Printf("MonitorDefaultRoutes: changed: %v\n", route.DestinationPrefix)
 | |
| 		if route.DestinationPrefix.PrefixLength == 0 {
 | |
| 			_ = doIt()
 | |
| 		}
 | |
| 	})
 | |
| 	if err != nil {
 | |
| 		return nil, err
 | |
| 	}
 | |
| 	return cb, nil
 | |
| }
 | |
| 
 | |
| func setDNSDomains(g windows.GUID, dnsDomains []string) {
 | |
| 	gs := g.String()
 | |
| 	log.Printf("setDNSDomains(%v) guid=%v\n", dnsDomains, gs)
 | |
| 	p := `SYSTEM\CurrentControlSet\Services\Tcpip\Parameters\Interfaces\` + gs
 | |
| 	key, err := registry.OpenKey(registry.LOCAL_MACHINE, p, registry.READ|registry.SET_VALUE)
 | |
| 	if err != nil {
 | |
| 		log.Printf("setDNSDomains(%v): open: %v\n", p, err)
 | |
| 		return
 | |
| 	}
 | |
| 	defer key.Close()
 | |
| 
 | |
| 	// Windows only supports a single per-interface DNS domain.
 | |
| 	dom := ""
 | |
| 	if len(dnsDomains) > 0 {
 | |
| 		dom = dnsDomains[0]
 | |
| 	}
 | |
| 	err = key.SetStringValue("Domain", dom)
 | |
| 	if err != nil {
 | |
| 		log.Printf("setDNSDomains(%v): SetStringValue: %v\n", p, err)
 | |
| 	}
 | |
| }
 | |
| 
 | |
| func setFirewall(ifcGUID *windows.GUID) (bool, error) {
 | |
| 	c := ole.Connection{}
 | |
| 	err := c.Initialize()
 | |
| 	if err != nil {
 | |
| 		return false, fmt.Errorf("c.Initialize: %v", err)
 | |
| 	}
 | |
| 	defer c.Uninitialize()
 | |
| 
 | |
| 	m, err := winnet.NewNetworkListManager(&c)
 | |
| 	if err != nil {
 | |
| 		return false, fmt.Errorf("winnet.NewNetworkListManager: %v", err)
 | |
| 	}
 | |
| 	defer m.Release()
 | |
| 
 | |
| 	cl, err := m.GetNetworkConnections()
 | |
| 	if err != nil {
 | |
| 		return false, fmt.Errorf("m.GetNetworkConnections: %v", err)
 | |
| 	}
 | |
| 	defer cl.Release()
 | |
| 
 | |
| 	for _, nco := range cl {
 | |
| 		aid, err := nco.GetAdapterId()
 | |
| 		if err != nil {
 | |
| 			return false, fmt.Errorf("nco.GetAdapterId: %v", err)
 | |
| 		}
 | |
| 		if aid != ifcGUID.String() {
 | |
| 			log.Printf("skipping adapter id: %v\n", aid)
 | |
| 			continue
 | |
| 		}
 | |
| 		log.Printf("found! adapter id: %v\n", aid)
 | |
| 
 | |
| 		n, err := nco.GetNetwork()
 | |
| 		if err != nil {
 | |
| 			return false, fmt.Errorf("GetNetwork: %v", err)
 | |
| 		}
 | |
| 		defer n.Release()
 | |
| 
 | |
| 		cat, err := n.GetCategory()
 | |
| 		if err != nil {
 | |
| 			return false, fmt.Errorf("GetCategory: %v", err)
 | |
| 		}
 | |
| 
 | |
| 		if cat == 0 {
 | |
| 			err = n.SetCategory(1)
 | |
| 			if err != nil {
 | |
| 				return false, fmt.Errorf("SetCategory: %v", err)
 | |
| 			}
 | |
| 		} else {
 | |
| 			log.Printf("setFirewall: already category %v\n", cat)
 | |
| 		}
 | |
| 
 | |
| 		return true, nil
 | |
| 	}
 | |
| 
 | |
| 	return false, nil
 | |
| }
 | |
| 
 | |
| func configureInterface(m *wgcfg.Config, tun *tun.NativeTun, dns []wgcfg.IP, dnsDomains []string) error {
 | |
| 	const mtu = 0
 | |
| 	guid := tun.GUID()
 | |
| 	log.Printf("wintun GUID is %v\n", guid)
 | |
| 	iface, err := winipcfg.InterfaceFromGUID(&guid)
 | |
| 	if err != nil {
 | |
| 		return err
 | |
| 	}
 | |
| 
 | |
| 	go func() {
 | |
| 		// It takes a weirdly long time for Windows to notice the
 | |
| 		// new interface has come up. Poll periodically until it
 | |
| 		// does.
 | |
| 		for i := 0; i < 20; i++ {
 | |
| 			found, err := setFirewall(&guid)
 | |
| 			if err != nil {
 | |
| 				log.Printf("setFirewall: %v\n", err)
 | |
| 				// fall through anyway, this isn't fatal.
 | |
| 			}
 | |
| 			if found {
 | |
| 				break
 | |
| 			}
 | |
| 			time.Sleep(1 * time.Second)
 | |
| 		}
 | |
| 	}()
 | |
| 
 | |
| 	setDNSDomains(guid, dnsDomains)
 | |
| 
 | |
| 	routes := []winipcfg.RouteData{}
 | |
| 	var firstGateway4 *net.IP
 | |
| 	var firstGateway6 *net.IP
 | |
| 	addresses := make([]*net.IPNet, len(m.Addresses))
 | |
| 	for i, addr := range m.Addresses {
 | |
| 		ipnet := addr.IPNet()
 | |
| 		addresses[i] = ipnet
 | |
| 		gateway := ipnet.IP
 | |
| 		if addr.IP.Is4() && firstGateway4 == nil {
 | |
| 			firstGateway4 = &gateway
 | |
| 		} else if addr.IP.Is6() && firstGateway6 == nil {
 | |
| 			firstGateway6 = &gateway
 | |
| 		}
 | |
| 	}
 | |
| 
 | |
| 	foundDefault4 := false
 | |
| 	foundDefault6 := false
 | |
| 	for _, peer := range m.Peers {
 | |
| 		for _, allowedip := range peer.AllowedIPs {
 | |
| 			if (allowedip.IP.Is4() && firstGateway4 == nil) || (allowedip.IP.Is6() && firstGateway6 == nil) {
 | |
| 				return errors.New("Due to a Windows limitation, one cannot have interface routes without an interface address")
 | |
| 			}
 | |
| 
 | |
| 			ipn := allowedip.IPNet()
 | |
| 			var gateway net.IP
 | |
| 			if allowedip.IP.Is4() {
 | |
| 				gateway = *firstGateway4
 | |
| 			} else if allowedip.IP.Is6() {
 | |
| 				gateway = *firstGateway6
 | |
| 			}
 | |
| 			r := winipcfg.RouteData{
 | |
| 				Destination: net.IPNet{
 | |
| 					IP:   ipn.IP.Mask(ipn.Mask),
 | |
| 					Mask: ipn.Mask,
 | |
| 				},
 | |
| 				NextHop: gateway,
 | |
| 				Metric:  0,
 | |
| 			}
 | |
| 			if bytes.Compare(r.Destination.IP, gateway) == 0 {
 | |
| 				// no need to add a route for the interface's
 | |
| 				// own IP. The kernel does that for us.
 | |
| 				// If we try to replace it, we'll fail to
 | |
| 				// add the route unless NextHop is set, but
 | |
| 				// then the interface's IP won't be pingable.
 | |
| 				continue
 | |
| 			}
 | |
| 			if allowedip.IP.Is4() {
 | |
| 				if allowedip.Mask == 0 {
 | |
| 					foundDefault4 = true
 | |
| 				}
 | |
| 				r.NextHop = *firstGateway4
 | |
| 			} else if allowedip.IP.Is6() {
 | |
| 				if allowedip.Mask == 0 {
 | |
| 					foundDefault6 = true
 | |
| 				}
 | |
| 				r.NextHop = *firstGateway6
 | |
| 			}
 | |
| 			routes = append(routes, r)
 | |
| 		}
 | |
| 	}
 | |
| 
 | |
| 	err = iface.SyncAddresses(addresses)
 | |
| 	if err != nil {
 | |
| 		return err
 | |
| 	}
 | |
| 
 | |
| 	sort.Slice(routes, func(i, j int) bool {
 | |
| 		return (bytes.Compare(routes[i].Destination.IP, routes[j].Destination.IP) == -1 ||
 | |
| 			// Narrower masks first
 | |
| 			bytes.Compare(routes[i].Destination.Mask, routes[j].Destination.Mask) == 1 ||
 | |
| 			// No nexthop before non-empty nexthop
 | |
| 			bytes.Compare(routes[i].NextHop, routes[j].NextHop) == -1 ||
 | |
| 			// Lower metrics first
 | |
| 			routes[i].Metric < routes[j].Metric)
 | |
| 	})
 | |
| 
 | |
| 	deduplicatedRoutes := []*winipcfg.RouteData{}
 | |
| 	for i := 0; i < len(routes); i++ {
 | |
| 		// There's only one way to get to a given IP+Mask, so delete
 | |
| 		// all matches after the first.
 | |
| 		if i > 0 &&
 | |
| 			bytes.Equal(routes[i].Destination.IP, routes[i-1].Destination.IP) &&
 | |
| 			bytes.Equal(routes[i].Destination.Mask, routes[i-1].Destination.Mask) {
 | |
| 			continue
 | |
| 		}
 | |
| 		deduplicatedRoutes = append(deduplicatedRoutes, &routes[i])
 | |
| 	}
 | |
| 	log.Printf("routes: %v\n", routes)
 | |
| 
 | |
| 	var errAcc error
 | |
| 	err = iface.SyncRoutes(deduplicatedRoutes)
 | |
| 	if err != nil && errAcc == nil {
 | |
| 		log.Printf("setroutes: %v\n", err)
 | |
| 		errAcc = err
 | |
| 	}
 | |
| 
 | |
| 	var dnsIPs []net.IP
 | |
| 	for _, ip := range dns {
 | |
| 		dnsIPs = append(dnsIPs, ip.IP())
 | |
| 	}
 | |
| 	err = iface.SetDNS(dnsIPs)
 | |
| 	if err != nil && errAcc == nil {
 | |
| 		log.Printf("setdns: %v\n", err)
 | |
| 		errAcc = err
 | |
| 	}
 | |
| 
 | |
| 	ipif, err := iface.GetIpInterface(winipcfg.AF_INET)
 | |
| 	if err != nil {
 | |
| 		log.Printf("getipif: %v\n", err)
 | |
| 		return err
 | |
| 	}
 | |
| 	log.Printf("foundDefault4: %v\n", foundDefault4)
 | |
| 	if foundDefault4 {
 | |
| 		ipif.UseAutomaticMetric = false
 | |
| 		ipif.Metric = 0
 | |
| 	}
 | |
| 	if mtu > 0 {
 | |
| 		ipif.NlMtu = uint32(mtu)
 | |
| 		tun.ForceMTU(int(ipif.NlMtu))
 | |
| 	}
 | |
| 	err = ipif.Set()
 | |
| 	if err != nil && errAcc == nil {
 | |
| 		errAcc = err
 | |
| 	}
 | |
| 
 | |
| 	ipif, err = iface.GetIpInterface(winipcfg.AF_INET6)
 | |
| 	if err != nil {
 | |
| 		return err
 | |
| 	}
 | |
| 	if err != nil && errAcc == nil {
 | |
| 		errAcc = err
 | |
| 	}
 | |
| 	if foundDefault6 {
 | |
| 		ipif.UseAutomaticMetric = false
 | |
| 		ipif.Metric = 0
 | |
| 	}
 | |
| 	if mtu > 0 {
 | |
| 		ipif.NlMtu = uint32(mtu)
 | |
| 	}
 | |
| 	ipif.DadTransmits = 0
 | |
| 	ipif.RouterDiscoveryBehavior = winipcfg.RouterDiscoveryDisabled
 | |
| 	err = ipif.Set()
 | |
| 	if err != nil && errAcc == nil {
 | |
| 		errAcc = err
 | |
| 	}
 | |
| 
 | |
| 	return errAcc
 | |
| }
 |