This commit is contained in:
Fran Bull
2025-01-13 13:29:41 -08:00
parent 836c01258d
commit 14dd2c9297
4 changed files with 1089 additions and 0 deletions

127
tsconsensus/http.go Normal file
View File

@@ -0,0 +1,127 @@
package tsconsensus
import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"time"
)
type joinRequest struct {
RemoteHost string `json:'remoteAddr'`
RemoteID string `json:'remoteID'`
}
type commandClient struct {
port uint16
httpClient *http.Client
}
func (rac *commandClient) Url(host string, path string) string {
return fmt.Sprintf("http://%s:%d%s", host, rac.port, path)
}
func (rac *commandClient) Join(host string, jr joinRequest) error {
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
rBs, err := json.Marshal(jr)
if err != nil {
return err
}
url := rac.Url(host, "/join")
req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(rBs))
if err != nil {
return err
}
resp, err := rac.httpClient.Do(req)
if err != nil {
return err
}
respBs, err := io.ReadAll(resp.Body)
if err != nil {
return err
}
if resp.StatusCode != 200 {
return errors.New(fmt.Sprintf("remote responded %d: %s", resp.StatusCode, string(respBs)))
}
return nil
}
func (rac *commandClient) ExecuteCommand(host string, bs []byte) (CommandResult, error) {
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
url := rac.Url(host, "/executeCommand")
req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(bs))
if err != nil {
return CommandResult{}, err
}
resp, err := rac.httpClient.Do(req)
if err != nil {
return CommandResult{}, err
}
respBs, err := io.ReadAll(resp.Body)
if err != nil {
return CommandResult{}, err
}
if resp.StatusCode != 200 {
return CommandResult{}, errors.New(fmt.Sprintf("remote responded %d: %s", resp.StatusCode, string(respBs)))
}
var cr CommandResult
if err = json.Unmarshal(respBs, &cr); err != nil {
return CommandResult{}, err
}
return cr, nil
}
func (c *Consensus) makeCommandMux() *http.ServeMux {
mux := http.NewServeMux()
mux.HandleFunc("/join", func(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
http.Error(w, "Bad Request", http.StatusBadRequest)
return
}
decoder := json.NewDecoder(r.Body)
var jr joinRequest
err := decoder.Decode(&jr)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
if jr.RemoteHost == "" {
http.Error(w, "Required: remoteAddr", http.StatusBadRequest)
return
}
if jr.RemoteID == "" {
http.Error(w, "Required: remoteID", http.StatusBadRequest)
return
}
err = c.handleJoin(jr)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
})
mux.HandleFunc("/executeCommand", func(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
http.Error(w, "Bad Request", http.StatusBadRequest)
return
}
decoder := json.NewDecoder(r.Body)
var cmd Command
err := decoder.Decode(&cmd)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
result, err := c.executeCommandLocally(cmd)
if err := json.NewEncoder(w).Encode(result); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
})
return mux
}

140
tsconsensus/monitor.go Normal file
View File

