mirror of
https://github.com/tailscale/tailscale.git
synced 2025-02-18 02:48:40 +00:00
tsweb: accept a function to call before request handling
To complement the existing `onCompletion` callback, which is called after request handler. Updates tailscale/corp#17075 Signed-off-by: Anton Tolchanov <anton@tailscale.com>
This commit is contained in:
parent
6e55d8f6a1
commit
787ead835f
@ -250,19 +250,25 @@ type HandlerOptions struct {
|
||||
// for each bucket based on the contained parameters.
|
||||
BucketedStats *BucketedStatsOptions
|
||||
|
||||
// OnStart is called inline before ServeHTTP is called. Optional.
|
||||
OnStart OnStartFunc
|
||||
|
||||
// OnError is called if the handler returned a HTTPError. This
|
||||
// is intended to be used to present pretty error pages if
|
||||
// the user agent is determined to be a browser.
|
||||
OnError ErrorHandlerFunc
|
||||
|
||||
// OnCompletion is called when ServeHTTP is finished and gets
|
||||
// useful data that the implementor can use for metrics.
|
||||
// OnCompletion is called inline when ServeHTTP is finished and gets
|
||||
// useful data that the implementor can use for metrics. Optional.
|
||||
OnCompletion OnCompletionFunc
|
||||
}
|
||||
|
||||
// ErrorHandlerFunc is called to present a error response.
|
||||
type ErrorHandlerFunc func(http.ResponseWriter, *http.Request, HTTPError)
|
||||
|
||||
// OnStartFunc is called before ServeHTTP is called.
|
||||
type OnStartFunc func(*http.Request, AccessLogRecord)
|
||||
|
||||
// OnCompletionFunc is called when ServeHTTP is finished and gets
|
||||
// useful data that the implementor can use for metrics.
|
||||
type OnCompletionFunc func(*http.Request, AccessLogRecord)
|
||||
@ -336,6 +342,10 @@ func (h retHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
}
|
||||
|
||||
if fn := h.opts.OnStart; fn != nil {
|
||||
fn(r, msg)
|
||||
}
|
||||
|
||||
lw := &loggingResponseWriter{ResponseWriter: w, logf: h.opts.Logf}
|
||||
|
||||
// In case the handler panics, we want to recover and continue logging the
|
||||
|
@ -17,6 +17,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/google/go-cmp/cmp/cmpopts"
|
||||
"tailscale.com/tstest"
|
||||
"tailscale.com/util/must"
|
||||
"tailscale.com/util/vizerror"
|
||||
@ -485,8 +486,15 @@ func TestStdHandler(t *testing.T) {
|
||||
Step: time.Second,
|
||||
})
|
||||
|
||||
var onStartRecord, onCompletionRecord AccessLogRecord
|
||||
rec := noopHijacker{httptest.NewRecorder(), false}
|
||||
h := StdHandler(test.rh, HandlerOptions{Logf: logf, Now: clock.Now, OnError: test.errHandler})
|
||||
h := StdHandler(test.rh, HandlerOptions{
|
||||
Logf: logf,
|
||||
Now: clock.Now,
|
||||
OnError: test.errHandler,
|
||||
OnStart: func(r *http.Request, alr AccessLogRecord) { onStartRecord = alr },
|
||||
OnCompletion: func(r *http.Request, alr AccessLogRecord) { onCompletionRecord = alr },
|
||||
})
|
||||
h.ServeHTTP(&rec, test.r)
|
||||
res := rec.Result()
|
||||
if res.StatusCode != test.wantCode {
|
||||
@ -502,6 +510,13 @@ func TestStdHandler(t *testing.T) {
|
||||
}
|
||||
return e.Error()
|
||||
})
|
||||
if diff := cmp.Diff(onStartRecord, test.wantLog, errTransform, cmpopts.IgnoreFields(
|
||||
AccessLogRecord{}, "Time", "Seconds", "Code", "Err")); diff != "" {
|
||||
t.Errorf("onStart callback returned unexpected request log (-got+want):\n%s", diff)
|
||||
}
|
||||
if diff := cmp.Diff(onCompletionRecord, test.wantLog, errTransform); diff != "" {
|
||||
t.Errorf("onCompletion callback returned incorrect request log (-got+want):\n%s", diff)
|
||||
}
|
||||
if diff := cmp.Diff(logs[0], test.wantLog, errTransform); diff != "" {
|
||||
t.Errorf("handler wrote incorrect request log (-got+want):\n%s", diff)
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user