logging and otel

This commit is contained in:
Tim Möhlmann
2023-11-05 13:58:22 +02:00
parent 66f91cdc4e
commit 96a53aa130
6 changed files with 116 additions and 65 deletions

View File

@@ -10,14 +10,18 @@ import (
"github.com/zitadel/oidc/v3/pkg/oidc" "github.com/zitadel/oidc/v3/pkg/oidc"
"github.com/zitadel/oidc/v3/pkg/op" "github.com/zitadel/oidc/v3/pkg/op"
"github.com/zitadel/zitadel/internal/command" "github.com/zitadel/zitadel/internal/command"
"github.com/zitadel/zitadel/internal/crypto" "github.com/zitadel/zitadel/internal/crypto"
errz "github.com/zitadel/zitadel/internal/errors" errz "github.com/zitadel/zitadel/internal/errors"
"github.com/zitadel/zitadel/internal/query" "github.com/zitadel/zitadel/internal/query"
"github.com/zitadel/zitadel/internal/telemetry/tracing" "github.com/zitadel/zitadel/internal/telemetry/tracing"
"github.com/zitadel/zitadel/internal/user/model"
) )
func (s *Server) Introspect(ctx context.Context, r *op.Request[op.IntrospectionRequest]) (_ *op.Response, err error) { func (s *Server) Introspect(ctx context.Context, r *op.Request[op.IntrospectionRequest]) (resp *op.Response, err error) {
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()
ctx, cancel := context.WithCancel(ctx) ctx, cancel := context.WithCancel(ctx)
defer cancel() defer cancel()
@@ -62,33 +66,40 @@ func (s *Server) Introspect(ctx context.Context, r *op.Request[op.IntrospectionR
return nil, err return nil, err
} }
// all other errors should result in a response with active: false. // remaining errors shoudn't be returned to the client,
response := new(oidc.IntrospectionResponse) // so we catch errors here, log them and return the response
// with active: false
defer func() {
if err != nil {
s.getLogger(ctx).ErrorContext(ctx, "oidc introspection", "err", err)
}
resp, err = op.NewResponse(new(oidc.IntrospectionResponse)), nil
}()
if err != nil { if err != nil {
// TODO: log error return nil, err
return op.NewResponse(response), nil
} }
if err = validateIntrospectionAudience(token.audience, client.clientID, client.projectID); err != nil { if err = validateIntrospectionAudience(token.audience, client.clientID, client.projectID); err != nil {
// TODO: log error return nil, err
return op.NewResponse(response), nil
} }
userInfo, err := s.storage.query.GetOIDCUserinfo(ctx, token.userID, token.scope, []string{client.projectID}) userInfo, err := s.storage.query.GetOIDCUserinfo(ctx, token.userID, token.scope, []string{client.projectID})
if err != nil { if err != nil {
// TODO: log error return nil, err
return op.NewResponse(response), nil
} }
response.SetUserInfo(userinfoToOIDC(userInfo, token.scope)) introspectionResp := &oidc.IntrospectionResponse{
response.Scope = token.scope Active: true,
response.ClientID = token.clientID Scope: token.scope,
response.TokenType = oidc.BearerToken ClientID: token.clientID,
response.Expiration = oidc.FromTime(token.tokenExpiration) TokenType: oidc.BearerToken,
response.IssuedAt = oidc.FromTime(token.tokenCreation) Expiration: oidc.FromTime(token.tokenExpiration),
response.NotBefore = oidc.FromTime(token.tokenCreation) IssuedAt: oidc.FromTime(token.tokenCreation),
response.Audience = token.audience NotBefore: oidc.FromTime(token.tokenCreation),
response.Issuer = op.IssuerFromContext(ctx) Audience: token.audience,
response.JWTID = token.tokenID Issuer: op.IssuerFromContext(ctx),
response.Active = true JWTID: token.tokenID,
return op.NewResponse(response), nil }
introspectionResp.SetUserInfo(userinfoToOIDC(userInfo, token.scope))
return op.NewResponse(introspectionResp), nil
} }
type instrospectionClientResult struct { type instrospectionClientResult struct {
@@ -163,52 +174,53 @@ type introspectionTokenResult struct {
} }
func (s *Server) introspectionToken(ctx context.Context, accessToken string, rc chan<- *introspectionTokenResult) { func (s *Server) introspectionToken(ctx context.Context, accessToken string, rc chan<- *introspectionTokenResult) {
var tokenID, subject string ctx, span := tracing.NewSpan(ctx)
if tokenIDSubject, err := s.Provider().Crypto().Decrypt(accessToken); err == nil { result, err := func() (_ *introspectionTokenResult, err error) {
split := strings.Split(tokenIDSubject, ":") var tokenID, subject string
if len(split) != 2 {
rc <- &introspectionTokenResult{err: errors.New("invalid token format")} if tokenIDSubject, err := s.Provider().Crypto().Decrypt(accessToken); err == nil {
return split := strings.Split(tokenIDSubject, ":")
if len(split) != 2 {
return nil, errors.New("invalid token format")
}
tokenID, subject = split[0], split[1]
} else {
verifier := op.NewAccessTokenVerifier(op.IssuerFromContext(ctx), s.storage.keySet)
claims, err := op.VerifyAccessToken[*oidc.AccessTokenClaims](ctx, accessToken, verifier)
if err != nil {
return nil, err
}
tokenID, subject = claims.JWTID, claims.Subject
} }
tokenID, subject = split[0], split[1]
} else { if strings.HasPrefix(tokenID, command.IDPrefixV2) {
verifier := op.NewAccessTokenVerifier(op.IssuerFromContext(ctx), s.storage.keySet) token, err := s.storage.query.ActiveAccessTokenByToken(ctx, tokenID)
claims, err := op.VerifyAccessToken[*oidc.AccessTokenClaims](ctx, accessToken, verifier) if err != nil {
rc <- &introspectionTokenResult{err: err}
return nil, err
}
return introspectionTokenResultV2(tokenID, subject, token), nil
}
token, err := s.storage.repo.TokenByIDs(ctx, subject, tokenID)
if err != nil { if err != nil {
rc <- &introspectionTokenResult{err: err} return nil, errz.ThrowPermissionDenied(err, "OIDC-Dsfb2", "token is not valid or has expired")
return
} }
tokenID, subject = claims.JWTID, claims.Subject return introspectionTokenResultV1(tokenID, subject, token), nil
} }()
if strings.HasPrefix(tokenID, command.IDPrefixV2) { span.EndWithError(err)
token, err := s.storage.query.ActiveAccessTokenByToken(ctx, tokenID)
if err != nil {
rc <- &introspectionTokenResult{err: err}
return
}
rc <- &introspectionTokenResult{
tokenID: tokenID,
userID: token.UserID,
subject: subject,
clientID: token.ClientID,
audience: token.Audience,
scope: token.Scope,
tokenCreation: token.AccessTokenCreation,
tokenExpiration: token.AccessTokenExpiration,
}
return
}
token, err := s.storage.repo.TokenByIDs(ctx, subject, tokenID)
if err != nil { if err != nil {
rc <- &introspectionTokenResult{ rc <- &introspectionTokenResult{err: err}
err: errz.ThrowPermissionDenied(err, "OIDC-Dsfb2", "token is not valid or has expired"),
}
return return
} }
rc <- &introspectionTokenResult{ rc <- result
}
func introspectionTokenResultV1(tokenID, subject string, token *model.TokenView) *introspectionTokenResult {
return &introspectionTokenResult{
tokenID: tokenID, tokenID: tokenID,
userID: token.UserID, userID: token.UserID,
subject: subject, subject: subject,
@@ -221,6 +233,19 @@ func (s *Server) introspectionToken(ctx context.Context, accessToken string, rc
} }
} }
func introspectionTokenResultV2(tokenID, subject string, token *query.OIDCSessionAccessTokenReadModel) *introspectionTokenResult {
return &introspectionTokenResult{
tokenID: tokenID,
userID: token.UserID,
subject: subject,
clientID: token.ClientID,
audience: token.Audience,
scope: token.Scope,
tokenCreation: token.AccessTokenCreation,
tokenExpiration: token.AccessTokenExpiration,
}
}
func validateIntrospectionAudience(audience []string, clientID, projectID string) error { func validateIntrospectionAudience(audience []string, clientID, projectID string) error {
if slices.ContainsFunc(audience, func(entry string) bool { if slices.ContainsFunc(audience, func(entry string) bool {
return entry == clientID || entry == projectID return entry == clientID || entry == projectID

View File

@@ -87,7 +87,10 @@ func (k *keySetCache) setKey(instanceID, keyID string, key query.PublicKey) {
k.instanceKeys[instanceID] = map[string]query.PublicKey{keyID: key} k.instanceKeys[instanceID] = map[string]query.PublicKey{keyID: key}
} }
func (k *keySetCache) getKey(ctx context.Context, keyID string, current time.Time) (*jose.JSONWebKey, error) { func (k *keySetCache) getKey(ctx context.Context, keyID string, current time.Time) (_ *jose.JSONWebKey, err error) {
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()
instanceID := authz.GetInstance(ctx).InstanceID() instanceID := authz.GetInstance(ctx).InstanceID()
k.mtx.RLock() k.mtx.RLock()
@@ -101,7 +104,7 @@ func (k *keySetCache) getKey(ctx context.Context, keyID string, current time.Tim
return nil, errors.ThrowInvalidArgument(nil, "OIDC-Zoh9E", "Errors.Key.ExpireBeforeNow") return nil, errors.ThrowInvalidArgument(nil, "OIDC-Zoh9E", "Errors.Key.ExpireBeforeNow")
} }
key, err := k.queryKey(ctx, keyID, current) key, err = k.queryKey(ctx, keyID, current)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -110,7 +113,10 @@ func (k *keySetCache) getKey(ctx context.Context, keyID string, current time.Tim
} }
// VerifySignature implements the oidc.KeySet interface. // VerifySignature implements the oidc.KeySet interface.
func (k *keySetCache) VerifySignature(ctx context.Context, jws *jose.JSONWebSignature) ([]byte, error) { func (k *keySetCache) VerifySignature(ctx context.Context, jws *jose.JSONWebSignature) (_ []byte, err error) {
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()
if len(jws.Signatures) != 1 { if len(jws.Signatures) != 1 {
return nil, errors.ThrowInvalidArgument(nil, "OIDC-Gid9s", "Errors.Token.Invalid") return nil, errors.ThrowInvalidArgument(nil, "OIDC-Gid9s", "Errors.Token.Invalid")
} }

View File

@@ -120,9 +120,10 @@ func NewServer(
} }
server := &Server{ server := &Server{
storage: storage, storage: storage,
LegacyServer: op.NewLegacyServer(provider, endpoints(config.CustomEndpoints)), LegacyServer: op.NewLegacyServer(provider, endpoints(config.CustomEndpoints)),
hashAlg: crypto.NewBCrypt(10), // as we are only verifying in oidc, the cost is already part of the hash string and the config here is irrelevant. hashAlg: crypto.NewBCrypt(10), // as we are only verifying in oidc, the cost is already part of the hash string and the config here is irrelevant.
fallbackLogger: fallbackLogger,
} }
metricTypes := []metrics.MetricType{metrics.MetricTypeRequestCount, metrics.MetricTypeStatusCode, metrics.MetricTypeTotalCount} metricTypes := []metrics.MetricType{metrics.MetricTypeRequestCount, metrics.MetricTypeStatusCode, metrics.MetricTypeTotalCount}
server.Handler = op.RegisterLegacyServer(server, op.WithHTTPMiddleware( server.Handler = op.RegisterLegacyServer(server, op.WithHTTPMiddleware(

View File

@@ -4,6 +4,9 @@ import (
"context" "context"
"net/http" "net/http"
"golang.org/x/exp/slog"
"github.com/zitadel/logging"
"github.com/zitadel/oidc/v3/pkg/oidc" "github.com/zitadel/oidc/v3/pkg/oidc"
"github.com/zitadel/oidc/v3/pkg/op" "github.com/zitadel/oidc/v3/pkg/op"
"github.com/zitadel/zitadel/internal/crypto" "github.com/zitadel/zitadel/internal/crypto"
@@ -15,7 +18,8 @@ type Server struct {
storage *OPStorage storage *OPStorage
*op.LegacyServer *op.LegacyServer
hashAlg crypto.HashAlgorithm fallbackLogger *slog.Logger
hashAlg crypto.HashAlgorithm
} }
func endpoints(endpointConfig *EndpointConfig) op.Endpoints { func endpoints(endpointConfig *EndpointConfig) op.Endpoints {
@@ -61,6 +65,13 @@ func endpoints(endpointConfig *EndpointConfig) op.Endpoints {
return endpoints return endpoints
} }
func (s *Server) getLogger(ctx context.Context) *slog.Logger {
if logger, ok := logging.FromContext(ctx); ok {
return logger
}
return s.fallbackLogger
}
func (s *Server) IssuerFromRequest(r *http.Request) string { func (s *Server) IssuerFromRequest(r *http.Request) string {
return s.Provider().IssuerFromRequest(r) return s.Provider().IssuerFromRequest(r)
} }

View File

@@ -6,9 +6,11 @@ import (
_ "embed" _ "embed"
"github.com/jackc/pgtype" "github.com/jackc/pgtype"
"github.com/zitadel/zitadel/internal/api/authz" "github.com/zitadel/zitadel/internal/api/authz"
"github.com/zitadel/zitadel/internal/crypto" "github.com/zitadel/zitadel/internal/crypto"
"github.com/zitadel/zitadel/internal/database" "github.com/zitadel/zitadel/internal/database"
"github.com/zitadel/zitadel/internal/telemetry/tracing"
) )
type IntrospectionClient struct { type IntrospectionClient struct {
@@ -22,6 +24,9 @@ type IntrospectionClient struct {
var introspectionClientByIDQuery string var introspectionClientByIDQuery string
func (q *Queries) GetIntrospectionClientByID(ctx context.Context, clientID string, getKeys bool) (_ *IntrospectionClient, err error) { func (q *Queries) GetIntrospectionClientByID(ctx context.Context, clientID string, getKeys bool) (_ *IntrospectionClient, err error) {
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()
var ( var (
instanceID = authz.GetInstance(ctx).InstanceID() instanceID = authz.GetInstance(ctx).InstanceID()
client = new(IntrospectionClient) client = new(IntrospectionClient)

View File

@@ -396,6 +396,9 @@ func (wm *PublicKeyReadModel) Query() *eventstore.SearchQueryBuilder {
} }
func (q *Queries) GetActivePublicKeyByID(ctx context.Context, keyID string, current time.Time) (_ PublicKey, err error) { func (q *Queries) GetActivePublicKeyByID(ctx context.Context, keyID string, current time.Time) (_ PublicKey, err error) {
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()
model := NewPublicKeyReadModel(keyID, authz.GetInstance(ctx).InstanceID()) model := NewPublicKeyReadModel(keyID, authz.GetInstance(ctx).InstanceID())
if err := q.eventstore.FilterToQueryReducer(ctx, model); err != nil { if err := q.eventstore.FilterToQueryReducer(ctx, model); err != nil {
return nil, err return nil, err