diff --git a/cmd/k8s-operator/depaware.txt b/cmd/k8s-operator/depaware.txt index f810d1b4f..85c991ad8 100644 --- a/cmd/k8s-operator/depaware.txt +++ b/cmd/k8s-operator/depaware.txt @@ -130,11 +130,7 @@ tailscale.com/cmd/k8s-operator dependencies: (generated by github.com/tailscale/ github.com/google/gnostic-models/jsonschema from github.com/google/gnostic-models/compiler github.com/google/gnostic-models/openapiv2 from k8s.io/client-go/discovery+ github.com/google/gnostic-models/openapiv3 from k8s.io/kube-openapi/pkg/handler3+ - 💣 github.com/google/go-cmp/cmp from k8s.io/apimachinery/pkg/util/diff+ - github.com/google/go-cmp/cmp/internal/diff from github.com/google/go-cmp/cmp - github.com/google/go-cmp/cmp/internal/flags from github.com/google/go-cmp/cmp+ - github.com/google/go-cmp/cmp/internal/function from github.com/google/go-cmp/cmp - 💣 github.com/google/go-cmp/cmp/internal/value from github.com/google/go-cmp/cmp + github.com/google/go-cmp/cmp from k8s.io/apimachinery/pkg/util/diff+ github.com/google/gofuzz from k8s.io/apimachinery/pkg/apis/meta/v1+ github.com/google/gofuzz/bytesource from github.com/google/gofuzz L github.com/google/nftables from tailscale.com/util/linuxfw @@ -1195,7 +1191,7 @@ tailscale.com/cmd/k8s-operator dependencies: (generated by github.com/tailscale/ math from archive/tar+ math/big from crypto/dsa+ math/bits from compress/flate+ - math/rand from github.com/google/go-cmp/cmp+ + math/rand from github.com/fxamacker/cbor/v2+ math/rand/v2 from tailscale.com/derp+ mime from github.com/prometheus/common/expfmt+ mime/multipart from github.com/go-openapi/swag+ diff --git a/go.mod b/go.mod index 3d7514158..531ded6da 100644 --- a/go.mod +++ b/go.mod @@ -21,6 +21,8 @@ require ( github.com/coder/websocket v1.8.12 github.com/coreos/go-iptables v0.7.1-0.20240112124308-65c67c9f46e6 github.com/coreos/go-systemd v0.0.0-20191104093116-d3cd4ed1dbcf + github.com/creachadair/command v0.1.24 + github.com/creachadair/flax v0.0.5 github.com/creachadair/taskgroup v0.13.2 github.com/creack/pty v1.1.23 github.com/dblohm7/wingoes v0.0.0-20240119213807-a09d6be7affa @@ -42,7 +44,7 @@ require ( github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da github.com/golang/snappy v0.0.4 github.com/golangci/golangci-lint v1.57.1 - github.com/google/go-cmp v0.6.0 + github.com/google/go-cmp v0.7.0 github.com/google/go-containerregistry v0.20.2 github.com/google/go-tpm v0.9.4 github.com/google/gopacket v1.1.19 @@ -123,6 +125,7 @@ require ( sigs.k8s.io/controller-tools v0.17.0 sigs.k8s.io/yaml v1.4.0 software.sslmate.com/src/go-pkcs12 v0.4.0 + tailscale.com/client/tailscale/v2 v2.0.0-20250729171440-3f3f51970e08 ) require ( diff --git a/go.sum b/go.sum index 995b93010..756a9e4ba 100644 --- a/go.sum +++ b/go.sum @@ -240,8 +240,12 @@ github.com/coreos/go-iptables v0.7.1-0.20240112124308-65c67c9f46e6/go.mod h1:Qe8 github.com/coreos/go-systemd v0.0.0-20191104093116-d3cd4ed1dbcf h1:iW4rZ826su+pqaw19uhpSCzhj44qo35pNgKFGqzDKkU= github.com/coreos/go-systemd v0.0.0-20191104093116-d3cd4ed1dbcf/go.mod h1:F5haX7vjVVG0kc13fIWeqUViNPyEJxv/OmvnBo0Yme4= github.com/cpuguy83/go-md2man/v2 v2.0.4/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o= -github.com/creachadair/mds v0.17.1 h1:lXQbTGKmb3nE3aK6OEp29L1gCx6B5ynzlQ6c1KOBurc= -github.com/creachadair/mds v0.17.1/go.mod h1:4b//mUiL8YldH6TImXjmW45myzTLNS1LLjOmrk888eg= +github.com/creachadair/command v0.1.24 h1:YdoK1t3swYioFq9woVJ4QtCIIPB0DLjVH38HuHk61R4= +github.com/creachadair/command v0.1.24/go.mod h1:+18hYRwZNDE9rMbMgy1P1gfgmU5bGMd7zFRBC0ARk+Y= +github.com/creachadair/flax v0.0.5 h1:zt+CRuXQASxwQ68e9GHAOnEgAU29nF0zYMHOCrL5wzE= +github.com/creachadair/flax v0.0.5/go.mod h1:F1PML0JZLXSNDMNiRGK2yjm5f+L9QCHchyHBldFymj8= +github.com/creachadair/mds v0.25.2 h1:xc0S0AfDq5GX9KUR5sLvi5XjA61/P6S5e0xFs1vA18Q= +github.com/creachadair/mds v0.25.2/go.mod h1:+s4CFteFRj4eq2KcGHW8Wei3u9NyzSPzNV32EvjyK/Q= github.com/creachadair/taskgroup v0.13.2 h1:3KyqakBuFsm3KkXi/9XIb0QcA8tEzLHLgaoidf0MdVc= github.com/creachadair/taskgroup v0.13.2/go.mod h1:i3V1Zx7H8RjwljUEeUWYT30Lmb9poewSb2XI1yTwD0g= github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= @@ -484,8 +488,8 @@ github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/ github.com/google/go-cmp v0.5.6/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.8/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= -github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= -github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= +github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= +github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= github.com/google/go-containerregistry v0.20.2 h1:B1wPJ1SN/S7pB+ZAimcciVD+r+yV/l/DSArMxlbwseo= github.com/google/go-containerregistry v0.20.2/go.mod h1:z38EKdKh4h7IP2gSfUUqEvalZBqs6AoLeWfUy34nQC8= github.com/google/go-tpm v0.9.4 h1:awZRf9FwOeTunQmHoDYSHJps3ie6f1UlhS1fOdPEt1I= @@ -1548,3 +1552,5 @@ sigs.k8s.io/yaml v1.4.0 h1:Mk1wCc2gy/F0THH0TAp1QYyJNzRm2KCLy3o5ASXVI5E= sigs.k8s.io/yaml v1.4.0/go.mod h1:Ejl7/uTz7PSA4eKMyQCUTnhZYNmLIl+5c2lQPGR2BPY= software.sslmate.com/src/go-pkcs12 v0.4.0 h1:H2g08FrTvSFKUj+D309j1DPfk5APnIdAQAB8aEykJ5k= software.sslmate.com/src/go-pkcs12 v0.4.0/go.mod h1:Qiz0EyvDRJjjxGyUQa2cCNZn/wMyzrRJ/qcDXOQazLI= +tailscale.com/client/tailscale/v2 v2.0.0-20250729171440-3f3f51970e08 h1:UNUkHzyOasjkoM53eKt28bCazm8NPfusEBfnRchCNYY= +tailscale.com/client/tailscale/v2 v2.0.0-20250729171440-3f3f51970e08/go.mod h1:4akEJPbysqHWAP+t7CZLQ5ZH8/vZWeH6+Hv+fEJUMp0= diff --git a/util/chaos/README.md b/util/chaos/README.md new file mode 100644 index 000000000..8678daef3 --- /dev/null +++ b/util/chaos/README.md @@ -0,0 +1,26 @@ +# chaos + +Chaos is a CLI framework that aims to make it easy to implement variants of a Chaos tailnet scenario, where a large number of Tailscale nodes join a tailnet and then perform some actions. + +It is currently under development, so the interface is expected to change, which means this readme can be out of date. However here are some good starting points: + +- `chaos.go` is the main entry point setting up the sub-command structure and has some helper code for API interaction. + - When adding a new scenario, you will need to add a new sub-command here. +- `scenario.go` contains the structure of a scenario, it defines the steps and how they are ran. +- `node.go` contains two different implementations of a "node" or Tailscale client: + - `NodeDirect` is a lightweight implementation that sets up a basic Direct client and minimal map logic, but does the full authentication flow. + - `NodeTSNet` is a full Tailscale client that uses the tsnet package to set up a full userspace client. +- `scenario-join-n-nodes.go` implements the original chaos tailnet scenario, where N nodes join a tailnet, and also serves as a nice example for how to create a scenario with different configurable variables via flags. + + +### Remove nodes from tailnet + +A helper to clean out all the nodes in the tailnet can be ran as follow: + +```bash + go run ./cmd/chaos \ + --apikey \ + --tailnet \ + --login-server http://127.0.0.1:31544 \ + remove-all-nodes +``` diff --git a/util/chaos/chaos.go b/util/chaos/chaos.go new file mode 100644 index 000000000..d51036e5f --- /dev/null +++ b/util/chaos/chaos.go @@ -0,0 +1,313 @@ +// chaos is a command-line tool and "framework" for generating different +// types of loads based on defined scenarios to a Tailscale control server. +// It can be used to test the control server's performance and resilience. +// +// Scenarios are implemented as subcommands and each can register their own +// flags and options allowing them to be modified at runtime. +package chaos + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "log" + "net" + "net/http" + "net/url" + "os" + "path/filepath" + "runtime/debug" + "strings" + "time" + + "github.com/creachadair/command" + "github.com/creachadair/flax" + "github.com/prometheus/client_golang/prometheus" + "github.com/prometheus/client_golang/prometheus/promauto" + "tailscale.com/client/tailscale/v2" + "tailscale.com/safeweb" + "tailscale.com/tsweb" + "tailscale.com/util/prompt" + + // Support for prometheus varz in tsweb + _ "tailscale.com/tsweb/promvarz" +) + +var baseArgs struct { + LoginServer string `flag:"login-server,Address of the tailcontrol server"` + Tailnet string `flag:"tailnet,default=example.com,TailnetSID of the test tailnet"` + AuthKey string `flag:"authkey,AuthKey for tailnet in tailcontrol"` + ApiKey string `flag:"apikey,API Key for tailcontrol"` + DebugServer string `flag:"debug-server,ip:port for a debug webserver"` + RemoveAll bool `flag:"remove-all,Remove all nodes from the tailnet before the scenario starts (if implemented by the scenario), must be passed with --force"` + Force bool `flag:"force,Force the operation without checks"` + FullTagLabels bool `flag:"full-tag-labels,Use full tag values in metric labels, instead of truncating numeric suffixes"` + NetmapTracker bool `flag:"netmap-tracker,default=true,Enable netmap latency tracking"` + TailcontrolArgs string `flag:"tailcontrol-args,default=,Args and flags passed to tailcontrol"` +} + +type NewControlFunc func(loginServer, tailnet, apikey string) (ControlServer, error) + +// NewControl is a function that creates a new ControlServer instance. +// It and allow for different implementations of the ControlServer interface +// to be used. +var NewControl NewControlFunc = NewTailControl + +// NewChaosCommandEnv creates a new command environment for the chaos tool. +// It is the main entry point for the command-line interface, and where scenarios +// are registered as subcommands. +// It should be called in main, and any alternative implementations of the ControlServer +// should be registered before calling this function by overriding NewControl. +func NewChaosCommandEnv(ctx context.Context) *command.Env { + root := command.C{ + Name: filepath.Base(os.Args[0]), + Usage: "command [flags] ...\nhelp [command]", + Help: `A command-line tool for testing load against a tailcontrol server`, + + Commands: []*command.C{ + // Scenarios are registered as additional subcommands + // and be invoked from the command line with specific flags + // as the user sees fit. + joinNNodesCmd, + ciChurnCmd, + + command.HelpCommand(nil), + command.VersionCommand(), + { + Name: "remove-all-nodes", + Help: `Removes all nodes currently present in the tailnet, use with caution.`, + Run: func(env *command.Env) error { + tc, err := NewControl(baseArgs.LoginServer, baseArgs.Tailnet, baseArgs.ApiKey) + if err != nil { + return err + } + + c := NewChaos(tc) + + return c.RemoveAllNodes(env) + }, + }, + }, + SetFlags: command.Flags(flax.MustBind, &baseArgs), + Init: func(env *command.Env) error { + if baseArgs.DebugServer != "" { + log.Printf("Starting debug server on %s", baseArgs.DebugServer) + go func() { + mux := http.NewServeMux() + tsweb.Debugger(mux) + httpServer, err := safeweb.NewServer(safeweb.Config{BrowserMux: mux}) + if err != nil { + log.Fatalf("safeweb.NewServer: %v", err) + } + ln, err := net.Listen("tcp", baseArgs.DebugServer) + if err != nil { + log.Fatalf("failed to listen: %v", err) + } + defer ln.Close() + + if err := httpServer.Serve(ln); err != nil && !errors.Is(err, http.ErrServerClosed) { + log.Fatalf("http: %v", err) + } + }() + } + + runLabels := map[string]string{ + "start": time.Now().Format(time.RFC3339), + "chaos_args": strings.Join(os.Args, " "), + "tailcontrol_args": baseArgs.TailcontrolArgs, + } + if build, ok := debug.ReadBuildInfo(); ok { + for _, setting := range build.Settings { + runLabels[settingLabel(setting.Key)] = setting.Value + } + } + promauto.NewGauge(prometheus.GaugeOpts{ + Name: "chaos_run", + Help: "details about this chaos run", + ConstLabels: runLabels, + }).Set(1) + return nil + }, + } + + return root.NewEnv(nil).SetContext(ctx).MergeFlags(true) +} + +func settingLabel(key string) string { + return "build_" + strings.Trim(strings.Map(func(r rune) rune { + if (r >= 'a' && r <= 'z') || (r >= 'A' && r <= 'Z') || (r >= '0' && r <= '9') { + return r + } + return '_' + }, key), "_") +} + +// Chaos is the main structure for the chaos tool. +type Chaos struct { + Control ControlServer +} + +// ControlServer defines the interface for interacting with a Tailscale control server. +type ControlServer interface { + Tailnet() string + BaseURL() string + SetACL(ctx context.Context, pol tailscale.ACL) (*tailscale.ACL, error) + ListDevices(ctx context.Context) ([]tailscale.Device, error) + RemoveDevice(ctx context.Context, nodeID string) error + CreatAuthKey(ctx context.Context, ephemeral bool, tags []string) (string, error) +} + +// TailControl is a concrete implementation of the ControlServer interface +// that uses the Tailscale API client to interact with a Tailscale control server. +type TailControl struct { + c *tailscale.Client +} + +// NewTailControl creates a new TailControl instance. +func NewTailControl(loginServer, tailnet, apikey string) (ControlServer, error) { + c := &tailscale.Client{ + Tailnet: tailnet, + APIKey: apikey, + } + c.UserAgent = "tailscale-chaos" + var err error + c.BaseURL, err = url.Parse(strings.TrimSuffix(loginServer, "/")) + if err != nil { + return nil, fmt.Errorf("parse url: %w", err) + } + + if tailnet == "" { + return nil, errors.New("tailnet is required for API client") + } + + if apikey == "" { + return nil, errors.New("apikey is required for API client") + } + + return &TailControl{ + c: c, + }, nil +} + +// SetACL sets the ACL for the tailnet. +func (tc *TailControl) SetACL(ctx context.Context, pol tailscale.ACL) (*tailscale.ACL, error) { + return tc.c.PolicyFile().SetAndGet(ctx, pol, pol.ETag) +} + +// Tailnet returns the tailnet domain. +func (tc *TailControl) Tailnet() string { + return tc.c.Tailnet +} + +// BaseURL returns the base URL of the Tailscale control server. +func (tc *TailControl) BaseURL() string { + return tc.c.BaseURL.String() +} + +// ListDevices lists all devices in the tailnet. +func (tc *TailControl) ListDevices(ctx context.Context) ([]tailscale.Device, error) { + return tc.c.Devices().List(ctx) +} + +// RemoveDevice removes a device from the tailnet by its node ID. +func (tc *TailControl) RemoveDevice(ctx context.Context, nodeID string) error { + return tc.c.Devices().Delete(ctx, nodeID) +} + +// CreatAuthKey creates a new Tailscale auth key with the specified options. +func (tc *TailControl) CreatAuthKey(ctx context.Context, eph bool, tags []string) (string, error) { + var req tailscale.CreateKeyRequest + req.Capabilities.Devices.Create.Preauthorized = true + req.Capabilities.Devices.Create.Reusable = true + req.Capabilities.Devices.Create.Tags = tags + req.Capabilities.Devices.Create.Ephemeral = eph + key, err := tc.c.Keys().Create(ctx, req) + if err != nil { + return "", err + } + return key.Key, err +} + +// NewChaos creates a new Chaos instance with the provided ControlServer. +func NewChaos(control ControlServer) *Chaos { + return &Chaos{ + Control: control, + } +} + +// SetACL sets the ACL for the tailnet. +func (c *Chaos) SetACL(ctx context.Context, pol tailscale.ACL) (*tailscale.ACL, error) { + return c.Control.SetACL(ctx, pol) +} + +const defaultACLs = ` +// Example/default ACLs for unrestricted connections. +{ + // Define grants that govern access for users, groups, autogroups, tags, + // Tailscale IP addresses, and subnet ranges. + "grants": [ + // Allow all connections. + // Comment this section out if you want to define specific restrictions. + {"src": ["*"], "dst": ["*"], "ip": ["*"]}, + ], + // Define users and devices that can use Tailscale SSH. + "ssh": [ + // Allow all users to SSH into their own devices in check mode. + // Comment this section out if you want to define specific restrictions. + { + "action": "check", + "src": ["autogroup:member"], + "dst": ["autogroup:self"], + "users": ["autogroup:nonroot", "root"], + }, + ], +} +` + +// ResetACL resets the ACL for the tailnet to the default policy. +func (c *Chaos) ResetACL(ctx context.Context) error { + var pol tailscale.ACL + if err := json.Unmarshal([]byte(defaultACLs), &pol); err != nil { + return err + } + + if _, err := c.Control.SetACL(ctx, pol); err != nil { + return err + } + + return nil +} + +// RemoveAllNodes removes all nodes from the tailnet. +// It prompts the user for confirmation unless the --force flag is set. +func (c *Chaos) RemoveAllNodes(env *command.Env) error { + if baseArgs.Force { + log.Printf("Force flag passed, proceeding with removal of all nodes") + } else if !prompt.YesNo(fmt.Sprintf("Remove all nodes in tailnet %q on tailcontrol %q?", c.Control.Tailnet(), c.Control.BaseURL())) { + log.Printf("removal of all nodes requested, but not confirmed: aborting removal of all nodes") + return nil + } + + if err := c.removeAllNodes(env.Context()); err != nil { + return err + } + + return nil +} + +func (c *Chaos) removeAllNodes(ctx context.Context) error { + devs, err := c.Control.ListDevices(ctx) + if err != nil { + return fmt.Errorf("getting devices: %w", err) + } + + for _, dev := range devs { + log.Printf("Deleting device %q (%s)", dev.Name, dev.NodeID) + if err := c.Control.RemoveDevice(ctx, dev.NodeID); err != nil { + return fmt.Errorf("deleting device %q (%s): %w", dev.Name, dev.NodeID, err) + } + } + + return nil +} diff --git a/util/chaos/node.go b/util/chaos/node.go new file mode 100644 index 000000000..36dc6444c --- /dev/null +++ b/util/chaos/node.go @@ -0,0 +1,684 @@ +package chaos + +import ( + "context" + "encoding/json" + "fmt" + "log" + "os" + "sync" + "time" + + "github.com/google/uuid" + "github.com/prometheus/client_golang/prometheus" + "github.com/prometheus/client_golang/prometheus/promauto" + "golang.org/x/sync/errgroup" + "golang.org/x/sync/semaphore" + "tailscale.com/client/local" + cc "tailscale.com/control/controlclient" + "tailscale.com/health" + "tailscale.com/ipn" + "tailscale.com/ipn/ipnstate" + "tailscale.com/ipn/store/mem" + "tailscale.com/net/netmon" + "tailscale.com/net/tsdial" + "tailscale.com/tailcfg" + "tailscale.com/tsnet" + "tailscale.com/types/key" + "tailscale.com/types/logger" + "tailscale.com/types/netmap" + "tailscale.com/util/multierr" +) + +var ( + nodeJoins = promauto.NewCounterVec(prometheus.CounterOpts{ + Name: "chaos_node_joins_total", + Help: "Incremented every time a node joins", + }, []string{"tags"}) + nodeDisconnects = promauto.NewCounterVec(prometheus.CounterOpts{ + Name: "chaos_node_disconnects_total", + Help: "Incremented when a node disconnects", + }, []string{"tags"}) + connectedNodes = promauto.NewGaugeVec(prometheus.GaugeOpts{ + Name: "chaos_online_nodes", + Help: "Number of online nodes", + }, []string{"tags"}) + joining = promauto.NewGaugeVec(prometheus.GaugeOpts{ + Name: "chaos_joining_nodes", + Help: "Number of nodes in the process of joining the network", + }, []string{"tags"}) + joinLatencies = promauto.NewHistogram(prometheus.HistogramOpts{ + Name: "chaos_join_latency_seconds", + Help: "Time it took to join", + Buckets: prometheus.ExponentialBucketsRange(0.01, 30, 20), + }) +) + +// Node is the interface that represent a Tailscale client. +// It can be implemented to use helper functions in this library +// to make it easier to create reusable scenarios. +// +// The standard implementation of Node is based on tsnet, which +// implements a full userspace tailscale client, but it can be +// quite heavy on both memory and cpu usage, so there might be +// cases where you want to implement your own Client via the node +// interface. +type Node interface { + // Name returns the name of the underlying instance of Tailscale. + Name() string + + // WaitRunning blocks until the underlying Tailscale reports running. + WaitRunning(context.Context) error + + // Start starts Tailscale. + Start(context.Context) error + + // Close stops Tailscale. + Close(context.Context) error + + // Status returns the current status of the Tailscale client. + Status(context.Context) (*ipnstate.Status, error) + + // Stats returns the node's connection stats (latency measurements, etc). + Stats() *NodeStats +} + +// NodeStats contains statistics about a Node's +// connection to the Tailscale network. +type NodeStats struct { + // Name of node + Name string + + // Durations spent logging in in milliseconds + LoginDur time.Duration + + // Duration spent to get the first netmap in milliseconds + FirstNetMapDur time.Duration + + // Number of peers in the first netmap + PeerCount int +} + +// NodeTSNet is a Node implementation based on tsnet, +// a userspace Tailscale client. +type NodeTSNet struct { + tagsLabel string + uuid uuid.UUID + dir string + + ts *tsnet.Server + lc *local.Client +} + +// NodeOpts describes configuration options for nodes +// to be used during a chaos scenario. All options might +// not have affect on all implementations of the tailscale +// clients implemented. +type NodeOpts struct { + loginServer string + authKey string + ephemeral bool + tags []string + logf logger.Logf + userLogf logger.Logf +} + +// NewNodeTSNet returns a Node based on tsnet, a userspace Tailscale +// client. +func NewNodeTSNet(ctx context.Context, opts NodeOpts) (Node, error) { + n := &NodeTSNet{ + uuid: uuid.New(), + ts: new(tsnet.Server), + tagsLabel: tagsMetricLabel(opts.tags, baseArgs.FullTagLabels), + } + + if opts.authKey == "" { + return nil, fmt.Errorf("AuthKey is required") + } + + var err error + if opts.ephemeral { + n.ts.Store = new(mem.Store) + } else { + n.dir, err = os.MkdirTemp("", "chaos-node"+n.uuid.String()) + if err != nil { + return nil, fmt.Errorf("failed to create temporary directory: %w", err) + } + n.ts.Dir = n.dir + } + + n.ts.Hostname = n.Name() + n.ts.ControlURL = opts.loginServer + n.ts.AuthKey = opts.authKey + n.ts.Ephemeral = opts.ephemeral + n.ts.Logf = opts.logf + n.ts.UserLogf = opts.userLogf + + lc, err := n.ts.LocalClient() + if err != nil { + return nil, err + } + n.lc = lc + + return n, nil +} + +// NewNodeTSNetAsync returns a Node based on tsnet, a userspace Tailscale +// client, the function returns a channel which will get a Node or nil +// when it is ready. +func NewNodeTSNetAsync(ctx context.Context, opts NodeOpts) <-chan Node { + ch := make(chan Node, 1) + go func() { + defer close(ch) + n, err := NewNodeTSNet(ctx, opts) + if err != nil { + log.Printf("failed to create node: %v", err) + return + } + ch <- n + }() + return ch +} + +// waitForNotification waits for a notification that satisfies fn. +func (n *NodeTSNet) waitForNotification(ctx context.Context, fn func(n *ipn.Notify) bool) error { + if n.lc == nil { + return fmt.Errorf("LocalClient is nil") + } + + watcher, err := n.lc.WatchIPNBus(ctx, ipn.NotifyInitialState) + if err != nil { + return err + } + + for { + n, err := watcher.Next() + if err != nil { + return fmt.Errorf("watching ipn: %w", err) + } + + if fn(&n) { + return nil + } + } +} + +// Name returns the name of the node. +func (n *NodeTSNet) Name() string { + return uuidToHostname(n.uuid) +} + +// waitRunning waits for the node to be in the Running state. +func (n *NodeTSNet) WaitRunning(ctx context.Context) error { + err := n.waitForNotification(ctx, func(n *ipn.Notify) bool { + return n.State != nil && *n.State == ipn.Running + }) + if err != nil { + return fmt.Errorf("node %s failed to get to a running state: %w", n.Name(), err) + } + return nil +} + +// Start starts the Tailscale client and waits for it to be in the Running state. +func (n *NodeTSNet) Start(_ context.Context) error { + joining.WithLabelValues(n.tagsLabel).Inc() + defer joining.WithLabelValues(n.tagsLabel).Dec() + start := time.Now() + + err := n.ts.Start() + if err == nil { + nodeJoins.WithLabelValues(n.tagsLabel).Inc() + connectedNodes.WithLabelValues(n.tagsLabel).Inc() + joinLatencies.Observe(time.Since(start).Seconds()) + } + return err +} + +// Close stops the Tailscale client and cleans up any resources used by it. +func (n *NodeTSNet) Close(_ context.Context) error { + defer nodeDisconnects.WithLabelValues(n.tagsLabel).Inc() + defer connectedNodes.WithLabelValues(n.tagsLabel).Dec() + defer func() { + if n.dir != "" { + if err := os.RemoveAll(n.dir); err != nil { + log.Printf("failed to remove temporary directory %q: %v", n.dir, err) + } + } + }() + return n.ts.Close() +} + +// Status returns the current status of the Tailscale client. +func (n *NodeTSNet) Status(ctx context.Context) (*ipnstate.Status, error) { + return n.lc.Status(ctx) +} + +// TODO(kradalby): Implement stats for tsnet +func (n *NodeTSNet) Stats() *NodeStats { + return &NodeStats{ + Name: n.Name(), + } +} + +// NodeMap is a collection of Nodes and helper functions +// that can be used to do common scenario tasks on a lot +// of nodes. +type NodeMap struct { + m map[string]Node + mu sync.Mutex +} + +func (nm *NodeMap) Len() int { + nm.mu.Lock() + defer nm.mu.Unlock() + return len(nm.m) +} + +// NewNodeMap creates a new NodeMap. +func NewNodeMap(nodeCount int) *NodeMap { + return &NodeMap{ + m: make(map[string]Node, nodeCount), + } +} + +type NewNodeFunc func(context.Context, NodeOpts) <-chan Node + +// NewNodeMapWithNodes creates N amount of Node instances and returns +// a new NodeMap. +func NewNodeMapWithNodes(ctx context.Context, newNode NewNodeFunc, nodeCount int, opts NodeOpts) (*NodeMap, error) { + nm := NewNodeMap(nodeCount) + + return nm, nm.AddNodes(ctx, newNode, nodeCount, opts) +} + +// AddNodes adds N amount of new nodes to the nodeMap. +func (nm *NodeMap) AddNodes(ctx context.Context, newNode NewNodeFunc, nodeCount int, opts NodeOpts) error { + var errg errgroup.Group + for range nodeCount { + ch := newNode(ctx, opts) + + errg.Go(func() error { + n, ok := <-ch + if ok { + if n == nil { + return fmt.Errorf("error creating node") + } + nm.mu.Lock() + nm.m[n.Name()] = n + nm.mu.Unlock() + } + return nil + }) + } + + return errg.Wait() +} + +// WaitForReady waits for all nodes in the nodeMap to enter +// a "Running" ready state. An error will return if any of the +// nodes failed to reach that state within the limits of the +// passed context. +func (nm *NodeMap) WaitForReady(ctx context.Context) error { + var errg errgroup.Group + for _, n := range nm.m { + n := n + errg.Go(func() error { + return n.WaitRunning(ctx) + }) + } + + return errg.Wait() +} + +// StartAll starts all nodes in the nodeMap. The concurrency limit +// restricts the number of nodes being started at the same time, +func (nm *NodeMap) StartAll(ctx context.Context, concurrency int) error { + errChan := make(chan error, nm.Len()) + var wg sync.WaitGroup + sem := semaphore.NewWeighted(int64(concurrency)) + + count := 0 + for _, n := range nm.m { + if err := sem.Acquire(ctx, 1); err != nil { + errChan <- fmt.Errorf("Failed to acquire semaphore: %v", err) + break + } + + count++ + wg.Add(1) + node := n + go func() { + defer sem.Release(1) + defer wg.Done() + if err := node.Start(ctx); err != nil { + errChan <- fmt.Errorf("starting node %q: %w", node.Name(), err) + } + }() + } + wg.Wait() + close(errChan) + + // Drain errors and combine + var errs []error + for err := range errChan { + errs = append(errs, err) + } + + return multierr.New(errs...) +} + +// SaveStatusToFile saves the stats of all nodes in the NodeMap +// to a JSON file at the specified path. +func (nm *NodeMap) SaveStatusToFile(path string) error { + f, err := os.OpenFile(path, os.O_CREATE|os.O_WRONLY, 0600) + if err != nil { + return fmt.Errorf("opening file: %w", err) + } + defer f.Close() + + nm.mu.Lock() + var stats []*NodeStats + for _, n := range nm.m { + stats = append(stats, n.Stats()) + } + nm.mu.Unlock() + + b, err := json.Marshal(stats) + if err != nil { + return fmt.Errorf("marshalling stats: %w", err) + } + + _, err = f.Write(b) + if err != nil { + return fmt.Errorf("writing to file: %w", err) + } + + return nil +} + +// CloseAll closes all running nodes. +func (nm *NodeMap) CloseAll(ctx context.Context) error { + errChan := make(chan error, nm.Len()) + var wg sync.WaitGroup + for _, n := range nm.m { + wg.Add(1) + + node := n + go func() { + defer wg.Done() + if err := node.Close(ctx); err != nil { + errChan <- fmt.Errorf("closing node %q: %w", node.Name(), err) + } + }() + } + wg.Wait() + close(errChan) + + // Drain errors and combine + var errs []error + for err := range errChan { + errs = append(errs, err) + } + + return multierr.New(errs...) +} + +// CloseAndDeleteAll closes down all running nodes +// and deletes the from the tailcontrol server if they exists. +func (nm *NodeMap) CloseAndDeleteAll(ctx context.Context, c *Chaos) error { + errChan := make(chan error, nm.Len()) + var wg sync.WaitGroup + for _, n := range nm.m { + wg.Add(1) + + node := n + go func() { + defer wg.Done() + status, err := node.Status(ctx) + if err != nil { + errChan <- fmt.Errorf("getting status: %w", err) + return + } + + err = node.Close(ctx) + if err != nil { + errChan <- fmt.Errorf("closing node %q: %w", node.Name(), err) + return + } + + log.Printf("Deleting device %q (%s)", node.Name(), status.Self.ID) + if err := c.Control.RemoveDevice(ctx, string(status.Self.ID)); err != nil { + errChan <- fmt.Errorf("deleting device %q (%s): %w", node.Name(), status.Self.ID, err) + } + }() + } + wg.Wait() + close(errChan) + + // Drain errors and combine + var errs []error + for err := range errChan { + errs = append(errs, err) + } + + return multierr.New(errs...) +} + +type NodeDirect struct { + *cc.Direct + ephemeral bool + loggedIn bool + nodeID tailcfg.NodeID + stableID tailcfg.StableNodeID + nmChan <-chan *netmap.NetworkMap + logf logger.Logf + + uuid uuid.UUID + tagsLabel string + stats NodeStats + tracker *netmapLatencyTracker +} + +// NewNodeDirect returns a Node based on cc.Direct, a tiny tailscale +// client made for direct connection and testing. +func NewNodeDirect(nOpts NodeOpts) (Node, error) { + node := &NodeDirect{ + uuid: uuid.New(), + ephemeral: nOpts.ephemeral, + loggedIn: false, + logf: nOpts.logf, + tagsLabel: tagsMetricLabel(nOpts.tags, baseArgs.FullTagLabels), + } + node.stats = NodeStats{Name: node.Name()} + if baseArgs.NetmapTracker { + node.tracker = newLatencyTracker() + } + + hi := &tailcfg.Hostinfo{ + Hostname: node.Name(), + + // Is required for the node to be able to connect to the tailcontrol server. + BackendLogID: "go-test-only", + FrontendLogID: "go-test-only", + } + opts := cc.Options{ + ServerURL: nOpts.loginServer, + AuthKey: nOpts.authKey, + Hostinfo: hi, + Dialer: tsdial.NewDialer(netmon.NewStatic()), + DiscoPublicKey: key.NewDisco().Public(), + HealthTracker: new(health.Tracker), + Logf: nOpts.logf, + } + if opts.GetMachinePrivateKey == nil { + opts.GetMachinePrivateKey = func() (key.MachinePrivate, error) { return key.NewMachine(), nil } + } + + var err error + node.Direct, err = cc.NewDirect(opts) + if err != nil { + return nil, fmt.Errorf("NewDirect: %w", err) + } + + return node, nil +} + +// NewNodeDirectAsync returns a Node based on cc.Direct, a tiny tailscale +// client made for direct connection and testing. +func NewNodeDirectAsync(_ context.Context, opts NodeOpts) <-chan Node { + ch := make(chan Node, 1) + go func() { + defer close(ch) + n, err := NewNodeDirect(opts) + if err != nil { + return + } + ch <- n + }() + return ch +} + +// Name returns the name of the node. +func (n *NodeDirect) Name() string { + return uuidToHostname(n.uuid) +} + +// WaitRunning blocks until the node is logged in and has received +// the first netmap update. +func (n *NodeDirect) WaitRunning(ctx context.Context) error { + for !n.loggedIn { + select { + case <-ctx.Done(): + return ctx.Err() + case <-time.After(500 * time.Millisecond): + } + } + + if n.nmChan == nil { + return fmt.Errorf("nmChan is nil, netmap channel not started") + } + + return nil +} + +// Start starts the node and waits for it to be logged in. +// When the node is logged in it will start listening for netmap updates. +func (n *NodeDirect) Start(ctx context.Context) error { + joining.WithLabelValues(n.tagsLabel).Inc() + defer joining.WithLabelValues(n.tagsLabel).Dec() + + loginStart := time.Now() + loginFlag := cc.LoginDefault + if n.ephemeral { + loginFlag |= cc.LoginEphemeral + } + _, err := n.Direct.TryLogin(ctx, loginFlag) + if err != nil { + return fmt.Errorf("TryLogin: %w", err) + } + + if n.tracker != nil { + n.tracker.Start(n.uuid, n.tagsLabel) + } + + n.loggedIn = true + loginDone := time.Since(loginStart) + + firstNetMapStart := time.Now() + nm, nmChan, err := n.waitForNetmapUpdates(ctx) + if err != nil { + return fmt.Errorf("getting initial netmap: %w", err) + } + firstNetMapStartDone := time.Since(firstNetMapStart) + + n.nmChan = nmChan + n.stableID = nm.SelfNode.StableID() + n.nodeID = nm.SelfNode.ID() + n.stats.LoginDur = loginDone + n.stats.FirstNetMapDur = firstNetMapStartDone + n.stats.PeerCount = len(nm.Peers) + + log.Printf("node %q joined, login: %s, firstnm: %s, peercount: %d", n.Name(), loginDone.String(), firstNetMapStartDone.String(), len(nm.Peers)) + + nodeJoins.WithLabelValues(n.tagsLabel).Inc() + connectedNodes.WithLabelValues(n.tagsLabel).Inc() + joinLatencies.Observe(time.Since(loginStart).Seconds()) + + return err +} + +func (n *NodeDirect) Close(ctx context.Context) error { + defer nodeDisconnects.WithLabelValues(n.tagsLabel).Inc() + defer connectedNodes.WithLabelValues(n.tagsLabel).Dec() + err := n.Direct.TryLogout(ctx) + if err != nil { + return err + } + return n.Direct.Close() +} + +func (n *NodeDirect) Status(context.Context) (*ipnstate.Status, error) { + st := &ipnstate.Status{ + Self: &ipnstate.PeerStatus{ + ID: n.stableID, + }, + BackendState: ipn.Stopped.String(), + } + if n.loggedIn { + st.BackendState = ipn.Running.String() + } + + return st, nil +} + +func (n *NodeDirect) Stats() *NodeStats { + return &n.stats +} + +// NetmapUpdaterFunc implements controlclient.NetmapUpdater using a func. +type NetmapUpdaterFunc func(*netmap.NetworkMap) + +func (f NetmapUpdaterFunc) UpdateFullNetmap(nm *netmap.NetworkMap) { + f(nm) +} + +// WaitForNetmapUpdates starts a netmap poll in a new goroutine and returns the +// first netmap and a channel to listen on for future netmap updates. It also +// returns a channel to listen on for errors. The channels are closed after the +// netmap poll returns, and are automatically drained on test completion. +func (n *NodeDirect) waitForNetmapUpdates(ctx context.Context) (*netmap.NetworkMap, <-chan *netmap.NetworkMap, error) { + // buffered channel with netmaps. 50 is chosen arbitrarily. + nmChan := make(chan *netmap.NetworkMap, 50) + name := n.Name() + + go func() { + defer close(nmChan) + if n.tracker != nil { + defer n.tracker.Done(n.uuid) + } + + count := 0 + n.PollNetMap(ctx, NetmapUpdaterFunc(func(nm *netmap.NetworkMap) { + count++ + log.Printf("Received %q netmap update (#%d), self: %s, peercount %d", name, count, nm.SelfNode.Name(), len(nm.Peers)) + + // Only send the first netmap update, currently there is nothing + // draining these and only the first one is used to determine if + // the node is running. + // TODO(kradalby): Put them back on the channel when there is a use + // for them. + if count == 1 { + nmChan <- nm + } + + if n.tracker != nil { + n.tracker.ProcessNetmap(n.uuid, nm) + } + })) + }() + nm, ok := <-nmChan + if !ok { + return nil, nil, fmt.Errorf("did not receive initial netmap") + } + return nm, nmChan, nil +} diff --git a/util/chaos/scenario-ci-churn.go b/util/chaos/scenario-ci-churn.go new file mode 100644 index 000000000..196aa2041 --- /dev/null +++ b/util/chaos/scenario-ci-churn.go @@ -0,0 +1,295 @@ +package chaos + +import ( + "context" + "fmt" + "log" + "sync" + "time" + + "github.com/creachadair/command" + "github.com/creachadair/flax" + xmaps "golang.org/x/exp/maps" + "golang.org/x/sync/semaphore" + "tailscale.com/client/tailscale/v2" +) + +var ciChurnArgs struct { + Verbose int `flag:"verbose,default=0,Print verbose output, 0: no, 1: yes, 2: very verbose"` + JoinTimeout time.Duration `flag:"join-timeout,default=30s,Timeout for a node joining the tailnet"` + JoinParallelism int `flag:"join-parallelism,default=150,Number of nodes to join in parallel"` + NodeType string `flag:"node-type,default=direct,Type of node to create, one of: direct (lightweight) or tsnet (full)"` +} + +var ciChurnCmd = &command.C{ + Name: "ci-churn", + Help: `Join a set of service nodes and a set of high Churn CI nodes to a tailscale network.`, + + Run: command.Adapt(runCIChurn), + SetFlags: command.Flags(flax.MustBind, &ciChurnArgs), +} + +func runCIChurn(env *command.Env) error { + tc, err := NewControl(baseArgs.LoginServer, baseArgs.Tailnet, baseArgs.ApiKey) + if err != nil { + return err + } + + chaos := NewChaos(tc) + + if baseArgs.ApiKey == "" { + return fmt.Errorf("--apikey is required") + } + + type taggedNodesSpec struct { + opts NodeOpts + count int + newFunc NewNodeFunc + } + + jobSpec := func(c int, eph bool) taggedNodesSpec { + o := NodeOpts{ + loginServer: baseArgs.LoginServer, + ephemeral: eph, + } + setVerboseOptionsFromFlag(&o, ciChurnArgs.Verbose) + return taggedNodesSpec{ + count: c, + opts: o, + newFunc: NewNodeDirectAsync, + } + } + + // 100 admins that can access anything + userSpec := map[string]taggedNodesSpec{ + "tag:user-admins": jobSpec(100, false), + } + // 100 groups of developers that can access their own services, + // 3 devices per group. + const numDevs = 100 + for i := range numDevs { + userSpec[fmt.Sprintf("tag:user-dev%d", i)] = jobSpec(3, false) + } + + commonTaggedSpec := map[string]taggedNodesSpec{} + // 100 common services that can be accessed from all CI jobs. + for i := range 100 { + commonTaggedSpec[fmt.Sprintf("tag:svc-common%d", i)] = jobSpec(3, false) + } + appTaggedSpec := map[string]taggedNodesSpec{} + // 300 app-specific services that can be accessed from app-specific CI jobs. + const numApps = 300 + for i := range numApps { + appTaggedSpec[fmt.Sprintf("tag:svc-app%d", i)] = jobSpec(3, false) + } + + ciSpec := map[string]taggedNodesSpec{ + // 4100 nodes in the common CI pool. + "tag:ci-common": jobSpec(4100, true), + } + // 300 app-specific CI services. + for i := range numApps { + ciSpec[fmt.Sprintf("tag:ci-app%d", i)] = jobSpec(3, true) + } + + s := Scenario{ + BeforeSteps: func() error { + if baseArgs.RemoveAll { + if err := chaos.RemoveAllNodes(env); err != nil { + return fmt.Errorf("removing all nodes: %w", err) + } + } + + // TODO: can make this read by CLI + o := []string{"insecure@example.com"} + allTags := append(append(append(xmaps.Keys(userSpec), xmaps.Keys(commonTaggedSpec)...), xmaps.Keys(ciSpec)...), xmaps.Keys(appTaggedSpec)...) + pol := tailscale.ACL{ + TagOwners: tagsToTagOwners(o, allTags), + ACLs: []tailscale.ACLEntry{ + { + // Admins can access everything. + Action: "accept", + Source: []string{"tag:user-admins"}, + Destination: []string{"*:22", "*:80", "*:443"}, + }, + { + // All CI jobs can access common tagged services. + Action: "accept", + Source: xmaps.Keys(ciSpec), + Destination: tagsToDst(xmaps.Keys(commonTaggedSpec), "80"), + }, + }, + } + for i := range numApps { + pol.ACLs = append(pol.ACLs, tailscale.ACLEntry{ + // App-specific CI jobs can access app-specific services. + Action: "accept", + Source: []string{fmt.Sprintf("tag:ci-app%d", i)}, + Destination: []string{fmt.Sprintf("tag:svc-app%d:80", i)}, + }) + } + for i := range numDevs { + pol.ACLs = append(pol.ACLs, tailscale.ACLEntry{ + // Developers can access their services + Action: "accept", + Source: []string{fmt.Sprintf("tag:user-dev%d", i)}, + Destination: []string{ + fmt.Sprintf("tag:svc-app%d:80", i), + fmt.Sprintf("tag:svc-app%d:80", 100+i), + fmt.Sprintf("tag:svc-app%d:80", 200+i), + }, + }) + } + ctx, cancel := context.WithTimeout(env.Context(), 10*time.Second) + defer cancel() + + if _, err := chaos.SetACL(ctx, pol); err != nil { + return err + } + + return nil + }, + Steps: []Step{ + {Run: func() error { + parallelismSem := semaphore.NewWeighted(int64(ciChurnArgs.JoinParallelism)) + var wg sync.WaitGroup + + // CI services. + ciTicker := newJitteredTicker(env.Context(), 3*time.Minute, time.Second/20) + for tag, spec := range ciSpec { + wg.Add(1) + go func() { + defer wg.Done() + NodeGroupSimulator(env.Context(), chaos, spec.opts, tag, spec.count, parallelismSem, ciTicker) + }() + } + + // Tagged services, churning every 30 min. + taggedTicker := newJitteredTicker(env.Context(), 3*time.Minute, 30*time.Minute) + for tag, spec := range commonTaggedSpec { + wg.Add(1) + go func() { + defer wg.Done() + NodeGroupSimulator(env.Context(), chaos, spec.opts, tag, spec.count, parallelismSem, taggedTicker) + }() + } + for tag, spec := range appTaggedSpec { + wg.Add(1) + go func() { + defer wg.Done() + NodeGroupSimulator(env.Context(), chaos, spec.opts, tag, spec.count, parallelismSem, taggedTicker) + }() + } + + // User nodes, churning every 1hr. + userTicker := newJitteredTicker(env.Context(), 3*time.Minute, 1*time.Hour) + for tag, spec := range userSpec { + wg.Add(1) + go func() { + defer wg.Done() + NodeGroupSimulator(env.Context(), chaos, spec.opts, tag, spec.count, parallelismSem, userTicker) + }() + } + + wg.Wait() + return nil + }}, + }, + TearDown: func() error { + log.Printf("Tearing down scenario") + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + if err := chaos.removeAllNodes(ctx); err != nil { + return fmt.Errorf("removing all nodes: %w", err) + } + return nil + }, + } + + err = s.Run() + if err != nil { + log.Printf("Error running scenario: %v", err) + return err + } + return nil +} + +// NodeGroupSimulator simulates a group of nodes joining and leaving a network. +// When a node joins the network, it will join with a tag and an authkey. +// The node will then leave the network after a random amount of time. +// A new node will join the network for a new random amount of time. +// TODO(kradalby): rename +func NodeGroupSimulator(ctx context.Context, chaos *Chaos, opts NodeOpts, tag string, c int, parallelismSem *semaphore.Weighted, stop jticker) { + sem := semaphore.NewWeighted(int64(c)) + + key, err := chaos.Control.CreatAuthKey(ctx, opts.ephemeral, []string{tag}) + if err != nil { + log.Printf("failed to create authkey: %s", err) + errCount.WithLabelValues("authkey").Inc() + return + } + opts.authKey = key + opts.tags = []string{tag} + + for { + if err := sem.Acquire(ctx, 1); err != nil { + log.Printf("failed to acquire semaphore: %v", err) + return + } + if err := parallelismSem.Acquire(ctx, 1); err != nil { + log.Printf("failed to acquire parallelism semaphore: %v", err) + return + } + + go func() { + defer sem.Release(1) + err := NewLimitedLifetimeNode(ctx, nodeFuncFromFlag(ciChurnArgs.NodeType), func() { + parallelismSem.Release(1) + }, opts, stop) + if err != nil { + log.Printf("error creating limited lifetime node: %v", err) + errCount.WithLabelValues("createnode").Inc() + } + }() + } +} + +// NewLimitedLifetimeNode creates a new node, starts it, waits for it to be running, +// and then closes it after the given lifetime. +// The node is created using the given NewNodeFunc. +// This function should be spawned in a go routine, it is closed by declaring the context Done. +func NewLimitedLifetimeNode(ctx context.Context, newFunc NewNodeFunc, loginDoneFunc func(), opts NodeOpts, stop jticker) error { + pending := newFunc(ctx, opts) + node, ok := <-pending + if !ok { + loginDoneFunc() + return fmt.Errorf("failed to create node") + } + + err := node.Start(ctx) + if err != nil { + loginDoneFunc() + return fmt.Errorf("failed to start node: %w", err) + } + + loginDoneFunc() + + err = node.WaitRunning(ctx) + if err != nil { + return fmt.Errorf("failed to wait for node to be running: %w", err) + } + + select { + case <-stop: + case <-ctx.Done(): + } + + closeCtx, close := context.WithTimeout(context.Background(), 30*time.Second) + defer close() + err = node.Close(closeCtx) + if err != nil { + return fmt.Errorf("failed to close node: %w", err) + } + + return nil +} diff --git a/util/chaos/scenario-join-n-nodes.go b/util/chaos/scenario-join-n-nodes.go new file mode 100644 index 000000000..856f309f2 --- /dev/null +++ b/util/chaos/scenario-join-n-nodes.go @@ -0,0 +1,138 @@ +package chaos + +import ( + "context" + "fmt" + "log" + "os" + "path" + "runtime/pprof" + "time" + + "github.com/creachadair/command" + "github.com/creachadair/flax" +) + +var joinNNodesArgs struct { + Count int `flag:"node-count,default=1,Number of nodes to join to the network"` + Verbose int `flag:"verbose,default=0,Print verbose output, 0: no, 1: yes, 2: very verbose"` + JoinTimeout time.Duration `flag:"join-timeout,default=30s,Timeout for a node joining the tailnet"` + JoinParallelism int `flag:"join-parallelism,default=50,Number of nodes to join in parallel"` + MemoryHeapProfilePath string `flag:"heap-pprof-path,Save a memory profile after the main step to this path"` + NodeType string `flag:"node-type,default=direct,Type of node to create, one of: direct (lightweight) or tsnet (full)"` + OutputDir string `flag:"output-dir,Directory to save output files"` +} + +var joinNNodesCmd = &command.C{ + Name: "join-n-nodes", + Help: `Join N nodes to a tailscale network.`, + + Run: command.Adapt(runJoinNNNodes), + SetFlags: command.Flags(flax.MustBind, &joinNNodesArgs), +} + +func runJoinNNNodes(env *command.Env) error { + tc, err := NewControl(baseArgs.LoginServer, baseArgs.Tailnet, baseArgs.ApiKey) + if err != nil { + return err + } + + chaos := NewChaos(tc) + + authKey := baseArgs.AuthKey + if authKey == "" { + if baseArgs.ApiKey == "" { + return fmt.Errorf("either --authkey or --apikey is required") + } + log.Printf("Auth key not provided; creating one...") + key, err := chaos.Control.CreatAuthKey(env.Context(), false, nil) + if err != nil { + return err + } + + authKey = key + } + opts := NodeOpts{ + loginServer: baseArgs.LoginServer, + authKey: authKey, + ephemeral: true, + } + + setVerboseOptionsFromFlag(&opts, joinNNodesArgs.Verbose) + + var nm *NodeMap + + s := Scenario{ + BeforeSteps: func() error { + if baseArgs.RemoveAll { + if err := chaos.RemoveAllNodes(env); err != nil { + return fmt.Errorf("removing all nodes: %w", err) + } + } + return nil + }, + Steps: []Step{ + { + Run: func() error { + var err error + log.Printf("Login server: %s, authkey: %s", opts.loginServer, opts.authKey) + log.Printf("Creating %d nodes", joinNNodesArgs.Count) + nm, err = NewNodeMapWithNodes(env.Context(), nodeFuncFromFlag(joinNNodesArgs.NodeType), joinNNodesArgs.Count, opts) + if err != nil { + return fmt.Errorf("creating nodes: %w", err) + } + + log.Printf("Joining %d nodes to the network", joinNNodesArgs.Count) + + if err := nm.StartAll(env.Context(), joinNNodesArgs.JoinParallelism); err != nil { + return fmt.Errorf("starting nodes: %w", err) + } + + ctx, cancel := context.WithTimeout(env.Context(), joinNNodesArgs.JoinTimeout) + defer cancel() + + ready := time.Now() + if err := nm.WaitForReady(ctx); err != nil { + return fmt.Errorf("waiting for ts-es to be ready: %w", err) + } + log.Printf("All nodes are ready in %s", time.Since(ready)) + + return nil + }, + AfterStep: func() error { + if joinNNodesArgs.MemoryHeapProfilePath != "" { + f, err := os.Create(joinNNodesArgs.MemoryHeapProfilePath) + if err != nil { + return err + } + pprof.WriteHeapProfile(f) + f.Close() + } + + return nil + }, + }, + }, + TearDown: func() error { + log.Printf("Tearing down scenario") + // Use a new context here to be able to clean up the nodes if the + // main context is canceled. + ctx := context.Background() + + if dir := joinNNodesArgs.OutputDir; dir != "" { + if _, err := os.Stat(dir); os.IsNotExist(err) { + err = os.MkdirAll(dir, 0755) + if err != nil { + return fmt.Errorf("creating output directory: %w", err) + } + } + p := path.Join(joinNNodesArgs.OutputDir, fmt.Sprintf("join-n-nodes-%d-%s.json", joinNNodesArgs.Count, time.Now().Format(TimeFileNameFormat))) + nm.SaveStatusToFile(p) + } + + return nm.CloseAndDeleteAll(ctx, chaos) + }, + } + + return s.Run() +} diff --git a/util/chaos/scenario.go b/util/chaos/scenario.go new file mode 100644 index 000000000..57cf9dddd --- /dev/null +++ b/util/chaos/scenario.go @@ -0,0 +1,91 @@ +package chaos + +import ( + "fmt" + "log" + + "tailscale.com/util/multierr" +) + +type Scenario struct { + // BeforeSteps and AfterSteps are run before any step and after all steps. + BeforeSteps func() error + AfterSteps func() error + + // BeforeStep and AfterStep are run before and after each step, respectively. + // Can be used to start a profiling of control for a given step. + BeforeStep func() error + AfterStep func() error + + // Steps to run in order. + Steps []Step + + // TearDown is run after all steps are run, regardless of success or failure. + TearDown func() error + + // ContinueOnError, if true, will continue to run steps even if one fails. + ContinueOnError bool +} + +type Step struct { + Run func() error + + // BeforeStep and AfterStep are run before and after the step. + // Can be used to start a profiling of control for a given step. + BeforeStep func() error + AfterStep func() error +} + +func (s *Scenario) Run() (err error) { + defer func() { + if s.TearDown != nil { + terr := s.TearDown() + if terr != nil { + err = fmt.Errorf("TearDown: %w", terr) + } + } + }() + if s.BeforeSteps != nil { + if err := s.BeforeSteps(); err != nil { + return fmt.Errorf("BeforeSteps: %w", err) + } + } + var errs []error + for _, step := range s.Steps { + if s.BeforeStep != nil { + if err := s.BeforeStep(); err != nil { + return fmt.Errorf("Before each step: %w", err) + } + } + if step.BeforeStep != nil { + if err := step.BeforeStep(); err != nil { + return fmt.Errorf("BeforeStep %w", err) + } + } + if err := step.Run(); err != nil { + log.Printf("Step failed: %s", err) + errs = append(errs, err) + if !s.ContinueOnError { + break + } + } + if step.AfterStep != nil { + if err := step.AfterStep(); err != nil { + return fmt.Errorf("AfterStep: %w", err) + } + } + + if s.AfterStep != nil { + if err := s.AfterStep(); err != nil { + return fmt.Errorf("After each step: %w", err) + } + } + } + if s.AfterSteps != nil { + if err := s.AfterSteps(); err != nil { + return fmt.Errorf("AfterSteps: %w", err) + } + } + + return multierr.New(errs...) +} diff --git a/util/chaos/scenario_test.go b/util/chaos/scenario_test.go new file mode 100644 index 000000000..585bd74a3 --- /dev/null +++ b/util/chaos/scenario_test.go @@ -0,0 +1,129 @@ +package chaos + +import ( + "testing" + + "github.com/google/go-cmp/cmp" +) + +func TestScenario(t *testing.T) { + tests := []struct { + name string + scenarioFunc func() ([]int, error) + want []int + }{ + { + name: "scenario-only-step", + scenarioFunc: func() ([]int, error) { + var got []int + s := Scenario{ + Steps: []Step{ + { + Run: func() error { + got = append(got, 1) + return nil + }, + }, + }, + } + return got, s.Run() + }, + want: []int{1}, + }, + { + name: "scenario-only-steps", + scenarioFunc: func() ([]int, error) { + var got []int + s := Scenario{ + Steps: []Step{ + { + Run: func() error { + got = append(got, 1) + return nil + }, + }, + { + Run: func() error { + got = append(got, 2) + return nil + }, + }, + }, + } + return got, s.Run() + }, + want: []int{1, 2}, + }, + { + name: "scenario-everything", + scenarioFunc: func() ([]int, error) { + var got []int + s := Scenario{ + BeforeSteps: func() error { + got = append(got, 1) + return nil + }, + BeforeStep: func() error { + got = append(got, 2) + return nil + }, + Steps: []Step{ + { + Run: func() error { + got = append(got, 3) + return nil + }, + }, + { + BeforeStep: func() error { + got = append(got, 4) + return nil + }, + Run: func() error { + got = append(got, 5) + return nil + }, + AfterStep: func() error { + got = append(got, 6) + return nil + }, + }, + }, + AfterStep: func() error { + got = append(got, 7) + return nil + }, + AfterSteps: func() error { + got = append(got, 8) + return nil + }, + TearDown: func() error { + got = append(got, 9) + return nil + }, + } + return got, s.Run() + }, + want: []int{1, 2, 3, + // "out of order" is expected as this + // is the AfterStep and BeforeStep called + // for each function + 7, 2, + 4, 5, 6, 7, 8, 9}, + }, + // TODO(kradalby): Add test cases for errors and continueOnError + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := tt.scenarioFunc() + if err != nil { + t.Errorf("scenarioFunc() error = %v", err) + } + + if diff := cmp.Diff(tt.want, got); diff != "" { + t.Errorf("unexpected scenario order (-want +got):\n%s", diff) + } + }) + } +} diff --git a/util/chaos/util.go b/util/chaos/util.go new file mode 100644 index 000000000..8d36e46d5 --- /dev/null +++ b/util/chaos/util.go @@ -0,0 +1,393 @@ +package chaos + +import ( + "context" + "errors" + "fmt" + "log" + "math/rand/v2" + "strings" + "sync" + "time" + "unicode" + + "github.com/google/uuid" + "github.com/prometheus/client_golang/prometheus" + "github.com/prometheus/client_golang/prometheus/promauto" + "tailscale.com/syncs" + "tailscale.com/types/netmap" +) + +var errCount = promauto.NewCounterVec(prometheus.CounterOpts{ + Name: "chaos_errors_total", + Help: "Number of errors", +}, []string{"type"}) + +const TimeFileNameFormat = "20060102T150405Z" + +func setVerboseOptionsFromFlag(opts *NodeOpts, verbose int) { + switch verbose { + case 0: + opts.userLogf = func(format string, args ...any) {} + case 1: + opts.userLogf = log.Printf + case 2: + opts.userLogf = log.Printf + opts.logf = log.Printf + } +} + +func nodeFuncFromFlag(flag string) NewNodeFunc { + switch flag { + case "direct": + return NewNodeDirectAsync + case "tsnet": + return NewNodeTSNetAsync + default: + log.Fatalf("Unknown node type: %s", joinNNodesArgs.NodeType) + return nil + } +} + +func tagsToDst(tags []string, port string) []string { + dsts := make([]string, len(tags)) + for i, tag := range tags { + dsts[i] = fmt.Sprintf("%s:%s", tag, port) + } + return dsts +} + +func tagsToTagOwners(owners []string, tags []string) map[string][]string { + m := make(map[string][]string) + for _, tag := range tags { + m[tag] = owners + } + + return m +} + +// tagsMetricLabel returns a string representation of the tags for use in +// metric labels. If noSuffix is true, it will remove any numeric suffixes +// from the tags. +func tagsMetricLabel(tags []string, fullLabels bool) string { + if len(tags) == 0 { + return "" + } + trim := func(tag string) string { + if !fullLabels { + tag = removeNumericSuffix(tag) + } + return strings.TrimPrefix(tag, "tag:") + } + var b strings.Builder + b.WriteString(trim(tags[0])) + for _, tag := range tags[1:] { + b.WriteString(",") + b.WriteString(trim(tag)) + } + return b.String() +} + +// removeNumericSuffix removes the numeric suffix from the input string. +func removeNumericSuffix(input string) string { + // Find the position where the numeric suffix starts + for i := len(input) - 1; i >= 0; i-- { + if !unicode.IsDigit(rune(input[i])) { + return input[:i+1] + } + } + // If the whole string is numeric, return an empty string + return input +} + +// netmapLatencyTracker measures latency between the time a new node +// joins the network and the time it first appears in any of the other nodes' +// netmaps. It relies on chaos nodes having a hostname of the form "chaos-". +type netmapLatencyTracker struct { + countNeverSeen prometheus.Counter + countNotFullySeen prometheus.Counter + + firstSeenLatencies *prometheus.HistogramVec + allSeenLatencies *prometheus.HistogramVec + numUnseenFirst *prometheus.GaugeVec + numUnseenAll *prometheus.GaugeVec + + // visibilityUpdates is our queue of updates to unseedFirst/unseenAll which + // can block. This is a syncs.Map just to ensure it has an independent + // synchronisation mechanism which is not mu. + visibilityUpdates syncs.Map[uuid.UUID, visibilityUpdate] + // visibilityUpdateReadyCh is updated when more work is available on + // visibilityUpdates. + visibilityUpdateReadyCh chan struct{} + + mu sync.Mutex + + // visibility is each node's list of peers' IDs. + // A node does not appear in visibility until we receive its first netmap. + visibility map[uuid.UUID]map[uuid.UUID]time.Time // node => peers => first seen + + // unseenFirst is a map of node IDs that have joined the network but + // have not yet appeared in a netmap of any other node. + unseenFirst map[uuid.UUID]nodeStart + // unseenAll is a map of node IDs that have joined the network but + // have not yet appeared in the netmaps of all other nodes. + unseenAll map[uuid.UUID]nodeStart +} + +type visibilityUpdate struct { + t time.Time + self uuid.UUID + peers map[uuid.UUID]bool + deleted bool +} + +type nodeStart struct { + start time.Time + tagsLabel string +} + +var latencyTracker *netmapLatencyTracker +var latencyTrackerOnce sync.Once + +// newLatencyTracker returns a new netmapLatencyTracker singleton. +func newLatencyTracker() *netmapLatencyTracker { + latencyTrackerOnce.Do(func() { + latencyTracker = &netmapLatencyTracker{ + countNeverSeen: promauto.NewCounter(prometheus.CounterOpts{ + Name: "chaos_netmap_tracker_never_seen_nodes", + Help: "Number of nodes that disappeared before they were seen in any other netmaps", + }), + countNotFullySeen: promauto.NewCounter(prometheus.CounterOpts{ + Name: "chaos_netmap_tracker_not_fully_seen_nodes", + Help: "Number of nodes that disappeared before they were seen in all other nodes' netmaps", + }), + + firstSeenLatencies: promauto.NewHistogramVec(prometheus.HistogramOpts{ + Name: "chaos_netmap_distribution_latency_seconds", + Help: "Time it took for a new node to be visible in a single netmap", + Buckets: prometheus.ExponentialBucketsRange(0.01, 30, 20), + }, []string{"tags"}), + allSeenLatencies: promauto.NewHistogramVec(prometheus.HistogramOpts{ + Name: "chaos_netmap_distribution_all_latency_seconds", + Help: "Time it took for a new node to be visible in all netmaps", + Buckets: prometheus.ExponentialBucketsRange(0.01, 30, 20), + }, []string{"tags"}), + numUnseenFirst: promauto.NewGaugeVec(prometheus.GaugeOpts{ + Name: "chaos_netmap_tracker_pending_nodes", + Help: "Number of nodes yet to appear in any other node's netmap", + }, []string{"tags"}), + numUnseenAll: promauto.NewGaugeVec(prometheus.GaugeOpts{ + Name: "chaos_netmap_tracker_pending_all_nodes", + Help: "Number of nodes yet to be appear in all other nodes' netmaps", + }, []string{"tags"}), + + visibilityUpdateReadyCh: make(chan struct{}), + visibility: make(map[uuid.UUID]map[uuid.UUID]time.Time), + + unseenFirst: make(map[uuid.UUID]nodeStart), + unseenAll: make(map[uuid.UUID]nodeStart), + } + go latencyTracker.backgroundNetmapUpdater() + }) + return latencyTracker +} + +func (t *netmapLatencyTracker) processUpdate(u visibilityUpdate) { + t.mu.Lock() + defer t.mu.Unlock() + + if u.deleted { + id := u.self + a, neverSeen := t.unseenFirst[id] + b, notAllSeen := t.unseenAll[id] + delete(t.unseenAll, id) + delete(t.unseenFirst, id) + + if neverSeen { + t.numUnseenFirst.WithLabelValues(a.tagsLabel).Dec() + t.countNeverSeen.Inc() + } + if notAllSeen { + t.numUnseenAll.WithLabelValues(b.tagsLabel).Dec() + t.countNotFullySeen.Inc() + } + + seen, ok := t.visibility[id] + if ok { + delete(t.visibility, id) + for p := range seen { + t.checkAllVisible(p) + } + } + return + } + + // Patch the visibility match. + if t.visibility[u.self] == nil { + t.visibility[u.self] = make(map[uuid.UUID]time.Time) + } + + for p := range u.peers { + vt, ok := t.visibility[u.self][p] + if ok && u.t.Before(vt) { + delete(u.peers, p) + continue + } + t.visibility[u.self][p] = u.t + } + + // u.peers now only newly-visible nodes. + for p := range u.peers { + if node, ok := t.unseenFirst[p]; ok { + t.numUnseenFirst.WithLabelValues(node.tagsLabel).Dec() + t.firstSeenLatencies.WithLabelValues(node.tagsLabel).Observe(u.t.Sub(node.start).Seconds()) + delete(t.unseenFirst, p) + } + t.checkAllVisible(p) + } + t.checkAllVisible(u.self) +} + +func (t *netmapLatencyTracker) sendUpdate(u visibilityUpdate) { + t.visibilityUpdates.Store(u.self, u) + select { + case t.visibilityUpdateReadyCh <- struct{}{}: + default: + } +} + +func (t *netmapLatencyTracker) backgroundNetmapUpdater() { + for { + <-t.visibilityUpdateReadyCh + for { + _, upd, found := takeItem(&t.visibilityUpdates) + if !found { + break + } + t.processUpdate(upd) + } + } +} + +// takeItem deletes and returns the first key and it's value visited when +// ranging over the underlying map. +func takeItem[K comparable, V any](m *syncs.Map[K, V]) (key K, val V, ok bool) { + m.WithLock(func(m map[K]V) { + for k, v := range m { + key, val, ok = k, v, true + delete(m, k) + return + } + }) + return +} + +func (t *netmapLatencyTracker) checkAllVisible(p uuid.UUID) { + node, ok := t.unseenAll[p] + if !ok { + return + } + if t.visibility[p] == nil { + return + } + var latest time.Time + for q := range t.visibility[p] { + t, ok := t.visibility[q][p] + if !ok { + // if p can see q, but q does not have a netmap, then assume that + // p is an older node than q. We mostly only care about p being seen + // by nodes older than it. + continue + } + if t.After(latest) { + latest = t + } + } + + t.numUnseenAll.WithLabelValues(node.tagsLabel).Dec() + t.allSeenLatencies.WithLabelValues(node.tagsLabel).Observe(latest.Sub(node.start).Seconds()) + delete(t.unseenAll, p) +} + +// Start records node join time, it should be called after a new node +// has joined the network, with that node's UUID. +func (t *netmapLatencyTracker) Start(id uuid.UUID, tagsLabel string) { + t.mu.Lock() + defer t.mu.Unlock() + start := nodeStart{start: time.Now(), tagsLabel: tagsLabel} + t.unseenFirst[id] = start + t.numUnseenFirst.WithLabelValues(tagsLabel).Inc() + t.unseenAll[id] = start + t.numUnseenAll.WithLabelValues(tagsLabel).Inc() +} + +func (t *netmapLatencyTracker) Done(id uuid.UUID) { + t.sendUpdate(visibilityUpdate{ + t: time.Now(), + self: id, + deleted: true, + }) +} + +// ProcessNetmap should be called every time a new netmap is received. +func (t *netmapLatencyTracker) ProcessNetmap(self uuid.UUID, nm *netmap.NetworkMap) { + seen := make(map[uuid.UUID]bool) + for _, p := range nm.Peers { + id, err := hostnameToUUID(p.Hostinfo().Hostname()) + if err != nil { + log.Printf("Failed to parse UUID from hostname %q: %v", p.Hostinfo().Hostname(), err) + errCount.WithLabelValues("tracker-parse-uuid").Inc() + continue + } + seen[id] = true + } + t.sendUpdate(visibilityUpdate{ + t: time.Now(), + self: self, + peers: seen, + }) +} + +// uuidToHostname converts a UUID to a hostname. +func uuidToHostname(id uuid.UUID) string { + return "chaos-" + id.String() +} + +// hostnameToUUID converts a hostname to a UUID. It expects the hostname +// to have the format "chaos-". +func hostnameToUUID(hostname string) (uuid.UUID, error) { + hid, ok := strings.CutPrefix(hostname, "chaos-") + if !ok { + return uuid.Nil, errors.New("hostname does not have the expected prefix") + } + return uuid.Parse(hid) +} + +// jticker is a jittered ticker that sends a message to the channel +// at a given interval with +/- 10% jitter. +type jticker <-chan struct{} + +// newJitteredTicker creates a new jittered ticker. +func newJitteredTicker(ctx context.Context, after, every time.Duration) jticker { + ch := make(chan struct{}) + go func() { + if after > every { + select { + case <-time.After(after - every): + case <-ctx.Done(): + return + } + } + for { + delay := time.Duration(float64(time.Second) * every.Seconds() * (0.9 + 0.2*rand.Float64())) + select { + case <-time.After(delay): + ch <- struct{}{} + case <-ctx.Done(): + close(ch) + return + } + } + }() + return ch +} diff --git a/util/chaos/util_test.go b/util/chaos/util_test.go new file mode 100644 index 000000000..5e7448f09 --- /dev/null +++ b/util/chaos/util_test.go @@ -0,0 +1,195 @@ +package chaos + +import ( + "cmp" + "slices" + "testing" + "time" + + qt "github.com/frankban/quicktest" + "github.com/google/uuid" + "github.com/prometheus/client_golang/prometheus/testutil" + + xmaps "golang.org/x/exp/maps" +) + +func TestTagsMetricLabel(t *testing.T) { + tests := []struct { + name string + tags []string + fullLabels bool + want string + }{ + { + name: "empty1", + tags: []string{}, + fullLabels: true, + want: "", + }, + { + name: "empty2", + tags: []string{}, + fullLabels: false, + want: "", + }, + { + name: "one_trimmed", + tags: []string{"tag:foo15"}, + fullLabels: false, + want: "foo", + }, + { + name: "one_full", + tags: []string{"tag:foo15"}, + fullLabels: true, + want: "foo15", + }, + { + name: "two_trimmed", + tags: []string{"tag:foo15", "tag:bar"}, + fullLabels: false, + want: "foo,bar", + }, + { + name: "one_full", + tags: []string{"tag:foo", "tag:bar0"}, + fullLabels: true, + want: "foo,bar0", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := tagsMetricLabel(tt.tags, tt.fullLabels); got != tt.want { + t.Errorf("tagsMetricLabel(%v, %v) = %v, want %v", tt.tags, tt.fullLabels, got, tt.want) + } + }) + } +} + +func TestUUIDConversion(t *testing.T) { + tests := []uuid.UUID{ + uuid.MustParse("123e4567-e89b-12d3-a456-426614174000"), + uuid.MustParse("00000000-0000-0000-0000-000000000000"), + } + for _, tt := range tests { + hostname := uuidToHostname(tt) + got, err := hostnameToUUID(hostname) + if err != nil { + t.Errorf("hostnameToUUID(%q) error = %v", hostname, err) + continue + } + if got != tt { + t.Errorf("RoundTrip failed: uuidToHostname(%v) -> hostnameToUUID(%q) = %v, want %v", tt, hostname, got, tt) + } + } +} + +func TestLatencyTracker(t *testing.T) { + lt := newLatencyTracker() + c := qt.New(t) + + u1 := uuid.UUID{0: 1 * 16} + u2 := uuid.UUID{0: 2 * 16} + u3 := uuid.UUID{0: 3 * 16} + u4 := uuid.UUID{0: 4 * 16} + u5 := uuid.UUID{0: 5 * 16} + + assertNeverSeenCount := func(first, all int) { + t.Helper() + c.Assert(testutil.ToFloat64(lt.countNeverSeen), qt.Equals, float64(first)) + c.Assert(testutil.ToFloat64(lt.countNotFullySeen), qt.Equals, float64(all)) + } + + assertUnseenCount := func(first, all int) { + t.Helper() + c.Assert(len(lt.unseenFirst), qt.Equals, first, qt.Commentf("first: %+v", lt.unseenFirst)) + c.Assert(len(lt.unseenAll), qt.Equals, all, qt.Commentf("first: %+v", lt.unseenAll)) + c.Assert(testutil.ToFloat64(lt.numUnseenFirst.WithLabelValues("foo")), qt.Equals, float64(first)) + c.Assert(testutil.ToFloat64(lt.numUnseenAll.WithLabelValues("foo")), qt.Equals, float64(all)) + } + + assertUnseenFirst := func(uuids ...uuid.UUID) { + t.Helper() + + sortUUIDS(uuids) + keys := xmaps.Keys(lt.unseenFirst) + sortUUIDS(keys) + c.Assert(uuids, qt.DeepEquals, keys) + } + assertUnseenAll := func(uuids ...uuid.UUID) { + t.Helper() + + sortUUIDS(uuids) + got := xmaps.Keys(lt.unseenAll) + sortUUIDS(got) + c.Assert(uuids, qt.DeepEquals, got) + } + + lt.Start(u1, "foo") + assertUnseenCount(1, 1) + assertNeverSeenCount(0, 0) + + lt.Start(u2, "foo") + lt.Start(u3, "foo") + lt.Start(u4, "foo") + lt.Start(u5, "foo") + assertUnseenCount(5, 5) + assertNeverSeenCount(0, 0) + + // u1 saw u2 and u3 + lt.processUpdate(visibilityUpdate{ + t: time.Now(), + self: u1, + peers: map[uuid.UUID]bool{u2: true, u3: true}, + }) + assertUnseenCount(3, 4) + assertNeverSeenCount(0, 0) + assertUnseenFirst(u1, u4, u5) + assertUnseenAll(u2, u3, u4, u5) + + // u2 saw u1 and u3 + lt.processUpdate(visibilityUpdate{ + t: time.Now(), + self: u2, + peers: map[uuid.UUID]bool{u1: true, u3: true}, + }) + // u3 saw u1 and u2 + lt.processUpdate(visibilityUpdate{ + t: time.Now(), + self: u3, + peers: map[uuid.UUID]bool{u1: true, u2: true}, + }) + assertUnseenCount(2, 2) + assertNeverSeenCount(0, 0) + assertUnseenFirst(u4, u5) + assertUnseenAll(u4, u5) + + // u3 saw u4 + lt.processUpdate(visibilityUpdate{ + t: time.Now(), + self: u3, + peers: map[uuid.UUID]bool{u4: true}, + }) + assertUnseenCount(1, 2) + assertNeverSeenCount(0, 0) + assertUnseenFirst(u5) + assertUnseenAll(u4, u5) + + // u4 and u5 are gone. + lt.processUpdate(visibilityUpdate{t: time.Now(), self: u4, deleted: true}) + lt.processUpdate(visibilityUpdate{t: time.Now(), self: u5, deleted: true}) + assertUnseenCount(0, 0) + assertNeverSeenCount(1, 2) +} + +func sortUUIDS(uuids []uuid.UUID) { + slices.SortFunc(uuids, func(a, b uuid.UUID) int { + for i := range len(uuid.UUID{}) { + if a[i] == b[i] { + continue + } + return cmp.Compare(a[i], b[i]) + } + return 0 + }) +}