diff --git a/tsconsensus/http.go b/tsconsensus/http.go index 2293a3f0a..a1f06cf5a 100644 --- a/tsconsensus/http.go +++ b/tsconsensus/http.go @@ -31,6 +31,12 @@ func (rac *commandClient) url(host string, path string) string { return fmt.Sprintf("http://%s:%d%s", host, rac.port, path) } +const maxBodyBytes = 1024 * 1024 + +func readAllMaxBytes(r io.Reader) ([]byte, error) { + return io.ReadAll(io.LimitReader(r, maxBodyBytes)) +} + func (rac *commandClient) join(host string, jr joinRequest) error { ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) defer cancel() @@ -48,7 +54,7 @@ func (rac *commandClient) join(host string, jr joinRequest) error { return err } defer resp.Body.Close() - respBs, err := io.ReadAll(resp.Body) + respBs, err := readAllMaxBytes(resp.Body) if err != nil { return err } @@ -71,7 +77,7 @@ func (rac *commandClient) executeCommand(host string, bs []byte) (CommandResult, return CommandResult{}, err } defer resp.Body.Close() - respBs, err := io.ReadAll(resp.Body) + respBs, err := readAllMaxBytes(resp.Body) if err != nil { return CommandResult{}, err } @@ -113,7 +119,7 @@ func (h authedHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { func (c *Consensus) handleJoinHTTP(w http.ResponseWriter, r *http.Request) { defer r.Body.Close() - decoder := json.NewDecoder(http.MaxBytesReader(w, r.Body, 1024*1024)) + decoder := json.NewDecoder(http.MaxBytesReader(w, r.Body, maxBodyBytes)) var jr joinRequest err := decoder.Decode(&jr) if err != nil { diff --git a/tsconsensus/monitor.go b/tsconsensus/monitor.go index a475b1e71..e337e78be 100644 --- a/tsconsensus/monitor.go +++ b/tsconsensus/monitor.go @@ -138,7 +138,7 @@ func (m *monitor) handleDial(w http.ResponseWriter, r *http.Request) { Addr string } defer r.Body.Close() - bs, err := io.ReadAll(r.Body) + bs, err := io.ReadAll(http.MaxBytesReader(w, r.Body, maxBodyBytes)) if err != nil { log.Printf("monitor: error reading body: %v", err) http.Error(w, "", http.StatusInternalServerError)