mirror of
https://github.com/tailscale/tailscale.git
synced 2025-03-13 17:03:52 +00:00
wgengine/netstack: implement netstack loopback (#13301)
When the TS_DEBUG_NETSTACK_LOOPBACK_PORT environment variable is set, netstack will loop back (dnat to addressFamilyLoopback:loopbackPort) TCP & UDP flows originally destined to localServicesIP:loopbackPort. localServicesIP is quad-100 or the IPv6 equivalent. Updates tailscale/corp#22713 Signed-off-by: Jordan Whited <jordan@tailscale.com>
This commit is contained in:
parent
80b2b45d60
commit
d21ebc28af
@ -23,6 +23,7 @@ import (
|
|||||||
"path/filepath"
|
"path/filepath"
|
||||||
"regexp"
|
"regexp"
|
||||||
"runtime"
|
"runtime"
|
||||||
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
@ -1204,13 +1205,157 @@ func TestDNSOverTCPIntervalResolver(t *testing.T) {
|
|||||||
d1.MustCleanShutdown(t)
|
d1.MustCleanShutdown(t)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TestNetstackTCPLoopback tests netstack loopback of a TCP stream, in both
|
||||||
|
// directions.
|
||||||
|
// TODO(jwhited): do the same for UDP
|
||||||
|
func TestNetstackTCPLoopback(t *testing.T) {
|
||||||
|
tstest.Shard(t)
|
||||||
|
if os.Getuid() != 0 {
|
||||||
|
t.Skip("skipping when not root")
|
||||||
|
}
|
||||||
|
|
||||||
|
env := newTestEnv(t)
|
||||||
|
env.tunMode = true
|
||||||
|
loopbackPort := uint16(5201)
|
||||||
|
env.loopbackPort = &loopbackPort
|
||||||
|
loopbackPortStr := strconv.Itoa(int(loopbackPort))
|
||||||
|
n1 := newTestNode(t, env)
|
||||||
|
d1 := n1.StartDaemon()
|
||||||
|
|
||||||
|
n1.AwaitResponding()
|
||||||
|
n1.MustUp()
|
||||||
|
|
||||||
|
n1.AwaitIP4()
|
||||||
|
n1.AwaitRunning()
|
||||||
|
|
||||||
|
cases := []struct {
|
||||||
|
lisAddr string
|
||||||
|
network string
|
||||||
|
dialAddr string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
lisAddr: net.JoinHostPort("127.0.0.1", loopbackPortStr),
|
||||||
|
network: "tcp4",
|
||||||
|
dialAddr: net.JoinHostPort(tsaddr.TailscaleServiceIPString, loopbackPortStr),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
lisAddr: net.JoinHostPort("::1", loopbackPortStr),
|
||||||
|
network: "tcp6",
|
||||||
|
dialAddr: net.JoinHostPort(tsaddr.TailscaleServiceIPv6String, loopbackPortStr),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
writeBufSize := 128 << 10 // 128KiB, exercise GSO if enabled
|
||||||
|
writeBufIterations := 100 // allow TCP send window to open up
|
||||||
|
wantTotal := writeBufSize * writeBufIterations
|
||||||
|
|
||||||
|
for _, c := range cases {
|
||||||
|
lis, err := net.Listen(c.network, c.lisAddr)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
defer lis.Close()
|
||||||
|
|
||||||
|
writeFn := func(conn net.Conn) error {
|
||||||
|
for i := 0; i < writeBufIterations; i++ {
|
||||||
|
toWrite := make([]byte, writeBufSize)
|
||||||
|
var wrote int
|
||||||
|
for {
|
||||||
|
n, err := conn.Write(toWrite)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
wrote += n
|
||||||
|
if wrote == len(toWrite) {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
readFn := func(conn net.Conn) error {
|
||||||
|
var read int
|
||||||
|
for {
|
||||||
|
b := make([]byte, writeBufSize)
|
||||||
|
n, err := conn.Read(b)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
read += n
|
||||||
|
if read == wantTotal {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
lisStepCh := make(chan error)
|
||||||
|
go func() {
|
||||||
|
conn, err := lis.Accept()
|
||||||
|
if err != nil {
|
||||||
|
lisStepCh <- err
|
||||||
|
return
|
||||||
|
}
|
||||||
|
lisStepCh <- readFn(conn)
|
||||||
|
lisStepCh <- writeFn(conn)
|
||||||
|
}()
|
||||||
|
|
||||||
|
var conn net.Conn
|
||||||
|
err = tstest.WaitFor(time.Second*5, func() error {
|
||||||
|
conn, err = net.DialTimeout(c.network, c.dialAddr, time.Second*1)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
defer conn.Close()
|
||||||
|
|
||||||
|
dialerStepCh := make(chan error)
|
||||||
|
go func() {
|
||||||
|
dialerStepCh <- writeFn(conn)
|
||||||
|
dialerStepCh <- readFn(conn)
|
||||||
|
}()
|
||||||
|
|
||||||
|
var (
|
||||||
|
dialerSteps int
|
||||||
|
lisSteps int
|
||||||
|
)
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case lisErr := <-lisStepCh:
|
||||||
|
if lisErr != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
lisSteps++
|
||||||
|
if dialerSteps == 2 && lisSteps == 2 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
case dialerErr := <-dialerStepCh:
|
||||||
|
if dialerErr != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
dialerSteps++
|
||||||
|
if dialerSteps == 2 && lisSteps == 2 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
d1.MustCleanShutdown(t)
|
||||||
|
}
|
||||||
|
|
||||||
// testEnv contains the test environment (set of servers) used by one
|
// testEnv contains the test environment (set of servers) used by one
|
||||||
// or more nodes.
|
// or more nodes.
|
||||||
type testEnv struct {
|
type testEnv struct {
|
||||||
t testing.TB
|
t testing.TB
|
||||||
tunMode bool
|
tunMode bool
|
||||||
cli string
|
cli string
|
||||||
daemon string
|
daemon string
|
||||||
|
loopbackPort *uint16
|
||||||
|
|
||||||
LogCatcher *LogCatcher
|
LogCatcher *LogCatcher
|
||||||
LogCatcherServer *httptest.Server
|
LogCatcherServer *httptest.Server
|
||||||
@ -1511,6 +1656,9 @@ func (n *testNode) StartDaemonAsIPNGOOS(ipnGOOS string) *Daemon {
|
|||||||
"TS_DISABLE_PORTMAPPER=1", // shouldn't be needed; test is all localhost
|
"TS_DISABLE_PORTMAPPER=1", // shouldn't be needed; test is all localhost
|
||||||
"TS_DEBUG_LOG_RATE=all",
|
"TS_DEBUG_LOG_RATE=all",
|
||||||
)
|
)
|
||||||
|
if n.env.loopbackPort != nil {
|
||||||
|
cmd.Env = append(cmd.Env, "TS_DEBUG_NETSTACK_LOOPBACK_PORT="+strconv.Itoa(int(*n.env.loopbackPort)))
|
||||||
|
}
|
||||||
if version.IsRace() {
|
if version.IsRace() {
|
||||||
cmd.Env = append(cmd.Env, "GORACE=halt_on_error=1")
|
cmd.Env = append(cmd.Env, "GORACE=halt_on_error=1")
|
||||||
}
|
}
|
||||||
|
@ -187,6 +187,11 @@ type Impl struct {
|
|||||||
dns *dns.Manager
|
dns *dns.Manager
|
||||||
driveForLocal drive.FileSystemForLocal // or nil
|
driveForLocal drive.FileSystemForLocal // or nil
|
||||||
|
|
||||||
|
// loopbackPort, if non-nil, will enable Impl to loop back (dnat to
|
||||||
|
// <address-family-loopback>:loopbackPort) TCP & UDP flows originally
|
||||||
|
// destined to serviceIP{v6}:loopbackPort.
|
||||||
|
loopbackPort *int
|
||||||
|
|
||||||
peerapiPort4Atomic atomic.Uint32 // uint16 port number for IPv4 peerapi
|
peerapiPort4Atomic atomic.Uint32 // uint16 port number for IPv4 peerapi
|
||||||
peerapiPort6Atomic atomic.Uint32 // uint16 port number for IPv6 peerapi
|
peerapiPort6Atomic atomic.Uint32 // uint16 port number for IPv6 peerapi
|
||||||
|
|
||||||
@ -378,6 +383,10 @@ func Create(logf logger.Logf, tundev *tstun.Wrapper, e wgengine.Engine, mc *magi
|
|||||||
dns: dns,
|
dns: dns,
|
||||||
driveForLocal: driveForLocal,
|
driveForLocal: driveForLocal,
|
||||||
}
|
}
|
||||||
|
loopbackPort, ok := envknob.LookupInt("TS_DEBUG_NETSTACK_LOOPBACK_PORT")
|
||||||
|
if ok && loopbackPort >= 0 && loopbackPort <= math.MaxUint16 {
|
||||||
|
ns.loopbackPort = &loopbackPort
|
||||||
|
}
|
||||||
ns.ctx, ns.ctxCancel = context.WithCancel(context.Background())
|
ns.ctx, ns.ctxCancel = context.WithCancel(context.Background())
|
||||||
ns.atomicIsLocalIPFunc.Store(ipset.FalseContainsIPFunc())
|
ns.atomicIsLocalIPFunc.Store(ipset.FalseContainsIPFunc())
|
||||||
ns.tundev.PostFilterPacketInboundFromWireGuard = ns.injectInbound
|
ns.tundev.PostFilterPacketInboundFromWireGuard = ns.injectInbound
|
||||||
@ -706,6 +715,13 @@ func (ns *Impl) UpdateNetstackIPs(nm *netmap.NetworkMap) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (ns *Impl) isLoopbackPort(port uint16) bool {
|
||||||
|
if ns.loopbackPort != nil && int(port) == *ns.loopbackPort {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
// handleLocalPackets is hooked into the tun datapath for packets leaving
|
// handleLocalPackets is hooked into the tun datapath for packets leaving
|
||||||
// the host and arriving at tailscaled. This method returns filter.DropSilently
|
// the host and arriving at tailscaled. This method returns filter.DropSilently
|
||||||
// to intercept a packet for handling, for instance traffic to quad-100.
|
// to intercept a packet for handling, for instance traffic to quad-100.
|
||||||
@ -724,11 +740,11 @@ func (ns *Impl) handleLocalPackets(p *packet.Parsed, t *tstun.Wrapper) filter.Re
|
|||||||
// 80, and 8080.
|
// 80, and 8080.
|
||||||
switch p.IPProto {
|
switch p.IPProto {
|
||||||
case ipproto.TCP:
|
case ipproto.TCP:
|
||||||
if port := p.Dst.Port(); port != 53 && port != 80 && port != 8080 {
|
if port := p.Dst.Port(); port != 53 && port != 80 && port != 8080 && !ns.isLoopbackPort(port) {
|
||||||
return filter.Accept
|
return filter.Accept
|
||||||
}
|
}
|
||||||
case ipproto.UDP:
|
case ipproto.UDP:
|
||||||
if port := p.Dst.Port(); port != 53 {
|
if port := p.Dst.Port(); port != 53 && !ns.isLoopbackPort(port) {
|
||||||
return filter.Accept
|
return filter.Accept
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -1169,6 +1185,11 @@ func netaddrIPFromNetstackIP(s tcpip.Address) netip.Addr {
|
|||||||
return netip.Addr{}
|
return netip.Addr{}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
ipv4Loopback = netip.MustParseAddr("127.0.0.1")
|
||||||
|
ipv6Loopback = netip.MustParseAddr("::1")
|
||||||
|
)
|
||||||
|
|
||||||
func (ns *Impl) acceptTCP(r *tcp.ForwarderRequest) {
|
func (ns *Impl) acceptTCP(r *tcp.ForwarderRequest) {
|
||||||
reqDetails := r.ID()
|
reqDetails := r.ID()
|
||||||
if debugNetstack() {
|
if debugNetstack() {
|
||||||
@ -1305,8 +1326,15 @@ func (ns *Impl) acceptTCP(r *tcp.ForwarderRequest) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if isTailscaleIP {
|
switch {
|
||||||
dialIP = netaddr.IPv4(127, 0, 0, 1)
|
case hittingServiceIP && ns.isLoopbackPort(reqDetails.LocalPort):
|
||||||
|
if dialIP == serviceIPv6 {
|
||||||
|
dialIP = ipv6Loopback
|
||||||
|
} else {
|
||||||
|
dialIP = ipv4Loopback
|
||||||
|
}
|
||||||
|
case isTailscaleIP:
|
||||||
|
dialIP = ipv4Loopback
|
||||||
}
|
}
|
||||||
dialAddr := netip.AddrPortFrom(dialIP, uint16(reqDetails.LocalPort))
|
dialAddr := netip.AddrPortFrom(dialIP, uint16(reqDetails.LocalPort))
|
||||||
|
|
||||||
@ -1457,16 +1485,23 @@ func (ns *Impl) acceptUDP(r *udp.ForwarderRequest) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Handle magicDNS traffic (via UDP) here.
|
// Handle magicDNS and loopback traffic (via UDP) here.
|
||||||
if dst := dstAddr.Addr(); dst == serviceIP || dst == serviceIPv6 {
|
if dst := dstAddr.Addr(); dst == serviceIP || dst == serviceIPv6 {
|
||||||
if dstAddr.Port() != 53 {
|
switch {
|
||||||
|
case dstAddr.Port() == 53:
|
||||||
|
c := gonet.NewUDPConn(&wq, ep)
|
||||||
|
go ns.handleMagicDNSUDP(srcAddr, c)
|
||||||
|
return
|
||||||
|
case ns.isLoopbackPort(dstAddr.Port()):
|
||||||
|
if dst == serviceIPv6 {
|
||||||
|
dstAddr = netip.AddrPortFrom(ipv6Loopback, dstAddr.Port())
|
||||||
|
} else {
|
||||||
|
dstAddr = netip.AddrPortFrom(ipv4Loopback, dstAddr.Port())
|
||||||
|
}
|
||||||
|
default:
|
||||||
ep.Close()
|
ep.Close()
|
||||||
return // Only MagicDNS traffic runs on the service IPs for now.
|
return // Only MagicDNS and loopback traffic runs on the service IPs for now.
|
||||||
}
|
}
|
||||||
|
|
||||||
c := gonet.NewUDPConn(&wq, ep)
|
|
||||||
go ns.handleMagicDNSUDP(srcAddr, c)
|
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if get := ns.GetUDPHandlerForFlow; get != nil {
|
if get := ns.GetUDPHandlerForFlow; get != nil {
|
||||||
@ -1545,9 +1580,17 @@ func (ns *Impl) forwardUDP(client *gonet.UDPConn, clientAddr, dstAddr netip.Addr
|
|||||||
var backendListenAddr *net.UDPAddr
|
var backendListenAddr *net.UDPAddr
|
||||||
var backendRemoteAddr *net.UDPAddr
|
var backendRemoteAddr *net.UDPAddr
|
||||||
isLocal := ns.isLocalIP(dstAddr.Addr())
|
isLocal := ns.isLocalIP(dstAddr.Addr())
|
||||||
|
isLoopback := dstAddr.Addr() == ipv4Loopback || dstAddr.Addr() == ipv6Loopback
|
||||||
if isLocal {
|
if isLocal {
|
||||||
backendRemoteAddr = &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: int(port)}
|
backendRemoteAddr = &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: int(port)}
|
||||||
backendListenAddr = &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: int(srcPort)}
|
backendListenAddr = &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: int(srcPort)}
|
||||||
|
} else if isLoopback {
|
||||||
|
ip := net.IP(ipv4Loopback.AsSlice())
|
||||||
|
if dstAddr.Addr() == ipv6Loopback {
|
||||||
|
ip = ipv6Loopback.AsSlice()
|
||||||
|
}
|
||||||
|
backendRemoteAddr = &net.UDPAddr{IP: ip, Port: int(port)}
|
||||||
|
backendListenAddr = &net.UDPAddr{IP: ip, Port: int(srcPort)}
|
||||||
} else {
|
} else {
|
||||||
if dstIP := dstAddr.Addr(); viaRange.Contains(dstIP) {
|
if dstIP := dstAddr.Addr(); viaRange.Contains(dstIP) {
|
||||||
dstAddr = netip.AddrPortFrom(tsaddr.UnmapVia(dstIP), dstAddr.Port())
|
dstAddr = netip.AddrPortFrom(tsaddr.UnmapVia(dstIP), dstAddr.Port())
|
||||||
|
Loading…
x
Reference in New Issue
Block a user