@@ -0,0 +1,140 @@
package tsconsensus
import (
"context"
"encoding/json"
"fmt"
"io"
"log"
"net/http"
"slices"
"strings"
"tailscale.com/ipn"
"tailscale.com/ipn/ipnstate"
"tailscale.com/tsnet"
)
type status struct {
Status *ipnstate.Status
RaftState string
}
type monitor struct {
ts *tsnet.Server
con *Consensus
}
func (m *monitor) getStatus(ctx context.Context) (status, error) {
lc, err := m.ts.LocalClient()
if err != nil {
return status{}, err
}
tStatus, err := lc.Status(ctx)
if err != nil {
return status{}, err
}
return status{Status: tStatus, RaftState: m.con.Raft.State().String()}, nil
}
func serveMonitor(c *Consensus, ts *tsnet.Server, listenAddr string) (*http.Server, error) {
ln, err := ts.Listen("tcp", listenAddr)
if err != nil {
return nil, err
}
m := &monitor{con: c, ts: ts}
mux := http.NewServeMux()
mux.HandleFunc("/full", m.handleFullStatus)
mux.HandleFunc("/", m.handleSummaryStatus)
mux.HandleFunc("/netmap", m.handleNetmap)
mux.HandleFunc("/dial", m.handleDial)
srv := &http.Server{Handler: mux}
go func() {
defer ln.Close()
err := srv.Serve(ln)
log.Printf("MonitorHTTP stopped serving with error: %v", err)
}()
return srv, nil
}
func (m *monitor) handleFullStatus(w http.ResponseWriter, r *http.Request) {
s, err := m.getStatus(r.Context())
if err != nil {
http.Error(w, err.Error(), 500)
return
}
if err := json.NewEncoder(w).Encode(s); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
}
func (m *monitor) handleSummaryStatus(w http.ResponseWriter, r *http.Request) {
s, err := m.getStatus(r.Context())
if err != nil {
http.Error(w, err.Error(), 500)
return
}
lines := []string{}
for _, p := range s.Status.Peer {
if p.Online {
lines = append(lines, fmt.Sprintf("%s\t\t%d\t%d\t%t", strings.Split(p.DNSName, ".")[0], p.RxBytes, p.TxBytes, p.Active))
}
}
slices.Sort(lines)
lines = append([]string{fmt.Sprintf("RaftState: %s", s.RaftState)}, lines...)
txt := strings.Join(lines, "\n") + "\n"
w.Write([]byte(txt))
}
func (m *monitor) handleNetmap(w http.ResponseWriter, r *http.Request) {
var mask ipn.NotifyWatchOpt = ipn.NotifyInitialNetMap
mask |= ipn.NotifyNoPrivateKeys
lc, err := m.ts.LocalClient()
if err != nil {
http.Error(w, err.Error(), 500)
return
}
watcher, err := lc.WatchIPNBus(r.Context(), mask)
if err != nil {
http.Error(w, err.Error(), 500)
return
}
defer watcher.Close()
n, err := watcher.Next()
if err != nil {
http.Error(w, err.Error(), 500)
return
}
j, _ := json.MarshalIndent(n.NetMap, "", "\t")
w.Write([]byte(j))
return
}
func (m *monitor) handleDial(w http.ResponseWriter, r *http.Request) {
fmt.Println("FRAN handle ping")
var dialParams struct {
Addr string
}
bs, err := io.ReadAll(r.Body)
if err != nil {
http.Error(w, err.Error(), 500)
return
}
err = json.Unmarshal(bs, &dialParams)
if err != nil {
http.Error(w, err.Error(), 500)
return
}
fmt.Println("dialing", dialParams.Addr)
c, err := m.ts.Dial(r.Context(), "tcp", dialParams.Addr)
if err != nil {
http.Error(w, err.Error(), 500)
return
}
fmt.Println("ping success", c)
defer c.Close()
w.Write([]byte("ok\n"))
return
}

363
tsconsensus/tsconsensus.go Normal file
View File

