mirror of
https://github.com/tailscale/tailscale.git
synced 2024-12-02 06:25:37 +00:00
natlab: add NodeAgentClient
This adds a new NodeAgentClient type that can be used to invoke the LocalAPI using the LocalClient instead of handcrafted URLs. However, there are certain cases where it does make sense for the node agent to provide more functionality than whats possible with just the LocalClient, as such it also exposes a http.Client to make requests directly. Signed-off-by: Maisem Ali <maisem@tailscale.com>
This commit is contained in:
parent
19a7b2143b
commit
2f00bead5a
@ -65,6 +65,8 @@ type localClientRoundTripper struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (rt localClientRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
|
func (rt localClientRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||||
|
req = req.Clone(req.Context())
|
||||||
|
req.RequestURI = ""
|
||||||
return rt.lc.DoLocalRequest(req)
|
return rt.lc.DoLocalRequest(req)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -68,12 +68,12 @@ func main() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
s.WriteStartingBanner(os.Stdout)
|
s.WriteStartingBanner(os.Stdout)
|
||||||
|
nc := s.NodeAgentClient(node1)
|
||||||
go func() {
|
go func() {
|
||||||
getStatus := func() {
|
getStatus := func() {
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
st, err := s.NodeStatus(ctx, node1)
|
st, err := nc.Status(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("NodeStatus: %v", err)
|
log.Printf("NodeStatus: %v", err)
|
||||||
return
|
return
|
||||||
|
@ -2,23 +2,23 @@
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"net/url"
|
|
||||||
"os"
|
"os"
|
||||||
"os/exec"
|
"os/exec"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"strings"
|
|
||||||
"sync"
|
"sync"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"golang.org/x/sync/errgroup"
|
"golang.org/x/sync/errgroup"
|
||||||
|
"tailscale.com/client/tailscale"
|
||||||
"tailscale.com/ipn/ipnstate"
|
"tailscale.com/ipn/ipnstate"
|
||||||
|
"tailscale.com/tailcfg"
|
||||||
"tailscale.com/tstest/natlab/vnet"
|
"tailscale.com/tstest/natlab/vnet"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -33,7 +33,7 @@ func newNatTest(tb testing.TB) *natTest {
|
|||||||
nt := &natTest{
|
nt := &natTest{
|
||||||
tb: tb,
|
tb: tb,
|
||||||
tempDir: tb.TempDir(),
|
tempDir: tb.TempDir(),
|
||||||
base: "/Users/bradfitz/src/tailscale.com/gokrazy/tsapp.qcow2",
|
base: "/Users/maisem/dev/tailscale.com/gokrazy/tsapp.qcow2",
|
||||||
}
|
}
|
||||||
|
|
||||||
if _, err := os.Stat(nt.base); err != nil {
|
if _, err := os.Stat(nt.base); err != nil {
|
||||||
@ -113,7 +113,7 @@ func (nt *natTest) runTest(node1, node2 addNodeFunc) {
|
|||||||
"-M", "microvm,isa-serial=off",
|
"-M", "microvm,isa-serial=off",
|
||||||
"-m", "1G",
|
"-m", "1G",
|
||||||
"-nodefaults", "-no-user-config", "-nographic",
|
"-nodefaults", "-no-user-config", "-nographic",
|
||||||
"-kernel", "/Users/bradfitz/src/github.com/tailscale/gokrazy-kernel/vmlinuz",
|
"-kernel", "/Users/maisem/dev/github.com/tailscale/gokrazy-kernel/vmlinuz",
|
||||||
"-append", "console=hvc0 root=PARTUUID=60c24cc1-f3f9-427a-8199-dd02023b0001/PARTNROFF=1 ro init=/gokrazy/init panic=10 oops=panic pci=off nousb tsc=unstable clocksource=hpet tailscale-tta=1",
|
"-append", "console=hvc0 root=PARTUUID=60c24cc1-f3f9-427a-8199-dd02023b0001/PARTNROFF=1 ro init=/gokrazy/init panic=10 oops=panic pci=off nousb tsc=unstable clocksource=hpet tailscale-tta=1",
|
||||||
"-drive", "id=blk0,file="+disk+",format=qcow2",
|
"-drive", "id=blk0,file="+disk+",format=qcow2",
|
||||||
"-device", "virtio-blk-device,drive=blk0",
|
"-device", "virtio-blk-device,drive=blk0",
|
||||||
@ -139,15 +139,16 @@ func (nt *natTest) runTest(node1, node2 addNodeFunc) {
|
|||||||
ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second)
|
ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
c1 := &http.Client{Transport: nt.vnet.NodeAgentRoundTripper(nodes[0])}
|
lc1 := nt.vnet.NodeAgentClient(nodes[0])
|
||||||
c2 := &http.Client{Transport: nt.vnet.NodeAgentRoundTripper(nodes[1])}
|
lc2 := nt.vnet.NodeAgentClient(nodes[1])
|
||||||
|
clients := []*vnet.NodeAgentClient{lc1, lc2}
|
||||||
|
|
||||||
var eg errgroup.Group
|
var eg errgroup.Group
|
||||||
var sts [2]*ipnstate.Status
|
var sts [2]*ipnstate.Status
|
||||||
for i, c := range []*http.Client{c1, c2} {
|
for i, c := range clients {
|
||||||
i, c := i, c
|
i, c := i, c
|
||||||
eg.Go(func() error {
|
eg.Go(func() error {
|
||||||
st, err := status(ctx, c)
|
st, err := c.Status(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("node%d status: %w", i, err)
|
return fmt.Errorf("node%d status: %w", i, err)
|
||||||
}
|
}
|
||||||
@ -156,7 +157,7 @@ func (nt *natTest) runTest(node1, node2 addNodeFunc) {
|
|||||||
return fmt.Errorf("node%d up: %w", i, err)
|
return fmt.Errorf("node%d up: %w", i, err)
|
||||||
}
|
}
|
||||||
t.Logf("node%d up!", i)
|
t.Logf("node%d up!", i)
|
||||||
st, err = status(ctx, c)
|
st, err = c.Status(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("node%d status: %w", i, err)
|
return fmt.Errorf("node%d status: %w", i, err)
|
||||||
}
|
}
|
||||||
@ -173,85 +174,40 @@ func (nt *natTest) runTest(node1, node2 addNodeFunc) {
|
|||||||
t.Fatalf("initial setup: %v", err)
|
t.Fatalf("initial setup: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
route, err := ping(ctx, c1, sts[1].Self.TailscaleIPs[0].String())
|
route, err := ping(ctx, lc1, sts[1].Self.TailscaleIPs[0])
|
||||||
t.Logf("ping route: %v, %v", route, err)
|
t.Logf("ping route: %v, %v", route, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
func status(ctx context.Context, c *http.Client) (*ipnstate.Status, error) {
|
func ping(ctx context.Context, c *vnet.NodeAgentClient, target netip.Addr) (*ipnstate.PingResult, error) {
|
||||||
req, err := http.NewRequestWithContext(ctx, "GET", "http://unused/status", nil)
|
n := 0
|
||||||
|
var res *ipnstate.PingResult
|
||||||
|
anyPong := false
|
||||||
|
for {
|
||||||
|
n++
|
||||||
|
pr, err := c.PingWithOpts(ctx, target, tailcfg.PingDisco, tailscale.PingOpts{})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
if anyPong {
|
||||||
|
return res, nil
|
||||||
|
}
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
res, err := c.Do(req)
|
if pr.Err != "" {
|
||||||
if err != nil {
|
return nil, errors.New(pr.Err)
|
||||||
return nil, err
|
|
||||||
}
|
}
|
||||||
defer res.Body.Close()
|
if pr.DERPRegionID == 0 {
|
||||||
all, err := io.ReadAll(res.Body)
|
return pr, nil
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("ReadAll: %w", err)
|
|
||||||
}
|
}
|
||||||
var st ipnstate.Status
|
time.Sleep(time.Second)
|
||||||
if err := json.Unmarshal(all, &st); err != nil {
|
res = pr
|
||||||
return nil, fmt.Errorf("JSON marshal error: %v; body was %q", err, all)
|
|
||||||
}
|
}
|
||||||
return &st, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type routeType string
|
func up(ctx context.Context, c *vnet.NodeAgentClient) error {
|
||||||
|
|
||||||
const (
|
|
||||||
routeDirect routeType = "direct"
|
|
||||||
routeDERP routeType = "derp"
|
|
||||||
routeLAN routeType = "lan"
|
|
||||||
)
|
|
||||||
|
|
||||||
func ping(ctx context.Context, c *http.Client, target string) (routeType, error) {
|
|
||||||
req, err := http.NewRequestWithContext(ctx, "POST", "http://unused/ping?target="+url.QueryEscape(target), nil)
|
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
res, err := c.Do(req)
|
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
defer res.Body.Close()
|
|
||||||
if res.StatusCode != 200 {
|
|
||||||
return "", fmt.Errorf("unexpected status code %v", res.Status)
|
|
||||||
}
|
|
||||||
all, _ := io.ReadAll(res.Body)
|
|
||||||
var route routeType
|
|
||||||
for _, line := range strings.Split(string(all), "\n") {
|
|
||||||
if strings.Contains(line, " via DERP") {
|
|
||||||
route = routeDERP
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
// pong from foo (100.82.3.4) via ADDR:PORT in 69ms
|
|
||||||
if _, rest, ok := strings.Cut(line, " via "); ok {
|
|
||||||
ipPorStr, _, _ := strings.Cut(rest, " in ")
|
|
||||||
ipPort, err := netip.ParseAddrPort(ipPorStr)
|
|
||||||
if err == nil {
|
|
||||||
if ipPort.Addr().IsPrivate() {
|
|
||||||
route = routeLAN
|
|
||||||
} else {
|
|
||||||
route = routeDirect
|
|
||||||
}
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if route == "" {
|
|
||||||
return routeType(all), nil
|
|
||||||
}
|
|
||||||
return route, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func up(ctx context.Context, c *http.Client) error {
|
|
||||||
req, err := http.NewRequestWithContext(ctx, "GET", "http://unused/up", nil)
|
req, err := http.NewRequestWithContext(ctx, "GET", "http://unused/up", nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
res, err := c.Do(req)
|
res, err := c.HTTPClient.Do(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@ -269,6 +225,7 @@ func TestEasyEasy(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestEasyHard(t *testing.T) {
|
func TestEasyHard(t *testing.T) {
|
||||||
|
t.Skip()
|
||||||
nt := newNatTest(t)
|
nt := newNatTest(t)
|
||||||
nt.runTest(easy, hard)
|
nt.runTest(easy, hard)
|
||||||
}
|
}
|
||||||
|
@ -46,6 +46,7 @@
|
|||||||
"gvisor.dev/gvisor/pkg/tcpip/transport/icmp"
|
"gvisor.dev/gvisor/pkg/tcpip/transport/icmp"
|
||||||
"gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
|
"gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
|
||||||
"gvisor.dev/gvisor/pkg/waiter"
|
"gvisor.dev/gvisor/pkg/waiter"
|
||||||
|
"tailscale.com/client/tailscale"
|
||||||
"tailscale.com/derp"
|
"tailscale.com/derp"
|
||||||
"tailscale.com/derp/derphttp"
|
"tailscale.com/derp/derphttp"
|
||||||
"tailscale.com/net/netutil"
|
"tailscale.com/net/netutil"
|
||||||
@ -279,7 +280,7 @@ func (n *network) acceptTCP(r *tcp.ForwarderRequest) {
|
|||||||
bs := bufio.NewScanner(tc)
|
bs := bufio.NewScanner(tc)
|
||||||
for bs.Scan() {
|
for bs.Scan() {
|
||||||
line := bs.Text()
|
line := bs.Text()
|
||||||
log.Printf("LOG from guest: %s", line)
|
log.Printf("LOG from guest %v: %s", clientRemoteIP, line)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
return
|
return
|
||||||
@ -1356,6 +1357,11 @@ func (s *Server) takeAgentConnOne(n *node) (_ *agentConn, ok bool) {
|
|||||||
return nil, false
|
return nil, false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type NodeAgentClient struct {
|
||||||
|
*tailscale.LocalClient
|
||||||
|
HTTPClient *http.Client
|
||||||
|
}
|
||||||
|
|
||||||
func (s *Server) NodeAgentDialer(n *Node) DialFunc {
|
func (s *Server) NodeAgentDialer(n *Node) DialFunc {
|
||||||
s.mu.Lock()
|
s.mu.Lock()
|
||||||
defer s.mu.Unlock()
|
defer s.mu.Unlock()
|
||||||
@ -1374,26 +1380,16 @@ func (s *Server) NodeAgentDialer(n *Node) DialFunc {
|
|||||||
return d
|
return d
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) NodeAgentRoundTripper(n *Node) http.RoundTripper {
|
func (s *Server) NodeAgentClient(n *Node) *NodeAgentClient {
|
||||||
return &http.Transport{
|
d := s.NodeAgentDialer(n)
|
||||||
DialContext: s.NodeAgentDialer(n),
|
return &NodeAgentClient{
|
||||||
|
LocalClient: &tailscale.LocalClient{
|
||||||
|
Dial: d,
|
||||||
|
},
|
||||||
|
HTTPClient: &http.Client{
|
||||||
|
Transport: &http.Transport{
|
||||||
|
DialContext: d,
|
||||||
|
},
|
||||||
|
},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) NodeStatus(ctx context.Context, n *Node) ([]byte, error) {
|
|
||||||
rt := s.NodeAgentRoundTripper(n)
|
|
||||||
req, err := http.NewRequestWithContext(ctx, "GET", "http://node/status", nil)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
res, err := rt.RoundTrip(req)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
defer res.Body.Close()
|
|
||||||
if res.StatusCode != 200 {
|
|
||||||
body, _ := io.ReadAll(io.LimitReader(res.Body, 1<<20))
|
|
||||||
return nil, fmt.Errorf("status: %v, %s, %v", res.Status, body, res.Header)
|
|
||||||
}
|
|
||||||
return io.ReadAll(res.Body)
|
|
||||||
}
|
|
||||||
|
Loading…
Reference in New Issue
Block a user