diff --git a/cmd/zitadel/main.go b/cmd/zitadel/main.go index d35a707ee9..83244266d3 100644 --- a/cmd/zitadel/main.go +++ b/cmd/zitadel/main.go @@ -157,7 +157,8 @@ func startZitadel(configPaths []string) { logging.Log("MAIN-Ddv21").OnError(err).Fatal("cannot start eventstore for queries") } - queries, err := query.StartQueries(ctx, esQueries, conf.Projections, conf.SystemDefaults) + keyChan := make(chan interface{}) + queries, err := query.StartQueries(ctx, esQueries, conf.Projections, conf.SystemDefaults, keyChan) logging.Log("MAIN-WpeJY").OnError(err).Fatal("cannot start queries") authZRepo, err := authz.Start(ctx, conf.AuthZ, conf.InternalAuthZ, conf.SystemDefaults, queries) @@ -189,7 +190,7 @@ func startZitadel(configPaths []string) { } verifier := internal_authz.Start(&repo) - startAPI(ctx, conf, verifier, authZRepo, authRepo, commands, queries, store) + startAPI(ctx, conf, verifier, authZRepo, authRepo, commands, queries, store, esQueries, conf.Projections.CRDB, keyChan) startUI(ctx, conf, authRepo, commands, queries, store) if *notificationEnabled { @@ -214,7 +215,7 @@ func startUI(ctx context.Context, conf *Config, authRepo *auth_es.EsRepository, uis.Start(ctx) } -func startAPI(ctx context.Context, conf *Config, verifier *internal_authz.TokenVerifier, authZRepo *authz_repo.EsRepository, authRepo *auth_es.EsRepository, command *command.Commands, query *query.Queries, static static.Storage) { +func startAPI(ctx context.Context, conf *Config, verifier *internal_authz.TokenVerifier, authZRepo *authz_repo.EsRepository, authRepo *auth_es.EsRepository, command *command.Commands, query *query.Queries, static static.Storage, es *eventstore.Eventstore, projections types.SQL, keyChan <-chan interface{}) { roles := make([]string, len(conf.InternalAuthZ.RolePermissionMappings)) for i, role := range conf.InternalAuthZ.RolePermissionMappings { roles[i] = role.Role @@ -236,7 +237,7 @@ func startAPI(ctx context.Context, conf *Config, verifier *internal_authz.TokenV apis.RegisterServer(ctx, auth.CreateServer(command, query, authRepo, conf.SystemDefaults)) } if *oidcEnabled { - op := oidc.NewProvider(ctx, conf.API.OIDC, command, query, authRepo, conf.SystemDefaults.KeyConfig.EncryptionConfig, *localDevMode) + op := oidc.NewProvider(ctx, conf.API.OIDC, command, query, authRepo, conf.SystemDefaults.KeyConfig, *localDevMode, es, projections, keyChan) apis.RegisterHandler("/oauth/v2", op.HttpHandler()) } if *assetsEnabled { diff --git a/go.mod b/go.mod index a5f6d0ef7e..ce211b459e 100644 --- a/go.mod +++ b/go.mod @@ -44,7 +44,6 @@ require ( github.com/minio/minio-go/v7 v7.0.20 github.com/muesli/gamut v0.2.0 github.com/nicksnyder/go-i18n/v2 v2.1.2 - github.com/nsf/jsondiff v0.0.0-20210926074059-1e845ec5d249 github.com/pkg/errors v0.9.1 github.com/pquerna/otp v1.3.0 github.com/rakyll/statik v0.1.7 diff --git a/go.sum b/go.sum index 08d7be3069..5508ea8703 100644 --- a/go.sum +++ b/go.sum @@ -790,8 +790,6 @@ github.com/nats-io/nuid v1.0.1/go.mod h1:19wcPz3Ph3q0Jbyiqsd0kePYG7A95tJPxeL+1OS github.com/nicksnyder/go-i18n/v2 v2.1.2 h1:QHYxcUJnGHBaq7XbvgunmZ2Pn0focXFqTD61CkH146c= github.com/nicksnyder/go-i18n/v2 v2.1.2/go.mod h1:d++QJC9ZVf7pa48qrsRWhMJ5pSHIPmS3OLqK1niyLxs= github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e/go.mod h1:zD1mROLANZcx1PVRCS0qkT7pwLkGfwJo4zjcN/Tysno= -github.com/nsf/jsondiff v0.0.0-20210926074059-1e845ec5d249 h1:NHrXEjTNQY7P0Zfx1aMrNhpgxHmow66XQtm0aQLY0AE= -github.com/nsf/jsondiff v0.0.0-20210926074059-1e845ec5d249/go.mod h1:mpRZBD8SJ55OIICQ3iWH0Yz3cjzA61JdqMLoWXeB2+8= github.com/nxadm/tail v1.4.4/go.mod h1:kenIhsEOeOJmVchQTgglprH7qJGnHDVpk1VPCcaMI8A= github.com/nxadm/tail v1.4.8 h1:nPr65rt6Y5JFSKQO7qToXr7pePgD6Gwiw05lkbyAQTE= github.com/nxadm/tail v1.4.8/go.mod h1:+ncqLTQzXmGhMZNUePPaPqPvBxHAIsmXswZKocGu+AU= diff --git a/internal/api/oidc/auth_request.go b/internal/api/oidc/auth_request.go index 639b20244d..d1e392c40b 100644 --- a/internal/api/oidc/auth_request.go +++ b/internal/api/oidc/auth_request.go @@ -9,7 +9,6 @@ import ( "github.com/caos/logging" "github.com/caos/oidc/pkg/oidc" "github.com/caos/oidc/pkg/op" - "gopkg.in/square/go-jose.v2" "github.com/caos/zitadel/internal/api/http/middleware" "github.com/caos/zitadel/internal/errors" @@ -198,16 +197,6 @@ func (o *OPStorage) RevokeToken(ctx context.Context, token, userID, clientID str return oidc.ErrServerError().WithParent(err) } -func (o *OPStorage) GetSigningKey(ctx context.Context, keyCh chan<- jose.SigningKey) { - o.repo.GetSigningKey(ctx, keyCh, o.signingKeyAlgorithm) -} - -func (o *OPStorage) GetKeySet(ctx context.Context) (_ *jose.JSONWebKeySet, err error) { - ctx, span := tracing.NewSpan(ctx) - defer func() { span.EndWithError(err) }() - return o.repo.GetKeySet(ctx) -} - func (o *OPStorage) assertProjectRoleScopes(ctx context.Context, clientID string, scopes []string) ([]string, error) { for _, scope := range scopes { if strings.HasPrefix(scope, ScopeProjectRolePrefix) { diff --git a/internal/api/oidc/key.go b/internal/api/oidc/key.go new file mode 100644 index 0000000000..254319612a --- /dev/null +++ b/internal/api/oidc/key.go @@ -0,0 +1,179 @@ +package oidc + +import ( + "context" + "fmt" + "time" + + "github.com/caos/logging" + "gopkg.in/square/go-jose.v2" + + "github.com/caos/zitadel/internal/telemetry/tracing" + + "github.com/caos/zitadel/internal/crypto" + "github.com/caos/zitadel/internal/domain" + "github.com/caos/zitadel/internal/errors" + "github.com/caos/zitadel/internal/eventstore" + "github.com/caos/zitadel/internal/query" + "github.com/caos/zitadel/internal/repository/keypair" +) + +const ( + locksTable = "projections.locks" + signingKey = "signing_key" +) + +func (o *OPStorage) GetKeySet(ctx context.Context) (_ *jose.JSONWebKeySet, err error) { + ctx, span := tracing.NewSpan(ctx) + defer func() { span.EndWithError(err) }() + keys, err := o.query.ActivePublicKeys(ctx, time.Now()) + if err != nil { + return nil, err + } + webKeys := make([]jose.JSONWebKey, len(keys.Keys)) + for i, key := range keys.Keys { + webKeys[i] = jose.JSONWebKey{ + KeyID: key.ID(), + Algorithm: key.Algorithm(), + Use: key.Use().String(), + Key: key.Key(), + } + } + return &jose.JSONWebKeySet{Keys: webKeys}, nil +} + +func (o *OPStorage) GetSigningKey(ctx context.Context, keyCh chan<- jose.SigningKey) { + renewTimer := time.NewTimer(0) + go func() { + for { + select { + case <-ctx.Done(): + return + case <-o.keyChan: + if !renewTimer.Stop() { + <-renewTimer.C + } + checkAfter := o.resetTimer(renewTimer, true) + logging.Log("OIDC-dK432").Infof("requested next signing key check in %s", checkAfter) + case <-renewTimer.C: + o.getSigningKey(ctx, renewTimer, keyCh) + } + } + }() +} + +func (o *OPStorage) getSigningKey(ctx context.Context, renewTimer *time.Timer, keyCh chan<- jose.SigningKey) { + keys, err := o.query.ActivePrivateSigningKey(ctx, time.Now().Add(o.signingKeyGracefulPeriod)) + if err != nil { + checkAfter := o.resetTimer(renewTimer, true) + logging.Log("OIDC-ASff").Infof("next signing key check in %s", checkAfter) + return + } + if len(keys.Keys) == 0 { + o.refreshSigningKey(ctx, keyCh, o.signingKeyAlgorithm, keys.LatestSequence) + checkAfter := o.resetTimer(renewTimer, true) + logging.Log("OIDC-ASDf3").Infof("next signing key check in %s", checkAfter) + return + } + err = o.exchangeSigningKey(selectSigningKey(keys.Keys), keyCh) + logging.Log("OIDC-aDfg3").OnError(err).Error("could not exchange signing key") + checkAfter := o.resetTimer(renewTimer, err != nil) + logging.Log("OIDC-dK432").Infof("next signing key check in %s", checkAfter) +} + +func (o *OPStorage) resetTimer(timer *time.Timer, shortRefresh bool) (nextCheck time.Duration) { + nextCheck = o.signingKeyRotationCheck + defer func() { timer.Reset(nextCheck) }() + if shortRefresh || o.currentKey == nil { + return nextCheck + } + maxLifetime := time.Until(o.currentKey.Expiry()) + if maxLifetime < o.signingKeyGracefulPeriod+2*o.signingKeyRotationCheck { + return nextCheck + } + return maxLifetime - o.signingKeyGracefulPeriod - o.signingKeyRotationCheck +} + +func (o *OPStorage) refreshSigningKey(ctx context.Context, keyCh chan<- jose.SigningKey, algorithm string, sequence *query.LatestSequence) { + if o.currentKey != nil && o.currentKey.Expiry().Before(time.Now().UTC()) { + logging.Log("OIDC-ADg26").Info("unset current signing key") + keyCh <- jose.SigningKey{} + } + ok, err := o.ensureIsLatestKey(ctx, sequence.Sequence) + if err != nil { + logging.Log("OIDC-sdz53").WithError(err).Error("could not ensure latest key") + return + } + if !ok { + logging.Log("EVENT-GBD23").Warn("view not up to date, retrying later") + return + } + err = o.lockAndGenerateSigningKeyPair(ctx, algorithm) + logging.Log("EVENT-B4d21").OnError(err).Warn("could not create signing key") +} + +func (o *OPStorage) ensureIsLatestKey(ctx context.Context, sequence uint64) (bool, error) { + maxSequence, err := o.getMaxKeySequence(ctx) + if err != nil { + return false, fmt.Errorf("error retrieving new events: %w", err) + } + return sequence == maxSequence, nil +} + +func (o *OPStorage) exchangeSigningKey(key query.PrivateKey, keyCh chan<- jose.SigningKey) (err error) { + if o.currentKey != nil && o.currentKey.ID() == key.ID() { + logging.Log("OIDC-Abb3e").Info("no new signing key") + return nil + } + keyData, err := crypto.Decrypt(key.Key(), o.encAlg) + if err != nil { + return err + } + privateKey, err := crypto.BytesToPrivateKey(keyData) + if err != nil { + return err + } + keyCh <- jose.SigningKey{ + Algorithm: jose.SignatureAlgorithm(key.Algorithm()), + Key: jose.JSONWebKey{ + KeyID: key.ID(), + Key: privateKey, + }, + } + o.currentKey = key + logging.LogWithFields("OIDC-dsg54", "keyID", key.ID()).Info("exchanged signing key") + return nil +} + +func (o *OPStorage) lockAndGenerateSigningKeyPair(ctx context.Context, algorithm string) error { + logging.Log("OIDC-sdz53").Info("lock and generate signing key pair") + + ctx, cancel := context.WithCancel(ctx) + defer cancel() + + errs := o.locker.Lock(ctx, o.signingKeyRotationCheck*2) + err, ok := <-errs + if err != nil || !ok { + if errors.IsErrorAlreadyExists(err) { + return nil + } + logging.Log("OIDC-Dfg32").OnError(err).Warn("initial lock failed") + return err + } + + return o.command.GenerateSigningKeyPair(ctx, algorithm) +} + +func (o *OPStorage) getMaxKeySequence(ctx context.Context) (uint64, error) { + return o.eventstore.LatestSequence(ctx, + eventstore.NewSearchQueryBuilder(eventstore.ColumnsMaxSequence). + ResourceOwner(domain.IAMID). + AddQuery(). + AggregateTypes(keypair.AggregateType). + Builder(), + ) +} + +func selectSigningKey(keys []query.PrivateKey) query.PrivateKey { + return keys[len(keys)-1] +} diff --git a/internal/api/oidc/op.go b/internal/api/oidc/op.go index 436e95519a..5b8c117da0 100644 --- a/internal/api/oidc/op.go +++ b/internal/api/oidc/op.go @@ -13,8 +13,11 @@ import ( "github.com/caos/zitadel/internal/api/http/middleware" "github.com/caos/zitadel/internal/auth/repository" "github.com/caos/zitadel/internal/command" + "github.com/caos/zitadel/internal/config/systemdefaults" "github.com/caos/zitadel/internal/config/types" "github.com/caos/zitadel/internal/crypto" + "github.com/caos/zitadel/internal/eventstore" + "github.com/caos/zitadel/internal/eventstore/handler/crdb" "github.com/caos/zitadel/internal/i18n" "github.com/caos/zitadel/internal/id" "github.com/caos/zitadel/internal/query" @@ -58,18 +61,25 @@ type OPStorage struct { repo repository.Repository command *command.Commands query *query.Queries + eventstore *eventstore.Eventstore defaultLoginURL string defaultAccessTokenLifetime time.Duration defaultIdTokenLifetime time.Duration signingKeyAlgorithm string defaultRefreshTokenIdleExpiration time.Duration defaultRefreshTokenExpiration time.Duration + encAlg crypto.EncryptionAlgorithm + keyChan <-chan interface{} + currentKey query.PrivateKey + signingKeyRotationCheck time.Duration + signingKeyGracefulPeriod time.Duration + locker crdb.Locker } -func NewProvider(ctx context.Context, config OPHandlerConfig, command *command.Commands, query *query.Queries, repo repository.Repository, keyConfig *crypto.KeyConfig, localDevMode bool) op.OpenIDProvider { +func NewProvider(ctx context.Context, config OPHandlerConfig, command *command.Commands, query *query.Queries, repo repository.Repository, keyConfig systemdefaults.KeyConfig, localDevMode bool, es *eventstore.Eventstore, projections types.SQL, keyChan <-chan interface{}) op.OpenIDProvider { cookieHandler, err := middleware.NewUserAgentHandler(config.UserAgentCookieConfig, id.SonyFlakeGenerator, localDevMode) logging.Log("OIDC-sd4fd").OnError(err).WithField("traceID", tracing.TraceIDFromCtx(ctx)).Panic("cannot user agent handler") - tokenKey, err := crypto.LoadKey(keyConfig, keyConfig.EncryptionKeyID) + tokenKey, err := crypto.LoadKey(keyConfig.EncryptionConfig, keyConfig.EncryptionConfig.EncryptionKeyID) logging.Log("OIDC-ADvbv").OnError(err).Panic("cannot load OP crypto key") cryptoKey := []byte(tokenKey) if len(cryptoKey) != 32 { @@ -84,10 +94,12 @@ func NewProvider(ctx context.Context, config OPHandlerConfig, command *command.C logging.Log("OIDC-GBd3t").OnError(err).Panic("cannot get supported languages") config.OPConfig.SupportedUILocales = supportedLanguages metricTypes := []metrics.MetricType{metrics.MetricTypeRequestCount, metrics.MetricTypeStatusCode, metrics.MetricTypeTotalCount} + storage, err := newStorage(config.StorageConfig, command, query, repo, keyConfig, es, projections, keyChan) + logging.Log("OIDC-Jdg2k").OnError(err).WithField("traceID", tracing.TraceIDFromCtx(ctx)).Panic("cannot create storage") provider, err := op.NewOpenIDProvider( ctx, config.OPConfig, - newStorage(config.StorageConfig, command, query, repo), + storage, op.WithHttpInterceptors( middleware.MetricsHandler(metricTypes), middleware.TelemetryHandler(), @@ -107,18 +119,32 @@ func NewProvider(ctx context.Context, config OPHandlerConfig, command *command.C return provider } -func newStorage(config StorageConfig, command *command.Commands, query *query.Queries, repo repository.Repository) *OPStorage { +func newStorage(config StorageConfig, command *command.Commands, query *query.Queries, repo repository.Repository, keyConfig systemdefaults.KeyConfig, es *eventstore.Eventstore, projections types.SQL, keyChan <-chan interface{}) (*OPStorage, error) { + encAlg, err := crypto.NewAESCrypto(keyConfig.EncryptionConfig) + if err != nil { + return nil, err + } + sqlClient, err := projections.Start() + if err != nil { + return nil, err + } return &OPStorage{ repo: repo, command: command, query: query, + eventstore: es, defaultLoginURL: config.DefaultLoginURL, signingKeyAlgorithm: config.SigningKeyAlgorithm, defaultAccessTokenLifetime: config.DefaultAccessTokenLifetime.Duration, defaultIdTokenLifetime: config.DefaultIdTokenLifetime.Duration, defaultRefreshTokenIdleExpiration: config.DefaultRefreshTokenIdleExpiration.Duration, defaultRefreshTokenExpiration: config.DefaultRefreshTokenExpiration.Duration, - } + encAlg: encAlg, + signingKeyGracefulPeriod: keyConfig.SigningKeyGracefulPeriod.Duration, + signingKeyRotationCheck: keyConfig.SigningKeyRotationCheck.Duration, + locker: crdb.NewLocker(sqlClient, locksTable, signingKey), + keyChan: keyChan, + }, nil } func (o *OPStorage) Health(ctx context.Context) error { diff --git a/internal/auth/repository/eventsourcing/eventstore/key.go b/internal/auth/repository/eventsourcing/eventstore/key.go deleted file mode 100644 index dd34354030..0000000000 --- a/internal/auth/repository/eventsourcing/eventstore/key.go +++ /dev/null @@ -1,175 +0,0 @@ -package eventstore - -import ( - "context" - "os" - "time" - - "github.com/caos/logging" - "gopkg.in/square/go-jose.v2" - - "github.com/caos/zitadel/internal/auth/repository/eventsourcing/view" - "github.com/caos/zitadel/internal/command" - "github.com/caos/zitadel/internal/crypto" - "github.com/caos/zitadel/internal/errors" - "github.com/caos/zitadel/internal/eventstore" - "github.com/caos/zitadel/internal/eventstore/v1/spooler" - "github.com/caos/zitadel/internal/id" - "github.com/caos/zitadel/internal/key/model" - key_view "github.com/caos/zitadel/internal/key/repository/view" -) - -type KeyRepository struct { - Commands *command.Commands - Eventstore *eventstore.Eventstore - View *view.View - SigningKeyRotationCheck time.Duration - SigningKeyGracefulPeriod time.Duration - KeyAlgorithm crypto.EncryptionAlgorithm - KeyChan <-chan *model.KeyView - Locker spooler.Locker - lockID string - currentKeyID string - currentKeyExpiration time.Time -} - -const ( - signingKey = "signing_key" -) - -func (k *KeyRepository) GetSigningKey(ctx context.Context, keyCh chan<- jose.SigningKey, algorithm string) { - renewTimer := time.After(0) - go func() { - for { - select { - case <-ctx.Done(): - return - case key := <-k.KeyChan: - refreshed, err := k.refreshSigningKey(ctx, key, keyCh, algorithm) - logging.Log("KEY-asd5g").OnError(err).Error("could not refresh signing key on key channel push") - renewTimer = time.After(k.getRenewTimer(refreshed)) - case <-renewTimer: - key, err := k.latestSigningKey() - logging.Log("KEY-DAfh4-1").OnError(err).Error("could not check for latest signing key") - refreshed, err := k.refreshSigningKey(ctx, key, keyCh, algorithm) - logging.Log("KEY-DAfh4-2").OnError(err).Error("could not refresh signing key when ensuring key") - renewTimer = time.After(k.getRenewTimer(refreshed)) - } - } - }() -} - -func (k *KeyRepository) GetKeySet(ctx context.Context) (*jose.JSONWebKeySet, error) { - keys, err := k.View.GetActiveKeySet() - if err != nil { - return nil, err - } - webKeys := make([]jose.JSONWebKey, len(keys)) - for i, key := range keys { - webKeys[i] = jose.JSONWebKey{KeyID: key.ID, Algorithm: key.Algorithm, Use: key.Usage.String(), Key: key.Key} - } - return &jose.JSONWebKeySet{Keys: webKeys}, nil -} - -func (k *KeyRepository) getRenewTimer(refreshed bool) time.Duration { - duration := k.SigningKeyRotationCheck - if refreshed { - duration = k.currentKeyExpiration.Sub(time.Now().Add(k.SigningKeyGracefulPeriod + k.SigningKeyRotationCheck*2)) - } - logging.LogWithFields("EVENT-dK432", "in", duration).Info("next signing key check") - return duration -} - -func (k *KeyRepository) latestSigningKey() (shortRefresh *model.KeyView, err error) { - key, errView := k.View.GetActivePrivateKeyForSigning(time.Now().UTC().Add(k.SigningKeyGracefulPeriod)) - if errView != nil && !errors.IsNotFound(errView) { - logging.Log("EVENT-GEd4h").WithError(errView).Warn("could not get signing key") - return nil, errView - } - return key, nil -} - -func (k *KeyRepository) ensureIsLatestKey(ctx context.Context) (bool, error) { - sequence, err := k.View.GetLatestKeySequence() - if err != nil { - return false, err - } - events, err := k.getKeyEvents(ctx, sequence.CurrentSequence) - if err != nil { - logging.Log("EVENT-der5g").WithError(err).Warn("error retrieving new events") - return false, err - } - if len(events) > 0 { - logging.Log("EVENT-GBD23").Warn("view not up to date, retrying later") - return false, nil - } - return true, nil -} - -func (k *KeyRepository) refreshSigningKey(ctx context.Context, key *model.KeyView, keyCh chan<- jose.SigningKey, algorithm string) (refreshed bool, err error) { - if key == nil { - if k.currentKeyExpiration.Before(time.Now().UTC()) { - logging.Log("EVENT-ADg26").Info("unset current signing key") - keyCh <- jose.SigningKey{} - } - if ok, err := k.ensureIsLatestKey(ctx); !ok && err == nil { - return false, err - } - logging.Log("EVENT-sdz53").Info("lock and generate signing key pair") - err = k.lockAndGenerateSigningKeyPair(ctx, algorithm) - logging.Log("EVENT-B4d21").OnError(err).Warn("could not create signing key") - return false, err - } - - if k.currentKeyID == key.ID { - logging.Log("EVENT-Abb3e").Info("no new signing key") - return false, nil - } - if ok, err := k.ensureIsLatestKey(ctx); !ok && err == nil { - logging.Log("EVENT-HJd92").Info("signing key in view is not latest key") - return false, err - } - signingKey, err := model.SigningKeyFromKeyView(key, k.KeyAlgorithm) - if err != nil { - logging.Log("EVENT-HJd92").WithError(err).Error("signing key cannot be decrypted -> immediate refresh") - return k.refreshSigningKey(ctx, nil, keyCh, algorithm) - } - k.currentKeyID = signingKey.ID - k.currentKeyExpiration = key.Expiry - keyCh <- jose.SigningKey{ - Algorithm: jose.SignatureAlgorithm(signingKey.Algorithm), - Key: jose.JSONWebKey{ - KeyID: signingKey.ID, - Key: signingKey.Key, - }, - } - logging.LogWithFields("EVENT-dsg54", "keyID", signingKey.ID).Info("refreshed signing key") - return true, nil -} - -func (k *KeyRepository) lockAndGenerateSigningKeyPair(ctx context.Context, algorithm string) error { - err := k.Locker.Renew(k.lockerID(), signingKey, k.SigningKeyRotationCheck*2) - if err != nil { - if errors.IsErrorAlreadyExists(err) { - return nil - } - return err - } - return k.Commands.GenerateSigningKeyPair(ctx, algorithm) -} - -func (k *KeyRepository) lockerID() string { - if k.lockID == "" { - var err error - k.lockID, err = os.Hostname() - if err != nil || k.lockID == "" { - k.lockID, err = id.SonyFlakeGenerator.Next() - logging.Log("EVENT-bsdf6").OnError(err).Panic("unable to generate lockID") - } - } - return k.lockID -} - -func (k *KeyRepository) getKeyEvents(ctx context.Context, sequence uint64) ([]eventstore.Event, error) { - return k.Eventstore.Filter(ctx, key_view.KeyPairQuery(sequence)) -} diff --git a/internal/auth/repository/eventsourcing/handler/handler.go b/internal/auth/repository/eventsourcing/handler/handler.go index fba698954e..b273680a8b 100644 --- a/internal/auth/repository/eventsourcing/handler/handler.go +++ b/internal/auth/repository/eventsourcing/handler/handler.go @@ -8,7 +8,6 @@ import ( "github.com/caos/zitadel/internal/config/types" v1 "github.com/caos/zitadel/internal/eventstore/v1" "github.com/caos/zitadel/internal/eventstore/v1/query" - key_model "github.com/caos/zitadel/internal/key/model" ) type Configs map[string]*Config @@ -30,7 +29,7 @@ func (h *handler) Eventstore() v1.Eventstore { return h.es } -func Register(configs Configs, bulkLimit, errorCount uint64, view *view.View, es v1.Eventstore, systemDefaults sd.SystemDefaults, keyChan chan<- *key_model.KeyView) []query.Handler { +func Register(configs Configs, bulkLimit, errorCount uint64, view *view.View, es v1.Eventstore, systemDefaults sd.SystemDefaults) []query.Handler { return []query.Handler{ newUser( handler{view, bulkLimit, configs.cycleDuration("User"), errorCount, es}, @@ -41,9 +40,6 @@ func Register(configs Configs, bulkLimit, errorCount uint64, view *view.View, es handler{view, bulkLimit, configs.cycleDuration("UserMembership"), errorCount, es}), newToken( handler{view, bulkLimit, configs.cycleDuration("Token"), errorCount, es}), - newKey( - handler{view, bulkLimit, configs.cycleDuration("Key"), errorCount, es}, - keyChan), newUserGrant( handler{view, bulkLimit, configs.cycleDuration("UserGrant"), errorCount, es}, systemDefaults.IamID), diff --git a/internal/auth/repository/eventsourcing/handler/key.go b/internal/auth/repository/eventsourcing/handler/key.go deleted file mode 100644 index 372533c0cb..0000000000 --- a/internal/auth/repository/eventsourcing/handler/key.go +++ /dev/null @@ -1,106 +0,0 @@ -package handler - -import ( - "github.com/caos/zitadel/internal/eventstore/v1" - "time" - - "github.com/caos/logging" - - "github.com/caos/zitadel/internal/eventstore/v1/models" - "github.com/caos/zitadel/internal/eventstore/v1/query" - "github.com/caos/zitadel/internal/eventstore/v1/spooler" - "github.com/caos/zitadel/internal/key/model" - "github.com/caos/zitadel/internal/key/repository/eventsourcing" - es_model "github.com/caos/zitadel/internal/key/repository/eventsourcing/model" - view_model "github.com/caos/zitadel/internal/key/repository/view/model" -) - -const ( - keyTable = "auth.keys" -) - -type Key struct { - handler - subscription *v1.Subscription - keyChan chan<- *model.KeyView -} - -func newKey(handler handler, keyChan chan<- *model.KeyView) *Key { - h := &Key{ - handler: handler, - keyChan: keyChan, - } - - h.subscribe() - - return h -} - -func (k *Key) subscribe() { - k.subscription = k.es.Subscribe(k.AggregateTypes()...) - go func() { - for event := range k.subscription.Events { - query.ReduceEvent(k, event) - } - }() -} - -func (k *Key) ViewModel() string { - return keyTable -} - -func (k *Key) Subscription() *v1.Subscription { - return k.subscription -} - -func (_ *Key) AggregateTypes() []models.AggregateType { - return []models.AggregateType{es_model.KeyPairAggregate} -} - -func (k *Key) CurrentSequence() (uint64, error) { - sequence, err := k.view.GetLatestKeySequence() - if err != nil { - return 0, err - } - return sequence.CurrentSequence, nil -} - -func (k *Key) EventQuery() (*models.SearchQuery, error) { - sequence, err := k.view.GetLatestKeySequence() - if err != nil { - return nil, err - } - return eventsourcing.KeyPairQuery(sequence.CurrentSequence), nil -} - -func (k *Key) Reduce(event *models.Event) error { - switch event.Type { - case es_model.KeyPairAdded: - privateKey, publicKey, err := view_model.KeysFromPairEvent(event) - if err != nil { - return err - } - if privateKey.Expiry.Before(time.Now()) && publicKey.Expiry.Before(time.Now()) { - return k.view.ProcessedKeySequence(event) - } - err = k.view.PutKeys(privateKey, publicKey, event) - if err != nil { - return err - } - k.keyChan <- view_model.KeyViewToModel(privateKey) - return nil - default: - return k.view.ProcessedKeySequence(event) - } -} - -func (k *Key) OnError(event *models.Event, err error) error { - logging.LogWithFields("SPOOL-GHa3a", "id", event.AggregateID).WithError(err).Warn("something went wrong in key handler") - return spooler.HandleError(event, err, k.view.GetLatestKeyFailedEvent, k.view.ProcessedKeyFailedEvent, k.view.ProcessedKeySequence, k.errorCountUntilSkip) -} - -func (k *Key) OnSuccess() error { - err := spooler.HandleSuccess(k.view.UpdateKeySpoolerRunTimestamp) - logging.LogWithFields("SPOOL-vM9sd", "table", keyTable).OnError(err).Warn("could not process on success func") - return err -} diff --git a/internal/auth/repository/eventsourcing/repository.go b/internal/auth/repository/eventsourcing/repository.go index 01f8946084..d7c93dd8a2 100644 --- a/internal/auth/repository/eventsourcing/repository.go +++ b/internal/auth/repository/eventsourcing/repository.go @@ -20,7 +20,6 @@ import ( v1 "github.com/caos/zitadel/internal/eventstore/v1" es_spol "github.com/caos/zitadel/internal/eventstore/v1/spooler" "github.com/caos/zitadel/internal/id" - key_model "github.com/caos/zitadel/internal/key/model" "github.com/caos/zitadel/internal/query" ) @@ -41,7 +40,6 @@ type EsRepository struct { eventstore.AuthRequestRepo eventstore.TokenRepo eventstore.RefreshTokenRepo - eventstore.KeyRepository eventstore.ApplicationRepo eventstore.UserSessionRepo eventstore.UserGrantRepo @@ -81,9 +79,7 @@ func Start(conf Config, authZ authz.Config, systemDefaults sd.SystemDefaults, co statikLoginFS, err := fs.NewWithNamespace("login") logging.Log("CONFI-20opp").OnError(err).Panic("unable to start login statik dir") - keyChan := make(chan *key_model.KeyView) - spool := spooler.StartSpooler(conf.Spooler, es, view, sqlClient, systemDefaults, keyChan) - locker := spooler.NewLocker(sqlClient) + spool := spooler.StartSpooler(conf.Spooler, es, view, sqlClient, systemDefaults) userRepo := eventstore.UserRepo{ SearchLimit: conf.SearchLimit, @@ -141,16 +137,6 @@ func Start(conf Config, authZ authz.Config, systemDefaults sd.SystemDefaults, co SearchLimit: conf.SearchLimit, KeyAlgorithm: keyAlgorithm, }, - eventstore.KeyRepository{ - View: view, - Commands: command, - Eventstore: esV2, - SigningKeyRotationCheck: systemDefaults.KeyConfig.SigningKeyRotationCheck.Duration, - SigningKeyGracefulPeriod: systemDefaults.KeyConfig.SigningKeyGracefulPeriod.Duration, - KeyAlgorithm: keyAlgorithm, - Locker: locker, - KeyChan: keyChan, - }, eventstore.ApplicationRepo{ Commands: command, Query: queries, diff --git a/internal/auth/repository/eventsourcing/spooler/spooler.go b/internal/auth/repository/eventsourcing/spooler/spooler.go index e2b842364c..b36ef4a5e7 100644 --- a/internal/auth/repository/eventsourcing/spooler/spooler.go +++ b/internal/auth/repository/eventsourcing/spooler/spooler.go @@ -2,13 +2,13 @@ package spooler import ( "database/sql" - "github.com/caos/zitadel/internal/eventstore/v1" + + v1 "github.com/caos/zitadel/internal/eventstore/v1" "github.com/caos/zitadel/internal/auth/repository/eventsourcing/handler" "github.com/caos/zitadel/internal/auth/repository/eventsourcing/view" sd "github.com/caos/zitadel/internal/config/systemdefaults" "github.com/caos/zitadel/internal/eventstore/v1/spooler" - key_model "github.com/caos/zitadel/internal/key/model" ) type SpoolerConfig struct { @@ -18,12 +18,12 @@ type SpoolerConfig struct { Handlers handler.Configs } -func StartSpooler(c SpoolerConfig, es v1.Eventstore, view *view.View, client *sql.DB, systemDefaults sd.SystemDefaults, keyChan chan<- *key_model.KeyView) *spooler.Spooler { +func StartSpooler(c SpoolerConfig, es v1.Eventstore, view *view.View, client *sql.DB, systemDefaults sd.SystemDefaults) *spooler.Spooler { spoolerConfig := spooler.Config{ Eventstore: es, Locker: &locker{dbClient: client}, ConcurrentWorkers: c.ConcurrentWorkers, - ViewHandlers: handler.Register(c.Handlers, c.BulkLimit, c.FailureCountUntilSkip, view, es, systemDefaults, keyChan), + ViewHandlers: handler.Register(c.Handlers, c.BulkLimit, c.FailureCountUntilSkip, view, es, systemDefaults), } spool := spoolerConfig.New() spool.Start() diff --git a/internal/auth/repository/key.go b/internal/auth/repository/key.go deleted file mode 100644 index ed8cef3a86..0000000000 --- a/internal/auth/repository/key.go +++ /dev/null @@ -1,12 +0,0 @@ -package repository - -import ( - "context" - - "gopkg.in/square/go-jose.v2" -) - -type KeyRepository interface { - GetSigningKey(ctx context.Context, keyCh chan<- jose.SigningKey, algorithm string) - GetKeySet(ctx context.Context) (*jose.JSONWebKeySet, error) -} diff --git a/internal/auth/repository/repository.go b/internal/auth/repository/repository.go index d8c3edd824..fee6594207 100644 --- a/internal/auth/repository/repository.go +++ b/internal/auth/repository/repository.go @@ -10,7 +10,6 @@ type Repository interface { AuthRequestRepository TokenRepository ApplicationRepository - KeyRepository UserSessionRepository UserGrantRepository OrgRepository diff --git a/internal/crypto/rsa.go b/internal/crypto/rsa.go index b8295fbb29..86b968a8f5 100644 --- a/internal/crypto/rsa.go +++ b/internal/crypto/rsa.go @@ -5,6 +5,7 @@ import ( "crypto/rsa" "crypto/x509" "encoding/pem" + "errors" ) func GenerateKeyPair(bits int) (*rsa.PrivateKey, *rsa.PublicKey, error) { @@ -64,18 +65,17 @@ func BytesToPrivateKey(priv []byte) (*rsa.PrivateKey, error) { return key, nil } +var ErrEmpty = errors.New("cannot decode, empty data") + func BytesToPublicKey(pub []byte) (*rsa.PublicKey, error) { - block, _ := pem.Decode(pub) - enc := x509.IsEncryptedPEMBlock(block) - b := block.Bytes - var err error - if enc { - b, err = x509.DecryptPEMBlock(block, nil) - if err != nil { - return nil, err - } + if pub == nil { + return nil, ErrEmpty } - ifc, err := x509.ParsePKIXPublicKey(b) + block, _ := pem.Decode(pub) + if block == nil { + return nil, ErrEmpty + } + ifc, err := x509.ParsePKIXPublicKey(block.Bytes) if err != nil { return nil, err } diff --git a/internal/eventstore/handler/crdb/handler_stmt.go b/internal/eventstore/handler/crdb/handler_stmt.go index 8aeb04ad0e..d73baed817 100644 --- a/internal/eventstore/handler/crdb/handler_stmt.go +++ b/internal/eventstore/handler/crdb/handler_stmt.go @@ -4,13 +4,12 @@ import ( "context" "database/sql" "fmt" - "os" "github.com/caos/logging" + "github.com/caos/zitadel/internal/errors" "github.com/caos/zitadel/internal/eventstore" "github.com/caos/zitadel/internal/eventstore/handler" - "github.com/caos/zitadel/internal/id" ) var ( @@ -32,6 +31,7 @@ type StatementHandlerConfig struct { type StatementHandler struct { *handler.ProjectionHandler + Locker client *sql.DB sequenceTable string @@ -40,25 +40,17 @@ type StatementHandler struct { maxFailureCount uint failureCountStmt string setFailureCountStmt string - lockStmt string aggregates []eventstore.AggregateType reduces map[eventstore.EventType]handler.Reduce - workerName string - bulkLimit uint64 + bulkLimit uint64 } func NewStatementHandler( ctx context.Context, config StatementHandlerConfig, ) StatementHandler { - workerName, err := os.Hostname() - if err != nil || workerName == "" { - workerName, err = id.SonyFlakeGenerator.Next() - logging.Log("SPOOL-bdO56").OnError(err).Panic("unable to generate lockID") - } - aggregateTypes := make([]eventstore.AggregateType, 0, len(config.Reducers)) reduces := make(map[eventstore.EventType]handler.Reduce, len(config.Reducers)) for _, aggReducer := range config.Reducers { @@ -77,11 +69,10 @@ func NewStatementHandler( updateSequencesBaseStmt: fmt.Sprintf(updateCurrentSequencesStmtFormat, config.SequenceTable), failureCountStmt: fmt.Sprintf(failureCountStmtFormat, config.FailedEventsTable), setFailureCountStmt: fmt.Sprintf(setFailureCountStmtFormat, config.FailedEventsTable), - lockStmt: fmt.Sprintf(lockStmtFormat, config.LockTable), aggregates: aggregateTypes, reduces: reduces, - workerName: workerName, bulkLimit: config.BulkLimit, + Locker: NewLocker(config.Client, config.LockTable, config.ProjectionHandlerConfig.ProjectionName), } go h.ProjectionHandler.Process( diff --git a/internal/eventstore/handler/crdb/projection_lock.go b/internal/eventstore/handler/crdb/lock.go similarity index 56% rename from internal/eventstore/handler/crdb/projection_lock.go rename to internal/eventstore/handler/crdb/lock.go index c95d8e8200..668b764989 100644 --- a/internal/eventstore/handler/crdb/projection_lock.go +++ b/internal/eventstore/handler/crdb/lock.go @@ -2,9 +2,15 @@ package crdb import ( "context" + "database/sql" + "fmt" + "os" "time" + "github.com/caos/logging" + "github.com/caos/zitadel/internal/errors" + "github.com/caos/zitadel/internal/id" ) const ( @@ -15,13 +21,39 @@ const ( " WHERE %[1]s.projection_name = $3 AND (%[1]s.locker_id = $1 OR %[1]s.locked_until < now())" ) -func (h *StatementHandler) Lock(ctx context.Context, lockDuration time.Duration) <-chan error { +type Locker interface { + Lock(ctx context.Context, lockDuration time.Duration) <-chan error + Unlock() error +} + +type locker struct { + client *sql.DB + lockStmt string + workerName string + projectionName string +} + +func NewLocker(client *sql.DB, lockTable, projectionName string) Locker { + workerName, err := os.Hostname() + if err != nil || workerName == "" { + workerName, err = id.SonyFlakeGenerator.Next() + logging.Log("CRDB-bdO56").OnError(err).Panic("unable to generate lockID") + } + return &locker{ + client: client, + lockStmt: fmt.Sprintf(lockStmtFormat, lockTable), + workerName: workerName, + projectionName: projectionName, + } +} + +func (h *locker) Lock(ctx context.Context, lockDuration time.Duration) <-chan error { errs := make(chan error) go h.handleLock(ctx, errs, lockDuration) return errs } -func (h *StatementHandler) handleLock(ctx context.Context, errs chan error, lockDuration time.Duration) { +func (h *locker) handleLock(ctx context.Context, errs chan error, lockDuration time.Duration) { renewLock := time.NewTimer(0) for { select { @@ -37,9 +69,9 @@ func (h *StatementHandler) handleLock(ctx context.Context, errs chan error, lock } } -func (h *StatementHandler) renewLock(ctx context.Context, lockDuration time.Duration) error { +func (h *locker) renewLock(ctx context.Context, lockDuration time.Duration) error { //the unit of crdb interval is seconds (https://www.cockroachlabs.com/docs/stable/interval.html). - res, err := h.client.Exec(h.lockStmt, h.workerName, lockDuration.Seconds(), h.ProjectionName) + res, err := h.client.Exec(h.lockStmt, h.workerName, lockDuration.Seconds(), h.projectionName) if err != nil { return errors.ThrowInternal(err, "CRDB-uaDoR", "unable to execute lock") } @@ -51,8 +83,8 @@ func (h *StatementHandler) renewLock(ctx context.Context, lockDuration time.Dura return nil } -func (h *StatementHandler) Unlock() error { - _, err := h.client.Exec(h.lockStmt, h.workerName, float64(0), h.ProjectionName) +func (h *locker) Unlock() error { + _, err := h.client.Exec(h.lockStmt, h.workerName, float64(0), h.projectionName) if err != nil { return errors.ThrowUnknown(err, "CRDB-JjfwO", "unlock failed") } diff --git a/internal/eventstore/handler/crdb/projection_lock_test.go b/internal/eventstore/handler/crdb/lock_test.go similarity index 87% rename from internal/eventstore/handler/crdb/projection_lock_test.go rename to internal/eventstore/handler/crdb/lock_test.go index 781acd22f4..ac7991d921 100644 --- a/internal/eventstore/handler/crdb/projection_lock_test.go +++ b/internal/eventstore/handler/crdb/lock_test.go @@ -8,10 +8,9 @@ import ( "testing" "time" - z_errs "github.com/caos/zitadel/internal/errors" - "github.com/caos/zitadel/internal/eventstore/handler" - "github.com/DATA-DOG/go-sqlmock" + + z_errs "github.com/caos/zitadel/internal/errors" ) const ( @@ -82,13 +81,11 @@ func TestStatementHandler_handleLock(t *testing.T) { if err != nil { t.Fatal(err) } - h := &StatementHandler{ - ProjectionHandler: &handler.ProjectionHandler{ - ProjectionName: projectionName, - }, - client: client, - workerName: workerName, - lockStmt: fmt.Sprintf(lockStmtFormat, lockTable), + h := &locker{ + projectionName: projectionName, + client: client, + workerName: workerName, + lockStmt: fmt.Sprintf(lockStmtFormat, lockTable), } for _, expectation := range tt.want.expectations { @@ -173,13 +170,11 @@ func TestStatementHandler_renewLock(t *testing.T) { if err != nil { t.Fatal(err) } - h := &StatementHandler{ - ProjectionHandler: &handler.ProjectionHandler{ - ProjectionName: projectionName, - }, - client: client, - workerName: workerName, - lockStmt: fmt.Sprintf(lockStmtFormat, lockTable), + h := &locker{ + projectionName: projectionName, + client: client, + workerName: workerName, + lockStmt: fmt.Sprintf(lockStmtFormat, lockTable), } for _, expectation := range tt.want.expectations { @@ -237,13 +232,11 @@ func TestStatementHandler_Unlock(t *testing.T) { if err != nil { t.Fatal(err) } - h := &StatementHandler{ - ProjectionHandler: &handler.ProjectionHandler{ - ProjectionName: projectionName, - }, - client: client, - workerName: workerName, - lockStmt: fmt.Sprintf(lockStmtFormat, lockTable), + h := &locker{ + projectionName: projectionName, + client: client, + workerName: workerName, + lockStmt: fmt.Sprintf(lockStmtFormat, lockTable), } for _, expectation := range tt.want.expectations { diff --git a/internal/query/key.go b/internal/query/key.go new file mode 100644 index 0000000000..ee0555ef16 --- /dev/null +++ b/internal/query/key.go @@ -0,0 +1,330 @@ +package query + +import ( + "context" + "crypto/rsa" + "database/sql" + "time" + + sq "github.com/Masterminds/squirrel" + + "github.com/caos/zitadel/internal/crypto" + "github.com/caos/zitadel/internal/domain" + "github.com/caos/zitadel/internal/errors" + "github.com/caos/zitadel/internal/query/projection" +) + +type Key interface { + ID() string + Algorithm() string + Use() domain.KeyUsage + Sequence() uint64 +} + +type PrivateKey interface { + Key + Expiry() time.Time + Key() *crypto.CryptoValue +} + +type PublicKey interface { + Key + Expiry() time.Time + Key() interface{} +} + +type PrivateKeys struct { + SearchResponse + Keys []PrivateKey +} + +type PublicKeys struct { + SearchResponse + Keys []PublicKey +} + +type key struct { + id string + creationDate time.Time + changeDate time.Time + sequence uint64 + resourceOwner string + algorithm string + use domain.KeyUsage +} + +func (k *key) ID() string { + return k.id +} + +func (k *key) Algorithm() string { + return k.algorithm +} + +func (k *key) Use() domain.KeyUsage { + return k.use +} + +func (k *key) Sequence() uint64 { + return k.sequence +} + +type privateKey struct { + key + expiry time.Time + privateKey *crypto.CryptoValue +} + +func (k *privateKey) Expiry() time.Time { + return k.expiry +} + +func (k *privateKey) Key() *crypto.CryptoValue { + return k.privateKey +} + +type rsaPublicKey struct { + key + expiry time.Time + publicKey *rsa.PublicKey +} + +func (r *rsaPublicKey) Expiry() time.Time { + return r.expiry +} + +func (r *rsaPublicKey) Key() interface{} { + return r.publicKey +} + +var ( + keyTable = table{ + name: projection.KeyProjectionTable, + } + KeyColID = Column{ + name: projection.KeyColumnID, + table: keyTable, + } + KeyColCreationDate = Column{ + name: projection.KeyColumnCreationDate, + table: keyTable, + } + KeyColChangeDate = Column{ + name: projection.KeyColumnChangeDate, + table: keyTable, + } + KeyColResourceOwner = Column{ + name: projection.KeyColumnResourceOwner, + table: keyTable, + } + KeyColSequence = Column{ + name: projection.KeyColumnSequence, + table: keyTable, + } + KeyColAlgorithm = Column{ + name: projection.KeyColumnAlgorithm, + table: keyTable, + } + KeyColUse = Column{ + name: projection.KeyColumnUse, + table: keyTable, + } +) + +var ( + keyPrivateTable = table{ + name: projection.KeyPrivateTable, + } + KeyPrivateColID = Column{ + name: projection.KeyPrivateColumnID, + table: keyPrivateTable, + } + KeyPrivateColExpiry = Column{ + name: projection.KeyPrivateColumnExpiry, + table: keyPrivateTable, + } + KeyPrivateColKey = Column{ + name: projection.KeyPrivateColumnKey, + table: keyPrivateTable, + } +) + +var ( + keyPublicTable = table{ + name: projection.KeyPublicTable, + } + KeyPublicColID = Column{ + name: projection.KeyPublicColumnID, + table: keyPublicTable, + } + KeyPublicColExpiry = Column{ + name: projection.KeyPublicColumnExpiry, + table: keyPublicTable, + } + KeyPublicColKey = Column{ + name: projection.KeyPublicColumnKey, + table: keyPublicTable, + } +) + +func (q *Queries) ActivePublicKeys(ctx context.Context, t time.Time) (*PublicKeys, error) { + query, scan := preparePublicKeysQuery() + if t.IsZero() { + t = time.Now() + } + stmt, args, err := query.Where( + sq.Gt{ + KeyPublicColExpiry.identifier(): t, + }).ToSql() + if err != nil { + return nil, errors.ThrowInternal(err, "QUERY-SDFfg", "Errors.Query.SQLStatement") + } + + rows, err := q.client.QueryContext(ctx, stmt, args...) + if err != nil { + return nil, errors.ThrowInternal(err, "QUERY-Sghn4", "Errors.Internal") + } + keys, err := scan(rows) + if err != nil { + return nil, err + } + keys.LatestSequence, err = q.latestSequence(ctx, keyTable) + return keys, err +} + +func (q *Queries) ActivePrivateSigningKey(ctx context.Context, t time.Time) (*PrivateKeys, error) { + stmt, scan := preparePrivateKeysQuery() + if t.IsZero() { + t = time.Now() + } + query, args, err := stmt.Where( + sq.And{ + sq.Eq{ + KeyColUse.identifier(): domain.KeyUsageSigning, + }, + sq.Gt{ + KeyPrivateColExpiry.identifier(): t, + }, + }).OrderBy(KeyPrivateColExpiry.identifier()).ToSql() + if err != nil { + return nil, errors.ThrowInternal(err, "QUERY-SDff2", "Errors.Query.SQLStatement") + } + + rows, err := q.client.QueryContext(ctx, query, args...) + if err != nil { + return nil, errors.ThrowInternal(err, "QUERY-WRFG4", "Errors.Internal") + } + keys, err := scan(rows) + if err != nil { + return nil, err + } + keys.LatestSequence, err = q.latestSequence(ctx, keyTable) + return keys, err +} + +func preparePublicKeysQuery() (sq.SelectBuilder, func(*sql.Rows) (*PublicKeys, error)) { + return sq.Select( + KeyColID.identifier(), + KeyColCreationDate.identifier(), + KeyColChangeDate.identifier(), + KeyColSequence.identifier(), + KeyColResourceOwner.identifier(), + KeyColAlgorithm.identifier(), + KeyColUse.identifier(), + KeyPublicColExpiry.identifier(), + KeyPublicColKey.identifier(), + countColumn.identifier(), + ).From(keyTable.identifier()). + LeftJoin(join(KeyPublicColID, KeyColID)). + PlaceholderFormat(sq.Dollar), + func(rows *sql.Rows) (*PublicKeys, error) { + keys := make([]PublicKey, 0) + var count uint64 + for rows.Next() { + k := new(rsaPublicKey) + var keyValue []byte + err := rows.Scan( + &k.id, + &k.creationDate, + &k.changeDate, + &k.sequence, + &k.resourceOwner, + &k.algorithm, + &k.use, + &k.expiry, + &keyValue, + &count, + ) + if err != nil { + return nil, err + } + k.publicKey, err = crypto.BytesToPublicKey(keyValue) + if err != nil { + return nil, err + } + keys = append(keys, k) + } + + if err := rows.Close(); err != nil { + return nil, errors.ThrowInternal(err, "QUERY-rKd6k", "Errors.Query.CloseRows") + } + + return &PublicKeys{ + Keys: keys, + SearchResponse: SearchResponse{ + Count: count, + }, + }, nil + } +} + +func preparePrivateKeysQuery() (sq.SelectBuilder, func(*sql.Rows) (*PrivateKeys, error)) { + return sq.Select( + KeyColID.identifier(), + KeyColCreationDate.identifier(), + KeyColChangeDate.identifier(), + KeyColSequence.identifier(), + KeyColResourceOwner.identifier(), + KeyColAlgorithm.identifier(), + KeyColUse.identifier(), + KeyPrivateColExpiry.identifier(), + KeyPrivateColKey.identifier(), + countColumn.identifier(), + ).From(keyTable.identifier()). + LeftJoin(join(KeyPrivateColID, KeyColID)). + PlaceholderFormat(sq.Dollar), + func(rows *sql.Rows) (*PrivateKeys, error) { + keys := make([]PrivateKey, 0) + var count uint64 + for rows.Next() { + k := new(privateKey) + err := rows.Scan( + &k.id, + &k.creationDate, + &k.changeDate, + &k.sequence, + &k.resourceOwner, + &k.algorithm, + &k.use, + &k.expiry, + &k.privateKey, + &count, + ) + if err != nil { + return nil, err + } + keys = append(keys, k) + } + + if err := rows.Close(); err != nil { + return nil, errors.ThrowInternal(err, "QUERY-rKd6k", "Errors.Query.CloseRows") + } + + return &PrivateKeys{ + Keys: keys, + SearchResponse: SearchResponse{ + Count: count, + }, + }, nil + } +} diff --git a/internal/query/key_test.go b/internal/query/key_test.go new file mode 100644 index 0000000000..6d3c142bb3 --- /dev/null +++ b/internal/query/key_test.go @@ -0,0 +1,295 @@ +package query + +import ( + "crypto/rsa" + "database/sql" + "database/sql/driver" + "errors" + "fmt" + "math/big" + "regexp" + "testing" + + "github.com/caos/zitadel/internal/crypto" + "github.com/caos/zitadel/internal/domain" + errs "github.com/caos/zitadel/internal/errors" +) + +func Test_KeyPrepares(t *testing.T) { + type want struct { + sqlExpectations sqlExpectation + err checkErr + } + tests := []struct { + name string + prepare interface{} + want want + object interface{} + }{ + { + name: "preparePublicKeysQuery no result", + prepare: preparePublicKeysQuery, + want: want{ + sqlExpectations: mockQueries( + regexp.QuoteMeta(`SELECT zitadel.projections.keys.id,`+ + ` zitadel.projections.keys.creation_date,`+ + ` zitadel.projections.keys.change_date,`+ + ` zitadel.projections.keys.sequence,`+ + ` zitadel.projections.keys.resource_owner,`+ + ` zitadel.projections.keys.algorithm,`+ + ` zitadel.projections.keys.use,`+ + ` zitadel.projections.keys_public.expiry,`+ + ` zitadel.projections.keys_public.key,`+ + ` COUNT(*) OVER ()`+ + ` FROM zitadel.projections.keys`+ + ` LEFT JOIN zitadel.projections.keys_public ON zitadel.projections.keys.id = zitadel.projections.keys_public.id`), + nil, + nil, + ), + err: func(err error) (error, bool) { + if !errs.IsNotFound(err) { + return fmt.Errorf("err should be zitadel.NotFoundError got: %w", err), false + } + return nil, true + }, + }, + object: &PublicKeys{Keys: []PublicKey{}}, + }, + { + name: "preparePublicKeysQuery found", + prepare: preparePublicKeysQuery, + want: want{ + sqlExpectations: mockQueries( + regexp.QuoteMeta(`SELECT zitadel.projections.keys.id,`+ + ` zitadel.projections.keys.creation_date,`+ + ` zitadel.projections.keys.change_date,`+ + ` zitadel.projections.keys.sequence,`+ + ` zitadel.projections.keys.resource_owner,`+ + ` zitadel.projections.keys.algorithm,`+ + ` zitadel.projections.keys.use,`+ + ` zitadel.projections.keys_public.expiry,`+ + ` zitadel.projections.keys_public.key,`+ + ` COUNT(*) OVER ()`+ + ` FROM zitadel.projections.keys`+ + ` LEFT JOIN zitadel.projections.keys_public ON zitadel.projections.keys.id = zitadel.projections.keys_public.id`), + []string{ + "id", + "creation_date", + "change_date", + "sequence", + "resource_owner", + "algorithm", + "use", + "expiry", + "key", + "count", + }, + [][]driver.Value{ + { + "key-id", + testNow, + testNow, + uint64(20211109), + "ro", + "RS256", + 0, + testNow, + []byte("-----BEGIN RSA PUBLIC KEY-----\nMIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEAsvX9P58JFxEs5C+L+H7W\nduFSWL5EPzber7C2m94klrSV6q0bAcrYQnGwFOlveThsY200hRbadKaKjHD7qIKH\nDEe0IY2PSRht33Jye52AwhkRw+M3xuQH/7R8LydnsNFk2KHpr5X2SBv42e37LjkE\nslKSaMRgJW+v0KZ30piY8QsdFRKKaVg5/Ajt1YToM1YVsdHXJ3vmXFMtypLdxwUD\ndIaLEX6pFUkU75KSuEQ/E2luT61Q3ta9kOWm9+0zvi7OMcbdekJT7mzcVnh93R1c\n13ZhQCLbh9A7si8jKFtaMWevjayrvqQABEcTN9N4Hoxcyg6l4neZtRDk75OMYcqm\nDQIDAQAB\n-----END RSA PUBLIC KEY-----\n"), + }, + }, + ), + }, + object: &PublicKeys{ + SearchResponse: SearchResponse{ + Count: 1, + }, + Keys: []PublicKey{ + &rsaPublicKey{ + key: key{ + id: "key-id", + creationDate: testNow, + changeDate: testNow, + sequence: 20211109, + resourceOwner: "ro", + algorithm: "RS256", + use: domain.KeyUsageSigning, + }, + expiry: testNow, + publicKey: &rsa.PublicKey{ + E: 65537, + N: fromBase16("b2f5fd3f9f0917112ce42f8bf87ed676e15258be443f36deafb0b69bde2496b495eaad1b01cad84271b014e96f79386c636d348516da74a68a8c70fba882870c47b4218d8f49186ddf72727b9d80c21911c3e337c6e407ffb47c2f2767b0d164d8a1e9af95f6481bf8d9edfb2e3904b2529268c460256fafd0a677d29898f10b1d15128a695839fc08edd584e8335615b1d1d7277be65c532dca92ddc7050374868b117ea9154914ef9292b8443f13696e4fad50ded6bd90e5a6f7ed33be2ece31c6dd7a4253ee6cdc56787ddd1d5cd776614022db87d03bb22f23285b5a3167af8dacabbea40004471337d3781e8c5cca0ea5e27799b510e4ef938c61caa60d"), + }, + }, + }, + }, + }, + { + name: "preparePublicKeysQuery sql err", + prepare: preparePublicKeysQuery, + want: want{ + sqlExpectations: mockQueryErr( + regexp.QuoteMeta(`SELECT zitadel.projections.keys.id,`+ + ` zitadel.projections.keys.creation_date,`+ + ` zitadel.projections.keys.change_date,`+ + ` zitadel.projections.keys.sequence,`+ + ` zitadel.projections.keys.resource_owner,`+ + ` zitadel.projections.keys.algorithm,`+ + ` zitadel.projections.keys.use,`+ + ` zitadel.projections.keys_public.expiry,`+ + ` zitadel.projections.keys_public.key,`+ + ` COUNT(*) OVER ()`+ + ` FROM zitadel.projections.keys`+ + ` LEFT JOIN zitadel.projections.keys_public ON zitadel.projections.keys.id = zitadel.projections.keys_public.id`), + sql.ErrConnDone, + ), + err: func(err error) (error, bool) { + if !errors.Is(err, sql.ErrConnDone) { + return fmt.Errorf("err should be sql.ErrConnDone got: %w", err), false + } + return nil, true + }, + }, + object: nil, + }, + { + name: "preparePrivateKeysQuery no result", + prepare: preparePrivateKeysQuery, + want: want{ + sqlExpectations: mockQueries( + regexp.QuoteMeta(`SELECT zitadel.projections.keys.id,`+ + ` zitadel.projections.keys.creation_date,`+ + ` zitadel.projections.keys.change_date,`+ + ` zitadel.projections.keys.sequence,`+ + ` zitadel.projections.keys.resource_owner,`+ + ` zitadel.projections.keys.algorithm,`+ + ` zitadel.projections.keys.use,`+ + ` zitadel.projections.keys_private.expiry,`+ + ` zitadel.projections.keys_private.key,`+ + ` COUNT(*) OVER ()`+ + ` FROM zitadel.projections.keys`+ + ` LEFT JOIN zitadel.projections.keys_private ON zitadel.projections.keys.id = zitadel.projections.keys_private.id`), + nil, + nil, + ), + err: func(err error) (error, bool) { + if !errs.IsNotFound(err) { + return fmt.Errorf("err should be zitadel.NotFoundError got: %w", err), false + } + return nil, true + }, + }, + object: &PrivateKeys{Keys: []PrivateKey{}}, + }, + { + name: "preparePrivateKeysQuery found", + prepare: preparePrivateKeysQuery, + want: want{ + sqlExpectations: mockQueries( + regexp.QuoteMeta(`SELECT zitadel.projections.keys.id,`+ + ` zitadel.projections.keys.creation_date,`+ + ` zitadel.projections.keys.change_date,`+ + ` zitadel.projections.keys.sequence,`+ + ` zitadel.projections.keys.resource_owner,`+ + ` zitadel.projections.keys.algorithm,`+ + ` zitadel.projections.keys.use,`+ + ` zitadel.projections.keys_private.expiry,`+ + ` zitadel.projections.keys_private.key,`+ + ` COUNT(*) OVER ()`+ + ` FROM zitadel.projections.keys`+ + ` LEFT JOIN zitadel.projections.keys_private ON zitadel.projections.keys.id = zitadel.projections.keys_private.id`), + []string{ + "id", + "creation_date", + "change_date", + "sequence", + "resource_owner", + "algorithm", + "use", + "expiry", + "key", + "count", + }, + [][]driver.Value{ + { + "key-id", + testNow, + testNow, + uint64(20211109), + "ro", + "RS256", + 0, + testNow, + []byte(`{"Algorithm": "enc", "Crypted": "cHJpdmF0ZUtleQ==", "CryptoType": 0, "KeyID": "id"}`), + }, + }, + ), + }, + object: &PrivateKeys{ + SearchResponse: SearchResponse{ + Count: 1, + }, + Keys: []PrivateKey{ + &privateKey{ + key: key{ + id: "key-id", + creationDate: testNow, + changeDate: testNow, + sequence: 20211109, + resourceOwner: "ro", + algorithm: "RS256", + use: domain.KeyUsageSigning, + }, + expiry: testNow, + privateKey: &crypto.CryptoValue{ + CryptoType: crypto.TypeEncryption, + Algorithm: "enc", + KeyID: "id", + Crypted: []byte("privateKey"), + }, + }, + }, + }, + }, + { + name: "preparePrivateKeysQuery sql err", + prepare: preparePrivateKeysQuery, + want: want{ + sqlExpectations: mockQueryErr( + regexp.QuoteMeta(`SELECT zitadel.projections.keys.id,`+ + ` zitadel.projections.keys.creation_date,`+ + ` zitadel.projections.keys.change_date,`+ + ` zitadel.projections.keys.sequence,`+ + ` zitadel.projections.keys.resource_owner,`+ + ` zitadel.projections.keys.algorithm,`+ + ` zitadel.projections.keys.use,`+ + ` zitadel.projections.keys_private.expiry,`+ + ` zitadel.projections.keys_private.key,`+ + ` COUNT(*) OVER ()`+ + ` FROM zitadel.projections.keys`+ + ` LEFT JOIN zitadel.projections.keys_private ON zitadel.projections.keys.id = zitadel.projections.keys_private.id`), + sql.ErrConnDone, + ), + err: func(err error) (error, bool) { + if !errors.Is(err, sql.ErrConnDone) { + return fmt.Errorf("err should be sql.ErrConnDone got: %w", err), false + } + return nil, true + }, + }, + object: nil, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assertPrepare(t, tt.prepare, tt.object, tt.want.sqlExpectations, tt.want.err) + }) + } +} + +func fromBase16(base16 string) *big.Int { + i, ok := new(big.Int).SetString(base16, 16) + if !ok { + panic("bad number: " + base16) + } + return i +} diff --git a/internal/query/prepare_test.go b/internal/query/prepare_test.go index 7b820089e2..439712662c 100644 --- a/internal/query/prepare_test.go +++ b/internal/query/prepare_test.go @@ -3,7 +3,6 @@ package query import ( "database/sql" "database/sql/driver" - "encoding/json" "errors" "fmt" "log" @@ -13,7 +12,7 @@ import ( "github.com/DATA-DOG/go-sqlmock" sq "github.com/Masterminds/squirrel" - "github.com/nsf/jsondiff" + "github.com/stretchr/testify/assert" ) var ( @@ -58,8 +57,7 @@ func assertPrepare(t *testing.T, prepareFunc, expectedObject interface{}, sqlExp return false } - if !reflect.DeepEqual(object, expectedObject) { - prettyPrintDiff(t, expectedObject, object) + if !assert.Equal(t, expectedObject, object) { return false } @@ -315,19 +313,3 @@ func TestValidatePrepare(t *testing.T) { }) } } - -func prettyPrintDiff(t *testing.T, expected, gotten interface{}) { - t.Helper() - - expectedMarshalled, _ := json.Marshal(expected) - objectMarshalled, _ := json.Marshal(gotten) - _, diff := jsondiff.Compare( - expectedMarshalled, - objectMarshalled, - &jsondiff.Options{ - SkipMatches: true, - Indent: " ", - ChangedSeparator: " is expected, got ", - }) - t.Errorf("unexpected object: want %T, got %T, difference:\n%s", expected, gotten, diff) -} diff --git a/internal/query/projection/key.go b/internal/query/projection/key.go index 308811231a..190f96dba6 100644 --- a/internal/query/projection/key.go +++ b/internal/query/projection/key.go @@ -18,6 +18,7 @@ import ( type KeyProjection struct { crdb.StatementHandler encryptionAlgorithm crypto.EncryptionAlgorithm + keyChan chan<- interface{} } const ( @@ -26,11 +27,12 @@ const ( KeyPublicTable = KeyProjectionTable + "_" + publicKeyTableSuffix ) -func NewKeyProjection(ctx context.Context, config crdb.StatementHandlerConfig, keyConfig systemdefaults.KeyConfig) (_ *KeyProjection, err error) { +func NewKeyProjection(ctx context.Context, config crdb.StatementHandlerConfig, keyConfig systemdefaults.KeyConfig, keyChan chan<- interface{}) (_ *KeyProjection, err error) { p := &KeyProjection{} config.ProjectionName = KeyProjectionTable config.Reducers = p.reducers() p.StatementHandler = crdb.NewStatementHandler(ctx, config) + p.keyChan = keyChan p.encryptionAlgorithm, err = crypto.NewAESCrypto(keyConfig.EncryptionConfig) if err != nil { return nil, err @@ -103,6 +105,9 @@ func (p *KeyProjection) reduceKeyPairAdded(event eventstore.Event) (*handler.Sta }, crdb.WithTableSuffix(privateKeyTableSuffix), )) + if p.keyChan != nil { + p.keyChan <- true + } } if e.PublicKey.Expiry.After(time.Now()) { publicKey, err := crypto.Decrypt(e.PublicKey.Key, p.encryptionAlgorithm) diff --git a/internal/query/projection/projection.go b/internal/query/projection/projection.go index bcbc878f1d..abb9600ce5 100644 --- a/internal/query/projection/projection.go +++ b/internal/query/projection/projection.go @@ -17,7 +17,7 @@ const ( FailedEventsTable = "projections.failed_events" ) -func Start(ctx context.Context, sqlClient *sql.DB, es *eventstore.Eventstore, config Config, defaults systemdefaults.SystemDefaults) error { +func Start(ctx context.Context, sqlClient *sql.DB, es *eventstore.Eventstore, config Config, defaults systemdefaults.SystemDefaults, keyChan chan<- interface{}) error { projectionConfig := crdb.StatementHandlerConfig{ ProjectionHandlerConfig: handler.ProjectionHandlerConfig{ HandlerConfig: handler.HandlerConfig{ @@ -63,8 +63,8 @@ func Start(ctx context.Context, sqlClient *sql.DB, es *eventstore.Eventstore, co NewIAMMemberProjection(ctx, applyCustomConfig(projectionConfig, config.Customizations["iam_members"])) NewProjectMemberProjection(ctx, applyCustomConfig(projectionConfig, config.Customizations["project_members"])) NewProjectGrantMemberProjection(ctx, applyCustomConfig(projectionConfig, config.Customizations["project_grant_members"])) - _, err := NewKeyProjection(ctx, applyCustomConfig(projectionConfig, config.Customizations["keys"]), defaults.KeyConfig) NewAuthNKeyProjection(ctx, applyCustomConfig(projectionConfig, config.Customizations["authn_keys"])) + _, err := NewKeyProjection(ctx, applyCustomConfig(projectionConfig, config.Customizations["keys"]), defaults.KeyConfig, keyChan) return err } diff --git a/internal/query/query.go b/internal/query/query.go index e5c5d9dd7e..2f5685a213 100644 --- a/internal/query/query.go +++ b/internal/query/query.go @@ -40,7 +40,7 @@ type Config struct { Eventstore types.SQLUser } -func StartQueries(ctx context.Context, es *eventstore.Eventstore, projections projection.Config, defaults sd.SystemDefaults) (repo *Queries, err error) { +func StartQueries(ctx context.Context, es *eventstore.Eventstore, projections projection.Config, defaults sd.SystemDefaults, keyChan chan<- interface{}) (repo *Queries, err error) { sqlClient, err := projections.CRDB.Start() if err != nil { return nil, err @@ -69,7 +69,7 @@ func StartQueries(ctx context.Context, es *eventstore.Eventstore, projections pr action.RegisterEventMappers(repo.eventstore) keypair.RegisterEventMappers(repo.eventstore) - err = projection.Start(ctx, sqlClient, es, projections, defaults) + err = projection.Start(ctx, sqlClient, es, projections, defaults, keyChan) if err != nil { return nil, err }