max bytes reader

This commit is contained in:
Fran Bull 2025-03-17 12:05:30 -07:00
parent b880227d41
commit c4664c20b8
2 changed files with 10 additions and 4 deletions

View File

@ -31,6 +31,12 @@ func (rac *commandClient) url(host string, path string) string {
return fmt.Sprintf("http://%s:%d%s", host, rac.port, path) return fmt.Sprintf("http://%s:%d%s", host, rac.port, path)
} }
const maxBodyBytes = 1024 * 1024
func readAllMaxBytes(r io.Reader) ([]byte, error) {
return io.ReadAll(io.LimitReader(r, maxBodyBytes))
}
func (rac *commandClient) join(host string, jr joinRequest) error { func (rac *commandClient) join(host string, jr joinRequest) error {
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel() defer cancel()
@ -48,7 +54,7 @@ func (rac *commandClient) join(host string, jr joinRequest) error {
return err return err
} }
defer resp.Body.Close() defer resp.Body.Close()
respBs, err := io.ReadAll(resp.Body) respBs, err := readAllMaxBytes(resp.Body)
if err != nil { if err != nil {
return err return err
} }
@ -71,7 +77,7 @@ func (rac *commandClient) executeCommand(host string, bs []byte) (CommandResult,
return CommandResult{}, err return CommandResult{}, err
} }
defer resp.Body.Close() defer resp.Body.Close()
respBs, err := io.ReadAll(resp.Body) respBs, err := readAllMaxBytes(resp.Body)
if err != nil { if err != nil {
return CommandResult{}, err return CommandResult{}, err
} }
@ -113,7 +119,7 @@ func (h authedHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
func (c *Consensus) handleJoinHTTP(w http.ResponseWriter, r *http.Request) { func (c *Consensus) handleJoinHTTP(w http.ResponseWriter, r *http.Request) {
defer r.Body.Close() defer r.Body.Close()
decoder := json.NewDecoder(http.MaxBytesReader(w, r.Body, 1024*1024)) decoder := json.NewDecoder(http.MaxBytesReader(w, r.Body, maxBodyBytes))
var jr joinRequest var jr joinRequest
err := decoder.Decode(&jr) err := decoder.Decode(&jr)
if err != nil { if err != nil {

View File

@ -138,7 +138,7 @@ func (m *monitor) handleDial(w http.ResponseWriter, r *http.Request) {
Addr string Addr string
} }
defer r.Body.Close() defer r.Body.Close()
bs, err := io.ReadAll(r.Body) bs, err := io.ReadAll(http.MaxBytesReader(w, r.Body, maxBodyBytes))
if err != nil { if err != nil {
log.Printf("monitor: error reading body: %v", err) log.Printf("monitor: error reading body: %v", err)
http.Error(w, "", http.StatusInternalServerError) http.Error(w, "", http.StatusInternalServerError)