@@ -0,0 +1,363 @@
package tsconsensus
import (
"context"
"encoding/json"
"errors"
"fmt"
"log"
"net"
"net/http"
"slices"
"time"
"github.com/hashicorp/raft"
"tailscale.com/ipn/ipnstate"
"tailscale.com/tsnet"
)
/*
Package tsconsensus implements a consensus algorithm for a group of tsnet.Servers
The Raft consensus algorithm relies on you implementing a state machine that will give the same
result to a give command as long as the same logs have been applied in the same order.
tsconsensus uses the hashicorp/raft library to implement leader elections and log application.
tsconsensus provides:
* cluster peer discovery based on tailscale tags
* executing a command on the leader
* communication between cluster peers over tailscale using tsnet
Users implement a state machine that satisfies the raft.FSM interface, with the business logic they desire.
When changes to state are needed any node may
* create a Command instance with serialized Args.
* call ExecuteCommand with the Command instance
this will propagate the command to the leader,
and then from the reader to every node via raft.
* the state machine then can implement raft.Apply, and dispatch commands via the Command.Name
returning a CommandResult with an Err or a serialized Result.
*/
func addr(host string, port uint16) string {
return fmt.Sprintf("%s:%d", host, port)
}
func raftAddr(host string, cfg Config) string {
return addr(host, cfg.RaftPort)
}
// A SelfRaftNode is the info we need to talk to hashicorp/raft about our node.
// We specify the ID and Addr on Consensus Start, and then use it later for raft
// operations such as BootstrapCluster and AddVoter.
type SelfRaftNode struct {
ID string
Host string
}
// A Config holds configurable values such as ports and timeouts.
// Use DefaultConfig to get a useful Config.
type Config struct {
CommandPort uint16
RaftPort uint16
MonitorPort uint16
Raft *raft.Config
MaxConnPool int
ConnTimeout time.Duration
}
// DefaultConfig returns a Config populated with default values ready for use.
func DefaultConfig() Config {
return Config{
CommandPort: 6271,
RaftPort: 6270,
MonitorPort: 8081,
Raft: raft.DefaultConfig(),
MaxConnPool: 5,
ConnTimeout: 5 * time.Second,
}
}
// StreamLayer implements an interface asked for by raft.NetworkTransport.
// It does the raft interprocess communication via tailscale.
type StreamLayer struct {
net.Listener
s *tsnet.Server
}
// Dial implements the raft.StreamLayer interface with the tsnet.Server's Dial.
func (sl StreamLayer) Dial(address raft.ServerAddress, timeout time.Duration) (net.Conn, error) {
ctx, _ := context.WithTimeout(context.Background(), timeout)
return sl.s.Dial(ctx, "tcp", string(address))
}
// Start returns a pointer to a running Consensus instance.
func Start(ctx context.Context, ts *tsnet.Server, fsm raft.FSM, targetTag string, cfg Config) (*Consensus, error) {
v4, _ := ts.TailscaleIPs()
cc := commandClient{
port: cfg.CommandPort,
httpClient: ts.HTTPClient(),
}
self := SelfRaftNode{
ID: v4.String(),
Host: v4.String(),
}
c := Consensus{
CommandClient: &cc,
Self: self,
Config: cfg,
}
lc, err := ts.LocalClient()
if err != nil {
return nil, err
}
tStatus, err := lc.Status(ctx)
if err != nil {
return nil, err
}
var targets []*ipnstate.PeerStatus
if targetTag != "" && tStatus.Self.Tags != nil && slices.Contains(tStatus.Self.Tags.AsSlice(), targetTag) {
for _, v := range tStatus.Peer {
if v.Tags != nil && slices.Contains(v.Tags.AsSlice(), targetTag) {
targets = append(targets, v)
}
}
} else {
return nil, errors.New("targetTag empty, or this node is not tagged with it")
}
r, err := startRaft(ts, &fsm, c.Self, cfg)
if err != nil {
return nil, err
}
c.Raft = r
srv, err := c.serveCmdHttp(ts)
if err != nil {
return nil, err
}
c.cmdHttpServer = srv
c.bootstrap(targets)
srv, err = serveMonitor(&c, ts, addr(c.Self.Host, cfg.MonitorPort))
if err != nil {
return nil, err
}
c.monitorHttpServer = srv
return &c, nil
}
func startRaft(ts *tsnet.Server, fsm *raft.FSM, self SelfRaftNode, cfg Config) (*raft.Raft, error) {
config := cfg.Raft
config.LocalID = raft.ServerID(self.ID)
// no persistence (for now?)
logStore := raft.NewInmemStore()
stableStore := raft.NewInmemStore()
snapshots := raft.NewInmemSnapshotStore()
// opens the listener on the raft port, raft will close it when it thinks it's appropriate
ln, err := ts.Listen("tcp", raftAddr(self.Host, cfg))
if err != nil {
return nil, err
}
transport := raft.NewNetworkTransport(StreamLayer{
s: ts,
Listener: ln,
},
cfg.MaxConnPool,
cfg.ConnTimeout,
nil) // TODO pass in proper logging
// after NewRaft it's possible some other raft node that has us in their configuration will get
// in contact, so by the time we do anything else we may already be a functioning member
// of a consensus
return raft.NewRaft(config, *fsm, logStore, stableStore, snapshots, transport)
}
// A Consensus is the consensus algorithm for a tsnet.Server
// It wraps a raft.Raft instance and performs the peer discovery
// and command execution on the leader.
type Consensus struct {
Raft *raft.Raft
CommandClient *commandClient
Self SelfRaftNode
Config Config
cmdHttpServer *http.Server
monitorHttpServer *http.Server
}
// bootstrap tries to join a raft cluster, or start one.
//
// We need to do the very first raft cluster configuration, but after that raft manages it.
// bootstrap is called at start up, and we are not currently aware of what the cluster config might be,
// our node may already be in it. Try to join the raft cluster of all the other nodes we know about, and
// if unsuccessful, assume we are the first and start our own.
//
// It's possible for bootstrap to return an error, or start a errant breakaway cluster.
//
// We have a list of expected cluster members already from control (the members of the tailnet with the tag)
// so we could do the initial configuration with all servers specified.
// Choose to start with just this machine in the raft configuration instead, as:
// - We want to handle machines joining after start anyway.
// - Not all tagged nodes tailscale believes are active are necessarily actually responsive right now,
// so let each node opt in when able.
func (c *Consensus) bootstrap(targets []*ipnstate.PeerStatus) error {
log.Printf("Trying to find cluster: num targets to try: %d", len(targets))
for _, p := range targets {
if !p.Online {
log.Printf("Trying to find cluster: tailscale reports not online: %s", p.TailscaleIPs[0])
} else {
log.Printf("Trying to find cluster: trying %s", p.TailscaleIPs[0])
err := c.CommandClient.Join(p.TailscaleIPs[0].String(), joinRequest{
RemoteHost: c.Self.Host,
RemoteID: c.Self.ID,
})
if err != nil {
log.Printf("Trying to find cluster: could not join %s: %v", p.TailscaleIPs[0], err)
} else {
log.Printf("Trying to find cluster: joined %s", p.TailscaleIPs[0])
return nil
}
}
}
log.Printf("Trying to find cluster: unsuccessful, starting as leader: %s", c.Self.Host)
f := c.Raft.BootstrapCluster(
raft.Configuration{
Servers: []raft.Server{
{
ID: raft.ServerID(c.Self.ID),
Address: raft.ServerAddress(c.raftAddr(c.Self.Host)),
},
},
})
return f.Error()
}
// ExecuteCommand propagates a Command to be executed on the leader. Which
// uses raft to Apply it to the followers.
func (c *Consensus) ExecuteCommand(cmd Command) (CommandResult, error) {
b, err := json.Marshal(cmd)
if err != nil {
return CommandResult{}, err
}
result, err := c.executeCommandLocally(cmd)
var leErr lookElsewhereError
for errors.As(err, &leErr) {
result, err = c.CommandClient.ExecuteCommand(leErr.where, b)
}
return result, err
}
// Stop attempts to gracefully shutdown various components.
func (c *Consensus) Stop(ctx context.Context) error {
fut := c.Raft.Shutdown()
err := fut.Error()
if err != nil {
log.Printf("Stop: Error in Raft Shutdown: %v", err)
}
err = c.cmdHttpServer.Shutdown(ctx)
if err != nil {
log.Printf("Stop: Error in command HTTP Shutdown: %v", err)
}
err = c.monitorHttpServer.Shutdown(ctx)
if err != nil {
log.Printf("Stop: Error in monitor HTTP Shutdown: %v", err)
}
return nil
}
// A Command is a representation of a state machine action.
// The Name can be used to dispatch the command when received.
// The Args are serialized for transport.
type Command struct {
Name string
Args []byte
}
// A CommandResult is a representation of the result of a state
// machine action.
// Err is any error that occurred on the node that tried to execute the command,
// including any error from the underlying operation and deserialization problems etc.
// Result is serialized for transport.
type CommandResult struct {
Err error
Result []byte
}
type lookElsewhereError struct {
where string
}
func (e lookElsewhereError) Error() string {
return fmt.Sprintf("not the leader, try: %s", e.where)
}
var ErrLeaderUnknown = errors.New("Leader Unknown")
func (c *Consensus) serveCmdHttp(ts *tsnet.Server) (*http.Server, error) {
ln, err := ts.Listen("tcp", c.commandAddr(c.Self.Host))
if err != nil {
return nil, err
}
mux := c.makeCommandMux()
srv := &http.Server{Handler: mux}
go func() {
defer ln.Close()
err := srv.Serve(ln)
log.Printf("CmdHttp stopped serving with err: %v", err)
}()
return srv, nil
}
func (c *Consensus) getLeader() (string, error) {
raftLeaderAddr, _ := c.Raft.LeaderWithID()
leaderAddr := (string)(raftLeaderAddr)
if leaderAddr == "" {
// Raft doesn't know who the leader is.
return "", ErrLeaderUnknown
}
// Raft gives us the address with the raft port, we don't always want that.
host, _, err := net.SplitHostPort(leaderAddr)
return host, err
}
func (c *Consensus) executeCommandLocally(cmd Command) (CommandResult, error) {
b, err := json.Marshal(cmd)
if err != nil {
return CommandResult{}, err
}
f := c.Raft.Apply(b, 10*time.Second)
err = f.Error()
result := f.Response()
if errors.Is(err, raft.ErrNotLeader) {
leader, err := c.getLeader()
if err != nil {
// we know we're not leader but we were unable to give the address of the leader
return CommandResult{}, err
}
return CommandResult{}, lookElsewhereError{where: leader}
}
if result == nil {
result = CommandResult{}
}
return result.(CommandResult), err
}
func (c *Consensus) handleJoin(jr joinRequest) error {
remoteAddr := c.raftAddr(jr.RemoteHost)
f := c.Raft.AddVoter(raft.ServerID(jr.RemoteID), raft.ServerAddress(remoteAddr), 0, 0)
if f.Error() != nil {
return f.Error()
}
return nil
}
func (c *Consensus) raftAddr(host string) string {
return raftAddr(host, c.Config)
}
func (c *Consensus) commandAddr(host string) string {
return addr(host, c.Config.CommandPort)
}

