don't http with non allowed peers either

This commit is contained in:
Fran Bull 2025-02-12 11:12:18 -08:00
parent 027b2ca840
commit 95131102df
3 changed files with 72 additions and 9 deletions

View File

@ -9,6 +9,8 @@ import (
"io" "io"
"net/http" "net/http"
"time" "time"
"tailscale.com/tsnet"
) )
type joinRequest struct { type joinRequest struct {
@ -77,9 +79,24 @@ func (rac *commandClient) ExecuteCommand(host string, bs []byte) (CommandResult,
return cr, nil return cr, nil
} }
func (c *Consensus) makeCommandMux() *http.ServeMux { func taggedOnly(ts *tsnet.Server, tag string, fx func(http.ResponseWriter, *http.Request)) func(http.ResponseWriter, *http.Request) {
return func(w http.ResponseWriter, r *http.Request) {
allowed, err := allowedPeer(r.RemoteAddr, tag, ts)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
if !allowed {
http.Error(w, "peer not allowed", http.StatusBadRequest)
return
}
fx(w, r)
}
}
func (c *Consensus) makeCommandMux(ts *tsnet.Server, tag string) *http.ServeMux {
mux := http.NewServeMux() mux := http.NewServeMux()
mux.HandleFunc("/join", func(w http.ResponseWriter, r *http.Request) { mux.HandleFunc("/join", taggedOnly(ts, tag, func(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost { if r.Method != http.MethodPost {
http.Error(w, "Bad Request", http.StatusBadRequest) http.Error(w, "Bad Request", http.StatusBadRequest)
return return
@ -104,8 +121,8 @@ func (c *Consensus) makeCommandMux() *http.ServeMux {
http.Error(w, err.Error(), http.StatusInternalServerError) http.Error(w, err.Error(), http.StatusInternalServerError)
return return
} }
}) }))
mux.HandleFunc("/executeCommand", func(w http.ResponseWriter, r *http.Request) { mux.HandleFunc("/executeCommand", taggedOnly(ts, tag, func(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost { if r.Method != http.MethodPost {
http.Error(w, "Bad Request", http.StatusBadRequest) http.Error(w, "Bad Request", http.StatusBadRequest)
return return
@ -122,6 +139,6 @@ func (c *Consensus) makeCommandMux() *http.ServeMux {
http.Error(w, err.Error(), http.StatusInternalServerError) http.Error(w, err.Error(), http.StatusInternalServerError)
return return
} }
}) }))
return mux return mux
} }

View File

@ -221,7 +221,7 @@ func Start(ctx context.Context, ts *tsnet.Server, fsm raft.FSM, clusterTag strin
return nil, err return nil, err
} }
c.Raft = r c.Raft = r
srv, err := c.serveCmdHttp(ts) srv, err := c.serveCmdHttp(ts, clusterTag)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -386,12 +386,12 @@ func (e lookElsewhereError) Error() string {
var ErrLeaderUnknown = errors.New("Leader Unknown") var ErrLeaderUnknown = errors.New("Leader Unknown")
func (c *Consensus) serveCmdHttp(ts *tsnet.Server) (*http.Server, error) { func (c *Consensus) serveCmdHttp(ts *tsnet.Server, tag string) (*http.Server, error) {
ln, err := ts.Listen("tcp", c.commandAddr(c.Self.Host)) ln, err := ts.Listen("tcp", c.commandAddr(c.Self.Host))
if err != nil { if err != nil {
return nil, err return nil, err
} }
mux := c.makeCommandMux() mux := c.makeCommandMux(ts, tag)
srv := &http.Server{Handler: mux} srv := &http.Server{Handler: mux}
go func() { go func() {
err := srv.Serve(ln) err := srv.Serve(ln)

View File

@ -2,12 +2,15 @@ package tsconsensus
import ( import (
"bufio" "bufio"
"bytes"
"context" "context"
"encoding/json"
"errors" "errors"
"fmt" "fmt"
"io" "io"
"log" "log"
"net" "net"
"net/http"
"net/http/httptest" "net/http/httptest"
"net/netip" "net/netip"
"os" "os"
@ -585,7 +588,6 @@ func TestOnlyTaggedPeersCanBeDialed(t *testing.T) {
} }
fxOneEventSent := func() bool { fxOneEventSent := func() bool {
fmt.Println(len(ps[0].sm.events), len(ps[1].sm.events), len(ps[2].sm.events))
return len(ps[0].sm.events) == 4 && len(ps[1].sm.events) == 4 && len(ps[2].sm.events) == 3 return len(ps[0].sm.events) == 4 && len(ps[1].sm.events) == 4 && len(ps[2].sm.events) == 3
} }
waitFor(t, "after untagging first and second node get events, but third does not", fxOneEventSent, 10, time.Second*1) waitFor(t, "after untagging first and second node get events, but third does not", fxOneEventSent, 10, time.Second*1)
@ -603,3 +605,47 @@ func TestOnlyTaggedPeersCanBeDialed(t *testing.T) {
} }
waitFor(t, "after untagging first and second node get events, but third does not", fxTwoEventsSent, 10, time.Second*1) waitFor(t, "after untagging first and second node get events, but third does not", fxTwoEventsSent, 10, time.Second*1)
} }
func TestOnlyTaggedPeersCanJoin(t *testing.T) {
nettest.SkipIfNoNetwork(t)
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
clusterTag := "tag:whatever"
ps, _, controlURL := startNodesAndWaitForPeerStatus(t, ctx, clusterTag, 3)
cfg := DefaultConfig()
createConsensusCluster(t, ctx, clusterTag, ps, cfg)
for _, p := range ps {
defer p.c.Stop(ctx)
}
tsJoiner, _, _ := startNode(t, ctx, controlURL, "joiner node")
ipv4, _ := tsJoiner.TailscaleIPs()
url := fmt.Sprintf("http://%s/join", ps[0].c.commandAddr(ps[0].c.Self.Host))
payload, err := json.Marshal(joinRequest{
RemoteHost: ipv4.String(),
RemoteID: "node joiner",
})
if err != nil {
t.Fatal(err)
}
body := bytes.NewBuffer(payload)
req, err := http.NewRequest("POST", url, body)
if err != nil {
t.Fatal(err)
}
resp, err := tsJoiner.HTTPClient().Do(req)
if err != nil {
t.Fatal(err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusBadRequest {
t.Fatalf("join req when not tagged, expected status: %d, got: %d", http.StatusBadRequest, resp.StatusCode)
}
rBody, _ := io.ReadAll(resp.Body)
sBody := strings.TrimSpace(string(rBody))
expected := "peer not allowed"
if sBody != expected {
t.Fatalf("join req when not tagged, expected body: %s, got: %s", expected, sBody)
}
}