// Copyright (c) Tailscale Inc & AUTHORS // SPDX-License-Identifier: BSD-3-Clause package nat import ( "bytes" "cmp" "context" "encoding/json" "errors" "flag" "fmt" "io" "log" "net" "net/http" "net/netip" "os" "os/exec" "path/filepath" "strings" "sync" "testing" "time" "golang.org/x/mod/modfile" "golang.org/x/sync/errgroup" "tailscale.com/client/tailscale" "tailscale.com/ipn/ipnstate" "tailscale.com/syncs" "tailscale.com/tailcfg" "tailscale.com/tstest/natlab/vnet" ) var ( logTailscaled = flag.Bool("log-tailscaled", false, "log tailscaled output") pcapFile = flag.String("pcap", "", "write pcap to file") ) type natTest struct { tb testing.TB base string // base image tempDir string // for qcow2 images vnet *vnet.Server kernel string // linux kernel path } func newNatTest(tb testing.TB) *natTest { root, err := os.Getwd() if err != nil { tb.Fatal(err) } modRoot := filepath.Join(root, "../../..") nt := &natTest{ tb: tb, tempDir: tb.TempDir(), base: filepath.Join(modRoot, "gokrazy/tsapp.qcow2"), } if _, err := os.Stat(nt.base); err != nil { tb.Skipf("skipping test; base image %q not found", nt.base) } nt.kernel, err = findKernelPath(filepath.Join(modRoot, "gokrazy/tsapp/builddir/github.com/tailscale/gokrazy-kernel/go.mod")) if err != nil { tb.Skipf("skipping test; kernel not found: %v", err) } tb.Logf("found kernel: %v", nt.kernel) return nt } func findKernelPath(goMod string) (string, error) { b, err := os.ReadFile(goMod) if err != nil { return "", err } mf, err := modfile.Parse("go.mod", b, nil) if err != nil { return "", err } goModB, err := exec.Command("go", "env", "GOMODCACHE").CombinedOutput() if err != nil { return "", err } for _, r := range mf.Require { if r.Mod.Path == "github.com/tailscale/gokrazy-kernel" { return strings.TrimSpace(string(goModB)) + "/" + r.Mod.String() + "/vmlinuz", nil } } return "", fmt.Errorf("failed to find kernel in %v", goMod) } type addNodeFunc func(c *vnet.Config) *vnet.Node // returns nil to omit test func easy(c *vnet.Config) *vnet.Node { n := c.NumNodes() + 1 return c.AddNode(c.AddNetwork( fmt.Sprintf("2.%d.%d.%d", n, n, n), // public IP fmt.Sprintf("192.168.%d.1/24", n), vnet.EasyNAT)) } func easyAF(c *vnet.Config) *vnet.Node { n := c.NumNodes() + 1 return c.AddNode(c.AddNetwork( fmt.Sprintf("2.%d.%d.%d", n, n, n), // public IP fmt.Sprintf("192.168.%d.1/24", n), vnet.EasyAFNAT)) } func sameLAN(c *vnet.Config) *vnet.Node { nw := c.FirstNetwork() if nw == nil { return nil } if !nw.CanTakeMoreNodes() { return nil } return c.AddNode(nw) } func one2one(c *vnet.Config) *vnet.Node { n := c.NumNodes() + 1 return c.AddNode(c.AddNetwork( fmt.Sprintf("2.%d.%d.%d", n, n, n), // public IP fmt.Sprintf("172.16.%d.1/24", n), vnet.One2OneNAT)) } func easyPMP(c *vnet.Config) *vnet.Node { n := c.NumNodes() + 1 return c.AddNode(c.AddNetwork( fmt.Sprintf("2.%d.%d.%d", n, n, n), // public IP fmt.Sprintf("192.168.%d.1/24", n), vnet.EasyNAT, vnet.NATPMP)) } func hard(c *vnet.Config) *vnet.Node { n := c.NumNodes() + 1 return c.AddNode(c.AddNetwork( fmt.Sprintf("2.%d.%d.%d", n, n, n), // public IP fmt.Sprintf("10.0.%d.1/24", n), vnet.HardNAT)) } func hardPMP(c *vnet.Config) *vnet.Node { n := c.NumNodes() + 1 return c.AddNode(c.AddNetwork( fmt.Sprintf("2.%d.%d.%d", n, n, n), // public IP fmt.Sprintf("10.7.%d.1/24", n), vnet.HardNAT, vnet.NATPMP)) } func (nt *natTest) runTest(node1, node2 addNodeFunc) pingRoute { t := nt.tb var c vnet.Config c.SetPCAPFile(*pcapFile) nodes := []*vnet.Node{ node1(&c), node2(&c), } if nodes[0] == nil || nodes[1] == nil { t.Skip("skipping test; not applicable combination") } var err error nt.vnet, err = vnet.New(&c) if err != nil { t.Fatalf("newServer: %v", err) } nt.tb.Cleanup(func() { nt.vnet.Close() }) var wg sync.WaitGroup // waiting for srv.Accept goroutine defer wg.Wait() sockAddr := filepath.Join(nt.tempDir, "qemu.sock") srv, err := net.Listen("unix", sockAddr) if err != nil { t.Fatalf("Listen: %v", err) } defer srv.Close() wg.Add(1) go func() { defer wg.Done() for { c, err := srv.Accept() if err != nil { return } go nt.vnet.ServeUnixConn(c.(*net.UnixConn), vnet.ProtocolQEMU) } }() for i, node := range nodes { disk := fmt.Sprintf("%s/node-%d.qcow2", nt.tempDir, i) out, err := exec.Command("qemu-img", "create", "-f", "qcow2", "-F", "qcow2", "-b", nt.base, disk).CombinedOutput() if err != nil { t.Fatalf("qemu-img create: %v, %s", err, out) } cmd := exec.Command("qemu-system-x86_64", "-M", "microvm,isa-serial=off", "-m", "384M", "-nodefaults", "-no-user-config", "-nographic", "-kernel", nt.kernel, "-append", "console=hvc0 root=PARTUUID=60c24cc1-f3f9-427a-8199-dd02023b0001/PARTNROFF=1 ro init=/gokrazy/init panic=10 oops=panic pci=off nousb tsc=unstable clocksource=hpet tailscale-tta=1", "-drive", "id=blk0,file="+disk+",format=qcow2", "-device", "virtio-blk-device,drive=blk0", "-netdev", "stream,id=net0,addr.type=unix,addr.path="+sockAddr, "-device", "virtio-serial-device", "-device", "virtio-net-device,netdev=net0,mac="+node.MAC().String(), "-chardev", "stdio,id=virtiocon0,mux=on", "-device", "virtconsole,chardev=virtiocon0", "-mon", "chardev=virtiocon0,mode=readline", "-audio", "none", ) cmd.Stdout = os.Stdout cmd.Stderr = os.Stderr if err := cmd.Start(); err != nil { t.Fatalf("qemu: %v", err) } nt.tb.Cleanup(func() { cmd.Process.Kill() cmd.Wait() }) } ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second) defer cancel() lc1 := nt.vnet.NodeAgentClient(nodes[0]) lc2 := nt.vnet.NodeAgentClient(nodes[1]) clients := []*vnet.NodeAgentClient{lc1, lc2} var eg errgroup.Group var sts [2]*ipnstate.Status for i, c := range clients { i, c := i, c eg.Go(func() error { if *logTailscaled { wg.Add(1) go func() { defer wg.Done() streamDaemonLogs(ctx, t, c, fmt.Sprintf("node%d:", i)) }() } st, err := c.Status(ctx) if err != nil { return fmt.Errorf("node%d status: %w", i, err) } t.Logf("node%d status: %v", i, st) if err := up(ctx, c); err != nil { return fmt.Errorf("node%d up: %w", i, err) } t.Logf("node%d up!", i) st, err = c.Status(ctx) if err != nil { return fmt.Errorf("node%d status: %w", i, err) } sts[i] = st if st.BackendState != "Running" { return fmt.Errorf("node%d state = %q", i, st.BackendState) } t.Logf("node%d up with %v", i, sts[i].Self.TailscaleIPs) return nil }) } if err := eg.Wait(); err != nil { t.Fatalf("initial setup: %v", err) } defer nt.vnet.Close() pingRes, err := ping(ctx, lc1, sts[1].Self.TailscaleIPs[0]) if err != nil { t.Fatalf("ping failure: %v", err) } route := classifyPing(pingRes) t.Logf("ping route: %v", route) return route } func classifyPing(pr *ipnstate.PingResult) pingRoute { if pr == nil { return routeNil } if pr.Endpoint != "" { ap, err := netip.ParseAddrPort(pr.Endpoint) if err == nil { if ap.Addr().IsPrivate() { return routeLocal } return routeDirect } } return routeDERP // presumably } type pingRoute string const ( routeDERP pingRoute = "derp" routeLocal pingRoute = "local" routeDirect pingRoute = "direct" routeNil pingRoute = "nil" // *ipnstate.PingResult is nil ) func streamDaemonLogs(ctx context.Context, t testing.TB, c *vnet.NodeAgentClient, nodeID string) { ctx, cancel := context.WithCancel(ctx) defer cancel() r, err := c.TailDaemonLogs(ctx) if err != nil { t.Errorf("tailDaemonLogs: %v", err) return } logger := log.New(os.Stderr, nodeID+" ", log.Lmsgprefix) dec := json.NewDecoder(r) for { // /{"logtail":{"client_time":"2024-08-08T17:42:31.95095956Z","proc_id":2024742977,"proc_seq":232},"text":"magicsock: derp-1 connected; connGen=1\n"} var logEntry struct { LogTail struct { ClientTime time.Time `json:"client_time"` } Text string `json:"text"` } if err := dec.Decode(&logEntry); err != nil { if err == io.EOF || errors.Is(err, context.Canceled) { return } t.Errorf("log entry: %v", err) return } logger.Printf("%s %s", logEntry.LogTail.ClientTime.Format("2006/01/02 15:04:05"), logEntry.Text) } } func ping(ctx context.Context, c *vnet.NodeAgentClient, target netip.Addr) (*ipnstate.PingResult, error) { n := 0 var res *ipnstate.PingResult anyPong := false for n < 10 { n++ pr, err := c.PingWithOpts(ctx, target, tailcfg.PingDisco, tailscale.PingOpts{}) if err != nil { if anyPong { return res, nil } return nil, err } if pr.Err != "" { return nil, errors.New(pr.Err) } if pr.DERPRegionID == 0 { return pr, nil } res = pr select { case <-ctx.Done(): case <-time.After(time.Second): } } if res == nil { return nil, errors.New("no ping response") } return res, nil } func up(ctx context.Context, c *vnet.NodeAgentClient) error { req, err := http.NewRequestWithContext(ctx, "GET", "http://unused/up", nil) if err != nil { return err } res, err := c.HTTPClient.Do(req) if err != nil { return err } defer res.Body.Close() all, _ := io.ReadAll(res.Body) if res.StatusCode != 200 { return fmt.Errorf("unexpected status code %v: %s", res.Status, all) } return nil } type nodeType struct { name string fn addNodeFunc } var types = []nodeType{ {"easy", easy}, {"easyAF", easyAF}, {"hard", hard}, {"easyPMP", easyPMP}, {"hardPMP", hardPMP}, {"one2one", one2one}, {"sameLAN", sameLAN}, } func TestEasyEasy(t *testing.T) { nt := newNatTest(t) nt.runTest(easy, easy) } var pair = flag.String("pair", "", "comma-separated pair of types to test (easy, easyAF, hard, easyPMP, hardPMP, one2one, sameLAN)") func TestPair(t *testing.T) { t1, t2, ok := strings.Cut(*pair, ",") if !ok { t.Skipf("skipping test without --pair=type1,type2 set") } find := func(name string) addNodeFunc { for _, nt := range types { if nt.name == name { return nt.fn } } t.Fatalf("unknown type %q", name) return nil } nt := newNatTest(t) nt.runTest(find(t1), find(t2)) } var runGrid = flag.Bool("run-grid", false, "run grid test") func TestGrid(t *testing.T) { if !*runGrid { t.Skip("skipping grid test; set --run-grid to run") } t.Parallel() sem := syncs.NewSemaphore(2) var ( mu sync.Mutex res = make(map[string]pingRoute) ) for _, a := range types { for _, b := range types { key := a.name + "-" + b.name keyBack := b.name + "-" + a.name t.Run(key, func(t *testing.T) { t.Parallel() sem.Acquire() defer sem.Release() filename := key + ".cache" contents, _ := os.ReadFile(filename) if len(contents) == 0 { filename2 := keyBack + ".cache" contents, _ = os.ReadFile(filename2) } route := pingRoute(strings.TrimSpace(string(contents))) if route == "" { nt := newNatTest(t) route = nt.runTest(a.fn, b.fn) if err := os.WriteFile(filename, []byte(string(route)), 0666); err != nil { t.Fatalf("writeFile: %v", err) } } mu.Lock() defer mu.Unlock() res[key] = route t.Logf("results: %v", res) }) } } t.Cleanup(func() { mu.Lock() defer mu.Unlock() var hb bytes.Buffer pf := func(format string, args ...any) { fmt.Fprintf(&hb, format, args...) } rewrite := func(s string) string { return strings.ReplaceAll(s, "PMP", "+pm") } pf("
") for _, a := range types { pf(" | %s | ", rewrite(a.name)) } pf("||
%s | ", rewrite(a.name)) for _, b := range types { key := a.name + "-" + b.name key2 := b.name + "-" + a.name v := cmp.Or(res[key], res[key2], "-") if v == "derp" { pf("%s | ", v)
} else if v == "local" {
pf("%s | ", v)
} else {
pf("%s | ", v) } } pf("