mirror of
https://github.com/tailscale/tailscale.git
synced 2024-11-29 21:15:39 +00:00
278 lines
6.5 KiB
Go
278 lines
6.5 KiB
Go
|
// Copyright (c) Tailscale Inc & AUTHORS
|
||
|
// SPDX-License-Identifier: BSD-3-Clause
|
||
|
|
||
|
package nat
|
||
|
|
||
|
import (
|
||
|
"context"
|
||
|
"encoding/json"
|
||
|
"fmt"
|
||
|
"io"
|
||
|
"net"
|
||
|
"net/http"
|
||
|
"net/netip"
|
||
|
"net/url"
|
||
|
"os"
|
||
|
"os/exec"
|
||
|
"path/filepath"
|
||
|
"strings"
|
||
|
"sync"
|
||
|
"testing"
|
||
|
"time"
|
||
|
|
||
|
"golang.org/x/sync/errgroup"
|
||
|
"tailscale.com/ipn/ipnstate"
|
||
|
"tailscale.com/tstest/natlab/vnet"
|
||
|
)
|
||
|
|
||
|
type natTest struct {
|
||
|
tb testing.TB
|
||
|
base string // base image
|
||
|
tempDir string // for qcow2 images
|
||
|
vnet *vnet.Server
|
||
|
}
|
||
|
|
||
|
func newNatTest(tb testing.TB) *natTest {
|
||
|
nt := &natTest{
|
||
|
tb: tb,
|
||
|
tempDir: tb.TempDir(),
|
||
|
base: "/Users/bradfitz/src/tailscale.com/gokrazy/tsapp.qcow2",
|
||
|
}
|
||
|
|
||
|
if _, err := os.Stat(nt.base); err != nil {
|
||
|
tb.Skipf("skipping test; base image %q not found", nt.base)
|
||
|
}
|
||
|
return nt
|
||
|
}
|
||
|
|
||
|
type addNodeFunc func(c *vnet.Config) *vnet.Node
|
||
|
|
||
|
func easy(c *vnet.Config) *vnet.Node {
|
||
|
n := c.NumNodes() + 1
|
||
|
return c.AddNode(c.AddNetwork(
|
||
|
fmt.Sprintf("2.%d.%d.%d", n, n, n), // public IP
|
||
|
fmt.Sprintf("192.168.%d.1/24", n), vnet.EasyNAT))
|
||
|
}
|
||
|
|
||
|
func hard(c *vnet.Config) *vnet.Node {
|
||
|
n := c.NumNodes() + 1
|
||
|
return c.AddNode(c.AddNetwork(
|
||
|
fmt.Sprintf("2.%d.%d.%d", n, n, n), // public IP
|
||
|
fmt.Sprintf("10.0.%d.1/24", n), vnet.HardNAT))
|
||
|
}
|
||
|
|
||
|
func (nt *natTest) runTest(node1, node2 addNodeFunc) {
|
||
|
t := nt.tb
|
||
|
|
||
|
var c vnet.Config
|
||
|
nodes := []*vnet.Node{
|
||
|
node1(&c),
|
||
|
node2(&c),
|
||
|
}
|
||
|
|
||
|
var err error
|
||
|
nt.vnet, err = vnet.New(&c)
|
||
|
if err != nil {
|
||
|
t.Fatalf("newServer: %v", err)
|
||
|
}
|
||
|
nt.tb.Cleanup(func() {
|
||
|
nt.vnet.Close()
|
||
|
})
|
||
|
|
||
|
var wg sync.WaitGroup // waiting for srv.Accept goroutine
|
||
|
defer wg.Wait()
|
||
|
|
||
|
sockAddr := filepath.Join(nt.tempDir, "qemu.sock")
|
||
|
srv, err := net.Listen("unix", sockAddr)
|
||
|
if err != nil {
|
||
|
t.Fatalf("Listen: %v", err)
|
||
|
}
|
||
|
defer srv.Close()
|
||
|
|
||
|
wg.Add(1)
|
||
|
go func() {
|
||
|
defer wg.Done()
|
||
|
for {
|
||
|
c, err := srv.Accept()
|
||
|
if err != nil {
|
||
|
return
|
||
|
}
|
||
|
go nt.vnet.ServeUnixConn(c.(*net.UnixConn), vnet.ProtocolQEMU)
|
||
|
}
|
||
|
}()
|
||
|
|
||
|
for i, node := range nodes {
|
||
|
disk := fmt.Sprintf("%s/node-%d.qcow2", nt.tempDir, i)
|
||
|
out, err := exec.Command("qemu-img", "create",
|
||
|
"-f", "qcow2",
|
||
|
"-F", "qcow2",
|
||
|
"-b", nt.base,
|
||
|
disk).CombinedOutput()
|
||
|
if err != nil {
|
||
|
t.Fatalf("qemu-img create: %v, %s", err, out)
|
||
|
}
|
||
|
|
||
|
cmd := exec.Command("qemu-system-x86_64",
|
||
|
"-M", "microvm,isa-serial=off",
|
||
|
"-m", "1G",
|
||
|
"-nodefaults", "-no-user-config", "-nographic",
|
||
|
"-kernel", "/Users/bradfitz/src/github.com/tailscale/gokrazy-kernel/vmlinuz",
|
||
|
"-append", "console=hvc0 root=PARTUUID=60c24cc1-f3f9-427a-8199-dd02023b0001/PARTNROFF=1 ro init=/gokrazy/init panic=10 oops=panic pci=off nousb tsc=unstable clocksource=hpet tailscale-tta=1",
|
||
|
"-drive", "id=blk0,file="+disk+",format=qcow2",
|
||
|
"-device", "virtio-blk-device,drive=blk0",
|
||
|
"-netdev", "stream,id=net0,addr.type=unix,addr.path="+sockAddr,
|
||
|
"-device", "virtio-serial-device",
|
||
|
"-device", "virtio-net-device,netdev=net0,mac="+node.MAC().String(),
|
||
|
"-chardev", "stdio,id=virtiocon0,mux=on",
|
||
|
"-device", "virtconsole,chardev=virtiocon0",
|
||
|
"-mon", "chardev=virtiocon0,mode=readline",
|
||
|
"-audio", "none",
|
||
|
)
|
||
|
cmd.Stdout = os.Stdout
|
||
|
cmd.Stderr = os.Stderr
|
||
|
if err := cmd.Start(); err != nil {
|
||
|
t.Fatalf("qemu: %v", err)
|
||
|
}
|
||
|
nt.tb.Cleanup(func() {
|
||
|
cmd.Process.Kill()
|
||
|
cmd.Wait()
|
||
|
})
|
||
|
}
|
||
|
|
||
|
ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second)
|
||
|
defer cancel()
|
||
|
|
||
|
c1 := &http.Client{Transport: nt.vnet.NodeAgentRoundTripper(nodes[0])}
|
||
|
c2 := &http.Client{Transport: nt.vnet.NodeAgentRoundTripper(nodes[1])}
|
||
|
|
||
|
var eg errgroup.Group
|
||
|
var sts [2]*ipnstate.Status
|
||
|
for i, c := range []*http.Client{c1, c2} {
|
||
|
i, c := i, c
|
||
|
eg.Go(func() error {
|
||
|
st, err := status(ctx, c)
|
||
|
if err != nil {
|
||
|
return fmt.Errorf("node%d status: %w", i, err)
|
||
|
}
|
||
|
t.Logf("node%d status: %v", i, st)
|
||
|
if err := up(ctx, c); err != nil {
|
||
|
return fmt.Errorf("node%d up: %w", i, err)
|
||
|
}
|
||
|
t.Logf("node%d up!", i)
|
||
|
st, err = status(ctx, c)
|
||
|
if err != nil {
|
||
|
return fmt.Errorf("node%d status: %w", i, err)
|
||
|
}
|
||
|
sts[i] = st
|
||
|
|
||
|
if st.BackendState != "Running" {
|
||
|
return fmt.Errorf("node%d state = %q", i, st.BackendState)
|
||
|
}
|
||
|
t.Logf("node%d up with %v", i, sts[i].Self.TailscaleIPs)
|
||
|
return nil
|
||
|
})
|
||
|
}
|
||
|
if err := eg.Wait(); err != nil {
|
||
|
t.Fatalf("initial setup: %v", err)
|
||
|
}
|
||
|
|
||
|
route, err := ping(ctx, c1, sts[1].Self.TailscaleIPs[0].String())
|
||
|
t.Logf("ping route: %v, %v", route, err)
|
||
|
}
|
||
|
|
||
|
func status(ctx context.Context, c *http.Client) (*ipnstate.Status, error) {
|
||
|
req, err := http.NewRequestWithContext(ctx, "GET", "http://unused/status", nil)
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
res, err := c.Do(req)
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
defer res.Body.Close()
|
||
|
all, err := io.ReadAll(res.Body)
|
||
|
if err != nil {
|
||
|
return nil, fmt.Errorf("ReadAll: %w", err)
|
||
|
}
|
||
|
var st ipnstate.Status
|
||
|
if err := json.Unmarshal(all, &st); err != nil {
|
||
|
return nil, fmt.Errorf("JSON marshal error: %v; body was %q", err, all)
|
||
|
}
|
||
|
return &st, nil
|
||
|
}
|
||
|
|
||
|
type routeType string
|
||
|
|
||
|
const (
|
||
|
routeDirect routeType = "direct"
|
||
|
routeDERP routeType = "derp"
|
||
|
routeLAN routeType = "lan"
|
||
|
)
|
||
|
|
||
|
func ping(ctx context.Context, c *http.Client, target string) (routeType, error) {
|
||
|
req, err := http.NewRequestWithContext(ctx, "POST", "http://unused/ping?target="+url.QueryEscape(target), nil)
|
||
|
if err != nil {
|
||
|
return "", err
|
||
|
}
|
||
|
res, err := c.Do(req)
|
||
|
if err != nil {
|
||
|
return "", err
|
||
|
}
|
||
|
defer res.Body.Close()
|
||
|
if res.StatusCode != 200 {
|
||
|
return "", fmt.Errorf("unexpected status code %v", res.Status)
|
||
|
}
|
||
|
all, _ := io.ReadAll(res.Body)
|
||
|
var route routeType
|
||
|
for _, line := range strings.Split(string(all), "\n") {
|
||
|
if strings.Contains(line, " via DERP") {
|
||
|
route = routeDERP
|
||
|
continue
|
||
|
}
|
||
|
// pong from foo (100.82.3.4) via ADDR:PORT in 69ms
|
||
|
if _, rest, ok := strings.Cut(line, " via "); ok {
|
||
|
ipPorStr, _, _ := strings.Cut(rest, " in ")
|
||
|
ipPort, err := netip.ParseAddrPort(ipPorStr)
|
||
|
if err == nil {
|
||
|
if ipPort.Addr().IsPrivate() {
|
||
|
route = routeLAN
|
||
|
} else {
|
||
|
route = routeDirect
|
||
|
}
|
||
|
continue
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
if route == "" {
|
||
|
return routeType(all), nil
|
||
|
}
|
||
|
return route, nil
|
||
|
}
|
||
|
|
||
|
func up(ctx context.Context, c *http.Client) error {
|
||
|
req, err := http.NewRequestWithContext(ctx, "GET", "http://unused/up", nil)
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
res, err := c.Do(req)
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
defer res.Body.Close()
|
||
|
all, _ := io.ReadAll(res.Body)
|
||
|
if res.StatusCode != 200 {
|
||
|
return fmt.Errorf("unexpected status code %v: %s", res.Status, all)
|
||
|
}
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
func TestEasyEasy(t *testing.T) {
|
||
|
nt := newNatTest(t)
|
||
|
nt.runTest(easy, easy)
|
||
|
}
|
||
|
|
||
|
func TestEasyHard(t *testing.T) {
|
||
|
nt := newNatTest(t)
|
||
|
nt.runTest(easy, hard)
|
||
|
}
|