tailscale/net/netns/netns_darwin.go

220 lines
5.7 KiB
Go
Raw Normal View History

// Copyright (c) 2021 Tailscale Inc & AUTHORS All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build darwin
package netns
import (
"errors"
"fmt"
"log"
"net"
"net/netip"
"os"
"strings"
"syscall"
"golang.org/x/net/route"
"golang.org/x/sys/unix"
"tailscale.com/envknob"
"tailscale.com/net/interfaces"
"tailscale.com/types/logger"
)
func control(logf logger.Logf) func(network, address string, c syscall.RawConn) error {
return func(network, address string, c syscall.RawConn) error {
return controlLogf(logf, network, address, c)
}
}
var bindToInterfaceByRouteEnv = envknob.RegisterBool("TS_BIND_TO_INTERFACE_BY_ROUTE")
var errInterfaceIndexInvalid = errors.New("interface index invalid")
// controlLogf marks c as necessary to dial in a separate network namespace.
//
// It's intentionally the same signature as net.Dialer.Control
// and net.ListenConfig.Control.
func controlLogf(logf logger.Logf, network, address string, c syscall.RawConn) error {
if isLocalhost(address) {
// Don't bind to an interface for localhost connections.
return nil
}
idx, err := getInterfaceIndex(logf, address)
if err != nil {
// callee logged
return nil
}
return bindConnToInterface(c, network, address, idx, logf)
}
func getInterfaceIndex(logf logger.Logf, address string) (int, error) {
// Helper so we can log errors.
defaultIdx := func() (int, error) {
idx, err := interfaces.DefaultRouteInterfaceIndex()
if err != nil {
logf("[unexpected] netns: DefaultRouteInterfaceIndex: %v", err)
return -1, err
}
return idx, nil
}
useRoute := bindToInterfaceByRoute.Load() || bindToInterfaceByRouteEnv()
if !useRoute {
return defaultIdx()
}
host, _, err := net.SplitHostPort(address)
if err != nil {
// No port number; use the string directly.
host = address
}
// If the address doesn't parse, use the default index.
addr, err := netip.ParseAddr(host)
if err != nil {
logf("[unexpected] netns: error parsing address %q: %v", host, err)
return defaultIdx()
}
idx, err := interfaceIndexFor(addr, true /* canRecurse */)
if err != nil {
logf("netns: error in interfaceIndexFor: %v", err)
return defaultIdx()
}
// Verify that we didn't just choose the Tailscale interface;
// if so, we fall back to binding from the default.
_, tsif, err2 := interfaces.Tailscale()
if err2 == nil && tsif.Index == idx {
logf("[unexpected] netns: interfaceIndexFor returned Tailscale interface")
return defaultIdx()
}
return idx, err
}
// interfaceIndexFor returns the interface index that we should bind to in
// order to send traffic to the provided address.
func interfaceIndexFor(addr netip.Addr, canRecurse bool) (int, error) {
fd, err := unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW, unix.AF_UNSPEC)
if err != nil {
return 0, fmt.Errorf("creating AF_ROUTE socket: %w", err)
}
defer unix.Close(fd)
var routeAddr route.Addr
if addr.Is4() {
routeAddr = &route.Inet4Addr{IP: addr.As4()}
} else {
routeAddr = &route.Inet6Addr{IP: addr.As16()}
}
rm := route.RouteMessage{
Version: unix.RTM_VERSION,
Type: unix.RTM_GET,
Flags: unix.RTF_UP,
ID: uintptr(os.Getpid()),
Seq: 1,
Addrs: []route.Addr{
unix.RTAX_DST: routeAddr,
},
}
b, err := rm.Marshal()
if err != nil {
return 0, fmt.Errorf("marshaling RouteMessage: %w", err)
}
_, err = unix.Write(fd, b)
if err != nil {
return 0, fmt.Errorf("writing message: %w", err)
}
var buf [2048]byte
n, err := unix.Read(fd, buf[:])
if err != nil {
return 0, fmt.Errorf("reading message: %w", err)
}
msgs, err := route.ParseRIB(route.RIBTypeRoute, buf[:n])
if err != nil {
return 0, fmt.Errorf("route.ParseRIB: %w", err)
}
if len(msgs) == 0 {
return 0, fmt.Errorf("no messages")
}
for _, msg := range msgs {
rm, ok := msg.(*route.RouteMessage)
if !ok {
continue
}
if rm.Version < 3 || rm.Version > 5 || rm.Type != unix.RTM_GET {
continue
}
if len(rm.Addrs) < unix.RTAX_GATEWAY {
continue
}
switch addr := rm.Addrs[unix.RTAX_GATEWAY].(type) {
case *route.LinkAddr:
return addr.Index, nil
case *route.Inet4Addr:
// We can get a gateway IP; recursively call ourselves
// (exactly once) to get the link (and thus index) for
// the gateway IP.
if canRecurse {
return interfaceIndexFor(netip.AddrFrom4(addr.IP), false)
}
case *route.Inet6Addr:
// As above.
if canRecurse {
return interfaceIndexFor(netip.AddrFrom16(addr.IP), false)
}
default:
// Unknown type; skip it
continue
}
}
return 0, fmt.Errorf("no valid address found")
}
// SetListenConfigInterfaceIndex sets lc.Control such that sockets are bound
// to the provided interface index.
func SetListenConfigInterfaceIndex(lc *net.ListenConfig, ifIndex int) error {
if lc == nil {
return errors.New("nil ListenConfig")
}
if lc.Control != nil {
return errors.New("ListenConfig.Control already set")
}
lc.Control = func(network, address string, c syscall.RawConn) error {
return bindConnToInterface(c, network, address, ifIndex, log.Printf)
}
return nil
}
func bindConnToInterface(c syscall.RawConn, network, address string, ifIndex int, logf logger.Logf) error {
v6 := strings.Contains(address, "]:") || strings.HasSuffix(network, "6") // hacky test for v6
proto := unix.IPPROTO_IP
opt := unix.IP_BOUND_IF
if v6 {
proto = unix.IPPROTO_IPV6
opt = unix.IPV6_BOUND_IF
}
var sockErr error
err := c.Control(func(fd uintptr) {
sockErr = unix.SetsockoptInt(int(fd), proto, opt, ifIndex)
})
if sockErr != nil {
logf("[unexpected] netns: bindConnToInterface(%q, %q), v6=%v, index=%v: %v", network, address, v6, ifIndex, sockErr)
}
if err != nil {
return fmt.Errorf("RawConn.Control on %T: %w", c, err)
}
return sockErr
}