View File

@@ -0,0 +1,459 @@
package tsconsensus
import (
"context"
"fmt"
"io"
"log"
"net/http"
"net/http/httptest"
"net/netip"
"os"
"path/filepath"
"strings"
"testing"
"time"
"github.com/hashicorp/raft"
"tailscale.com/client/tailscale"
"tailscale.com/ipn/store/mem"
"tailscale.com/net/netns"
"tailscale.com/tailcfg"
"tailscale.com/tsnet"
"tailscale.com/tstest/integration"
"tailscale.com/tstest/integration/testcontrol"
"tailscale.com/tstest/nettest"
"tailscale.com/types/key"
"tailscale.com/types/logger"
)
type fsm struct {
events []map[string]interface{}
count int
}
type fsmSnapshot struct{}
func (f *fsm) Apply(l *raft.Log) interface{} {
f.count++
f.events = append(f.events, map[string]interface{}{
"type": "Apply",
"l": l,
})
return CommandResult{
Result: []byte{byte(f.count)},
}
}
func (f *fsm) Snapshot() (raft.FSMSnapshot, error) {
panic("Snapshot unexpectedly used")
return nil, nil
}
func (f *fsm) Restore(rc io.ReadCloser) error {
panic("Restore unexpectedly used")
return nil
}
func (f *fsmSnapshot) Persist(sink raft.SnapshotSink) error {
panic("Persist unexpectedly used")
return nil
}
func (f *fsmSnapshot) Release() {
panic("Release unexpectedly used")
}
var verboseDERP = false
var verboseNodes = false
// TODO copied from sniproxy_test
func startControl(t *testing.T) (control *testcontrol.Server, controlURL string) {
// Corp#4520: don't use netns for tests.
netns.SetEnabled(false)
t.Cleanup(func() {
netns.SetEnabled(true)
})
derpLogf := logger.Discard
if verboseDERP {
derpLogf = t.Logf
}
derpMap := integration.RunDERPAndSTUN(t, derpLogf, "127.0.0.1")
control = &testcontrol.Server{
DERPMap: derpMap,
DNSConfig: &tailcfg.DNSConfig{
Proxied: true,
},
MagicDNSDomain: "tail-scale.ts.net",
}
control.HTTPTestServer = httptest.NewUnstartedServer(control)
control.HTTPTestServer.Start()
t.Cleanup(control.HTTPTestServer.Close)
controlURL = control.HTTPTestServer.URL
t.Logf("testcontrol listening on %s", controlURL)
return control, controlURL
}
// TODO copied from sniproxy_test
func startNode(t *testing.T, ctx context.Context, controlURL, hostname string) (*tsnet.Server, key.NodePublic, netip.Addr) {
t.Helper()
tmp := filepath.Join(t.TempDir(), hostname)
os.MkdirAll(tmp, 0755)
s := &tsnet.Server{
Dir: tmp,
ControlURL: controlURL,
Hostname: hostname,
Store: new(mem.Store),
Ephemeral: true,
}
if verboseNodes {
s.Logf = log.Printf
}
t.Cleanup(func() { s.Close() })
status, err := s.Up(ctx)
if err != nil {
t.Fatal(err)
}
return s, status.Self.PublicKey, status.TailscaleIPs[0]
}
func pingNode(t *testing.T, control *testcontrol.Server, nodeKey key.NodePublic) {
t.Helper()
gotPing := make(chan bool, 1)
waitPing := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
gotPing <- true
}))
defer waitPing.Close()
for try := 0; try < 5; try++ {
pr := &tailcfg.PingRequest{URL: fmt.Sprintf("%s/ping-%d", waitPing.URL, try), Log: true}
if !control.AddPingRequest(nodeKey, pr) {
t.Fatalf("failed to AddPingRequest")
}
pingTimeout := time.NewTimer(2 * time.Second)
defer pingTimeout.Stop()
select {
case <-gotPing:
// ok! the machinery that refreshes the netmap has been nudged
return
case <-pingTimeout.C:
t.Logf("waiting for ping timed out: %d", try)
}
}
}
func tagNodes(t *testing.T, control *testcontrol.Server, nodeKeys []key.NodePublic, tag string) {
t.Helper()
for _, key := range nodeKeys {
n := control.Node(key)
n.Tags = append(n.Tags, tag)
b := true
n.Online = &b
control.UpdateNode(n)
}
// all this ping stuff is only to prod the netmap to get updated with the tag we just added to the node
// ie to actually get the netmap issued to clients that represents the current state of the nodes
// there _must_ be a better way to do this, but I looked all day and this was the first thing I found that worked.
for _, key := range nodeKeys {
pingNode(t, control, key)
}
}
// TODO test start with al lthe config settings
func TestStart(t *testing.T) {
nettest.SkipIfNoNetwork(t)
control, controlURL := startControl(t)
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
one, k, _ := startNode(t, ctx, controlURL, "one")
clusterTag := "tag:whatever"
// nodes must be tagged with the cluster tag, to find each other
tagNodes(t, control, []key.NodePublic{k}, clusterTag)
sm := &fsm{}
r, err := Start(ctx, one, (*fsm)(sm), clusterTag, DefaultConfig())
if err != nil {
t.Fatal(err)
}
defer r.Stop(ctx)
}
func waitFor(t *testing.T, msg string, condition func() bool, nTries int, waitBetweenTries time.Duration) {
for try := 0; try < nTries; try++ {
done := condition()
if done {
t.Logf("waitFor success: %s: after %d tries", msg, try)
return
}
time.Sleep(waitBetweenTries)
}
t.Fatalf("waitFor timed out: %s, after %d tries", msg, nTries)
}
type participant struct {
c *Consensus
sm *fsm
ts *tsnet.Server
key key.NodePublic
}
// starts and tags the *tsnet.Server nodes with the control, waits for the nodes to make successful
// LocalClient Status calls that show the first node as Online.
func startNodesAndWaitForPeerStatus(t *testing.T, ctx context.Context, clusterTag string, nNodes int) ([]*participant, *testcontrol.Server, string) {
ps := make([]*participant, nNodes)
keysToTag := make([]key.NodePublic, nNodes)
localClients := make([]*tailscale.LocalClient, nNodes)
control, controlURL := startControl(t)
for i := 0; i < nNodes; i++ {
ts, key, _ := startNode(t, ctx, controlURL, fmt.Sprintf("node: %d", i))
ps[i] = &participant{ts: ts, key: key}
keysToTag[i] = key
lc, err := ts.LocalClient()
if err != nil {
t.Fatalf("%d: error getting local client: %v", i, err)
}
localClients[i] = lc
}
tagNodes(t, control, keysToTag, clusterTag)
fxCameOnline := func() bool {
// all the _other_ nodes see the first as online
for i := 1; i < nNodes; i++ {
status, err := localClients[i].Status(ctx)
if err != nil {
t.Fatalf("%d: error getting status: %v", i, err)
}
if !status.Peer[ps[0].key].Online {
return false
}
}
return true
}
waitFor(t, "other nodes see node 1 online in ts status", fxCameOnline, 10, 2*time.Second)
return ps, control, controlURL
}
// populates participants with their consensus fields, waits for all nodes to show all nodes
// as part of the same consensus cluster. Starts the first participant first and waits for it to
// become leader before adding other nodes.
func createConsensusCluster(t *testing.T, ctx context.Context, clusterTag string, participants []*participant, cfg Config) {
participants[0].sm = &fsm{}
first, err := Start(ctx, participants[0].ts, (*fsm)(participants[0].sm), clusterTag, cfg)
if err != nil {
t.Fatal(err)
}
fxFirstIsLeader := func() bool {
return first.Raft.State() == raft.Leader
}
waitFor(t, "node 0 is leader", fxFirstIsLeader, 10, 2*time.Second)
participants[0].c = first
for i := 1; i < len(participants); i++ {
participants[i].sm = &fsm{}
c, err := Start(ctx, participants[i].ts, (*fsm)(participants[i].sm), clusterTag, cfg)
if err != nil {
t.Fatal(err)
}
participants[i].c = c
}
fxRaftConfigContainsAll := func() bool {
for i := 0; i < len(participants); i++ {
fut := participants[i].c.Raft.GetConfiguration()
err = fut.Error()
if err != nil {
t.Fatalf("%d: Getting Configuration errored: %v", i, err)
}
if len(fut.Configuration().Servers) != len(participants) {
return false
}
}
return true
}
waitFor(t, "all raft machines have all servers in their config", fxRaftConfigContainsAll, 10, time.Second*2)
}
func TestApply(t *testing.T) {
nettest.SkipIfNoNetwork(t)
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
clusterTag := "tag:whatever"
ps, _, _ := startNodesAndWaitForPeerStatus(t, ctx, clusterTag, 2)
cfg := DefaultConfig()
createConsensusCluster(t, ctx, clusterTag, ps, cfg)
fut := ps[0].c.Raft.Apply([]byte("woo"), 2*time.Second)
err := fut.Error()
if err != nil {
t.Fatalf("Raft Apply Error: %v", err)
}
fxBothMachinesHaveTheApply := func() bool {
return len(ps[0].sm.events) == 1 && len(ps[1].sm.events) == 1
}
waitFor(t, "the apply event made it into both state machines", fxBothMachinesHaveTheApply, 10, time.Second*1)
}
// calls ExecuteCommand on each participant and checks that all participants get all commands
func assertCommandsWorkOnAnyNode(t *testing.T, participants []*participant) {
for i, p := range participants {
res, err := p.c.ExecuteCommand(Command{Args: []byte{byte(i)}})
if err != nil {
t.Fatalf("%d: Error ExecuteCommand: %v", i, err)
}
if res.Err != nil {
t.Fatalf("%d: Result Error ExecuteCommand: %v", i, res.Err)
}
retVal := int(res.Result[0])
// the test implementation of the fsm returns the count of events that have been received
if retVal != i+1 {
t.Fatalf("Result, want %d, got %d", i+1, retVal)
}
expectedEventsLength := i + 1
fxEventsInAll := func() bool {
for _, pOther := range participants {
if len(pOther.sm.events) != expectedEventsLength {
return false
}
}
return true
}
waitFor(t, "event makes it to all", fxEventsInAll, 10, time.Second*1)
}
}
func TestConfig(t *testing.T) {
nettest.SkipIfNoNetwork(t)
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
clusterTag := "tag:whatever"
ps, _, _ := startNodesAndWaitForPeerStatus(t, ctx, clusterTag, 3)
cfg := DefaultConfig()
// test all is well with non default ports
cfg.CommandPort = 12347
cfg.RaftPort = 11882
mp := uint16(8798)
cfg.MonitorPort = mp
createConsensusCluster(t, ctx, clusterTag, ps, cfg)
assertCommandsWorkOnAnyNode(t, ps)
url := fmt.Sprintf("http://%s:%d/", ps[0].c.Self.Host, mp)
httpClientOnTailnet := ps[1].ts.HTTPClient()
rsp, err := httpClientOnTailnet.Get(url)
if err != nil {
t.Fatal(err)
}
if rsp.StatusCode != 200 {
t.Fatalf("monitor status want %d, got %d", 200, rsp.StatusCode)
}
body, err := io.ReadAll(rsp.Body)
if err != nil {
t.Fatal(err)
}
// Not a great assertion because it relies on the format of the response.
line1 := strings.Split(string(body), "\n")[0]
if line1[:10] != "RaftState:" {
t.Fatalf("getting monitor status, first line, want something that starts with 'RaftState:', got '%s'", line1)
}
}
func TestFollowerFailover(t *testing.T) {
nettest.SkipIfNoNetwork(t)
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
clusterTag := "tag:whatever"
ps, _, _ := startNodesAndWaitForPeerStatus(t, ctx, clusterTag, 3)
cfg := DefaultConfig()
createConsensusCluster(t, ctx, clusterTag, ps, cfg)
smThree := ps[2].sm
fut := ps[0].c.Raft.Apply([]byte("a"), 2*time.Second)
futTwo := ps[0].c.Raft.Apply([]byte("b"), 2*time.Second)
err := fut.Error()
if err != nil {
t.Fatalf("Apply Raft error %v", err)
}
err = futTwo.Error()
if err != nil {
t.Fatalf("Apply Raft error %v", err)
}
fxAllMachinesHaveTheApplies := func() bool {
return len(ps[0].sm.events) == 2 && len(ps[1].sm.events) == 2 && len(smThree.events) == 2
}
waitFor(t, "the apply events made it into all state machines", fxAllMachinesHaveTheApplies, 10, time.Second*1)
//a follower goes loses contact with the cluster
ps[2].c.Stop(ctx)
// applies still make it to one and two
futThree := ps[0].c.Raft.Apply([]byte("c"), 2*time.Second)
futFour := ps[0].c.Raft.Apply([]byte("d"), 2*time.Second)
err = futThree.Error()
if err != nil {
t.Fatalf("Apply Raft error %v", err)
}
err = futFour.Error()
if err != nil {
t.Fatalf("Apply Raft error %v", err)
}
fxAliveMachinesHaveTheApplies := func() bool {
return len(ps[0].sm.events) == 4 && len(ps[1].sm.events) == 4 && len(smThree.events) == 2
}
waitFor(t, "the apply events made it into eligible state machines", fxAliveMachinesHaveTheApplies, 10, time.Second*1)
// follower comes back
smThreeAgain := &fsm{}
rThreeAgain, err := Start(ctx, ps[2].ts, (*fsm)(smThreeAgain), clusterTag, DefaultConfig())
if err != nil {
t.Fatal(err)
}
defer rThreeAgain.Stop(ctx)
fxThreeGetsCaughtUp := func() bool {
return len(smThreeAgain.events) == 4
}
waitFor(t, "the apply events made it into the third node when it appeared with an empty state machine", fxThreeGetsCaughtUp, 20, time.Second*2)
if len(smThree.events) != 2 {
t.Fatalf("Expected smThree to remain on 2 events: got %d", len(smThree.events))
}
}
func TestRejoin(t *testing.T) {
nettest.SkipIfNoNetwork(t)
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
clusterTag := "tag:whatever"
ps, control, controlURL := startNodesAndWaitForPeerStatus(t, ctx, clusterTag, 3)
cfg := DefaultConfig()
createConsensusCluster(t, ctx, clusterTag, ps, cfg)
for _, p := range ps {
defer p.c.Stop(ctx)
}
// 1st node gets a redundant second join request from the second node
ps[0].c.handleJoin(joinRequest{
RemoteHost: ps[1].c.Self.Host,
RemoteID: ps[1].c.Self.ID,
})
tsJoiner, keyJoiner, _ := startNode(t, ctx, controlURL, "node: joiner")
tagNodes(t, control, []key.NodePublic{keyJoiner}, clusterTag)
smJoiner := &fsm{}
cJoiner, err := Start(ctx, tsJoiner, (*fsm)(smJoiner), clusterTag, cfg)
if err != nil {
t.Fatal(err)
}
ps = append(ps, &participant{
sm: smJoiner,
c: cJoiner,
ts: tsJoiner,
key: keyJoiner,
})
assertCommandsWorkOnAnyNode(t, ps)
}