From 85e22c15217f4719c44846d78fec2825dacf5ed8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tim=20M=C3=B6hlmann?= Date: Wed, 1 Nov 2023 15:59:23 +0200 Subject: [PATCH] get key by id and cache them --- internal/api/oidc/key.go | 56 +++++++++++++++++++++++++ internal/api/oidc/op.go | 16 +++++--- internal/api/oidc/server.go | 61 +++++++++++++++++++++++++-- internal/query/key.go | 80 ++++++++++++++++++++++++++++++++++++ internal/query/query.go | 25 +++++------ internal/static/i18n/en.yaml | 1 + 6 files changed, 218 insertions(+), 21 deletions(-) diff --git a/internal/api/oidc/key.go b/internal/api/oidc/key.go index 3a2a6ae32c..f1d3a21dc0 100644 --- a/internal/api/oidc/key.go +++ b/internal/api/oidc/key.go @@ -3,6 +3,7 @@ package oidc import ( "context" "fmt" + "sync" "time" "github.com/go-jose/go-jose/v3" @@ -19,6 +20,61 @@ import ( "github.com/zitadel/zitadel/internal/telemetry/tracing" ) +type keySet struct { + mtx sync.RWMutex + keys map[string]query.PublicKey + queryKey func(ctx context.Context, keyID string, current time.Time) (query.PublicKey, error) +} + +func (v *keySet) getKey(ctx context.Context, keyID string, current time.Time) (*jose.JSONWebKey, error) { + v.mtx.RLock() + key, ok := v.keys[keyID] + v.mtx.RUnlock() + + if ok { + if key.Expiry().After(current) { + return jsonWebkey(key), nil + } + v.mtx.Lock() + delete(v.keys, keyID) // cleanup expired keys + v.mtx.Unlock() + + return nil, errors.ThrowInvalidArgument(nil, "OIDC-Zoh9E", "Errors.Key.ExpireBeforeNow") + } + + key, err := v.queryKey(ctx, keyID, current) + if err != nil { + return nil, err + } + + v.mtx.Lock() + v.keys[key.ID()] = key + v.mtx.Unlock() + + return jsonWebkey(key), nil +} + +// VerifySignature implements the oidc.KeySet interface. +func (v *keySet) VerifySignature(ctx context.Context, jws *jose.JSONWebSignature) ([]byte, error) { + if len(jws.Signatures) != 1 { + return nil, errors.ThrowInvalidArgument(nil, "OIDC-Gid9s", "Errors.Token.Invalid") + } + key, err := v.getKey(ctx, jws.Signatures[0].Header.KeyID, time.Now()) + if err != nil { + return nil, err + } + return jws.Verify(&key) +} + +func jsonWebkey(key query.PublicKey) *jose.JSONWebKey { + return &jose.JSONWebKey{ + KeyID: key.ID(), + Algorithm: key.Algorithm(), + Use: key.Use().String(), + Key: key.Key(), + } +} + const ( locksTable = "projections.locks" signingKey = "signing_key" diff --git a/internal/api/oidc/op.go b/internal/api/oidc/op.go index 117172a440..ad79759f30 100644 --- a/internal/api/oidc/op.go +++ b/internal/api/oidc/op.go @@ -68,6 +68,7 @@ type OPStorage struct { command *command.Commands query *query.Queries eventstore *eventstore.Eventstore + keySet *keySet defaultLoginURL string defaultLoginURLV2 string defaultLogoutURLV2 string @@ -119,6 +120,7 @@ func NewServer( } server := &Server{ + storage: storage, LegacyServer: op.NewLegacyServer(provider, endpoints(config.CustomEndpoints)), } metricTypes := []metrics.MetricType{metrics.MetricTypeRequestCount, metrics.MetricTypeStatusCode, metrics.MetricTypeTotalCount} @@ -172,12 +174,16 @@ func createOPConfig(config Config, defaultLogoutRedirectURI string, cryptoKey [] return opConfig, nil } -func newStorage(config Config, command *command.Commands, query *query.Queries, repo repository.Repository, encAlg crypto.EncryptionAlgorithm, es *eventstore.Eventstore, db *database.DB, externalSecure bool) *OPStorage { +func newStorage(config Config, command *command.Commands, queries *query.Queries, repo repository.Repository, encAlg crypto.EncryptionAlgorithm, es *eventstore.Eventstore, db *database.DB, externalSecure bool) *OPStorage { return &OPStorage{ - repo: repo, - command: command, - query: query, - eventstore: es, + repo: repo, + command: command, + query: queries, + eventstore: es, + keySet: &keySet{ + keys: make(map[string]query.PublicKey), + queryKey: queries.GetActivePublicKeyByID, + }, defaultLoginURL: fmt.Sprintf("%s%s?%s=", login.HandlerPrefix, login.EndpointLogin, login.QueryAuthRequestID), defaultLoginURLV2: config.DefaultLoginURLV2, defaultLogoutURLV2: config.DefaultLogoutURLV2, diff --git a/internal/api/oidc/server.go b/internal/api/oidc/server.go index f9a38a2613..18b2638bcd 100644 --- a/internal/api/oidc/server.go +++ b/internal/api/oidc/server.go @@ -2,7 +2,10 @@ package oidc import ( "context" + "errors" "net/http" + "strings" + "time" "github.com/zitadel/oidc/v3/pkg/oidc" "github.com/zitadel/oidc/v3/pkg/op" @@ -11,6 +14,7 @@ import ( type Server struct { http.Handler + storage *OPStorage *op.LegacyServer } @@ -159,11 +163,60 @@ func (s *Server) DeviceToken(ctx context.Context, r *op.ClientRequest[oidc.Devic return s.LegacyServer.DeviceToken(ctx, r) } -func (s *Server) Introspect(ctx context.Context, r *op.Request[op.IntrospectionRequest]) (_ *op.Response, err error) { - ctx, span := tracing.NewSpan(ctx) - defer func() { span.EndWithError(err) }() +func (s *Server) authenticateResourceClient(ctx context.Context, cc *op.ClientCredentials) (clientID string, err error) { + if cc.ClientAssertion != "" { + verifier := op.NewJWTProfileVerifier(s.storage, op.IssuerFromContext(ctx), 1*time.Hour, time.Second) + profile, err := op.VerifyJWTAssertion(ctx, cc.ClientAssertion, verifier) + if err != nil { + return "", err + } + return profile.Issuer, nil + } - return s.LegacyServer.Introspect(ctx, r) + if err = s.storage.AuthorizeClientIDSecret(ctx, cc.ClientID, cc.ClientSecret); err != nil { + if err != nil { + return "", err + } + } + return cc.ClientID, nil +} + +func (s *Server) getTokenIDAndSubject(ctx context.Context, accessToken string) (idToken, subject string, err error) { + provider := s.Provider() + tokenIDSubject, err := provider.Crypto().Decrypt(accessToken) + if err == nil { + splitToken := strings.Split(tokenIDSubject, ":") + if len(splitToken) != 2 { + return "", "", errors.New("invalid token format") + } + return splitToken[0], splitToken[1], nil + } + + verifier := op.NewAccessTokenVerifier(op.IssuerFromContext(ctx), s.storage.keySet) + accessTokenClaims, err := op.VerifyAccessToken[*oidc.AccessTokenClaims](ctx, accessToken, verifier) + if err != nil { + return "", "", err + } + return accessTokenClaims.JWTID, accessTokenClaims.Subject, nil +} + +func (s *Server) Introspect(ctx context.Context, r *op.Request[op.IntrospectionRequest]) (_ *op.Response, err error) { + clientID, err := s.authenticateResourceClient(ctx, r.Data.ClientCredentials) + if err != nil { + return nil, err + } + response := new(oidc.IntrospectionResponse) + tokenID, subject, err := s.getTokenIDAndSubject(ctx, r.Data.Token) + if err != nil { + // TODO: log error + return op.NewResponse(response), nil + } + err = s.storage.SetIntrospectionFromToken(ctx, response, tokenID, subject, clientID) + if err != nil { + return op.NewResponse(response), nil + } + response.Active = true + return op.NewResponse(response), nil } func (s *Server) UserInfo(ctx context.Context, r *op.Request[oidc.UserInfoRequest]) (_ *op.Response, err error) { diff --git a/internal/query/key.go b/internal/query/key.go index ac00d0624c..ec656b5792 100644 --- a/internal/query/key.go +++ b/internal/query/key.go @@ -13,7 +13,9 @@ import ( "github.com/zitadel/zitadel/internal/crypto" "github.com/zitadel/zitadel/internal/domain" "github.com/zitadel/zitadel/internal/errors" + "github.com/zitadel/zitadel/internal/eventstore" "github.com/zitadel/zitadel/internal/query/projection" + "github.com/zitadel/zitadel/internal/repository/keypair" "github.com/zitadel/zitadel/internal/telemetry/tracing" ) @@ -349,3 +351,81 @@ func preparePrivateKeysQuery(ctx context.Context, db prepareDatabase) (sq.Select }, nil } } + +type PublicKeyReadModel struct { + eventstore.ReadModel + + Algorithm string + Key *crypto.CryptoValue + Expiry time.Time +} + +func NewPublicKeyReadModel(keyID, resourceOwner string) *PublicKeyReadModel { + return &PublicKeyReadModel{ + ReadModel: eventstore.ReadModel{ + AggregateID: keyID, + ResourceOwner: resourceOwner, + }, + } +} + +func (wm *PublicKeyReadModel) AppendEvents(events ...eventstore.Event) { + wm.ReadModel.AppendEvents(events...) +} + +func (wm *PublicKeyReadModel) Reduce() error { + for _, event := range wm.Events { + switch e := event.(type) { + case *keypair.AddedEvent: + wm.Algorithm = e.Algorithm + wm.Key = e.PublicKey.Key + wm.Expiry = e.PublicKey.Expiry + } + } + return wm.ReadModel.Reduce() +} + +func (wm *PublicKeyReadModel) Query() *eventstore.SearchQueryBuilder { + return eventstore.NewSearchQueryBuilder(eventstore.ColumnsEvent). + ResourceOwner(wm.ResourceOwner). + AddQuery(). + AggregateTypes(keypair.AggregateType). + AggregateIDs(wm.AggregateID). + EventTypes(keypair.AddedEventType). + Builder() +} + +func (q *Queries) GetActivePublicKeyByID(ctx context.Context, keyID string, after time.Time) (_ PublicKey, err error) { + model := NewPublicKeyReadModel(keyID, authz.GetInstance(ctx).InstanceID()) + if err := q.eventstore.FilterToQueryReducer(ctx, model); err != nil { + return nil, err + } + if model.Algorithm == "" || model.Key == nil { + return nil, errors.ThrowNotFound(err, "QUERY-Ahf7x", "Errors.Key.NotFound") + } + if model.Expiry.After(after) { + return nil, errors.ThrowInvalidArgument(err, "QUERY-ciF4k", "Errors.Key.ExpireBeforeNow") + } + keyValue, err := crypto.Decrypt(model.Key, q.keyEncryptionAlgorithm) + if err != nil { + return nil, errors.ThrowInternal(err, "QUERY-Ie4oh", "Errors.Internal") + } + publicKey, err := crypto.BytesToPublicKey(keyValue) + if err != nil { + return nil, errors.ThrowInternal(err, "QUERY-Kai2Z", "Errors.Internal") + } + + return &rsaPublicKey{ + key: key{ + id: model.AggregateID, + creationDate: model.CreationDate, + changeDate: model.ChangeDate, + sequence: model.ProcessedSequence, + resourceOwner: model.ResourceOwner, + algorithm: model.Algorithm, + // use: , TBD, what events update this and do we need it? + }, + expiry: model.Expiry, + publicKey: publicKey, + }, nil +} diff --git a/internal/query/query.go b/internal/query/query.go index cb50ce1fcd..58f4b3009b 100644 --- a/internal/query/query.go +++ b/internal/query/query.go @@ -38,9 +38,10 @@ type Queries struct { eventstore *eventstore.Eventstore client *database.DB - idpConfigEncryption crypto.EncryptionAlgorithm - sessionTokenVerifier func(ctx context.Context, sessionToken string, sessionID string, tokenID string) (err error) - checkPermission domain.PermissionCheck + keyEncryptionAlgorithm crypto.EncryptionAlgorithm + idpConfigEncryption crypto.EncryptionAlgorithm + sessionTokenVerifier func(ctx context.Context, sessionToken string, sessionID string, tokenID string) (err error) + checkPermission domain.PermissionCheck DefaultLanguage language.Tag LoginDir http.FileSystem @@ -86,8 +87,16 @@ func StartQueries( LoginTranslationFileContents: make(map[string][]byte), NotificationTranslationFileContents: make(map[string][]byte), zitadelRoles: zitadelRoles, + keyEncryptionAlgorithm: keyEncryptionAlgorithm, + idpConfigEncryption: idpConfigEncryption, sessionTokenVerifier: sessionTokenVerifier, - defaultAuditLogRetention: defaultAuditLogRetention, + multifactors: domain.MultifactorConfigs{ + OTP: domain.OTPConfig{ + CryptoMFA: otpEncryption, + Issuer: defaults.Multifactors.OTP.Issuer, + }, + }, + defaultAuditLogRetention: defaultAuditLogRetention, } iam_repo.RegisterEventMappers(repo.eventstore) usr_repo.RegisterEventMappers(repo.eventstore) @@ -103,14 +112,6 @@ func StartQueries( quota.RegisterEventMappers(repo.eventstore) limits.RegisterEventMappers(repo.eventstore) - repo.idpConfigEncryption = idpConfigEncryption - repo.multifactors = domain.MultifactorConfigs{ - OTP: domain.OTPConfig{ - CryptoMFA: otpEncryption, - Issuer: defaults.Multifactors.OTP.Issuer, - }, - } - repo.checkPermission = permissionCheck(repo) err = projection.Create(ctx, sqlClient, es, projections, keyEncryptionAlgorithm, certEncryptionAlgorithm, systemAPIUsers) diff --git a/internal/static/i18n/en.yaml b/internal/static/i18n/en.yaml index cc9e539446..90b43e182b 100644 --- a/internal/static/i18n/en.yaml +++ b/internal/static/i18n/en.yaml @@ -428,6 +428,7 @@ Errors: UserSession: NotFound: UserSession not found Key: + NotFound: Key not found ExpireBeforeNow: The expiration date is in the past Login: LoginPolicy: