From 602adde5dc1147ab543c8c81fddb8411ca646b58 Mon Sep 17 00:00:00 2001 From: Brad Fitzpatrick Date: Wed, 7 Aug 2024 12:41:26 -0700 Subject: [PATCH] WIP Change-Id: Ib6804b5c56d8d8da4eb850ef09bc86fc3610ba92 Signed-off-by: Brad Fitzpatrick --- cmd/tta/tta.go | 35 +++- cmd/vnet/vnet-main.go | 2 +- gokrazy/Makefile | 3 + tstest/integration/nat/nat_test.go | 277 +++++++++++++++++++++++++++++ tstest/natlab/vnet/conf.go | 9 + tstest/natlab/vnet/vnet.go | 76 +++++--- 6 files changed, 375 insertions(+), 27 deletions(-) create mode 100644 tstest/integration/nat/nat_test.go diff --git a/cmd/tta/tta.go b/cmd/tta/tta.go index 0a11394ef..d33526031 100644 --- a/cmd/tta/tta.go +++ b/cmd/tta/tta.go @@ -19,12 +19,16 @@ import ( "log" "net" "net/http" + "net/http/httputil" + "net/url" "os" "os/exec" "strings" "sync" "time" + "tailscale.com/client/tailscale" + "tailscale.com/util/must" "tailscale.com/util/set" "tailscale.com/version/distro" ) @@ -33,11 +37,15 @@ var ( driverAddr = flag.String("driver", "test-driver.tailscale:8008", "address of the test driver; by default we use the DNS name test-driver.tailscale which is special cased in the emulated network's DNS server") ) -func serveCmd(w http.ResponseWriter, cmd string, args ...string) { +func absify(cmd string) string { if distro.Get() == distro.Gokrazy && !strings.Contains(cmd, "/") { - cmd = "/user/" + cmd + return "/user/" + cmd } - out, err := exec.Command(cmd, args...).CombinedOutput() + return cmd +} + +func serveCmd(w http.ResponseWriter, cmd string, args ...string) { + out, err := exec.Command(absify(cmd), args...).CombinedOutput() w.Header().Set("Content-Type", "text/plain; charset=utf-8") if err != nil { w.Header().Set("Exec-Err", err.Error()) @@ -46,6 +54,14 @@ func serveCmd(w http.ResponseWriter, cmd string, args ...string) { w.Write(out) } +type localClientRoundTripper struct { + lc *tailscale.LocalClient +} + +func (rt localClientRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { + return rt.lc.DoLocalRequest(req) +} + func main() { if distro.Get() == distro.Gokrazy { cmdLine, _ := os.ReadFile("/proc/cmdline") @@ -57,6 +73,12 @@ func main() { } } flag.Parse() + + logc, err := net.Dial("tcp", "9.9.9.9:124") + if err == nil { + log.SetOutput(logc) + } + log.Printf("Tailscale Test Agent running.") var mux http.ServeMux @@ -84,6 +106,11 @@ func main() { } } conns := make(chan net.Conn, 1) + var lc tailscale.LocalClient + rp := httputil.NewSingleHostReverseProxy(must.Get(url.Parse("http://local-tailscaled.sock"))) + rp.Transport = localClientRoundTripper{&lc} + + mux.Handle("/localapi/", rp) mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { io.WriteString(w, "TTA\n") @@ -97,7 +124,7 @@ func main() { }) mux.HandleFunc("/ping", func(w http.ResponseWriter, r *http.Request) { target := r.FormValue("target") - cmd := exec.Command("tailscale", "ping", target) + cmd := exec.Command(absify("tailscale"), "ping", target) w.Header().Set("Content-Type", "text/plain; charset=utf-8") w.(http.Flusher).Flush() cmd.Stdout = w diff --git a/cmd/vnet/vnet-main.go b/cmd/vnet/vnet-main.go index 501c2d25f..3bc512995 100644 --- a/cmd/vnet/vnet-main.go +++ b/cmd/vnet/vnet-main.go @@ -82,7 +82,7 @@ func main() { } for { time.Sleep(5 * time.Second) - continue + //continue getStatus() } }() diff --git a/gokrazy/Makefile b/gokrazy/Makefile index f086dd26b..a0807abe5 100644 --- a/gokrazy/Makefile +++ b/gokrazy/Makefile @@ -6,3 +6,6 @@ image: qemu: image qemu-system-x86_64 -m 1G -drive file=tsapp.img,format=raw -boot d -netdev user,id=user.0 -device virtio-net-pci,netdev=user.0 -serial mon:stdio -audio none + +qcow2: image + qemu-img convert -O qcow2 tsapp.img tsapp.qcow2 diff --git a/tstest/integration/nat/nat_test.go b/tstest/integration/nat/nat_test.go new file mode 100644 index 000000000..d5c43083d --- /dev/null +++ b/tstest/integration/nat/nat_test.go @@ -0,0 +1,277 @@ +package nat + +import ( + "context" + "encoding/json" + "fmt" + "io" + "net" + "net/http" + "net/netip" + "net/url" + "os" + "os/exec" + "path/filepath" + "strings" + "sync" + "testing" + "time" + + "tailscale.com/ipn/ipnstate" + "tailscale.com/tstest/natlab/vnet" +) + +type natTest struct { + tb testing.TB + base string // base image + tempDir string // for qcow2 images + vnet *vnet.Server +} + +func newNatTest(tb testing.TB) *natTest { + nt := &natTest{ + tb: tb, + tempDir: tb.TempDir(), + base: "/Users/bradfitz/src/tailscale.com/gokrazy/tsapp.qcow2", + } + + if _, err := os.Stat(nt.base); err != nil { + tb.Skipf("skipping test; base image %q not found", nt.base) + } + return nt +} + +type addNodeFunc func(c *vnet.Config) *vnet.Node + +func easy(c *vnet.Config) *vnet.Node { + n := c.NumNodes() + 1 + return c.AddNode(c.AddNetwork( + fmt.Sprintf("2.%d.%d.%d", n, n, n), // public IP + fmt.Sprintf("192.168.%d.1/24", n), vnet.EasyNAT)) +} + +func hard(c *vnet.Config) *vnet.Node { + n := c.NumNodes() + 1 + return c.AddNode(c.AddNetwork( + fmt.Sprintf("2.%d.%d.%d", n, n, n), // public IP + fmt.Sprintf("10.0.%d.1/24", n), vnet.HardNAT)) +} + +func (nt *natTest) runTest(node1, node2 addNodeFunc) { + t := nt.tb + + var c vnet.Config + nodes := []*vnet.Node{ + node1(&c), + node2(&c), + } + + var err error + nt.vnet, err = vnet.New(&c) + if err != nil { + t.Fatalf("newServer: %v", err) + } + nt.tb.Cleanup(func() { + nt.vnet.Close() + }) + + var wg sync.WaitGroup // waiting for srv.Accept goroutine + defer wg.Wait() + + sockAddr := filepath.Join(nt.tempDir, "qemu.sock") + srv, err := net.Listen("unix", sockAddr) + if err != nil { + t.Fatalf("Listen: %v", err) + } + defer srv.Close() + + wg.Add(1) + go func() { + defer wg.Done() + for { + c, err := srv.Accept() + if err != nil { + return + } + go nt.vnet.ServeUnixConn(c.(*net.UnixConn), vnet.ProtocolQEMU) + } + }() + + for i, node := range nodes { + disk := fmt.Sprintf("%s/node-%d.qcow2", nt.tempDir, i) + out, err := exec.Command("qemu-img", "create", + "-f", "qcow2", + "-F", "qcow2", + "-b", nt.base, + disk).CombinedOutput() + if err != nil { + t.Fatalf("qemu-img create: %v, %s", err, out) + } + + cmd := exec.Command("qemu-system-x86_64", + "-M", "microvm,isa-serial=off", + "-m", "1G", + "-nodefaults", "-no-user-config", "-nographic", + "-kernel", "/Users/bradfitz/src/github.com/tailscale/gokrazy-kernel/vmlinuz", + "-append", "console=hvc0 root=PARTUUID=60c24cc1-f3f9-427a-8199-dd02023b0001/PARTNROFF=1 ro init=/gokrazy/init panic=10 oops=panic pci=off nousb tsc=unstable clocksource=hpet tailscale-tta=1", + "-drive", "id=blk0,file="+disk+",format=qcow2", + "-device", "virtio-blk-device,drive=blk0", + "-netdev", "stream,id=net0,addr.type=unix,addr.path="+sockAddr, + "-device", "virtio-serial-device", + "-device", "virtio-net-device,netdev=net0,mac="+node.MAC().String(), + "-chardev", "stdio,id=virtiocon0,mux=on", + "-device", "virtconsole,chardev=virtiocon0", + "-mon", "chardev=virtiocon0,mode=readline", + "-audio", "none", + ) + cmd.Stdout = os.Stdout + cmd.Stderr = os.Stderr + if err := cmd.Start(); err != nil { + t.Fatalf("qemu: %v", err) + } + nt.tb.Cleanup(func() { + cmd.Process.Kill() + cmd.Wait() + }) + } + + ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second) + defer cancel() + + c1 := &http.Client{Transport: nt.vnet.NodeAgentRoundTripper(nodes[0])} + c2 := &http.Client{Transport: nt.vnet.NodeAgentRoundTripper(nodes[1])} + + // var lc1 tailscale.LocalClient + // lc1.Dial = nt.vnet.NodeAgentDialer(nodes[0]) + //st, err := lc1.Status(ctx) + + for i, c := range []*http.Client{c1, c2} { + st, err := status(ctx, c) + if err != nil { + t.Fatalf("node%d status: %v", i, err) + } + t.Logf("XXX node%d status: %v", i, st) + if err := up(ctx, c); err != nil { + t.Fatalf("node%d up: %v", i, err) + } + t.Logf("XXX node%d up!", i) + } + + t.Logf("both up1") + + var sts []*ipnstate.Status + for i, c := range []*http.Client{c1, c2} { + st, err := status(ctx, c) + if err != nil { + t.Fatalf("node%d status second time: %v", i, err) + } + if st.BackendState != "Running" { + t.Fatalf("node%d state = %q", i, st.BackendState) + } + if len(st.Peer) != 1 { + t.Fatalf("node%d peer count = %d; want 1", i, len(st.Peer)) + } + sts = append(sts, st) + } + + t.Logf("both up2") + + route, err := ping(ctx, c1, sts[1].Self.TailscaleIPs[0].String()) + t.Logf("ping route: %v, %v", route, err) +} + +func status(ctx context.Context, c *http.Client) (*ipnstate.Status, error) { + req, err := http.NewRequestWithContext(ctx, "GET", "http://unused/status", nil) + if err != nil { + return nil, err + } + res, err := c.Do(req) + if err != nil { + return nil, err + } + defer res.Body.Close() + all, err := io.ReadAll(res.Body) + if err != nil { + return nil, fmt.Errorf("ReadAll: %w", err) + } + var st ipnstate.Status + if err := json.Unmarshal(all, &st); err != nil { + return nil, fmt.Errorf("JSON marshal error: %v; body was %q", err, all) + } + return &st, nil +} + +type routeType string + +const ( + routeDirect routeType = "direct" + routeDERP routeType = "derp" + routeLAN routeType = "lan" +) + +func ping(ctx context.Context, c *http.Client, target string) (routeType, error) { + req, err := http.NewRequestWithContext(ctx, "POST", "http://unused/ping?target="+url.QueryEscape(target), nil) + if err != nil { + return "", err + } + res, err := c.Do(req) + if err != nil { + return "", err + } + defer res.Body.Close() + if res.StatusCode != 200 { + return "", fmt.Errorf("unexpected status code %v", res.Status) + } + all, _ := io.ReadAll(res.Body) + var route routeType + for _, line := range strings.Split(string(all), "\n") { + if strings.Contains(line, " via DERP") { + route = routeDERP + continue + } + // pong from foo (100.82.3.4) via ADDR:PORT in 69ms + if _, rest, ok := strings.Cut(line, " via "); ok { + ipPorStr, _, _ := strings.Cut(rest, " in ") + ipPort, err := netip.ParseAddrPort(ipPorStr) + if err == nil { + if ipPort.Addr().IsPrivate() { + route = routeLAN + } else { + route = routeDirect + } + continue + } + } + } + if route == "" { + return routeType(all), nil + } + return route, nil +} + +func up(ctx context.Context, c *http.Client) error { + req, err := http.NewRequestWithContext(ctx, "GET", "http://unused/up", nil) + if err != nil { + return err + } + res, err := c.Do(req) + if err != nil { + return err + } + defer res.Body.Close() + all, _ := io.ReadAll(res.Body) + if res.StatusCode != 200 { + return fmt.Errorf("unexpected status code %v: %s", res.Status, all) + } + return nil +} + +func TestEasyEasy(t *testing.T) { + nt := newNatTest(t) + nt.runTest(easy, easy) +} + +func TestEasyHard(t *testing.T) { + nt := newNatTest(t) + nt.runTest(easy, hard) +} diff --git a/tstest/natlab/vnet/conf.go b/tstest/natlab/vnet/conf.go index 89dfc9570..8cd91f4cd 100644 --- a/tstest/natlab/vnet/conf.go +++ b/tstest/natlab/vnet/conf.go @@ -27,6 +27,10 @@ type Config struct { networks []*Network } +func (c *Config) NumNodes() int { + return len(c.nodes) +} + // AddNode creates a new node in the world. // // The opts may be of the following types: @@ -110,6 +114,11 @@ type Node struct { nets []*Network } +// MAC returns the MAC address of the node. +func (n *Node) MAC() MAC { + return n.mac +} + // Network returns the first network this node is connected to, // or nil if none. func (n *Node) Network() *Network { diff --git a/tstest/natlab/vnet/vnet.go b/tstest/natlab/vnet/vnet.go index 2f81ec709..1601addaf 100644 --- a/tstest/natlab/vnet/vnet.go +++ b/tstest/natlab/vnet/vnet.go @@ -271,6 +271,20 @@ func (n *network) acceptTCP(r *tcp.ForwarderRequest) { return } + if destPort == 124 { + r.Complete(false) + tc := gonet.NewTCPConn(&wq, ep) + go func() { + defer tc.Close() + bs := bufio.NewScanner(tc) + for bs.Scan() { + line := bs.Text() + log.Printf("LOG from guest: %s", line) + } + }() + return + } + if destPort == 8008 && destIP == fakeTestAgentIP { r.Complete(false) tc := gonet.NewTCPConn(&wq, ep) @@ -448,6 +462,7 @@ func newDERPServer() *derpServer { type Server struct { shutdownCtx context.Context shutdownCancel context.CancelFunc + blendReality bool derpIPs set.Set[netip.Addr] @@ -459,12 +474,14 @@ type Server struct { control *testcontrol.Server derps []*derpServer - mu sync.Mutex - agentConnWaiter map[*node]chan<- struct{} // signaled after added to set - agentConns set.Set[*agentConn] // not keyed by node; should be small/cheap enough to scan all - agentRoundTripper map[*node]*http.Transport + mu sync.Mutex + agentConnWaiter map[*node]chan<- struct{} // signaled after added to set + agentConns set.Set[*agentConn] // not keyed by node; should be small/cheap enough to scan all + agentDialer map[*node]DialFunc } +type DialFunc func(ctx context.Context, network, address string) (net.Conn, error) + var derpMap = &tailcfg.DERPMap{ Regions: map[int]*tailcfg.DERPRegion{ 1: { @@ -532,6 +549,10 @@ func New(c *Config) (*Server, error) { return s, nil } +func (s *Server) Close() { + s.shutdownCancel() +} + func (s *Server) HWAddr(mac MAC) net.HardwareAddr { // TODO: cache return net.HardwareAddr(mac[:]) @@ -655,7 +676,7 @@ func (s *Server) routeUDPPacket(up UDPPacket) { if up.Dst.Port() == stunPort { // TODO(bradfitz): fake latency; time.AfterFunc the response if res, ok := makeSTUNReply(up); ok { - log.Printf("STUN reply: %+v", res) + //log.Printf("STUN reply: %+v", res) s.routeUDPPacket(res) } else { log.Printf("weird: STUN packet not handled") @@ -1015,15 +1036,18 @@ func (s *Server) shouldInterceptTCP(pkt gopacket.Packet) bool { if !ok { return false } - if tcp.DstPort == 123 { + if tcp.DstPort == 123 || tcp.DstPort == 124 { return true } dstIP, _ := netip.AddrFromSlice(ipv4.DstIP.To4()) if tcp.DstPort == 80 || tcp.DstPort == 443 { switch dstIP { - case fakeProxyControlplaneIP, fakeControlIP, fakeDERP1IP, fakeDERP2IP: + case fakeControlIP, fakeDERP1IP, fakeDERP2IP: return true } + if dstIP == fakeProxyControlplaneIP { + return s.blendReality + } if s.derpIPs.Contains(dstIP) { return true } @@ -1294,12 +1318,15 @@ func (s *Server) takeAgentConn(ctx context.Context, n *node) (_ *agentConn, ok b for { ac, ok := s.takeAgentConnOne(n) if ok { + log.Printf("got agent conn for %v", n.mac) return ac, true } s.mu.Lock() ready := make(chan struct{}) mak.Set(&s.agentConnWaiter, n, ready) s.mu.Unlock() + + log.Printf("waiting for agent conn for %v", n.mac) select { case <-ctx.Done(): return nil, false @@ -1318,36 +1345,41 @@ func (s *Server) takeAgentConnOne(n *node) (_ *agentConn, ok bool) { for ac := range s.agentConns { if ac.node == n { s.agentConns.Delete(ac) + log.Printf("XXX takeAgentConnOne HIT for %v", n.mac) return ac, true } } + log.Printf("XXX takeAgentConnOne MISS for %v", n.mac) return nil, false } -func (s *Server) NodeAgentRoundTripper(ctx context.Context, n *Node) http.RoundTripper { +func (s *Server) NodeAgentDialer(n *Node) DialFunc { s.mu.Lock() defer s.mu.Unlock() - if rt, ok := s.agentRoundTripper[n.n]; ok { - return rt + if d, ok := s.agentDialer[n.n]; ok { + return d } - - var rt = &http.Transport{ - DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { - ac, ok := s.takeAgentConn(ctx, n.n) - if !ok { - return nil, ctx.Err() - } - return ac.tc, nil - }, + d := func(ctx context.Context, network, addr string) (net.Conn, error) { + ac, ok := s.takeAgentConn(ctx, n.n) + if !ok { + return nil, ctx.Err() + } + return ac.tc, nil } + mak.Set(&s.agentDialer, n.n, d) + return d +} - mak.Set(&s.agentRoundTripper, n.n, rt) - return rt +func (s *Server) NodeAgentRoundTripper(n *Node) http.RoundTripper { + return &http.Transport{ + DisableKeepAlives: true, // XXX + DialContext: s.NodeAgentDialer(n), + } } func (s *Server) NodeStatus(ctx context.Context, n *Node) ([]byte, error) { - rt := s.NodeAgentRoundTripper(ctx, n) + rt := s.NodeAgentRoundTripper(n) req, err := http.NewRequestWithContext(ctx, "GET", "http://node/status", nil) if err != nil { return nil, err