centralize cmd http auth

This commit is contained in:
Fran Bull 2025-02-27 05:15:49 -08:00
parent 5afa742b06
commit ae30f58b46
3 changed files with 40 additions and 28 deletions

View File

@ -9,6 +9,7 @@ import (
"encoding/json"
"fmt"
"io"
"log"
"net/http"
"time"
@ -83,30 +84,35 @@ func (rac *commandClient) executeCommand(host string, bs []byte) (CommandResult,
return cr, nil
}
func authorized(auth *authorization, fx func(http.ResponseWriter, *http.Request)) func(http.ResponseWriter, *http.Request) {
return func(w http.ResponseWriter, r *http.Request) {
err := auth.refresh(r.Context())
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
a, err := addrFromServerAddress(r.RemoteAddr)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
allowed := auth.allowsHost(a)
if !allowed {
http.Error(w, "peer not allowed", http.StatusBadRequest)
return
}
fx(w, r)
}
type authedHandler struct {
auth *authorization
mux *http.ServeMux
}
func (c *Consensus) makeCommandMux(auth *authorization) *http.ServeMux {
func (h authedHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
err := h.auth.refresh(r.Context())
if err != nil {
log.Printf("error authedHandler ServeHTTP refresh auth: %v", err)
http.Error(w, "", http.StatusInternalServerError)
return
}
a, err := addrFromServerAddress(r.RemoteAddr)
if err != nil {
log.Printf("error authedHandler ServeHTTP refresh auth: %v", err)
http.Error(w, "", http.StatusInternalServerError)
return
}
allowed := h.auth.allowsHost(a)
if !allowed {
http.Error(w, "peer not allowed", http.StatusUnauthorized)
return
}
h.mux.ServeHTTP(w, r)
}
func (c *Consensus) makeCommandMux() *http.ServeMux {
mux := http.NewServeMux()
mux.HandleFunc("/join", authorized(auth, func(w http.ResponseWriter, r *http.Request) {
mux.HandleFunc("/join", func(w http.ResponseWriter, r *http.Request) {
if r.Method != httpm.POST {
http.Error(w, "Bad Request", http.StatusBadRequest)
return
@ -132,8 +138,8 @@ func (c *Consensus) makeCommandMux(auth *authorization) *http.ServeMux {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
}))
mux.HandleFunc("/executeCommand", authorized(auth, func(w http.ResponseWriter, r *http.Request) {
})
mux.HandleFunc("/executeCommand", func(w http.ResponseWriter, r *http.Request) {
if r.Method != httpm.POST {
http.Error(w, "Bad Request", http.StatusBadRequest)
return
@ -155,6 +161,13 @@ func (c *Consensus) makeCommandMux(auth *authorization) *http.ServeMux {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
}))
})
return mux
}
func (c *Consensus) makeCommandHandler(auth *authorization) http.Handler {
return authedHandler{
mux: c.makeCommandMux(),
auth: auth,
}
}

View File

@ -370,8 +370,7 @@ func (c *Consensus) serveCmdHttp(ts *tsnet.Server, auth *authorization) (*http.S
if err != nil {
return nil, err
}
mux := c.makeCommandMux(auth)
srv := &http.Server{Handler: mux}
srv := &http.Server{Handler: c.makeCommandHandler(auth)}
go func() {
err := srv.Serve(ln)
log.Printf("CmdHttp stopped serving with err: %v", err)

View File

@ -682,8 +682,8 @@ func TestOnlyTaggedPeersCanJoin(t *testing.T) {
t.Fatal(err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusBadRequest {
t.Fatalf("join req when not tagged, expected status: %d, got: %d", http.StatusBadRequest, resp.StatusCode)
if resp.StatusCode != http.StatusUnauthorized {
t.Fatalf("join req when not tagged, expected status: %d, got: %d", http.StatusUnauthorized, resp.StatusCode)
}
rBody, _ := io.ReadAll(resp.Body)
sBody := strings.TrimSpace(string(rBody))