diff --git a/internal/api/grpc/client/middleware/tracing.go b/internal/api/grpc/client/middleware/tracing.go index 32a144c2f1..05e4a5f40e 100644 --- a/internal/api/grpc/client/middleware/tracing.go +++ b/internal/api/grpc/client/middleware/tracing.go @@ -1,36 +1,29 @@ package middleware import ( - "context" "strings" grpc_trace "go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc" - "google.golang.org/grpc" + "google.golang.org/grpc/stats" grpc_utils "github.com/zitadel/zitadel/internal/api/grpc" ) type GRPCMethod string -func DefaultTracingClient() grpc.UnaryClientInterceptor { - return TracingServer(grpc_utils.Healthz, grpc_utils.Readiness, grpc_utils.Validation) +func DefaultTracingClient() stats.Handler { + return TracingClient(grpc_utils.Healthz, grpc_utils.Readiness, grpc_utils.Validation) } -func TracingServer(ignoredMethods ...GRPCMethod) grpc.UnaryClientInterceptor { - return func( - ctx context.Context, - method string, - req, reply interface{}, - cc *grpc.ClientConn, - invoker grpc.UnaryInvoker, - opts ...grpc.CallOption, - ) error { - - for _, ignoredMethod := range ignoredMethods { - if strings.HasSuffix(method, string(ignoredMethod)) { - return invoker(ctx, method, req, reply, cc, opts...) +func TracingClient(ignoredMethods ...GRPCMethod) stats.Handler { + return grpc_trace.NewClientHandler(grpc_trace.WithFilter( + func(info *stats.RPCTagInfo) bool { + for _, ignoredMethod := range ignoredMethods { + if strings.HasSuffix(info.FullMethodName, string(ignoredMethod)) { + return false + } } - } - return grpc_trace.UnaryClientInterceptor()(ctx, method, req, reply, cc, invoker, opts...) - } + return true + }, + )) } diff --git a/internal/api/grpc/server/gateway.go b/internal/api/grpc/server/gateway.go index 3e41be94ae..43947917a2 100644 --- a/internal/api/grpc/server/gateway.go +++ b/internal/api/grpc/server/gateway.go @@ -10,6 +10,7 @@ import ( "github.com/grpc-ecosystem/grpc-gateway/v2/runtime" "github.com/zitadel/logging" + "go.opentelemetry.io/otel/trace" "google.golang.org/grpc" "google.golang.org/grpc/codes" "google.golang.org/grpc/credentials" @@ -56,6 +57,13 @@ var ( }, ) + // we need the errorHandler to set the request URI pattern in case of an error + errorHandler = runtime.ErrorHandlerFunc( + func(ctx context.Context, mux *runtime.ServeMux, marshaler runtime.Marshaler, w http.ResponseWriter, r *http.Request, err error) { + setRequestURIPattern(ctx) + runtime.DefaultHTTPErrorHandler(ctx, mux, marshaler, w, r, err) + }) + serveMuxOptions = func(hostHeaders []string) []runtime.ServeMuxOption { return []runtime.ServeMuxOption{ runtime.WithMarshalerOption(jsonMarshaler.ContentType(nil), jsonMarshaler), @@ -65,6 +73,7 @@ var ( runtime.WithOutgoingHeaderMatcher(runtime.DefaultHeaderMatcher), runtime.WithForwardResponseOption(responseForwarder), runtime.WithRoutingErrorHandler(httpErrorHandler), + runtime.WithErrorHandler(errorHandler), } } @@ -81,6 +90,7 @@ var ( } responseForwarder = func(ctx context.Context, w http.ResponseWriter, resp proto.Message) error { + setRequestURIPattern(ctx) t, ok := resp.(CustomHTTPResponse) if ok { // TODO: find a way to return a location header if needed w.Header().Set("location", t.Location()) @@ -118,9 +128,9 @@ func CreateGatewayWithPrefix( opts := []grpc.DialOption{ grpc.WithTransportCredentials(grpcCredentials(tlsConfig)), grpc.WithChainUnaryInterceptor( - client_middleware.DefaultTracingClient(), client_middleware.UnaryActivityClientInterceptor(), ), + grpc.WithStatsHandler(client_middleware.DefaultTracingClient()), } connection, err := dial(ctx, port, opts) if err != nil { @@ -145,9 +155,9 @@ func CreateGateway( []grpc.DialOption{ grpc.WithTransportCredentials(grpcCredentials(tlsConfig)), grpc.WithChainUnaryInterceptor( - client_middleware.DefaultTracingClient(), client_middleware.UnaryActivityClientInterceptor(), ), + grpc.WithStatsHandler(client_middleware.DefaultTracingClient()), }) if err != nil { return nil, err @@ -260,3 +270,13 @@ func grpcCredentials(tlsConfig *tls.Config) credentials.TransportCredentials { } return creds } + +func setRequestURIPattern(ctx context.Context) { + pattern, ok := runtime.HTTPPathPattern(ctx) + if !ok { + return + } + span := trace.SpanFromContext(ctx) + span.SetName(pattern) + metrics.SetRequestURIPattern(ctx, pattern) +} diff --git a/internal/api/grpc/server/middleware/tracing.go b/internal/api/grpc/server/middleware/tracing.go index 748257551c..6781e67cf2 100644 --- a/internal/api/grpc/server/middleware/tracing.go +++ b/internal/api/grpc/server/middleware/tracing.go @@ -1,34 +1,29 @@ package middleware import ( - "context" "strings" grpc_trace "go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc" - "google.golang.org/grpc" + "google.golang.org/grpc/stats" grpc_utils "github.com/zitadel/zitadel/internal/api/grpc" ) type GRPCMethod string -func DefaultTracingServer() grpc.UnaryServerInterceptor { +func DefaultTracingServer() stats.Handler { return TracingServer(grpc_utils.Healthz, grpc_utils.Readiness, grpc_utils.Validation) } -func TracingServer(ignoredMethods ...GRPCMethod) grpc.UnaryServerInterceptor { - return func( - ctx context.Context, - req interface{}, - info *grpc.UnaryServerInfo, - handler grpc.UnaryHandler, - ) (interface{}, error) { - - for _, ignoredMethod := range ignoredMethods { - if strings.HasSuffix(info.FullMethod, string(ignoredMethod)) { - return handler(ctx, req) +func TracingServer(ignoredMethods ...GRPCMethod) stats.Handler { + return grpc_trace.NewServerHandler(grpc_trace.WithFilter( + func(info *stats.RPCTagInfo) bool { + for _, ignoredMethod := range ignoredMethods { + if strings.HasSuffix(info.FullMethodName, string(ignoredMethod)) { + return false + } } - } - return grpc_trace.UnaryServerInterceptor()(ctx, req, info, handler) - } + return true + }, + )) } diff --git a/internal/api/grpc/server/server.go b/internal/api/grpc/server/server.go index 5408ae257f..27b921b7d5 100644 --- a/internal/api/grpc/server/server.go +++ b/internal/api/grpc/server/server.go @@ -47,7 +47,6 @@ func CreateServer( grpc.UnaryInterceptor( grpc_middleware.ChainUnaryServer( middleware.CallDurationHandler(), - middleware.DefaultTracingServer(), middleware.MetricsHandler(metricTypes, grpc_api.Probes...), middleware.NoCacheInterceptor(), middleware.InstanceInterceptor(queries, externalDomain, system_pb.SystemService_ServiceDesc.ServiceName, healthpb.Health_ServiceDesc.ServiceName), @@ -63,6 +62,7 @@ func CreateServer( middleware.ActivityInterceptor(), ), ), + grpc.StatsHandler(middleware.DefaultTracingServer()), } if tlsConfig != nil { serverOptions = append(serverOptions, grpc.Creds(credentials.NewTLS(tlsConfig))) diff --git a/internal/api/ui/console/console.go b/internal/api/ui/console/console.go index e844ea20a8..1fcd3450a3 100644 --- a/internal/api/ui/console/console.go +++ b/internal/api/ui/console/console.go @@ -111,9 +111,11 @@ func Start(config Config, externalSecure bool, issuer op.IssuerFromRequest, call security := middleware.SecurityHeaders(csp(config.PostHog.URL), nil) handler := mux.NewRouter() + handler.Use(security, limitingAccessInterceptor.WithoutLimiting().Handle) - handler.Use(callDurationInterceptor, instanceHandler, security, limitingAccessInterceptor.WithoutLimiting().Handle) - handler.Handle(envRequestPath, middleware.TelemetryHandler()(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + env := handler.NewRoute().Path(envRequestPath).Subrouter() + env.Use(callDurationInterceptor, middleware.TelemetryHandler(), instanceHandler) + env.HandleFunc("", func(w http.ResponseWriter, r *http.Request) { url := http_util.BuildOrigin(r.Host, externalSecure) ctx := r.Context() instance := authz.GetInstance(ctx) @@ -130,7 +132,7 @@ func Start(config Config, externalSecure bool, issuer op.IssuerFromRequest, call } _, err = w.Write(environmentJSON) logging.OnError(err).Error("error serving environment.json") - }))) + }) handler.SkipClean(true).PathPrefix("").Handler(cache(http.FileServer(&spaHandler{http.FS(fSys)}))) return handler, nil } diff --git a/internal/notification/channels/instrumenting/metrics.go b/internal/notification/channels/instrumenting/metrics.go index 6b76c8f788..09a402e63a 100644 --- a/internal/notification/channels/instrumenting/metrics.go +++ b/internal/notification/channels/instrumenting/metrics.go @@ -6,7 +6,6 @@ import ( "github.com/zitadel/logging" "go.opentelemetry.io/otel/attribute" - "github.com/zitadel/zitadel/internal/api/authz" "github.com/zitadel/zitadel/internal/notification/channels" "github.com/zitadel/zitadel/internal/telemetry/metrics" ) @@ -18,18 +17,14 @@ func countMessages(ctx context.Context, channel channels.NotificationChannel, su if err != nil { metricName = errorMetricName } - addCount(ctx, metricName, message, err) + addCount(ctx, metricName, message) return err }) } -func addCount(ctx context.Context, metricName string, message channels.Message, err error) { +func addCount(ctx context.Context, metricName string, message channels.Message) { labels := map[string]attribute.Value{ - "triggering_event_typey": attribute.StringValue(string(message.GetTriggeringEvent().Type())), - "instance": attribute.StringValue(authz.GetInstance(ctx).InstanceID()), - } - if err != nil { - labels["error"] = attribute.StringValue(err.Error()) + "triggering_event_type": attribute.StringValue(string(message.GetTriggeringEvent().Type())), } addCountErr := metrics.AddCount(ctx, metricName, 1, labels) logging.WithFields("name", metricName, "labels", labels).OnError(addCountErr).Error("incrementing counter metric failed") diff --git a/internal/telemetry/http_handler.go b/internal/telemetry/http_handler.go index 32e614b34c..ff70978a5a 100644 --- a/internal/telemetry/http_handler.go +++ b/internal/telemetry/http_handler.go @@ -30,5 +30,5 @@ func TelemetryHandler(handler http.Handler, ignoredEndpoints ...string) http.Han } func spanNameFormatter(_ string, r *http.Request) string { - return r.Host + r.URL.EscapedPath() + return strings.Split(r.RequestURI, "?")[0] } diff --git a/internal/telemetry/metrics/http_handler.go b/internal/telemetry/metrics/http_handler.go index 8a2ed64fc2..cdfe21ef4f 100644 --- a/internal/telemetry/metrics/http_handler.go +++ b/internal/telemetry/metrics/http_handler.go @@ -1,6 +1,7 @@ package metrics import ( + "context" "net/http" "strings" @@ -35,7 +36,8 @@ const ( type StatusRecorder struct { http.ResponseWriter - Status int + RequestURI *string + Status int } func (r *StatusRecorder) WriteHeader(status int) { @@ -56,6 +58,18 @@ func NewMetricsHandler(handler http.Handler, metricMethods []MetricType, ignored return &h } +type key int + +const requestURI key = iota + +func SetRequestURIPattern(ctx context.Context, pattern string) { + uri, ok := ctx.Value(requestURI).(*string) + if !ok { + return + } + *uri = pattern +} + // ServeHTTP serves HTTP requests (http.Handler) func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { if len(h.methods) == 0 { @@ -69,13 +83,16 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { return } } + uri := strings.Split(r.RequestURI, "?")[0] recorder := &StatusRecorder{ ResponseWriter: w, + RequestURI: &uri, Status: 200, } + r = r.WithContext(context.WithValue(r.Context(), requestURI, &uri)) h.handler.ServeHTTP(recorder, r) if h.containsMetricsMethod(MetricTypeRequestCount) { - RegisterRequestCounter(r) + RegisterRequestCounter(recorder, r) } if h.containsMetricsMethod(MetricTypeTotalCount) { RegisterTotalRequestCounter(r) @@ -94,9 +111,9 @@ func (h *Handler) containsMetricsMethod(method MetricType) bool { return false } -func RegisterRequestCounter(r *http.Request) { +func RegisterRequestCounter(recorder *StatusRecorder, r *http.Request) { var labels = map[string]attribute.Value{ - URI: attribute.StringValue(strings.Split(r.RequestURI, "?")[0]), + URI: attribute.StringValue(*recorder.RequestURI), Method: attribute.StringValue(r.Method), } RegisterCounter(RequestCounter, RequestCountDescription) @@ -110,7 +127,7 @@ func RegisterTotalRequestCounter(r *http.Request) { func RegisterRequestCodeCounter(recorder *StatusRecorder, r *http.Request) { var labels = map[string]attribute.Value{ - URI: attribute.StringValue(strings.Split(r.RequestURI, "?")[0]), + URI: attribute.StringValue(*recorder.RequestURI), Method: attribute.StringValue(r.Method), ReturnCode: attribute.IntValue(recorder.Status), } diff --git a/internal/telemetry/metrics/otel/open_telemetry.go b/internal/telemetry/metrics/otel/open_telemetry.go index ed1a94f840..21a45699f1 100644 --- a/internal/telemetry/metrics/otel/open_telemetry.go +++ b/internal/telemetry/metrics/otel/open_telemetry.go @@ -6,9 +6,11 @@ import ( "sync" "github.com/prometheus/client_golang/prometheus/promhttp" + "go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp" "go.opentelemetry.io/otel/attribute" "go.opentelemetry.io/otel/exporters/prometheus" "go.opentelemetry.io/otel/metric" + "go.opentelemetry.io/otel/sdk/instrumentation" sdk_metric "go.opentelemetry.io/otel/sdk/metric" "github.com/zitadel/zitadel/internal/telemetry/metrics" @@ -33,9 +35,19 @@ func NewMetrics(meterName string) (metrics.Metrics, error) { if err != nil { return &Metrics{}, err } + // create a view to filter out unwanted attributes + view := sdk_metric.NewView( + sdk_metric.Instrument{ + Scope: instrumentation.Scope{Name: otelhttp.ScopeName}, + }, + sdk_metric.Stream{ + AttributeFilter: attribute.NewAllowKeysFilter("http.method", "http.status_code", "http.target"), + }, + ) meterProvider := sdk_metric.NewMeterProvider( sdk_metric.WithReader(exporter), sdk_metric.WithResource(resource), + sdk_metric.WithView(view), ) return &Metrics{ Provider: meterProvider, diff --git a/internal/telemetry/tracing/google/google_tracer.go b/internal/telemetry/tracing/google/google_tracer.go index 4787a9fd09..8d15bb18fb 100644 --- a/internal/telemetry/tracing/google/google_tracer.go +++ b/internal/telemetry/tracing/google/google_tracer.go @@ -28,7 +28,7 @@ type Tracer struct { } func (c *Config) NewTracer() error { - sampler := sdk_trace.ParentBased(sdk_trace.TraceIDRatioBased(c.Fraction)) + sampler := otel.NewSampler(sdk_trace.TraceIDRatioBased(c.Fraction)) exporter, err := texporter.New(texporter.WithProjectID(c.ProjectID)) if err != nil { return err diff --git a/internal/telemetry/tracing/log/config.go b/internal/telemetry/tracing/log/config.go index e6a8c13f0b..862e14624c 100644 --- a/internal/telemetry/tracing/log/config.go +++ b/internal/telemetry/tracing/log/config.go @@ -26,7 +26,7 @@ type Tracer struct { } func (c *Config) NewTracer() error { - sampler := sdk_trace.ParentBased(sdk_trace.TraceIDRatioBased(c.Fraction)) + sampler := otel.NewSampler(sdk_trace.TraceIDRatioBased(c.Fraction)) exporter, err := stdout.New(stdout.WithPrettyPrint()) if err != nil { return err diff --git a/internal/telemetry/tracing/otel/config.go b/internal/telemetry/tracing/otel/config.go index a9f9168c1b..5b417359b9 100644 --- a/internal/telemetry/tracing/otel/config.go +++ b/internal/telemetry/tracing/otel/config.go @@ -6,6 +6,7 @@ import ( otlpgrpc "go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc" sdk_trace "go.opentelemetry.io/otel/sdk/trace" + api_trace "go.opentelemetry.io/otel/trace" "github.com/zitadel/zitadel/internal/telemetry/tracing" "github.com/zitadel/zitadel/internal/zerrors" @@ -47,7 +48,7 @@ func FractionFromConfig(i interface{}) (float64, error) { } func (c *Config) NewTracer() error { - sampler := sdk_trace.ParentBased(sdk_trace.TraceIDRatioBased(c.Fraction)) + sampler := NewSampler(sdk_trace.TraceIDRatioBased(c.Fraction)) exporter, err := otlpgrpc.New(context.Background(), otlpgrpc.WithEndpoint(c.Endpoint), otlpgrpc.WithInsecure()) if err != nil { return err @@ -56,3 +57,19 @@ func (c *Config) NewTracer() error { tracing.T, err = NewTracer(sampler, exporter) return err } + +// NewSampler returns a sampler decorator which behaves differently, +// based on the parent of the span. If the span has no parent and is of kind server, +// the decorated sampler is used to make sampling decision. +// If the span has a parent, depending on whether the parent is remote and whether it +// is sampled, one of the following samplers will apply: +// - remote parent sampled -> always sample +// - remote parent not sampled -> sample based on the decorated sampler (fraction based) +// - local parent sampled -> always sample +// - local parent not sampled -> never sample +func NewSampler(sampler sdk_trace.Sampler) sdk_trace.Sampler { + return sdk_trace.ParentBased( + tracing.SpanKindBased(sampler, api_trace.SpanKindServer), + sdk_trace.WithRemoteParentNotSampled(sampler), + ) +} diff --git a/internal/telemetry/tracing/sampler.go b/internal/telemetry/tracing/sampler.go new file mode 100644 index 0000000000..4ea53980b8 --- /dev/null +++ b/internal/telemetry/tracing/sampler.go @@ -0,0 +1,46 @@ +package tracing + +import ( + "fmt" + "slices" + + sdk_trace "go.opentelemetry.io/otel/sdk/trace" + "go.opentelemetry.io/otel/trace" +) + +type spanKindSampler struct { + sampler sdk_trace.Sampler + kinds []trace.SpanKind +} + +// ShouldSample implements the [sdk_trace.Sampler] interface. +// It will not sample any spans which do not match the configured span kinds. +// For spans which do match, the decorated sampler is used to make the sampling decision. +func (sk spanKindSampler) ShouldSample(p sdk_trace.SamplingParameters) sdk_trace.SamplingResult { + psc := trace.SpanContextFromContext(p.ParentContext) + if !slices.Contains(sk.kinds, p.Kind) { + return sdk_trace.SamplingResult{ + Decision: sdk_trace.Drop, + Tracestate: psc.TraceState(), + } + } + s := sk.sampler.ShouldSample(p) + return s +} + +func (sk spanKindSampler) Description() string { + return fmt.Sprintf("SpanKindBased{sampler:%s,kinds:%v}", + sk.sampler.Description(), + sk.kinds, + ) +} + +// SpanKindBased returns a sampler decorator which behaves differently, based on the kind of the span. +// If the span kind does not match one of the configured kinds, it will not be sampled. +// If the span kind matches, the decorated sampler is used to make sampling decision. +func SpanKindBased(sampler sdk_trace.Sampler, kinds ...trace.SpanKind) sdk_trace.Sampler { + return spanKindSampler{ + sampler: sampler, + kinds: kinds, + } +} diff --git a/internal/telemetry/tracing/sampler_test.go b/internal/telemetry/tracing/sampler_test.go new file mode 100644 index 0000000000..1f2070a0fb --- /dev/null +++ b/internal/telemetry/tracing/sampler_test.go @@ -0,0 +1,80 @@ +package tracing + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" + sdk_trace "go.opentelemetry.io/otel/sdk/trace" + "go.opentelemetry.io/otel/trace" +) + +func TestSpanKindBased(t *testing.T) { + type args struct { + sampler sdk_trace.Sampler + kinds []trace.SpanKind + } + type want struct { + description string + sampled int + } + tests := []struct { + name string + args args + want want + }{ + { + "never sample, no sample", + args{ + sampler: sdk_trace.NeverSample(), + kinds: []trace.SpanKind{trace.SpanKindServer}, + }, + want{ + description: "SpanKindBased{sampler:AlwaysOffSampler,kinds:[server]}", + sampled: 0, + }, + }, + { + "always sample, no kind, no sample", + args{ + sampler: sdk_trace.AlwaysSample(), + kinds: nil, + }, + want{ + description: "SpanKindBased{sampler:AlwaysOnSampler,kinds:[]}", + sampled: 0, + }, + }, + { + "always sample, 2 kinds, 2 samples", + args{ + sampler: sdk_trace.AlwaysSample(), + kinds: []trace.SpanKind{trace.SpanKindServer, trace.SpanKindClient}, + }, + want{ + description: "SpanKindBased{sampler:AlwaysOnSampler,kinds:[server client]}", + sampled: 2, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + sampler := SpanKindBased(tt.args.sampler, tt.args.kinds...) + assert.Equal(t, tt.want.description, sampler.Description()) + + p := sdk_trace.NewTracerProvider(sdk_trace.WithSampler(sampler)) + tr := p.Tracer("test") + + var sampled int + for i := trace.SpanKindUnspecified; i <= trace.SpanKindConsumer; i++ { + ctx := context.Background() + _, span := tr.Start(ctx, "test", trace.WithSpanKind(i)) + if span.SpanContext().IsSampled() { + sampled++ + } + } + + assert.Equal(t, tt.want.sampled, sampled) + }) + } +}