fix(OTEL): reduce high cardinality in traces and metrics (#9286)

# Which Problems Are Solved

There were multiple issues in the OpenTelemetry (OTEL) implementation
and usage for tracing and metrics, which lead to high cardinality and
potential memory leaks:
- wrongly initiated tracing interceptors
- high cardinality in traces:
  - HTTP/1.1 endpoints containing host names
- HTTP/1.1 endpoints containing object IDs like userID (e.g.
`/management/v1/users/2352839823/`)
- high amount of traces from internal processes (spooler)
- high cardinality in metrics endpoint:
  - GRPC entries containing host names
  - notification metrics containing instanceIDs and error messages

# How the Problems Are Solved

- Properly initialize the interceptors once and update them to use the
grpc stats handler (unary interceptors were deprecated).
- Remove host names from HTTP/1.1 span names and use path as default.
- Set / overwrite the uri for spans on the grpc-gateway with the uri
pattern (`/management/v1/users/{user_id}`). This is used for spans in
traces and metric entries.
- Created a new sampler which will only sample spans in the following
cases:
  - remote was already sampled
- remote was not sampled, root span is of kind `Server` and based on
fraction set in the runtime configuration
- This will prevent having a lot of spans from the spooler back ground
jobs if they were not started by a client call querying an object (e.g.
UserByID).
- Filter out host names and alike from OTEL generated metrics (using a
`view`).
- Removed instance and error messages from notification metrics.

# Additional Changes

Fixed the middleware handling for serving Console. Telemetry and
instance selection are only used for the environment.json, but not on
statically served files.

# Additional Context

- closes #8096 
- relates to #9074
- back ports to at least 2.66.x, 2.67.x and 2.68.x
This commit is contained in:
Livio Spring 2025-02-04 09:55:26 +01:00 committed by GitHub
parent 04b9e9b144
commit 990e1982c7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
14 changed files with 237 additions and 60 deletions

View File

@ -1,36 +1,29 @@
package middleware package middleware
import ( import (
"context"
"strings" "strings"
grpc_trace "go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc" grpc_trace "go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc"
"google.golang.org/grpc" "google.golang.org/grpc/stats"
grpc_utils "github.com/zitadel/zitadel/internal/api/grpc" grpc_utils "github.com/zitadel/zitadel/internal/api/grpc"
) )
type GRPCMethod string type GRPCMethod string
func DefaultTracingClient() grpc.UnaryClientInterceptor { func DefaultTracingClient() stats.Handler {
return TracingServer(grpc_utils.Healthz, grpc_utils.Readiness, grpc_utils.Validation) return TracingClient(grpc_utils.Healthz, grpc_utils.Readiness, grpc_utils.Validation)
} }
func TracingServer(ignoredMethods ...GRPCMethod) grpc.UnaryClientInterceptor { func TracingClient(ignoredMethods ...GRPCMethod) stats.Handler {
return func( return grpc_trace.NewClientHandler(grpc_trace.WithFilter(
ctx context.Context, func(info *stats.RPCTagInfo) bool {
method string, for _, ignoredMethod := range ignoredMethods {
req, reply interface{}, if strings.HasSuffix(info.FullMethodName, string(ignoredMethod)) {
cc *grpc.ClientConn, return false
invoker grpc.UnaryInvoker, }
opts ...grpc.CallOption,
) error {
for _, ignoredMethod := range ignoredMethods {
if strings.HasSuffix(method, string(ignoredMethod)) {
return invoker(ctx, method, req, reply, cc, opts...)
} }
} return true
return grpc_trace.UnaryClientInterceptor()(ctx, method, req, reply, cc, invoker, opts...) },
} ))
} }

View File

@ -10,6 +10,7 @@ import (
"github.com/grpc-ecosystem/grpc-gateway/v2/runtime" "github.com/grpc-ecosystem/grpc-gateway/v2/runtime"
"github.com/zitadel/logging" "github.com/zitadel/logging"
"go.opentelemetry.io/otel/trace"
"google.golang.org/grpc" "google.golang.org/grpc"
"google.golang.org/grpc/codes" "google.golang.org/grpc/codes"
"google.golang.org/grpc/credentials" "google.golang.org/grpc/credentials"
@ -56,6 +57,13 @@ var (
}, },
) )
// we need the errorHandler to set the request URI pattern in case of an error
errorHandler = runtime.ErrorHandlerFunc(
func(ctx context.Context, mux *runtime.ServeMux, marshaler runtime.Marshaler, w http.ResponseWriter, r *http.Request, err error) {
setRequestURIPattern(ctx)
runtime.DefaultHTTPErrorHandler(ctx, mux, marshaler, w, r, err)
})
serveMuxOptions = func(hostHeaders []string) []runtime.ServeMuxOption { serveMuxOptions = func(hostHeaders []string) []runtime.ServeMuxOption {
return []runtime.ServeMuxOption{ return []runtime.ServeMuxOption{
runtime.WithMarshalerOption(jsonMarshaler.ContentType(nil), jsonMarshaler), runtime.WithMarshalerOption(jsonMarshaler.ContentType(nil), jsonMarshaler),
@ -65,6 +73,7 @@ var (
runtime.WithOutgoingHeaderMatcher(runtime.DefaultHeaderMatcher), runtime.WithOutgoingHeaderMatcher(runtime.DefaultHeaderMatcher),
runtime.WithForwardResponseOption(responseForwarder), runtime.WithForwardResponseOption(responseForwarder),
runtime.WithRoutingErrorHandler(httpErrorHandler), runtime.WithRoutingErrorHandler(httpErrorHandler),
runtime.WithErrorHandler(errorHandler),
} }
} }
@ -81,6 +90,7 @@ var (
} }
responseForwarder = func(ctx context.Context, w http.ResponseWriter, resp proto.Message) error { responseForwarder = func(ctx context.Context, w http.ResponseWriter, resp proto.Message) error {
setRequestURIPattern(ctx)
t, ok := resp.(CustomHTTPResponse) t, ok := resp.(CustomHTTPResponse)
if ok { if ok {
// TODO: find a way to return a location header if needed w.Header().Set("location", t.Location()) // TODO: find a way to return a location header if needed w.Header().Set("location", t.Location())
@ -118,9 +128,9 @@ func CreateGatewayWithPrefix(
opts := []grpc.DialOption{ opts := []grpc.DialOption{
grpc.WithTransportCredentials(grpcCredentials(tlsConfig)), grpc.WithTransportCredentials(grpcCredentials(tlsConfig)),
grpc.WithChainUnaryInterceptor( grpc.WithChainUnaryInterceptor(
client_middleware.DefaultTracingClient(),
client_middleware.UnaryActivityClientInterceptor(), client_middleware.UnaryActivityClientInterceptor(),
), ),
grpc.WithStatsHandler(client_middleware.DefaultTracingClient()),
} }
connection, err := dial(ctx, port, opts) connection, err := dial(ctx, port, opts)
if err != nil { if err != nil {
@ -145,9 +155,9 @@ func CreateGateway(
[]grpc.DialOption{ []grpc.DialOption{
grpc.WithTransportCredentials(grpcCredentials(tlsConfig)), grpc.WithTransportCredentials(grpcCredentials(tlsConfig)),
grpc.WithChainUnaryInterceptor( grpc.WithChainUnaryInterceptor(
client_middleware.DefaultTracingClient(),
client_middleware.UnaryActivityClientInterceptor(), client_middleware.UnaryActivityClientInterceptor(),
), ),
grpc.WithStatsHandler(client_middleware.DefaultTracingClient()),
}) })
if err != nil { if err != nil {
return nil, err return nil, err
@ -260,3 +270,13 @@ func grpcCredentials(tlsConfig *tls.Config) credentials.TransportCredentials {
} }
return creds return creds
} }
func setRequestURIPattern(ctx context.Context) {
pattern, ok := runtime.HTTPPathPattern(ctx)
if !ok {
return
}
span := trace.SpanFromContext(ctx)
span.SetName(pattern)
metrics.SetRequestURIPattern(ctx, pattern)
}

View File

@ -1,34 +1,29 @@
package middleware package middleware
import ( import (
"context"
"strings" "strings"
grpc_trace "go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc" grpc_trace "go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc"
"google.golang.org/grpc" "google.golang.org/grpc/stats"
grpc_utils "github.com/zitadel/zitadel/internal/api/grpc" grpc_utils "github.com/zitadel/zitadel/internal/api/grpc"
) )
type GRPCMethod string type GRPCMethod string
func DefaultTracingServer() grpc.UnaryServerInterceptor { func DefaultTracingServer() stats.Handler {
return TracingServer(grpc_utils.Healthz, grpc_utils.Readiness, grpc_utils.Validation) return TracingServer(grpc_utils.Healthz, grpc_utils.Readiness, grpc_utils.Validation)
} }
func TracingServer(ignoredMethods ...GRPCMethod) grpc.UnaryServerInterceptor { func TracingServer(ignoredMethods ...GRPCMethod) stats.Handler {
return func( return grpc_trace.NewServerHandler(grpc_trace.WithFilter(
ctx context.Context, func(info *stats.RPCTagInfo) bool {
req interface{}, for _, ignoredMethod := range ignoredMethods {
info *grpc.UnaryServerInfo, if strings.HasSuffix(info.FullMethodName, string(ignoredMethod)) {
handler grpc.UnaryHandler, return false
) (interface{}, error) { }
for _, ignoredMethod := range ignoredMethods {
if strings.HasSuffix(info.FullMethod, string(ignoredMethod)) {
return handler(ctx, req)
} }
} return true
return grpc_trace.UnaryServerInterceptor()(ctx, req, info, handler) },
} ))
} }

View File

@ -47,7 +47,6 @@ func CreateServer(
grpc.UnaryInterceptor( grpc.UnaryInterceptor(
grpc_middleware.ChainUnaryServer( grpc_middleware.ChainUnaryServer(
middleware.CallDurationHandler(), middleware.CallDurationHandler(),
middleware.DefaultTracingServer(),
middleware.MetricsHandler(metricTypes, grpc_api.Probes...), middleware.MetricsHandler(metricTypes, grpc_api.Probes...),
middleware.NoCacheInterceptor(), middleware.NoCacheInterceptor(),
middleware.InstanceInterceptor(queries, externalDomain, system_pb.SystemService_ServiceDesc.ServiceName, healthpb.Health_ServiceDesc.ServiceName), middleware.InstanceInterceptor(queries, externalDomain, system_pb.SystemService_ServiceDesc.ServiceName, healthpb.Health_ServiceDesc.ServiceName),
@ -63,6 +62,7 @@ func CreateServer(
middleware.ActivityInterceptor(), middleware.ActivityInterceptor(),
), ),
), ),
grpc.StatsHandler(middleware.DefaultTracingServer()),
} }
if tlsConfig != nil { if tlsConfig != nil {
serverOptions = append(serverOptions, grpc.Creds(credentials.NewTLS(tlsConfig))) serverOptions = append(serverOptions, grpc.Creds(credentials.NewTLS(tlsConfig)))

View File

@ -111,9 +111,11 @@ func Start(config Config, externalSecure bool, issuer op.IssuerFromRequest, call
security := middleware.SecurityHeaders(csp(config.PostHog.URL), nil) security := middleware.SecurityHeaders(csp(config.PostHog.URL), nil)
handler := mux.NewRouter() handler := mux.NewRouter()
handler.Use(security, limitingAccessInterceptor.WithoutLimiting().Handle)
handler.Use(callDurationInterceptor, instanceHandler, security, limitingAccessInterceptor.WithoutLimiting().Handle) env := handler.NewRoute().Path(envRequestPath).Subrouter()
handler.Handle(envRequestPath, middleware.TelemetryHandler()(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { env.Use(callDurationInterceptor, middleware.TelemetryHandler(), instanceHandler)
env.HandleFunc("", func(w http.ResponseWriter, r *http.Request) {
url := http_util.BuildOrigin(r.Host, externalSecure) url := http_util.BuildOrigin(r.Host, externalSecure)
ctx := r.Context() ctx := r.Context()
instance := authz.GetInstance(ctx) instance := authz.GetInstance(ctx)
@ -130,7 +132,7 @@ func Start(config Config, externalSecure bool, issuer op.IssuerFromRequest, call
} }
_, err = w.Write(environmentJSON) _, err = w.Write(environmentJSON)
logging.OnError(err).Error("error serving environment.json") logging.OnError(err).Error("error serving environment.json")
}))) })
handler.SkipClean(true).PathPrefix("").Handler(cache(http.FileServer(&spaHandler{http.FS(fSys)}))) handler.SkipClean(true).PathPrefix("").Handler(cache(http.FileServer(&spaHandler{http.FS(fSys)})))
return handler, nil return handler, nil
} }

View File

@ -6,7 +6,6 @@ import (
"github.com/zitadel/logging" "github.com/zitadel/logging"
"go.opentelemetry.io/otel/attribute" "go.opentelemetry.io/otel/attribute"
"github.com/zitadel/zitadel/internal/api/authz"
"github.com/zitadel/zitadel/internal/notification/channels" "github.com/zitadel/zitadel/internal/notification/channels"
"github.com/zitadel/zitadel/internal/telemetry/metrics" "github.com/zitadel/zitadel/internal/telemetry/metrics"
) )
@ -18,18 +17,14 @@ func countMessages(ctx context.Context, channel channels.NotificationChannel, su
if err != nil { if err != nil {
metricName = errorMetricName metricName = errorMetricName
} }
addCount(ctx, metricName, message, err) addCount(ctx, metricName, message)
return err return err
}) })
} }
func addCount(ctx context.Context, metricName string, message channels.Message, err error) { func addCount(ctx context.Context, metricName string, message channels.Message) {
labels := map[string]attribute.Value{ labels := map[string]attribute.Value{
"triggering_event_typey": attribute.StringValue(string(message.GetTriggeringEvent().Type())), "triggering_event_type": attribute.StringValue(string(message.GetTriggeringEvent().Type())),
"instance": attribute.StringValue(authz.GetInstance(ctx).InstanceID()),
}
if err != nil {
labels["error"] = attribute.StringValue(err.Error())
} }
addCountErr := metrics.AddCount(ctx, metricName, 1, labels) addCountErr := metrics.AddCount(ctx, metricName, 1, labels)
logging.WithFields("name", metricName, "labels", labels).OnError(addCountErr).Error("incrementing counter metric failed") logging.WithFields("name", metricName, "labels", labels).OnError(addCountErr).Error("incrementing counter metric failed")

View File

@ -30,5 +30,5 @@ func TelemetryHandler(handler http.Handler, ignoredEndpoints ...string) http.Han
} }
func spanNameFormatter(_ string, r *http.Request) string { func spanNameFormatter(_ string, r *http.Request) string {
return r.Host + r.URL.EscapedPath() return strings.Split(r.RequestURI, "?")[0]
} }

