mirror of
https://github.com/tailscale/tailscale.git
synced 2025-08-05 02:16:27 +00:00
dont accept conns from untagged peers
This commit is contained in:
@@ -8,6 +8,7 @@ import (
|
|||||||
"log"
|
"log"
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"net/netip"
|
||||||
"slices"
|
"slices"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -82,7 +83,8 @@ func DefaultConfig() Config {
|
|||||||
// It does the raft interprocess communication via tailscale.
|
// It does the raft interprocess communication via tailscale.
|
||||||
type StreamLayer struct {
|
type StreamLayer struct {
|
||||||
net.Listener
|
net.Listener
|
||||||
s *tsnet.Server
|
s *tsnet.Server
|
||||||
|
tag string
|
||||||
}
|
}
|
||||||
|
|
||||||
// Dial implements the raft.StreamLayer interface with the tsnet.Server's Dial.
|
// Dial implements the raft.StreamLayer interface with the tsnet.Server's Dial.
|
||||||
@@ -91,8 +93,102 @@ func (sl StreamLayer) Dial(address raft.ServerAddress, timeout time.Duration) (n
|
|||||||
return sl.s.Dial(ctx, "tcp", string(address))
|
return sl.s.Dial(ctx, "tcp", string(address))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func allowedPeer(remoteAddr string, tag string, s *tsnet.Server) (bool, error) {
|
||||||
|
sAddr, _, err := net.SplitHostPort(remoteAddr)
|
||||||
|
if err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
a, err := netip.ParseAddr(sAddr)
|
||||||
|
if err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
ctx := context.Background() // TODO very much a sign I shouldn't be doing this here
|
||||||
|
peers, err := taggedNodesFromStatus(ctx, tag, s)
|
||||||
|
if err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
return peers.has(a), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (sl StreamLayer) Accept() (net.Conn, error) {
|
||||||
|
for {
|
||||||
|
conn, err := sl.Listener.Accept()
|
||||||
|
if err != nil || conn == nil {
|
||||||
|
return conn, err
|
||||||
|
}
|
||||||
|
allowed, err := allowedPeer(conn.RemoteAddr().String(), sl.tag, sl.s)
|
||||||
|
if err != nil {
|
||||||
|
// TODO should we stay alive here?
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if !allowed {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
return conn, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type allowedPeers struct {
|
||||||
|
self *ipnstate.PeerStatus
|
||||||
|
peers []*ipnstate.PeerStatus
|
||||||
|
peerByIPAddress map[netip.Addr]*ipnstate.PeerStatus
|
||||||
|
clusterTag string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ap *allowedPeers) allowed(n *ipnstate.PeerStatus) bool {
|
||||||
|
return n.Tags != nil && slices.Contains(n.Tags.AsSlice(), ap.clusterTag)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ap *allowedPeers) addPeerIfAllowed(p *ipnstate.PeerStatus) {
|
||||||
|
if !ap.allowed(p) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
ap.peers = append(ap.peers, p)
|
||||||
|
for _, addr := range p.TailscaleIPs {
|
||||||
|
ap.peerByIPAddress[addr] = p
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ap *allowedPeers) addSelfIfAllowed(n *ipnstate.PeerStatus) {
|
||||||
|
if ap.allowed(n) {
|
||||||
|
ap.self = n
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ap *allowedPeers) has(a netip.Addr) bool {
|
||||||
|
_, ok := ap.peerByIPAddress[a]
|
||||||
|
return ok
|
||||||
|
}
|
||||||
|
|
||||||
|
func taggedNodesFromStatus(ctx context.Context, clusterTag string, ts *tsnet.Server) (*allowedPeers, error) {
|
||||||
|
lc, err := ts.LocalClient()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
tStatus, err := lc.Status(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
ap := newAllowedPeers(clusterTag)
|
||||||
|
for _, v := range tStatus.Peer {
|
||||||
|
ap.addPeerIfAllowed(v)
|
||||||
|
}
|
||||||
|
ap.addSelfIfAllowed(tStatus.Self)
|
||||||
|
return ap, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func newAllowedPeers(tag string) *allowedPeers {
|
||||||
|
return &allowedPeers{
|
||||||
|
peerByIPAddress: map[netip.Addr]*ipnstate.PeerStatus{},
|
||||||
|
clusterTag: tag,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Start returns a pointer to a running Consensus instance.
|
// Start returns a pointer to a running Consensus instance.
|
||||||
func Start(ctx context.Context, ts *tsnet.Server, fsm raft.FSM, targetTag string, cfg Config) (*Consensus, error) {
|
func Start(ctx context.Context, ts *tsnet.Server, fsm raft.FSM, clusterTag string, cfg Config) (*Consensus, error) {
|
||||||
|
if clusterTag == "" {
|
||||||
|
return nil, errors.New("cluster tag must be provided")
|
||||||
|
}
|
||||||
v4, _ := ts.TailscaleIPs()
|
v4, _ := ts.TailscaleIPs()
|
||||||
cc := commandClient{
|
cc := commandClient{
|
||||||
port: cfg.CommandPort,
|
port: cfg.CommandPort,
|
||||||
@@ -108,26 +204,12 @@ func Start(ctx context.Context, ts *tsnet.Server, fsm raft.FSM, targetTag string
|
|||||||
Config: cfg,
|
Config: cfg,
|
||||||
}
|
}
|
||||||
|
|
||||||
lc, err := ts.LocalClient()
|
tnfs, err := taggedNodesFromStatus(ctx, clusterTag, ts)
|
||||||
if err != nil {
|
if tnfs.self == nil {
|
||||||
return nil, err
|
return nil, errors.New("this node is not tagged with the cluster tag")
|
||||||
}
|
|
||||||
tStatus, err := lc.Status(ctx)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
var targets []*ipnstate.PeerStatus
|
|
||||||
if targetTag != "" && tStatus.Self.Tags != nil && slices.Contains(tStatus.Self.Tags.AsSlice(), targetTag) {
|
|
||||||
for _, v := range tStatus.Peer {
|
|
||||||
if v.Tags != nil && slices.Contains(v.Tags.AsSlice(), targetTag) {
|
|
||||||
targets = append(targets, v)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
return nil, errors.New("targetTag empty, or this node is not tagged with it")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
r, err := startRaft(ts, &fsm, c.Self, cfg)
|
r, err := startRaft(ts, &fsm, c.Self, clusterTag, cfg)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -137,7 +219,7 @@ func Start(ctx context.Context, ts *tsnet.Server, fsm raft.FSM, targetTag string
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
c.cmdHttpServer = srv
|
c.cmdHttpServer = srv
|
||||||
c.bootstrap(targets)
|
c.bootstrap(tnfs.peers)
|
||||||
srv, err = serveMonitor(&c, ts, addr(c.Self.Host, cfg.MonitorPort))
|
srv, err = serveMonitor(&c, ts, addr(c.Self.Host, cfg.MonitorPort))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -146,7 +228,7 @@ func Start(ctx context.Context, ts *tsnet.Server, fsm raft.FSM, targetTag string
|
|||||||
return &c, nil
|
return &c, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func startRaft(ts *tsnet.Server, fsm *raft.FSM, self SelfRaftNode, cfg Config) (*raft.Raft, error) {
|
func startRaft(ts *tsnet.Server, fsm *raft.FSM, self SelfRaftNode, clusterTag string, cfg Config) (*raft.Raft, error) {
|
||||||
config := cfg.Raft
|
config := cfg.Raft
|
||||||
config.LocalID = raft.ServerID(self.ID)
|
config.LocalID = raft.ServerID(self.ID)
|
||||||
|
|
||||||
@@ -164,6 +246,7 @@ func startRaft(ts *tsnet.Server, fsm *raft.FSM, self SelfRaftNode, cfg Config) (
|
|||||||
transport := raft.NewNetworkTransport(StreamLayer{
|
transport := raft.NewNetworkTransport(StreamLayer{
|
||||||
s: ts,
|
s: ts,
|
||||||
Listener: ln,
|
Listener: ln,
|
||||||
|
tag: clusterTag,
|
||||||
},
|
},
|
||||||
cfg.MaxConnPool,
|
cfg.MaxConnPool,
|
||||||
cfg.ConnTimeout,
|
cfg.ConnTimeout,
|
||||||
|
@@ -1,15 +1,18 @@
|
|||||||
package tsconsensus
|
package tsconsensus
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bufio"
|
||||||
"context"
|
"context"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"log"
|
"log"
|
||||||
"net/http"
|
"net"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
|
"slices"
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
@@ -25,6 +28,7 @@ import (
|
|||||||
"tailscale.com/tstest/nettest"
|
"tailscale.com/tstest/nettest"
|
||||||
"tailscale.com/types/key"
|
"tailscale.com/types/key"
|
||||||
"tailscale.com/types/logger"
|
"tailscale.com/types/logger"
|
||||||
|
"tailscale.com/types/views"
|
||||||
)
|
)
|
||||||
|
|
||||||
type fsm struct {
|
type fsm struct {
|
||||||
@@ -119,29 +123,29 @@ func startNode(t *testing.T, ctx context.Context, controlURL, hostname string) (
|
|||||||
return s, status.Self.PublicKey, status.TailscaleIPs[0]
|
return s, status.Self.PublicKey, status.TailscaleIPs[0]
|
||||||
}
|
}
|
||||||
|
|
||||||
func pingNode(t *testing.T, control *testcontrol.Server, nodeKey key.NodePublic) {
|
func waitForNodesToBeTaggedInStatus(t *testing.T, ctx context.Context, ts *tsnet.Server, nodeKeys []key.NodePublic, tag string) {
|
||||||
t.Helper()
|
waitFor(t, "nodes tagged in status", func() bool {
|
||||||
gotPing := make(chan bool, 1)
|
lc, err := ts.LocalClient()
|
||||||
waitPing := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
if err != nil {
|
||||||
gotPing <- true
|
t.Fatal(err)
|
||||||
}))
|
|
||||||
defer waitPing.Close()
|
|
||||||
|
|
||||||
for try := 0; try < 5; try++ {
|
|
||||||
pr := &tailcfg.PingRequest{URL: fmt.Sprintf("%s/ping-%d", waitPing.URL, try), Log: true}
|
|
||||||
if !control.AddPingRequest(nodeKey, pr) {
|
|
||||||
t.Fatalf("failed to AddPingRequest")
|
|
||||||
}
|
}
|
||||||
pingTimeout := time.NewTimer(2 * time.Second)
|
status, err := lc.Status(ctx)
|
||||||
defer pingTimeout.Stop()
|
if err != nil {
|
||||||
select {
|
t.Fatalf("error getting status: %v", err)
|
||||||
case <-gotPing:
|
|
||||||
// ok! the machinery that refreshes the netmap has been nudged
|
|
||||||
return
|
|
||||||
case <-pingTimeout.C:
|
|
||||||
t.Logf("waiting for ping timed out: %d", try)
|
|
||||||
}
|
}
|
||||||
}
|
for _, k := range nodeKeys {
|
||||||
|
var tags *views.Slice[string]
|
||||||
|
if k == status.Self.PublicKey {
|
||||||
|
tags = status.Self.Tags
|
||||||
|
} else {
|
||||||
|
tags = status.Peer[k].Tags
|
||||||
|
}
|
||||||
|
if tags == nil || !slices.Contains(tags.AsSlice(), tag) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}, 5, 1*time.Second)
|
||||||
}
|
}
|
||||||
|
|
||||||
func tagNodes(t *testing.T, control *testcontrol.Server, nodeKeys []key.NodePublic, tag string) {
|
func tagNodes(t *testing.T, control *testcontrol.Server, nodeKeys []key.NodePublic, tag string) {
|
||||||
@@ -153,13 +157,6 @@ func tagNodes(t *testing.T, control *testcontrol.Server, nodeKeys []key.NodePubl
|
|||||||
n.Online = &b
|
n.Online = &b
|
||||||
control.UpdateNode(n)
|
control.UpdateNode(n)
|
||||||
}
|
}
|
||||||
|
|
||||||
// all this ping stuff is only to prod the netmap to get updated with the tag we just added to the node
|
|
||||||
// ie to actually get the netmap issued to clients that represents the current state of the nodes
|
|
||||||
// there _must_ be a better way to do this, but I looked all day and this was the first thing I found that worked.
|
|
||||||
for _, key := range nodeKeys {
|
|
||||||
pingNode(t, control, key)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO test start with al lthe config settings
|
// TODO test start with al lthe config settings
|
||||||
@@ -173,6 +170,7 @@ func TestStart(t *testing.T) {
|
|||||||
clusterTag := "tag:whatever"
|
clusterTag := "tag:whatever"
|
||||||
// nodes must be tagged with the cluster tag, to find each other
|
// nodes must be tagged with the cluster tag, to find each other
|
||||||
tagNodes(t, control, []key.NodePublic{k}, clusterTag)
|
tagNodes(t, control, []key.NodePublic{k}, clusterTag)
|
||||||
|
waitForNodesToBeTaggedInStatus(t, ctx, one, []key.NodePublic{k}, clusterTag)
|
||||||
|
|
||||||
sm := &fsm{}
|
sm := &fsm{}
|
||||||
r, err := Start(ctx, one, (*fsm)(sm), clusterTag, DefaultConfig())
|
r, err := Start(ctx, one, (*fsm)(sm), clusterTag, DefaultConfig())
|
||||||
@@ -219,6 +217,7 @@ func startNodesAndWaitForPeerStatus(t *testing.T, ctx context.Context, clusterTa
|
|||||||
localClients[i] = lc
|
localClients[i] = lc
|
||||||
}
|
}
|
||||||
tagNodes(t, control, keysToTag, clusterTag)
|
tagNodes(t, control, keysToTag, clusterTag)
|
||||||
|
waitForNodesToBeTaggedInStatus(t, ctx, ps[0].ts, keysToTag, clusterTag)
|
||||||
fxCameOnline := func() bool {
|
fxCameOnline := func() bool {
|
||||||
// all the _other_ nodes see the first as online
|
// all the _other_ nodes see the first as online
|
||||||
for i := 1; i < nNodes; i++ {
|
for i := 1; i < nNodes; i++ {
|
||||||
@@ -443,6 +442,7 @@ func TestRejoin(t *testing.T) {
|
|||||||
|
|
||||||
tsJoiner, keyJoiner, _ := startNode(t, ctx, controlURL, "node: joiner")
|
tsJoiner, keyJoiner, _ := startNode(t, ctx, controlURL, "node: joiner")
|
||||||
tagNodes(t, control, []key.NodePublic{keyJoiner}, clusterTag)
|
tagNodes(t, control, []key.NodePublic{keyJoiner}, clusterTag)
|
||||||
|
waitForNodesToBeTaggedInStatus(t, ctx, ps[0].ts, []key.NodePublic{keyJoiner}, clusterTag)
|
||||||
smJoiner := &fsm{}
|
smJoiner := &fsm{}
|
||||||
cJoiner, err := Start(ctx, tsJoiner, (*fsm)(smJoiner), clusterTag, cfg)
|
cJoiner, err := Start(ctx, tsJoiner, (*fsm)(smJoiner), clusterTag, cfg)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -457,3 +457,66 @@ func TestRejoin(t *testing.T) {
|
|||||||
|
|
||||||
assertCommandsWorkOnAnyNode(t, ps)
|
assertCommandsWorkOnAnyNode(t, ps)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestOnlyTaggedPeersCanDialRaftPort(t *testing.T) {
|
||||||
|
nettest.SkipIfNoNetwork(t)
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
clusterTag := "tag:whatever"
|
||||||
|
ps, control, controlURL := startNodesAndWaitForPeerStatus(t, ctx, clusterTag, 3)
|
||||||
|
cfg := DefaultConfig()
|
||||||
|
createConsensusCluster(t, ctx, clusterTag, ps, cfg)
|
||||||
|
for _, p := range ps {
|
||||||
|
defer p.c.Stop(ctx)
|
||||||
|
}
|
||||||
|
assertCommandsWorkOnAnyNode(t, ps)
|
||||||
|
|
||||||
|
untaggedNode, _, _ := startNode(t, ctx, controlURL, "untagged node")
|
||||||
|
|
||||||
|
taggedNode, taggedKey, _ := startNode(t, ctx, controlURL, "untagged node")
|
||||||
|
tagNodes(t, control, []key.NodePublic{taggedKey}, clusterTag)
|
||||||
|
waitForNodesToBeTaggedInStatus(t, ctx, ps[0].ts, []key.NodePublic{taggedKey}, clusterTag)
|
||||||
|
|
||||||
|
// surface area: command http, peer tcp
|
||||||
|
//untagged
|
||||||
|
ipv4, _ := ps[0].ts.TailscaleIPs()
|
||||||
|
sAddr := fmt.Sprintf("%s:%d", ipv4, cfg.RaftPort)
|
||||||
|
|
||||||
|
isNetTimeoutErr := func(err error) bool {
|
||||||
|
var netErr net.Error
|
||||||
|
if !errors.As(err, &netErr) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return netErr.Timeout()
|
||||||
|
}
|
||||||
|
|
||||||
|
getErrorFromTryingToSend := func(s *tsnet.Server) error {
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
conn, err := s.Dial(ctx, "tcp", sAddr)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected Dial err: %v", err)
|
||||||
|
}
|
||||||
|
conn.SetDeadline(time.Now().Add(1 * time.Second))
|
||||||
|
fmt.Fprintf(conn, "hellllllloooooo")
|
||||||
|
status, err := bufio.NewReader(conn).ReadString('\n')
|
||||||
|
if status != "" {
|
||||||
|
t.Fatalf("node sending non-raft message should get empty response, got: '%s' for: %s", status, s.Hostname)
|
||||||
|
}
|
||||||
|
if err == nil {
|
||||||
|
t.Fatalf("node sending non-raft message should get an error but got nil err for: %s", s.Hostname)
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
err := getErrorFromTryingToSend(untaggedNode)
|
||||||
|
if !isNetTimeoutErr(err) {
|
||||||
|
t.Fatalf("untagged node trying to send should time out, got: %v", err)
|
||||||
|
}
|
||||||
|
// we still get an error trying to send but it's EOF the target node was happy to talk
|
||||||
|
// to us but couldn't understand what we said.
|
||||||
|
err = getErrorFromTryingToSend(taggedNode)
|
||||||
|
if isNetTimeoutErr(err) {
|
||||||
|
t.Fatalf("tagged node trying to send should not time out, got: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Reference in New Issue
Block a user