diff --git a/tsweb/tsweb.go b/tsweb/tsweb.go index f1ba6d965..4ea0748e0 100644 --- a/tsweb/tsweb.go +++ b/tsweb/tsweb.go @@ -250,19 +250,25 @@ type HandlerOptions struct { // for each bucket based on the contained parameters. BucketedStats *BucketedStatsOptions + // OnStart is called inline before ServeHTTP is called. Optional. + OnStart OnStartFunc + // OnError is called if the handler returned a HTTPError. This // is intended to be used to present pretty error pages if // the user agent is determined to be a browser. OnError ErrorHandlerFunc - // OnCompletion is called when ServeHTTP is finished and gets - // useful data that the implementor can use for metrics. + // OnCompletion is called inline when ServeHTTP is finished and gets + // useful data that the implementor can use for metrics. Optional. OnCompletion OnCompletionFunc } // ErrorHandlerFunc is called to present a error response. type ErrorHandlerFunc func(http.ResponseWriter, *http.Request, HTTPError) +// OnStartFunc is called before ServeHTTP is called. +type OnStartFunc func(*http.Request, AccessLogRecord) + // OnCompletionFunc is called when ServeHTTP is finished and gets // useful data that the implementor can use for metrics. type OnCompletionFunc func(*http.Request, AccessLogRecord) @@ -336,6 +342,10 @@ func (h retHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { } } + if fn := h.opts.OnStart; fn != nil { + fn(r, msg) + } + lw := &loggingResponseWriter{ResponseWriter: w, logf: h.opts.Logf} // In case the handler panics, we want to recover and continue logging the diff --git a/tsweb/tsweb_test.go b/tsweb/tsweb_test.go index 0f1d82114..30036ae1c 100644 --- a/tsweb/tsweb_test.go +++ b/tsweb/tsweb_test.go @@ -17,6 +17,7 @@ import ( "time" "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" "tailscale.com/tstest" "tailscale.com/util/must" "tailscale.com/util/vizerror" @@ -485,8 +486,15 @@ func TestStdHandler(t *testing.T) { Step: time.Second, }) + var onStartRecord, onCompletionRecord AccessLogRecord rec := noopHijacker{httptest.NewRecorder(), false} - h := StdHandler(test.rh, HandlerOptions{Logf: logf, Now: clock.Now, OnError: test.errHandler}) + h := StdHandler(test.rh, HandlerOptions{ + Logf: logf, + Now: clock.Now, + OnError: test.errHandler, + OnStart: func(r *http.Request, alr AccessLogRecord) { onStartRecord = alr }, + OnCompletion: func(r *http.Request, alr AccessLogRecord) { onCompletionRecord = alr }, + }) h.ServeHTTP(&rec, test.r) res := rec.Result() if res.StatusCode != test.wantCode { @@ -502,6 +510,13 @@ func TestStdHandler(t *testing.T) { } return e.Error() }) + if diff := cmp.Diff(onStartRecord, test.wantLog, errTransform, cmpopts.IgnoreFields( + AccessLogRecord{}, "Time", "Seconds", "Code", "Err")); diff != "" { + t.Errorf("onStart callback returned unexpected request log (-got+want):\n%s", diff) + } + if diff := cmp.Diff(onCompletionRecord, test.wantLog, errTransform); diff != "" { + t.Errorf("onCompletion callback returned incorrect request log (-got+want):\n%s", diff) + } if diff := cmp.Diff(logs[0], test.wantLog, errTransform); diff != "" { t.Errorf("handler wrote incorrect request log (-got+want):\n%s", diff) }