mirror of
https://github.com/tailscale/tailscale.git
synced 2025-07-31 16:23:44 +00:00
centralize cmd http auth
This commit is contained in:
parent
5afa742b06
commit
ae30f58b46
@ -9,6 +9,7 @@ import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
@ -83,30 +84,35 @@ func (rac *commandClient) executeCommand(host string, bs []byte) (CommandResult,
|
||||
return cr, nil
|
||||
}
|
||||
|
||||
func authorized(auth *authorization, fx func(http.ResponseWriter, *http.Request)) func(http.ResponseWriter, *http.Request) {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
err := auth.refresh(r.Context())
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
a, err := addrFromServerAddress(r.RemoteAddr)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
allowed := auth.allowsHost(a)
|
||||
if !allowed {
|
||||
http.Error(w, "peer not allowed", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
fx(w, r)
|
||||
}
|
||||
type authedHandler struct {
|
||||
auth *authorization
|
||||
mux *http.ServeMux
|
||||
}
|
||||
|
||||
func (c *Consensus) makeCommandMux(auth *authorization) *http.ServeMux {
|
||||
func (h authedHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
err := h.auth.refresh(r.Context())
|
||||
if err != nil {
|
||||
log.Printf("error authedHandler ServeHTTP refresh auth: %v", err)
|
||||
http.Error(w, "", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
a, err := addrFromServerAddress(r.RemoteAddr)
|
||||
if err != nil {
|
||||
log.Printf("error authedHandler ServeHTTP refresh auth: %v", err)
|
||||
http.Error(w, "", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
allowed := h.auth.allowsHost(a)
|
||||
if !allowed {
|
||||
http.Error(w, "peer not allowed", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
h.mux.ServeHTTP(w, r)
|
||||
}
|
||||
|
||||
func (c *Consensus) makeCommandMux() *http.ServeMux {
|
||||
mux := http.NewServeMux()
|
||||
mux.HandleFunc("/join", authorized(auth, func(w http.ResponseWriter, r *http.Request) {
|
||||
mux.HandleFunc("/join", func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != httpm.POST {
|
||||
http.Error(w, "Bad Request", http.StatusBadRequest)
|
||||
return
|
||||
@ -132,8 +138,8 @@ func (c *Consensus) makeCommandMux(auth *authorization) *http.ServeMux {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
}))
|
||||
mux.HandleFunc("/executeCommand", authorized(auth, func(w http.ResponseWriter, r *http.Request) {
|
||||
})
|
||||
mux.HandleFunc("/executeCommand", func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != httpm.POST {
|
||||
http.Error(w, "Bad Request", http.StatusBadRequest)
|
||||
return
|
||||
@ -155,6 +161,13 @@ func (c *Consensus) makeCommandMux(auth *authorization) *http.ServeMux {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
}))
|
||||
})
|
||||
return mux
|
||||
}
|
||||
|
||||
func (c *Consensus) makeCommandHandler(auth *authorization) http.Handler {
|
||||
return authedHandler{
|
||||
mux: c.makeCommandMux(),
|
||||
auth: auth,
|
||||
}
|
||||
}
|
||||
|
@ -370,8 +370,7 @@ func (c *Consensus) serveCmdHttp(ts *tsnet.Server, auth *authorization) (*http.S
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
mux := c.makeCommandMux(auth)
|
||||
srv := &http.Server{Handler: mux}
|
||||
srv := &http.Server{Handler: c.makeCommandHandler(auth)}
|
||||
go func() {
|
||||
err := srv.Serve(ln)
|
||||
log.Printf("CmdHttp stopped serving with err: %v", err)
|
||||
|
@ -682,8 +682,8 @@ func TestOnlyTaggedPeersCanJoin(t *testing.T) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode != http.StatusBadRequest {
|
||||
t.Fatalf("join req when not tagged, expected status: %d, got: %d", http.StatusBadRequest, resp.StatusCode)
|
||||
if resp.StatusCode != http.StatusUnauthorized {
|
||||
t.Fatalf("join req when not tagged, expected status: %d, got: %d", http.StatusUnauthorized, resp.StatusCode)
|
||||
}
|
||||
rBody, _ := io.ReadAll(resp.Body)
|
||||
sBody := strings.TrimSpace(string(rBody))
|
||||
|
Loading…
x
Reference in New Issue
Block a user