mirror of
https://github.com/tailscale/tailscale.git
synced 2025-12-05 04:11:59 +00:00
tsweb: log once per request
StdHandler/retHandler would previously emit one log line for each request. If there were multiple StdHandler in the chain, there would be one log line per instance of retHandler. With this change, only the outermost StdHandler/logHandler actually logs the request or invokes OnStart or OnCompletion callbacks. The error-rendering part of retHandler lives on in errorHandler, and errorHandler passes those errors up the stack to logHandler through a callback that logHandler places in the request.Context(). Updates tailscale/corp#19999 Signed-off-by: Paul Scott <paul@tailscale.com>
This commit is contained in:
@@ -7,6 +7,7 @@ import (
|
||||
"bufio"
|
||||
"context"
|
||||
"errors"
|
||||
"expvar"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
@@ -19,6 +20,7 @@ import (
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/google/go-cmp/cmp/cmpopts"
|
||||
"tailscale.com/metrics"
|
||||
"tailscale.com/tstest"
|
||||
"tailscale.com/util/must"
|
||||
"tailscale.com/util/vizerror"
|
||||
@@ -60,11 +62,7 @@ func TestStdHandler(t *testing.T) {
|
||||
}
|
||||
|
||||
req = func(ctx context.Context, url string) *http.Request {
|
||||
ret, err := http.NewRequestWithContext(ctx, "GET", url, nil)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return ret
|
||||
return httptest.NewRequest("GET", url, nil).WithContext(ctx)
|
||||
}
|
||||
|
||||
testErr = errors.New("test error")
|
||||
@@ -278,6 +276,29 @@ func TestStdHandler(t *testing.T) {
|
||||
wantBody: "visible error\n",
|
||||
},
|
||||
|
||||
{
|
||||
name: "handler returns JSON-formatted HTTPError",
|
||||
rh: ReturnHandlerFunc(func(w http.ResponseWriter, r *http.Request) error {
|
||||
h := Error(http.StatusBadRequest, `{"isjson": true}`, errors.New("uh"))
|
||||
h.Header = http.Header{"Content-Type": {"application/json"}}
|
||||
return h
|
||||
}),
|
||||
r: req(RequestIDKey.WithValue(bgCtx, exampleRequestID), "http://example.com/foo"),
|
||||
wantCode: 400,
|
||||
wantLog: AccessLogRecord{
|
||||
Time: startTime,
|
||||
Seconds: 1.0,
|
||||
Proto: "HTTP/1.1",
|
||||
Host: "example.com",
|
||||
Method: "GET",
|
||||
RequestURI: "/foo",
|
||||
Err: `{"isjson": true}: uh`,
|
||||
Code: 400,
|
||||
RequestID: exampleRequestID,
|
||||
},
|
||||
wantBody: `{"isjson": true}`,
|
||||
},
|
||||
|
||||
{
|
||||
name: "handler returns user-visible error wrapped by private error with request ID",
|
||||
rh: handlerErr(0, fmt.Errorf("private internal error: %w", vizerror.New("visible error"))),
|
||||
@@ -312,7 +333,7 @@ func TestStdHandler(t *testing.T) {
|
||||
Err: testErr.Error(),
|
||||
Code: 500,
|
||||
},
|
||||
wantBody: "internal server error\n",
|
||||
wantBody: "Internal Server Error\n",
|
||||
},
|
||||
|
||||
{
|
||||
@@ -331,7 +352,7 @@ func TestStdHandler(t *testing.T) {
|
||||
Code: 500,
|
||||
RequestID: exampleRequestID,
|
||||
},
|
||||
wantBody: "internal server error\n" + exampleRequestID + "\n",
|
||||
wantBody: "Internal Server Error\n" + exampleRequestID + "\n",
|
||||
},
|
||||
|
||||
{
|
||||
@@ -440,7 +461,7 @@ func TestStdHandler(t *testing.T) {
|
||||
TLS: false,
|
||||
Host: "example.com",
|
||||
Method: "GET",
|
||||
Code: 404,
|
||||
Code: 200,
|
||||
Err: "not found",
|
||||
RequestURI: "/",
|
||||
},
|
||||
@@ -463,66 +484,148 @@ func TestStdHandler(t *testing.T) {
|
||||
TLS: false,
|
||||
Host: "example.com",
|
||||
Method: "GET",
|
||||
Code: 404,
|
||||
Code: 200,
|
||||
Err: "not found",
|
||||
RequestURI: "/",
|
||||
RequestID: exampleRequestID,
|
||||
},
|
||||
wantBody: "not found with request ID " + exampleRequestID + "\n",
|
||||
},
|
||||
|
||||
{
|
||||
name: "nested",
|
||||
rh: ReturnHandlerFunc(func(w http.ResponseWriter, r *http.Request) error {
|
||||
// Here we completely handle the web response with an
|
||||
// independent StdHandler that is unaware of the outer
|
||||
// StdHandler and its logger.
|
||||
StdHandler(ReturnHandlerFunc(func(w http.ResponseWriter, r *http.Request) error {
|
||||
return Error(501, "Not Implemented", errors.New("uhoh"))
|
||||
}), HandlerOptions{
|
||||
OnError: func(w http.ResponseWriter, r *http.Request, h HTTPError) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(h.Code)
|
||||
fmt.Fprintf(w, `{"error": %q}`, h.Msg)
|
||||
},
|
||||
}).ServeHTTP(w, r)
|
||||
return nil
|
||||
}),
|
||||
r: req(RequestIDKey.WithValue(bgCtx, exampleRequestID), "http://example.com/"),
|
||||
wantCode: 501,
|
||||
wantLog: AccessLogRecord{
|
||||
Time: startTime,
|
||||
Seconds: 1.0,
|
||||
Proto: "HTTP/1.1",
|
||||
TLS: false,
|
||||
Host: "example.com",
|
||||
Method: "GET",
|
||||
Code: 501,
|
||||
Err: "Not Implemented: uhoh",
|
||||
RequestURI: "/",
|
||||
RequestID: exampleRequestID,
|
||||
},
|
||||
wantBody: `{"error": "Not Implemented"}`,
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
var logs []AccessLogRecord
|
||||
clock := tstest.NewClock(tstest.ClockOpts{
|
||||
Start: startTime,
|
||||
Step: time.Second,
|
||||
})
|
||||
|
||||
// Callbacks to track the emitted AccessLogRecords.
|
||||
var (
|
||||
logs []AccessLogRecord
|
||||
starts []AccessLogRecord
|
||||
comps []AccessLogRecord
|
||||
)
|
||||
logf := func(fmt string, args ...any) {
|
||||
if fmt == "%s" {
|
||||
logs = append(logs, args[0].(AccessLogRecord))
|
||||
}
|
||||
t.Logf(fmt, args...)
|
||||
}
|
||||
oncomp := func(r *http.Request, msg AccessLogRecord) {
|
||||
comps = append(comps, msg)
|
||||
}
|
||||
onstart := func(r *http.Request, msg AccessLogRecord) {
|
||||
starts = append(starts, msg)
|
||||
}
|
||||
|
||||
clock := tstest.NewClock(tstest.ClockOpts{
|
||||
Start: startTime,
|
||||
Step: time.Second,
|
||||
})
|
||||
bucket := func(r *http.Request) string { return r.URL.RequestURI() }
|
||||
|
||||
// Build the request handler.
|
||||
opts := HandlerOptions{
|
||||
Now: clock.Now,
|
||||
|
||||
var onStartRecord, onCompletionRecord AccessLogRecord
|
||||
rec := noopHijacker{httptest.NewRecorder(), false}
|
||||
h := StdHandler(test.rh, HandlerOptions{
|
||||
Logf: logf,
|
||||
Now: clock.Now,
|
||||
OnError: test.errHandler,
|
||||
OnStart: func(r *http.Request, alr AccessLogRecord) { onStartRecord = alr },
|
||||
OnCompletion: func(r *http.Request, alr AccessLogRecord) { onCompletionRecord = alr },
|
||||
})
|
||||
Logf: logf,
|
||||
OnStart: onstart,
|
||||
OnCompletion: oncomp,
|
||||
|
||||
StatusCodeCounters: &expvar.Map{},
|
||||
StatusCodeCountersFull: &expvar.Map{},
|
||||
BucketedStats: &BucketedStatsOptions{
|
||||
Bucket: bucket,
|
||||
Started: &metrics.LabelMap{},
|
||||
Finished: &metrics.LabelMap{},
|
||||
},
|
||||
}
|
||||
h := StdHandler(test.rh, opts)
|
||||
|
||||
// Pre-create the BucketedStats.{Started,Finished} metric for the
|
||||
// test request's bucket so that even non-200 status codes get
|
||||
// recorded immediately. logHandler tries to avoid counting unknown
|
||||
// paths, so here we're marking them known.
|
||||
opts.BucketedStats.Started.Get(bucket(test.r))
|
||||
opts.BucketedStats.Finished.Get(bucket(test.r))
|
||||
|
||||
// Perform the request.
|
||||
rec := noopHijacker{httptest.NewRecorder(), false}
|
||||
h.ServeHTTP(&rec, test.r)
|
||||
|
||||
// Validate the client received the expected response.
|
||||
res := rec.Result()
|
||||
if res.StatusCode != test.wantCode {
|
||||
t.Errorf("HTTP code = %v, want %v", res.StatusCode, test.wantCode)
|
||||
}
|
||||
if len(logs) != 1 {
|
||||
t.Errorf("handler didn't write a request log")
|
||||
return
|
||||
}
|
||||
errTransform := cmp.Transformer("err", func(e error) string {
|
||||
if e == nil {
|
||||
return ""
|
||||
}
|
||||
return e.Error()
|
||||
})
|
||||
if diff := cmp.Diff(onStartRecord, test.wantLog, errTransform, cmpopts.IgnoreFields(
|
||||
AccessLogRecord{}, "Time", "Seconds", "Code", "Err")); diff != "" {
|
||||
t.Errorf("onStart callback returned unexpected request log (-got+want):\n%s", diff)
|
||||
}
|
||||
if diff := cmp.Diff(onCompletionRecord, test.wantLog, errTransform); diff != "" {
|
||||
t.Errorf("onCompletion callback returned incorrect request log (-got+want):\n%s", diff)
|
||||
}
|
||||
if diff := cmp.Diff(logs[0], test.wantLog, errTransform); diff != "" {
|
||||
t.Errorf("handler wrote incorrect request log (-got+want):\n%s", diff)
|
||||
}
|
||||
if diff := cmp.Diff(rec.Body.String(), test.wantBody); diff != "" {
|
||||
t.Errorf("handler wrote incorrect body (-got+want):\n%s", diff)
|
||||
t.Errorf("handler wrote incorrect body (-got +want):\n%s", diff)
|
||||
}
|
||||
|
||||
// Fields we want to check for in tests but not repeat on every case.
|
||||
test.wantLog.RemoteAddr = "192.0.2.1:1234" // Hard-coded by httptest.NewRequest.
|
||||
test.wantLog.Bytes = len(test.wantBody)
|
||||
|
||||
// Validate the AccessLogRecords written to logf and sent back to
|
||||
// the OnCompletion handler.
|
||||
checkOutput := func(src string, msgs []AccessLogRecord, opts ...cmp.Option) {
|
||||
t.Helper()
|
||||
if len(msgs) != 1 {
|
||||
t.Errorf("%s: expected 1 msg, got: %#v", src, msgs)
|
||||
} else if diff := cmp.Diff(msgs[0], test.wantLog, opts...); diff != "" {
|
||||
t.Errorf("%s: wrong access log (-got +want):\n%s", src, diff)
|
||||
}
|
||||
}
|
||||
checkOutput("hander wrote logs", logs)
|
||||
checkOutput("start msgs", starts, cmpopts.IgnoreFields(AccessLogRecord{}, "Time", "Seconds", "Code", "Err", "Bytes"))
|
||||
checkOutput("completion msgs", comps)
|
||||
|
||||
// Validate the code counters.
|
||||
if got, want := opts.StatusCodeCounters.String(), fmt.Sprintf(`{"%dxx": 1}`, test.wantLog.Code/100); got != want {
|
||||
t.Errorf("StatusCodeCounters: got %s, want %s", got, want)
|
||||
}
|
||||
if got, want := opts.StatusCodeCountersFull.String(), fmt.Sprintf(`{"%d": 1}`, test.wantLog.Code); got != want {
|
||||
t.Errorf("StatusCodeCountersFull: got %s, want %s", got, want)
|
||||
}
|
||||
|
||||
// Validate the bucketed counters.
|
||||
if got, want := opts.BucketedStats.Started.String(), fmt.Sprintf("{%q: 1}", bucket(test.r)); got != want {
|
||||
t.Errorf("BucketedStats.Started: got %q, want %q", got, want)
|
||||
}
|
||||
if got, want := opts.BucketedStats.Finished.String(), fmt.Sprintf("{%q: 1}", bucket(test.r)); got != want {
|
||||
t.Errorf("BucketedStats.Finished: got %s, want %s", got, want)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user