tsweb: swallow panics

With this change, the error handling and request logging are all done in defers
after calling inner.ServeHTTP. This ensures that any recovered values which we
want to re-panic with retain a useful stacktrace.  However, we now only
re-panic from errorHandler when there's no outside logHandler. Which if you're
using StdHandler there always is. We prefer this to ensure that we are able to
write a 500 Internal Server Error to the client. If a panic hits http.Server
then the response is not sent back.

Updates #12784

Signed-off-by: Paul Scott <paul@tailscale.com>
This commit is contained in:
Paul Scott 2024-07-16 11:32:43 +01:00 committed by Paul Scott
parent f77821fd63
commit d97cddd876
2 changed files with 231 additions and 104 deletions

View File

@ -387,15 +387,18 @@ func StdHandler(h ReturnHandler, opts HandlerOptions) http.Handler {
// LogHandler returns an http.Handler that logs to opts.Logf. // LogHandler returns an http.Handler that logs to opts.Logf.
// It logs both successful and failing requests. // It logs both successful and failing requests.
// The log line includes the first error returned to [Handler] within. // The log line includes the first error returned to [ErrorHandler] within.
// The outer-most LogHandler(LogHandler(...)) does all of the logging. // The outer-most LogHandler(LogHandler(...)) does all of the logging.
// Inner LogHandler instance do nothing. // Inner LogHandler instance do nothing.
// Panics are swallowed and their stack traces are put in the error.
func LogHandler(h http.Handler, opts LogOptions) http.Handler { func LogHandler(h http.Handler, opts LogOptions) http.Handler {
return logHandler{h, opts.withDefaults()} return logHandler{h, opts.withDefaults()}
} }
// ErrorHandler converts a [ReturnHandler] into a standard [http.Handler]. // ErrorHandler converts a [ReturnHandler] into a standard [http.Handler].
// Errors are handled as specified by the [ReturnHandler.ServeHTTPReturn] method. // Errors are handled as specified by the [ReturnHandler.ServeHTTPReturn] method.
// When wrapped in a [LogHandler], panics are added to the [AccessLogRecord];
// otherwise, panics continue up the stack.
func ErrorHandler(h ReturnHandler, opts ErrorOptions) http.Handler { func ErrorHandler(h ReturnHandler, opts ErrorOptions) http.Handler {
return errorHandler{h, opts.withDefaults()} return errorHandler{h, opts.withDefaults()}
} }
@ -433,21 +436,34 @@ func (h logHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
RequestID: RequestIDFromContext(r.Context()), RequestID: RequestIDFromContext(r.Context()),
} }
var bucket string if bs := h.opts.BucketedStats; bs != nil && bs.Started != nil && bs.Finished != nil {
var startRecorded bool bucket := bs.bucketForRequest(r)
if bs := h.opts.BucketedStats; bs != nil { var startRecorded bool
bucket = bs.bucketForRequest(r) switch v := bs.Started.Map.Get(bucket).(type) {
if bs.Started != nil { case *expvar.Int:
switch v := bs.Started.Map.Get(bucket).(type) { // If we've already seen this bucket for, count it immediately.
case *expvar.Int: // Otherwise, for newly seen paths, only count retroactively
// If we've already seen this bucket for, count it immediately. // (so started-finished doesn't go negative) so we don't fill
// Otherwise, for newly seen paths, only count retroactively // this LabelMap up with internet scanning spam.
// (so started-finished doesn't go negative) so we don't fill v.Add(1)
// this LabelMap up with internet scanning spam. startRecorded = true
v.Add(1)
startRecorded = true
}
} }
defer func() {
// Only increment metrics for buckets that result in good HTTP statuses
// or when we know the start was already counted.
// Otherwise they get full of internet scanning noise. Only filtering 404
// gets most of the way there but there are also plenty of URLs that are
// almost right but result in 400s too. Seem easier to just only ignore
// all 4xx and 5xx.
if startRecorded {
bs.Finished.Add(bucket, 1)
} else if msg.Code < 400 {
// This is the first non-error request for this bucket,
// so count it now retroactively.
bs.Started.Add(bucket, 1)
bs.Finished.Add(bucket, 1)
}
}()
} }
if fn := h.opts.OnStart; fn != nil { if fn := h.opts.OnStart; fn != nil {
@ -467,29 +483,27 @@ func (h logHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
msg.Err = e msg.Err = e
})) }))
lw := &loggingResponseWriter{ResponseWriter: w, logf: h.opts.Logf} lw := newLogResponseWriter(h.opts.Logf, w, r)
// Invoke the handler that we're logging.
var recovered any
defer func() { defer func() {
// If the handler panicked then make sure we include that in our error.
// Panics caught up errorHandler shouldn't appear here, unless the panic
// originates in one of its callbacks.
recovered := recover()
if recovered != nil { if recovered != nil {
// TODO(icio): When the panic below is eventually caught by if msg.Err == "" {
// http.Server, it cancels the inlight request and the "500 Internal msg.Err = panic2err(recovered).Error()
// Server Error" response we wrote to the client below is never } else {
// received, even if we flush it. msg.Err += "\n\nthen " + panic2err(recovered).Error()
if f, ok := w.(http.Flusher); ok {
f.Flush()
} }
panic(recovered)
} }
}() h.logRequest(r, lw, msg)
func() {
defer func() {
recovered = recover()
}()
h.h.ServeHTTP(lw, r)
}() }()
h.h.ServeHTTP(lw, r)
}
func (h logHandler) logRequest(r *http.Request, lw *loggingResponseWriter, msg AccessLogRecord) {
// Complete our access log from the loggingResponseWriter. // Complete our access log from the loggingResponseWriter.
msg.Bytes = lw.bytes msg.Bytes = lw.bytes
msg.Seconds = h.opts.Now().Sub(msg.Time).Seconds() msg.Seconds = h.opts.Now().Sub(msg.Time).Seconds()
@ -515,22 +529,6 @@ func() {
} }
// Closing metrics. // Closing metrics.
if bs := h.opts.BucketedStats; bs != nil && bs.Finished != nil {
// Only increment metrics for buckets that result in good HTTP statuses
// or when we know the start was already counted.
// Otherwise they get full of internet scanning noise. Only filtering 404
// gets most of the way there but there are also plenty of URLs that are
// almost right but result in 400s too. Seem easier to just only ignore
// all 4xx and 5xx.
if startRecorded {
bs.Finished.Add(bucket, 1)
} else if msg.Code < 400 {
// This is the first non-error request for this bucket,
// so count it now retroactively.
bs.Started.Add(bucket, 1)
bs.Finished.Add(bucket, 1)
}
}
if h.opts.StatusCodeCounters != nil { if h.opts.StatusCodeCounters != nil {
h.opts.StatusCodeCounters.Add(responseCodeString(msg.Code/100), 1) h.opts.StatusCodeCounters.Add(responseCodeString(msg.Code/100), 1)
} }
@ -573,7 +571,23 @@ type loggingResponseWriter struct {
logf logger.Logf logf logger.Logf
} }
// WriteHeader implements http.Handler. // newLogResponseWriter returns a loggingResponseWriter which uses's the logger
// from r, or falls back to logf. If a nil logger is given, the logs are
// discarded.
func newLogResponseWriter(logf logger.Logf, w http.ResponseWriter, r *http.Request) *loggingResponseWriter {
if l, ok := logger.LogfKey.ValueOk(r.Context()); ok && l != nil {
logf = l
}
if logf == nil {
logf = logger.Discard
}
return &loggingResponseWriter{
ResponseWriter: w,
logf: logf,
}
}
// WriteHeader implements [http.ResponseWriter].
func (l *loggingResponseWriter) WriteHeader(statusCode int) { func (l *loggingResponseWriter) WriteHeader(statusCode int) {
if l.code != 0 { if l.code != 0 {
l.logf("[unexpected] HTTP handler set statusCode twice (%d and %d)", l.code, statusCode) l.logf("[unexpected] HTTP handler set statusCode twice (%d and %d)", l.code, statusCode)
@ -583,7 +597,7 @@ func (l *loggingResponseWriter) WriteHeader(statusCode int) {
l.ResponseWriter.WriteHeader(statusCode) l.ResponseWriter.WriteHeader(statusCode)
} }
// Write implements http.Handler. // Write implements [http.ResponseWriter].
func (l *loggingResponseWriter) Write(bs []byte) (int, error) { func (l *loggingResponseWriter) Write(bs []byte) (int, error) {
if l.code == 0 { if l.code == 0 {
l.code = 200 l.code = 200
@ -626,57 +640,43 @@ type errorHandler struct {
// ServeHTTP implements the http.Handler interface. // ServeHTTP implements the http.Handler interface.
func (h errorHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { func (h errorHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
logf := h.opts.Logf
if l := logger.LogfKey.Value(r.Context()); l != nil {
logf = l
}
// Keep track of whether a response gets written. // Keep track of whether a response gets written.
lw, ok := w.(*loggingResponseWriter) lw, ok := w.(*loggingResponseWriter)
if !ok { if !ok {
lw = &loggingResponseWriter{ lw = newLogResponseWriter(h.opts.Logf, w, r)
ResponseWriter: w,
logf: logf,
}
} }
// In case the handler panics, we want to recover and continue logging the var err error
// error before raising the panic again for the server to handle.
var panicRes any
defer func() { defer func() {
if panicRes != nil { // In case the handler panics, we want to recover and continue logging
panic(panicRes) // the error before logging it (or re-panicking if we couldn't log).
rec := recover()
if rec != nil {
err = panic2err(rec)
}
if err == nil {
return
}
if h.handleError(w, r, lw, err) {
return
}
if rec != nil {
// If we weren't able to log the panic somewhere, throw it up the
// stack to someone who can.
panic(rec)
} }
}() }()
err := func() (err error) { err = h.rh.ServeHTTPReturn(lw, r)
defer func() { }
if r := recover(); r != nil {
panicRes = r
if r == http.ErrAbortHandler {
err = http.ErrAbortHandler
} else {
// Even if r is an error, do not wrap it as an error here as
// that would allow things like panic(vizerror.New("foo"))
// which is really hard to define the behavior of.
var stack [10000]byte
n := runtime.Stack(stack[:], false)
err = fmt.Errorf("panic: %v\n\n%s", r, stack[:n])
}
}
}()
return h.rh.ServeHTTPReturn(lw, r)
}()
if err == nil {
return
}
func (h errorHandler) handleError(w http.ResponseWriter, r *http.Request, lw *loggingResponseWriter, err error) (logged bool) {
// Extract a presentable, loggable error. // Extract a presentable, loggable error.
var hOK bool var hOK bool
var hErr HTTPError var hErr HTTPError
if errors.As(err, &hErr) { if errors.As(err, &hErr) {
hOK = true hOK = true
if hErr.Code == 0 { if hErr.Code == 0 {
logf("[unexpected] HTTPError %v did not contain an HTTP status code, sending internal server error", hErr) lw.logf("[unexpected] HTTPError %v did not contain an HTTP status code, sending internal server error", hErr)
hErr.Code = http.StatusInternalServerError hErr.Code = http.StatusInternalServerError
} }
} else if v, ok := vizerror.As(err); ok { } else if v, ok := vizerror.As(err); ok {
@ -696,23 +696,70 @@ func (h errorHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
} else if hErr.Msg != "" { } else if hErr.Msg != "" {
pb(hErr.Msg) pb(hErr.Msg)
} }
logged = true
} }
if lw.code != 0 { if lw.code != 0 {
if hOK { if hOK {
logf("[unexpected] handler returned HTTPError %v, but already sent a response with code %d", hErr, lw.code) lw.logf("[unexpected] handler returned HTTPError %v, but already sent a response with code %d", hErr, lw.code)
} }
return return
} }
// Set a default error message from the status code. Do this after we pass // Set a default error message from the status code. Do this after we pass
// the error back to the logger so that `return errors.New("oh")` logs as // the error back to the logger so that `return errors.New("oh")` logs as
// `"err": "oh"`, not `"err": "internal server error: oh"`. // `"err": "oh"`, not `"err": "Internal Server Error: oh"`.
if hErr.Msg == "" { if hErr.Msg == "" {
hErr.Msg = http.StatusText(hErr.Code) hErr.Msg = http.StatusText(hErr.Code)
} }
// If OnError panics before a response is written, write a bare 500 back.
defer func() {
if lw.code == 0 {
if rec := recover(); rec != nil {
w.WriteHeader(http.StatusInternalServerError)
panic(rec)
}
}
}()
h.opts.OnError(w, r, hErr) h.opts.OnError(w, r, hErr)
return logged
}
// panic2err converts a recovered value to an error containing the panic stack trace.
func panic2err(recovered any) error {
if recovered == nil {
return nil
}
if recovered == http.ErrAbortHandler {
return http.ErrAbortHandler
}
// Even if r is an error, do not wrap it as an error here as
// that would allow things like panic(vizerror.New("foo"))
// which is really hard to define the behavior of.
var stack [10000]byte
n := runtime.Stack(stack[:], false)
return &panicError{
rec: recovered,
stack: stack[:n],
}
}
// panicError is an error that contains a panic.
type panicError struct {
rec any
stack []byte
}
func (e *panicError) Error() string {
return fmt.Sprintf("panic: %v\n\n%s", e.rec, e.stack)
}
func (e *panicError) Unwrap() error {
err, _ := e.rec.(error)
return err
} }
// writeHTTPError is the default error response formatter. // writeHTTPError is the default error response formatter.

View File

@ -662,40 +662,120 @@ func TestStdHandler_Panic(t *testing.T) {
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
if <-recovered == nil { if rec := <-recovered; rec != nil {
t.Fatal("expected panic but saw none") t.Fatalf("expected no panic but saw: %v", rec)
} }
// Check that the log message contained the stack trace in the error. // Check that the log message contained the stack trace in the error.
var logerr bool var logerr bool
if p := "panic: panicked elsewhere\n\ngoroutine "; !strings.HasPrefix(r.Err, p) { if p := "panic: panicked elsewhere\n\ngoroutine "; !strings.HasPrefix(r.Err, p) {
t.Errorf("got error prefix %q, want %q", r.Err[:min(len(r.Err), len(p))], p) t.Errorf("got Err prefix %q, want %q", r.Err[:min(len(r.Err), len(p))], p)
logerr = true logerr = true
} }
if s := "\ntailscale.com/tsweb.panicElsewhere("; !strings.Contains(r.Err, s) { if s := "\ntailscale.com/tsweb.panicElsewhere("; !strings.Contains(r.Err, s) {
t.Errorf("want substr %q, not found", s) t.Errorf("want Err substr %q, not found", s)
logerr = true logerr = true
} }
if logerr { if logerr {
t.Logf("logger got error: (quoted) %q\n\n(verbatim)\n%s", r.Err, r.Err) t.Logf("logger got error: (quoted) %q\n\n(verbatim)\n%s", r.Err, r.Err)
} }
t.Run("check_response", func(t *testing.T) { // Check that the server sent an error response.
// TODO(icio): Swallow panics? tailscale/tailscale#12784 if res.StatusCode != 500 {
t.SkipNow() t.Errorf("got status code %d, want %d", res.StatusCode, 500)
}
body, err := io.ReadAll(res.Body)
if err != nil {
t.Errorf("error reading body: %s", err)
} else if want := "Internal Server Error\n"; string(body) != want {
t.Errorf("got body %q, want %q", body, want)
}
res.Body.Close()
}
// Check that the server sent an error response. func TestStdHandler_OnErrorPanic(t *testing.T) {
if res.StatusCode != 500 { var r AccessLogRecord
t.Errorf("got status code %d, want %d", res.StatusCode, 500) h := StdHandler(
ReturnHandlerFunc(func(w http.ResponseWriter, r *http.Request) error {
// This response is supposed to be written by OnError, but it panics
// so nothing is written.
return Error(401, "lacking auth", nil)
}),
HandlerOptions{
Logf: t.Logf,
OnError: func(w http.ResponseWriter, r *http.Request, h HTTPError) {
panicElsewhere()
},
OnCompletion: func(_ *http.Request, alr AccessLogRecord) {
r = alr
},
},
)
// Run our panicking handler in a http.Server which catches and rethrows
// any panics.
recovered := make(chan any, 1)
s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
defer func() {
recovered <- recover()
}()
h.ServeHTTP(w, r)
}))
t.Cleanup(s.Close)
// Send a request to our server.
res, err := http.Get(s.URL)
if err != nil {
t.Fatal(err)
}
if rec := <-recovered; rec != nil {
t.Fatalf("expected no panic but saw: %v", rec)
}
// Check that the log message contained the stack trace in the error.
var logerr bool
if p := "lacking auth\n\nthen panic: panicked elsewhere\n\ngoroutine "; !strings.HasPrefix(r.Err, p) {
t.Errorf("got Err prefix %q, want %q", r.Err[:min(len(r.Err), len(p))], p)
logerr = true
}
if s := "\ntailscale.com/tsweb.panicElsewhere("; !strings.Contains(r.Err, s) {
t.Errorf("want Err substr %q, not found", s)
logerr = true
}
if logerr {
t.Logf("logger got error: (quoted) %q\n\n(verbatim)\n%s", r.Err, r.Err)
}
// Check that the server sent a bare 500 response.
if res.StatusCode != 500 {
t.Errorf("got status code %d, want %d", res.StatusCode, 500)
}
body, err := io.ReadAll(res.Body)
if err != nil {
t.Errorf("error reading body: %s", err)
} else if want := ""; string(body) != want {
t.Errorf("got body %q, want %q", body, want)
}
res.Body.Close()
}
func TestErrorHandler_Panic(t *testing.T) {
// errorHandler should panic when not wrapped in logHandler.
defer func() {
rec := recover()
if rec == nil {
t.Fatal("expected errorHandler to panic when not wrapped in logHandler")
} }
body, err := io.ReadAll(res.Body) if want := any("uhoh"); rec != want {
if err != nil { t.Fatalf("got panic %#v, want %#v", rec, want)
t.Errorf("error reading body: %s", err)
} else if want := "internal server error\n"; string(body) != want {
t.Errorf("got body %q, want %q", body, want)
} }
res.Body.Close() }()
}) ErrorHandler(
ReturnHandlerFunc(func(w http.ResponseWriter, r *http.Request) error {
panic("uhoh")
}),
ErrorOptions{},
).ServeHTTP(httptest.NewRecorder(), httptest.NewRequest("GET", "/", nil))
} }
func panicElsewhere() { func panicElsewhere() {