diff --git a/cmd/derpprobe/derpprobe.go b/cmd/derpprobe/derpprobe.go
index f2c156bfb..56e08ae64 100644
--- a/cmd/derpprobe/derpprobe.go
+++ b/cmd/derpprobe/derpprobe.go
@@ -2,72 +2,45 @@
// SPDX-License-Identifier: BSD-3-Clause
// The derpprobe binary probes derpers.
-package main // import "tailscale.com/cmd/derper/derpprobe"
+package main
import (
- "bytes"
- "context"
- crand "crypto/rand"
- "crypto/x509"
- "encoding/json"
- "errors"
+ "expvar"
"flag"
"fmt"
"html"
"io"
"log"
- "net"
"net/http"
- "os"
"sort"
- "strings"
- "sync"
"time"
- "tailscale.com/derp"
- "tailscale.com/derp/derphttp"
- "tailscale.com/net/stun"
- "tailscale.com/tailcfg"
- "tailscale.com/types/key"
+ "tailscale.com/prober"
+ "tailscale.com/tsweb"
)
var (
derpMapURL = flag.String("derp-map", "https://login.tailscale.com/derpmap/default", "URL to DERP map (https:// or file://)")
listen = flag.String("listen", ":8030", "HTTP listen address")
probeOnce = flag.Bool("once", false, "probe once and print results, then exit; ignores the listen flag")
-)
-
-// certReissueAfter is the time after which we expect all certs to be
-// reissued, at minimum.
-//
-// This is currently set to the date of the LetsEncrypt ALPN revocation event of Jan 2022:
-// https://community.letsencrypt.org/t/questions-about-renewing-before-tls-alpn-01-revocations/170449
-//
-// If there's another revocation event, bump this again.
-var certReissueAfter = time.Unix(1643226768, 0)
-
-var (
- mu sync.Mutex
- state = map[nodePair]pairStatus{}
- lastDERPMap *tailcfg.DERPMap
- lastDERPMapAt time.Time
- certs = map[string]*x509.Certificate{}
+ interval = flag.Duration("interval", 15*time.Second, "probe interval")
)
func main() {
flag.Parse()
- // proactively load the DERP map. Nothing terrible happens if this fails, so we ignore
- // the error. The Slack bot will print a notification that the DERP map was empty.
- ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
- defer cancel()
- _, _ = getDERPMap(ctx)
+ p := prober.New().WithSpread(true).WithOnce(*probeOnce)
+ dp, err := prober.DERP(p, *derpMapURL, *interval, *interval, *interval)
+ if err != nil {
+ log.Fatal(err)
+ }
+ p.Run("derpmap-probe", *interval, nil, dp.ProbeMap)
if *probeOnce {
- log.Printf("Starting probe (may take up to 1m)")
- probe()
- log.Printf("Probe results:")
- st := getOverallStatus()
+ log.Printf("Waiting for all probes (may take up to 1m)")
+ p.Wait()
+
+ st := getOverallStatus(p)
for _, s := range st.good {
log.Printf("good: %s", s)
}
@@ -77,15 +50,11 @@ func main() {
return
}
- go probeLoop()
- go slackLoop()
- log.Fatal(http.ListenAndServe(*listen, http.HandlerFunc(serve)))
-}
-
-func setCert(name string, cert *x509.Certificate) {
- mu.Lock()
- defer mu.Unlock()
- certs[name] = cert
+ mux := http.NewServeMux()
+ tsweb.Debugger(mux)
+ expvar.Publish("derpprobe", p.Expvar())
+ mux.HandleFunc("/", http.HandlerFunc(serveFunc(p)))
+ log.Fatal(http.ListenAndServe(*listen, mux))
}
type overallStatus struct {
@@ -100,471 +69,43 @@ func (st *overallStatus) addGoodf(format string, a ...any) {
st.good = append(st.good, fmt.Sprintf(format, a...))
}
-func getOverallStatus() (o overallStatus) {
- mu.Lock()
- defer mu.Unlock()
- if lastDERPMap == nil {
- o.addBadf("no DERP map")
- return
- }
- now := time.Now()
- if age := now.Sub(lastDERPMapAt); age > time.Minute {
- o.addBadf("DERPMap hasn't been successfully refreshed in %v", age.Round(time.Second))
- }
-
- addPairMeta := func(pair nodePair) {
- st, ok := state[pair]
- age := now.Sub(st.at).Round(time.Second)
- switch {
- case !ok:
- o.addBadf("no state for %v", pair)
- case st.err != nil:
- o.addBadf("%v: %v", pair, st.err)
- case age > 90*time.Second:
- o.addBadf("%v: update is %v old", pair, age)
- default:
- o.addGoodf("%v: %v, %v ago", pair, st.latency.Round(time.Millisecond), age)
- }
- }
-
- for _, reg := range sortedRegions(lastDERPMap) {
- for _, from := range reg.Nodes {
- addPairMeta(nodePair{"UDP", from.Name})
- for _, to := range reg.Nodes {
- addPairMeta(nodePair{from.Name, to.Name})
- }
- }
- }
-
- var subjs []string
- for k := range certs {
- subjs = append(subjs, k)
- }
- sort.Strings(subjs)
-
- soon := time.Now().Add(14 * 24 * time.Hour) // in 2 weeks; autocert does 30 days by default
- for _, s := range subjs {
- cert := certs[s]
- if cert.NotBefore.Before(certReissueAfter) {
- o.addBadf("cert %q needs reissuing; NotBefore=%v", s, cert.NotBefore.Format(time.RFC3339))
+func getOverallStatus(p *prober.Prober) (o overallStatus) {
+ for p, i := range p.ProbeInfo() {
+ if i.End.IsZero() {
+ // Do not show probes that have not finished yet.
continue
}
- if cert.NotAfter.Before(soon) {
- o.addBadf("cert %q expiring soon (%v); wasn't auto-refreshed", s, cert.NotAfter.Format(time.RFC3339))
- continue
+ if i.Result {
+ o.addGoodf("%s: %s", p, i.Latency)
+ } else {
+ o.addBadf("%s: %s", p, i.Error)
}
- o.addGoodf("cert %q good %v - %v", s, cert.NotBefore.Format(time.RFC3339), cert.NotAfter.Format(time.RFC3339))
}
+ sort.Strings(o.bad)
+ sort.Strings(o.good)
return
}
-func serve(w http.ResponseWriter, r *http.Request) {
- st := getOverallStatus()
- summary := "All good"
- if (float64(len(st.bad)) / float64(len(st.bad)+len(st.good))) > 0.25 {
- // This will generate an alert and page a human.
- // It also ends up in Slack, but as part of the alert handling pipeline not
- // because we generated a Slack notification from here.
- w.WriteHeader(500)
- summary = fmt.Sprintf("%d problems", len(st.bad))
- }
-
- io.WriteString(w, "
\n")
- fmt.Fprintf(w, "derp probe
\n%s:", summary)
- for _, s := range st.bad {
- fmt.Fprintf(w, "- %s
\n", html.EscapeString(s))
- }
- for _, s := range st.good {
- fmt.Fprintf(w, "- %s
\n", html.EscapeString(s))
- }
- io.WriteString(w, "
\n")
-}
-
-func notifySlack(text string) error {
- type SlackRequestBody struct {
- Text string `json:"text"`
- }
-
- slackBody, err := json.Marshal(SlackRequestBody{Text: text})
- if err != nil {
- return err
- }
-
- webhookUrl := os.Getenv("SLACK_WEBHOOK")
- if webhookUrl == "" {
- return errors.New("No SLACK_WEBHOOK configured")
- }
-
- req, err := http.NewRequest("POST", webhookUrl, bytes.NewReader(slackBody))
- if err != nil {
- return err
- }
- req.Header.Add("Content-Type", "application/json")
-
- client := &http.Client{Timeout: 10 * time.Second}
- resp, err := client.Do(req)
- if err != nil {
- return err
- }
- defer resp.Body.Close()
-
- if resp.StatusCode != 200 {
- return errors.New(resp.Status)
- }
-
- body, _ := io.ReadAll(resp.Body)
- if string(body) != "ok" {
- return errors.New("Non-ok response returned from Slack")
- }
- return nil
-}
-
-// We only page a human if it looks like there is a significant outage across multiple regions.
-// To Slack, we report all failures great and small.
-func slackLoop() {
- inBadState := false
- for {
- time.Sleep(time.Second * 30)
- st := getOverallStatus()
-
- if len(st.bad) > 0 && !inBadState {
- err := notifySlack(strings.Join(st.bad, "\n"))
- if err == nil {
- inBadState = true
- } else {
- log.Printf("%d problems, notify Slack failed: %v", len(st.bad), err)
- }
+func serveFunc(p *prober.Prober) func(w http.ResponseWriter, r *http.Request) {
+ return func(w http.ResponseWriter, r *http.Request) {
+ st := getOverallStatus(p)
+ summary := "All good"
+ if (float64(len(st.bad)) / float64(len(st.bad)+len(st.good))) > 0.25 {
+ // Returning a 500 allows monitoring this server externally and configuring
+ // an alert on HTTP response code.
+ w.WriteHeader(500)
+ summary = fmt.Sprintf("%d problems", len(st.bad))
}
- if len(st.bad) == 0 && inBadState {
- err := notifySlack("All DERPs recovered.")
- if err == nil {
- inBadState = false
- }
+ io.WriteString(w, "\n")
+ fmt.Fprintf(w, "derp probe
\n%s:", summary)
+ for _, s := range st.bad {
+ fmt.Fprintf(w, "- %s
\n", html.EscapeString(s))
}
+ for _, s := range st.good {
+ fmt.Fprintf(w, "- %s
\n", html.EscapeString(s))
+ }
+ io.WriteString(w, "
\n")
}
}
-
-func sortedRegions(dm *tailcfg.DERPMap) []*tailcfg.DERPRegion {
- ret := make([]*tailcfg.DERPRegion, 0, len(dm.Regions))
- for _, r := range dm.Regions {
- ret = append(ret, r)
- }
- sort.Slice(ret, func(i, j int) bool { return ret[i].RegionID < ret[j].RegionID })
- return ret
-}
-
-type nodePair struct {
- from string // DERPNode.Name, or "UDP" for a STUN query to 'to'
- to string // DERPNode.Name
-}
-
-func (p nodePair) String() string { return fmt.Sprintf("(%s→%s)", p.from, p.to) }
-
-type pairStatus struct {
- err error
- latency time.Duration
- at time.Time
-}
-
-func setDERPMap(dm *tailcfg.DERPMap) {
- mu.Lock()
- defer mu.Unlock()
- lastDERPMap = dm
- lastDERPMapAt = time.Now()
-}
-
-func setState(p nodePair, latency time.Duration, err error) {
- mu.Lock()
- defer mu.Unlock()
- st := pairStatus{
- err: err,
- latency: latency,
- at: time.Now(),
- }
- state[p] = st
- if err != nil {
- log.Printf("%+v error: %v", p, err)
- } else {
- log.Printf("%+v: %v", p, latency.Round(time.Millisecond))
- }
-}
-
-func probeLoop() {
- ticker := time.NewTicker(15 * time.Second)
- for {
- err := probe()
- if err != nil {
- log.Printf("probe: %v", err)
- }
- <-ticker.C
- }
-}
-
-func probe() error {
- ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second)
- defer cancel()
- dm, err := getDERPMap(ctx)
- if err != nil {
- return err
- }
-
- var wg sync.WaitGroup
- wg.Add(len(dm.Regions))
- for _, reg := range dm.Regions {
- reg := reg
- go func() {
- defer wg.Done()
- for _, from := range reg.Nodes {
- latency, err := probeUDP(ctx, dm, from)
- setState(nodePair{"UDP", from.Name}, latency, err)
- for _, to := range reg.Nodes {
- latency, err := probeNodePair(ctx, dm, from, to)
- setState(nodePair{from.Name, to.Name}, latency, err)
- }
- }
- }()
- }
-
- wg.Wait()
- return ctx.Err()
-}
-
-func probeUDP(ctx context.Context, dm *tailcfg.DERPMap, n *tailcfg.DERPNode) (latency time.Duration, err error) {
- pc, err := net.ListenPacket("udp", ":0")
- if err != nil {
- return 0, err
- }
- defer pc.Close()
- uc := pc.(*net.UDPConn)
-
- tx := stun.NewTxID()
- req := stun.Request(tx)
-
- for _, ipStr := range []string{n.IPv4, n.IPv6} {
- if ipStr == "" {
- continue
- }
- port := n.STUNPort
- if port == -1 {
- continue
- }
- if port == 0 {
- port = 3478
- }
- for {
- ip := net.ParseIP(ipStr)
- _, err := uc.WriteToUDP(req, &net.UDPAddr{IP: ip, Port: port})
- if err != nil {
- return 0, err
- }
- buf := make([]byte, 1500)
- uc.SetReadDeadline(time.Now().Add(2 * time.Second))
- t0 := time.Now()
- n, _, err := uc.ReadFromUDP(buf)
- d := time.Since(t0)
- if err != nil {
- if ctx.Err() != nil {
- return 0, fmt.Errorf("timeout reading from %v: %v", ip, err)
- }
- if d < time.Second {
- return 0, fmt.Errorf("error reading from %v: %v", ip, err)
- }
- time.Sleep(100 * time.Millisecond)
- continue
- }
- txBack, _, err := stun.ParseResponse(buf[:n])
- if err != nil {
- return 0, fmt.Errorf("parsing STUN response from %v: %v", ip, err)
- }
- if txBack != tx {
- return 0, fmt.Errorf("read wrong tx back from %v", ip)
- }
- if latency == 0 || d < latency {
- latency = d
- }
- break
- }
- }
- return latency, nil
-}
-
-func probeNodePair(ctx context.Context, dm *tailcfg.DERPMap, from, to *tailcfg.DERPNode) (latency time.Duration, err error) {
- // The passed in context is a minute for the whole region. The
- // idea is that each node pair in the region will be done
- // serially and regularly in the future, reusing connections
- // (at least in the happy path). For now they don't reuse
- // connections and probe at most once every 15 seconds. We
- // bound the duration of a single node pair within a region
- // so one bad one can't starve others.
- ctx, cancel := context.WithTimeout(ctx, 10*time.Second)
- defer cancel()
-
- fromc, err := newConn(ctx, dm, from)
- if err != nil {
- return 0, err
- }
- defer fromc.Close()
- toc, err := newConn(ctx, dm, to)
- if err != nil {
- return 0, err
- }
- defer toc.Close()
-
- // Wait a bit for from's node to hear about to existing on the
- // other node in the region, in the case where the two nodes
- // are different.
- if from.Name != to.Name {
- time.Sleep(100 * time.Millisecond) // pretty arbitrary
- }
-
- // Make a random packet
- pkt := make([]byte, 8)
- crand.Read(pkt)
-
- t0 := time.Now()
-
- // Send the random packet.
- sendc := make(chan error, 1)
- go func() {
- sendc <- fromc.Send(toc.SelfPublicKey(), pkt)
- }()
- select {
- case <-ctx.Done():
- return 0, fmt.Errorf("timeout sending via %q: %w", from.Name, ctx.Err())
- case err := <-sendc:
- if err != nil {
- return 0, fmt.Errorf("error sending via %q: %w", from.Name, err)
- }
- }
-
- // Receive the random packet.
- recvc := make(chan any, 1) // either derp.ReceivedPacket or error
- go func() {
- for {
- m, err := toc.Recv()
- if err != nil {
- recvc <- err
- return
- }
- switch v := m.(type) {
- case derp.ReceivedPacket:
- recvc <- v
- default:
- log.Printf("%v: ignoring Recv frame type %T", to.Name, v)
- // Loop.
- }
- }
- }()
- select {
- case <-ctx.Done():
- return 0, fmt.Errorf("timeout receiving from %q: %w", to.Name, ctx.Err())
- case v := <-recvc:
- if err, ok := v.(error); ok {
- return 0, fmt.Errorf("error receiving from %q: %w", to.Name, err)
- }
- p := v.(derp.ReceivedPacket)
- if p.Source != fromc.SelfPublicKey() {
- return 0, fmt.Errorf("got data packet from unexpected source, %v", p.Source)
- }
- if !bytes.Equal(p.Data, pkt) {
- return 0, fmt.Errorf("unexpected data packet %q", p.Data)
- }
- }
- return time.Since(t0), nil
-}
-
-func newConn(ctx context.Context, dm *tailcfg.DERPMap, n *tailcfg.DERPNode) (*derphttp.Client, error) {
- priv := key.NewNode()
- dc := derphttp.NewRegionClient(priv, log.Printf, func() *tailcfg.DERPRegion {
- rid := n.RegionID
- return &tailcfg.DERPRegion{
- RegionID: rid,
- RegionCode: fmt.Sprintf("%s-%s", dm.Regions[rid].RegionCode, n.Name),
- RegionName: dm.Regions[rid].RegionName,
- Nodes: []*tailcfg.DERPNode{n},
- }
- })
- dc.IsProber = true
- err := dc.Connect(ctx)
- if err != nil {
- return nil, err
- }
- cs, ok := dc.TLSConnectionState()
- if !ok {
- dc.Close()
- return nil, errors.New("no TLS state")
- }
- if len(cs.PeerCertificates) == 0 {
- dc.Close()
- return nil, errors.New("no peer certificates")
- }
- if cs.ServerName != n.HostName {
- dc.Close()
- return nil, fmt.Errorf("TLS server name %q != derp hostname %q", cs.ServerName, n.HostName)
- }
- setCert(cs.ServerName, cs.PeerCertificates[0])
-
- errc := make(chan error, 1)
- go func() {
- m, err := dc.Recv()
- if err != nil {
- errc <- err
- return
- }
- switch m.(type) {
- case derp.ServerInfoMessage:
- errc <- nil
- default:
- errc <- fmt.Errorf("unexpected first message type %T", errc)
- }
- }()
- select {
- case err := <-errc:
- if err != nil {
- go dc.Close()
- return nil, err
- }
- case <-ctx.Done():
- go dc.Close()
- return nil, fmt.Errorf("timeout waiting for ServerInfoMessage: %w", ctx.Err())
- }
- return dc, nil
-}
-
-var httpOrFileClient = &http.Client{Transport: httpOrFileTransport()}
-
-func httpOrFileTransport() http.RoundTripper {
- tr := http.DefaultTransport.(*http.Transport).Clone()
- tr.RegisterProtocol("file", http.NewFileTransport(http.Dir("/")))
- return tr
-}
-
-func getDERPMap(ctx context.Context) (*tailcfg.DERPMap, error) {
- req, err := http.NewRequestWithContext(ctx, "GET", *derpMapURL, nil)
- if err != nil {
- return nil, err
- }
- res, err := httpOrFileClient.Do(req)
- if err != nil {
- mu.Lock()
- defer mu.Unlock()
- if lastDERPMap != nil && time.Since(lastDERPMapAt) < 10*time.Minute {
- // Assume that control is restarting and use
- // the same one for a bit.
- return lastDERPMap, nil
- }
- return nil, err
- }
- defer res.Body.Close()
- if res.StatusCode != 200 {
- return nil, fmt.Errorf("fetching %s: %s", *derpMapURL, res.Status)
- }
- dm := new(tailcfg.DERPMap)
- if err := json.NewDecoder(res.Body).Decode(dm); err != nil {
- return nil, fmt.Errorf("decoding %s JSON: %v", *derpMapURL, err)
- }
- setDERPMap(dm)
- return dm, nil
-}
diff --git a/prober/derp.go b/prober/derp.go
index e5c07a8d2..2609dba73 100644
--- a/prober/derp.go
+++ b/prober/derp.go
@@ -157,7 +157,7 @@ func (d *derpProber) updateMap(ctx context.Context) error {
if err != nil {
return nil
}
- res, err := http.DefaultClient.Do(req)
+ res, err := httpOrFileClient.Do(req)
if err != nil {
d.Lock()
defer d.Unlock()
@@ -389,3 +389,11 @@ func newConn(ctx context.Context, dm *tailcfg.DERPMap, n *tailcfg.DERPNode) (*de
}
return dc, nil
}
+
+var httpOrFileClient = &http.Client{Transport: httpOrFileTransport()}
+
+func httpOrFileTransport() http.RoundTripper {
+ tr := http.DefaultTransport.(*http.Transport).Clone()
+ tr.RegisterProtocol("file", http.NewFileTransport(http.Dir("/")))
+ return tr
+}
diff --git a/prober/prober.go b/prober/prober.go
index 46956b4e5..84ee17900 100644
--- a/prober/prober.go
+++ b/prober/prober.go
@@ -33,6 +33,9 @@ type Prober struct {
// random delay before the first probe run.
spread bool
+ // Whether to run all probes once instead of running them in a loop.
+ once bool
+
// Time-related functions that get faked out during tests.
now func() time.Time
newTicker func(time.Duration) ticker
@@ -59,6 +62,11 @@ func (p *Prober) Expvar() expvar.Var {
return varExporter{p}
}
+// ProbeInfo returns information about most recent probe runs.
+func (p *Prober) ProbeInfo() map[string]ProbeInfo {
+ return varExporter{p}.probeInfo()
+}
+
// Run executes fun every interval, and exports probe results under probeName.
//
// Registering a probe under an already-registered name panics.
@@ -101,7 +109,37 @@ func (p *Prober) WithSpread(s bool) *Prober {
return p
}
-// Reports the number of registered probes. For tests only.
+// WithOnce mode can be used if you want to run all configured probes once
+// rather than on a schedule.
+func (p *Prober) WithOnce(s bool) *Prober {
+ p.once = s
+ return p
+}
+
+// Wait blocks until all probes have finished execution. It should typically
+// be used with the `once` mode to wait for probes to finish before collecting
+// their results.
+func (p *Prober) Wait() {
+ for {
+ chans := make([]chan struct{}, 0)
+ p.mu.Lock()
+ for _, p := range p.probes {
+ chans = append(chans, p.stopped)
+ }
+ p.mu.Unlock()
+ for _, c := range chans {
+ <-c
+ }
+
+ // Since probes can add other probes, retry if the number of probes has changed.
+ if p.activeProbes() != len(chans) {
+ continue
+ }
+ return
+ }
+}
+
+// Reports the number of registered probes.
func (p *Prober) activeProbes() int {
p.mu.Lock()
defer p.mu.Unlock()
@@ -123,10 +161,11 @@ type Probe struct {
tick ticker
labels map[string]string
- mu sync.Mutex
- start time.Time // last time doProbe started
- end time.Time // last time doProbe returned
- result bool // whether the last doProbe call succeeded
+ mu sync.Mutex
+ start time.Time // last time doProbe started
+ end time.Time // last time doProbe returned
+ result bool // whether the last doProbe call succeeded
+ lastErr error
}
// Close shuts down the Probe and unregisters it from its Prober.
@@ -157,6 +196,10 @@ func (p *Probe) loop() {
p.run()
}
+ if p.prober.once {
+ return
+ }
+
p.tick = p.prober.newTicker(p.interval)
defer p.tick.Stop()
for {
@@ -212,26 +255,26 @@ func (p *Probe) recordEnd(start time.Time, err error) {
defer p.mu.Unlock()
p.end = end
p.result = err == nil
+ p.lastErr = err
}
type varExporter struct {
p *Prober
}
-// probeInfo is the state of a Probe. Used in expvar-format debug
+// ProbeInfo is the state of a Probe. Used in expvar-format debug
// data.
-type probeInfo struct {
+type ProbeInfo struct {
Labels map[string]string
Start time.Time
End time.Time
Latency string // as a string because time.Duration doesn't encode readably to JSON
Result bool
+ Error string
}
-// String implements expvar.Var, returning the prober's state as an
-// encoded JSON map of probe name to its probeInfo.
-func (v varExporter) String() string {
- out := map[string]probeInfo{}
+func (v varExporter) probeInfo() map[string]ProbeInfo {
+ out := map[string]ProbeInfo{}
v.p.mu.Lock()
probes := make([]*Probe, 0, len(v.p.probes))
@@ -242,20 +285,28 @@ func (v varExporter) String() string {
for _, probe := range probes {
probe.mu.Lock()
- inf := probeInfo{
+ inf := ProbeInfo{
Labels: probe.labels,
Start: probe.start,
End: probe.end,
Result: probe.result,
}
+ if probe.lastErr != nil {
+ inf.Error = probe.lastErr.Error()
+ }
if probe.end.After(probe.start) {
inf.Latency = probe.end.Sub(probe.start).String()
}
out[probe.name] = inf
probe.mu.Unlock()
}
+ return out
+}
- bs, err := json.Marshal(out)
+// String implements expvar.Var, returning the prober's state as an
+// encoded JSON map of probe name to its ProbeInfo.
+func (v varExporter) String() string {
+ bs, err := json.Marshal(v.probeInfo())
if err != nil {
return fmt.Sprintf(`{"error": %q}`, err)
}
diff --git a/prober/prober_test.go b/prober/prober_test.go
index 29b6473c8..5bb6f0460 100644
--- a/prober/prober_test.go
+++ b/prober/prober_test.go
@@ -214,7 +214,7 @@ func TestExpvar(t *testing.T) {
waitActiveProbes(t, p, clk, 1)
- check := func(name string, want probeInfo) {
+ check := func(name string, want ProbeInfo) {
t.Helper()
err := tstest.WaitFor(convergenceTimeout, func() error {
vars := probeExpvar(t, p)
@@ -236,19 +236,20 @@ func TestExpvar(t *testing.T) {
}
}
- check("probe", probeInfo{
+ check("probe", ProbeInfo{
Labels: map[string]string{"label": "value"},
Start: epoch,
End: epoch.Add(aFewMillis),
Latency: aFewMillis.String(),
Result: false,
+ Error: "failing, as instructed by test",
})
succeed.Store(true)
clk.Advance(probeInterval + halfProbeInterval)
st := epoch.Add(probeInterval + halfProbeInterval + aFewMillis)
- check("probe", probeInfo{
+ check("probe", ProbeInfo{
Labels: map[string]string{"label": "value"},
Start: st,
End: st.Add(aFewMillis),
@@ -316,6 +317,31 @@ func TestPrometheus(t *testing.T) {
}
}
+func TestOnceMode(t *testing.T) {
+ clk := newFakeTime()
+ p := newForTest(clk.Now, clk.NewTicker).WithOnce(true)
+
+ p.Run("probe1", probeInterval, nil, func(context.Context) error { return nil })
+ p.Run("probe2", probeInterval, nil, func(context.Context) error { return fmt.Errorf("error2") })
+ p.Run("probe3", probeInterval, nil, func(context.Context) error {
+ p.Run("probe4", probeInterval, nil, func(context.Context) error {
+ return fmt.Errorf("error4")
+ })
+ return nil
+ })
+
+ p.Wait()
+ info := p.ProbeInfo()
+ if len(info) != 4 {
+ t.Errorf("expected 4 probe results, got %+v", info)
+ }
+ for _, p := range info {
+ if p.End.IsZero() {
+ t.Errorf("expected all probes to finish; got %+v", p)
+ }
+ }
+}
+
type fakeTicker struct {
ch chan time.Time
interval time.Duration
@@ -409,10 +435,10 @@ func (t *fakeTime) activeTickers() (count int) {
return
}
-func probeExpvar(t *testing.T, p *Prober) map[string]*probeInfo {
+func probeExpvar(t *testing.T, p *Prober) map[string]*ProbeInfo {
t.Helper()
s := p.Expvar().String()
- ret := map[string]*probeInfo{}
+ ret := map[string]*ProbeInfo{}
if err := json.Unmarshal([]byte(s), &ret); err != nil {
t.Fatalf("expvar json decode failed: %v", err)
}