tailscale/tstest/integration/nat/nat_test.go

278 lines
6.5 KiB
Go
Raw Normal View History

// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause
package nat
import (
"context"
"encoding/json"
"fmt"
"io"
"net"
"net/http"
"net/netip"
"net/url"
"os"
"os/exec"
"path/filepath"
"strings"
"sync"
"testing"
"time"
"golang.org/x/sync/errgroup"
"tailscale.com/ipn/ipnstate"
"tailscale.com/tstest/natlab/vnet"
)
type natTest struct {
tb testing.TB
base string // base image
tempDir string // for qcow2 images
vnet *vnet.Server
}
func newNatTest(tb testing.TB) *natTest {
nt := &natTest{
tb: tb,
tempDir: tb.TempDir(),
base: "/Users/bradfitz/src/tailscale.com/gokrazy/tsapp.qcow2",
}
if _, err := os.Stat(nt.base); err != nil {
tb.Skipf("skipping test; base image %q not found", nt.base)
}
return nt
}
type addNodeFunc func(c *vnet.Config) *vnet.Node
func easy(c *vnet.Config) *vnet.Node {
n := c.NumNodes() + 1
return c.AddNode(c.AddNetwork(
fmt.Sprintf("2.%d.%d.%d", n, n, n), // public IP
fmt.Sprintf("192.168.%d.1/24", n), vnet.EasyNAT))
}
func hard(c *vnet.Config) *vnet.Node {
n := c.NumNodes() + 1
return c.AddNode(c.AddNetwork(
fmt.Sprintf("2.%d.%d.%d", n, n, n), // public IP
fmt.Sprintf("10.0.%d.1/24", n), vnet.HardNAT))
}
func (nt *natTest) runTest(node1, node2 addNodeFunc) {
t := nt.tb
var c vnet.Config
nodes := []*vnet.Node{
node1(&c),
node2(&c),
}
var err error
nt.vnet, err = vnet.New(&c)
if err != nil {
t.Fatalf("newServer: %v", err)
}
nt.tb.Cleanup(func() {
nt.vnet.Close()
})
var wg sync.WaitGroup // waiting for srv.Accept goroutine
defer wg.Wait()
sockAddr := filepath.Join(nt.tempDir, "qemu.sock")
srv, err := net.Listen("unix", sockAddr)
if err != nil {
t.Fatalf("Listen: %v", err)
}
defer srv.Close()
wg.Add(1)
go func() {
defer wg.Done()
for {
c, err := srv.Accept()
if err != nil {
return
}
go nt.vnet.ServeUnixConn(c.(*net.UnixConn), vnet.ProtocolQEMU)
}
}()
for i, node := range nodes {
disk := fmt.Sprintf("%s/node-%d.qcow2", nt.tempDir, i)
out, err := exec.Command("qemu-img", "create",
"-f", "qcow2",
"-F", "qcow2",
"-b", nt.base,
disk).CombinedOutput()
if err != nil {
t.Fatalf("qemu-img create: %v, %s", err, out)
}
cmd := exec.Command("qemu-system-x86_64",
"-M", "microvm,isa-serial=off",
"-m", "1G",
"-nodefaults", "-no-user-config", "-nographic",
"-kernel", "/Users/bradfitz/src/github.com/tailscale/gokrazy-kernel/vmlinuz",
"-append", "console=hvc0 root=PARTUUID=60c24cc1-f3f9-427a-8199-dd02023b0001/PARTNROFF=1 ro init=/gokrazy/init panic=10 oops=panic pci=off nousb tsc=unstable clocksource=hpet tailscale-tta=1",
"-drive", "id=blk0,file="+disk+",format=qcow2",
"-device", "virtio-blk-device,drive=blk0",
"-netdev", "stream,id=net0,addr.type=unix,addr.path="+sockAddr,
"-device", "virtio-serial-device",
"-device", "virtio-net-device,netdev=net0,mac="+node.MAC().String(),
"-chardev", "stdio,id=virtiocon0,mux=on",
"-device", "virtconsole,chardev=virtiocon0",
"-mon", "chardev=virtiocon0,mode=readline",
"-audio", "none",
)
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
if err := cmd.Start(); err != nil {
t.Fatalf("qemu: %v", err)
}
nt.tb.Cleanup(func() {
cmd.Process.Kill()
cmd.Wait()
})
}
ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second)
defer cancel()
c1 := &http.Client{Transport: nt.vnet.NodeAgentRoundTripper(nodes[0])}
c2 := &http.Client{Transport: nt.vnet.NodeAgentRoundTripper(nodes[1])}
var eg errgroup.Group
var sts [2]*ipnstate.Status
for i, c := range []*http.Client{c1, c2} {
i, c := i, c
eg.Go(func() error {
st, err := status(ctx, c)
if err != nil {
return fmt.Errorf("node%d status: %w", i, err)
}
t.Logf("node%d status: %v", i, st)
if err := up(ctx, c); err != nil {
return fmt.Errorf("node%d up: %w", i, err)
}
t.Logf("node%d up!", i)
st, err = status(ctx, c)
if err != nil {
return fmt.Errorf("node%d status: %w", i, err)
}
sts[i] = st
if st.BackendState != "Running" {
return fmt.Errorf("node%d state = %q", i, st.BackendState)
}
t.Logf("node%d up with %v", i, sts[i].Self.TailscaleIPs)
return nil
})
}
if err := eg.Wait(); err != nil {
t.Fatalf("initial setup: %v", err)
}
route, err := ping(ctx, c1, sts[1].Self.TailscaleIPs[0].String())
t.Logf("ping route: %v, %v", route, err)
}
func status(ctx context.Context, c *http.Client) (*ipnstate.Status, error) {
req, err := http.NewRequestWithContext(ctx, "GET", "http://unused/status", nil)
if err != nil {
return nil, err
}
res, err := c.Do(req)
if err != nil {
return nil, err
}
defer res.Body.Close()
all, err := io.ReadAll(res.Body)
if err != nil {
return nil, fmt.Errorf("ReadAll: %w", err)
}
var st ipnstate.Status
if err := json.Unmarshal(all, &st); err != nil {
return nil, fmt.Errorf("JSON marshal error: %v; body was %q", err, all)
}
return &st, nil
}
type routeType string
const (
routeDirect routeType = "direct"
routeDERP routeType = "derp"
routeLAN routeType = "lan"
)
func ping(ctx context.Context, c *http.Client, target string) (routeType, error) {
req, err := http.NewRequestWithContext(ctx, "POST", "http://unused/ping?target="+url.QueryEscape(target), nil)
if err != nil {
return "", err
}
res, err := c.Do(req)
if err != nil {
return "", err
}
defer res.Body.Close()
if res.StatusCode != 200 {
return "", fmt.Errorf("unexpected status code %v", res.Status)
}
all, _ := io.ReadAll(res.Body)
var route routeType
for _, line := range strings.Split(string(all), "\n") {
if strings.Contains(line, " via DERP") {
route = routeDERP
continue
}
// pong from foo (100.82.3.4) via ADDR:PORT in 69ms
if _, rest, ok := strings.Cut(line, " via "); ok {
ipPorStr, _, _ := strings.Cut(rest, " in ")
ipPort, err := netip.ParseAddrPort(ipPorStr)
if err == nil {
if ipPort.Addr().IsPrivate() {
route = routeLAN
} else {
route = routeDirect
}
continue
}
}
}
if route == "" {
return routeType(all), nil
}
return route, nil
}
func up(ctx context.Context, c *http.Client) error {
req, err := http.NewRequestWithContext(ctx, "GET", "http://unused/up", nil)
if err != nil {
return err
}
res, err := c.Do(req)
if err != nil {
return err
}
defer res.Body.Close()
all, _ := io.ReadAll(res.Body)
if res.StatusCode != 200 {
return fmt.Errorf("unexpected status code %v: %s", res.Status, all)
}
return nil
}
func TestEasyEasy(t *testing.T) {
nt := newNatTest(t)
nt.runTest(easy, easy)
}
func TestEasyHard(t *testing.T) {
nt := newNatTest(t)
nt.runTest(easy, hard)
}