View File

@ -1,6 +1,7 @@
package metrics package metrics
import ( import (
"context"
"net/http" "net/http"
"strings" "strings"
@ -35,7 +36,8 @@ const (
type StatusRecorder struct { type StatusRecorder struct {
http.ResponseWriter http.ResponseWriter
Status int RequestURI *string
Status int
} }
func (r *StatusRecorder) WriteHeader(status int) { func (r *StatusRecorder) WriteHeader(status int) {
@ -56,6 +58,18 @@ func NewMetricsHandler(handler http.Handler, metricMethods []MetricType, ignored
return &h return &h
} }
type key int
const requestURI key = iota
func SetRequestURIPattern(ctx context.Context, pattern string) {
uri, ok := ctx.Value(requestURI).(*string)
if !ok {
return
}
*uri = pattern
}
// ServeHTTP serves HTTP requests (http.Handler) // ServeHTTP serves HTTP requests (http.Handler)
func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
if len(h.methods) == 0 { if len(h.methods) == 0 {
@ -69,13 +83,16 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
return return
} }
} }
uri := strings.Split(r.RequestURI, "?")[0]
recorder := &StatusRecorder{ recorder := &StatusRecorder{
ResponseWriter: w, ResponseWriter: w,
RequestURI: &uri,
Status: 200, Status: 200,
} }
r = r.WithContext(context.WithValue(r.Context(), requestURI, &uri))
h.handler.ServeHTTP(recorder, r) h.handler.ServeHTTP(recorder, r)
if h.containsMetricsMethod(MetricTypeRequestCount) { if h.containsMetricsMethod(MetricTypeRequestCount) {
RegisterRequestCounter(r) RegisterRequestCounter(recorder, r)
} }
if h.containsMetricsMethod(MetricTypeTotalCount) { if h.containsMetricsMethod(MetricTypeTotalCount) {
RegisterTotalRequestCounter(r) RegisterTotalRequestCounter(r)
@ -94,9 +111,9 @@ func (h *Handler) containsMetricsMethod(method MetricType) bool {
return false return false
} }
func RegisterRequestCounter(r *http.Request) { func RegisterRequestCounter(recorder *StatusRecorder, r *http.Request) {
var labels = map[string]attribute.Value{ var labels = map[string]attribute.Value{
URI: attribute.StringValue(strings.Split(r.RequestURI, "?")[0]), URI: attribute.StringValue(*recorder.RequestURI),
Method: attribute.StringValue(r.Method), Method: attribute.StringValue(r.Method),
} }
RegisterCounter(RequestCounter, RequestCountDescription) RegisterCounter(RequestCounter, RequestCountDescription)
@ -110,7 +127,7 @@ func RegisterTotalRequestCounter(r *http.Request) {
func RegisterRequestCodeCounter(recorder *StatusRecorder, r *http.Request) { func RegisterRequestCodeCounter(recorder *StatusRecorder, r *http.Request) {
var labels = map[string]attribute.Value{ var labels = map[string]attribute.Value{
URI: attribute.StringValue(strings.Split(r.RequestURI, "?")[0]), URI: attribute.StringValue(*recorder.RequestURI),
Method: attribute.StringValue(r.Method), Method: attribute.StringValue(r.Method),
ReturnCode: attribute.IntValue(recorder.Status), ReturnCode: attribute.IntValue(recorder.Status),
} }

View File

@ -6,9 +6,11 @@ import (
"sync" "sync"
"github.com/prometheus/client_golang/prometheus/promhttp" "github.com/prometheus/client_golang/prometheus/promhttp"
"go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp"
"go.opentelemetry.io/otel/attribute" "go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/exporters/prometheus" "go.opentelemetry.io/otel/exporters/prometheus"
"go.opentelemetry.io/otel/metric" "go.opentelemetry.io/otel/metric"
"go.opentelemetry.io/otel/sdk/instrumentation"
sdk_metric "go.opentelemetry.io/otel/sdk/metric" sdk_metric "go.opentelemetry.io/otel/sdk/metric"
"github.com/zitadel/zitadel/internal/telemetry/metrics" "github.com/zitadel/zitadel/internal/telemetry/metrics"
@ -33,9 +35,19 @@ func NewMetrics(meterName string) (metrics.Metrics, error) {
if err != nil { if err != nil {
return &Metrics{}, err return &Metrics{}, err
} }
// create a view to filter out unwanted attributes
view := sdk_metric.NewView(
sdk_metric.Instrument{
Scope: instrumentation.Scope{Name: otelhttp.ScopeName},
},
sdk_metric.Stream{
AttributeFilter: attribute.NewAllowKeysFilter("http.method", "http.status_code", "http.target"),
},
)
meterProvider := sdk_metric.NewMeterProvider( meterProvider := sdk_metric.NewMeterProvider(
sdk_metric.WithReader(exporter), sdk_metric.WithReader(exporter),
sdk_metric.WithResource(resource), sdk_metric.WithResource(resource),
sdk_metric.WithView(view),
) )
return &Metrics{ return &Metrics{
Provider: meterProvider, Provider: meterProvider,

View File

@ -28,7 +28,7 @@ type Tracer struct {
} }
func (c *Config) NewTracer() error { func (c *Config) NewTracer() error {
sampler := sdk_trace.ParentBased(sdk_trace.TraceIDRatioBased(c.Fraction)) sampler := otel.NewSampler(sdk_trace.TraceIDRatioBased(c.Fraction))
exporter, err := texporter.New(texporter.WithProjectID(c.ProjectID)) exporter, err := texporter.New(texporter.WithProjectID(c.ProjectID))
if err != nil { if err != nil {
return err return err

View File

@ -26,7 +26,7 @@ type Tracer struct {
} }
func (c *Config) NewTracer() error { func (c *Config) NewTracer() error {
sampler := sdk_trace.ParentBased(sdk_trace.TraceIDRatioBased(c.Fraction)) sampler := otel.NewSampler(sdk_trace.TraceIDRatioBased(c.Fraction))
exporter, err := stdout.New(stdout.WithPrettyPrint()) exporter, err := stdout.New(stdout.WithPrettyPrint())
if err != nil { if err != nil {
return err return err

View File

@ -6,6 +6,7 @@ import (
otlpgrpc "go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc" otlpgrpc "go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc"
sdk_trace "go.opentelemetry.io/otel/sdk/trace" sdk_trace "go.opentelemetry.io/otel/sdk/trace"
api_trace "go.opentelemetry.io/otel/trace"
"github.com/zitadel/zitadel/internal/telemetry/tracing" "github.com/zitadel/zitadel/internal/telemetry/tracing"
"github.com/zitadel/zitadel/internal/zerrors" "github.com/zitadel/zitadel/internal/zerrors"
@ -47,7 +48,7 @@ func FractionFromConfig(i interface{}) (float64, error) {
} }
func (c *Config) NewTracer() error { func (c *Config) NewTracer() error {
sampler := sdk_trace.ParentBased(sdk_trace.TraceIDRatioBased(c.Fraction)) sampler := NewSampler(sdk_trace.TraceIDRatioBased(c.Fraction))
exporter, err := otlpgrpc.New(context.Background(), otlpgrpc.WithEndpoint(c.Endpoint), otlpgrpc.WithInsecure()) exporter, err := otlpgrpc.New(context.Background(), otlpgrpc.WithEndpoint(c.Endpoint), otlpgrpc.WithInsecure())
if err != nil { if err != nil {
return err return err
@ -56,3 +57,19 @@ func (c *Config) NewTracer() error {
tracing.T, err = NewTracer(sampler, exporter) tracing.T, err = NewTracer(sampler, exporter)
return err return err
} }
// NewSampler returns a sampler decorator which behaves differently,
// based on the parent of the span. If the span has no parent and is of kind server,
// the decorated sampler is used to make sampling decision.
// If the span has a parent, depending on whether the parent is remote and whether it
// is sampled, one of the following samplers will apply:
// - remote parent sampled -> always sample
// - remote parent not sampled -> sample based on the decorated sampler (fraction based)
// - local parent sampled -> always sample
// - local parent not sampled -> never sample
func NewSampler(sampler sdk_trace.Sampler) sdk_trace.Sampler {
return sdk_trace.ParentBased(
tracing.SpanKindBased(sampler, api_trace.SpanKindServer),
sdk_trace.WithRemoteParentNotSampled(sampler),
)
}

View File

@ -0,0 +1,46 @@
package tracing
import (
"fmt"
"slices"
sdk_trace "go.opentelemetry.io/otel/sdk/trace"
"go.opentelemetry.io/otel/trace"
)
type spanKindSampler struct {
sampler sdk_trace.Sampler
kinds []trace.SpanKind
}
// ShouldSample implements the [sdk_trace.Sampler] interface.
// It will not sample any spans which do not match the configured span kinds.
// For spans which do match, the decorated sampler is used to make the sampling decision.
func (sk spanKindSampler) ShouldSample(p sdk_trace.SamplingParameters) sdk_trace.SamplingResult {
psc := trace.SpanContextFromContext(p.ParentContext)
if !slices.Contains(sk.kinds, p.Kind) {
return sdk_trace.SamplingResult{
Decision: sdk_trace.Drop,
Tracestate: psc.TraceState(),
}
}
s := sk.sampler.ShouldSample(p)
return s
}
func (sk spanKindSampler) Description() string {
return fmt.Sprintf("SpanKindBased{sampler:%s,kinds:%v}",
sk.sampler.Description(),
sk.kinds,
)
}
// SpanKindBased returns a sampler decorator which behaves differently, based on the kind of the span.
// If the span kind does not match one of the configured kinds, it will not be sampled.
// If the span kind matches, the decorated sampler is used to make sampling decision.
func SpanKindBased(sampler sdk_trace.Sampler, kinds ...trace.SpanKind) sdk_trace.Sampler {
return spanKindSampler{
sampler: sampler,
kinds: kinds,
}
}

View File

@ -0,0 +1,80 @@
package tracing
import (
"context"
"testing"
"github.com/stretchr/testify/assert"
sdk_trace "go.opentelemetry.io/otel/sdk/trace"
"go.opentelemetry.io/otel/trace"
)
func TestSpanKindBased(t *testing.T) {
type args struct {
sampler sdk_trace.Sampler
kinds []trace.SpanKind
}
type want struct {
description string
sampled int
}
tests := []struct {
name string
args args
want want
}{
{
"never sample, no sample",
args{
sampler: sdk_trace.NeverSample(),
kinds: []trace.SpanKind{trace.SpanKindServer},
},
want{
description: "SpanKindBased{sampler:AlwaysOffSampler,kinds:[server]}",
sampled: 0,
},
},
{
"always sample, no kind, no sample",
args{
sampler: sdk_trace.AlwaysSample(),
kinds: nil,
},
want{
description: "SpanKindBased{sampler:AlwaysOnSampler,kinds:[]}",
sampled: 0,
},
},
{
"always sample, 2 kinds, 2 samples",
args{
sampler: sdk_trace.AlwaysSample(),
kinds: []trace.SpanKind{trace.SpanKindServer, trace.SpanKindClient},
},
want{
description: "SpanKindBased{sampler:AlwaysOnSampler,kinds:[server client]}",
sampled: 2,
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
sampler := SpanKindBased(tt.args.sampler, tt.args.kinds...)
assert.Equal(t, tt.want.description, sampler.Description())
p := sdk_trace.NewTracerProvider(sdk_trace.WithSampler(sampler))
tr := p.Tracer("test")
var sampled int
for i := trace.SpanKindUnspecified; i <= trace.SpanKindConsumer; i++ {
ctx := context.Background()
_, span := tr.Start(ctx, "test", trace.WithSpanKind(i))
if span.SpanContext().IsSampled() {
sampled++
}
}
assert.Equal(t, tt.want.sampled, sampled)
})
}
}