From ba7f2d129eb1a576ef970883dcf8cbbb2ae0a078 Mon Sep 17 00:00:00 2001 From: Paul Scott <408401+icio@users.noreply.github.com> Date: Wed, 24 Jul 2024 08:58:06 +0100 Subject: [PATCH] tsweb: log all cancellations as 499s (#12894) Updates #12141 Signed-off-by: Paul Scott --- tsweb/tsweb.go | 92 ++++++++++++++------- tsweb/tsweb_test.go | 191 +++++++++++++++++++++++++++++++++++++++++++- 2 files changed, 252 insertions(+), 31 deletions(-) diff --git a/tsweb/tsweb.go b/tsweb/tsweb.go index 430ccd5b8..2f3b1eae3 100644 --- a/tsweb/tsweb.go +++ b/tsweb/tsweb.go @@ -338,7 +338,7 @@ func (opts ErrorOptions) withDefaults() ErrorOptions { opts.Logf = logger.Discard } if opts.OnError == nil { - opts.OnError = writeHTTPError + opts.OnError = WriteHTTPError } return opts } @@ -405,7 +405,7 @@ func ErrorHandler(h ReturnHandler, opts ErrorOptions) http.Handler { // errCallback is added to logHandler's request context so that errorHandler can // pass errors back up the stack to logHandler. -var errCallback = ctxkey.New[func(string)]("tailscale.com/tsweb.errCallback", nil) +var errCallback = ctxkey.New[func(HTTPError)]("tailscale.com/tsweb.errCallback", nil) // logHandler is a http.Handler which logs the HTTP request. // It injects an errCallback for errorHandler to augment the log message with @@ -471,9 +471,25 @@ func (h logHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { } // Let errorHandler tell us what error it wrote to the client. - r = r.WithContext(errCallback.WithValue(ctx, func(e string) { - if msg.Err == "" { - msg.Err = e // Keep the first error. + r = r.WithContext(errCallback.WithValue(ctx, func(e HTTPError) { + // Keep the deepest error. + if msg.Err != "" { + return + } + + // Log the error. + if e.Msg != "" && e.Err != nil { + msg.Err = e.Msg + ": " + e.Err.Error() + } else if e.Err != nil { + msg.Err = e.Err.Error() + } else if e.Msg != "" { + msg.Err = e.Msg + } + + // We log the code from the loggingResponseWriter, except for + // cancellation where we override with 499. + if reqCancelled(r, e.Err) { + msg.Code = 499 } })) @@ -502,23 +518,29 @@ func (h logHandler) logRequest(r *http.Request, lw *loggingResponseWriter, msg A msg.Bytes = lw.bytes msg.Seconds = h.opts.Now().Sub(msg.Time).Seconds() switch { + case msg.Code != 0: + // Keep explicit codes from a few particular errors. case lw.hijacked: // Connection no longer belongs to us, just log that we // switched protocols away from HTTP. msg.Code = http.StatusSwitchingProtocols case lw.code == 0: - if r.Context().Err() != nil { - // We didn't write a response before the client disconnected. - msg.Code = 499 - } else { - // If the handler didn't write and didn't send a header, that still means 200. - // (See https://play.golang.org/p/4P7nx_Tap7p) - msg.Code = 200 - } + // If the handler didn't write and didn't send a header, that still means 200. + // (See https://play.golang.org/p/4P7nx_Tap7p) + msg.Code = 200 default: msg.Code = lw.code } + // Keep track of the original response code when we've overridden it. + if lw.code != 0 && msg.Code != lw.code { + if msg.Err == "" { + msg.Err = fmt.Sprintf("(original code %d)", lw.code) + } else { + msg.Err = fmt.Sprintf("%s (original code %d)", msg.Err, lw.code) + } + } + if !h.opts.QuietLoggingIfSuccessful || (msg.Code != http.StatusOK && msg.Code != http.StatusNotModified) { h.opts.Logf("%s", msg) } @@ -564,6 +586,7 @@ func responseCodeString(code int) string { // response code that gets sent, if any. type loggingResponseWriter struct { http.ResponseWriter + ctx context.Context code int bytes int hijacked bool @@ -582,6 +605,7 @@ func newLogResponseWriter(logf logger.Logf, w http.ResponseWriter, r *http.Reque } return &loggingResponseWriter{ ResponseWriter: w, + ctx: r.Context(), logf: logf, } } @@ -592,7 +616,9 @@ func (l *loggingResponseWriter) WriteHeader(statusCode int) { l.logf("[unexpected] HTTP handler set statusCode twice (%d and %d)", l.code, statusCode) return } - l.code = statusCode + if l.ctx.Err() == nil { + l.code = statusCode + } l.ResponseWriter.WriteHeader(statusCode) } @@ -682,8 +708,13 @@ func (h errorHandler) handleError(w http.ResponseWriter, r *http.Request, lw *lo } } else if v, ok := vizerror.As(err); ok { hErr = Error(http.StatusInternalServerError, v.Error(), nil) - } else if errors.Is(err, context.Canceled) || errors.Is(err, http.ErrAbortHandler) { - hErr = Error(499, "", err) // Nginx convention + } else if reqCancelled(r, err) { + // 499 is the Nginx convention meaning "Client Closed Connection". + if errors.Is(err, context.Canceled) || errors.Is(err, http.ErrAbortHandler) { + hErr = Error(499, "", err) + } else { + hErr = Error(499, "", fmt.Errorf("%w: %w", context.Canceled, err)) + } } else { // Omit the friendly message so HTTP logs show the bare error that was // returned and we know it's not a HTTPError. @@ -692,13 +723,7 @@ func (h errorHandler) handleError(w http.ResponseWriter, r *http.Request, lw *lo // Tell the logger what error we wrote back to the client. if pb := errCallback.Value(r.Context()); pb != nil { - if hErr.Msg != "" && hErr.Err != nil { - pb(hErr.Msg + ": " + hErr.Err.Error()) - } else if hErr.Err != nil { - pb(hErr.Err.Error()) - } else if hErr.Msg != "" { - pb(hErr.Msg) - } + pb(hErr) logged = true } @@ -775,21 +800,32 @@ func (e *panicError) Unwrap() error { return err } -// writeHTTPError is the default error response formatter. -func writeHTTPError(w http.ResponseWriter, r *http.Request, hErr HTTPError) { +// reqCancelled returns true if err is http.ErrAbortHandler or r.Context.Err() +// is context.Canceled. +func reqCancelled(r *http.Request, err error) bool { + return errors.Is(err, http.ErrAbortHandler) || r.Context().Err() == context.Canceled +} + +// WriteHTTPError is the default error response formatter. +func WriteHTTPError(w http.ResponseWriter, r *http.Request, e HTTPError) { + // Don't write a response if we've hit a cancellation/abort. + if r.Context().Err() != nil || errors.Is(e.Err, http.ErrAbortHandler) { + return + } + // Default headers set by http.Error. h := w.Header() h.Set("Content-Type", "text/plain; charset=utf-8") h.Set("X-Content-Type-Options", "nosniff") // Custom headers from the error. - for k, vs := range hErr.Header { + for k, vs := range e.Header { h[k] = vs } // Write the msg back to the user. - w.WriteHeader(hErr.Code) - fmt.Fprint(w, hErr.Msg) + w.WriteHeader(e.Code) + fmt.Fprint(w, e.Msg) // If it's a plaintext message, add line breaks and RequestID. if strings.HasPrefix(h.Get("Content-Type"), "text/plain") { diff --git a/tsweb/tsweb_test.go b/tsweb/tsweb_test.go index 2633534aa..2bf2b7341 100644 --- a/tsweb/tsweb_test.go +++ b/tsweb/tsweb_test.go @@ -13,6 +13,7 @@ "net" "net/http" "net/http/httptest" + "net/http/httputil" "net/url" "strings" "testing" @@ -493,6 +494,25 @@ func TestStdHandler(t *testing.T) { wantBody: "not found with request ID " + exampleRequestID + "\n", }, + { + name: "inner_cancelled", + rh: handlerErr(0, context.Canceled), // return canceled error, but the request was not cancelled + r: req(bgCtx, "http://example.com/"), + wantCode: 500, + wantLog: AccessLogRecord{ + Time: startTime, + Seconds: 1.0, + Proto: "HTTP/1.1", + TLS: false, + Host: "example.com", + Method: "GET", + Code: 500, + Err: "context canceled", + RequestURI: "/", + }, + wantBody: "Internal Server Error\n", + }, + { name: "nested", rh: ReturnHandlerFunc(func(w http.ResponseWriter, r *http.Request) error { @@ -705,6 +725,7 @@ func TestStdHandler_Canceled(t *testing.T) { close(handlerOpen) ctx := r.Context() <-ctx.Done() + w.WriteHeader(200) // Ignored. return ctx.Err() }), HandlerOptions{ @@ -718,6 +739,8 @@ func TestStdHandler_Canceled(t *testing.T) { }, }, ) + s := httptest.NewServer(h) + t.Cleanup(s.Close) // Create a context which gets canceled after the handler starts processing // the request. @@ -727,9 +750,6 @@ func TestStdHandler_Canceled(t *testing.T) { cancelReq() }() - s := httptest.NewServer(h) - t.Cleanup(s.Close) - // Send a request to our server. req, err := http.NewRequestWithContext(ctx, httpm.GET, s.URL, nil) if err != nil { @@ -766,7 +786,172 @@ func TestStdHandler_Canceled(t *testing.T) { if e != nil { t.Errorf("got OnError callback with %#v, want no callback", e) } +} +func TestStdHandler_CanceledAfterHeader(t *testing.T) { + now := time.Now() + + r := make(chan AccessLogRecord) + var e *HTTPError + handlerOpen := make(chan struct{}) + h := StdHandler( + ReturnHandlerFunc(func(w http.ResponseWriter, r *http.Request) error { + w.WriteHeader(http.StatusNoContent) + close(handlerOpen) + ctx := r.Context() + <-ctx.Done() + return ctx.Err() + }), + HandlerOptions{ + Logf: t.Logf, + Now: func() time.Time { return now }, + OnError: func(w http.ResponseWriter, r *http.Request, h HTTPError) { + e = &h + }, + OnCompletion: func(_ *http.Request, alr AccessLogRecord) { + r <- alr + }, + }, + ) + s := httptest.NewServer(h) + t.Cleanup(s.Close) + + // Create a context which gets canceled after the handler starts processing + // the request. + ctx, cancelReq := context.WithCancel(context.Background()) + go func() { + <-handlerOpen + cancelReq() + }() + + // Send a request to our server. + req, err := http.NewRequestWithContext(ctx, httpm.GET, s.URL, nil) + if err != nil { + t.Fatalf("making request: %s", err) + } + res, err := http.DefaultClient.Do(req) + if !errors.Is(err, context.Canceled) { + t.Errorf("got error %v, want context.Canceled", err) + } + if res != nil { + t.Errorf("got response %#v, want nil", res) + } + + // Check that we got the expected log record. + got := <-r + got.Seconds = 0 + got.RemoteAddr = "" + got.Host = "" + got.UserAgent = "" + want := AccessLogRecord{ + Time: now, + Code: 499, + Method: "GET", + Err: "context canceled (original code 204)", + Proto: "HTTP/1.1", + RequestURI: "/", + } + if d := cmp.Diff(want, got); d != "" { + t.Errorf("AccessLogRecord wrong (-want +got)\n%s", d) + } + + // Check that we rendered no response to the client after + // logHandler.OnCompletion has been called. + if e != nil { + t.Errorf("got OnError callback with %#v, want no callback", e) + } +} + +func TestStdHandler_ConnectionClosedDuringBody(t *testing.T) { + now := time.Now() + + // Start a HTTP server that returns 1MB of data. + // We next put a reverse-proxy in front of this server. + rs := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + for range 1024 { + w.Write(make([]byte, 1024)) + } + })) + defer rs.Close() + + r := make(chan AccessLogRecord) + var e *HTTPError + responseStarted := make(chan struct{}) + + // Create another server which proxies our 1MB server. + // The [httputil.ReverseProxy] will panic with [http.ErrAbortHandler] when + // it fails to copy the response to the client. + h := StdHandler( + ReturnHandlerFunc(func(w http.ResponseWriter, r *http.Request) error { + (&httputil.ReverseProxy{ + Director: func(r *http.Request) { + r.URL = must.Get(url.Parse(rs.URL)) + }, + ModifyResponse: func(r *http.Response) error { + close(responseStarted) + return nil + }, + }).ServeHTTP(w, r) + return nil + }), + HandlerOptions{ + Logf: t.Logf, + Now: func() time.Time { return now }, + OnError: func(w http.ResponseWriter, r *http.Request, h HTTPError) { + e = &h + }, + OnCompletion: func(_ *http.Request, alr AccessLogRecord) { + r <- alr + }, + }, + ) + s := httptest.NewServer(h) + t.Cleanup(s.Close) + + // Create a context which gets canceled after the handler starts processing + // the request. + ctx, cancelReq := context.WithCancel(context.Background()) + go func() { + <-responseStarted + cancelReq() + }() + + // Send a request to our server. + req, err := http.NewRequestWithContext(ctx, httpm.GET, s.URL, nil) + if err != nil { + t.Fatalf("making request: %s", err) + } + res, err := http.DefaultClient.Do(req) + if !errors.Is(err, context.Canceled) { + t.Errorf("got error %v, want context.Canceled", err) + } + if res != nil { + t.Errorf("got response %#v, want nil", res) + } + + // Check that we got the expected log record. + got := <-r + got.Seconds = 0 + got.RemoteAddr = "" + got.Host = "" + got.UserAgent = "" + want := AccessLogRecord{ + Time: now, + Code: 499, + Method: "GET", + Err: "net/http: abort Handler (original code 200)", + Proto: "HTTP/1.1", + RequestURI: "/", + } + if d := cmp.Diff(want, got, cmpopts.IgnoreFields(AccessLogRecord{}, "Bytes")); d != "" { + t.Errorf("AccessLogRecord wrong (-want +got)\n%s", d) + } + + // Check that we rendered no response to the client after + // logHandler.OnCompletion has been called. + if e != nil { + t.Errorf("got OnError callback with %#v, want no callback", e) + } } func TestStdHandler_OnErrorPanic(t *testing.T) {