diff --git a/cmd/tta/tta.go b/cmd/tta/tta.go new file mode 100644 index 000000000..c7f587c4b --- /dev/null +++ b/cmd/tta/tta.go @@ -0,0 +1,159 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// The tta server is the Tailscale Test Agent. +// +// It runs on each Tailscale node being integration tested and permits the test +// harness to control the node. It connects out to the test drver (rather than +// accepting any TCP connections inbound, which might be blocked depending on +// the scenario being tested) and then the test driver turns the TCP connection +// around and sends request back. +package main + +import ( + "bytes" + "errors" + "flag" + "fmt" + "io" + "log" + "net" + "net/http" + "os" + "os/exec" + "strings" + "sync" + "time" + + "tailscale.com/util/set" + "tailscale.com/version/distro" +) + +var ( + driverAddr = flag.String("driver", "test-driver.tailscale:8008", "address of the test driver; by default we use the DNS name test-driver.tailscale which is special cased in the emulated network's DNS server") +) + +type chanListener <-chan net.Conn + +func serveCmd(w http.ResponseWriter, cmd string, args ...string) { + if distro.Get() == distro.Gokrazy && !strings.Contains(cmd, "/") { + cmd = "/user/" + cmd + } + out, err := exec.Command(cmd, args...).CombinedOutput() + w.Header().Set("Content-Type", "text/plain; charset=utf-8") + if err != nil { + w.Header().Set("Exec-Err", err.Error()) + w.WriteHeader(500) + } + w.Write(out) +} + +func main() { + if distro.Get() == distro.Gokrazy { + cmdLine, _ := os.ReadFile("/proc/cmdline") + if !bytes.Contains(cmdLine, []byte("tailscale-tta=1")) { + // "Exiting immediately with status code 0 when the + // GOKRAZY_FIRST_START=1 environment variable is set means “don’t + // start the program on boot”" + return + } + } + flag.Parse() + log.Printf("Tailscale Test Agent running.") + + var mux http.ServeMux + var hs http.Server + hs.Handler = &mux + var ( + stMu sync.Mutex + newSet = set.Set[net.Conn]{} // conns in StateNew + ) + needConnCh := make(chan bool, 1) + hs.ConnState = func(c net.Conn, s http.ConnState) { + stMu.Lock() + defer stMu.Unlock() + switch s { + case http.StateNew: + newSet.Add(c) + case http.StateClosed: + newSet.Delete(c) + } + if len(newSet) == 0 { + select { + case needConnCh <- true: + default: + } + } + } + conns := make(chan net.Conn, 1) + + mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { + io.WriteString(w, "TTA\n") + return + }) + mux.HandleFunc("/up", func(w http.ResponseWriter, r *http.Request) { + serveCmd(w, "tailscale", "up", "--auth-key=test") + }) + mux.HandleFunc("/status", func(w http.ResponseWriter, r *http.Request) { + serveCmd(w, "tailscale", "status", "--json") + }) + mux.HandleFunc("/ping", func(w http.ResponseWriter, r *http.Request) { + target := r.FormValue("target") + cmd := exec.Command("tailscale", "ping", target) + w.Header().Set("Content-Type", "text/plain; charset=utf-8") + w.(http.Flusher).Flush() + cmd.Stdout = w + cmd.Stderr = w + if err := cmd.Run(); err != nil { + fmt.Fprintf(w, "error: %v\n", err) + } + }) + go hs.Serve(chanListener(conns)) + + var lastErr string + needConnCh <- true + for { + <-needConnCh + c, err := connect() + log.Printf("Connect: %v", err) + if err != nil { + s := err.Error() + if s != lastErr { + log.Printf("Connect failure: %v", s) + } + lastErr = s + time.Sleep(time.Second) + continue + } + conns <- c + + time.Sleep(time.Second) + } +} + +func connect() (net.Conn, error) { + c, err := net.Dial("tcp", *driverAddr) + if err != nil { + return nil, err + } + return c, nil +} + +func (cl chanListener) Accept() (net.Conn, error) { + c, ok := <-cl + if !ok { + return nil, errors.New("closed") + } + return c, nil +} + +func (cl chanListener) Close() error { + return nil +} + +func (cl chanListener) Addr() net.Addr { + return &net.TCPAddr{ + IP: net.ParseIP("52.0.0.34"), // TS..DR(iver) + Port: 123, + } +} diff --git a/cmd/vnet/run-krazy.sh b/cmd/vnet/run-krazy.sh new file mode 100755 index 000000000..6cf608b62 --- /dev/null +++ b/cmd/vnet/run-krazy.sh @@ -0,0 +1,20 @@ +#!/usr/bin/env bash + +echo "Type 'C-a c' to enter monitor; q to quit." + +set -eux +qemu-system-x86_64 -M microvm,isa-serial=off \ + -m 1G \ + -nodefaults -no-user-config -nographic \ + -kernel $HOME/src/github.com/tailscale/gokrazy-kernel/vmlinuz \ + -append "console=hvc0 root=PARTUUID=60c24cc1-f3f9-427a-8199-dd02023b0001/PARTNROFF=1 ro init=/gokrazy/init panic=10 oops=panic pci=off nousb tsc=unstable clocksource=hpet tailscale-tta=1" \ + -drive id=blk0,file=$HOME/src/tailscale.com/gokrazy/tsapp.img,format=raw \ + -device virtio-blk-device,drive=blk0 \ + -netdev stream,id=net0,addr.type=unix,addr.path=/tmp/qemu.sock \ + -device virtio-serial-device \ + -device virtio-net-device,netdev=net0,mac=52:cc:cc:cc:cc:00 \ + -chardev stdio,id=virtiocon0,mux=on \ + -device virtconsole,chardev=virtiocon0 \ + -mon chardev=virtiocon0,mode=readline \ + -audio none + diff --git a/cmd/vnet/vnet-main.go b/cmd/vnet/vnet-main.go new file mode 100644 index 000000000..31e11f89f --- /dev/null +++ b/cmd/vnet/vnet-main.go @@ -0,0 +1,101 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// The vnet binary runs a virtual network stack in userspace for qemu instances +// to connect to and simulate various network conditions. +package main + +import ( + "context" + "flag" + "log" + "net" + "os" + "time" + + "tailscale.com/tstest/natlab/vnet" +) + +var ( + listen = flag.String("listen", "/tmp/qemu.sock", "path to listen on") + nat = flag.String("nat", "easy", "type of NAT to use") + portmap = flag.Bool("portmap", false, "enable portmapping") + dgram = flag.Bool("dgram", false, "enable datagram mode; for use with macOS Hypervisor.Framework and VZFileHandleNetworkDeviceAttachment") +) + +func main() { + flag.Parse() + + if _, err := os.Stat(*listen); err == nil { + os.Remove(*listen) + } + + var srv net.Listener + var err error + var conn *net.UnixConn + if *dgram { + addr, err := net.ResolveUnixAddr("unixgram", *listen) + if err != nil { + log.Fatalf("ResolveUnixAddr: %v", err) + } + conn, err = net.ListenUnixgram("unixgram", addr) + if err != nil { + log.Fatalf("ListenUnixgram: %v", err) + } + defer conn.Close() + } else { + srv, err = net.Listen("unix", *listen) + } + if err != nil { + log.Fatal(err) + } + + var c vnet.Config + node1 := c.AddNode(c.AddNetwork("2.1.1.1", "192.168.1.1/24", vnet.NAT(*nat))) + c.AddNode(c.AddNetwork("2.2.2.2", "10.2.0.1/16", vnet.NAT(*nat))) + if *portmap { + node1.Network().AddService(vnet.NATPMP) + } + + s, err := vnet.New(&c) + if err != nil { + log.Fatalf("newServer: %v", err) + } + + if err := s.PopulateDERPMapIPs(); err != nil { + log.Printf("warning: ignoring failure to populate DERP map: %v", err) + } + + s.WriteStartingBanner(os.Stdout) + + go func() { + getStatus := func() { + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + st, err := s.NodeStatus(ctx, node1) + if err != nil { + log.Printf("NodeStatus: %v", err) + return + } + log.Printf("NodeStatus: %q", st) + } + for { + time.Sleep(5 * time.Second) + getStatus() + } + }() + + if conn != nil { + s.ServeUnixConn(conn, vnet.ProtocolUnixDGRAM) + return + } + + for { + c, err := srv.Accept() + if err != nil { + log.Printf("Accept: %v", err) + continue + } + go s.ServeUnixConn(c.(*net.UnixConn), vnet.ProtocolQEMU) + } +} diff --git a/go.mod b/go.mod index ff8adf5b4..1bf0ae650 100644 --- a/go.mod +++ b/go.mod @@ -39,6 +39,7 @@ require ( github.com/golangci/golangci-lint v1.52.2 github.com/google/go-cmp v0.6.0 github.com/google/go-containerregistry v0.18.0 + github.com/google/gopacket v1.1.19 github.com/google/nftables v0.2.1-0.20240414091927-5e242ec57806 github.com/google/uuid v1.6.0 github.com/goreleaser/nfpm/v2 v2.33.1 diff --git a/go.sum b/go.sum index e5f8a4673..bf6a126ba 100644 --- a/go.sum +++ b/go.sum @@ -477,6 +477,8 @@ github.com/google/go-containerregistry v0.18.0/go.mod h1:u0qB2l7mvtWVR5kNcbFIhFY github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= github.com/google/gofuzz v1.2.0 h1:xRy4A+RhZaiKjJ1bPfwQ8sedCA+YS2YcCHW6ec7JMi0= github.com/google/gofuzz v1.2.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= +github.com/google/gopacket v1.1.19 h1:ves8RnFZPGiFnTS0uPQStjwru6uO6h+nlr9j6fL7kF8= +github.com/google/gopacket v1.1.19/go.mod h1:iJ8V8n6KS+z2U1A8pUwu8bW5SyEMkXJB8Yo/Vo+TKTo= github.com/google/goterm v0.0.0-20200907032337-555d40f16ae2 h1:CVuJwN34x4xM2aT4sIKhmeib40NeBPhRihNjQmpJsA4= github.com/google/goterm v0.0.0-20200907032337-555d40f16ae2/go.mod h1:nOFQdrUlIlx6M6ODdSpBj1NVA+VgLC6kmw60mkw34H4= github.com/google/martian v2.1.0+incompatible/go.mod h1:9I4somxYTbIHy5NJKHRl3wXiIaQGbYVAs8BPL6v8lEs= diff --git a/gokrazy/tsapp/builddir/tailscale.com/go.sum b/gokrazy/tsapp/builddir/tailscale.com/go.sum index 5e82db5d7..b3b73e2d0 100644 --- a/gokrazy/tsapp/builddir/tailscale.com/go.sum +++ b/gokrazy/tsapp/builddir/tailscale.com/go.sum @@ -122,8 +122,12 @@ github.com/tailscale/web-client-prebuilt v0.0.0-20240226180453-5db17b287bf1 h1:t github.com/tailscale/web-client-prebuilt v0.0.0-20240226180453-5db17b287bf1/go.mod h1:agQPE6y6ldqCOui2gkIh7ZMztTkIQKH049tv8siLuNQ= github.com/tailscale/wireguard-go v0.0.0-20240705152531-2f5d148bcfe1 h1:ycpNCSYwzZ7x4G4ioPNtKQmIY0G/3o4pVf8wCZq6blY= github.com/tailscale/wireguard-go v0.0.0-20240705152531-2f5d148bcfe1/go.mod h1:BOm5fXUBFM+m9woLNBoxI9TaBXXhGNP50LX/TGIvGb4= +github.com/tailscale/wireguard-go v0.0.0-20240731203015-71393c576b98 h1:RNpJrXfI5u6e+uzyIzvmnXbhmhdRkVf//90sMBH3lso= +github.com/tailscale/wireguard-go v0.0.0-20240731203015-71393c576b98/go.mod h1:BOm5fXUBFM+m9woLNBoxI9TaBXXhGNP50LX/TGIvGb4= github.com/tailscale/xnet v0.0.0-20240117122442-62b9a7c569f9 h1:81P7rjnikHKTJ75EkjppvbwUfKHDHYk6LJpO5PZy8pA= github.com/tailscale/xnet v0.0.0-20240117122442-62b9a7c569f9/go.mod h1:orPd6JZXXRyuDusYilywte7k094d7dycXXU5YnWsrwg= +github.com/tailscale/xnet v0.0.0-20240729143630-8497ac4dab2e h1:zOGKqN5D5hHhiYUp091JqK7DPCqSARyUfduhGUY8Bek= +github.com/tailscale/xnet v0.0.0-20240729143630-8497ac4dab2e/go.mod h1:orPd6JZXXRyuDusYilywte7k094d7dycXXU5YnWsrwg= github.com/tcnksm/go-httpstat v0.2.0 h1:rP7T5e5U2HfmOBmZzGgGZjBQ5/GluWUylujl0tJ04I0= github.com/tcnksm/go-httpstat v0.2.0/go.mod h1:s3JVJFtQxtBEBC9dwcdTTXS9xFnM3SXAZwPG41aurT8= github.com/toqueteos/webbrowser v1.2.0 h1:tVP/gpK69Fx+qMJKsLE7TD8LuGWPnEV71wBN9rrstGQ= @@ -170,6 +174,8 @@ golang.org/x/time v0.5.0 h1:o7cqy6amK/52YcAKIPlM3a+Fpj35zvRj2TP+e1xFSfk= golang.org/x/time v0.5.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM= gvisor.dev/gvisor v0.0.0-20240306221502-ee1e1f6070e3 h1:/8/t5pz/mgdRXhYOIeqqYhFAQLE4DDGegc0Y4ZjyFJM= gvisor.dev/gvisor v0.0.0-20240306221502-ee1e1f6070e3/go.mod h1:NQHVAzMwvZ+Qe3ElSiHmq9RUm1MdNHpUZ52fiEqvn+0= +gvisor.dev/gvisor v0.0.0-20240722211153-64c016c92987 h1:TU8z2Lh3Bbq77w0t1eG8yRlLcNHzZu3x6mhoH2Mk0c8= +gvisor.dev/gvisor v0.0.0-20240722211153-64c016c92987/go.mod h1:sxc3Uvk/vHcd3tj7/DHVBoR5wvWT/MmRq2pj7HRJnwU= k8s.io/client-go v0.30.1 h1:uC/Ir6A3R46wdkgCV3vbLyNOYyCJ8oZnjtJGKfytl/Q= k8s.io/client-go v0.30.1/go.mod h1:wrAqLNs2trwiCH/wxxmT/x3hKVH9PuV0GGW0oDoHVqc= k8s.io/client-go v0.30.3 h1:bHrJu3xQZNXIi8/MoxYtZBBWQQXwy16zqJwloXXfD3k= diff --git a/gokrazy/tsapp/config.json b/gokrazy/tsapp/config.json index 6445eb89e..d103a3601 100644 --- a/gokrazy/tsapp/config.json +++ b/gokrazy/tsapp/config.json @@ -1,16 +1,21 @@ { "Hostname": "tsapp", - "Update": { "NoPassword": true }, + "Update": { + "NoPassword": true + }, "SerialConsole": "ttyS0,115200", "Packages": [ "github.com/gokrazy/serial-busybox", "github.com/gokrazy/breakglass", "tailscale.com/cmd/tailscale", - "tailscale.com/cmd/tailscaled" + "tailscale.com/cmd/tailscaled", + "tailscale.com/cmd/tta" ], "PackageConfig": { "github.com/gokrazy/breakglass": { - "CommandLineFlags": [ "-authorized_keys=ec2" ] + "CommandLineFlags": [ + "-authorized_keys=ec2" + ] }, "tailscale.com/cmd/tailscale": { "ExtraFilePaths": { @@ -21,4 +26,4 @@ "KernelPackage": "github.com/tailscale/gokrazy-kernel", "FirmwarePackage": "github.com/tailscale/gokrazy-kernel", "InternalCompatibilityFlags": {} -} +} \ No newline at end of file diff --git a/tstest/natlab/vnet/conf.go b/tstest/natlab/vnet/conf.go new file mode 100644 index 000000000..89dfc9570 --- /dev/null +++ b/tstest/natlab/vnet/conf.go @@ -0,0 +1,217 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package vnet + +import ( + "cmp" + "fmt" + "net/netip" + "slices" + + "tailscale.com/util/set" +) + +// Note: the exported Node and Network are the configuration types; +// the unexported node and network are the runtime types that are actually +// used once the server is created. + +// Config is the requested state of the natlab virtual network. +// +// The zero value is a valid empty configuration. Call AddNode +// and AddNetwork to methods on the returned Node and Network +// values to modify the config before calling NewServer. +// Once the NewServer is called, Config is no longer used. +type Config struct { + nodes []*Node + networks []*Network +} + +// AddNode creates a new node in the world. +// +// The opts may be of the following types: +// - *Network: zero, one, or more networks to add this node to +// - TODO: more +// +// On an error or unknown opt type, AddNode returns a +// node with a carried error that gets returned later. +func (c *Config) AddNode(opts ...any) *Node { + num := len(c.nodes) + n := &Node{ + mac: MAC{0x52, 0xcc, 0xcc, 0xcc, 0xcc, byte(num)}, // 52=TS then 0xcc for ccclient + } + c.nodes = append(c.nodes, n) + for _, o := range opts { + switch o := o.(type) { + case *Network: + if !slices.Contains(o.nodes, n) { + o.nodes = append(o.nodes, n) + } + n.nets = append(n.nets, o) + default: + if n.err == nil { + n.err = fmt.Errorf("unknown AddNode option type %T", o) + } + } + } + return n +} + +// AddNetwork add a new network. +// +// The opts may be of the following types: +// - string IP address, for the network's WAN IP (if any) +// - string netip.Prefix, for the network's LAN IP (defaults to 192.168.0.0/24) +// - NAT, the type of NAT to use +// - NetworkService, a service to add to the network +// +// On an error or unknown opt type, AddNetwork returns a +// network with a carried error that gets returned later. +func (c *Config) AddNetwork(opts ...any) *Network { + num := len(c.networks) + n := &Network{ + mac: MAC{0x52, 0xee, 0xee, 0xee, 0xee, byte(num)}, // 52=TS then 0xee for 'etwork + } + c.networks = append(c.networks, n) + for _, o := range opts { + switch o := o.(type) { + case string: + if ip, err := netip.ParseAddr(o); err == nil { + n.wanIP = ip + } else if ip, err := netip.ParsePrefix(o); err == nil { + n.lanIP = ip + } else { + if n.err == nil { + n.err = fmt.Errorf("unknown string option %q", o) + } + } + case NAT: + n.natType = o + case NetworkService: + n.AddService(o) + default: + if n.err == nil { + n.err = fmt.Errorf("unknown AddNetwork option type %T", o) + } + } + } + return n +} + +// Node is the configuration of a node in the virtual network. +type Node struct { + err error + n *node // nil until NewServer called + + // TODO(bradfitz): this is halfway converted to supporting multiple NICs + // but not done. We need a MAC-per-Network. + + mac MAC + nets []*Network +} + +// Network returns the first network this node is connected to, +// or nil if none. +func (n *Node) Network() *Network { + if len(n.nets) == 0 { + return nil + } + return n.nets[0] +} + +// Network is the configuration of a network in the virtual network. +type Network struct { + mac MAC // MAC address of the router/gateway + natType NAT + + wanIP netip.Addr + lanIP netip.Prefix + nodes []*Node + + svcs set.Set[NetworkService] + + // ... + err error // carried error +} + +// NetworkService is a service that can be added to a network. +type NetworkService string + +const ( + NATPMP NetworkService = "NAT-PMP" + PCP NetworkService = "PCP" + UPnP NetworkService = "UPnP" +) + +// AddService adds a network service (such as port mapping protocols) to a +// network. +func (n *Network) AddService(s NetworkService) { + if n.svcs == nil { + n.svcs = set.Of(s) + } else { + n.svcs.Add(s) + } +} + +// initFromConfig initializes the server from the previous calls +// to NewNode and NewNetwork and returns an error if +// there were any configuration issues. +func (s *Server) initFromConfig(c *Config) error { + netOfConf := map[*Network]*network{} + for _, conf := range c.networks { + if conf.err != nil { + return conf.err + } + if !conf.lanIP.IsValid() { + conf.lanIP = netip.MustParsePrefix("192.168.0.0/24") + } + n := &network{ + s: s, + mac: conf.mac, + portmap: conf.svcs.Contains(NATPMP), // TODO: expand network.portmap + wanIP: conf.wanIP, + lanIP: conf.lanIP, + nodesByIP: map[netip.Addr]*node{}, + } + netOfConf[conf] = n + s.networks.Add(n) + if _, ok := s.networkByWAN[conf.wanIP]; ok { + return fmt.Errorf("two networks have the same WAN IP %v; Anycast not (yet?) supported", conf.wanIP) + } + s.networkByWAN[conf.wanIP] = n + } + for _, conf := range c.nodes { + if conf.err != nil { + return conf.err + } + n := &node{ + mac: conf.mac, + net: netOfConf[conf.Network()], + } + conf.n = n + if _, ok := s.nodeByMAC[n.mac]; ok { + return fmt.Errorf("two nodes have the same MAC %v", n.mac) + } + s.nodes = append(s.nodes, n) + s.nodeByMAC[n.mac] = n + + // Allocate a lanIP for the node. Use the network's CIDR and use final + // octet 101 (for first node), 102, etc. The node number comes from the + // last octent of the MAC address (0-based) + ip4 := n.net.lanIP.Addr().As4() + ip4[3] = 101 + n.mac[5] + n.lanIP = netip.AddrFrom4(ip4) + n.net.nodesByIP[n.lanIP] = n + } + + // Now that nodes are populated, set up NAT: + for _, conf := range c.networks { + n := netOfConf[conf] + natType := cmp.Or(conf.natType, EasyNAT) + if err := n.InitNAT(natType); err != nil { + return err + } + } + + return nil +} diff --git a/tstest/natlab/vnet/conf_test.go b/tstest/natlab/vnet/conf_test.go new file mode 100644 index 000000000..ae731d127 --- /dev/null +++ b/tstest/natlab/vnet/conf_test.go @@ -0,0 +1,71 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package vnet + +import "testing" + +func TestConfig(t *testing.T) { + tests := []struct { + name string + setup func(*Config) + wantErr string + }{ + { + name: "simple", + setup: func(c *Config) { + c.AddNode(c.AddNetwork("2.1.1.1", "192.168.1.1/24", EasyNAT, NATPMP)) + c.AddNode(c.AddNetwork("2.2.2.2", "10.2.0.1/16", HardNAT)) + }, + }, + { + name: "indirect", + setup: func(c *Config) { + n1 := c.AddNode(c.AddNetwork("2.1.1.1", "192.168.1.1/24", HardNAT)) + n1.Network().AddService(NATPMP) + c.AddNode(c.AddNetwork("2.2.2.2", "10.2.0.1/16", NAT("hard"))) + }, + }, + { + name: "multi-node-in-net", + setup: func(c *Config) { + net1 := c.AddNetwork("2.1.1.1", "192.168.1.1/24") + c.AddNode(net1) + c.AddNode(net1) + }, + }, + { + name: "dup-wan-ip", + setup: func(c *Config) { + c.AddNetwork("2.1.1.1", "192.168.1.1/24") + c.AddNetwork("2.1.1.1", "10.2.0.1/16") + }, + wantErr: "two networks have the same WAN IP 2.1.1.1; Anycast not (yet?) supported", + }, + { + name: "one-to-one-nat-with-multiple-nodes", + setup: func(c *Config) { + net1 := c.AddNetwork("2.1.1.1", "192.168.1.1/24", One2OneNAT) + c.AddNode(net1) + c.AddNode(net1) + }, + wantErr: "error creating NAT type \"one2one\" for network 2.1.1.1: can't use one2one NAT type on networks other than single-node networks", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var c Config + tt.setup(&c) + _, err := New(&c) + if err == nil { + if tt.wantErr == "" { + return + } + t.Fatalf("got success; wanted error %q", tt.wantErr) + } + if err.Error() != tt.wantErr { + t.Fatalf("got error %q; want %q", err, tt.wantErr) + } + }) + } +} diff --git a/tstest/natlab/vnet/nat.go b/tstest/natlab/vnet/nat.go new file mode 100644 index 000000000..9ce04a23a --- /dev/null +++ b/tstest/natlab/vnet/nat.go @@ -0,0 +1,239 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package vnet + +import ( + "errors" + "math/rand/v2" + "net/netip" + "time" + + "tailscale.com/util/mak" +) + +const ( + One2OneNAT NAT = "one2one" + EasyNAT NAT = "easy" + HardNAT NAT = "hard" +) + +// IPPool is the interface that a NAT implementation uses to get information +// about a network. +// +// Outside of tests, this is typically a *network. +type IPPool interface { + // WANIP returns the primary WAN IP address. + // + // TODO: add another method for networks with multiple WAN IP addresses. + WANIP() netip.Addr + + // SoleLanIP reports whether this network has a sole LAN client + // and if so, its IP address. + SoleLANIP() (_ netip.Addr, ok bool) + + // TODO: port availability stuff for interacting with portmapping +} + +// newTableFunc is a constructor for a NAT table. +// The provided IPPool is typically (outside of tests) a *network. +type newTableFunc func(IPPool) (NATTable, error) + +// NAT is a type of NAT that's known to natlab. +// +// For example, "easy" for Linux-style NAT, "hard" for FreeBSD-style NAT, etc. +type NAT string + +// natTypes are the known NAT types. +var natTypes = map[NAT]newTableFunc{} + +// registerNATType registers a NAT type. +func registerNATType(name NAT, f newTableFunc) { + if _, ok := natTypes[name]; ok { + panic("duplicate NAT type: " + name) + } + natTypes[name] = f +} + +// NATTable is what a NAT implementation is expected to do. +// +// This project tests Tailscale as it faces various combinations various NAT +// implementations (e.g. Linux easy style NAT vs FreeBSD hard/endpoint dependent +// NAT vs Cloud 1:1 NAT, etc) +// +// Implementations of NATTable need not handle concurrency; the natlab serializes +// all calls into a NATTable. +// +// The provided `at` value will typically be time.Now, except for tests. +// Implementations should not use real time and should only compare +// previously provided time values. +type NATTable interface { + // PickOutgoingSrc returns the source address to use for an outgoing packet. + // + // The result should either be invalid (to drop the packet) or a WAN (not + // private) IP address. + // + // Typically, the src is a LAN source IP address, but it might also be a WAN + // IP address if the packet is being forwarded for a source machine that has + // a public IP address. + PickOutgoingSrc(src, dst netip.AddrPort, at time.Time) (wanSrc netip.AddrPort) + + // PickIncomingDst returns the destination address to use for an incoming + // packet. The incoming src address is always a public WAN IP. + // + // The result should either be invalid (to drop the packet) or the IP + // address of a machine on the local network address, usually a private + // LAN IP. + PickIncomingDst(src, dst netip.AddrPort, at time.Time) (lanDst netip.AddrPort) +} + +// oneToOneNAT is a 1:1 NAT, like a typical EC2 VM. +type oneToOneNAT struct { + lanIP netip.Addr + wanIP netip.Addr +} + +func init() { + registerNATType(One2OneNAT, func(p IPPool) (NATTable, error) { + lanIP, ok := p.SoleLANIP() + if !ok { + return nil, errors.New("can't use one2one NAT type on networks other than single-node networks") + } + return &oneToOneNAT{lanIP: lanIP, wanIP: p.WANIP()}, nil + }) +} + +func (n *oneToOneNAT) PickOutgoingSrc(src, dst netip.AddrPort, at time.Time) (wanSrc netip.AddrPort) { + return netip.AddrPortFrom(n.wanIP, src.Port()) +} + +func (n *oneToOneNAT) PickIncomingDst(src, dst netip.AddrPort, at time.Time) (lanDst netip.AddrPort) { + return netip.AddrPortFrom(n.lanIP, dst.Port()) +} + +type hardKeyOut struct { + lanIP netip.Addr + dst netip.AddrPort +} + +type hardKeyIn struct { + wanPort uint16 + src netip.AddrPort +} + +type portMappingAndTime struct { + port uint16 + at time.Time +} + +type lanAddrAndTime struct { + lanAddr netip.AddrPort + at time.Time +} + +// hardNAT is an "Endpoint Dependent" NAT, like FreeBSD/pfSense/OPNsense. +// This is shown as "MappingVariesByDestIP: true" by netcheck, and what +// Tailscale calls "Hard NAT". +type hardNAT struct { + wanIP netip.Addr + + out map[hardKeyOut]portMappingAndTime + in map[hardKeyIn]lanAddrAndTime +} + +func init() { + registerNATType(HardNAT, func(p IPPool) (NATTable, error) { + return &hardNAT{wanIP: p.WANIP()}, nil + }) +} + +func (n *hardNAT) PickOutgoingSrc(src, dst netip.AddrPort, at time.Time) (wanSrc netip.AddrPort) { + ko := hardKeyOut{src.Addr(), dst} + if pm, ok := n.out[ko]; ok { + // Existing flow. + // TODO: bump timestamp + return netip.AddrPortFrom(n.wanIP, pm.port) + } + + // No existing mapping exists. Create one. + + // TODO: clean up old expired mappings + + // Instead of proper data structures that would be efficient, we instead + // just loop a bunch and look for a free port. This project is only used + // by tests and doesn't care about performance, this is good enough. + for { + port := rand.N(uint16(32<<10)) + 32<<10 // pick some "ephemeral" port + ki := hardKeyIn{wanPort: port, src: dst} + if _, ok := n.in[ki]; ok { + // Port already in use. + continue + } + mak.Set(&n.in, ki, lanAddrAndTime{lanAddr: src, at: at}) + mak.Set(&n.out, ko, portMappingAndTime{port: port, at: at}) + return netip.AddrPortFrom(n.wanIP, port) + } +} + +func (n *hardNAT) PickIncomingDst(src, dst netip.AddrPort, at time.Time) (lanDst netip.AddrPort) { + if dst.Addr() != n.wanIP { + return netip.AddrPort{} // drop; not for us. shouldn't happen if natlabd routing isn't broken. + } + ki := hardKeyIn{wanPort: dst.Port(), src: src} + if pm, ok := n.in[ki]; ok { + // Existing flow. + return pm.lanAddr + } + return netip.AddrPort{} // drop; no mapping +} + +// easyNAT is an "Endpoint Independent" NAT, like Linux and most home routers +// (many of which are Linux). +// +// This is shown as "MappingVariesByDestIP: false" by netcheck, and what +// Tailscale calls "Easy NAT". +// +// Unlike Linux, this implementation is capped at 32k entries and doesn't resort +// to other allocation strategies when all 32k WAN ports are taken. +type easyNAT struct { + wanIP netip.Addr + out map[netip.AddrPort]portMappingAndTime + in map[uint16]lanAddrAndTime +} + +func init() { + registerNATType(EasyNAT, func(p IPPool) (NATTable, error) { + return &easyNAT{wanIP: p.WANIP()}, nil + }) +} + +func (n *easyNAT) PickOutgoingSrc(src, dst netip.AddrPort, at time.Time) (wanSrc netip.AddrPort) { + if pm, ok := n.out[src]; ok { + // Existing flow. + // TODO: bump timestamp + return netip.AddrPortFrom(n.wanIP, pm.port) + } + + // Loop through all 32k high (ephemeral) ports, starting at a random + // position and looping back around to the start. + start := rand.N(uint16(32 << 10)) + for off := range uint16(32 << 10) { + port := 32<<10 + (start+off)%(32<<10) + if _, ok := n.in[port]; !ok { + wanAddr := netip.AddrPortFrom(n.wanIP, port) + + // Found a free port. + mak.Set(&n.out, src, portMappingAndTime{port: port, at: at}) + mak.Set(&n.in, port, lanAddrAndTime{lanAddr: src, at: at}) + return wanAddr + } + } + return netip.AddrPort{} // failed to allocate a mapping; TODO: fire an alert? +} + +func (n *easyNAT) PickIncomingDst(src, dst netip.AddrPort, at time.Time) (lanDst netip.AddrPort) { + if dst.Addr() != n.wanIP { + return netip.AddrPort{} // drop; not for us. shouldn't happen if natlabd routing isn't broken. + } + return n.in[dst.Port()].lanAddr +} diff --git a/tstest/natlab/vnet/vnet.go b/tstest/natlab/vnet/vnet.go new file mode 100644 index 000000000..7ce86d512 --- /dev/null +++ b/tstest/natlab/vnet/vnet.go @@ -0,0 +1,1237 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package vnet simulates a virtual Internet containing a set of networks with various +// NAT behaviors. You can then plug VMs into the virtual internet at different points +// to test Tailscale working end-to-end in various network conditions. +// +// See https://github.com/tailscale/tailscale/issues/13038 +package vnet + +// TODO: +// - [ ] port mapping actually working +// - [ ] conf to let you firewall things +// - [ ] tests for NAT tables + +import ( + "bufio" + "context" + "encoding/binary" + "encoding/json" + "errors" + "fmt" + "io" + "log" + "net" + "net/http" + "net/netip" + "os/exec" + "strconv" + "sync" + "time" + + "github.com/google/gopacket" + "github.com/google/gopacket/layers" + "go4.org/mem" + "gvisor.dev/gvisor/pkg/buffer" + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/adapters/gonet" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/link/channel" + "gvisor.dev/gvisor/pkg/tcpip/network/arp" + "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" + "gvisor.dev/gvisor/pkg/tcpip/stack" + "gvisor.dev/gvisor/pkg/tcpip/transport/icmp" + "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" + "gvisor.dev/gvisor/pkg/waiter" + "tailscale.com/net/stun" + "tailscale.com/syncs" + "tailscale.com/tailcfg" + "tailscale.com/util/mak" + "tailscale.com/util/set" +) + +const nicID = 1 +const stunPort = 3478 + +func (s *Server) PopulateDERPMapIPs() error { + out, err := exec.Command("tailscale", "debug", "derp-map").Output() + if err != nil { + return fmt.Errorf("tailscale debug derp-map: %v", err) + } + var dm tailcfg.DERPMap + if err := json.Unmarshal(out, &dm); err != nil { + return fmt.Errorf("unmarshal DERPMap: %v", err) + } + for _, r := range dm.Regions { + for _, n := range r.Nodes { + if n.IPv4 != "" { + s.derpIPs.Add(netip.MustParseAddr(n.IPv4)) + } + } + } + return nil +} + +func (n *network) InitNAT(natType NAT) error { + ctor, ok := natTypes[natType] + if !ok { + return fmt.Errorf("unknown NAT type %q", natType) + } + t, err := ctor(n) + if err != nil { + return fmt.Errorf("error creating NAT type %q for network %v: %w", natType, n.wanIP, err) + } + n.setNATTable(t) + n.natStyle.Store(natType) + return nil +} + +func (n *network) setNATTable(nt NATTable) { + n.natMu.Lock() + defer n.natMu.Unlock() + n.natTable = nt +} + +// SoleLANIP implements [IPPool]. +func (n *network) SoleLANIP() (netip.Addr, bool) { + if len(n.nodesByIP) != 1 { + return netip.Addr{}, false + } + for ip := range n.nodesByIP { + return ip, true + } + return netip.Addr{}, false +} + +// WANIP implements [IPPool]. +func (n *network) WANIP() netip.Addr { return n.wanIP } + +func (n *network) initStack() error { + n.ns = stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{ + ipv4.NewProtocol, + arp.NewProtocol, + }, + TransportProtocols: []stack.TransportProtocolFactory{ + tcp.NewProtocol, + icmp.NewProtocol4, + }, + }) + sackEnabledOpt := tcpip.TCPSACKEnabled(true) // TCP SACK is disabled by default + tcpipErr := n.ns.SetTransportProtocolOption(tcp.ProtocolNumber, &sackEnabledOpt) + if tcpipErr != nil { + return fmt.Errorf("SetTransportProtocolOption SACK: %v", tcpipErr) + } + n.linkEP = channel.New(512, 1500, tcpip.LinkAddress(n.mac.HWAddr())) + if tcpipProblem := n.ns.CreateNIC(nicID, n.linkEP); tcpipProblem != nil { + return fmt.Errorf("CreateNIC: %v", tcpipProblem) + } + n.ns.SetPromiscuousMode(nicID, true) + n.ns.SetSpoofing(nicID, true) + + prefix := tcpip.AddrFrom4Slice(n.lanIP.Addr().AsSlice()).WithPrefix() + prefix.PrefixLen = n.lanIP.Bits() + if tcpProb := n.ns.AddProtocolAddress(nicID, tcpip.ProtocolAddress{ + Protocol: ipv4.ProtocolNumber, + AddressWithPrefix: prefix, + }, stack.AddressProperties{}); tcpProb != nil { + return errors.New(tcpProb.String()) + } + + ipv4Subnet, err := tcpip.NewSubnet(tcpip.AddrFromSlice(make([]byte, 4)), tcpip.MaskFromBytes(make([]byte, 4))) + if err != nil { + return fmt.Errorf("could not create IPv4 subnet: %v", err) + } + n.ns.SetRouteTable([]tcpip.Route{ + { + Destination: ipv4Subnet, + NIC: nicID, + }, + }) + + const tcpReceiveBufferSize = 0 // default + const maxInFlightConnectionAttempts = 8192 + tcpFwd := tcp.NewForwarder(n.ns, tcpReceiveBufferSize, maxInFlightConnectionAttempts, n.acceptTCP) + n.ns.SetTransportProtocolHandler(tcp.ProtocolNumber, func(tei stack.TransportEndpointID, pb *stack.PacketBuffer) (handled bool) { + return tcpFwd.HandlePacket(tei, pb) + }) + + go func() { + for { + pkt := n.linkEP.ReadContext(n.s.shutdownCtx) + if pkt == nil { + if n.s.shutdownCtx.Err() != nil { + // Return without logging. + return + } + continue + } + + ipRaw := pkt.ToView().AsSlice() + goPkt := gopacket.NewPacket( + ipRaw, + layers.LayerTypeIPv4, gopacket.Lazy) + layerV4 := goPkt.Layer(layers.LayerTypeIPv4).(*layers.IPv4) + + dstIP, _ := netip.AddrFromSlice(layerV4.DstIP) + node, ok := n.nodesByIP[dstIP] + if !ok { + log.Printf("no MAC for dest IP %v", dstIP) + continue + } + eth := &layers.Ethernet{ + SrcMAC: n.mac.HWAddr(), + DstMAC: node.mac.HWAddr(), + EthernetType: layers.EthernetTypeIPv4, + } + buffer := gopacket.NewSerializeBuffer() + options := gopacket.SerializeOptions{FixLengths: true, ComputeChecksums: true} + sls := []gopacket.SerializableLayer{ + eth, + } + for _, layer := range goPkt.Layers() { + sl, ok := layer.(gopacket.SerializableLayer) + if !ok { + log.Fatalf("layer %s is not serializable", layer.LayerType().String()) + } + switch gl := layer.(type) { + case *layers.TCP: + gl.SetNetworkLayerForChecksum(layerV4) + case *layers.UDP: + gl.SetNetworkLayerForChecksum(layerV4) + } + sls = append(sls, sl) + } + + if err := gopacket.SerializeLayers(buffer, options, sls...); err != nil { + log.Printf("Serialize error: %v", err) + continue + } + if writeFunc, ok := n.writeFunc.Load(node.mac); ok { + writeFunc(buffer.Bytes()) + } else { + log.Printf("No writeFunc for %v", node.mac) + } + } + }() + return nil +} + +func netaddrIPFromNetstackIP(s tcpip.Address) netip.Addr { + switch s.Len() { + case 4: + return netip.AddrFrom4(s.As4()) + case 16: + return netip.AddrFrom16(s.As16()).Unmap() + } + return netip.Addr{} +} + +func stringifyTEI(tei stack.TransportEndpointID) string { + localHostPort := net.JoinHostPort(tei.LocalAddress.String(), strconv.Itoa(int(tei.LocalPort))) + remoteHostPort := net.JoinHostPort(tei.RemoteAddress.String(), strconv.Itoa(int(tei.RemotePort))) + return fmt.Sprintf("%s -> %s", remoteHostPort, localHostPort) +} + +func (n *network) acceptTCP(r *tcp.ForwarderRequest) { + reqDetails := r.ID() + + log.Printf("AcceptTCP: %v", stringifyTEI(reqDetails)) + clientRemoteIP := netaddrIPFromNetstackIP(reqDetails.RemoteAddress) + destIP := netaddrIPFromNetstackIP(reqDetails.LocalAddress) + if !clientRemoteIP.IsValid() { + r.Complete(true) // sends a RST + return + } + + var wq waiter.Queue + ep, err := r.CreateEndpoint(&wq) + if err != nil { + log.Printf("CreateEndpoint error for %s: %v", stringifyTEI(reqDetails), err) + r.Complete(true) // sends a RST + return + } + ep.SocketOptions().SetKeepAlive(true) + + if reqDetails.LocalPort == 123 { + r.Complete(false) + tc := gonet.NewTCPConn(&wq, ep) + io.WriteString(tc, "Hello from Go\nGoodbye.\n") + tc.Close() + return + } + + if reqDetails.LocalPort == 8008 && destIP == fakeTestAgentIP { + r.Complete(false) + tc := gonet.NewTCPConn(&wq, ep) + node := n.nodesByIP[clientRemoteIP] + ac := &agentConn{node, tc} + n.s.addIdleAgentConn(ac) + return + } + + var targetDial string + if n.s.derpIPs.Contains(destIP) { + targetDial = destIP.String() + ":" + strconv.Itoa(int(reqDetails.LocalPort)) + } else if destIP == fakeControlplaneIP { + targetDial = "controlplane.tailscale.com:" + strconv.Itoa(int(reqDetails.LocalPort)) + } + if targetDial != "" { + c, err := net.Dial("tcp", targetDial) + if err != nil { + r.Complete(true) + log.Printf("Dial controlplane: %v", err) + return + } + defer c.Close() + tc := gonet.NewTCPConn(&wq, ep) + defer tc.Close() + r.Complete(false) + errc := make(chan error, 2) + go func() { _, err := io.Copy(tc, c); errc <- err }() + go func() { _, err := io.Copy(c, tc); errc <- err }() + <-errc + } else { + r.Complete(true) // sends a RST + } +} + +var ( + fakeDNSIP = netip.AddrFrom4([4]byte{4, 11, 4, 11}) + fakeControlplaneIP = netip.AddrFrom4([4]byte{52, 52, 0, 1}) + fakeTestAgentIP = netip.AddrFrom4([4]byte{52, 52, 0, 2}) +) + +type EthernetPacket struct { + le *layers.Ethernet + gp gopacket.Packet +} + +func (ep EthernetPacket) SrcMAC() MAC { + return MAC(ep.le.SrcMAC) +} + +func (ep EthernetPacket) DstMAC() MAC { + return MAC(ep.le.DstMAC) +} + +type MAC [6]byte + +func (m MAC) IsBroadcast() bool { + return m == MAC{0xff, 0xff, 0xff, 0xff, 0xff, 0xff} +} + +func macOf(hwa net.HardwareAddr) (_ MAC, ok bool) { + if len(hwa) != 6 { + return MAC{}, false + } + return MAC(hwa), true +} + +func (m MAC) HWAddr() net.HardwareAddr { + return net.HardwareAddr(m[:]) +} + +func (m MAC) String() string { + return fmt.Sprintf("%02x:%02x:%02x:%02x:%02x:%02x", m[0], m[1], m[2], m[3], m[4], m[5]) +} + +type network struct { + s *Server + mac MAC + portmap bool + wanIP netip.Addr + lanIP netip.Prefix // with host bits set (e.g. 192.168.2.1/24) + nodesByIP map[netip.Addr]*node + + ns *stack.Stack + linkEP *channel.Endpoint + + natStyle syncs.AtomicValue[NAT] + natMu sync.Mutex // held while using + changing natTable + natTable NATTable + + // writeFunc is a map of MAC -> func to write to that MAC. + // It contains entries for connected nodes only. + writeFunc syncs.Map[MAC, func([]byte)] // MAC -> func to write to that MAC +} + +func (n *network) registerWriter(mac MAC, f func([]byte)) { + if f != nil { + n.writeFunc.Store(mac, f) + } else { + n.writeFunc.Delete(mac) + } +} + +func (n *network) MACOfIP(ip netip.Addr) (_ MAC, ok bool) { + if n.lanIP.Addr() == ip { + return n.mac, true + } + if n, ok := n.nodesByIP[ip]; ok { + return n.mac, true + } + return MAC{}, false +} + +type node struct { + mac MAC + net *network + lanIP netip.Addr // must be in net.lanIP prefix + unique in net +} + +type Server struct { + shutdownCtx context.Context + shutdownCancel context.CancelFunc + + derpIPs set.Set[netip.Addr] + + nodes []*node + nodeByMAC map[MAC]*node + networks set.Set[*network] + networkByWAN map[netip.Addr]*network + + mu sync.Mutex + agentConnWaiter map[*node]chan<- struct{} // signaled after added to set + agentConns set.Set[*agentConn] // not keyed by node; should be small/cheap enough to scan all + agentRoundTripper map[*node]*http.Transport +} + +func New(c *Config) (*Server, error) { + ctx, cancel := context.WithCancel(context.Background()) + s := &Server{ + shutdownCtx: ctx, + shutdownCancel: cancel, + + derpIPs: set.Of[netip.Addr](), + + nodeByMAC: map[MAC]*node{}, + networkByWAN: map[netip.Addr]*network{}, + networks: set.Of[*network](), + } + if err := s.initFromConfig(c); err != nil { + return nil, err + } + for n := range s.networks { + if err := n.initStack(); err != nil { + return nil, fmt.Errorf("newServer: initStack: %v", err) + } + } + return s, nil +} + +func (s *Server) HWAddr(mac MAC) net.HardwareAddr { + // TODO: cache + return net.HardwareAddr(mac[:]) +} + +// IPv4ForDNS returns the IP address for the given DNS query name (for IPv4 A +// queries only). +func (s *Server) IPv4ForDNS(qname string) (netip.Addr, bool) { + switch qname { + case "dns": + return fakeDNSIP, true + case "test-driver.tailscale": + return fakeTestAgentIP, true + case "controlplane.tailscale.com": + return fakeControlplaneIP, true + } + return netip.Addr{}, false +} + +type Protocol int + +const ( + ProtocolQEMU = Protocol(iota + 1) + ProtocolUnixDGRAM // for macOS Hypervisor.Framework and VZFileHandleNetworkDeviceAttachment +) + +// serveConn serves a single connection from a client. +func (s *Server) ServeUnixConn(uc *net.UnixConn, proto Protocol) { + log.Printf("Got conn %T %p", uc, uc) + defer uc.Close() + + bw := bufio.NewWriterSize(uc, 2<<10) + var writeMu sync.Mutex + writePkt := func(pkt []byte) { + if pkt == nil { + return + } + writeMu.Lock() + defer writeMu.Unlock() + if proto == ProtocolQEMU { + hdr := binary.BigEndian.AppendUint32(bw.AvailableBuffer()[:0], uint32(len(pkt))) + if _, err := bw.Write(hdr); err != nil { + log.Printf("Write hdr: %v", err) + return + } + } + if _, err := bw.Write(pkt); err != nil { + log.Printf("Write pkt: %v", err) + return + } + if err := bw.Flush(); err != nil { + log.Printf("Flush: %v", err) + } + } + + buf := make([]byte, 16<<10) + var srcNode *node + var netw *network // non-nil after first packet + for { + var packetRaw []byte + if proto == ProtocolUnixDGRAM { + n, _, err := uc.ReadFromUnix(buf) + if err != nil { + log.Printf("ReadFromUnix: %v", err) + continue + } + packetRaw = buf[:n] + } else if proto == ProtocolQEMU { + if _, err := io.ReadFull(uc, buf[:4]); err != nil { + log.Printf("ReadFull header: %v", err) + return + } + n := binary.BigEndian.Uint32(buf[:4]) + + if _, err := io.ReadFull(uc, buf[4:4+n]); err != nil { + log.Printf("ReadFull pkt: %v", err) + return + } + packetRaw = buf[4 : 4+n] // raw ethernet frame + } + + packet := gopacket.NewPacket(packetRaw, layers.LayerTypeEthernet, gopacket.Lazy) + le, ok := packet.LinkLayer().(*layers.Ethernet) + if !ok || len(le.SrcMAC) != 6 || len(le.DstMAC) != 6 { + continue + } + ep := EthernetPacket{le, packet} + + srcMAC := ep.SrcMAC() + if srcNode == nil { + srcNode, ok = s.nodeByMAC[srcMAC] + if !ok { + log.Printf("[conn %p] ignoring frame from unknown MAC %v", uc, srcMAC) + continue + } + log.Printf("[conn %p] MAC %v is node %v", uc, srcMAC, srcNode.lanIP) + netw = srcNode.net + netw.registerWriter(srcMAC, writePkt) + defer netw.registerWriter(srcMAC, nil) + } else { + if srcMAC != srcNode.mac { + log.Printf("[conn %p] ignoring frame from MAC %v, expected %v", uc, srcMAC, srcNode.mac) + continue + } + } + netw.HandleEthernetPacket(ep) + } +} + +func (s *Server) routeUDPPacket(up UDPPacket) { + // Find which network owns this based on the destination IP + // and all the known networks' wan IPs. + + // But certain things (like STUN) we do in-process. + if up.Dst.Port() == stunPort { + // TODO(bradfitz): fake latency; time.AfterFunc the response + if res, ok := makeSTUNReply(up); ok { + s.routeUDPPacket(res) + } + return + } + + netw, ok := s.networkByWAN[up.Dst.Addr()] + if !ok { + log.Printf("no network to route UDP packet for %v", up.Dst) + return + } + netw.HandleUDPPacket(up) +} + +// writeEth writes a raw Ethernet frame to all (0, 1, or multiple) connected +// clients on the network. +// +// This only delivers to client devices and not the virtual router/gateway +// device. +func (n *network) writeEth(res []byte) { + if len(res) < 12 { + return + } + dstMAC := MAC(res[0:6]) + srcMAC := MAC(res[6:12]) + if dstMAC.IsBroadcast() { + n.writeFunc.Range(func(mac MAC, writeFunc func([]byte)) bool { + writeFunc(res) + return true + }) + return + } + if srcMAC == dstMAC { + log.Printf("dropping write of packet from %v to itself", srcMAC) + return + } + if writeFunc, ok := n.writeFunc.Load(dstMAC); ok { + writeFunc(res) + return + } +} + +func (n *network) HandleEthernetPacket(ep EthernetPacket) { + packet := ep.gp + dstMAC := ep.DstMAC() + isBroadcast := dstMAC.IsBroadcast() + forRouter := dstMAC == n.mac || isBroadcast + + switch ep.le.EthernetType { + default: + log.Printf("Dropping non-IP packet: %v", ep.le.EthernetType) + return + case layers.EthernetTypeARP: + res, err := n.createARPResponse(packet) + if err != nil { + log.Printf("createARPResponse: %v", err) + } else { + n.writeEth(res) + } + return + case layers.EthernetTypeIPv6: + // One day. Low value for now. IPv4 NAT modes is the main thing + // this project wants to test. + return + case layers.EthernetTypeIPv4: + // Below + } + + // Send ethernet broadcasts and unicast ethernet frames to peers + // on the same network. This is all LAN traffic that isn't meant + // for the router/gw itself: + n.writeEth(ep.gp.Data()) + + if forRouter { + n.HandleEthernetIPv4PacketForRouter(ep) + } +} + +// HandleUDPPacket handles a UDP packet arriving from the internet, +// addressed to the router's WAN IP. It is then NATed back to a +// LAN IP here and wrapped in an ethernet layer and delivered +// to the network. +func (n *network) HandleUDPPacket(p UDPPacket) { + dst := n.doNATIn(p.Src, p.Dst) + if !dst.IsValid() { + return + } + p.Dst = dst + n.WriteUDPPacketNoNAT(p) +} + +// WriteUDPPacketNoNAT writes a UDP packet to the network, without +// doing any NAT translation. +// +// The packet will always have the ethernet src MAC of the router +// so this should not be used for packets between clients on the +// same ethernet segment. +func (n *network) WriteUDPPacketNoNAT(p UDPPacket) { + src, dst := p.Src, p.Dst + node, ok := n.nodesByIP[dst.Addr()] + if !ok { + log.Printf("no node for dest IP %v in UDP packet %v=>%v", dst.Addr(), p.Src, p.Dst) + return + } + + eth := &layers.Ethernet{ + SrcMAC: n.mac.HWAddr(), // of gateway + DstMAC: node.mac.HWAddr(), + EthernetType: layers.EthernetTypeIPv4, + } + ip := &layers.IPv4{ + Version: 4, + TTL: 64, + Protocol: layers.IPProtocolUDP, + SrcIP: src.Addr().AsSlice(), + DstIP: dst.Addr().AsSlice(), + } + udp := &layers.UDP{ + SrcPort: layers.UDPPort(src.Port()), + DstPort: layers.UDPPort(dst.Port()), + } + udp.SetNetworkLayerForChecksum(ip) + + buffer := gopacket.NewSerializeBuffer() + options := gopacket.SerializeOptions{FixLengths: true, ComputeChecksums: true} + if err := gopacket.SerializeLayers(buffer, options, eth, ip, udp, gopacket.Payload(p.Payload)); err != nil { + log.Printf("serializing UDP: %v", err) + return + } + ethRaw := buffer.Bytes() + n.writeEth(ethRaw) +} + +// HandleEthernetIPv4PacketForRouter handles an IPv4 packet that is +// directed to the router/gateway itself. The packet may be to the +// broadcast MAC address, or to the router's MAC address. The target +// IP may be the router's IP, or an internet (routed) IP. +func (n *network) HandleEthernetIPv4PacketForRouter(ep EthernetPacket) { + packet := ep.gp + writePkt := n.writeEth + + v4, ok := packet.Layer(layers.LayerTypeIPv4).(*layers.IPv4) + if !ok { + return + } + srcIP, _ := netip.AddrFromSlice(v4.SrcIP) + dstIP, _ := netip.AddrFromSlice(v4.DstIP) + toForward := dstIP != n.lanIP.Addr() && dstIP != netip.IPv4Unspecified() + udp, isUDP := packet.Layer(layers.LayerTypeUDP).(*layers.UDP) + + if isDHCPRequest(packet) { + res, err := n.s.createDHCPResponse(packet) + if err != nil { + log.Printf("createDHCPResponse: %v", err) + return + } + writePkt(res) + return + } + + if isMDNSQuery(packet) || isIGMP(packet) { + // Don't log. Spammy for now. + return + } + + if isDNSRequest(packet) { + // TODO(bradfitz): restrict this to 4.11.4.11? add DNS + // on gateway instead? + res, err := n.s.createDNSResponse(packet) + if err != nil { + log.Printf("createDNSResponse: %v", err) + return + } + writePkt(res) + return + } + + if !toForward && isNATPMP(packet) { + n.handleNATPMPRequest(UDPPacket{ + Src: netip.AddrPortFrom(srcIP, uint16(udp.SrcPort)), + Dst: netip.AddrPortFrom(dstIP, uint16(udp.DstPort)), + Payload: udp.Payload, + }) + return + } + + if toForward && isUDP { + src := netip.AddrPortFrom(srcIP, uint16(udp.SrcPort)) + dst := netip.AddrPortFrom(dstIP, uint16(udp.DstPort)) + src = n.doNATOut(src, dst) + + n.s.routeUDPPacket(UDPPacket{ + Src: src, + Dst: dst, + Payload: udp.Payload, + }) + return + } + + if toForward && n.s.shouldInterceptTCP(packet) { + ipp := packet.Layer(layers.LayerTypeIPv4).(*layers.IPv4) + pktCopy := make([]byte, 0, len(ipp.Contents)+len(ipp.Payload)) + pktCopy = append(pktCopy, ipp.Contents...) + pktCopy = append(pktCopy, ipp.Payload...) + packetBuf := stack.NewPacketBuffer(stack.PacketBufferOptions{ + Payload: buffer.MakeWithData(pktCopy), + }) + n.linkEP.InjectInbound(header.IPv4ProtocolNumber, packetBuf) + packetBuf.DecRef() + return + } + + //log.Printf("Got packet: %v", packet) +} + +func (s *Server) createDHCPResponse(request gopacket.Packet) ([]byte, error) { + ethLayer := request.Layer(layers.LayerTypeEthernet).(*layers.Ethernet) + srcMAC, ok := macOf(ethLayer.SrcMAC) + if !ok { + return nil, nil + } + node, ok := s.nodeByMAC[srcMAC] + if !ok { + log.Printf("DHCP request from unknown node %v; ignoring", srcMAC) + return nil, nil + } + gwIP := node.net.lanIP.Addr() + + ipLayer := request.Layer(layers.LayerTypeIPv4).(*layers.IPv4) + udpLayer := request.Layer(layers.LayerTypeUDP).(*layers.UDP) + dhcpLayer := request.Layer(layers.LayerTypeDHCPv4).(*layers.DHCPv4) + + response := &layers.DHCPv4{ + Operation: layers.DHCPOpReply, + HardwareType: layers.LinkTypeEthernet, + HardwareLen: 6, + Xid: dhcpLayer.Xid, + ClientHWAddr: dhcpLayer.ClientHWAddr, + Flags: dhcpLayer.Flags, + YourClientIP: node.lanIP.AsSlice(), + Options: []layers.DHCPOption{ + { + Type: layers.DHCPOptServerID, + Data: gwIP.AsSlice(), // DHCP server's IP + Length: 4, + }, + }, + } + + var msgType layers.DHCPMsgType + for _, opt := range dhcpLayer.Options { + if opt.Type == layers.DHCPOptMessageType && opt.Length > 0 { + msgType = layers.DHCPMsgType(opt.Data[0]) + } + } + switch msgType { + case layers.DHCPMsgTypeDiscover: + response.Options = append(response.Options, layers.DHCPOption{ + Type: layers.DHCPOptMessageType, + Data: []byte{byte(layers.DHCPMsgTypeOffer)}, + Length: 1, + }) + case layers.DHCPMsgTypeRequest: + response.Options = append(response.Options, + layers.DHCPOption{ + Type: layers.DHCPOptMessageType, + Data: []byte{byte(layers.DHCPMsgTypeAck)}, + Length: 1, + }, + layers.DHCPOption{ + Type: layers.DHCPOptLeaseTime, + Data: binary.BigEndian.AppendUint32(nil, 3600), // hour? sure. + Length: 4, + }, + layers.DHCPOption{ + Type: layers.DHCPOptRouter, + Data: gwIP.AsSlice(), + Length: 4, + }, + layers.DHCPOption{ + Type: layers.DHCPOptDNS, + Data: fakeDNSIP.AsSlice(), + Length: 4, + }, + layers.DHCPOption{ + Type: layers.DHCPOptSubnetMask, + Data: net.CIDRMask(node.net.lanIP.Bits(), 32), + Length: 4, + }, + ) + } + + eth := &layers.Ethernet{ + SrcMAC: node.net.mac.HWAddr(), + DstMAC: ethLayer.SrcMAC, + EthernetType: layers.EthernetTypeIPv4, + } + + ip := &layers.IPv4{ + Version: 4, + TTL: 64, + Protocol: layers.IPProtocolUDP, + SrcIP: ipLayer.DstIP, + DstIP: ipLayer.SrcIP, + } + + udp := &layers.UDP{ + SrcPort: udpLayer.DstPort, + DstPort: udpLayer.SrcPort, + } + udp.SetNetworkLayerForChecksum(ip) + + buffer := gopacket.NewSerializeBuffer() + options := gopacket.SerializeOptions{FixLengths: true, ComputeChecksums: true} + if err := gopacket.SerializeLayers(buffer, options, + eth, + ip, + udp, + response, + ); err != nil { + return nil, err + } + + return buffer.Bytes(), nil +} + +func isDHCPRequest(pkt gopacket.Packet) bool { + v4, ok := pkt.Layer(layers.LayerTypeIPv4).(*layers.IPv4) + if !ok || v4.Protocol != layers.IPProtocolUDP { + return false + } + udp, ok := pkt.Layer(layers.LayerTypeUDP).(*layers.UDP) + return ok && udp.DstPort == 67 && udp.SrcPort == 68 +} + +func isIGMP(pkt gopacket.Packet) bool { + return pkt.Layer(layers.LayerTypeIGMP) != nil +} + +func isMDNSQuery(pkt gopacket.Packet) bool { + udp, ok := pkt.Layer(layers.LayerTypeUDP).(*layers.UDP) + // TODO(bradfitz): also check IPv4 DstIP=224.0.0.251 (or whatever) + return ok && udp.SrcPort == 5353 && udp.DstPort == 5353 +} + +func (s *Server) shouldInterceptTCP(pkt gopacket.Packet) bool { + tcp, ok := pkt.Layer(layers.LayerTypeTCP).(*layers.TCP) + if !ok { + return false + } + ipv4, ok := pkt.Layer(layers.LayerTypeIPv4).(*layers.IPv4) + if !ok { + return false + } + if tcp.DstPort == 123 { + return true + } + dstIP, _ := netip.AddrFromSlice(ipv4.DstIP.To4()) + if tcp.DstPort == 80 || tcp.DstPort == 443 { + if dstIP == fakeControlplaneIP || s.derpIPs.Contains(dstIP) { + return true + } + } + if tcp.DstPort == 8008 && dstIP == fakeTestAgentIP { + // Connection from cmd/tta. + return true + } + return false +} + +// isDNSRequest reports whether pkt is a DNS request to the fake DNS server. +func isDNSRequest(pkt gopacket.Packet) bool { + udp, ok := pkt.Layer(layers.LayerTypeUDP).(*layers.UDP) + if !ok || udp.DstPort != 53 { + return false + } + ip, ok := pkt.Layer(layers.LayerTypeIPv4).(*layers.IPv4) + if !ok { + return false + } + dstIP, ok := netip.AddrFromSlice(ip.DstIP) + if !ok || dstIP != fakeDNSIP { + return false + } + dns, ok := pkt.Layer(layers.LayerTypeDNS).(*layers.DNS) + return ok && dns.QR == false && len(dns.Questions) > 0 +} + +func isNATPMP(pkt gopacket.Packet) bool { + udp, ok := pkt.Layer(layers.LayerTypeUDP).(*layers.UDP) + return ok && udp.DstPort == 5351 && len(udp.Payload) > 0 && udp.Payload[0] == 0 // version 0, not 2 for PCP +} + +func makeSTUNReply(req UDPPacket) (res UDPPacket, ok bool) { + txid, err := stun.ParseBindingRequest(req.Payload) + if err != nil { + log.Printf("invalid STUN request: %v", err) + return res, false + } + return UDPPacket{ + Src: req.Dst, + Dst: req.Src, + Payload: stun.Response(txid, req.Src), + }, true +} + +func (s *Server) createDNSResponse(pkt gopacket.Packet) ([]byte, error) { + ethLayer := pkt.Layer(layers.LayerTypeEthernet).(*layers.Ethernet) + ipLayer := pkt.Layer(layers.LayerTypeIPv4).(*layers.IPv4) + udpLayer := pkt.Layer(layers.LayerTypeUDP).(*layers.UDP) + dnsLayer := pkt.Layer(layers.LayerTypeDNS).(*layers.DNS) + + if dnsLayer.OpCode != layers.DNSOpCodeQuery || dnsLayer.QR || len(dnsLayer.Questions) == 0 { + return nil, nil + } + + response := &layers.DNS{ + ID: dnsLayer.ID, + QR: true, + AA: true, + TC: false, + RD: dnsLayer.RD, + RA: true, + OpCode: layers.DNSOpCodeQuery, + ResponseCode: layers.DNSResponseCodeNoErr, + } + + var names []string + for _, q := range dnsLayer.Questions { + response.QDCount++ + response.Questions = append(response.Questions, q) + + if mem.HasSuffix(mem.B(q.Name), mem.S(".pool.ntp.org")) { + // Just drop DNS queries for NTP servers. For Debian/etc guests used + // during development. Not needed. Assume VM guests get correct time + // via their hypervisor. + return nil, nil + } + + names = append(names, q.Type.String()+"/"+string(q.Name)) + if q.Class != layers.DNSClassIN || q.Type != layers.DNSTypeA { + continue + } + + if ip, ok := s.IPv4ForDNS(string(q.Name)); ok { + response.ANCount++ + response.Answers = append(response.Answers, layers.DNSResourceRecord{ + Name: q.Name, + Type: q.Type, + Class: q.Class, + IP: ip.AsSlice(), + TTL: 60, + }) + } + } + + eth2 := &layers.Ethernet{ + SrcMAC: ethLayer.DstMAC, + DstMAC: ethLayer.SrcMAC, + EthernetType: layers.EthernetTypeIPv4, + } + ip2 := &layers.IPv4{ + Version: 4, + TTL: 64, + Protocol: layers.IPProtocolUDP, + SrcIP: ipLayer.DstIP, + DstIP: ipLayer.SrcIP, + } + udp2 := &layers.UDP{ + SrcPort: udpLayer.DstPort, + DstPort: udpLayer.SrcPort, + } + udp2.SetNetworkLayerForChecksum(ip2) + + buffer := gopacket.NewSerializeBuffer() + options := gopacket.SerializeOptions{FixLengths: true, ComputeChecksums: true} + if err := gopacket.SerializeLayers(buffer, options, eth2, ip2, udp2, response); err != nil { + return nil, err + } + + const debugDNS = false + if debugDNS { + if len(response.Answers) > 0 { + back := gopacket.NewPacket(buffer.Bytes(), layers.LayerTypeEthernet, gopacket.Lazy) + log.Printf("Generated: %v", back) + } else { + log.Printf("made empty response for %q", names) + } + } + + return buffer.Bytes(), nil +} + +// doNATOut performs NAT on an outgoing packet from src to dst, where +// src is a LAN IP and dst is a WAN IP. +// +// It returns the souce WAN ip:port to use. +func (n *network) doNATOut(src, dst netip.AddrPort) (newSrc netip.AddrPort) { + n.natMu.Lock() + defer n.natMu.Unlock() + return n.natTable.PickOutgoingSrc(src, dst, time.Now()) +} + +// doNATIn performs NAT on an incoming packet from WAN src to WAN dst, returning +// a new destination LAN ip:port to use. +func (n *network) doNATIn(src, dst netip.AddrPort) (newDst netip.AddrPort) { + n.natMu.Lock() + defer n.natMu.Unlock() + return n.natTable.PickIncomingDst(src, dst, time.Now()) +} + +func (n *network) createARPResponse(pkt gopacket.Packet) ([]byte, error) { + ethLayer, ok := pkt.Layer(layers.LayerTypeEthernet).(*layers.Ethernet) + if !ok { + return nil, nil + } + arpLayer, ok := pkt.Layer(layers.LayerTypeARP).(*layers.ARP) + if !ok || + arpLayer.Operation != layers.ARPRequest || + arpLayer.AddrType != layers.LinkTypeEthernet || + arpLayer.Protocol != layers.EthernetTypeIPv4 || + arpLayer.HwAddressSize != 6 || + arpLayer.ProtAddressSize != 4 || + len(arpLayer.DstProtAddress) != 4 { + return nil, nil + } + + wantIP := netip.AddrFrom4([4]byte(arpLayer.DstProtAddress)) + foundMAC, ok := n.MACOfIP(wantIP) + if !ok { + return nil, nil + } + + eth := &layers.Ethernet{ + SrcMAC: foundMAC.HWAddr(), + DstMAC: ethLayer.SrcMAC, + EthernetType: layers.EthernetTypeARP, + } + + a2 := &layers.ARP{ + AddrType: layers.LinkTypeEthernet, + Protocol: layers.EthernetTypeIPv4, + HwAddressSize: 6, + ProtAddressSize: 4, + Operation: layers.ARPReply, + SourceHwAddress: foundMAC.HWAddr(), + SourceProtAddress: arpLayer.DstProtAddress, + DstHwAddress: ethLayer.SrcMAC, + DstProtAddress: arpLayer.SourceProtAddress, + } + + buffer := gopacket.NewSerializeBuffer() + options := gopacket.SerializeOptions{FixLengths: true, ComputeChecksums: true} + if err := gopacket.SerializeLayers(buffer, options, eth, a2); err != nil { + return nil, err + } + + return buffer.Bytes(), nil +} + +func (n *network) handleNATPMPRequest(req UDPPacket) { + if string(req.Payload) == "\x00\x00" { + // https://www.rfc-editor.org/rfc/rfc6886#section-3.2 + + res := make([]byte, 0, 12) + res = append(res, + 0, // version 0 (NAT-PMP) + 128, // response to op 0 (128+0) + 0, 0, // result code success + ) + res = binary.BigEndian.AppendUint32(res, uint32(time.Now().Unix())) + wan4 := n.wanIP.As4() + res = append(res, wan4[:]...) + n.WriteUDPPacketNoNAT(UDPPacket{ + Src: req.Dst, + Dst: req.Src, + Payload: res, + }) + return + } + + log.Printf("TODO: handle NAT-PMP packet % 02x", req.Payload) + // TODO: handle NAT-PMP packet 00 01 00 00 ed 40 00 00 00 00 1c 20 +} + +// UDPPacket is a UDP packet. +// +// For the purposes of this project, a UDP packet +// (not a general IP packet) is the unit to be NAT'ed, +// as that's all that Tailscale uses. +type UDPPacket struct { + Src netip.AddrPort + Dst netip.AddrPort + Payload []byte // everything after UDP header +} + +func (s *Server) WriteStartingBanner(w io.Writer) { + fmt.Fprintf(w, "vnet serving clients:\n") + + for _, n := range s.nodes { + fmt.Fprintf(w, " %v %15v (%v, %v)\n", n.mac, n.lanIP, n.net.wanIP, n.net.natStyle.Load()) + } +} + +type agentConn struct { + node *node + tc *gonet.TCPConn +} + +func (s *Server) addIdleAgentConn(ac *agentConn) { + log.Printf("got agent conn from %v", ac.node.mac) + s.mu.Lock() + defer s.mu.Unlock() + + s.agentConns.Make() + s.agentConns.Add(ac) + + if waiter, ok := s.agentConnWaiter[ac.node]; ok { + select { + case waiter <- struct{}{}: + default: + } + } +} + +func (s *Server) takeAgentConn(ctx context.Context, n *node) (_ *agentConn, ok bool) { + for { + ac, ok := s.takeAgentConnOne(n) + if ok { + return ac, true + } + s.mu.Lock() + ready := make(chan struct{}) + mak.Set(&s.agentConnWaiter, n, ready) + s.mu.Unlock() + select { + case <-ctx.Done(): + return nil, false + case <-ready: + case <-time.After(time.Second): + // Try again regularly anyway, in case we have multiple clients + // trying to hit the same node, or if a race means we weren't in the + // select by the time addIdleAgentConn tried to signal us. + } + } +} + +func (s *Server) takeAgentConnOne(n *node) (_ *agentConn, ok bool) { + s.mu.Lock() + defer s.mu.Unlock() + for ac := range s.agentConns { + if ac.node == n { + s.agentConns.Delete(ac) + return ac, true + } + } + return nil, false +} + +func (s *Server) NodeAgentRoundTripper(ctx context.Context, n *Node) http.RoundTripper { + s.mu.Lock() + defer s.mu.Unlock() + + if rt, ok := s.agentRoundTripper[n.n]; ok { + return rt + } + + var rt = &http.Transport{ + DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { + ac, ok := s.takeAgentConn(ctx, n.n) + if !ok { + return nil, ctx.Err() + } + return ac.tc, nil + }, + } + + mak.Set(&s.agentRoundTripper, n.n, rt) + return rt +} + +func (s *Server) NodeStatus(ctx context.Context, n *Node) ([]byte, error) { + rt := s.NodeAgentRoundTripper(ctx, n) + req, err := http.NewRequestWithContext(ctx, "GET", "http://node/status", nil) + if err != nil { + return nil, err + } + res, err := rt.RoundTrip(req) + if err != nil { + return nil, err + } + defer res.Body.Close() + if res.StatusCode != 200 { + body, _ := io.ReadAll(io.LimitReader(res.Body, 1<<20)) + return nil, fmt.Errorf("status: %v, %s, %v", res.Status, body, res.Header) + } + return io.ReadAll(res.Body) +}