From 2f00bead5a81ace639cf76155b84a945d3b2a56f Mon Sep 17 00:00:00 2001 From: Maisem Ali Date: Wed, 7 Aug 2024 21:31:50 -0700 Subject: [PATCH] natlab: add NodeAgentClient This adds a new NodeAgentClient type that can be used to invoke the LocalAPI using the LocalClient instead of handcrafted URLs. However, there are certain cases where it does make sense for the node agent to provide more functionality than whats possible with just the LocalClient, as such it also exposes a http.Client to make requests directly. Signed-off-by: Maisem Ali --- cmd/tta/tta.go | 2 + cmd/vnet/vnet-main.go | 4 +- tstest/integration/nat/nat_test.go | 111 +++++++++-------------------- tstest/natlab/vnet/vnet.go | 40 +++++------ 4 files changed, 56 insertions(+), 101 deletions(-) diff --git a/cmd/tta/tta.go b/cmd/tta/tta.go index c11807992..99205d1f5 100644 --- a/cmd/tta/tta.go +++ b/cmd/tta/tta.go @@ -65,6 +65,8 @@ type localClientRoundTripper struct { } func (rt localClientRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { + req = req.Clone(req.Context()) + req.RequestURI = "" return rt.lc.DoLocalRequest(req) } diff --git a/cmd/vnet/vnet-main.go b/cmd/vnet/vnet-main.go index 3bc512995..acf85db0e 100644 --- a/cmd/vnet/vnet-main.go +++ b/cmd/vnet/vnet-main.go @@ -68,12 +68,12 @@ func main() { } s.WriteStartingBanner(os.Stdout) - + nc := s.NodeAgentClient(node1) go func() { getStatus := func() { ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) defer cancel() - st, err := s.NodeStatus(ctx, node1) + st, err := nc.Status(ctx) if err != nil { log.Printf("NodeStatus: %v", err) return diff --git a/tstest/integration/nat/nat_test.go b/tstest/integration/nat/nat_test.go index 3717012f4..aee7ea526 100644 --- a/tstest/integration/nat/nat_test.go +++ b/tstest/integration/nat/nat_test.go @@ -2,23 +2,23 @@ import ( "context" - "encoding/json" + "errors" "fmt" "io" "net" "net/http" "net/netip" - "net/url" "os" "os/exec" "path/filepath" - "strings" "sync" "testing" "time" "golang.org/x/sync/errgroup" + "tailscale.com/client/tailscale" "tailscale.com/ipn/ipnstate" + "tailscale.com/tailcfg" "tailscale.com/tstest/natlab/vnet" ) @@ -33,7 +33,7 @@ func newNatTest(tb testing.TB) *natTest { nt := &natTest{ tb: tb, tempDir: tb.TempDir(), - base: "/Users/bradfitz/src/tailscale.com/gokrazy/tsapp.qcow2", + base: "/Users/maisem/dev/tailscale.com/gokrazy/tsapp.qcow2", } if _, err := os.Stat(nt.base); err != nil { @@ -113,7 +113,7 @@ func (nt *natTest) runTest(node1, node2 addNodeFunc) { "-M", "microvm,isa-serial=off", "-m", "1G", "-nodefaults", "-no-user-config", "-nographic", - "-kernel", "/Users/bradfitz/src/github.com/tailscale/gokrazy-kernel/vmlinuz", + "-kernel", "/Users/maisem/dev/github.com/tailscale/gokrazy-kernel/vmlinuz", "-append", "console=hvc0 root=PARTUUID=60c24cc1-f3f9-427a-8199-dd02023b0001/PARTNROFF=1 ro init=/gokrazy/init panic=10 oops=panic pci=off nousb tsc=unstable clocksource=hpet tailscale-tta=1", "-drive", "id=blk0,file="+disk+",format=qcow2", "-device", "virtio-blk-device,drive=blk0", @@ -139,15 +139,16 @@ func (nt *natTest) runTest(node1, node2 addNodeFunc) { ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second) defer cancel() - c1 := &http.Client{Transport: nt.vnet.NodeAgentRoundTripper(nodes[0])} - c2 := &http.Client{Transport: nt.vnet.NodeAgentRoundTripper(nodes[1])} + lc1 := nt.vnet.NodeAgentClient(nodes[0]) + lc2 := nt.vnet.NodeAgentClient(nodes[1]) + clients := []*vnet.NodeAgentClient{lc1, lc2} var eg errgroup.Group var sts [2]*ipnstate.Status - for i, c := range []*http.Client{c1, c2} { + for i, c := range clients { i, c := i, c eg.Go(func() error { - st, err := status(ctx, c) + st, err := c.Status(ctx) if err != nil { return fmt.Errorf("node%d status: %w", i, err) } @@ -156,7 +157,7 @@ func (nt *natTest) runTest(node1, node2 addNodeFunc) { return fmt.Errorf("node%d up: %w", i, err) } t.Logf("node%d up!", i) - st, err = status(ctx, c) + st, err = c.Status(ctx) if err != nil { return fmt.Errorf("node%d status: %w", i, err) } @@ -173,85 +174,40 @@ func (nt *natTest) runTest(node1, node2 addNodeFunc) { t.Fatalf("initial setup: %v", err) } - route, err := ping(ctx, c1, sts[1].Self.TailscaleIPs[0].String()) + route, err := ping(ctx, lc1, sts[1].Self.TailscaleIPs[0]) t.Logf("ping route: %v, %v", route, err) } -func status(ctx context.Context, c *http.Client) (*ipnstate.Status, error) { - req, err := http.NewRequestWithContext(ctx, "GET", "http://unused/status", nil) - if err != nil { - return nil, err - } - res, err := c.Do(req) - if err != nil { - return nil, err - } - defer res.Body.Close() - all, err := io.ReadAll(res.Body) - if err != nil { - return nil, fmt.Errorf("ReadAll: %w", err) - } - var st ipnstate.Status - if err := json.Unmarshal(all, &st); err != nil { - return nil, fmt.Errorf("JSON marshal error: %v; body was %q", err, all) - } - return &st, nil -} - -type routeType string - -const ( - routeDirect routeType = "direct" - routeDERP routeType = "derp" - routeLAN routeType = "lan" -) - -func ping(ctx context.Context, c *http.Client, target string) (routeType, error) { - req, err := http.NewRequestWithContext(ctx, "POST", "http://unused/ping?target="+url.QueryEscape(target), nil) - if err != nil { - return "", err - } - res, err := c.Do(req) - if err != nil { - return "", err - } - defer res.Body.Close() - if res.StatusCode != 200 { - return "", fmt.Errorf("unexpected status code %v", res.Status) - } - all, _ := io.ReadAll(res.Body) - var route routeType - for _, line := range strings.Split(string(all), "\n") { - if strings.Contains(line, " via DERP") { - route = routeDERP - continue - } - // pong from foo (100.82.3.4) via ADDR:PORT in 69ms - if _, rest, ok := strings.Cut(line, " via "); ok { - ipPorStr, _, _ := strings.Cut(rest, " in ") - ipPort, err := netip.ParseAddrPort(ipPorStr) - if err == nil { - if ipPort.Addr().IsPrivate() { - route = routeLAN - } else { - route = routeDirect - } - continue +func ping(ctx context.Context, c *vnet.NodeAgentClient, target netip.Addr) (*ipnstate.PingResult, error) { + n := 0 + var res *ipnstate.PingResult + anyPong := false + for { + n++ + pr, err := c.PingWithOpts(ctx, target, tailcfg.PingDisco, tailscale.PingOpts{}) + if err != nil { + if anyPong { + return res, nil } + return nil, err } + if pr.Err != "" { + return nil, errors.New(pr.Err) + } + if pr.DERPRegionID == 0 { + return pr, nil + } + time.Sleep(time.Second) + res = pr } - if route == "" { - return routeType(all), nil - } - return route, nil } -func up(ctx context.Context, c *http.Client) error { +func up(ctx context.Context, c *vnet.NodeAgentClient) error { req, err := http.NewRequestWithContext(ctx, "GET", "http://unused/up", nil) if err != nil { return err } - res, err := c.Do(req) + res, err := c.HTTPClient.Do(req) if err != nil { return err } @@ -269,6 +225,7 @@ func TestEasyEasy(t *testing.T) { } func TestEasyHard(t *testing.T) { + t.Skip() nt := newNatTest(t) nt.runTest(easy, hard) } diff --git a/tstest/natlab/vnet/vnet.go b/tstest/natlab/vnet/vnet.go index b02a9a945..f0a80c35b 100644 --- a/tstest/natlab/vnet/vnet.go +++ b/tstest/natlab/vnet/vnet.go @@ -46,6 +46,7 @@ "gvisor.dev/gvisor/pkg/tcpip/transport/icmp" "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" "gvisor.dev/gvisor/pkg/waiter" + "tailscale.com/client/tailscale" "tailscale.com/derp" "tailscale.com/derp/derphttp" "tailscale.com/net/netutil" @@ -279,7 +280,7 @@ func (n *network) acceptTCP(r *tcp.ForwarderRequest) { bs := bufio.NewScanner(tc) for bs.Scan() { line := bs.Text() - log.Printf("LOG from guest: %s", line) + log.Printf("LOG from guest %v: %s", clientRemoteIP, line) } }() return @@ -1356,6 +1357,11 @@ func (s *Server) takeAgentConnOne(n *node) (_ *agentConn, ok bool) { return nil, false } +type NodeAgentClient struct { + *tailscale.LocalClient + HTTPClient *http.Client +} + func (s *Server) NodeAgentDialer(n *Node) DialFunc { s.mu.Lock() defer s.mu.Unlock() @@ -1374,26 +1380,16 @@ func (s *Server) NodeAgentDialer(n *Node) DialFunc { return d } -func (s *Server) NodeAgentRoundTripper(n *Node) http.RoundTripper { - return &http.Transport{ - DialContext: s.NodeAgentDialer(n), +func (s *Server) NodeAgentClient(n *Node) *NodeAgentClient { + d := s.NodeAgentDialer(n) + return &NodeAgentClient{ + LocalClient: &tailscale.LocalClient{ + Dial: d, + }, + HTTPClient: &http.Client{ + Transport: &http.Transport{ + DialContext: d, + }, + }, } } - -func (s *Server) NodeStatus(ctx context.Context, n *Node) ([]byte, error) { - rt := s.NodeAgentRoundTripper(n) - req, err := http.NewRequestWithContext(ctx, "GET", "http://node/status", nil) - if err != nil { - return nil, err - } - res, err := rt.RoundTrip(req) - if err != nil { - return nil, err - } - defer res.Body.Close() - if res.StatusCode != 200 { - body, _ := io.ReadAll(io.LimitReader(res.Body, 1<<20)) - return nil, fmt.Errorf("status: %v, %s, %v", res.Status, body, res.Header) - } - return io.ReadAll(res.Body) -}