net/netns: add Windows support for bind-to-interface-by-route

This is implemented via GetBestInterfaceEx. Should we encounter errors
or fail to resolve a valid, non-Tailscale interface, we fall back to
returning the index for the default interface instead.

Fixes #12551

Signed-off-by: Aaron Klotz <aaron@tailscale.com>
This commit is contained in:
Aaron Klotz 2024-05-21 14:38:53 -06:00
parent 591979b95f
commit 7dd76c3411
10 changed files with 313 additions and 28 deletions

View File

@ -99,7 +99,7 @@ tailscale.com/cmd/derper dependencies: (generated by github.com/tailscale/depawa
tailscale.com/net/netaddr from tailscale.com/ipn+ tailscale.com/net/netaddr from tailscale.com/ipn+
tailscale.com/net/netknob from tailscale.com/net/netns tailscale.com/net/netknob from tailscale.com/net/netns
💣 tailscale.com/net/netmon from tailscale.com/derp/derphttp+ 💣 tailscale.com/net/netmon from tailscale.com/derp/derphttp+
tailscale.com/net/netns from tailscale.com/derp/derphttp 💣 tailscale.com/net/netns from tailscale.com/derp/derphttp
tailscale.com/net/netutil from tailscale.com/client/tailscale tailscale.com/net/netutil from tailscale.com/client/tailscale
tailscale.com/net/sockstats from tailscale.com/derp/derphttp tailscale.com/net/sockstats from tailscale.com/derp/derphttp
tailscale.com/net/stun from tailscale.com/net/stunserver tailscale.com/net/stun from tailscale.com/net/stunserver
@ -114,7 +114,7 @@ tailscale.com/cmd/derper dependencies: (generated by github.com/tailscale/depawa
tailscale.com/syncs from tailscale.com/cmd/derper+ tailscale.com/syncs from tailscale.com/cmd/derper+
tailscale.com/tailcfg from tailscale.com/client/tailscale+ tailscale.com/tailcfg from tailscale.com/client/tailscale+
tailscale.com/tka from tailscale.com/client/tailscale+ tailscale.com/tka from tailscale.com/client/tailscale+
W tailscale.com/tsconst from tailscale.com/net/netmon W tailscale.com/tsconst from tailscale.com/net/netmon+
tailscale.com/tstime from tailscale.com/derp+ tailscale.com/tstime from tailscale.com/derp+
tailscale.com/tstime/mono from tailscale.com/tstime/rate tailscale.com/tstime/mono from tailscale.com/tstime/rate
tailscale.com/tstime/rate from tailscale.com/derp tailscale.com/tstime/rate from tailscale.com/derp

View File

@ -103,7 +103,7 @@ tailscale.com/cmd/tailscale dependencies: (generated by github.com/tailscale/dep
tailscale.com/net/neterror from tailscale.com/net/netcheck+ tailscale.com/net/neterror from tailscale.com/net/netcheck+
tailscale.com/net/netknob from tailscale.com/net/netns tailscale.com/net/netknob from tailscale.com/net/netns
💣 tailscale.com/net/netmon from tailscale.com/cmd/tailscale/cli+ 💣 tailscale.com/net/netmon from tailscale.com/cmd/tailscale/cli+
tailscale.com/net/netns from tailscale.com/derp/derphttp+ 💣 tailscale.com/net/netns from tailscale.com/derp/derphttp+
tailscale.com/net/netutil from tailscale.com/client/tailscale+ tailscale.com/net/netutil from tailscale.com/client/tailscale+
tailscale.com/net/packet from tailscale.com/wgengine/capture tailscale.com/net/packet from tailscale.com/wgengine/capture
tailscale.com/net/ping from tailscale.com/net/netcheck tailscale.com/net/ping from tailscale.com/net/netcheck
@ -121,7 +121,7 @@ tailscale.com/cmd/tailscale dependencies: (generated by github.com/tailscale/dep
tailscale.com/tailcfg from tailscale.com/client/tailscale+ tailscale.com/tailcfg from tailscale.com/client/tailscale+
tailscale.com/tempfork/spf13/cobra from tailscale.com/cmd/tailscale/cli/ffcomplete+ tailscale.com/tempfork/spf13/cobra from tailscale.com/cmd/tailscale/cli/ffcomplete+
tailscale.com/tka from tailscale.com/client/tailscale+ tailscale.com/tka from tailscale.com/client/tailscale+
W tailscale.com/tsconst from tailscale.com/net/netmon W tailscale.com/tsconst from tailscale.com/net/netmon+
tailscale.com/tstime from tailscale.com/control/controlhttp+ tailscale.com/tstime from tailscale.com/control/controlhttp+
tailscale.com/tstime/mono from tailscale.com/tstime/rate tailscale.com/tstime/mono from tailscale.com/tstime/rate
tailscale.com/tstime/rate from tailscale.com/cmd/tailscale/cli+ tailscale.com/tstime/rate from tailscale.com/cmd/tailscale/cli+

View File

@ -303,7 +303,7 @@ tailscale.com/cmd/tailscaled dependencies: (generated by github.com/tailscale/de
tailscale.com/net/netkernelconf from tailscale.com/ipn/ipnlocal tailscale.com/net/netkernelconf from tailscale.com/ipn/ipnlocal
tailscale.com/net/netknob from tailscale.com/logpolicy+ tailscale.com/net/netknob from tailscale.com/logpolicy+
💣 tailscale.com/net/netmon from tailscale.com/cmd/tailscaled+ 💣 tailscale.com/net/netmon from tailscale.com/cmd/tailscaled+
tailscale.com/net/netns from tailscale.com/cmd/tailscaled+ 💣 tailscale.com/net/netns from tailscale.com/cmd/tailscaled+
W 💣 tailscale.com/net/netstat from tailscale.com/portlist W 💣 tailscale.com/net/netstat from tailscale.com/portlist
tailscale.com/net/netutil from tailscale.com/client/tailscale+ tailscale.com/net/netutil from tailscale.com/client/tailscale+
tailscale.com/net/packet from tailscale.com/net/connstats+ tailscale.com/net/packet from tailscale.com/net/connstats+
@ -335,7 +335,7 @@ tailscale.com/cmd/tailscaled dependencies: (generated by github.com/tailscale/de
LD tailscale.com/tempfork/gliderlabs/ssh from tailscale.com/ssh/tailssh LD tailscale.com/tempfork/gliderlabs/ssh from tailscale.com/ssh/tailssh
tailscale.com/tempfork/heap from tailscale.com/wgengine/magicsock tailscale.com/tempfork/heap from tailscale.com/wgengine/magicsock
tailscale.com/tka from tailscale.com/client/tailscale+ tailscale.com/tka from tailscale.com/client/tailscale+
W tailscale.com/tsconst from tailscale.com/net/netmon W tailscale.com/tsconst from tailscale.com/net/netmon+
tailscale.com/tsd from tailscale.com/cmd/tailscaled+ tailscale.com/tsd from tailscale.com/cmd/tailscaled+
tailscale.com/tstime from tailscale.com/control/controlclient+ tailscale.com/tstime from tailscale.com/control/controlclient+
tailscale.com/tstime/mono from tailscale.com/net/tstun+ tailscale.com/tstime/mono from tailscale.com/net/tstun+

9
net/netns/mksyscall.go Normal file
View File

@ -0,0 +1,9 @@
// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause
package netns
//go:generate go run golang.org/x/sys/windows/mkwinsyscall -output zsyscall_windows.go mksyscall.go
//go:generate go run golang.org/x/tools/cmd/goimports -w zsyscall_windows.go
//sys getBestInterfaceEx(sockaddr *winipcfg.RawSockaddrInet, bestIfaceIndex *uint32) (ret error) = iphlpapi.GetBestInterfaceEx

View File

@ -38,7 +38,7 @@ func SetEnabled(on bool) {
// route information to bind to a particular interface. It is the same as // route information to bind to a particular interface. It is the same as
// setting the TS_BIND_TO_INTERFACE_BY_ROUTE. // setting the TS_BIND_TO_INTERFACE_BY_ROUTE.
// //
// Currently, this only changes the behaviour on macOS. // Currently, this only changes the behaviour on macOS and Windows.
func SetBindToInterfaceByRoute(v bool) { func SetBindToInterfaceByRoute(v bool) {
bindToInterfaceByRoute.Store(v) bindToInterfaceByRoute.Store(v)
} }

View File

@ -89,16 +89,10 @@ func getInterfaceIndex(logf logger.Logf, netMon *netmon.Monitor, address string)
return defaultIdx() return defaultIdx()
} }
host, _, err := net.SplitHostPort(address)
if err != nil {
// No port number; use the string directly.
host = address
}
// If the address doesn't parse, use the default index. // If the address doesn't parse, use the default index.
addr, err := netip.ParseAddr(host) addr, err := parseAddress(address)
if err != nil { if err != nil {
logf("[unexpected] netns: error parsing address %q: %v", host, err) logf("[unexpected] netns: error parsing address %q: %v", address, err)
return defaultIdx() return defaultIdx()
} }

21
net/netns/netns_dw.go Normal file
View File

@ -0,0 +1,21 @@
// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause
//go:build darwin || windows
package netns
import (
"net"
"net/netip"
)
func parseAddress(address string) (addr netip.Addr, err error) {
host, _, err := net.SplitHostPort(address)
if err != nil {
// error means the string didn't contain a port number, so use the string directly
host = address
}
return netip.ParseAddr(host)
}

View File

@ -4,14 +4,18 @@
package netns package netns
import ( import (
"fmt"
"math/bits" "math/bits"
"net/netip"
"strings" "strings"
"syscall" "syscall"
"golang.org/x/sys/cpu" "golang.org/x/sys/cpu"
"golang.org/x/sys/windows" "golang.org/x/sys/windows"
"golang.zx2c4.com/wireguard/windows/tunnel/winipcfg" "golang.zx2c4.com/wireguard/windows/tunnel/winipcfg"
"tailscale.com/envknob"
"tailscale.com/net/netmon" "tailscale.com/net/netmon"
"tailscale.com/tsconst"
"tailscale.com/types/logger" "tailscale.com/types/logger"
) )
@ -26,20 +30,34 @@ func interfaceIndex(iface *winipcfg.IPAdapterAddresses) uint32 {
return iface.IfIndex return iface.IfIndex
} }
func control(logger.Logf, *netmon.Monitor) func(network, address string, c syscall.RawConn) error { func defaultInterfaceIndex(family winipcfg.AddressFamily) (uint32, error) {
return controlC iface, err := netmon.GetWindowsDefault(family)
if err != nil {
return 0, err
} }
return interfaceIndex(iface), nil
}
func control(logf logger.Logf, _ *netmon.Monitor) func(network, address string, c syscall.RawConn) error {
return func(network, address string, c syscall.RawConn) error {
return controlC(logf, network, address, c)
}
}
var bindToInterfaceByRouteEnv = envknob.RegisterBool("TS_BIND_TO_INTERFACE_BY_ROUTE")
// controlC binds c to the Windows interface that holds a default // controlC binds c to the Windows interface that holds a default
// route, and is not the Tailscale WinTun interface. // route, and is not the Tailscale WinTun interface.
func controlC(network, address string, c syscall.RawConn) error { func controlC(logf logger.Logf, network, address string, c syscall.RawConn) (err error) {
if strings.HasPrefix(address, "127.") { if isLocalhost(address) {
// Don't bind to an interface for localhost connections, // Don't bind to an interface for localhost connections,
// otherwise we get: // otherwise we get:
// connectex: The requested address is not valid in its context // connectex: The requested address is not valid in its context
// (The derphttp tests were failing) // (The derphttp tests were failing)
return nil return nil
} }
canV4, canV6 := false, false canV4, canV6 := false, false
switch network { switch network {
case "tcp", "udp": case "tcp", "udp":
@ -50,29 +68,107 @@ func controlC(network, address string, c syscall.RawConn) error {
canV6 = true canV6 = true
} }
var defIfaceIdxV4, defIfaceIdxV6 uint32
if canV4 { if canV4 {
iface, err := netmon.GetWindowsDefault(windows.AF_INET) defIfaceIdxV4, err = defaultInterfaceIndex(windows.AF_INET)
if err != nil { if err != nil {
return err return fmt.Errorf("defaultInterfaceIndex(AF_INET): %w", err)
}
if err := bindSocket4(c, interfaceIndex(iface)); err != nil {
return err
} }
} }
if canV6 { if canV6 {
iface, err := netmon.GetWindowsDefault(windows.AF_INET6) defIfaceIdxV6, err = defaultInterfaceIndex(windows.AF_INET6)
if err != nil { if err != nil {
return err return fmt.Errorf("defaultInterfaceIndex(AF_INET6): %w", err)
} }
if err := bindSocket6(c, interfaceIndex(iface)); err != nil { }
return err
var ifaceIdxV4, ifaceIdxV6 uint32
if useRoute := bindToInterfaceByRoute.Load() || bindToInterfaceByRouteEnv(); useRoute {
addr, err := parseAddress(address)
if err != nil {
return fmt.Errorf("parseAddress: %w", err)
}
if canV4 && (addr.Is4() || addr.Is4In6()) {
addrV4 := addr.Unmap()
ifaceIdxV4, err = getInterfaceIndex(logf, addrV4, defIfaceIdxV4)
if err != nil {
return fmt.Errorf("getInterfaceIndex(%v): %w", addrV4, err)
}
}
if canV6 && addr.Is6() {
ifaceIdxV6, err = getInterfaceIndex(logf, addr, defIfaceIdxV6)
if err != nil {
return fmt.Errorf("getInterfaceIndex(%v): %w", addr, err)
}
}
} else {
ifaceIdxV4, ifaceIdxV6 = defIfaceIdxV4, defIfaceIdxV6
}
if canV4 {
if err := bindSocket4(c, ifaceIdxV4); err != nil {
return fmt.Errorf("bindSocket4(%d): %w", ifaceIdxV4, err)
}
}
if canV6 {
if err := bindSocket6(c, ifaceIdxV6); err != nil {
return fmt.Errorf("bindSocket6(%d): %w", ifaceIdxV6, err)
} }
} }
return nil return nil
} }
func getInterfaceIndex(logf logger.Logf, addr netip.Addr, defaultIdx uint32) (idx uint32, err error) {
idx, err = interfaceIndexFor(addr)
if err != nil {
return defaultIdx, fmt.Errorf("interfaceIndexFor: %w", err)
}
isTS, err := isTailscaleInterface(idx)
if err != nil {
return defaultIdx, fmt.Errorf("isTailscaleInterface: %w", err)
}
if isTS {
return defaultIdx, nil
}
return idx, nil
}
func isTailscaleInterface(ifaceIdx uint32) (bool, error) {
ifaceLUID, err := winipcfg.LUIDFromIndex(ifaceIdx)
if err != nil {
return false, err
}
iface, err := ifaceLUID.Interface()
if err != nil {
return false, err
}
result := iface.Type == winipcfg.IfTypePropVirtual &&
strings.Contains(iface.Description(), tsconst.WintunInterfaceDesc)
return result, nil
}
func interfaceIndexFor(addr netip.Addr) (uint32, error) {
var sockaddr winipcfg.RawSockaddrInet
if err := sockaddr.SetAddr(addr); err != nil {
return 0, err
}
var idx uint32
if err := getBestInterfaceEx(&sockaddr, &idx); err != nil {
return 0, err
}
return idx, nil
}
// sockoptBoundInterface is the value of IP_UNICAST_IF and IPV6_UNICAST_IF. // sockoptBoundInterface is the value of IP_UNICAST_IF and IPV6_UNICAST_IF.
// //
// See https://docs.microsoft.com/en-us/windows/win32/winsock/ipproto-ip-socket-options // See https://docs.microsoft.com/en-us/windows/win32/winsock/ipproto-ip-socket-options

View File

@ -0,0 +1,112 @@
// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause
package netns
import (
"strings"
"testing"
"golang.org/x/sys/windows"
"golang.zx2c4.com/wireguard/windows/tunnel/winipcfg"
"tailscale.com/tsconst"
)
func TestGetInterfaceIndex(t *testing.T) {
oldVal := bindToInterfaceByRoute.Load()
t.Cleanup(func() { bindToInterfaceByRoute.Store(oldVal) })
bindToInterfaceByRoute.Store(true)
defIfaceIdxV4, err := defaultInterfaceIndex(windows.AF_INET)
if err != nil {
t.Fatalf("defaultInterfaceIndex(AF_INET) failed: %v", err)
}
tests := []struct {
name string
addr string
err string
}{
{
name: "IP_and_port",
addr: "8.8.8.8:53",
},
{
name: "bare_ip",
addr: "8.8.8.8",
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
addr, err := parseAddress(tc.addr)
if err != nil {
t.Fatal(err)
}
idx, err := getInterfaceIndex(t.Logf, addr, defIfaceIdxV4)
if err != nil {
if tc.err == "" {
t.Fatalf("got unexpected error: %v", err)
}
if errstr := err.Error(); errstr != tc.err {
t.Errorf("expected error %q, got %q", errstr, tc.err)
}
} else {
t.Logf("getInterfaceIndex(%q) = %d", tc.addr, idx)
if tc.err != "" {
t.Fatalf("wanted error %q", tc.err)
}
}
})
}
t.Run("NoTailscale", func(t *testing.T) {
tsIdx, ok, err := tailscaleInterfaceIndex()
if err != nil {
t.Fatal(err)
}
if !ok {
t.Skip("no tailscale interface on this machine")
}
defaultIdx, err := defaultInterfaceIndex(windows.AF_INET)
if err != nil {
t.Fatalf("defaultInterfaceIndex(AF_INET) failed: %v", err)
}
addr, err := parseAddress("100.100.100.100:53")
if err != nil {
t.Fatal(err)
}
idx, err := getInterfaceIndex(t.Logf, addr, defaultIdx)
if err != nil {
t.Fatal(err)
}
t.Logf("tailscaleIdx=%d defaultIdx=%d idx=%d", tsIdx, defaultIdx, idx)
if idx == tsIdx {
t.Fatalf("got idx=%d; wanted not Tailscale interface", idx)
} else if idx != defaultIdx {
t.Fatalf("got idx=%d, want %d", idx, defaultIdx)
}
})
}
func tailscaleInterfaceIndex() (idx uint32, found bool, err error) {
ifs, err := winipcfg.GetAdaptersAddresses(windows.AF_INET, winipcfg.GAAFlagIncludeAllInterfaces)
if err != nil {
return idx, false, err
}
for _, iface := range ifs {
if iface.IfType != winipcfg.IfTypePropVirtual {
continue
}
if strings.Contains(iface.Description(), tsconst.WintunInterfaceDesc) {
return iface.IfIndex, true, nil
}
}
return idx, false, nil
}

View File

@ -0,0 +1,53 @@
// Code generated by 'go generate'; DO NOT EDIT.
package netns
import (
"syscall"
"unsafe"
"golang.org/x/sys/windows"
"golang.zx2c4.com/wireguard/windows/tunnel/winipcfg"
)
var _ unsafe.Pointer
// Do the interface allocations only once for common
// Errno values.
const (
errnoERROR_IO_PENDING = 997
)
var (
errERROR_IO_PENDING error = syscall.Errno(errnoERROR_IO_PENDING)
errERROR_EINVAL error = syscall.EINVAL
)
// errnoErr returns common boxed Errno values, to prevent
// allocations at runtime.
func errnoErr(e syscall.Errno) error {
switch e {
case 0:
return errERROR_EINVAL
case errnoERROR_IO_PENDING:
return errERROR_IO_PENDING
}
// TODO: add more here, after collecting data on the common
// error values see on Windows. (perhaps when running
// all.bat?)
return e
}
var (
modiphlpapi = windows.NewLazySystemDLL("iphlpapi.dll")
procGetBestInterfaceEx = modiphlpapi.NewProc("GetBestInterfaceEx")
)
func getBestInterfaceEx(sockaddr *winipcfg.RawSockaddrInet, bestIfaceIndex *uint32) (ret error) {
r0, _, _ := syscall.Syscall(procGetBestInterfaceEx.Addr(), 2, uintptr(unsafe.Pointer(sockaddr)), uintptr(unsafe.Pointer(bestIfaceIndex)), 0)
if r0 != 0 {
ret = syscall.Errno(r0)
}
return
}