wgengine/router: unfork winipcfg-go package, use upstream

Use golang.zx2c4.com/wireguard/windows/tunnel/winipcfg
instead of github.com/tailscale/winipcfg-go package.

Updates #760

Signed-off-by: Alex Brainman <alex.brainman@gmail.com>
This commit is contained in:
Alex Brainman
2020-09-26 12:11:05 +10:00
committed by Brad Fitzpatrick
parent 515866d7c6
commit f2ce64f0c6
9 changed files with 160 additions and 119 deletions

View File

@@ -11,15 +11,14 @@ import (
"fmt"
"log"
"net"
"os"
"sort"
"time"
"github.com/go-multierror/multierror"
ole "github.com/go-ole/go-ole"
winipcfg "github.com/tailscale/winipcfg-go"
"github.com/tailscale/wireguard-go/tun"
"golang.org/x/sys/windows"
"golang.zx2c4.com/wireguard/windows/tunnel/winipcfg"
"tailscale.com/net/interfaces"
"tailscale.com/wgengine/winnet"
)
@@ -40,51 +39,51 @@ import (
// help with MTU issues compared to a static 1280B implementation.
func monitorDefaultRoutes(tun *tun.NativeTun) (*winipcfg.RouteChangeCallback, error) {
guid := tun.GUID()
ourLuid, err := winipcfg.InterfaceGuidToLuid(&guid)
ourLuid, err := winipcfg.LUIDFromGUID(&guid)
lastMtu := uint32(0)
if err != nil {
return nil, err
return nil, fmt.Errorf("error mapping GUID %v to LUID: %w", guid, err)
}
doIt := func() error {
mtu, err := getDefaultRouteMTU()
if err != nil {
return err
return fmt.Errorf("error getting default route MTU: %w", err)
}
if mtu > 0 && (lastMtu == 0 || lastMtu != mtu) {
iface, err := winipcfg.GetIpInterface(ourLuid, winipcfg.AF_INET)
iface, err := ourLuid.IPInterface(windows.AF_INET)
if err != nil {
return err
return fmt.Errorf("error getting v4 interface: %w", err)
}
iface.NlMtu = mtu - 80
iface.NLMTU = mtu - 80
// If the TUN device was created with a smaller MTU,
// though, such as 1280, we don't want to go bigger than
// configured. (See the comment on minimalMTU in the
// wgengine package.)
if min, err := tun.MTU(); err == nil && min < int(iface.NlMtu) {
iface.NlMtu = uint32(min)
if min, err := tun.MTU(); err == nil && min < int(iface.NLMTU) {
iface.NLMTU = uint32(min)
}
if iface.NlMtu < 576 {
iface.NlMtu = 576
if iface.NLMTU < 576 {
iface.NLMTU = 576
}
err = iface.Set()
if err != nil {
return err
return fmt.Errorf("error setting v4 MTU: %w", err)
}
tun.ForceMTU(int(iface.NlMtu))
iface, err = winipcfg.GetIpInterface(ourLuid, winipcfg.AF_INET6)
tun.ForceMTU(int(iface.NLMTU))
iface, err = ourLuid.IPInterface(windows.AF_INET6)
if err != nil {
if !isMissingIPv6Err(err) {
return err
if !errors.Is(err, windows.ERROR_NOT_FOUND) {
return fmt.Errorf("error getting v6 interface: %w", err)
}
} else {
iface.NlMtu = mtu - 80
if iface.NlMtu < 1280 {
iface.NlMtu = 1280
iface.NLMTU = mtu - 80
if iface.NLMTU < 1280 {
iface.NLMTU = 1280
}
err = iface.Set()
if err != nil {
return err
return fmt.Errorf("error setting v6 MTU: %w", err)
}
}
lastMtu = mtu
@@ -95,7 +94,7 @@ func monitorDefaultRoutes(tun *tun.NativeTun) (*winipcfg.RouteChangeCallback, er
if err != nil {
return nil, err
}
cb, err := winipcfg.RegisterRouteChangeCallback(func(notificationType winipcfg.MibNotificationType, route *winipcfg.Route) {
cb, err := winipcfg.RegisterRouteChangeCallback(func(notificationType winipcfg.MibNotificationType, route *winipcfg.MibIPforwardRow2) {
//fmt.Printf("MonitorDefaultRoutes: changed: %v\n", route.DestinationPrefix)
if route.DestinationPrefix.PrefixLength == 0 {
_ = doIt()
@@ -113,7 +112,7 @@ func getDefaultRouteMTU() (uint32, error) {
return 0, err
}
routes, err := winipcfg.GetRoutes(winipcfg.AF_INET)
routes, err := winipcfg.GetIPForwardTable2(windows.AF_INET)
if err != nil {
return 0, err
}
@@ -123,7 +122,7 @@ func getDefaultRouteMTU() (uint32, error) {
if route.DestinationPrefix.PrefixLength != 0 {
continue
}
routeMTU := mtus[route.InterfaceLuid]
routeMTU := mtus[route.InterfaceLUID]
if routeMTU == 0 {
continue
}
@@ -133,7 +132,7 @@ func getDefaultRouteMTU() (uint32, error) {
}
}
routes, err = winipcfg.GetRoutes(winipcfg.AF_INET6)
routes, err = winipcfg.GetIPForwardTable2(windows.AF_INET6)
if err != nil {
return 0, err
}
@@ -142,7 +141,7 @@ func getDefaultRouteMTU() (uint32, error) {
if route.DestinationPrefix.PrefixLength != 0 {
continue
}
routeMTU := mtus[route.InterfaceLuid]
routeMTU := mtus[route.InterfaceLUID]
if routeMTU == 0 {
continue
}
@@ -215,16 +214,34 @@ func setPrivateNetwork(ifcGUID *windows.GUID) (bool, error) {
return false, nil
}
// interfaceFromGUID returns IPAdapterAddresses with specified GUID.
func interfaceFromGUID(guid *windows.GUID, flags winipcfg.GAAFlags) (*winipcfg.IPAdapterAddresses, error) {
luid, err := winipcfg.LUIDFromGUID(guid)
if err != nil {
return nil, err
}
addresses, err := winipcfg.GetAdaptersAddresses(windows.AF_UNSPEC, flags)
if err != nil {
return nil, err
}
for _, addr := range addresses {
if addr.LUID == luid {
return addr, nil
}
}
return nil, fmt.Errorf("interfaceFromGUID: interface with LUID %v (from GUID %v) not found", luid, guid)
}
func configureInterface(cfg *Config, tun *tun.NativeTun) error {
const mtu = 0
guid := tun.GUID()
iface, err := winipcfg.InterfaceFromGUIDEx(&guid, &winipcfg.GetAdapterAddressesFlags{
iface, err := interfaceFromGUID(&guid,
// Issue 474: on early boot, when the network is still
// coming up, if the Tailscale service comes up first,
// the Tailscale adapter it finds might not have the
// IPv4 service available yet? Try this flag:
GAA_FLAG_INCLUDE_ALL_INTERFACES: true,
})
winipcfg.GAAFlagIncludeAllInterfaces,
)
if err != nil {
return err
}
@@ -327,6 +344,18 @@ func configureInterface(cfg *Config, tun *tun.NativeTun) error {
deduplicatedRoutes = append(deduplicatedRoutes, &routes[i])
}
// Re-read interface after syncAddresses.
iface, err = interfaceFromGUID(&guid,
// Issue 474: on early boot, when the network is still
// coming up, if the Tailscale service comes up first,
// the Tailscale adapter it finds might not have the
// IPv4 service available yet? Try this flag:
winipcfg.GAAFlagIncludeAllInterfaces,
)
if err != nil {
return err
}
var errAcc error
err = syncRoutes(iface, deduplicatedRoutes)
if err != nil && errAcc == nil {
@@ -334,7 +363,7 @@ func configureInterface(cfg *Config, tun *tun.NativeTun) error {
errAcc = err
}
ipif, err := iface.GetIpInterface(winipcfg.AF_INET)
ipif, err := iface.LUID.IPInterface(windows.AF_INET)
if err != nil {
log.Printf("getipif: %v", err)
return err
@@ -344,17 +373,17 @@ func configureInterface(cfg *Config, tun *tun.NativeTun) error {
ipif.Metric = 0
}
if mtu > 0 {
ipif.NlMtu = uint32(mtu)
tun.ForceMTU(int(ipif.NlMtu))
ipif.NLMTU = uint32(mtu)
tun.ForceMTU(int(ipif.NLMTU))
}
err = ipif.Set()
if err != nil && errAcc == nil {
errAcc = err
}
ipif, err = iface.GetIpInterface(winipcfg.AF_INET6)
ipif, err = iface.LUID.IPInterface(windows.AF_INET6)
if err != nil {
if !isMissingIPv6Err(err) {
if !errors.Is(err, windows.ERROR_NOT_FOUND) {
return err
}
} else {
@@ -363,7 +392,7 @@ func configureInterface(cfg *Config, tun *tun.NativeTun) error {
ipif.Metric = 0
}
if mtu > 0 {
ipif.NlMtu = uint32(mtu)
ipif.NLMTU = uint32(mtu)
}
ipif.DadTransmits = 0
ipif.RouterDiscoveryBehavior = winipcfg.RouterDiscoveryDisabled
@@ -376,22 +405,6 @@ func configureInterface(cfg *Config, tun *tun.NativeTun) error {
return errAcc
}
// isMissingIPv6Err reports whether err is due to IPv6 not being enabled on the machine.
//
// It only currently supports the errors returned by winipcfg.Interface.GetIpInterface.
func isMissingIPv6Err(err error) bool {
if se, ok := err.(*os.SyscallError); ok {
switch se.Syscall {
case "iphlpapi.GetIpInterfaceEntry":
// ERROR_NOT_FOUND from means the address family (IPv6) is not found.
// (ERROR_FILE_NOT_FOUND means that the interface doesn't exist.)
// https://docs.microsoft.com/en-us/windows/win32/api/netioapi/nf-netioapi-getipinterfaceentry
return se.Err == windows.ERROR_NOT_FOUND
}
}
return false
}
// routeLess reports whether ri should sort before rj.
// The actual sort order doesn't appear to matter. The caller just
// wants them sorted to be able to de-dup.
@@ -495,31 +508,53 @@ func excludeIPv6LinkLocal(in []*net.IPNet) (out []*net.IPNet) {
return out
}
// ipAdapterUnicastAddressToIPNet converts windows.IpAdapterUnicastAddress to net.IPNet.
func ipAdapterUnicastAddressToIPNet(u *windows.IpAdapterUnicastAddress) *net.IPNet {
ip := u.Address.IP()
w := 32
if ip.To4() == nil {
w = 128
}
return &net.IPNet{
IP: ip,
Mask: net.CIDRMask(int(u.OnLinkPrefixLength), w),
}
}
// unicastIPNets returns all unicast net.IPNet for ifc interface.
func unicastIPNets(ifc *winipcfg.IPAdapterAddresses) []*net.IPNet {
nets := make([]*net.IPNet, 0)
for addr := ifc.FirstUnicastAddress; addr != nil; addr = addr.Next {
nets = append(nets, ipAdapterUnicastAddressToIPNet(addr))
}
return nets
}
// syncAddresses incrementally sets the interface's unicast IP addresses,
// doing the minimum number of AddAddresses & DeleteAddress calls.
// This avoids the full FlushAddresses.
//
// Any IPv6 link-local addresses are not deleted.
func syncAddresses(ifc *winipcfg.Interface, want []*net.IPNet) error {
func syncAddresses(ifc *winipcfg.IPAdapterAddresses, want []*net.IPNet) error {
var erracc error
got := ifc.UnicastIPNets
got := unicastIPNets(ifc)
add, del := deltaNets(got, want)
del = excludeIPv6LinkLocal(del)
for _, a := range del {
err := ifc.DeleteAddress(&a.IP)
err := ifc.LUID.DeleteIPAddress(*a)
if err != nil {
erracc = err
}
}
err := ifc.AddAddresses(add)
if err != nil {
erracc = err
for _, a := range add {
err := ifc.LUID.AddIPAddress(*a)
if err != nil {
erracc = err
}
}
ifc.UnicastIPNets = make([]*net.IPNet, len(want))
copy(ifc.UnicastIPNets, want)
return erracc
}
@@ -588,28 +623,46 @@ func deltaRouteData(a, b []*winipcfg.RouteData) (add, del []*winipcfg.RouteData)
return
}
// getInterfaceRoutes returns all the interface's routes.
// Corresponds to GetIpForwardTable2 function, but filtered by interface.
func getInterfaceRoutes(ifc *winipcfg.IPAdapterAddresses, family winipcfg.AddressFamily) ([]*winipcfg.MibIPforwardRow2, error) {
routes, err := winipcfg.GetIPForwardTable2(family)
if err != nil {
return nil, err
}
matches := make([]*winipcfg.MibIPforwardRow2, len(routes))
i := 0
for i := range routes {
if routes[i].InterfaceLUID == ifc.LUID {
matches[i] = &routes[i]
i++
}
}
return matches[:i], nil
}
// syncRoutes incrementally sets multiples routes on an interface.
// This avoids a full ifc.FlushRoutes call.
func syncRoutes(ifc *winipcfg.Interface, want []*winipcfg.RouteData) error {
routes, err := ifc.GetRoutes(windows.AF_INET)
func syncRoutes(ifc *winipcfg.IPAdapterAddresses, want []*winipcfg.RouteData) error {
routes, err := getInterfaceRoutes(ifc, windows.AF_INET)
if err != nil {
return err
}
got := make([]*winipcfg.RouteData, 0, len(routes))
for _, r := range routes {
v, err := r.ToRouteData()
if err != nil {
return err
}
got = append(got, v)
got = append(got, &winipcfg.RouteData{
Destination: r.DestinationPrefix.IPNet(),
NextHop: r.NextHop.IP(),
Metric: r.Metric,
})
}
add, del := deltaRouteData(got, want)
var errs []error
for _, a := range del {
err := ifc.DeleteRoute(&a.Destination, &a.NextHop)
err := ifc.LUID.DeleteRoute(a.Destination, a.NextHop)
if err != nil {
dstStr := a.Destination.String()
if dstStr == "169.254.255.255/32" {
@@ -622,7 +675,7 @@ func syncRoutes(ifc *winipcfg.Interface, want []*winipcfg.RouteData) error {
}
for _, a := range add {
err := ifc.AddRoute(a)
err := ifc.LUID.AddRoute(a.Destination, a.NextHop, a.Metric)
if err != nil {
errs = append(errs, fmt.Errorf("adding route %v: %w", &a.Destination, err))
}

View File

@@ -11,7 +11,7 @@ import (
"strings"
"testing"
winipcfg "github.com/tailscale/winipcfg-go"
"golang.zx2c4.com/wireguard/windows/tunnel/winipcfg"
"inet.af/netaddr"
)

View File

@@ -11,9 +11,9 @@ import (
"syscall"
"time"
winipcfg "github.com/tailscale/winipcfg-go"
"github.com/tailscale/wireguard-go/device"
"github.com/tailscale/wireguard-go/tun"
"golang.zx2c4.com/wireguard/windows/tunnel/winipcfg"
"tailscale.com/types/logger"
"tailscale.com/wgengine/router/dns"
)