to squash: restrict communication to tagged nodes

This commit is contained in:
Fran Bull 2025-02-20 15:14:22 -08:00
parent 6ebb0c749d
commit 7c539e3d2f
7 changed files with 396 additions and 75 deletions

4
go.mod
View File

@ -127,11 +127,7 @@ require (
github.com/OpenPeeDeeP/depguard/v2 v2.2.0 // indirect
github.com/alecthomas/go-check-sumtype v0.1.4 // indirect
github.com/alexkohler/nakedret/v2 v2.0.4 // indirect
<<<<<<< HEAD
=======
github.com/armon/go-metrics v0.4.1 // indirect
github.com/bits-and-blooms/bitset v1.13.0 // indirect
>>>>>>> 348d01d82 (tsconsensus: add a tsconsensus package)
github.com/bombsimon/wsl/v4 v4.2.1 // indirect
github.com/butuzov/mirror v1.1.0 // indirect
github.com/catenacyber/perfsprint v0.7.1 // indirect

12
go.sum
View File

@ -293,14 +293,9 @@ github.com/evanphx/json-patch/v5 v5.9.0 h1:kcBlZQbplgElYIlo/n1hJbls2z/1awpXxpRi0
github.com/evanphx/json-patch/v5 v5.9.0/go.mod h1:VNkHZ/282BpEyt/tObQO8s5CMPmYYq14uClGH4abBuQ=
github.com/evanw/esbuild v0.19.11 h1:mbPO1VJ/df//jjUd+p/nRLYCpizXxXb2w/zZMShxa2k=
github.com/evanw/esbuild v0.19.11/go.mod h1:D2vIQZqV/vIf/VRHtViaUtViZmG7o+kKmlBfVQuRi48=
<<<<<<< HEAD
github.com/fatih/color v1.13.0/go.mod h1:kLAiJbzzSOZDVNGyDpeOxJ47H46qBXwg5ILebYFFOfk=
github.com/fatih/color v1.18.0 h1:S8gINlzdQ840/4pfAwic/ZE0djQEH3wM94VfqLTZcOM=
github.com/fatih/color v1.18.0/go.mod h1:4FelSpRwEGDpQ12mAdzqdOukCy4u8WUtOY6lkT/6HfU=
=======
github.com/fatih/color v1.13.0/go.mod h1:kLAiJbzzSOZDVNGyDpeOxJ47H46qBXwg5ILebYFFOfk=
github.com/fatih/color v1.17.0 h1:GlRw1BRJxkpqUCBKzKOw098ed57fEsKeNjpTe3cSjK4=
github.com/fatih/color v1.17.0/go.mod h1:YZ7TlrGPkiz6ku9fK3TLD/pl3CpsiFyu8N92HLgmosI=
>>>>>>> 348d01d82 (tsconsensus: add a tsconsensus package)
github.com/fatih/structtag v1.2.0 h1:/OdNE99OxoI/PqaW/SuSK9uxxT3f/tcSZgon/ssNSx4=
github.com/fatih/structtag v1.2.0/go.mod h1:mBJUNpUnHmRKrKlQQlmCrh5PuhftFbNv8Ys4/aAZl94=
github.com/felixge/httpsnoop v1.0.4 h1:NFTV2Zj1bL4mc9sqWACXbQFVBBg2W3GPvqp8/ESS2Wg=
@ -536,12 +531,8 @@ github.com/gostaticanalysis/testutil v0.3.1-0.20210208050101-bfb5c8eec0e4/go.mod
github.com/gostaticanalysis/testutil v0.4.0 h1:nhdCmubdmDF6VEatUNjgUZBJKWRqugoISdUv3PPQgHY=
github.com/gostaticanalysis/testutil v0.4.0/go.mod h1:bLIoPefWXrRi/ssLFWX1dx7Repi5x3CuviD3dgAZaBU=
github.com/grpc-ecosystem/grpc-gateway v1.16.0 h1:gmcG1KaJ57LophUzW0Hy8NmPhnMZb4M0+kPpLofRdBo=
<<<<<<< HEAD
github.com/grpc-ecosystem/grpc-gateway/v2 v2.20.0 h1:bkypFPDjIYGfCYD5mRBvpqxfYX1YCS1PXdKYWi8FsN0=
github.com/grpc-ecosystem/grpc-gateway/v2 v2.20.0/go.mod h1:P+Lt/0by1T8bfcF3z737NnSbmxQAppXMRziHUxPOC8k=
=======
github.com/grpc-ecosystem/grpc-gateway/v2 v2.16.0 h1:YBftPWNWd4WwGqtY2yeZL2ef8rHAxPBD8KFhJpmcqms=
github.com/grpc-ecosystem/grpc-gateway/v2 v2.16.0/go.mod h1:YN5jB8ie0yfIUg6VvR9Kz84aCaG7AsGZnLjhHbUqwPg=
github.com/hashicorp/go-cleanhttp v0.5.0/go.mod h1:JpRdi6/HCYpAwUzNwuwqhbovhLtngrth3wmdIIUrZ80=
github.com/hashicorp/go-hclog v1.6.2 h1:NOtoftovWkDheyUM/8JW3QMiXyxJK3uHRK7wV04nD2I=
github.com/hashicorp/go-hclog v1.6.2/go.mod h1:W4Qnvbt70Wk/zYJryRzDRU/4r0kIg0PVHBcfoyhpF5M=
@ -554,7 +545,6 @@ github.com/hashicorp/go-msgpack/v2 v2.1.2 h1:4Ee8FTp834e+ewB71RDrQ0VKpyFdrKOjvYt
github.com/hashicorp/go-msgpack/v2 v2.1.2/go.mod h1:upybraOAblm4S7rx0+jeNy+CWWhzywQsSRV5033mMu4=
github.com/hashicorp/go-retryablehttp v0.5.3/go.mod h1:9B5zBasrRhHXnJnui7y6sL7es7NDiJgTc6Er0maI1Xs=
github.com/hashicorp/go-uuid v1.0.0/go.mod h1:6SBZvOh/SIDV7/2o3Jml5SYk/TvGqwFJ/bN7x4byOro=
>>>>>>> 348d01d82 (tsconsensus: add a tsconsensus package)
github.com/hashicorp/go-version v1.2.1/go.mod h1:fltr4n8CU8Ke44wwGCBoEymUuxUHl09ZGVZPK5anwXA=
github.com/hashicorp/go-version v1.6.0 h1:feTTfFNnjP967rlCxM/I9g701jU+RN74YKx2mOkIeek=
github.com/hashicorp/go-version v1.6.0/go.mod h1:fltr4n8CU8Ke44wwGCBoEymUuxUHl09ZGVZPK5anwXA=

View File

@ -0,0 +1,80 @@
package tsconsensus
import (
"context"
"net/netip"
"slices"
"tailscale.com/ipn/ipnstate"
"tailscale.com/tsnet"
)
type authorization struct {
ts *tsnet.Server
tag string
peers *peers
}
func (a *authorization) refresh(ctx context.Context) error {
lc, err := a.ts.LocalClient()
if err != nil {
return err
}
tStatus, err := lc.Status(ctx)
if err != nil {
return err
}
a.peers = newPeers(tStatus)
return nil
}
func (a *authorization) allowsHost(addr netip.Addr) bool {
return a.peers.peerExists(addr, a.tag)
}
func (a *authorization) selfAllowed() bool {
return a.peers.status.Self.Tags != nil && slices.Contains(a.peers.status.Self.Tags.AsSlice(), a.tag)
}
func (a *authorization) allowedPeers() []*ipnstate.PeerStatus {
if a.peers.allowedPeers == nil {
return []*ipnstate.PeerStatus{}
}
return a.peers.allowedPeers
}
type peers struct {
status *ipnstate.Status
peerByIPAddressAndTag map[netip.Addr]map[string]*ipnstate.PeerStatus
allowedPeers []*ipnstate.PeerStatus
}
func (ps *peers) peerExists(a netip.Addr, tag string) bool {
byTag, ok := ps.peerByIPAddressAndTag[a]
if !ok {
return false
}
_, ok = byTag[tag]
return ok
}
func newPeers(status *ipnstate.Status) *peers {
ps := &peers{
peerByIPAddressAndTag: map[netip.Addr]map[string]*ipnstate.PeerStatus{},
status: status,
}
for _, p := range status.Peer {
for _, addr := range p.TailscaleIPs {
if ps.peerByIPAddressAndTag[addr] == nil {
ps.peerByIPAddressAndTag[addr] = map[string]*ipnstate.PeerStatus{}
}
if p.Tags != nil {
for _, tag := range p.Tags.AsSlice() {
ps.peerByIPAddressAndTag[addr][tag] = p
ps.allowedPeers = append(ps.allowedPeers, p)
}
}
}
}
return ps
}

View File

@ -78,9 +78,30 @@ func (rac *commandClient) executeCommand(host string, bs []byte) (CommandResult,
return cr, nil
}
func (c *Consensus) makeCommandMux() *http.ServeMux {
func authorized(auth *authorization, fx func(http.ResponseWriter, *http.Request)) func(http.ResponseWriter, *http.Request) {
return func(w http.ResponseWriter, r *http.Request) {
err := auth.refresh(r.Context())
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
a, err := addrFromServerAddress(r.RemoteAddr)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
allowed := auth.allowsHost(a)
if !allowed {
http.Error(w, "peer not allowed", http.StatusBadRequest)
return
}
fx(w, r)
}
}
func (c *Consensus) makeCommandMux(auth *authorization) *http.ServeMux {
mux := http.NewServeMux()
mux.HandleFunc("/join", func(w http.ResponseWriter, r *http.Request) {
mux.HandleFunc("/join", authorized(auth, func(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
http.Error(w, "Bad Request", http.StatusBadRequest)
return
@ -106,8 +127,8 @@ func (c *Consensus) makeCommandMux() *http.ServeMux {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
})
mux.HandleFunc("/executeCommand", func(w http.ResponseWriter, r *http.Request) {
}))
mux.HandleFunc("/executeCommand", authorized(auth, func(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
http.Error(w, "Bad Request", http.StatusBadRequest)
return
@ -125,6 +146,6 @@ func (c *Consensus) makeCommandMux() *http.ServeMux {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
})
}))
return mux
}

View File

@ -50,7 +50,6 @@ func serveMonitor(c *Consensus, ts *tsnet.Server, listenAddr string) (*http.Serv
mux.HandleFunc("/dial", m.handleDial)
srv := &http.Server{Handler: mux}
go func() {
defer ln.Close()
err := srv.Serve(ln)
log.Printf("MonitorHTTP stopped serving with error: %v", err)
}()

View File

@ -8,7 +8,7 @@ import (
"log"
"net"
"net/http"
"slices"
"net/netip"
"time"
"github.com/hashicorp/raft"
@ -47,6 +47,14 @@ func raftAddr(host string, cfg Config) string {
return addr(host, cfg.RaftPort)
}
func addrFromServerAddress(sa string) (netip.Addr, error) {
sAddr, _, err := net.SplitHostPort(sa)
if err != nil {
return netip.Addr{}, err
}
return netip.ParseAddr(sAddr)
}
// A selfRaftNode is the info we need to talk to hashicorp/raft about our node.
// We specify the ID and Addr on Consensus Start, and then use it later for raft
// operations such as BootstrapCluster and AddVoter.
@ -82,18 +90,61 @@ func DefaultConfig() Config {
// It does the raft interprocess communication via tailscale.
type StreamLayer struct {
net.Listener
s *tsnet.Server
s *tsnet.Server
auth *authorization
}
// Dial implements the raft.StreamLayer interface with the tsnet.Server's Dial.
func (sl StreamLayer) Dial(address raft.ServerAddress, timeout time.Duration) (net.Conn, error) {
ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
err := sl.auth.refresh(ctx)
if err != nil {
return nil, err
}
addr, err := addrFromServerAddress(string(address))
if err != nil {
return nil, err
}
if !sl.auth.allowsHost(addr) {
return nil, errors.New("peer is not allowed")
}
return sl.s.Dial(ctx, "tcp", string(address))
}
func (sl StreamLayer) Accept() (net.Conn, error) {
for {
conn, err := sl.Listener.Accept()
if err != nil || conn == nil {
return conn, err
}
ctx := context.Background() // TODO
err = sl.auth.refresh(ctx)
if err != nil {
// TODO should we stay alive here?
return nil, err
}
addr, err := addrFromServerAddress(conn.RemoteAddr().String())
if err != nil {
// TODO should we stay alive here?
return nil, err
}
if !sl.auth.allowsHost(addr) {
continue
}
return conn, err
}
}
// Start returns a pointer to a running Consensus instance.
func Start(ctx context.Context, ts *tsnet.Server, fsm raft.FSM, targetTag string, cfg Config) (*Consensus, error) {
func Start(ctx context.Context, ts *tsnet.Server, fsm raft.FSM, clusterTag string, cfg Config) (*Consensus, error) {
if clusterTag == "" {
return nil, errors.New("cluster tag must be provided")
}
v4, _ := ts.TailscaleIPs()
cc := commandClient{
port: cfg.CommandPort,
@ -109,36 +160,29 @@ func Start(ctx context.Context, ts *tsnet.Server, fsm raft.FSM, targetTag string
config: cfg,
}
lc, err := ts.LocalClient()
auth := &authorization{
tag: clusterTag,
ts: ts,
}
err := auth.refresh(ctx)
if err != nil {
return nil, err
}
tStatus, err := lc.Status(ctx)
if err != nil {
return nil, err
}
var targets []*ipnstate.PeerStatus
if targetTag != "" && tStatus.Self.Tags != nil && slices.Contains(tStatus.Self.Tags.AsSlice(), targetTag) {
for _, v := range tStatus.Peer {
if v.Tags != nil && slices.Contains(v.Tags.AsSlice(), targetTag) {
targets = append(targets, v)
}
}
} else {
return nil, errors.New("targetTag empty, or this node is not tagged with it")
if !auth.selfAllowed() {
return nil, errors.New("this node is not tagged with the cluster tag")
}
r, err := startRaft(ts, &fsm, c.self, cfg)
r, err := startRaft(ts, &fsm, c.self, auth, cfg)
if err != nil {
return nil, err
}
c.raft = r
srv, err := c.serveCmdHttp(ts)
srv, err := c.serveCmdHttp(ts, auth)
if err != nil {
return nil, err
}
c.cmdHttpServer = srv
c.bootstrap(targets)
c.bootstrap(auth.allowedPeers())
srv, err = serveMonitor(&c, ts, addr(c.self.host, cfg.MonitorPort))
if err != nil {
return nil, err
@ -147,7 +191,7 @@ func Start(ctx context.Context, ts *tsnet.Server, fsm raft.FSM, targetTag string
return &c, nil
}
func startRaft(ts *tsnet.Server, fsm *raft.FSM, self selfRaftNode, cfg Config) (*raft.Raft, error) {
func startRaft(ts *tsnet.Server, fsm *raft.FSM, self selfRaftNode, auth *authorization, cfg Config) (*raft.Raft, error) {
config := cfg.Raft
config.LocalID = raft.ServerID(self.id)
@ -165,6 +209,7 @@ func startRaft(ts *tsnet.Server, fsm *raft.FSM, self selfRaftNode, cfg Config) (
transport := raft.NewNetworkTransport(StreamLayer{
s: ts,
Listener: ln,
auth: auth,
},
cfg.MaxConnPool,
cfg.ConnTimeout,
@ -297,15 +342,14 @@ func (e lookElsewhereError) Error() string {
var errLeaderUnknown = errors.New("Leader Unknown")
func (c *Consensus) serveCmdHttp(ts *tsnet.Server) (*http.Server, error) {
func (c *Consensus) serveCmdHttp(ts *tsnet.Server, auth *authorization) (*http.Server, error) {
ln, err := ts.Listen("tcp", c.commandAddr(c.self.host))
if err != nil {
return nil, err
}
mux := c.makeCommandMux()
mux := c.makeCommandMux(auth)
srv := &http.Server{Handler: mux}
go func() {
defer ln.Close()
err := srv.Serve(ln)
log.Printf("CmdHttp stopped serving with err: %v", err)
}()

View File

@ -1,10 +1,15 @@
package tsconsensus
import (
"bufio"
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"log"
"net"
"net/http"
"net/http/httptest"
"net/netip"
@ -25,6 +30,7 @@ import (
"tailscale.com/tstest/nettest"
"tailscale.com/types/key"
"tailscale.com/types/logger"
"tailscale.com/types/views"
)
type fsm struct {
@ -112,48 +118,62 @@ func startNode(t *testing.T, ctx context.Context, controlURL, hostname string) (
return s, status.Self.PublicKey, status.TailscaleIPs[0]
}
// pingNode sends a tailscale ping between two nodes. But that's not really relevant here
// doing this has a side effect of causing the testcontrol.Server to recalculate and reissue
// netmaps.
func pingNode(t *testing.T, control *testcontrol.Server, nodeKey key.NodePublic) {
t.Helper()
gotPing := make(chan bool, 1)
waitPing := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
gotPing <- true
}))
defer waitPing.Close()
for try := 0; try < 5; try++ {
pr := &tailcfg.PingRequest{URL: fmt.Sprintf("%s/ping-%d", waitPing.URL, try), Log: true}
if !control.AddPingRequest(nodeKey, pr) {
t.Fatalf("failed to AddPingRequest")
func waitForNodesToBeTaggedInStatus(t *testing.T, ctx context.Context, ts *tsnet.Server, nodeKeys []key.NodePublic, tag string) {
waitFor(t, "nodes tagged in status", func() bool {
lc, err := ts.LocalClient()
if err != nil {
t.Fatal(err)
}
pingTimeout := time.NewTimer(2 * time.Second)
defer pingTimeout.Stop()
select {
case <-gotPing:
// ok! the machinery that refreshes the netmap has been nudged
return
case <-pingTimeout.C:
t.Logf("waiting for ping timed out: %d", try)
status, err := lc.Status(ctx)
if err != nil {
t.Fatalf("error getting status: %v", err)
}
}
for _, k := range nodeKeys {
var tags *views.Slice[string]
if k == status.Self.PublicKey {
tags = status.Self.Tags
} else {
tags = status.Peer[k].Tags
}
if tag == "" {
if tags != nil && tags.Len() != 0 {
return false
}
} else {
if tags == nil {
return false
}
sliceTags := tags.AsSlice()
if len(sliceTags) != 1 || sliceTags[0] != tag {
return false
}
}
}
return true
}, 10, 2*time.Second)
}
func tagNodes(t *testing.T, control *testcontrol.Server, nodeKeys []key.NodePublic, tag string) {
t.Helper()
for _, key := range nodeKeys {
n := control.Node(key)
n.Tags = append(n.Tags, tag)
if tag == "" {
if len(n.Tags) != 1 {
t.Fatalf("expected tags to have one tag")
}
n.Tags = nil
} else {
if len(n.Tags) != 0 {
// if we want this to work with multiple tags we'll have to change the logic
// for checking if a tag got removed yet.
t.Fatalf("expected tags to be empty")
}
n.Tags = append(n.Tags, tag)
}
b := true
n.Online = &b
control.UpdateNode(n)
}
// Cause the netmap to be recalculated and reissued, so we don't have to wait for it.
for _, key := range nodeKeys {
pingNode(t, control, key)
}
}
func TestStart(t *testing.T) {
@ -166,6 +186,7 @@ func TestStart(t *testing.T) {
clusterTag := "tag:whatever"
// nodes must be tagged with the cluster tag, to find each other
tagNodes(t, control, []key.NodePublic{k}, clusterTag)
waitForNodesToBeTaggedInStatus(t, ctx, one, []key.NodePublic{k}, clusterTag)
sm := &fsm{}
r, err := Start(ctx, one, (*fsm)(sm), clusterTag, DefaultConfig())
@ -212,6 +233,7 @@ func startNodesAndWaitForPeerStatus(t *testing.T, ctx context.Context, clusterTa
localClients[i] = lc
}
tagNodes(t, control, keysToTag, clusterTag)
waitForNodesToBeTaggedInStatus(t, ctx, ps[0].ts, keysToTag, clusterTag)
fxCameOnline := func() bool {
// all the _other_ nodes see the first as online
for i := 1; i < nNodes; i++ {
@ -437,6 +459,7 @@ func TestRejoin(t *testing.T) {
tsJoiner, keyJoiner, _ := startNode(t, ctx, controlURL, "node: joiner")
tagNodes(t, control, []key.NodePublic{keyJoiner}, clusterTag)
waitForNodesToBeTaggedInStatus(t, ctx, ps[0].ts, []key.NodePublic{keyJoiner}, clusterTag)
smJoiner := &fsm{}
cJoiner, err := Start(ctx, tsJoiner, (*fsm)(smJoiner), clusterTag, cfg)
if err != nil {
@ -451,3 +474,171 @@ func TestRejoin(t *testing.T) {
assertCommandsWorkOnAnyNode(t, ps)
}
func TestOnlyTaggedPeersCanDialRaftPort(t *testing.T) {
nettest.SkipIfNoNetwork(t)
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
clusterTag := "tag:whatever"
ps, control, controlURL := startNodesAndWaitForPeerStatus(t, ctx, clusterTag, 3)
cfg := DefaultConfig()
createConsensusCluster(t, ctx, clusterTag, ps, cfg)
for _, p := range ps {
defer p.c.Stop(ctx)
}
assertCommandsWorkOnAnyNode(t, ps)
untaggedNode, _, _ := startNode(t, ctx, controlURL, "untagged node")
taggedNode, taggedKey, _ := startNode(t, ctx, controlURL, "untagged node")
tagNodes(t, control, []key.NodePublic{taggedKey}, clusterTag)
waitForNodesToBeTaggedInStatus(t, ctx, ps[0].ts, []key.NodePublic{taggedKey}, clusterTag)
// surface area: command http, peer tcp
//untagged
ipv4, _ := ps[0].ts.TailscaleIPs()
sAddr := fmt.Sprintf("%s:%d", ipv4, cfg.RaftPort)
isNetTimeoutErr := func(err error) bool {
var netErr net.Error
if !errors.As(err, &netErr) {
return false
}
return netErr.Timeout()
}
getErrorFromTryingToSend := func(s *tsnet.Server) error {
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
conn, err := s.Dial(ctx, "tcp", sAddr)
if err != nil {
t.Fatalf("unexpected Dial err: %v", err)
}
conn.SetDeadline(time.Now().Add(1 * time.Second))
fmt.Fprintf(conn, "hellllllloooooo")
status, err := bufio.NewReader(conn).ReadString('\n')
if status != "" {
t.Fatalf("node sending non-raft message should get empty response, got: '%s' for: %s", status, s.Hostname)
}
if err == nil {
t.Fatalf("node sending non-raft message should get an error but got nil err for: %s", s.Hostname)
}
return err
}
err := getErrorFromTryingToSend(untaggedNode)
if !isNetTimeoutErr(err) {
t.Fatalf("untagged node trying to send should time out, got: %v", err)
}
// we still get an error trying to send but it's EOF the target node was happy to talk
// to us but couldn't understand what we said.
err = getErrorFromTryingToSend(taggedNode)
if isNetTimeoutErr(err) {
t.Fatalf("tagged node trying to send should not time out, got: %v", err)
}
}
func TestOnlyTaggedPeersCanBeDialed(t *testing.T) {
t.Skip("flaky test, need to figure out how to actually cause a Dial if we want to test this")
nettest.SkipIfNoNetwork(t)
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
clusterTag := "tag:whatever"
ps, control, _ := startNodesAndWaitForPeerStatus(t, ctx, clusterTag, 3)
cfg := DefaultConfig()
createConsensusCluster(t, ctx, clusterTag, ps, cfg)
for _, p := range ps {
defer p.c.Stop(ctx)
}
assertCommandsWorkOnAnyNode(t, ps)
tagNodes(t, control, []key.NodePublic{ps[2].key}, "")
waitForNodesToBeTaggedInStatus(t, ctx, ps[0].ts, []key.NodePublic{ps[2].key}, "")
// now when we try to communicate there's an open conn we can talk over still, but
// we won't dial a fresh one
// get Raft to redial by removing and readding
// TODO although this doesn't actually cause redialing apparently, at least not for the command rpc stuff.
fut := ps[0].c.raft.RemoveServer(raft.ServerID(ps[2].c.self.id), 0, 5*time.Second)
err := fut.Error()
if err != nil {
t.Fatal(err)
}
fut = ps[0].c.raft.AddVoter(raft.ServerID(ps[2].c.self.id), raft.ServerAddress(raftAddr(ps[2].c.self.host, cfg)), 0, 5*time.Second)
err = fut.Error()
if err != nil {
t.Fatal(err)
}
// ps[2] doesn't get updates any more
res, err := ps[0].c.ExecuteCommand(Command{Args: []byte{byte(1)}})
if err != nil {
t.Fatalf("Error ExecuteCommand: %v", err)
}
if res.Err != nil {
t.Fatalf("Result Error ExecuteCommand: %v", res.Err)
}
fxOneEventSent := func() bool {
return len(ps[0].sm.events) == 4 && len(ps[1].sm.events) == 4 && len(ps[2].sm.events) == 3
}
waitFor(t, "after untagging first and second node get events, but third does not", fxOneEventSent, 10, time.Second*1)
res, err = ps[1].c.ExecuteCommand(Command{Args: []byte{byte(1)}})
if err != nil {
t.Fatalf("Error ExecuteCommand: %v", err)
}
if res.Err != nil {
t.Fatalf("Result Error ExecuteCommand: %v", res.Err)
}
fxTwoEventsSent := func() bool {
return len(ps[0].sm.events) == 5 && len(ps[1].sm.events) == 5 && len(ps[2].sm.events) == 3
}
waitFor(t, "after untagging first and second node get events, but third does not", fxTwoEventsSent, 10, time.Second*1)
}
func TestOnlyTaggedPeersCanJoin(t *testing.T) {
nettest.SkipIfNoNetwork(t)
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
clusterTag := "tag:whatever"
ps, _, controlURL := startNodesAndWaitForPeerStatus(t, ctx, clusterTag, 3)
cfg := DefaultConfig()
createConsensusCluster(t, ctx, clusterTag, ps, cfg)
for _, p := range ps {
defer p.c.Stop(ctx)
}
tsJoiner, _, _ := startNode(t, ctx, controlURL, "joiner node")
ipv4, _ := tsJoiner.TailscaleIPs()
url := fmt.Sprintf("http://%s/join", ps[0].c.commandAddr(ps[0].c.self.host))
payload, err := json.Marshal(joinRequest{
RemoteHost: ipv4.String(),
RemoteID: "node joiner",
})
if err != nil {
t.Fatal(err)
}
body := bytes.NewBuffer(payload)
req, err := http.NewRequest("POST", url, body)
if err != nil {
t.Fatal(err)
}
resp, err := tsJoiner.HTTPClient().Do(req)
if err != nil {
t.Fatal(err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusBadRequest {
t.Fatalf("join req when not tagged, expected status: %d, got: %d", http.StatusBadRequest, resp.StatusCode)
}
rBody, _ := io.ReadAll(resp.Body)
sBody := strings.TrimSpace(string(rBody))
expected := "peer not allowed"
if sBody != expected {
t.Fatalf("join req when not tagged, expected body: %s, got: %s", expected, sBody)
}
}