tailscale/net/netns/netns_darwin.go
Andrew Dunham 2703d6916f net/netns: add functionality to bind outgoing sockets based on route table
When turned on via environment variable (off by default), this will use
the BSD routing APIs to query what interface index a socket should be
bound to, rather than binding to the default interface in all cases.

Updates #5719
Updates #5940

Signed-off-by: Andrew Dunham <andrew@du.nham.ca>
Change-Id: Ib4c919471f377b7a08cd3413f8e8caacb29fee0b
2023-01-26 20:58:58 -05:00

220 lines
5.7 KiB
Go

// Copyright (c) 2021 Tailscale Inc & AUTHORS All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build darwin
package netns
import (
"errors"
"fmt"
"log"
"net"
"net/netip"
"os"
"strings"
"syscall"
"golang.org/x/net/route"
"golang.org/x/sys/unix"
"tailscale.com/envknob"
"tailscale.com/net/interfaces"
"tailscale.com/types/logger"
)
func control(logf logger.Logf) func(network, address string, c syscall.RawConn) error {
return func(network, address string, c syscall.RawConn) error {
return controlLogf(logf, network, address, c)
}
}
var bindToInterfaceByRouteEnv = envknob.RegisterBool("TS_BIND_TO_INTERFACE_BY_ROUTE")
var errInterfaceIndexInvalid = errors.New("interface index invalid")
// controlLogf marks c as necessary to dial in a separate network namespace.
//
// It's intentionally the same signature as net.Dialer.Control
// and net.ListenConfig.Control.
func controlLogf(logf logger.Logf, network, address string, c syscall.RawConn) error {
if isLocalhost(address) {
// Don't bind to an interface for localhost connections.
return nil
}
idx, err := getInterfaceIndex(logf, address)
if err != nil {
// callee logged
return nil
}
return bindConnToInterface(c, network, address, idx, logf)
}
func getInterfaceIndex(logf logger.Logf, address string) (int, error) {
// Helper so we can log errors.
defaultIdx := func() (int, error) {
idx, err := interfaces.DefaultRouteInterfaceIndex()
if err != nil {
logf("[unexpected] netns: DefaultRouteInterfaceIndex: %v", err)
return -1, err
}
return idx, nil
}
useRoute := bindToInterfaceByRoute.Load() || bindToInterfaceByRouteEnv()
if !useRoute {
return defaultIdx()
}
host, _, err := net.SplitHostPort(address)
if err != nil {
// No port number; use the string directly.
host = address
}
// If the address doesn't parse, use the default index.
addr, err := netip.ParseAddr(host)
if err != nil {
logf("[unexpected] netns: error parsing address %q: %v", host, err)
return defaultIdx()
}
idx, err := interfaceIndexFor(addr, true /* canRecurse */)
if err != nil {
logf("netns: error in interfaceIndexFor: %v", err)
return defaultIdx()
}
// Verify that we didn't just choose the Tailscale interface;
// if so, we fall back to binding from the default.
_, tsif, err2 := interfaces.Tailscale()
if err2 == nil && tsif.Index == idx {
logf("[unexpected] netns: interfaceIndexFor returned Tailscale interface")
return defaultIdx()
}
return idx, err
}
// interfaceIndexFor returns the interface index that we should bind to in
// order to send traffic to the provided address.
func interfaceIndexFor(addr netip.Addr, canRecurse bool) (int, error) {
fd, err := unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW, unix.AF_UNSPEC)
if err != nil {
return 0, fmt.Errorf("creating AF_ROUTE socket: %w", err)
}
defer unix.Close(fd)
var routeAddr route.Addr
if addr.Is4() {
routeAddr = &route.Inet4Addr{IP: addr.As4()}
} else {
routeAddr = &route.Inet6Addr{IP: addr.As16()}
}
rm := route.RouteMessage{
Version: unix.RTM_VERSION,
Type: unix.RTM_GET,
Flags: unix.RTF_UP,
ID: uintptr(os.Getpid()),
Seq: 1,
Addrs: []route.Addr{
unix.RTAX_DST: routeAddr,
},
}
b, err := rm.Marshal()
if err != nil {
return 0, fmt.Errorf("marshaling RouteMessage: %w", err)
}
_, err = unix.Write(fd, b)
if err != nil {
return 0, fmt.Errorf("writing message: %w", err)
}
var buf [2048]byte
n, err := unix.Read(fd, buf[:])
if err != nil {
return 0, fmt.Errorf("reading message: %w", err)
}
msgs, err := route.ParseRIB(route.RIBTypeRoute, buf[:n])
if err != nil {
return 0, fmt.Errorf("route.ParseRIB: %w", err)
}
if len(msgs) == 0 {
return 0, fmt.Errorf("no messages")
}
for _, msg := range msgs {
rm, ok := msg.(*route.RouteMessage)
if !ok {
continue
}
if rm.Version < 3 || rm.Version > 5 || rm.Type != unix.RTM_GET {
continue
}
if len(rm.Addrs) < unix.RTAX_GATEWAY {
continue
}
switch addr := rm.Addrs[unix.RTAX_GATEWAY].(type) {
case *route.LinkAddr:
return addr.Index, nil
case *route.Inet4Addr:
// We can get a gateway IP; recursively call ourselves
// (exactly once) to get the link (and thus index) for
// the gateway IP.
if canRecurse {
return interfaceIndexFor(netip.AddrFrom4(addr.IP), false)
}
case *route.Inet6Addr:
// As above.
if canRecurse {
return interfaceIndexFor(netip.AddrFrom16(addr.IP), false)
}
default:
// Unknown type; skip it
continue
}
}
return 0, fmt.Errorf("no valid address found")
}
// SetListenConfigInterfaceIndex sets lc.Control such that sockets are bound
// to the provided interface index.
func SetListenConfigInterfaceIndex(lc *net.ListenConfig, ifIndex int) error {
if lc == nil {
return errors.New("nil ListenConfig")
}
if lc.Control != nil {
return errors.New("ListenConfig.Control already set")
}
lc.Control = func(network, address string, c syscall.RawConn) error {
return bindConnToInterface(c, network, address, ifIndex, log.Printf)
}
return nil
}
func bindConnToInterface(c syscall.RawConn, network, address string, ifIndex int, logf logger.Logf) error {
v6 := strings.Contains(address, "]:") || strings.HasSuffix(network, "6") // hacky test for v6
proto := unix.IPPROTO_IP
opt := unix.IP_BOUND_IF
if v6 {
proto = unix.IPPROTO_IPV6
opt = unix.IPV6_BOUND_IF
}
var sockErr error
err := c.Control(func(fd uintptr) {
sockErr = unix.SetsockoptInt(int(fd), proto, opt, ifIndex)
})
if sockErr != nil {
logf("[unexpected] netns: bindConnToInterface(%q, %q), v6=%v, index=%v: %v", network, address, v6, ifIndex, sockErr)
}
if err != nil {
return fmt.Errorf("RawConn.Control on %T: %w", c, err)
}
return sockErr
}