wgengine/netstack: implement netstack loopback (#13301)

When the TS_DEBUG_NETSTACK_LOOPBACK_PORT environment variable is set,
netstack will loop back (dnat to addressFamilyLoopback:loopbackPort)
TCP & UDP flows originally destined to localServicesIP:loopbackPort.
localServicesIP is quad-100 or the IPv6 equivalent.

Updates tailscale/corp#22713

Signed-off-by: Jordan Whited <jordan@tailscale.com>
This commit is contained in:
Jordan Whited
2024-08-28 18:50:13 -07:00
committed by GitHub
parent 80b2b45d60
commit d21ebc28af
2 changed files with 206 additions and 15 deletions

View File

@@ -23,6 +23,7 @@ import (
"path/filepath"
"regexp"
"runtime"
"strconv"
"strings"
"sync"
"sync/atomic"
@@ -1204,13 +1205,157 @@ func TestDNSOverTCPIntervalResolver(t *testing.T) {
d1.MustCleanShutdown(t)
}
// TestNetstackTCPLoopback tests netstack loopback of a TCP stream, in both
// directions.
// TODO(jwhited): do the same for UDP
func TestNetstackTCPLoopback(t *testing.T) {
tstest.Shard(t)
if os.Getuid() != 0 {
t.Skip("skipping when not root")
}
env := newTestEnv(t)
env.tunMode = true
loopbackPort := uint16(5201)
env.loopbackPort = &loopbackPort
loopbackPortStr := strconv.Itoa(int(loopbackPort))
n1 := newTestNode(t, env)
d1 := n1.StartDaemon()
n1.AwaitResponding()
n1.MustUp()
n1.AwaitIP4()
n1.AwaitRunning()
cases := []struct {
lisAddr string
network string
dialAddr string
}{
{
lisAddr: net.JoinHostPort("127.0.0.1", loopbackPortStr),
network: "tcp4",
dialAddr: net.JoinHostPort(tsaddr.TailscaleServiceIPString, loopbackPortStr),
},
{
lisAddr: net.JoinHostPort("::1", loopbackPortStr),
network: "tcp6",
dialAddr: net.JoinHostPort(tsaddr.TailscaleServiceIPv6String, loopbackPortStr),
},
}
writeBufSize := 128 << 10 // 128KiB, exercise GSO if enabled
writeBufIterations := 100 // allow TCP send window to open up
wantTotal := writeBufSize * writeBufIterations
for _, c := range cases {
lis, err := net.Listen(c.network, c.lisAddr)
if err != nil {
t.Fatal(err)
}
defer lis.Close()
writeFn := func(conn net.Conn) error {
for i := 0; i < writeBufIterations; i++ {
toWrite := make([]byte, writeBufSize)
var wrote int
for {
n, err := conn.Write(toWrite)
if err != nil {
return err
}
wrote += n
if wrote == len(toWrite) {
break
}
}
}
return nil
}
readFn := func(conn net.Conn) error {
var read int
for {
b := make([]byte, writeBufSize)
n, err := conn.Read(b)
if err != nil {
return err
}
read += n
if read == wantTotal {
return nil
}
}
}
lisStepCh := make(chan error)
go func() {
conn, err := lis.Accept()
if err != nil {
lisStepCh <- err
return
}
lisStepCh <- readFn(conn)
lisStepCh <- writeFn(conn)
}()
var conn net.Conn
err = tstest.WaitFor(time.Second*5, func() error {
conn, err = net.DialTimeout(c.network, c.dialAddr, time.Second*1)
if err != nil {
return err
}
return nil
})
if err != nil {
t.Fatal(err)
}
defer conn.Close()
dialerStepCh := make(chan error)
go func() {
dialerStepCh <- writeFn(conn)
dialerStepCh <- readFn(conn)
}()
var (
dialerSteps int
lisSteps int
)
for {
select {
case lisErr := <-lisStepCh:
if lisErr != nil {
t.Fatal(err)
}
lisSteps++
if dialerSteps == 2 && lisSteps == 2 {
return
}
case dialerErr := <-dialerStepCh:
if dialerErr != nil {
t.Fatal(err)
}
dialerSteps++
if dialerSteps == 2 && lisSteps == 2 {
return
}
}
}
}
d1.MustCleanShutdown(t)
}
// testEnv contains the test environment (set of servers) used by one
// or more nodes.
type testEnv struct {
t testing.TB
tunMode bool
cli string
daemon string
t testing.TB
tunMode bool
cli string
daemon string
loopbackPort *uint16
LogCatcher *LogCatcher
LogCatcherServer *httptest.Server
@@ -1511,6 +1656,9 @@ func (n *testNode) StartDaemonAsIPNGOOS(ipnGOOS string) *Daemon {
"TS_DISABLE_PORTMAPPER=1", // shouldn't be needed; test is all localhost
"TS_DEBUG_LOG_RATE=all",
)
if n.env.loopbackPort != nil {
cmd.Env = append(cmd.Env, "TS_DEBUG_NETSTACK_LOOPBACK_PORT="+strconv.Itoa(int(*n.env.loopbackPort)))
}
if version.IsRace() {
cmd.Env = append(cmd.Env, "GORACE=halt_on_error=1")
}