diff --git a/tsconsensus/tsconsensus.go b/tsconsensus/tsconsensus.go index 3dcc380b1..51c1a36d2 100644 --- a/tsconsensus/tsconsensus.go +++ b/tsconsensus/tsconsensus.go @@ -8,6 +8,7 @@ import ( "log" "net" "net/http" + "net/netip" "slices" "time" @@ -82,7 +83,8 @@ func DefaultConfig() Config { // It does the raft interprocess communication via tailscale. type StreamLayer struct { net.Listener - s *tsnet.Server + s *tsnet.Server + tag string } // Dial implements the raft.StreamLayer interface with the tsnet.Server's Dial. @@ -91,8 +93,102 @@ func (sl StreamLayer) Dial(address raft.ServerAddress, timeout time.Duration) (n return sl.s.Dial(ctx, "tcp", string(address)) } +func allowedPeer(remoteAddr string, tag string, s *tsnet.Server) (bool, error) { + sAddr, _, err := net.SplitHostPort(remoteAddr) + if err != nil { + return false, err + } + a, err := netip.ParseAddr(sAddr) + if err != nil { + return false, err + } + ctx := context.Background() // TODO very much a sign I shouldn't be doing this here + peers, err := taggedNodesFromStatus(ctx, tag, s) + if err != nil { + return false, err + } + return peers.has(a), nil +} + +func (sl StreamLayer) Accept() (net.Conn, error) { + for { + conn, err := sl.Listener.Accept() + if err != nil || conn == nil { + return conn, err + } + allowed, err := allowedPeer(conn.RemoteAddr().String(), sl.tag, sl.s) + if err != nil { + // TODO should we stay alive here? + return nil, err + } + if !allowed { + continue + } + return conn, err + } +} + +type allowedPeers struct { + self *ipnstate.PeerStatus + peers []*ipnstate.PeerStatus + peerByIPAddress map[netip.Addr]*ipnstate.PeerStatus + clusterTag string +} + +func (ap *allowedPeers) allowed(n *ipnstate.PeerStatus) bool { + return n.Tags != nil && slices.Contains(n.Tags.AsSlice(), ap.clusterTag) +} + +func (ap *allowedPeers) addPeerIfAllowed(p *ipnstate.PeerStatus) { + if !ap.allowed(p) { + return + } + ap.peers = append(ap.peers, p) + for _, addr := range p.TailscaleIPs { + ap.peerByIPAddress[addr] = p + } +} + +func (ap *allowedPeers) addSelfIfAllowed(n *ipnstate.PeerStatus) { + if ap.allowed(n) { + ap.self = n + } +} + +func (ap *allowedPeers) has(a netip.Addr) bool { + _, ok := ap.peerByIPAddress[a] + return ok +} + +func taggedNodesFromStatus(ctx context.Context, clusterTag string, ts *tsnet.Server) (*allowedPeers, error) { + lc, err := ts.LocalClient() + if err != nil { + return nil, err + } + tStatus, err := lc.Status(ctx) + if err != nil { + return nil, err + } + ap := newAllowedPeers(clusterTag) + for _, v := range tStatus.Peer { + ap.addPeerIfAllowed(v) + } + ap.addSelfIfAllowed(tStatus.Self) + return ap, nil +} + +func newAllowedPeers(tag string) *allowedPeers { + return &allowedPeers{ + peerByIPAddress: map[netip.Addr]*ipnstate.PeerStatus{}, + clusterTag: tag, + } +} + // Start returns a pointer to a running Consensus instance. -func Start(ctx context.Context, ts *tsnet.Server, fsm raft.FSM, targetTag string, cfg Config) (*Consensus, error) { +func Start(ctx context.Context, ts *tsnet.Server, fsm raft.FSM, clusterTag string, cfg Config) (*Consensus, error) { + if clusterTag == "" { + return nil, errors.New("cluster tag must be provided") + } v4, _ := ts.TailscaleIPs() cc := commandClient{ port: cfg.CommandPort, @@ -108,26 +204,12 @@ func Start(ctx context.Context, ts *tsnet.Server, fsm raft.FSM, targetTag string Config: cfg, } - lc, err := ts.LocalClient() - if err != nil { - return nil, err - } - tStatus, err := lc.Status(ctx) - if err != nil { - return nil, err - } - var targets []*ipnstate.PeerStatus - if targetTag != "" && tStatus.Self.Tags != nil && slices.Contains(tStatus.Self.Tags.AsSlice(), targetTag) { - for _, v := range tStatus.Peer { - if v.Tags != nil && slices.Contains(v.Tags.AsSlice(), targetTag) { - targets = append(targets, v) - } - } - } else { - return nil, errors.New("targetTag empty, or this node is not tagged with it") + tnfs, err := taggedNodesFromStatus(ctx, clusterTag, ts) + if tnfs.self == nil { + return nil, errors.New("this node is not tagged with the cluster tag") } - r, err := startRaft(ts, &fsm, c.Self, cfg) + r, err := startRaft(ts, &fsm, c.Self, clusterTag, cfg) if err != nil { return nil, err } @@ -137,7 +219,7 @@ func Start(ctx context.Context, ts *tsnet.Server, fsm raft.FSM, targetTag string return nil, err } c.cmdHttpServer = srv - c.bootstrap(targets) + c.bootstrap(tnfs.peers) srv, err = serveMonitor(&c, ts, addr(c.Self.Host, cfg.MonitorPort)) if err != nil { return nil, err @@ -146,7 +228,7 @@ func Start(ctx context.Context, ts *tsnet.Server, fsm raft.FSM, targetTag string return &c, nil } -func startRaft(ts *tsnet.Server, fsm *raft.FSM, self SelfRaftNode, cfg Config) (*raft.Raft, error) { +func startRaft(ts *tsnet.Server, fsm *raft.FSM, self SelfRaftNode, clusterTag string, cfg Config) (*raft.Raft, error) { config := cfg.Raft config.LocalID = raft.ServerID(self.ID) @@ -164,6 +246,7 @@ func startRaft(ts *tsnet.Server, fsm *raft.FSM, self SelfRaftNode, cfg Config) ( transport := raft.NewNetworkTransport(StreamLayer{ s: ts, Listener: ln, + tag: clusterTag, }, cfg.MaxConnPool, cfg.ConnTimeout, diff --git a/tsconsensus/tsconsensus_test.go b/tsconsensus/tsconsensus_test.go index d7a669f4f..83b6c62ae 100644 --- a/tsconsensus/tsconsensus_test.go +++ b/tsconsensus/tsconsensus_test.go @@ -1,15 +1,18 @@ package tsconsensus import ( + "bufio" "context" + "errors" "fmt" "io" "log" - "net/http" + "net" "net/http/httptest" "net/netip" "os" "path/filepath" + "slices" "strings" "testing" "time" @@ -25,6 +28,7 @@ import ( "tailscale.com/tstest/nettest" "tailscale.com/types/key" "tailscale.com/types/logger" + "tailscale.com/types/views" ) type fsm struct { @@ -119,29 +123,29 @@ func startNode(t *testing.T, ctx context.Context, controlURL, hostname string) ( return s, status.Self.PublicKey, status.TailscaleIPs[0] } -func pingNode(t *testing.T, control *testcontrol.Server, nodeKey key.NodePublic) { - t.Helper() - gotPing := make(chan bool, 1) - waitPing := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - gotPing <- true - })) - defer waitPing.Close() - - for try := 0; try < 5; try++ { - pr := &tailcfg.PingRequest{URL: fmt.Sprintf("%s/ping-%d", waitPing.URL, try), Log: true} - if !control.AddPingRequest(nodeKey, pr) { - t.Fatalf("failed to AddPingRequest") +func waitForNodesToBeTaggedInStatus(t *testing.T, ctx context.Context, ts *tsnet.Server, nodeKeys []key.NodePublic, tag string) { + waitFor(t, "nodes tagged in status", func() bool { + lc, err := ts.LocalClient() + if err != nil { + t.Fatal(err) } - pingTimeout := time.NewTimer(2 * time.Second) - defer pingTimeout.Stop() - select { - case <-gotPing: - // ok! the machinery that refreshes the netmap has been nudged - return - case <-pingTimeout.C: - t.Logf("waiting for ping timed out: %d", try) + status, err := lc.Status(ctx) + if err != nil { + t.Fatalf("error getting status: %v", err) } - } + for _, k := range nodeKeys { + var tags *views.Slice[string] + if k == status.Self.PublicKey { + tags = status.Self.Tags + } else { + tags = status.Peer[k].Tags + } + if tags == nil || !slices.Contains(tags.AsSlice(), tag) { + return false + } + } + return true + }, 5, 1*time.Second) } func tagNodes(t *testing.T, control *testcontrol.Server, nodeKeys []key.NodePublic, tag string) { @@ -153,13 +157,6 @@ func tagNodes(t *testing.T, control *testcontrol.Server, nodeKeys []key.NodePubl n.Online = &b control.UpdateNode(n) } - - // all this ping stuff is only to prod the netmap to get updated with the tag we just added to the node - // ie to actually get the netmap issued to clients that represents the current state of the nodes - // there _must_ be a better way to do this, but I looked all day and this was the first thing I found that worked. - for _, key := range nodeKeys { - pingNode(t, control, key) - } } // TODO test start with al lthe config settings @@ -173,6 +170,7 @@ func TestStart(t *testing.T) { clusterTag := "tag:whatever" // nodes must be tagged with the cluster tag, to find each other tagNodes(t, control, []key.NodePublic{k}, clusterTag) + waitForNodesToBeTaggedInStatus(t, ctx, one, []key.NodePublic{k}, clusterTag) sm := &fsm{} r, err := Start(ctx, one, (*fsm)(sm), clusterTag, DefaultConfig()) @@ -219,6 +217,7 @@ func startNodesAndWaitForPeerStatus(t *testing.T, ctx context.Context, clusterTa localClients[i] = lc } tagNodes(t, control, keysToTag, clusterTag) + waitForNodesToBeTaggedInStatus(t, ctx, ps[0].ts, keysToTag, clusterTag) fxCameOnline := func() bool { // all the _other_ nodes see the first as online for i := 1; i < nNodes; i++ { @@ -443,6 +442,7 @@ func TestRejoin(t *testing.T) { tsJoiner, keyJoiner, _ := startNode(t, ctx, controlURL, "node: joiner") tagNodes(t, control, []key.NodePublic{keyJoiner}, clusterTag) + waitForNodesToBeTaggedInStatus(t, ctx, ps[0].ts, []key.NodePublic{keyJoiner}, clusterTag) smJoiner := &fsm{} cJoiner, err := Start(ctx, tsJoiner, (*fsm)(smJoiner), clusterTag, cfg) if err != nil { @@ -457,3 +457,66 @@ func TestRejoin(t *testing.T) { assertCommandsWorkOnAnyNode(t, ps) } + +func TestOnlyTaggedPeersCanDialRaftPort(t *testing.T) { + nettest.SkipIfNoNetwork(t) + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + clusterTag := "tag:whatever" + ps, control, controlURL := startNodesAndWaitForPeerStatus(t, ctx, clusterTag, 3) + cfg := DefaultConfig() + createConsensusCluster(t, ctx, clusterTag, ps, cfg) + for _, p := range ps { + defer p.c.Stop(ctx) + } + assertCommandsWorkOnAnyNode(t, ps) + + untaggedNode, _, _ := startNode(t, ctx, controlURL, "untagged node") + + taggedNode, taggedKey, _ := startNode(t, ctx, controlURL, "untagged node") + tagNodes(t, control, []key.NodePublic{taggedKey}, clusterTag) + waitForNodesToBeTaggedInStatus(t, ctx, ps[0].ts, []key.NodePublic{taggedKey}, clusterTag) + + // surface area: command http, peer tcp + //untagged + ipv4, _ := ps[0].ts.TailscaleIPs() + sAddr := fmt.Sprintf("%s:%d", ipv4, cfg.RaftPort) + + isNetTimeoutErr := func(err error) bool { + var netErr net.Error + if !errors.As(err, &netErr) { + return false + } + return netErr.Timeout() + } + + getErrorFromTryingToSend := func(s *tsnet.Server) error { + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + conn, err := s.Dial(ctx, "tcp", sAddr) + if err != nil { + t.Fatalf("unexpected Dial err: %v", err) + } + conn.SetDeadline(time.Now().Add(1 * time.Second)) + fmt.Fprintf(conn, "hellllllloooooo") + status, err := bufio.NewReader(conn).ReadString('\n') + if status != "" { + t.Fatalf("node sending non-raft message should get empty response, got: '%s' for: %s", status, s.Hostname) + } + if err == nil { + t.Fatalf("node sending non-raft message should get an error but got nil err for: %s", s.Hostname) + } + return err + } + + err := getErrorFromTryingToSend(untaggedNode) + if !isNetTimeoutErr(err) { + t.Fatalf("untagged node trying to send should time out, got: %v", err) + } + // we still get an error trying to send but it's EOF the target node was happy to talk + // to us but couldn't understand what we said. + err = getErrorFromTryingToSend(taggedNode) + if isNetTimeoutErr(err) { + t.Fatalf("tagged node trying to send should not time out, got: %v", err) + } +}