From 95131102df3c800a2328cf0658d5f298085f7afd Mon Sep 17 00:00:00 2001 From: Fran Bull Date: Wed, 12 Feb 2025 11:12:18 -0800 Subject: [PATCH] don't http with non allowed peers either --- tsconsensus/http.go | 27 +++++++++++++++---- tsconsensus/tsconsensus.go | 6 ++--- tsconsensus/tsconsensus_test.go | 48 ++++++++++++++++++++++++++++++++- 3 files changed, 72 insertions(+), 9 deletions(-) diff --git a/tsconsensus/http.go b/tsconsensus/http.go index 4e19f1aac..301127687 100644 --- a/tsconsensus/http.go +++ b/tsconsensus/http.go @@ -9,6 +9,8 @@ import ( "io" "net/http" "time" + + "tailscale.com/tsnet" ) type joinRequest struct { @@ -77,9 +79,24 @@ func (rac *commandClient) ExecuteCommand(host string, bs []byte) (CommandResult, return cr, nil } -func (c *Consensus) makeCommandMux() *http.ServeMux { +func taggedOnly(ts *tsnet.Server, tag string, fx func(http.ResponseWriter, *http.Request)) func(http.ResponseWriter, *http.Request) { + return func(w http.ResponseWriter, r *http.Request) { + allowed, err := allowedPeer(r.RemoteAddr, tag, ts) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + if !allowed { + http.Error(w, "peer not allowed", http.StatusBadRequest) + return + } + fx(w, r) + } +} + +func (c *Consensus) makeCommandMux(ts *tsnet.Server, tag string) *http.ServeMux { mux := http.NewServeMux() - mux.HandleFunc("/join", func(w http.ResponseWriter, r *http.Request) { + mux.HandleFunc("/join", taggedOnly(ts, tag, func(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodPost { http.Error(w, "Bad Request", http.StatusBadRequest) return @@ -104,8 +121,8 @@ func (c *Consensus) makeCommandMux() *http.ServeMux { http.Error(w, err.Error(), http.StatusInternalServerError) return } - }) - mux.HandleFunc("/executeCommand", func(w http.ResponseWriter, r *http.Request) { + })) + mux.HandleFunc("/executeCommand", taggedOnly(ts, tag, func(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodPost { http.Error(w, "Bad Request", http.StatusBadRequest) return @@ -122,6 +139,6 @@ func (c *Consensus) makeCommandMux() *http.ServeMux { http.Error(w, err.Error(), http.StatusInternalServerError) return } - }) + })) return mux } diff --git a/tsconsensus/tsconsensus.go b/tsconsensus/tsconsensus.go index e9f335bf3..c6fb4f30b 100644 --- a/tsconsensus/tsconsensus.go +++ b/tsconsensus/tsconsensus.go @@ -221,7 +221,7 @@ func Start(ctx context.Context, ts *tsnet.Server, fsm raft.FSM, clusterTag strin return nil, err } c.Raft = r - srv, err := c.serveCmdHttp(ts) + srv, err := c.serveCmdHttp(ts, clusterTag) if err != nil { return nil, err } @@ -386,12 +386,12 @@ func (e lookElsewhereError) Error() string { var ErrLeaderUnknown = errors.New("Leader Unknown") -func (c *Consensus) serveCmdHttp(ts *tsnet.Server) (*http.Server, error) { +func (c *Consensus) serveCmdHttp(ts *tsnet.Server, tag string) (*http.Server, error) { ln, err := ts.Listen("tcp", c.commandAddr(c.Self.Host)) if err != nil { return nil, err } - mux := c.makeCommandMux() + mux := c.makeCommandMux(ts, tag) srv := &http.Server{Handler: mux} go func() { err := srv.Serve(ln) diff --git a/tsconsensus/tsconsensus_test.go b/tsconsensus/tsconsensus_test.go index 1011a2648..d8c182980 100644 --- a/tsconsensus/tsconsensus_test.go +++ b/tsconsensus/tsconsensus_test.go @@ -2,12 +2,15 @@ package tsconsensus import ( "bufio" + "bytes" "context" + "encoding/json" "errors" "fmt" "io" "log" "net" + "net/http" "net/http/httptest" "net/netip" "os" @@ -585,7 +588,6 @@ func TestOnlyTaggedPeersCanBeDialed(t *testing.T) { } fxOneEventSent := func() bool { - fmt.Println(len(ps[0].sm.events), len(ps[1].sm.events), len(ps[2].sm.events)) return len(ps[0].sm.events) == 4 && len(ps[1].sm.events) == 4 && len(ps[2].sm.events) == 3 } waitFor(t, "after untagging first and second node get events, but third does not", fxOneEventSent, 10, time.Second*1) @@ -603,3 +605,47 @@ func TestOnlyTaggedPeersCanBeDialed(t *testing.T) { } waitFor(t, "after untagging first and second node get events, but third does not", fxTwoEventsSent, 10, time.Second*1) } + +func TestOnlyTaggedPeersCanJoin(t *testing.T) { + nettest.SkipIfNoNetwork(t) + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + clusterTag := "tag:whatever" + ps, _, controlURL := startNodesAndWaitForPeerStatus(t, ctx, clusterTag, 3) + cfg := DefaultConfig() + createConsensusCluster(t, ctx, clusterTag, ps, cfg) + for _, p := range ps { + defer p.c.Stop(ctx) + } + + tsJoiner, _, _ := startNode(t, ctx, controlURL, "joiner node") + + ipv4, _ := tsJoiner.TailscaleIPs() + url := fmt.Sprintf("http://%s/join", ps[0].c.commandAddr(ps[0].c.Self.Host)) + payload, err := json.Marshal(joinRequest{ + RemoteHost: ipv4.String(), + RemoteID: "node joiner", + }) + if err != nil { + t.Fatal(err) + } + body := bytes.NewBuffer(payload) + req, err := http.NewRequest("POST", url, body) + if err != nil { + t.Fatal(err) + } + resp, err := tsJoiner.HTTPClient().Do(req) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusBadRequest { + t.Fatalf("join req when not tagged, expected status: %d, got: %d", http.StatusBadRequest, resp.StatusCode) + } + rBody, _ := io.ReadAll(resp.Body) + sBody := strings.TrimSpace(string(rBody)) + expected := "peer not allowed" + if sBody != expected { + t.Fatalf("join req when not tagged, expected body: %s, got: %s", expected, sBody) + } +}