diff --git a/prober/prober.go b/prober/prober.go index fec571e61..333661a38 100644 --- a/prober/prober.go +++ b/prober/prober.go @@ -9,6 +9,7 @@ import ( "container/ring" "context" + "encoding/json" "fmt" "hash/fnv" "log" @@ -449,6 +450,13 @@ func (probe *Probe) probeInfoLocked() ProbeInfo { return inf } +// RunHandlerResponse is the JSON response format for the RunHandler. +type RunHandlerResponse struct { + ProbeInfo ProbeInfo + PreviousSuccessRatio float64 + PreviousMedianLatency time.Duration +} + // RunHandler runs a probe by name and returns the result as an HTTP response. func (p *Prober) RunHandler(w http.ResponseWriter, r *http.Request) error { // Look up prober by name. @@ -458,16 +466,40 @@ func (p *Prober) RunHandler(w http.ResponseWriter, r *http.Request) error { } p.mu.Lock() probe, ok := p.probes[name] - prevInfo := probe.probeInfoLocked() p.mu.Unlock() if !ok { return tsweb.Error(http.StatusNotFound, fmt.Sprintf("unknown probe %q", name), nil) } + + probe.mu.Lock() + prevInfo := probe.probeInfoLocked() + probe.mu.Unlock() + info, err := probe.run() + respStatus := http.StatusOK + if err != nil { + respStatus = http.StatusFailedDependency + } + + // Return serialized JSON response if the client requested JSON + if r.Header.Get("Accept") == "application/json" { + resp := &RunHandlerResponse{ + ProbeInfo: info, + PreviousSuccessRatio: prevInfo.RecentSuccessRatio(), + PreviousMedianLatency: prevInfo.RecentMedianLatency(), + } + w.WriteHeader(respStatus) + w.Header().Set("Content-Type", "application/json") + if err := json.NewEncoder(w).Encode(resp); err != nil { + return tsweb.Error(http.StatusInternalServerError, "error encoding JSON response", err) + } + return nil + } + stats := fmt.Sprintf("Previous runs: success rate %d%%, median latency %v", int(prevInfo.RecentSuccessRatio()*100), prevInfo.RecentMedianLatency()) if err != nil { - return tsweb.Error(http.StatusFailedDependency, fmt.Sprintf("Probe failed: %s\n%s", err.Error(), stats), err) + return tsweb.Error(respStatus, fmt.Sprintf("Probe failed: %s\n%s", err.Error(), stats), err) } w.WriteHeader(respStatus) w.Write([]byte(fmt.Sprintf("Probe succeeded in %v\n%s", info.Latency, stats))) diff --git a/prober/prober_test.go b/prober/prober_test.go index a194e6f5c..742a914b2 100644 --- a/prober/prober_test.go +++ b/prober/prober_test.go @@ -5,8 +5,11 @@ import ( "context" + "encoding/json" "errors" "fmt" + "io" + "net/http/httptest" "strings" "sync" "sync/atomic" @@ -17,6 +20,7 @@ "github.com/google/go-cmp/cmp/cmpopts" "github.com/prometheus/client_golang/prometheus/testutil" "tailscale.com/tstest" + "tailscale.com/tsweb" ) const ( @@ -461,6 +465,87 @@ type probeResult struct { } } +func TestProberRunHandler(t *testing.T) { + clk := newFakeTime() + + tests := []struct { + name string + probeFunc func(context.Context) error + wantResponseCode int + wantJSONResponse RunHandlerResponse + wantPlaintextResponse string + }{ + { + name: "success", + probeFunc: func(context.Context) error { return nil }, + wantResponseCode: 200, + wantJSONResponse: RunHandlerResponse{ + ProbeInfo: ProbeInfo{ + Name: "success", + Interval: probeInterval, + Result: true, + RecentResults: []bool{true, true}, + }, + PreviousSuccessRatio: 1, + }, + wantPlaintextResponse: "Probe succeeded", + }, + { + name: "failure", + probeFunc: func(context.Context) error { return fmt.Errorf("error123") }, + wantResponseCode: 424, + wantJSONResponse: RunHandlerResponse{ + ProbeInfo: ProbeInfo{ + Name: "failure", + Interval: probeInterval, + Result: false, + Error: "error123", + RecentResults: []bool{false, false}, + }, + }, + wantPlaintextResponse: "Probe failed", + }, + } + + for _, tt := range tests { + for _, reqJSON := range []bool{true, false} { + t.Run(fmt.Sprintf("%s_json-%v", tt.name, reqJSON), func(t *testing.T) { + p := newForTest(clk.Now, clk.NewTicker).WithOnce(true) + probe := p.Run(tt.name, probeInterval, nil, FuncProbe(tt.probeFunc)) + defer probe.Close() + <-probe.stopped // wait for the first run. + + w := httptest.NewRecorder() + + req := httptest.NewRequest("GET", "/prober/run/?name="+tt.name, nil) + if reqJSON { + req.Header.Set("Accept", "application/json") + } + tsweb.StdHandler(tsweb.ReturnHandlerFunc(p.RunHandler), tsweb.HandlerOptions{}).ServeHTTP(w, req) + if w.Result().StatusCode != tt.wantResponseCode { + t.Errorf("unexpected response code: got %d, want %d", w.Code, tt.wantResponseCode) + } + + if reqJSON { + var gotJSON RunHandlerResponse + if err := json.Unmarshal(w.Body.Bytes(), &gotJSON); err != nil { + t.Fatalf("failed to unmarshal JSON response: %v; body: %s", err, w.Body.String()) + } + if diff := cmp.Diff(tt.wantJSONResponse, gotJSON, cmpopts.IgnoreFields(ProbeInfo{}, "Start", "End", "Labels", "RecentLatencies")); diff != "" { + t.Errorf("unexpected JSON response (-want +got):\n%s", diff) + } + } else { + body, _ := io.ReadAll(w.Result().Body) + if !strings.Contains(string(body), tt.wantPlaintextResponse) { + t.Errorf("unexpected response body: got %q, want to contain %q", body, tt.wantPlaintextResponse) + } + } + }) + } + } + +} + type fakeTicker struct { ch chan time.Time interval time.Duration