diff --git a/tsweb/tsweb.go b/tsweb/tsweb.go index b0a226c7f..0221dacb9 100644 --- a/tsweb/tsweb.go +++ b/tsweb/tsweb.go @@ -387,15 +387,18 @@ func StdHandler(h ReturnHandler, opts HandlerOptions) http.Handler { // LogHandler returns an http.Handler that logs to opts.Logf. // It logs both successful and failing requests. -// The log line includes the first error returned to [Handler] within. +// The log line includes the first error returned to [ErrorHandler] within. // The outer-most LogHandler(LogHandler(...)) does all of the logging. // Inner LogHandler instance do nothing. +// Panics are swallowed and their stack traces are put in the error. func LogHandler(h http.Handler, opts LogOptions) http.Handler { return logHandler{h, opts.withDefaults()} } // ErrorHandler converts a [ReturnHandler] into a standard [http.Handler]. // Errors are handled as specified by the [ReturnHandler.ServeHTTPReturn] method. +// When wrapped in a [LogHandler], panics are added to the [AccessLogRecord]; +// otherwise, panics continue up the stack. func ErrorHandler(h ReturnHandler, opts ErrorOptions) http.Handler { return errorHandler{h, opts.withDefaults()} } @@ -433,21 +436,34 @@ func (h logHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { RequestID: RequestIDFromContext(r.Context()), } - var bucket string - var startRecorded bool - if bs := h.opts.BucketedStats; bs != nil { - bucket = bs.bucketForRequest(r) - if bs.Started != nil { - switch v := bs.Started.Map.Get(bucket).(type) { - case *expvar.Int: - // If we've already seen this bucket for, count it immediately. - // Otherwise, for newly seen paths, only count retroactively - // (so started-finished doesn't go negative) so we don't fill - // this LabelMap up with internet scanning spam. - v.Add(1) - startRecorded = true - } + if bs := h.opts.BucketedStats; bs != nil && bs.Started != nil && bs.Finished != nil { + bucket := bs.bucketForRequest(r) + var startRecorded bool + switch v := bs.Started.Map.Get(bucket).(type) { + case *expvar.Int: + // If we've already seen this bucket for, count it immediately. + // Otherwise, for newly seen paths, only count retroactively + // (so started-finished doesn't go negative) so we don't fill + // this LabelMap up with internet scanning spam. + v.Add(1) + startRecorded = true } + defer func() { + // Only increment metrics for buckets that result in good HTTP statuses + // or when we know the start was already counted. + // Otherwise they get full of internet scanning noise. Only filtering 404 + // gets most of the way there but there are also plenty of URLs that are + // almost right but result in 400s too. Seem easier to just only ignore + // all 4xx and 5xx. + if startRecorded { + bs.Finished.Add(bucket, 1) + } else if msg.Code < 400 { + // This is the first non-error request for this bucket, + // so count it now retroactively. + bs.Started.Add(bucket, 1) + bs.Finished.Add(bucket, 1) + } + }() } if fn := h.opts.OnStart; fn != nil { @@ -467,29 +483,27 @@ func (h logHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { msg.Err = e })) - lw := &loggingResponseWriter{ResponseWriter: w, logf: h.opts.Logf} + lw := newLogResponseWriter(h.opts.Logf, w, r) - // Invoke the handler that we're logging. - var recovered any defer func() { + // If the handler panicked then make sure we include that in our error. + // Panics caught up errorHandler shouldn't appear here, unless the panic + // originates in one of its callbacks. + recovered := recover() if recovered != nil { - // TODO(icio): When the panic below is eventually caught by - // http.Server, it cancels the inlight request and the "500 Internal - // Server Error" response we wrote to the client below is never - // received, even if we flush it. - if f, ok := w.(http.Flusher); ok { - f.Flush() + if msg.Err == "" { + msg.Err = panic2err(recovered).Error() + } else { + msg.Err += "\n\nthen " + panic2err(recovered).Error() } - panic(recovered) } - }() - func() { - defer func() { - recovered = recover() - }() - h.h.ServeHTTP(lw, r) + h.logRequest(r, lw, msg) }() + h.h.ServeHTTP(lw, r) +} + +func (h logHandler) logRequest(r *http.Request, lw *loggingResponseWriter, msg AccessLogRecord) { // Complete our access log from the loggingResponseWriter. msg.Bytes = lw.bytes msg.Seconds = h.opts.Now().Sub(msg.Time).Seconds() @@ -515,22 +529,6 @@ func() { } // Closing metrics. - if bs := h.opts.BucketedStats; bs != nil && bs.Finished != nil { - // Only increment metrics for buckets that result in good HTTP statuses - // or when we know the start was already counted. - // Otherwise they get full of internet scanning noise. Only filtering 404 - // gets most of the way there but there are also plenty of URLs that are - // almost right but result in 400s too. Seem easier to just only ignore - // all 4xx and 5xx. - if startRecorded { - bs.Finished.Add(bucket, 1) - } else if msg.Code < 400 { - // This is the first non-error request for this bucket, - // so count it now retroactively. - bs.Started.Add(bucket, 1) - bs.Finished.Add(bucket, 1) - } - } if h.opts.StatusCodeCounters != nil { h.opts.StatusCodeCounters.Add(responseCodeString(msg.Code/100), 1) } @@ -573,7 +571,23 @@ type loggingResponseWriter struct { logf logger.Logf } -// WriteHeader implements http.Handler. +// newLogResponseWriter returns a loggingResponseWriter which uses's the logger +// from r, or falls back to logf. If a nil logger is given, the logs are +// discarded. +func newLogResponseWriter(logf logger.Logf, w http.ResponseWriter, r *http.Request) *loggingResponseWriter { + if l, ok := logger.LogfKey.ValueOk(r.Context()); ok && l != nil { + logf = l + } + if logf == nil { + logf = logger.Discard + } + return &loggingResponseWriter{ + ResponseWriter: w, + logf: logf, + } +} + +// WriteHeader implements [http.ResponseWriter]. func (l *loggingResponseWriter) WriteHeader(statusCode int) { if l.code != 0 { l.logf("[unexpected] HTTP handler set statusCode twice (%d and %d)", l.code, statusCode) @@ -583,7 +597,7 @@ func (l *loggingResponseWriter) WriteHeader(statusCode int) { l.ResponseWriter.WriteHeader(statusCode) } -// Write implements http.Handler. +// Write implements [http.ResponseWriter]. func (l *loggingResponseWriter) Write(bs []byte) (int, error) { if l.code == 0 { l.code = 200 @@ -626,57 +640,43 @@ type errorHandler struct { // ServeHTTP implements the http.Handler interface. func (h errorHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { - logf := h.opts.Logf - if l := logger.LogfKey.Value(r.Context()); l != nil { - logf = l - } - // Keep track of whether a response gets written. lw, ok := w.(*loggingResponseWriter) if !ok { - lw = &loggingResponseWriter{ - ResponseWriter: w, - logf: logf, - } + lw = newLogResponseWriter(h.opts.Logf, w, r) } - // In case the handler panics, we want to recover and continue logging the - // error before raising the panic again for the server to handle. - var panicRes any + var err error defer func() { - if panicRes != nil { - panic(panicRes) + // In case the handler panics, we want to recover and continue logging + // the error before logging it (or re-panicking if we couldn't log). + rec := recover() + if rec != nil { + err = panic2err(rec) + } + if err == nil { + return + } + if h.handleError(w, r, lw, err) { + return + } + if rec != nil { + // If we weren't able to log the panic somewhere, throw it up the + // stack to someone who can. + panic(rec) } }() - err := func() (err error) { - defer func() { - if r := recover(); r != nil { - panicRes = r - if r == http.ErrAbortHandler { - err = http.ErrAbortHandler - } else { - // Even if r is an error, do not wrap it as an error here as - // that would allow things like panic(vizerror.New("foo")) - // which is really hard to define the behavior of. - var stack [10000]byte - n := runtime.Stack(stack[:], false) - err = fmt.Errorf("panic: %v\n\n%s", r, stack[:n]) - } - } - }() - return h.rh.ServeHTTPReturn(lw, r) - }() - if err == nil { - return - } + err = h.rh.ServeHTTPReturn(lw, r) +} +func (h errorHandler) handleError(w http.ResponseWriter, r *http.Request, lw *loggingResponseWriter, err error) (logged bool) { // Extract a presentable, loggable error. var hOK bool var hErr HTTPError if errors.As(err, &hErr) { hOK = true if hErr.Code == 0 { - logf("[unexpected] HTTPError %v did not contain an HTTP status code, sending internal server error", hErr) + lw.logf("[unexpected] HTTPError %v did not contain an HTTP status code, sending internal server error", hErr) hErr.Code = http.StatusInternalServerError } } else if v, ok := vizerror.As(err); ok { @@ -696,23 +696,70 @@ func (h errorHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { } else if hErr.Msg != "" { pb(hErr.Msg) } + logged = true } if lw.code != 0 { if hOK { - logf("[unexpected] handler returned HTTPError %v, but already sent a response with code %d", hErr, lw.code) + lw.logf("[unexpected] handler returned HTTPError %v, but already sent a response with code %d", hErr, lw.code) } return } // Set a default error message from the status code. Do this after we pass // the error back to the logger so that `return errors.New("oh")` logs as - // `"err": "oh"`, not `"err": "internal server error: oh"`. + // `"err": "oh"`, not `"err": "Internal Server Error: oh"`. if hErr.Msg == "" { hErr.Msg = http.StatusText(hErr.Code) } + // If OnError panics before a response is written, write a bare 500 back. + defer func() { + if lw.code == 0 { + if rec := recover(); rec != nil { + w.WriteHeader(http.StatusInternalServerError) + panic(rec) + } + } + }() + h.opts.OnError(w, r, hErr) + return logged +} + +// panic2err converts a recovered value to an error containing the panic stack trace. +func panic2err(recovered any) error { + if recovered == nil { + return nil + } + if recovered == http.ErrAbortHandler { + return http.ErrAbortHandler + } + + // Even if r is an error, do not wrap it as an error here as + // that would allow things like panic(vizerror.New("foo")) + // which is really hard to define the behavior of. + var stack [10000]byte + n := runtime.Stack(stack[:], false) + return &panicError{ + rec: recovered, + stack: stack[:n], + } +} + +// panicError is an error that contains a panic. +type panicError struct { + rec any + stack []byte +} + +func (e *panicError) Error() string { + return fmt.Sprintf("panic: %v\n\n%s", e.rec, e.stack) +} + +func (e *panicError) Unwrap() error { + err, _ := e.rec.(error) + return err } // writeHTTPError is the default error response formatter. diff --git a/tsweb/tsweb_test.go b/tsweb/tsweb_test.go index fff7cc805..18bb7e48d 100644 --- a/tsweb/tsweb_test.go +++ b/tsweb/tsweb_test.go @@ -662,40 +662,120 @@ func TestStdHandler_Panic(t *testing.T) { if err != nil { t.Fatal(err) } - if <-recovered == nil { - t.Fatal("expected panic but saw none") + if rec := <-recovered; rec != nil { + t.Fatalf("expected no panic but saw: %v", rec) } // Check that the log message contained the stack trace in the error. var logerr bool if p := "panic: panicked elsewhere\n\ngoroutine "; !strings.HasPrefix(r.Err, p) { - t.Errorf("got error prefix %q, want %q", r.Err[:min(len(r.Err), len(p))], p) + t.Errorf("got Err prefix %q, want %q", r.Err[:min(len(r.Err), len(p))], p) logerr = true } if s := "\ntailscale.com/tsweb.panicElsewhere("; !strings.Contains(r.Err, s) { - t.Errorf("want substr %q, not found", s) + t.Errorf("want Err substr %q, not found", s) logerr = true } if logerr { t.Logf("logger got error: (quoted) %q\n\n(verbatim)\n%s", r.Err, r.Err) } - t.Run("check_response", func(t *testing.T) { - // TODO(icio): Swallow panics? tailscale/tailscale#12784 - t.SkipNow() + // Check that the server sent an error response. + if res.StatusCode != 500 { + t.Errorf("got status code %d, want %d", res.StatusCode, 500) + } + body, err := io.ReadAll(res.Body) + if err != nil { + t.Errorf("error reading body: %s", err) + } else if want := "Internal Server Error\n"; string(body) != want { + t.Errorf("got body %q, want %q", body, want) + } + res.Body.Close() +} - // Check that the server sent an error response. - if res.StatusCode != 500 { - t.Errorf("got status code %d, want %d", res.StatusCode, 500) +func TestStdHandler_OnErrorPanic(t *testing.T) { + var r AccessLogRecord + h := StdHandler( + ReturnHandlerFunc(func(w http.ResponseWriter, r *http.Request) error { + // This response is supposed to be written by OnError, but it panics + // so nothing is written. + return Error(401, "lacking auth", nil) + }), + HandlerOptions{ + Logf: t.Logf, + OnError: func(w http.ResponseWriter, r *http.Request, h HTTPError) { + panicElsewhere() + }, + OnCompletion: func(_ *http.Request, alr AccessLogRecord) { + r = alr + }, + }, + ) + + // Run our panicking handler in a http.Server which catches and rethrows + // any panics. + recovered := make(chan any, 1) + s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + defer func() { + recovered <- recover() + }() + h.ServeHTTP(w, r) + })) + t.Cleanup(s.Close) + + // Send a request to our server. + res, err := http.Get(s.URL) + if err != nil { + t.Fatal(err) + } + if rec := <-recovered; rec != nil { + t.Fatalf("expected no panic but saw: %v", rec) + } + + // Check that the log message contained the stack trace in the error. + var logerr bool + if p := "lacking auth\n\nthen panic: panicked elsewhere\n\ngoroutine "; !strings.HasPrefix(r.Err, p) { + t.Errorf("got Err prefix %q, want %q", r.Err[:min(len(r.Err), len(p))], p) + logerr = true + } + if s := "\ntailscale.com/tsweb.panicElsewhere("; !strings.Contains(r.Err, s) { + t.Errorf("want Err substr %q, not found", s) + logerr = true + } + if logerr { + t.Logf("logger got error: (quoted) %q\n\n(verbatim)\n%s", r.Err, r.Err) + } + + // Check that the server sent a bare 500 response. + if res.StatusCode != 500 { + t.Errorf("got status code %d, want %d", res.StatusCode, 500) + } + body, err := io.ReadAll(res.Body) + if err != nil { + t.Errorf("error reading body: %s", err) + } else if want := ""; string(body) != want { + t.Errorf("got body %q, want %q", body, want) + } + res.Body.Close() +} + +func TestErrorHandler_Panic(t *testing.T) { + // errorHandler should panic when not wrapped in logHandler. + defer func() { + rec := recover() + if rec == nil { + t.Fatal("expected errorHandler to panic when not wrapped in logHandler") } - body, err := io.ReadAll(res.Body) - if err != nil { - t.Errorf("error reading body: %s", err) - } else if want := "internal server error\n"; string(body) != want { - t.Errorf("got body %q, want %q", body, want) + if want := any("uhoh"); rec != want { + t.Fatalf("got panic %#v, want %#v", rec, want) } - res.Body.Close() - }) + }() + ErrorHandler( + ReturnHandlerFunc(func(w http.ResponseWriter, r *http.Request) error { + panic("uhoh") + }), + ErrorOptions{}, + ).ServeHTTP(httptest.NewRecorder(), httptest.NewRequest("GET", "/", nil)) } func panicElsewhere() {