diff --git a/cmd/tta/tta.go b/cmd/tta/tta.go index c11807992..99205d1f5 100644 --- a/cmd/tta/tta.go +++ b/cmd/tta/tta.go @@ -65,6 +65,8 @@ type localClientRoundTripper struct { } func (rt localClientRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { + req = req.Clone(req.Context()) + req.RequestURI = "" return rt.lc.DoLocalRequest(req) } diff --git a/cmd/vnet/vnet-main.go b/cmd/vnet/vnet-main.go index 3bc512995..acf85db0e 100644 --- a/cmd/vnet/vnet-main.go +++ b/cmd/vnet/vnet-main.go @@ -68,12 +68,12 @@ func main() { } s.WriteStartingBanner(os.Stdout) - + nc := s.NodeAgentClient(node1) go func() { getStatus := func() { ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) defer cancel() - st, err := s.NodeStatus(ctx, node1) + st, err := nc.Status(ctx) if err != nil { log.Printf("NodeStatus: %v", err) return diff --git a/tstest/integration/nat/nat_test.go b/tstest/integration/nat/nat_test.go index 3717012f4..aee7ea526 100644 --- a/tstest/integration/nat/nat_test.go +++ b/tstest/integration/nat/nat_test.go @@ -2,23 +2,23 @@ import ( "context" - "encoding/json" + "errors" "fmt" "io" "net" "net/http" "net/netip" - "net/url" "os" "os/exec" "path/filepath" - "strings" "sync" "testing" "time" "golang.org/x/sync/errgroup" + "tailscale.com/client/tailscale" "tailscale.com/ipn/ipnstate" + "tailscale.com/tailcfg" "tailscale.com/tstest/natlab/vnet" ) @@ -33,7 +33,7 @@ func newNatTest(tb testing.TB) *natTest { nt := &natTest{ tb: tb, tempDir: tb.TempDir(), - base: "/Users/bradfitz/src/tailscale.com/gokrazy/tsapp.qcow2", + base: "/Users/maisem/dev/tailscale.com/gokrazy/tsapp.qcow2", } if _, err := os.Stat(nt.base); err != nil { @@ -113,7 +113,7 @@ func (nt *natTest) runTest(node1, node2 addNodeFunc) { "-M", "microvm,isa-serial=off", "-m", "1G", "-nodefaults", "-no-user-config", "-nographic", - "-kernel", "/Users/bradfitz/src/github.com/tailscale/gokrazy-kernel/vmlinuz", + "-kernel", "/Users/maisem/dev/github.com/tailscale/gokrazy-kernel/vmlinuz", "-append", "console=hvc0 root=PARTUUID=60c24cc1-f3f9-427a-8199-dd02023b0001/PARTNROFF=1 ro init=/gokrazy/init panic=10 oops=panic pci=off nousb tsc=unstable clocksource=hpet tailscale-tta=1", "-drive", "id=blk0,file="+disk+",format=qcow2", "-device", "virtio-blk-device,drive=blk0", @@ -139,15 +139,16 @@ func (nt *natTest) runTest(node1, node2 addNodeFunc) { ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second) defer cancel() - c1 := &http.Client{Transport: nt.vnet.NodeAgentRoundTripper(nodes[0])} - c2 := &http.Client{Transport: nt.vnet.NodeAgentRoundTripper(nodes[1])} + lc1 := nt.vnet.NodeAgentClient(nodes[0]) + lc2 := nt.vnet.NodeAgentClient(nodes[1]) + clients := []*vnet.NodeAgentClient{lc1, lc2} var eg errgroup.Group var sts [2]*ipnstate.Status - for i, c := range []*http.Client{c1, c2} { + for i, c := range clients { i, c := i, c eg.Go(func() error { - st, err := status(ctx, c) + st, err := c.Status(ctx) if err != nil { return fmt.Errorf("node%d status: %w", i, err) } @@ -156,7 +157,7 @@ func (nt *natTest) runTest(node1, node2 addNodeFunc) { return fmt.Errorf("node%d up: %w", i, err) } t.Logf("node%d up!", i) - st, err = status(ctx, c) + st, err = c.Status(ctx) if err != nil { return fmt.Errorf("node%d status: %w", i, err) } @@ -173,85 +174,40 @@ func (nt *natTest) runTest(node1, node2 addNodeFunc) { t.Fatalf("initial setup: %v", err) } - route, err := ping(ctx, c1, sts[1].Self.TailscaleIPs[0].String()) + route, err := ping(ctx, lc1, sts[1].Self.TailscaleIPs[0]) t.Logf("ping route: %v, %v", route, err) } -func status(ctx context.Context, c *http.Client) (*ipnstate.Status, error) { - req, err := http.NewRequestWithContext(ctx, "GET", "http://unused/status", nil) - if err != nil { - return nil, err - } - res, err := c.Do(req) - if err != nil { - return nil, err - } - defer res.Body.Close() - all, err := io.ReadAll(res.Body) - if err != nil { - return nil, fmt.Errorf("ReadAll: %w", err) - } - var st ipnstate.Status - if err := json.Unmarshal(all, &st); err != nil { - return nil, fmt.Errorf("JSON marshal error: %v; body was %q", err, all) - } - return &st, nil -} - -type routeType string - -const ( - routeDirect routeType = "direct" - routeDERP routeType = "derp" - routeLAN routeType = "lan" -) - -func ping(ctx context.Context, c *http.Client, target string) (routeType, error) { - req, err := http.NewRequestWithContext(ctx, "POST", "http://unused/ping?target="+url.QueryEscape(target), nil) - if err != nil { - return "", err - } - res, err := c.Do(req) - if err != nil { - return "", err - } - defer res.Body.Close() - if res.StatusCode != 200 { - return "", fmt.Errorf("unexpected status code %v", res.Status) - } - all, _ := io.ReadAll(res.Body) - var route routeType - for _, line := range strings.Split(string(all), "\n") { - if strings.Contains(line, " via DERP") { - route = routeDERP - continue - } - // pong from foo (100.82.3.4) via ADDR:PORT in 69ms - if _, rest, ok := strings.Cut(line, " via "); ok { - ipPorStr, _, _ := strings.Cut(rest, " in ") - ipPort, err := netip.ParseAddrPort(ipPorStr) - if err == nil { - if ipPort.Addr().IsPrivate() { - route = routeLAN - } else { - route = routeDirect - } - continue +func ping(ctx context.Context, c *vnet.NodeAgentClient, target netip.Addr) (*ipnstate.PingResult, error) { + n := 0 + var res *ipnstate.PingResult + anyPong := false + for { + n++ + pr, err := c.PingWithOpts(ctx, target, tailcfg.PingDisco, tailscale.PingOpts{}) + if err != nil { + if anyPong { + return res, nil } + return nil, err } + if pr.Err != "" { + return nil, errors.New(pr.Err) + } + if pr.DERPRegionID == 0 { + return pr, nil + } + time.Sleep(time.Second) + res = pr } - if route == "" { - return routeType(all), nil - } - return route, nil } -func up(ctx context.Context, c *http.Client) error { +func up(ctx context.Context, c *vnet.NodeAgentClient) error { req, err := http.NewRequestWithContext(ctx, "GET", "http://unused/up", nil) if err != nil { return err } - res, err := c.Do(req) + res, err := c.HTTPClient.Do(req) if err != nil { return err } @@ -269,6 +225,7 @@ func TestEasyEasy(t *testing.T) { } func TestEasyHard(t *testing.T) { + t.Skip() nt := newNatTest(t) nt.runTest(easy, hard) } diff --git a/tstest/natlab/vnet/vnet.go b/tstest/natlab/vnet/vnet.go index b02a9a945..f0a80c35b 100644 --- a/tstest/natlab/vnet/vnet.go +++ b/tstest/natlab/vnet/vnet.go @@ -46,6 +46,7 @@ "gvisor.dev/gvisor/pkg/tcpip/transport/icmp" "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" "gvisor.dev/gvisor/pkg/waiter" + "tailscale.com/client/tailscale" "tailscale.com/derp" "tailscale.com/derp/derphttp" "tailscale.com/net/netutil" @@ -279,7 +280,7 @@ func (n *network) acceptTCP(r *tcp.ForwarderRequest) { bs := bufio.NewScanner(tc) for bs.Scan() { line := bs.Text() - log.Printf("LOG from guest: %s", line) + log.Printf("LOG from guest %v: %s", clientRemoteIP, line) } }() return @@ -1356,6 +1357,11 @@ func (s *Server) takeAgentConnOne(n *node) (_ *agentConn, ok bool) { return nil, false } +type NodeAgentClient struct { + *tailscale.LocalClient + HTTPClient *http.Client +} + func (s *Server) NodeAgentDialer(n *Node) DialFunc { s.mu.Lock() defer s.mu.Unlock() @@ -1374,26 +1380,16 @@ func (s *Server) NodeAgentDialer(n *Node) DialFunc { return d } -func (s *Server) NodeAgentRoundTripper(n *Node) http.RoundTripper { - return &http.Transport{ - DialContext: s.NodeAgentDialer(n), +func (s *Server) NodeAgentClient(n *Node) *NodeAgentClient { + d := s.NodeAgentDialer(n) + return &NodeAgentClient{ + LocalClient: &tailscale.LocalClient{ + Dial: d, + }, + HTTPClient: &http.Client{ + Transport: &http.Transport{ + DialContext: d, + }, + }, } } - -func (s *Server) NodeStatus(ctx context.Context, n *Node) ([]byte, error) { - rt := s.NodeAgentRoundTripper(n) - req, err := http.NewRequestWithContext(ctx, "GET", "http://node/status", nil) - if err != nil { - return nil, err - } - res, err := rt.RoundTrip(req) - if err != nil { - return nil, err - } - defer res.Body.Close() - if res.StatusCode != 200 { - body, _ := io.ReadAll(io.LimitReader(res.Body, 1<<20)) - return nil, fmt.Errorf("status: %v, %s, %v", res.Status, body, res.Header) - } - return io.ReadAll(res.Body) -}