net/netns: add functionality to bind outgoing sockets based on route table

When turned on via environment variable (off by default), this will use
the BSD routing APIs to query what interface index a socket should be
bound to, rather than binding to the default interface in all cases.

Updates #5719
Updates #5940

Signed-off-by: Andrew Dunham <andrew@du.nham.ca>
Change-Id: Ib4c919471f377b7a08cd3413f8e8caacb29fee0b
This commit is contained in:
Andrew Dunham
2023-01-25 13:17:40 -05:00
parent 7439bc7ba6
commit 2703d6916f
5 changed files with 247 additions and 3 deletions

View File

@@ -11,10 +11,14 @@ import (
"fmt"
"log"
"net"
"net/netip"
"os"
"strings"
"syscall"
"golang.org/x/net/route"
"golang.org/x/sys/unix"
"tailscale.com/envknob"
"tailscale.com/net/interfaces"
"tailscale.com/types/logger"
)
@@ -25,6 +29,10 @@ func control(logf logger.Logf) func(network, address string, c syscall.RawConn)
}
}
var bindToInterfaceByRouteEnv = envknob.RegisterBool("TS_BIND_TO_INTERFACE_BY_ROUTE")
var errInterfaceIndexInvalid = errors.New("interface index invalid")
// controlLogf marks c as necessary to dial in a separate network namespace.
//
// It's intentionally the same signature as net.Dialer.Control
@@ -34,15 +42,145 @@ func controlLogf(logf logger.Logf, network, address string, c syscall.RawConn) e
// Don't bind to an interface for localhost connections.
return nil
}
idx, err := interfaces.DefaultRouteInterfaceIndex()
idx, err := getInterfaceIndex(logf, address)
if err != nil {
logf("[unexpected] netns: DefaultRouteInterfaceIndex: %v", err)
// callee logged
return nil
}
return bindConnToInterface(c, network, address, idx, logf)
}
func getInterfaceIndex(logf logger.Logf, address string) (int, error) {
// Helper so we can log errors.
defaultIdx := func() (int, error) {
idx, err := interfaces.DefaultRouteInterfaceIndex()
if err != nil {
logf("[unexpected] netns: DefaultRouteInterfaceIndex: %v", err)
return -1, err
}
return idx, nil
}
useRoute := bindToInterfaceByRoute.Load() || bindToInterfaceByRouteEnv()
if !useRoute {
return defaultIdx()
}
host, _, err := net.SplitHostPort(address)
if err != nil {
// No port number; use the string directly.
host = address
}
// If the address doesn't parse, use the default index.
addr, err := netip.ParseAddr(host)
if err != nil {
logf("[unexpected] netns: error parsing address %q: %v", host, err)
return defaultIdx()
}
idx, err := interfaceIndexFor(addr, true /* canRecurse */)
if err != nil {
logf("netns: error in interfaceIndexFor: %v", err)
return defaultIdx()
}
// Verify that we didn't just choose the Tailscale interface;
// if so, we fall back to binding from the default.
_, tsif, err2 := interfaces.Tailscale()
if err2 == nil && tsif.Index == idx {
logf("[unexpected] netns: interfaceIndexFor returned Tailscale interface")
return defaultIdx()
}
return idx, err
}
// interfaceIndexFor returns the interface index that we should bind to in
// order to send traffic to the provided address.
func interfaceIndexFor(addr netip.Addr, canRecurse bool) (int, error) {
fd, err := unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW, unix.AF_UNSPEC)
if err != nil {
return 0, fmt.Errorf("creating AF_ROUTE socket: %w", err)
}
defer unix.Close(fd)
var routeAddr route.Addr
if addr.Is4() {
routeAddr = &route.Inet4Addr{IP: addr.As4()}
} else {
routeAddr = &route.Inet6Addr{IP: addr.As16()}
}
rm := route.RouteMessage{
Version: unix.RTM_VERSION,
Type: unix.RTM_GET,
Flags: unix.RTF_UP,
ID: uintptr(os.Getpid()),
Seq: 1,
Addrs: []route.Addr{
unix.RTAX_DST: routeAddr,
},
}
b, err := rm.Marshal()
if err != nil {
return 0, fmt.Errorf("marshaling RouteMessage: %w", err)
}
_, err = unix.Write(fd, b)
if err != nil {
return 0, fmt.Errorf("writing message: %w", err)
}
var buf [2048]byte
n, err := unix.Read(fd, buf[:])
if err != nil {
return 0, fmt.Errorf("reading message: %w", err)
}
msgs, err := route.ParseRIB(route.RIBTypeRoute, buf[:n])
if err != nil {
return 0, fmt.Errorf("route.ParseRIB: %w", err)
}
if len(msgs) == 0 {
return 0, fmt.Errorf("no messages")
}
for _, msg := range msgs {
rm, ok := msg.(*route.RouteMessage)
if !ok {
continue
}
if rm.Version < 3 || rm.Version > 5 || rm.Type != unix.RTM_GET {
continue
}
if len(rm.Addrs) < unix.RTAX_GATEWAY {
continue
}
switch addr := rm.Addrs[unix.RTAX_GATEWAY].(type) {
case *route.LinkAddr:
return addr.Index, nil
case *route.Inet4Addr:
// We can get a gateway IP; recursively call ourselves
// (exactly once) to get the link (and thus index) for
// the gateway IP.
if canRecurse {
return interfaceIndexFor(netip.AddrFrom4(addr.IP), false)
}
case *route.Inet6Addr:
// As above.
if canRecurse {
return interfaceIndexFor(netip.AddrFrom16(addr.IP), false)
}
default:
// Unknown type; skip it
continue
}
}
return 0, fmt.Errorf("no valid address found")
}
// SetListenConfigInterfaceIndex sets lc.Control such that sockets are bound
// to the provided interface index.
func SetListenConfigInterfaceIndex(lc *net.ListenConfig, ifIndex int) error {