natlab: add NodeAgentClient

This adds a new NodeAgentClient type that can be used to
invoke the LocalAPI using the LocalClient instead of
handcrafted URLs. However, there are certain cases where
it does make sense for the node agent to provide more
functionality than whats possible with just the LocalClient,
as such it also exposes a http.Client to make requests directly.

Signed-off-by: Maisem Ali <maisem@tailscale.com>
This commit is contained in:
Maisem Ali 2024-08-07 21:31:50 -07:00
parent 19a7b2143b
commit 2f00bead5a
4 changed files with 56 additions and 101 deletions

View File

@ -65,6 +65,8 @@ type localClientRoundTripper struct {
} }
func (rt localClientRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { func (rt localClientRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
req = req.Clone(req.Context())
req.RequestURI = ""
return rt.lc.DoLocalRequest(req) return rt.lc.DoLocalRequest(req)
} }

View File

@ -68,12 +68,12 @@ func main() {
} }
s.WriteStartingBanner(os.Stdout) s.WriteStartingBanner(os.Stdout)
nc := s.NodeAgentClient(node1)
go func() { go func() {
getStatus := func() { getStatus := func() {
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel() defer cancel()
st, err := s.NodeStatus(ctx, node1) st, err := nc.Status(ctx)
if err != nil { if err != nil {
log.Printf("NodeStatus: %v", err) log.Printf("NodeStatus: %v", err)
return return

View File

@ -2,23 +2,23 @@
import ( import (
"context" "context"
"encoding/json" "errors"
"fmt" "fmt"
"io" "io"
"net" "net"
"net/http" "net/http"
"net/netip" "net/netip"
"net/url"
"os" "os"
"os/exec" "os/exec"
"path/filepath" "path/filepath"
"strings"
"sync" "sync"
"testing" "testing"
"time" "time"
"golang.org/x/sync/errgroup" "golang.org/x/sync/errgroup"
"tailscale.com/client/tailscale"
"tailscale.com/ipn/ipnstate" "tailscale.com/ipn/ipnstate"
"tailscale.com/tailcfg"
"tailscale.com/tstest/natlab/vnet" "tailscale.com/tstest/natlab/vnet"
) )
@ -33,7 +33,7 @@ func newNatTest(tb testing.TB) *natTest {
nt := &natTest{ nt := &natTest{
tb: tb, tb: tb,
tempDir: tb.TempDir(), tempDir: tb.TempDir(),
base: "/Users/bradfitz/src/tailscale.com/gokrazy/tsapp.qcow2", base: "/Users/maisem/dev/tailscale.com/gokrazy/tsapp.qcow2",
} }
if _, err := os.Stat(nt.base); err != nil { if _, err := os.Stat(nt.base); err != nil {
@ -113,7 +113,7 @@ func (nt *natTest) runTest(node1, node2 addNodeFunc) {
"-M", "microvm,isa-serial=off", "-M", "microvm,isa-serial=off",
"-m", "1G", "-m", "1G",
"-nodefaults", "-no-user-config", "-nographic", "-nodefaults", "-no-user-config", "-nographic",
"-kernel", "/Users/bradfitz/src/github.com/tailscale/gokrazy-kernel/vmlinuz", "-kernel", "/Users/maisem/dev/github.com/tailscale/gokrazy-kernel/vmlinuz",
"-append", "console=hvc0 root=PARTUUID=60c24cc1-f3f9-427a-8199-dd02023b0001/PARTNROFF=1 ro init=/gokrazy/init panic=10 oops=panic pci=off nousb tsc=unstable clocksource=hpet tailscale-tta=1", "-append", "console=hvc0 root=PARTUUID=60c24cc1-f3f9-427a-8199-dd02023b0001/PARTNROFF=1 ro init=/gokrazy/init panic=10 oops=panic pci=off nousb tsc=unstable clocksource=hpet tailscale-tta=1",
"-drive", "id=blk0,file="+disk+",format=qcow2", "-drive", "id=blk0,file="+disk+",format=qcow2",
"-device", "virtio-blk-device,drive=blk0", "-device", "virtio-blk-device,drive=blk0",
@ -139,15 +139,16 @@ func (nt *natTest) runTest(node1, node2 addNodeFunc) {
ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second) ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second)
defer cancel() defer cancel()
c1 := &http.Client{Transport: nt.vnet.NodeAgentRoundTripper(nodes[0])} lc1 := nt.vnet.NodeAgentClient(nodes[0])
c2 := &http.Client{Transport: nt.vnet.NodeAgentRoundTripper(nodes[1])} lc2 := nt.vnet.NodeAgentClient(nodes[1])
clients := []*vnet.NodeAgentClient{lc1, lc2}
var eg errgroup.Group var eg errgroup.Group
var sts [2]*ipnstate.Status var sts [2]*ipnstate.Status
for i, c := range []*http.Client{c1, c2} { for i, c := range clients {
i, c := i, c i, c := i, c
eg.Go(func() error { eg.Go(func() error {
st, err := status(ctx, c) st, err := c.Status(ctx)
if err != nil { if err != nil {
return fmt.Errorf("node%d status: %w", i, err) return fmt.Errorf("node%d status: %w", i, err)
} }
@ -156,7 +157,7 @@ func (nt *natTest) runTest(node1, node2 addNodeFunc) {
return fmt.Errorf("node%d up: %w", i, err) return fmt.Errorf("node%d up: %w", i, err)
} }
t.Logf("node%d up!", i) t.Logf("node%d up!", i)
st, err = status(ctx, c) st, err = c.Status(ctx)
if err != nil { if err != nil {
return fmt.Errorf("node%d status: %w", i, err) return fmt.Errorf("node%d status: %w", i, err)
} }
@ -173,85 +174,40 @@ func (nt *natTest) runTest(node1, node2 addNodeFunc) {
t.Fatalf("initial setup: %v", err) t.Fatalf("initial setup: %v", err)
} }
route, err := ping(ctx, c1, sts[1].Self.TailscaleIPs[0].String()) route, err := ping(ctx, lc1, sts[1].Self.TailscaleIPs[0])
t.Logf("ping route: %v, %v", route, err) t.Logf("ping route: %v, %v", route, err)
} }
func status(ctx context.Context, c *http.Client) (*ipnstate.Status, error) { func ping(ctx context.Context, c *vnet.NodeAgentClient, target netip.Addr) (*ipnstate.PingResult, error) {
req, err := http.NewRequestWithContext(ctx, "GET", "http://unused/status", nil) n := 0
if err != nil { var res *ipnstate.PingResult
return nil, err anyPong := false
} for {
res, err := c.Do(req) n++
if err != nil { pr, err := c.PingWithOpts(ctx, target, tailcfg.PingDisco, tailscale.PingOpts{})
return nil, err if err != nil {
} if anyPong {
defer res.Body.Close() return res, nil
all, err := io.ReadAll(res.Body)
if err != nil {
return nil, fmt.Errorf("ReadAll: %w", err)
}
var st ipnstate.Status
if err := json.Unmarshal(all, &st); err != nil {
return nil, fmt.Errorf("JSON marshal error: %v; body was %q", err, all)
}
return &st, nil
}
type routeType string
const (
routeDirect routeType = "direct"
routeDERP routeType = "derp"
routeLAN routeType = "lan"
)
func ping(ctx context.Context, c *http.Client, target string) (routeType, error) {
req, err := http.NewRequestWithContext(ctx, "POST", "http://unused/ping?target="+url.QueryEscape(target), nil)
if err != nil {
return "", err
}
res, err := c.Do(req)
if err != nil {
return "", err
}
defer res.Body.Close()
if res.StatusCode != 200 {
return "", fmt.Errorf("unexpected status code %v", res.Status)
}
all, _ := io.ReadAll(res.Body)
var route routeType
for _, line := range strings.Split(string(all), "\n") {
if strings.Contains(line, " via DERP") {
route = routeDERP
continue
}
// pong from foo (100.82.3.4) via ADDR:PORT in 69ms
if _, rest, ok := strings.Cut(line, " via "); ok {
ipPorStr, _, _ := strings.Cut(rest, " in ")
ipPort, err := netip.ParseAddrPort(ipPorStr)
if err == nil {
if ipPort.Addr().IsPrivate() {
route = routeLAN
} else {
route = routeDirect
}
continue
} }
return nil, err
} }
if pr.Err != "" {
return nil, errors.New(pr.Err)
}
if pr.DERPRegionID == 0 {
return pr, nil
}
time.Sleep(time.Second)
res = pr
} }
if route == "" {
return routeType(all), nil
}
return route, nil
} }
func up(ctx context.Context, c *http.Client) error { func up(ctx context.Context, c *vnet.NodeAgentClient) error {
req, err := http.NewRequestWithContext(ctx, "GET", "http://unused/up", nil) req, err := http.NewRequestWithContext(ctx, "GET", "http://unused/up", nil)
if err != nil { if err != nil {
return err return err
} }
res, err := c.Do(req) res, err := c.HTTPClient.Do(req)
if err != nil { if err != nil {
return err return err
} }
@ -269,6 +225,7 @@ func TestEasyEasy(t *testing.T) {
} }
func TestEasyHard(t *testing.T) { func TestEasyHard(t *testing.T) {
t.Skip()
nt := newNatTest(t) nt := newNatTest(t)
nt.runTest(easy, hard) nt.runTest(easy, hard)
} }

View File

@ -46,6 +46,7 @@
"gvisor.dev/gvisor/pkg/tcpip/transport/icmp" "gvisor.dev/gvisor/pkg/tcpip/transport/icmp"
"gvisor.dev/gvisor/pkg/tcpip/transport/tcp" "gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
"gvisor.dev/gvisor/pkg/waiter" "gvisor.dev/gvisor/pkg/waiter"
"tailscale.com/client/tailscale"
"tailscale.com/derp" "tailscale.com/derp"
"tailscale.com/derp/derphttp" "tailscale.com/derp/derphttp"
"tailscale.com/net/netutil" "tailscale.com/net/netutil"
@ -279,7 +280,7 @@ func (n *network) acceptTCP(r *tcp.ForwarderRequest) {
bs := bufio.NewScanner(tc) bs := bufio.NewScanner(tc)
for bs.Scan() { for bs.Scan() {
line := bs.Text() line := bs.Text()
log.Printf("LOG from guest: %s", line) log.Printf("LOG from guest %v: %s", clientRemoteIP, line)
} }
}() }()
return return
@ -1356,6 +1357,11 @@ func (s *Server) takeAgentConnOne(n *node) (_ *agentConn, ok bool) {
return nil, false return nil, false
} }
type NodeAgentClient struct {
*tailscale.LocalClient
HTTPClient *http.Client
}
func (s *Server) NodeAgentDialer(n *Node) DialFunc { func (s *Server) NodeAgentDialer(n *Node) DialFunc {
s.mu.Lock() s.mu.Lock()
defer s.mu.Unlock() defer s.mu.Unlock()
@ -1374,26 +1380,16 @@ func (s *Server) NodeAgentDialer(n *Node) DialFunc {
return d return d
} }
func (s *Server) NodeAgentRoundTripper(n *Node) http.RoundTripper { func (s *Server) NodeAgentClient(n *Node) *NodeAgentClient {
return &http.Transport{ d := s.NodeAgentDialer(n)
DialContext: s.NodeAgentDialer(n), return &NodeAgentClient{
LocalClient: &tailscale.LocalClient{
Dial: d,
},
HTTPClient: &http.Client{
Transport: &http.Transport{
DialContext: d,
},
},
} }
} }
func (s *Server) NodeStatus(ctx context.Context, n *Node) ([]byte, error) {
rt := s.NodeAgentRoundTripper(n)
req, err := http.NewRequestWithContext(ctx, "GET", "http://node/status", nil)
if err != nil {
return nil, err
}
res, err := rt.RoundTrip(req)
if err != nil {
return nil, err
}
defer res.Body.Close()
if res.StatusCode != 200 {
body, _ := io.ReadAll(io.LimitReader(res.Body, 1<<20))
return nil, fmt.Errorf("status: %v, %s, %v", res.Status, body, res.Header)
}
return io.ReadAll(res.Body)
}