From c8c5cf3c5f825ce9d619cd7914c74117f5023d11 Mon Sep 17 00:00:00 2001 From: Silvan Date: Fri, 28 Apr 2023 13:55:35 +0200 Subject: [PATCH 1/5] feat(cli): add `setup cleanup` sub command (#5770) * feat(cli): add `setup cleanup` sub command * chore: logging * chore: logging --- cmd/setup/cleanup.go | 51 +++++++++++++++++++++++++++++++ cmd/setup/setup.go | 2 ++ internal/migration/command.go | 10 +++--- internal/migration/migration.go | 54 +++++++++++++++++++++++++++++---- 4 files changed, 106 insertions(+), 11 deletions(-) create mode 100644 cmd/setup/cleanup.go diff --git a/cmd/setup/cleanup.go b/cmd/setup/cleanup.go new file mode 100644 index 0000000000..7139b67d35 --- /dev/null +++ b/cmd/setup/cleanup.go @@ -0,0 +1,51 @@ +package setup + +import ( + "context" + + "github.com/spf13/cobra" + "github.com/spf13/viper" + "github.com/zitadel/logging" + + "github.com/zitadel/zitadel/internal/database" + "github.com/zitadel/zitadel/internal/eventstore" + "github.com/zitadel/zitadel/internal/migration" +) + +func NewCleanup() *cobra.Command { + return &cobra.Command{ + Use: "cleanup", + Short: "cleans up migration if they got stuck", + Long: `cleans up migration if they got stuck`, + Run: func(cmd *cobra.Command, args []string) { + config := MustNewConfig(viper.GetViper()) + Cleanup(config) + }, + } +} + +func Cleanup(config *Config) { + ctx := context.Background() + + logging.Info("cleanup started") + + dbClient, err := database.Connect(config.Database, false) + logging.OnError(err).Fatal("unable to connect to database") + + es, err := eventstore.Start(&eventstore.Config{Client: dbClient}) + logging.OnError(err).Fatal("unable to start eventstore") + migration.RegisterMappers(es) + + step, err := migration.LatestStep(ctx, es) + logging.OnError(err).Fatal("unable to query latest migration") + + if step.BaseEvent.EventType != migration.StartedType { + logging.Info("there is no stuck migration please run `zitadel setup`") + return + } + + logging.WithFields("name", step.Name).Info("cleanup migration") + + err = migration.CancelStep(ctx, es, step) + logging.OnError(err).Fatal("cleanup migration failed please retry") +} diff --git a/cmd/setup/setup.go b/cmd/setup/setup.go index e90ad85f94..918f08dc2d 100644 --- a/cmd/setup/setup.go +++ b/cmd/setup/setup.go @@ -45,6 +45,8 @@ Requirements: }, } + cmd.AddCommand(NewCleanup()) + Flags(cmd) return cmd diff --git a/internal/migration/command.go b/internal/migration/command.go index 9d04e3cbe7..6552ef6ed4 100644 --- a/internal/migration/command.go +++ b/internal/migration/command.go @@ -41,14 +41,14 @@ func setupStartedCmd(migration Migration) eventstore.Command { BaseEvent: *eventstore.NewBaseEventForPush( ctx, eventstore.NewAggregate(ctx, aggregateID, aggregateType, "v1"), - startedType), + StartedType), migration: migration, Name: migration.String(), } } -func setupDoneCmd(migration Migration, err error) eventstore.Command { - ctx := authz.SetCtxData(service.WithService(context.Background(), "system"), authz.CtxData{UserID: "system", OrgID: "SYSTEM", ResourceOwner: "SYSTEM"}) +func setupDoneCmd(ctx context.Context, migration Migration, err error) eventstore.Command { + ctx = authz.SetCtxData(service.WithService(ctx, "system"), authz.CtxData{UserID: "system", OrgID: "SYSTEM", ResourceOwner: "SYSTEM"}) typ := doneType var lastRun interface{} if repeatable, ok := migration.(RepeatableMigration); ok { @@ -80,7 +80,7 @@ func (s *SetupStep) Data() interface{} { func (s *SetupStep) UniqueConstraints() []*eventstore.EventUniqueConstraint { switch s.Type() { - case startedType: + case StartedType: return []*eventstore.EventUniqueConstraint{ eventstore.NewAddGlobalEventUniqueConstraint("migration_started", s.migration.String(), "Errors.Step.Started.AlreadyExists"), } @@ -97,7 +97,7 @@ func (s *SetupStep) UniqueConstraints() []*eventstore.EventUniqueConstraint { } func RegisterMappers(es *eventstore.Eventstore) { - es.RegisterFilterEventMapper(aggregateType, startedType, SetupMapper) + es.RegisterFilterEventMapper(aggregateType, StartedType, SetupMapper) es.RegisterFilterEventMapper(aggregateType, doneType, SetupMapper) es.RegisterFilterEventMapper(aggregateType, failedType, SetupMapper) es.RegisterFilterEventMapper(aggregateType, repeatableDoneType, SetupMapper) diff --git a/internal/migration/migration.go b/internal/migration/migration.go index 3608332a8a..63a6cb7b7a 100644 --- a/internal/migration/migration.go +++ b/internal/migration/migration.go @@ -12,7 +12,7 @@ import ( ) const ( - startedType = eventstore.EventType("system.migration.started") + StartedType = eventstore.EventType("system.migration.started") doneType = eventstore.EventType("system.migration.done") failedType = eventstore.EventType("system.migration.failed") repeatableDoneType = eventstore.EventType("system.migration.repeatable.done") @@ -36,7 +36,7 @@ type RepeatableMigration interface { } func Migrate(ctx context.Context, es *eventstore.Eventstore, migration Migration) (err error) { - logging.Infof("verify migration %s", migration.String()) + logging.WithFields("name", migration.String()).Info("verify migration") if should, err := checkExec(ctx, es, migration); !should || err != nil { return err @@ -46,11 +46,11 @@ func Migrate(ctx context.Context, es *eventstore.Eventstore, migration Migration return err } - logging.Infof("starting migration %s", migration.String()) + logging.WithFields("name", migration.String()).Info("starting migration") err = migration.Execute(ctx) logging.OnError(err).Error("migration failed") - _, pushErr := es.Push(ctx, setupDoneCmd(migration, err)) + _, pushErr := es.Push(ctx, setupDoneCmd(ctx, migration, err)) logging.OnError(pushErr).Error("migration failed") if err != nil { return err @@ -58,6 +58,48 @@ func Migrate(ctx context.Context, es *eventstore.Eventstore, migration Migration return pushErr } +func LatestStep(ctx context.Context, es *eventstore.Eventstore) (*SetupStep, error) { + events, err := es.Filter(ctx, eventstore.NewSearchQueryBuilder(eventstore.ColumnsEvent). + OrderDesc(). + Limit(1). + AddQuery(). + AggregateTypes(aggregateType). + AggregateIDs(aggregateID). + EventTypes(StartedType, doneType, repeatableDoneType, failedType). + Builder()) + if err != nil { + return nil, err + } + step, ok := events[0].(*SetupStep) + if !ok { + return nil, errors.ThrowInternal(nil, "MIGRA-hppLM", "setup step is malformed") + } + return step, nil +} + +var _ Migration = (*cancelMigration)(nil) + +type cancelMigration struct { + name string +} + +// Execute implements Migration +func (*cancelMigration) Execute(context.Context) error { + return nil +} + +// String implements Migration +func (m *cancelMigration) String() string { + return m.name +} + +var errCancelStep = errors.ThrowError(nil, "MIGRA-zo86K", "migration canceled manually") + +func CancelStep(ctx context.Context, es *eventstore.Eventstore, step *SetupStep) error { + _, err := es.Push(ctx, setupDoneCmd(ctx, &cancelMigration{name: step.Name}, errCancelStep)) + return err +} + // checkExec ensures that only one setup step is done concurrently // if a setup step is already started, it calls shouldExec after some time again func checkExec(ctx context.Context, es *eventstore.Eventstore, migration Migration) (bool, error) { @@ -88,7 +130,7 @@ func shouldExec(ctx context.Context, es *eventstore.Eventstore, migration Migrat AddQuery(). AggregateTypes(aggregateType). AggregateIDs(aggregateID). - EventTypes(startedType, doneType, repeatableDoneType, failedType). + EventTypes(StartedType, doneType, repeatableDoneType, failedType). Builder()) if err != nil { return false, err @@ -106,7 +148,7 @@ func shouldExec(ctx context.Context, es *eventstore.Eventstore, migration Migrat } switch event.Type() { - case startedType, failedType: + case StartedType, failedType: isStarted = !isStarted case doneType, repeatableDoneType: From 458a383de203cac4bd0013f419e46f0420f30921 Mon Sep 17 00:00:00 2001 From: Livio Spring Date: Fri, 28 Apr 2023 16:28:13 +0200 Subject: [PATCH 2/5] fix: use current sequence for refetching of events (#5772) * fix: use current sequence for refetching of events * fix: use client ids --- .../eventsourcing/handler/styling.go | 8 ++-- .../repository/eventsourcing/view/sequence.go | 9 ++-- .../repository/eventsourcing/view/styling.go | 10 ++-- .../repository/eventsourcing/view/view.go | 14 +++++- .../eventsourcing/eventstore/auth_request.go | 29 ++++++++---- .../eventstore/auth_request_test.go | 13 ++++++ .../eventsourcing/eventstore/refresh_token.go | 15 ++++-- .../eventsourcing/eventstore/token.go | 14 +++++- .../eventsourcing/handler/refresh_token.go | 8 ++-- .../repository/eventsourcing/handler/token.go | 20 ++++---- .../repository/eventsourcing/handler/user.go | 22 ++++++--- .../eventsourcing/handler/user_session.go | 12 ++--- .../eventsourcing/view/refresh_token.go | 10 ++-- .../repository/eventsourcing/view/sequence.go | 9 ++-- .../repository/eventsourcing/view/token.go | 10 ++-- .../repository/eventsourcing/view/user.go | 46 +++++++++++-------- .../eventsourcing/view/user_session.go | 10 ++-- .../repository/eventsourcing/view/view.go | 9 ++++ .../eventstore/token_verifier.go | 10 ++-- .../repository/eventsourcing/view/sequence.go | 6 ++- .../repository/eventsourcing/view/token.go | 6 ++- .../repository/eventsourcing/view/view.go | 13 +++++- internal/eventstore/v1/query/handler.go | 8 ++-- internal/eventstore/v1/spooler/spooler.go | 2 +- .../eventstore/v1/spooler/spooler_test.go | 4 +- internal/org/repository/view/query.go | 14 ++++++ .../repository/eventsourcing/model/project.go | 40 ++++++++++++++++ internal/project/repository/view/query.go | 9 ++++ 28 files changed, 273 insertions(+), 107 deletions(-) diff --git a/internal/admin/repository/eventsourcing/handler/styling.go b/internal/admin/repository/eventsourcing/handler/styling.go index 53a7f89812..8d5107e788 100644 --- a/internal/admin/repository/eventsourcing/handler/styling.go +++ b/internal/admin/repository/eventsourcing/handler/styling.go @@ -65,16 +65,16 @@ func (_ *Styling) AggregateTypes() []models.AggregateType { return []models.AggregateType{org.AggregateType, instance.AggregateType} } -func (m *Styling) CurrentSequence(instanceID string) (uint64, error) { - sequence, err := m.view.GetLatestStylingSequence(instanceID) +func (m *Styling) CurrentSequence(ctx context.Context, instanceID string) (uint64, error) { + sequence, err := m.view.GetLatestStylingSequence(ctx, instanceID) if err != nil { return 0, err } return sequence.CurrentSequence, nil } -func (m *Styling) EventQuery(instanceIDs []string) (*models.SearchQuery, error) { - sequences, err := m.view.GetLatestStylingSequences(instanceIDs) +func (m *Styling) EventQuery(ctx context.Context, instanceIDs []string) (*models.SearchQuery, error) { + sequences, err := m.view.GetLatestStylingSequences(ctx, instanceIDs) if err != nil { return nil, err } diff --git a/internal/admin/repository/eventsourcing/view/sequence.go b/internal/admin/repository/eventsourcing/view/sequence.go index 4985fbef48..8de577f89e 100644 --- a/internal/admin/repository/eventsourcing/view/sequence.go +++ b/internal/admin/repository/eventsourcing/view/sequence.go @@ -1,6 +1,7 @@ package view import ( + "context" "time" "github.com/zitadel/zitadel/internal/eventstore/v1/models" @@ -15,12 +16,12 @@ func (v *View) saveCurrentSequence(viewName string, event *models.Event) error { return repository.SaveCurrentSequence(v.Db, sequencesTable, viewName, event.InstanceID, event.Sequence, event.CreationDate) } -func (v *View) latestSequence(viewName, instanceID string) (*repository.CurrentSequence, error) { - return repository.LatestSequence(v.Db, sequencesTable, viewName, instanceID) +func (v *View) latestSequence(ctx context.Context, viewName, instanceID string) (*repository.CurrentSequence, error) { + return repository.LatestSequence(v.Db, v.TimeTravel(ctx, sequencesTable), viewName, instanceID) } -func (v *View) latestSequences(viewName string, instanceIDs []string) ([]*repository.CurrentSequence, error) { - return repository.LatestSequences(v.Db, sequencesTable, viewName, instanceIDs) +func (v *View) latestSequences(ctx context.Context, viewName string, instanceIDs []string) ([]*repository.CurrentSequence, error) { + return repository.LatestSequences(v.Db, v.TimeTravel(ctx, sequencesTable), viewName, instanceIDs) } func (v *View) AllCurrentSequences(db, instanceID string) ([]*repository.CurrentSequence, error) { diff --git a/internal/admin/repository/eventsourcing/view/styling.go b/internal/admin/repository/eventsourcing/view/styling.go index 25ac4c882a..ac477b5d96 100644 --- a/internal/admin/repository/eventsourcing/view/styling.go +++ b/internal/admin/repository/eventsourcing/view/styling.go @@ -1,6 +1,8 @@ package view import ( + "context" + "github.com/zitadel/zitadel/internal/eventstore/v1/models" "github.com/zitadel/zitadel/internal/iam/repository/view" "github.com/zitadel/zitadel/internal/iam/repository/view/model" @@ -39,12 +41,12 @@ func (v *View) UpdateOrgOwnerRemovedStyling(event *models.Event) error { return v.ProcessedStylingSequence(event) } -func (v *View) GetLatestStylingSequence(instanceID string) (*global_view.CurrentSequence, error) { - return v.latestSequence(stylingTyble, instanceID) +func (v *View) GetLatestStylingSequence(ctx context.Context, instanceID string) (*global_view.CurrentSequence, error) { + return v.latestSequence(ctx, stylingTyble, instanceID) } -func (v *View) GetLatestStylingSequences(instanceIDs []string) ([]*global_view.CurrentSequence, error) { - return v.latestSequences(stylingTyble, instanceIDs) +func (v *View) GetLatestStylingSequences(ctx context.Context, instanceIDs []string) ([]*global_view.CurrentSequence, error) { + return v.latestSequences(ctx, stylingTyble, instanceIDs) } func (v *View) ProcessedStylingSequence(event *models.Event) error { diff --git a/internal/admin/repository/eventsourcing/view/view.go b/internal/admin/repository/eventsourcing/view/view.go index 8f27706985..095e7c1dfa 100644 --- a/internal/admin/repository/eventsourcing/view/view.go +++ b/internal/admin/repository/eventsourcing/view/view.go @@ -1,12 +1,17 @@ package view import ( + "context" + "github.com/jinzhu/gorm" + + "github.com/zitadel/zitadel/internal/api/call" "github.com/zitadel/zitadel/internal/database" ) type View struct { - Db *gorm.DB + Db *gorm.DB + client *database.DB } func StartView(sqlClient *database.DB) (*View, error) { @@ -15,10 +20,15 @@ func StartView(sqlClient *database.DB) (*View, error) { return nil, err } return &View{ - Db: gorm, + Db: gorm, + client: sqlClient, }, nil } func (v *View) Health() (err error) { return v.Db.DB().Ping() } + +func (v *View) TimeTravel(ctx context.Context, tableName string) string { + return tableName + v.client.Timetravel(call.Took(ctx)) +} diff --git a/internal/auth/repository/eventsourcing/eventstore/auth_request.go b/internal/auth/repository/eventsourcing/eventstore/auth_request.go index 84f884620b..fcabd8f77f 100644 --- a/internal/auth/repository/eventsourcing/eventstore/auth_request.go +++ b/internal/auth/repository/eventsourcing/eventstore/auth_request.go @@ -23,6 +23,7 @@ import ( "github.com/zitadel/zitadel/internal/telemetry/tracing" user_model "github.com/zitadel/zitadel/internal/user/model" user_view_model "github.com/zitadel/zitadel/internal/user/repository/view/model" + "github.com/zitadel/zitadel/internal/view/repository" ) const unknownUserID = "UNKNOWN" @@ -64,7 +65,9 @@ type privacyPolicyProvider interface { type userSessionViewProvider interface { UserSessionByIDs(string, string, string) (*user_view_model.UserSessionView, error) UserSessionsByAgentID(string, string) ([]*user_view_model.UserSessionView, error) + GetLatestUserSessionSequence(ctx context.Context, instanceID string) (*repository.CurrentSequence, error) } + type userViewProvider interface { UserByID(string, string) (*user_view_model.UserView, error) } @@ -654,7 +657,7 @@ func (repo *AuthRequestRepo) checkLoginName(ctx context.Context, request *domain preferredLoginName += "@" + request.RequestedPrimaryDomain } } - user, err = repo.checkLoginNameInputForResourceOwner(request, preferredLoginName) + user, err = repo.checkLoginNameInputForResourceOwner(ctx, request, preferredLoginName) } else { user, err = repo.checkLoginNameInput(ctx, request, preferredLoginName) } @@ -729,12 +732,12 @@ func (repo *AuthRequestRepo) checkDomainDiscovery(ctx context.Context, request * func (repo *AuthRequestRepo) checkLoginNameInput(ctx context.Context, request *domain.AuthRequest, loginNameInput string) (*user_view_model.UserView, error) { // always check the loginname first - user, err := repo.View.UserByLoginName(loginNameInput, request.InstanceID) + user, err := repo.View.UserByLoginName(ctx, loginNameInput, request.InstanceID) if err == nil { // and take the user regardless if there would be a user with that email or phone return user, repo.checkLoginPolicyWithResourceOwner(ctx, request, user.ResourceOwner) } - user, emailErr := repo.View.UserByEmail(loginNameInput, request.InstanceID) + user, emailErr := repo.View.UserByEmail(ctx, loginNameInput, request.InstanceID) if emailErr == nil { // if there was a single user with the specified email // load and check the login policy @@ -747,7 +750,7 @@ func (repo *AuthRequestRepo) checkLoginNameInput(ctx context.Context, request *d return user, nil } } - user, phoneErr := repo.View.UserByPhone(loginNameInput, request.InstanceID) + user, phoneErr := repo.View.UserByPhone(ctx, loginNameInput, request.InstanceID) if phoneErr == nil { // if there was a single user with the specified phone // load and check the login policy @@ -765,9 +768,9 @@ func (repo *AuthRequestRepo) checkLoginNameInput(ctx context.Context, request *d return nil, err } -func (repo *AuthRequestRepo) checkLoginNameInputForResourceOwner(request *domain.AuthRequest, loginNameInput string) (*user_view_model.UserView, error) { +func (repo *AuthRequestRepo) checkLoginNameInputForResourceOwner(ctx context.Context, request *domain.AuthRequest, loginNameInput string) (*user_view_model.UserView, error) { // always check the loginname first - user, err := repo.View.UserByLoginNameAndResourceOwner(loginNameInput, request.RequestedOrgID, request.InstanceID) + user, err := repo.View.UserByLoginNameAndResourceOwner(ctx, loginNameInput, request.RequestedOrgID, request.InstanceID) if err == nil { // and take the user regardless if there would be a user with that email or phone return user, nil @@ -775,7 +778,7 @@ func (repo *AuthRequestRepo) checkLoginNameInputForResourceOwner(request *domain if request.LoginPolicy != nil && !request.LoginPolicy.DisableLoginWithEmail { // if login by email is allowed and there was a single user with the specified email // take that user (and ignore possible phone number matches) - user, emailErr := repo.View.UserByEmailAndResourceOwner(loginNameInput, request.RequestedOrgID, request.InstanceID) + user, emailErr := repo.View.UserByEmailAndResourceOwner(ctx, loginNameInput, request.RequestedOrgID, request.InstanceID) if emailErr == nil { return user, nil } @@ -783,7 +786,7 @@ func (repo *AuthRequestRepo) checkLoginNameInputForResourceOwner(request *domain if request.LoginPolicy != nil && !request.LoginPolicy.DisableLoginWithPhone { // if login by phone is allowed and there was a single user with the specified phone // take that user - user, phoneErr := repo.View.UserByPhoneAndResourceOwner(loginNameInput, request.RequestedOrgID, request.InstanceID) + user, phoneErr := repo.View.UserByPhoneAndResourceOwner(ctx, loginNameInput, request.RequestedOrgID, request.InstanceID) if phoneErr == nil { return user, nil } @@ -1298,12 +1301,20 @@ func userSessionsByUserAgentID(provider userSessionViewProvider, agentID, instan } func userSessionByIDs(ctx context.Context, provider userSessionViewProvider, eventProvider userEventProvider, agentID string, user *user_model.UserView) (*user_model.UserSessionView, error) { - session, err := provider.UserSessionByIDs(agentID, user.ID, authz.GetInstance(ctx).InstanceID()) + instanceID := authz.GetInstance(ctx).InstanceID() + session, err := provider.UserSessionByIDs(agentID, user.ID, instanceID) if err != nil { if !errors.IsNotFound(err) { return nil, err } + sequence, err := provider.GetLatestUserSessionSequence(ctx, instanceID) + logging.WithFields("instanceID", instanceID, "userID", user.ID). + OnError(err). + Errorf("could not get current sequence for userSessionByIDs") session = &user_view_model.UserSessionView{UserAgentID: agentID, UserID: user.ID} + if sequence != nil { + session.Sequence = sequence.CurrentSequence + } } events, err := eventProvider.UserEventsByID(ctx, user.ID, session.Sequence) if err != nil { diff --git a/internal/auth/repository/eventsourcing/eventstore/auth_request_test.go b/internal/auth/repository/eventsourcing/eventstore/auth_request_test.go index 04d9e06646..c5dd8f411f 100644 --- a/internal/auth/repository/eventsourcing/eventstore/auth_request_test.go +++ b/internal/auth/repository/eventsourcing/eventstore/auth_request_test.go @@ -19,6 +19,7 @@ import ( user_model "github.com/zitadel/zitadel/internal/user/model" user_es_model "github.com/zitadel/zitadel/internal/user/repository/eventsourcing/model" user_view_model "github.com/zitadel/zitadel/internal/user/repository/view/model" + "github.com/zitadel/zitadel/internal/view/repository" ) var ( @@ -35,6 +36,10 @@ func (m *mockViewNoUserSession) UserSessionsByAgentID(string, string) ([]*user_v return nil, nil } +func (m *mockViewNoUserSession) GetLatestUserSessionSequence(ctx context.Context, instanceID string) (*repository.CurrentSequence, error) { + return &repository.CurrentSequence{}, nil +} + type mockViewErrUserSession struct{} func (m *mockViewErrUserSession) UserSessionByIDs(string, string, string) (*user_view_model.UserSessionView, error) { @@ -45,6 +50,10 @@ func (m *mockViewErrUserSession) UserSessionsByAgentID(string, string) ([]*user_ return nil, errors.ThrowInternal(nil, "id", "internal error") } +func (m *mockViewErrUserSession) GetLatestUserSessionSequence(ctx context.Context, instanceID string) (*repository.CurrentSequence, error) { + return &repository.CurrentSequence{}, nil +} + type mockViewUserSession struct { ExternalLoginVerification time.Time PasswordlessVerification time.Time @@ -82,6 +91,10 @@ func (m *mockViewUserSession) UserSessionsByAgentID(string, string) ([]*user_vie return sessions, nil } +func (m *mockViewUserSession) GetLatestUserSessionSequence(ctx context.Context, instanceID string) (*repository.CurrentSequence, error) { + return &repository.CurrentSequence{}, nil +} + type mockViewNoUser struct{} func (m *mockViewNoUser) UserByID(string, string) (*user_view_model.UserView, error) { diff --git a/internal/auth/repository/eventsourcing/eventstore/refresh_token.go b/internal/auth/repository/eventsourcing/eventstore/refresh_token.go index 33ea1d868c..9edc5af154 100644 --- a/internal/auth/repository/eventsourcing/eventstore/refresh_token.go +++ b/internal/auth/repository/eventsourcing/eventstore/refresh_token.go @@ -42,15 +42,24 @@ func (r *RefreshTokenRepo) RefreshTokenByToken(ctx context.Context, refreshToken } func (r *RefreshTokenRepo) RefreshTokenByID(ctx context.Context, tokenID, userID string) (*usr_model.RefreshTokenView, error) { - tokenView, viewErr := r.View.RefreshTokenByID(tokenID, authz.GetInstance(ctx).InstanceID()) + instanceID := authz.GetInstance(ctx).InstanceID() + tokenView, viewErr := r.View.RefreshTokenByID(tokenID, instanceID) if viewErr != nil && !errors.IsNotFound(viewErr) { return nil, viewErr } if errors.IsNotFound(viewErr) { + sequence, err := r.View.GetLatestRefreshTokenSequence(ctx, instanceID) + logging.WithFields("instanceID", instanceID, "userID", userID, "tokenID", tokenID). + OnError(err). + Errorf("could not get current sequence for RefreshTokenByID") + tokenView = new(model.RefreshTokenView) tokenView.ID = tokenID tokenView.UserID = userID - tokenView.InstanceID = authz.GetInstance(ctx).InstanceID() + tokenView.InstanceID = instanceID + if sequence != nil { + tokenView.Sequence = sequence.CurrentSequence + } } events, esErr := r.getUserEvents(ctx, userID, tokenView.InstanceID, tokenView.Sequence) @@ -80,7 +89,7 @@ func (r *RefreshTokenRepo) SearchMyRefreshTokens(ctx context.Context, userID str if err != nil { return nil, err } - sequence, err := r.View.GetLatestRefreshTokenSequence(authz.GetInstance(ctx).InstanceID()) + sequence, err := r.View.GetLatestRefreshTokenSequence(ctx, authz.GetInstance(ctx).InstanceID()) logging.Log("EVENT-GBdn4").OnError(err).WithField("traceID", tracing.TraceIDFromCtx(ctx)).Warn("could not read latest refresh token sequence") request.Queries = append(request.Queries, &usr_model.RefreshTokenSearchQuery{Key: usr_model.RefreshTokenSearchKeyUserID, Method: domain.SearchMethodEquals, Value: userID}) tokens, count, err := r.View.SearchRefreshTokens(request) diff --git a/internal/auth/repository/eventsourcing/eventstore/token.go b/internal/auth/repository/eventsourcing/eventstore/token.go index ec982ea3b2..7894a6f4cc 100644 --- a/internal/auth/repository/eventsourcing/eventstore/token.go +++ b/internal/auth/repository/eventsourcing/eventstore/token.go @@ -34,15 +34,25 @@ func (repo *TokenRepo) IsTokenValid(ctx context.Context, userID, tokenID string) } func (repo *TokenRepo) TokenByIDs(ctx context.Context, userID, tokenID string) (*usr_model.TokenView, error) { - token, viewErr := repo.View.TokenByIDs(tokenID, userID, authz.GetInstance(ctx).InstanceID()) + instanceID := authz.GetInstance(ctx).InstanceID() + + token, viewErr := repo.View.TokenByIDs(tokenID, userID, instanceID) if viewErr != nil && !errors.IsNotFound(viewErr) { return nil, viewErr } if errors.IsNotFound(viewErr) { + sequence, err := repo.View.GetLatestTokenSequence(ctx, instanceID) + logging.WithFields("instanceID", instanceID, "userID", userID, "tokenID", tokenID). + OnError(err). + Errorf("could not get current sequence for TokenByIDs") + token = new(model.TokenView) token.ID = tokenID token.UserID = userID - token.InstanceID = authz.GetInstance(ctx).InstanceID() + token.InstanceID = instanceID + if sequence != nil { + token.Sequence = sequence.CurrentSequence + } } events, esErr := repo.getUserEvents(ctx, userID, token.InstanceID, token.Sequence) diff --git a/internal/auth/repository/eventsourcing/handler/refresh_token.go b/internal/auth/repository/eventsourcing/handler/refresh_token.go index 29e228caea..51f158b681 100644 --- a/internal/auth/repository/eventsourcing/handler/refresh_token.go +++ b/internal/auth/repository/eventsourcing/handler/refresh_token.go @@ -62,16 +62,16 @@ func (t *RefreshToken) AggregateTypes() []es_models.AggregateType { return []es_models.AggregateType{user.AggregateType, project.AggregateType, instance.AggregateType} } -func (t *RefreshToken) CurrentSequence(instanceID string) (uint64, error) { - sequence, err := t.view.GetLatestRefreshTokenSequence(instanceID) +func (t *RefreshToken) CurrentSequence(ctx context.Context, instanceID string) (uint64, error) { + sequence, err := t.view.GetLatestRefreshTokenSequence(ctx, instanceID) if err != nil { return 0, err } return sequence.CurrentSequence, nil } -func (t *RefreshToken) EventQuery(instanceIDs []string) (*es_models.SearchQuery, error) { - sequences, err := t.view.GetLatestRefreshTokenSequences(instanceIDs) +func (t *RefreshToken) EventQuery(ctx context.Context, instanceIDs []string) (*es_models.SearchQuery, error) { + sequences, err := t.view.GetLatestRefreshTokenSequences(ctx, instanceIDs) if err != nil { return nil, err } diff --git a/internal/auth/repository/eventsourcing/handler/token.go b/internal/auth/repository/eventsourcing/handler/token.go index 280c85b6b4..7bb5bd321b 100644 --- a/internal/auth/repository/eventsourcing/handler/token.go +++ b/internal/auth/repository/eventsourcing/handler/token.go @@ -67,16 +67,16 @@ func (_ *Token) AggregateTypes() []es_models.AggregateType { return []es_models.AggregateType{user.AggregateType, project.AggregateType, instance.AggregateType} } -func (t *Token) CurrentSequence(instanceID string) (uint64, error) { - sequence, err := t.view.GetLatestTokenSequence(instanceID) +func (t *Token) CurrentSequence(ctx context.Context, instanceID string) (uint64, error) { + sequence, err := t.view.GetLatestTokenSequence(ctx, instanceID) if err != nil { return 0, err } return sequence.CurrentSequence, nil } -func (t *Token) EventQuery(instanceIDs []string) (*es_models.SearchQuery, error) { - sequences, err := t.view.GetLatestTokenSequences(instanceIDs) +func (t *Token) EventQuery(ctx context.Context, instanceIDs []string) (*es_models.SearchQuery, error) { + sequences, err := t.view.GetLatestTokenSequences(ctx, instanceIDs) if err != nil { return nil, err } @@ -145,11 +145,13 @@ func (t *Token) Reduce(event *es_models.Event) (err error) { if err != nil { return err } - applicationsIDs := make([]string, 0, len(project.Applications)) + clientIDs := make([]string, 0, len(project.Applications)) for _, app := range project.Applications { - applicationsIDs = append(applicationsIDs, app.AppID) + if app.OIDCConfig != nil { + clientIDs = append(clientIDs, app.OIDCConfig.ClientID) + } } - return t.view.DeleteApplicationTokens(event, applicationsIDs...) + return t.view.DeleteApplicationTokens(event, clientIDs...) case instance.InstanceRemovedEventType: return t.view.DeleteInstanceTokens(event) case org.OrgRemovedEventType: @@ -208,7 +210,7 @@ func (t *Token) OnSuccess(instanceIDs []string) error { } func (t *Token) getProjectByID(ctx context.Context, projID, instanceID string) (*proj_model.Project, error) { - query, err := proj_view.ProjectByIDQuery(projID, instanceID, 0) + projectQuery, err := proj_view.ProjectByIDQuery(projID, instanceID, 0) if err != nil { return nil, err } @@ -217,7 +219,7 @@ func (t *Token) getProjectByID(ctx context.Context, projID, instanceID string) ( AggregateID: projID, }, } - err = es_sdk.Filter(ctx, t.Eventstore().FilterEvents, esProject.AppendEvents, query) + err = es_sdk.Filter(ctx, t.Eventstore().FilterEvents, esProject.AppendEvents, projectQuery) if err != nil && !caos_errs.IsNotFound(err) { return nil, err } diff --git a/internal/auth/repository/eventsourcing/handler/user.go b/internal/auth/repository/eventsourcing/handler/user.go index d19d09281a..b226f54b13 100644 --- a/internal/auth/repository/eventsourcing/handler/user.go +++ b/internal/auth/repository/eventsourcing/handler/user.go @@ -68,16 +68,16 @@ func (_ *User) AggregateTypes() []es_models.AggregateType { return []es_models.AggregateType{user_repo.AggregateType, org.AggregateType, instance.AggregateType} } -func (u *User) CurrentSequence(instanceID string) (uint64, error) { - sequence, err := u.view.GetLatestUserSequence(instanceID) +func (u *User) CurrentSequence(ctx context.Context, instanceID string) (uint64, error) { + sequence, err := u.view.GetLatestUserSequence(ctx, instanceID) if err != nil { return 0, err } return sequence.CurrentSequence, nil } -func (u *User) EventQuery(instanceIDs []string) (*es_models.SearchQuery, error) { - sequences, err := u.view.GetLatestUserSequences(instanceIDs) +func (u *User) EventQuery(ctx context.Context, instanceIDs []string) (*es_models.SearchQuery, error) { + sequences, err := u.view.GetLatestUserSequences(ctx, instanceIDs) if err != nil { return nil, err } @@ -158,6 +158,11 @@ func (u *User) ProcessUser(event *es_models.Event) (err error) { if !errors.IsNotFound(err) { return err } + logging.WithFields( + "instance", event.InstanceID, + "userID", event.AggregateID, + "eventType", event.Type, + ).Info("user not found in view") query, err := usr_view.UserByIDQuery(event.AggregateID, event.InstanceID, 0) if err != nil { return err @@ -181,6 +186,11 @@ func (u *User) ProcessUser(event *es_models.Event) (err error) { if !errors.IsNotFound(err) { return err } + logging.WithFields( + "instance", event.InstanceID, + "userID", event.AggregateID, + "eventType", event.Type, + ).Info("user not found in view") query, err := usr_view.UserByIDQuery(event.AggregateID, event.InstanceID, 0) if err != nil { return err @@ -291,7 +301,7 @@ func (u *User) OnSuccess(instanceIDs []string) error { } func (u *User) getOrgByID(ctx context.Context, orgID, instanceID string) (*org_model.Org, error) { - query, err := view.OrgByIDQuery(orgID, instanceID, 0) + orgQuery, err := view.OrgByIDQuery(orgID, instanceID, 0) if err != nil { return nil, err } @@ -301,7 +311,7 @@ func (u *User) getOrgByID(ctx context.Context, orgID, instanceID string) (*org_m AggregateID: orgID, }, } - err = es_sdk.Filter(ctx, u.Eventstore().FilterEvents, esOrg.AppendEvents, query) + err = es_sdk.Filter(ctx, u.Eventstore().FilterEvents, esOrg.AppendEvents, orgQuery) if err != nil && !errors.IsNotFound(err) { return nil, err } diff --git a/internal/auth/repository/eventsourcing/handler/user_session.go b/internal/auth/repository/eventsourcing/handler/user_session.go index a76e478512..a0148edd68 100644 --- a/internal/auth/repository/eventsourcing/handler/user_session.go +++ b/internal/auth/repository/eventsourcing/handler/user_session.go @@ -65,16 +65,16 @@ func (_ *UserSession) AggregateTypes() []models.AggregateType { return []models.AggregateType{user.AggregateType, org.AggregateType, instance.AggregateType} } -func (u *UserSession) CurrentSequence(instanceID string) (uint64, error) { - sequence, err := u.view.GetLatestUserSessionSequence(instanceID) +func (u *UserSession) CurrentSequence(ctx context.Context, instanceID string) (uint64, error) { + sequence, err := u.view.GetLatestUserSessionSequence(ctx, instanceID) if err != nil { return 0, err } return sequence.CurrentSequence, nil } -func (u *UserSession) EventQuery(instanceIDs []string) (*models.SearchQuery, error) { - sequences, err := u.view.GetLatestUserSessionSequences(instanceIDs) +func (u *UserSession) EventQuery(ctx context.Context, instanceIDs []string) (*models.SearchQuery, error) { + sequences, err := u.view.GetLatestUserSessionSequences(ctx, instanceIDs) if err != nil { return nil, err } @@ -231,7 +231,7 @@ func (u *UserSession) loginNameInformation(ctx context.Context, orgID string, in } func (u *UserSession) getOrgByID(ctx context.Context, orgID, instanceID string) (*org_model.Org, error) { - query, err := view.OrgByIDQuery(orgID, instanceID, 0) + orgQuery, err := view.OrgByIDQuery(orgID, instanceID, 0) if err != nil { return nil, err } @@ -241,7 +241,7 @@ func (u *UserSession) getOrgByID(ctx context.Context, orgID, instanceID string) AggregateID: orgID, }, } - err = es_sdk.Filter(ctx, u.Eventstore().FilterEvents, esOrg.AppendEvents, query) + err = es_sdk.Filter(ctx, u.Eventstore().FilterEvents, esOrg.AppendEvents, orgQuery) if err != nil && !errors.IsNotFound(err) { return nil, err } diff --git a/internal/auth/repository/eventsourcing/view/refresh_token.go b/internal/auth/repository/eventsourcing/view/refresh_token.go index 8c1342f6f8..2740c2ec84 100644 --- a/internal/auth/repository/eventsourcing/view/refresh_token.go +++ b/internal/auth/repository/eventsourcing/view/refresh_token.go @@ -1,6 +1,8 @@ package view import ( + "context" + "github.com/zitadel/zitadel/internal/errors" "github.com/zitadel/zitadel/internal/eventstore/v1/models" user_model "github.com/zitadel/zitadel/internal/user/model" @@ -81,12 +83,12 @@ func (v *View) DeleteOrgRefreshTokens(event *models.Event) error { return v.ProcessedRefreshTokenSequence(event) } -func (v *View) GetLatestRefreshTokenSequence(instanceID string) (*repository.CurrentSequence, error) { - return v.latestSequence(refreshTokenTable, instanceID) +func (v *View) GetLatestRefreshTokenSequence(ctx context.Context, instanceID string) (*repository.CurrentSequence, error) { + return v.latestSequence(ctx, refreshTokenTable, instanceID) } -func (v *View) GetLatestRefreshTokenSequences(instanceIDs []string) ([]*repository.CurrentSequence, error) { - return v.latestSequences(refreshTokenTable, instanceIDs) +func (v *View) GetLatestRefreshTokenSequences(ctx context.Context, instanceIDs []string) ([]*repository.CurrentSequence, error) { + return v.latestSequences(ctx, refreshTokenTable, instanceIDs) } func (v *View) ProcessedRefreshTokenSequence(event *models.Event) error { diff --git a/internal/auth/repository/eventsourcing/view/sequence.go b/internal/auth/repository/eventsourcing/view/sequence.go index c1e3b0b4e2..71c77cb6a4 100644 --- a/internal/auth/repository/eventsourcing/view/sequence.go +++ b/internal/auth/repository/eventsourcing/view/sequence.go @@ -1,6 +1,7 @@ package view import ( + "context" "time" "github.com/zitadel/zitadel/internal/eventstore/v1/models" @@ -15,12 +16,12 @@ func (v *View) saveCurrentSequence(viewName string, event *models.Event) error { return repository.SaveCurrentSequence(v.Db, sequencesTable, viewName, event.InstanceID, event.Sequence, event.CreationDate) } -func (v *View) latestSequence(viewName, instanceID string) (*repository.CurrentSequence, error) { - return repository.LatestSequence(v.Db, sequencesTable, viewName, instanceID) +func (v *View) latestSequence(ctx context.Context, viewName, instanceID string) (*repository.CurrentSequence, error) { + return repository.LatestSequence(v.Db, v.TimeTravel(ctx, sequencesTable), viewName, instanceID) } -func (v *View) latestSequences(viewName string, instanceIDs []string) ([]*repository.CurrentSequence, error) { - return repository.LatestSequences(v.Db, sequencesTable, viewName, instanceIDs) +func (v *View) latestSequences(ctx context.Context, viewName string, instanceIDs []string) ([]*repository.CurrentSequence, error) { + return repository.LatestSequences(v.Db, v.TimeTravel(ctx, sequencesTable), viewName, instanceIDs) } func (v *View) updateSpoolerRunSequence(viewName string, instanceIDs []string) error { diff --git a/internal/auth/repository/eventsourcing/view/token.go b/internal/auth/repository/eventsourcing/view/token.go index 132c3ecd94..3f3909cdb3 100644 --- a/internal/auth/repository/eventsourcing/view/token.go +++ b/internal/auth/repository/eventsourcing/view/token.go @@ -1,6 +1,8 @@ package view import ( + "context" + "github.com/zitadel/zitadel/internal/errors" "github.com/zitadel/zitadel/internal/eventstore/v1/models" usr_view "github.com/zitadel/zitadel/internal/user/repository/view" @@ -92,12 +94,12 @@ func (v *View) DeleteOrgTokens(event *models.Event) error { return v.ProcessedTokenSequence(event) } -func (v *View) GetLatestTokenSequence(instanceID string) (*repository.CurrentSequence, error) { - return v.latestSequence(tokenTable, instanceID) +func (v *View) GetLatestTokenSequence(ctx context.Context, instanceID string) (*repository.CurrentSequence, error) { + return v.latestSequence(ctx, tokenTable, instanceID) } -func (v *View) GetLatestTokenSequences(instanceIDs []string) ([]*repository.CurrentSequence, error) { - return v.latestSequences(tokenTable, instanceIDs) +func (v *View) GetLatestTokenSequences(ctx context.Context, instanceIDs []string) ([]*repository.CurrentSequence, error) { + return v.latestSequences(ctx, tokenTable, instanceIDs) } func (v *View) ProcessedTokenSequence(event *models.Event) error { diff --git a/internal/auth/repository/eventsourcing/view/user.go b/internal/auth/repository/eventsourcing/view/user.go index 51cd52f007..6a82f6ed63 100644 --- a/internal/auth/repository/eventsourcing/view/user.go +++ b/internal/auth/repository/eventsourcing/view/user.go @@ -3,7 +3,8 @@ package view import ( "context" - "github.com/zitadel/zitadel/internal/api/authz" + "github.com/zitadel/logging" + "github.com/zitadel/zitadel/internal/errors" "github.com/zitadel/zitadel/internal/eventstore/v1/models" "github.com/zitadel/zitadel/internal/query" @@ -21,16 +22,16 @@ func (v *View) UserByID(userID, instanceID string) (*model.UserView, error) { return view.UserByID(v.Db, userTable, userID, instanceID) } -func (v *View) UserByLoginName(loginName, instanceID string) (*model.UserView, error) { +func (v *View) UserByLoginName(ctx context.Context, loginName, instanceID string) (*model.UserView, error) { loginNameQuery, err := query.NewUserLoginNamesSearchQuery(loginName) if err != nil { return nil, err } - return v.userByID(instanceID, loginNameQuery) + return v.userByID(ctx, instanceID, loginNameQuery) } -func (v *View) UserByLoginNameAndResourceOwner(loginName, resourceOwner, instanceID string) (*model.UserView, error) { +func (v *View) UserByLoginNameAndResourceOwner(ctx context.Context, loginName, resourceOwner, instanceID string) (*model.UserView, error) { loginNameQuery, err := query.NewUserLoginNamesSearchQuery(loginName) if err != nil { return nil, err @@ -40,18 +41,18 @@ func (v *View) UserByLoginNameAndResourceOwner(loginName, resourceOwner, instanc return nil, err } - return v.userByID(instanceID, loginNameQuery, resourceOwnerQuery) + return v.userByID(ctx, instanceID, loginNameQuery, resourceOwnerQuery) } -func (v *View) UserByEmail(email, instanceID string) (*model.UserView, error) { +func (v *View) UserByEmail(ctx context.Context, email, instanceID string) (*model.UserView, error) { emailQuery, err := query.NewUserVerifiedEmailSearchQuery(email, query.TextEqualsIgnoreCase) if err != nil { return nil, err } - return v.userByID(instanceID, emailQuery) + return v.userByID(ctx, instanceID, emailQuery) } -func (v *View) UserByEmailAndResourceOwner(email, resourceOwner, instanceID string) (*model.UserView, error) { +func (v *View) UserByEmailAndResourceOwner(ctx context.Context, email, resourceOwner, instanceID string) (*model.UserView, error) { emailQuery, err := query.NewUserVerifiedEmailSearchQuery(email, query.TextEquals) if err != nil { return nil, err @@ -61,18 +62,18 @@ func (v *View) UserByEmailAndResourceOwner(email, resourceOwner, instanceID stri return nil, err } - return v.userByID(instanceID, emailQuery, resourceOwnerQuery) + return v.userByID(ctx, instanceID, emailQuery, resourceOwnerQuery) } -func (v *View) UserByPhone(phone, instanceID string) (*model.UserView, error) { +func (v *View) UserByPhone(ctx context.Context, phone, instanceID string) (*model.UserView, error) { phoneQuery, err := query.NewUserVerifiedPhoneSearchQuery(phone, query.TextEquals) if err != nil { return nil, err } - return v.userByID(instanceID, phoneQuery) + return v.userByID(ctx, instanceID, phoneQuery) } -func (v *View) UserByPhoneAndResourceOwner(phone, resourceOwner, instanceID string) (*model.UserView, error) { +func (v *View) UserByPhoneAndResourceOwner(ctx context.Context, phone, resourceOwner, instanceID string) (*model.UserView, error) { phoneQuery, err := query.NewUserVerifiedPhoneSearchQuery(phone, query.TextEquals) if err != nil { return nil, err @@ -82,12 +83,10 @@ func (v *View) UserByPhoneAndResourceOwner(phone, resourceOwner, instanceID stri return nil, err } - return v.userByID(instanceID, phoneQuery, resourceOwnerQuery) + return v.userByID(ctx, instanceID, phoneQuery, resourceOwnerQuery) } -func (v *View) userByID(instanceID string, queries ...query.SearchQuery) (*model.UserView, error) { - ctx := authz.WithInstanceID(context.Background(), instanceID) - +func (v *View) userByID(ctx context.Context, instanceID string, queries ...query.SearchQuery) (*model.UserView, error) { queriedUser, err := v.query.GetNotifyUser(ctx, true, false, queries...) if err != nil { return nil, err @@ -99,7 +98,14 @@ func (v *View) userByID(instanceID string, queries ...query.SearchQuery) (*model } if err != nil { + sequence, err := v.GetLatestUserSequence(ctx, instanceID) + logging.WithFields("instanceID", instanceID). + OnError(err). + Errorf("could not get current sequence for userByID") user = new(model.UserView) + if sequence != nil { + user.Sequence = sequence.CurrentSequence + } } query, err := view.UserByIDQuery(queriedUser.ID, instanceID, user.Sequence) @@ -188,12 +194,12 @@ func (v *View) UpdateOrgOwnerRemovedUsers(event *models.Event) error { return v.ProcessedUserSequence(event) } -func (v *View) GetLatestUserSequence(instanceID string) (*repository.CurrentSequence, error) { - return v.latestSequence(userTable, instanceID) +func (v *View) GetLatestUserSequence(ctx context.Context, instanceID string) (*repository.CurrentSequence, error) { + return v.latestSequence(ctx, userTable, instanceID) } -func (v *View) GetLatestUserSequences(instanceIDs []string) ([]*repository.CurrentSequence, error) { - return v.latestSequences(userTable, instanceIDs) +func (v *View) GetLatestUserSequences(ctx context.Context, instanceIDs []string) ([]*repository.CurrentSequence, error) { + return v.latestSequences(ctx, userTable, instanceIDs) } func (v *View) ProcessedUserSequence(event *models.Event) error { diff --git a/internal/auth/repository/eventsourcing/view/user_session.go b/internal/auth/repository/eventsourcing/view/user_session.go index 52ac559384..4e3803e77b 100644 --- a/internal/auth/repository/eventsourcing/view/user_session.go +++ b/internal/auth/repository/eventsourcing/view/user_session.go @@ -1,6 +1,8 @@ package view import ( + "context" + "github.com/zitadel/zitadel/internal/errors" "github.com/zitadel/zitadel/internal/eventstore/v1/models" "github.com/zitadel/zitadel/internal/user/repository/view" @@ -72,12 +74,12 @@ func (v *View) DeleteOrgUserSessions(event *models.Event) error { return v.ProcessedUserSessionSequence(event) } -func (v *View) GetLatestUserSessionSequence(instanceID string) (*repository.CurrentSequence, error) { - return v.latestSequence(userSessionTable, instanceID) +func (v *View) GetLatestUserSessionSequence(ctx context.Context, instanceID string) (*repository.CurrentSequence, error) { + return v.latestSequence(ctx, userSessionTable, instanceID) } -func (v *View) GetLatestUserSessionSequences(instanceIDs []string) ([]*repository.CurrentSequence, error) { - return v.latestSequences(userSessionTable, instanceIDs) +func (v *View) GetLatestUserSessionSequences(ctx context.Context, instanceIDs []string) ([]*repository.CurrentSequence, error) { + return v.latestSequences(ctx, userSessionTable, instanceIDs) } func (v *View) ProcessedUserSessionSequence(event *models.Event) error { diff --git a/internal/auth/repository/eventsourcing/view/view.go b/internal/auth/repository/eventsourcing/view/view.go index 80f683bc80..b65badf1e5 100644 --- a/internal/auth/repository/eventsourcing/view/view.go +++ b/internal/auth/repository/eventsourcing/view/view.go @@ -1,8 +1,11 @@ package view import ( + "context" + "github.com/jinzhu/gorm" + "github.com/zitadel/zitadel/internal/api/call" "github.com/zitadel/zitadel/internal/crypto" "github.com/zitadel/zitadel/internal/database" eventstore "github.com/zitadel/zitadel/internal/eventstore/v1" @@ -16,6 +19,7 @@ type View struct { idGenerator id.Generator query *query.Queries es eventstore.Eventstore + client *database.DB } func StartView(sqlClient *database.DB, keyAlgorithm crypto.EncryptionAlgorithm, queries *query.Queries, idGenerator id.Generator, es eventstore.Eventstore) (*View, error) { @@ -29,9 +33,14 @@ func StartView(sqlClient *database.DB, keyAlgorithm crypto.EncryptionAlgorithm, idGenerator: idGenerator, query: queries, es: es, + client: sqlClient, }, nil } func (v *View) Health() (err error) { return v.Db.DB().Ping() } + +func (v *View) TimeTravel(ctx context.Context, tableName string) string { + return tableName + v.client.Timetravel(call.Took(ctx)) +} diff --git a/internal/authz/repository/eventsourcing/eventstore/token_verifier.go b/internal/authz/repository/eventsourcing/eventstore/token_verifier.go index 4ef653d207..b2bff61557 100644 --- a/internal/authz/repository/eventsourcing/eventstore/token_verifier.go +++ b/internal/authz/repository/eventsourcing/eventstore/token_verifier.go @@ -44,16 +44,16 @@ func (repo *TokenVerifierRepo) tokenByID(ctx context.Context, tokenID, userID st defer func() { span.EndWithError(err) }() instanceID := authz.GetInstance(ctx).InstanceID() - sequence, err := repo.View.GetLatestTokenSequence(instanceID) - logging.WithFields("instanceID", instanceID, "userID", userID, "tokenID"). - OnError(err). - Errorf("could not get current sequence for token check") - token, viewErr := repo.View.TokenByIDs(tokenID, userID, instanceID) if viewErr != nil && !caos_errs.IsNotFound(viewErr) { return nil, viewErr } if caos_errs.IsNotFound(viewErr) { + sequence, err := repo.View.GetLatestTokenSequence(ctx, instanceID) + logging.WithFields("instanceID", instanceID, "userID", userID, "tokenID", tokenID). + OnError(err). + Errorf("could not get current sequence for token check") + token = new(model.TokenView) token.ID = tokenID token.UserID = userID diff --git a/internal/authz/repository/eventsourcing/view/sequence.go b/internal/authz/repository/eventsourcing/view/sequence.go index 6810e420d6..18b577c0cd 100644 --- a/internal/authz/repository/eventsourcing/view/sequence.go +++ b/internal/authz/repository/eventsourcing/view/sequence.go @@ -1,6 +1,8 @@ package view import ( + "context" + "github.com/zitadel/zitadel/internal/eventstore/v1/models" "github.com/zitadel/zitadel/internal/view/repository" ) @@ -13,6 +15,6 @@ func (v *View) saveCurrentSequence(viewName string, event *models.Event) error { return repository.SaveCurrentSequence(v.Db, sequencesTable, viewName, event.InstanceID, event.Sequence, event.CreationDate) } -func (v *View) latestSequence(viewName, instanceID string) (*repository.CurrentSequence, error) { - return repository.LatestSequence(v.Db, sequencesTable, viewName, instanceID) +func (v *View) latestSequence(ctx context.Context, viewName, instanceID string) (*repository.CurrentSequence, error) { + return repository.LatestSequence(v.Db, v.TimeTravel(ctx, sequencesTable), viewName, instanceID) } diff --git a/internal/authz/repository/eventsourcing/view/token.go b/internal/authz/repository/eventsourcing/view/token.go index 486d72008d..cfd0c8b9cf 100644 --- a/internal/authz/repository/eventsourcing/view/token.go +++ b/internal/authz/repository/eventsourcing/view/token.go @@ -1,6 +1,8 @@ package view import ( + "context" + "github.com/zitadel/zitadel/internal/errors" "github.com/zitadel/zitadel/internal/eventstore/v1/models" usr_view "github.com/zitadel/zitadel/internal/user/repository/view" @@ -40,8 +42,8 @@ func (v *View) DeleteSessionTokens(agentID, userID, instanceID string, event *mo return v.ProcessedTokenSequence(event) } -func (v *View) GetLatestTokenSequence(instanceID string) (*repository.CurrentSequence, error) { - return v.latestSequence(tokenTable, instanceID) +func (v *View) GetLatestTokenSequence(ctx context.Context, instanceID string) (*repository.CurrentSequence, error) { + return v.latestSequence(ctx, tokenTable, instanceID) } func (v *View) ProcessedTokenSequence(event *models.Event) error { diff --git a/internal/authz/repository/eventsourcing/view/view.go b/internal/authz/repository/eventsourcing/view/view.go index a55dec7699..c3ee5b79ac 100644 --- a/internal/authz/repository/eventsourcing/view/view.go +++ b/internal/authz/repository/eventsourcing/view/view.go @@ -1,17 +1,21 @@ package view import ( + "context" + + "github.com/jinzhu/gorm" + + "github.com/zitadel/zitadel/internal/api/call" "github.com/zitadel/zitadel/internal/database" "github.com/zitadel/zitadel/internal/id" "github.com/zitadel/zitadel/internal/query" - - "github.com/jinzhu/gorm" ) type View struct { Db *gorm.DB Query *query.Queries idGenerator id.Generator + client *database.DB } func StartView(sqlClient *database.DB, idGenerator id.Generator, queries *query.Queries) (*View, error) { @@ -23,9 +27,14 @@ func StartView(sqlClient *database.DB, idGenerator id.Generator, queries *query. Db: gorm, idGenerator: idGenerator, Query: queries, + client: sqlClient, }, nil } func (v *View) Health() (err error) { return v.Db.DB().Ping() } + +func (v *View) TimeTravel(ctx context.Context, tableName string) string { + return tableName + v.client.Timetravel(call.Took(ctx)) +} diff --git a/internal/eventstore/v1/query/handler.go b/internal/eventstore/v1/query/handler.go index 7e126a8f2a..384d5206f6 100755 --- a/internal/eventstore/v1/query/handler.go +++ b/internal/eventstore/v1/query/handler.go @@ -17,7 +17,7 @@ const ( type Handler interface { ViewModel() string - EventQuery(instanceIDs []string) (*models.SearchQuery, error) + EventQuery(ctx context.Context, instanceIDs []string) (*models.SearchQuery, error) Reduce(*models.Event) error OnError(event *models.Event, err error) error OnSuccess(instanceIDs []string) error @@ -26,7 +26,7 @@ type Handler interface { QueryLimit() uint64 AggregateTypes() []models.AggregateType - CurrentSequence(instanceID string) (uint64, error) + CurrentSequence(ctx context.Context, instanceID string) (uint64, error) Eventstore() v1.Eventstore Subscription() *v1.Subscription @@ -46,7 +46,7 @@ func ReduceEvent(ctx context.Context, handler Handler, event *models.Event) { ).Error("reduce panicked") } }() - currentSequence, err := handler.CurrentSequence(event.InstanceID) + currentSequence, err := handler.CurrentSequence(ctx, event.InstanceID) if err != nil { logging.WithError(err).Warn("unable to get current sequence") return @@ -67,7 +67,7 @@ func ReduceEvent(ctx context.Context, handler Handler, event *models.Event) { } for _, unprocessedEvent := range unprocessedEvents { - currentSequence, err := handler.CurrentSequence(unprocessedEvent.InstanceID) + currentSequence, err := handler.CurrentSequence(ctx, unprocessedEvent.InstanceID) if err != nil { logging.WithError(err).Warn("unable to get current sequence") return diff --git a/internal/eventstore/v1/spooler/spooler.go b/internal/eventstore/v1/spooler/spooler.go index c6cc15d3a4..6291e448d9 100644 --- a/internal/eventstore/v1/spooler/spooler.go +++ b/internal/eventstore/v1/spooler/spooler.go @@ -222,7 +222,7 @@ func (s *spooledHandler) process(ctx context.Context, events []*models.Event, wo } func (s *spooledHandler) query(ctx context.Context, instanceIDs []string) ([]*models.Event, error) { - query, err := s.EventQuery(instanceIDs) + query, err := s.EventQuery(ctx, instanceIDs) if err != nil { return nil, err } diff --git a/internal/eventstore/v1/spooler/spooler_test.go b/internal/eventstore/v1/spooler/spooler_test.go index 9aa0c75431..862d2aed35 100644 --- a/internal/eventstore/v1/spooler/spooler_test.go +++ b/internal/eventstore/v1/spooler/spooler_test.go @@ -35,7 +35,7 @@ func (h *testHandler) AggregateTypes() []models.AggregateType { return nil } -func (h *testHandler) CurrentSequence(instanceID string) (uint64, error) { +func (h *testHandler) CurrentSequence(ctx context.Context, instanceID string) (uint64, error) { return 0, nil } @@ -51,7 +51,7 @@ func (h *testHandler) Subscription() *v1.Subscription { return nil } -func (h *testHandler) EventQuery(instanceIDs []string) (*models.SearchQuery, error) { +func (h *testHandler) EventQuery(ctx context.Context, instanceIDs []string) (*models.SearchQuery, error) { if h.queryError != nil { return nil, h.queryError } diff --git a/internal/org/repository/view/query.go b/internal/org/repository/view/query.go index c6352bff4a..765d6abbc8 100644 --- a/internal/org/repository/view/query.go +++ b/internal/org/repository/view/query.go @@ -16,5 +16,19 @@ func OrgByIDQuery(id, instanceID string, latestSequence uint64) (*es_models.Sear LatestSequenceFilter(latestSequence). InstanceIDFilter(instanceID). AggregateIDFilter(id). + EventTypesFilter( + es_models.EventType(org.OrgAddedEventType), + es_models.EventType(org.OrgChangedEventType), + es_models.EventType(org.OrgDeactivatedEventType), + es_models.EventType(org.OrgReactivatedEventType), + es_models.EventType(org.OrgDomainAddedEventType), + es_models.EventType(org.OrgDomainVerificationAddedEventType), + es_models.EventType(org.OrgDomainVerifiedEventType), + es_models.EventType(org.OrgDomainPrimarySetEventType), + es_models.EventType(org.OrgDomainRemovedEventType), + es_models.EventType(org.DomainPolicyAddedEventType), + es_models.EventType(org.DomainPolicyChangedEventType), + es_models.EventType(org.DomainPolicyRemovedEventType), + ). SearchQuery(), nil } diff --git a/internal/project/repository/eventsourcing/model/project.go b/internal/project/repository/eventsourcing/model/project.go index 8aba9a08ab..dfba580f7f 100644 --- a/internal/project/repository/eventsourcing/model/project.go +++ b/internal/project/repository/eventsourcing/model/project.go @@ -18,15 +18,26 @@ type Project struct { ProjectRoleCheck bool `json:"projectRoleCheck,omitempty"` HasProjectCheck bool `json:"hasProjectCheck,omitempty"` State int32 `json:"-"` + OIDCApplications []*oidcApp +} + +type oidcApp struct { + AppID string `json:"appId"` + ClientID string `json:"clientId,omitempty"` } func ProjectToModel(project *Project) *model.Project { + apps := make([]*model.Application, len(project.OIDCApplications)) + for i, application := range project.OIDCApplications { + apps[i] = &model.Application{OIDCConfig: &model.OIDCConfig{ClientID: application.ClientID}} + } return &model.Project{ ObjectRoot: project.ObjectRoot, Name: project.Name, ProjectRoleAssertion: project.ProjectRoleAssertion, ProjectRoleCheck: project.ProjectRoleCheck, State: model.ProjectState(project.State), + Applications: apps, } } @@ -59,6 +70,10 @@ func (p *Project) AppendEvent(event *es_models.Event) error { return p.appendReactivatedEvent() case project.ProjectRemovedType: return p.appendRemovedEvent() + case project.OIDCConfigAddedType: + return p.appendOIDCConfig(event) + case project.ApplicationRemovedType: + return p.appendApplicationRemoved(event) } return nil } @@ -84,6 +99,31 @@ func (p *Project) appendRemovedEvent() error { return nil } +func (p *Project) appendOIDCConfig(event *es_models.Event) error { + appEvent := new(oidcApp) + if err := json.Unmarshal(event.Data, appEvent); err != nil { + return err + } + p.OIDCApplications = append(p.OIDCApplications, appEvent) + return nil +} + +func (p *Project) appendApplicationRemoved(event *es_models.Event) error { + appEvent := new(oidcApp) + if err := json.Unmarshal(event.Data, appEvent); err != nil { + return err + } + for i := len(p.OIDCApplications) - 1; i >= 0; i-- { + if p.OIDCApplications[i].AppID == appEvent.AppID { + p.OIDCApplications[i] = p.OIDCApplications[len(p.OIDCApplications)-1] + p.OIDCApplications[len(p.OIDCApplications)-1] = nil + p.OIDCApplications = p.OIDCApplications[:len(p.OIDCApplications)-1] + return nil + } + } + return nil +} + func (p *Project) SetData(event *es_models.Event) error { if err := json.Unmarshal(event.Data, p); err != nil { logging.Log("EVEN-lo9sr").WithError(err).Error("could not unmarshal event data") diff --git a/internal/project/repository/view/query.go b/internal/project/repository/view/query.go index 963e1cd4a9..56c8ae007b 100644 --- a/internal/project/repository/view/query.go +++ b/internal/project/repository/view/query.go @@ -16,5 +16,14 @@ func ProjectByIDQuery(id, instanceID string, latestSequence uint64) (*es_models. AggregateTypeFilter(project.AggregateType). LatestSequenceFilter(latestSequence). InstanceIDFilter(instanceID). + EventTypesFilter( + es_models.EventType(project.ProjectAddedType), + es_models.EventType(project.ProjectChangedType), + es_models.EventType(project.ProjectDeactivatedType), + es_models.EventType(project.ProjectReactivatedType), + es_models.EventType(project.ProjectRemovedType), + es_models.EventType(project.OIDCConfigAddedType), + es_models.EventType(project.ApplicationRemovedType), + ). SearchQuery(), nil } From 39bdef35e76438669d525da8c37c2f3a42596ab7 Mon Sep 17 00:00:00 2001 From: Silvan Date: Fri, 28 Apr 2023 16:56:51 +0200 Subject: [PATCH 3/5] chore: merge (#5773) * feat: allow skip of success page for native apps (#5627) add possibility to return to callback directly after login without rendering the successful login page * build next * fix(console): disallow inline fonts, critical styles (#5714) fix: disallow inline * fix(setup): step 10 for postgres (#5717) * fix(setup): smaller transactions (#5743) * fix: order by sequence by default * test: add allowCreationDateFilter * fix(step10): separate executions (#5754) * feat: allow skip of success page for native apps (#5627) add possibility to return to callback directly after login without rendering the successful login page * build next * fix(console): disallow inline fonts, critical styles (#5714) fix: disallow inline * fix(setup): step 10 for postgres (#5717) * fix(setup): smaller transactions (#5743) * fix(step10): split statements * fix(step10): split into separate execs * chore: prerelease * add truncate before insert * fix: add truncate * Merge branch 'main' into optimise-step-10 * chore: reset release definition --------- Co-authored-by: Livio Spring Co-authored-by: Max Peintner --------- Co-authored-by: Livio Spring Co-authored-by: Max Peintner --- .releaserc.js | 4 ++-- cmd/defaults.yaml | 1 + cmd/start/start.go | 6 +++--- .../repository/eventsourcing/repository.go | 4 ++-- .../repository/eventsourcing/repository.go | 4 ++-- internal/authz/authz.go | 4 ++-- .../repository/eventsourcing/repository.go | 4 ++-- internal/eventstore/config.go | 7 ++++--- internal/eventstore/repository/sql/crdb.go | 19 ++++++++++++++----- .../eventstore/repository/sql/query_test.go | 10 +++++++--- internal/eventstore/v1/eventstore.go | 4 ++-- .../v1/internal/repository/sql/config.go | 5 +++-- .../v1/internal/repository/sql/filter.go | 10 +++++----- .../v1/internal/repository/sql/filter_test.go | 3 ++- .../v1/internal/repository/sql/query.go | 16 ++++++++++++---- .../v1/internal/repository/sql/query_test.go | 2 +- .../v1/internal/repository/sql/sql.go | 3 ++- 17 files changed, 66 insertions(+), 40 deletions(-) diff --git a/.releaserc.js b/.releaserc.js index b8d5a8551a..f24249cada 100644 --- a/.releaserc.js +++ b/.releaserc.js @@ -1,7 +1,7 @@ module.exports = { branches: [ - {name: 'main', channel: 'next'}, - {name: 'next', prerelease: true} + {name: 'main'}, + {name: 'next'}, ], plugins: [ "@semantic-release/commit-analyzer" diff --git a/cmd/defaults.yaml b/cmd/defaults.yaml index 85d6c0ab15..1acbb55d44 100644 --- a/cmd/defaults.yaml +++ b/cmd/defaults.yaml @@ -398,6 +398,7 @@ Quotas: Eventstore: PushTimeout: 15s + AllowOrderByCreationDate: false DefaultInstance: InstanceName: diff --git a/cmd/start/start.go b/cmd/start/start.go index 90362d61e9..e0829b7eb3 100644 --- a/cmd/start/start.go +++ b/cmd/start/start.go @@ -115,7 +115,7 @@ func startZitadel(config *Config, masterKey string) error { return fmt.Errorf("cannot start queries: %w", err) } - authZRepo, err := authz.Start(queries, dbClient, keys.OIDC, config.ExternalSecure) + authZRepo, err := authz.Start(queries, dbClient, keys.OIDC, config.ExternalSecure, config.Eventstore.AllowOrderByCreationDate) if err != nil { return fmt.Errorf("error starting authz repo: %w", err) } @@ -229,11 +229,11 @@ func startAPIs( if err != nil { return fmt.Errorf("error creating api %w", err) } - authRepo, err := auth_es.Start(ctx, config.Auth, config.SystemDefaults, commands, queries, dbClient, eventstore, keys.OIDC, keys.User) + authRepo, err := auth_es.Start(ctx, config.Auth, config.SystemDefaults, commands, queries, dbClient, eventstore, keys.OIDC, keys.User, config.Eventstore.AllowOrderByCreationDate) if err != nil { return fmt.Errorf("error starting auth repo: %w", err) } - adminRepo, err := admin_es.Start(ctx, config.Admin, store, dbClient, eventstore) + adminRepo, err := admin_es.Start(ctx, config.Admin, store, dbClient, eventstore, config.Eventstore.AllowOrderByCreationDate) if err != nil { return fmt.Errorf("error starting admin repo: %w", err) } diff --git a/internal/admin/repository/eventsourcing/repository.go b/internal/admin/repository/eventsourcing/repository.go index d7c03c1fc9..febc2629de 100644 --- a/internal/admin/repository/eventsourcing/repository.go +++ b/internal/admin/repository/eventsourcing/repository.go @@ -23,8 +23,8 @@ type EsRepository struct { eventstore.AdministratorRepo } -func Start(ctx context.Context, conf Config, static static.Storage, dbClient *database.DB, esV2 *eventstore2.Eventstore) (*EsRepository, error) { - es, err := v1.Start(dbClient) +func Start(ctx context.Context, conf Config, static static.Storage, dbClient *database.DB, esV2 *eventstore2.Eventstore, allowOrderByCreationDate bool) (*EsRepository, error) { + es, err := v1.Start(dbClient, allowOrderByCreationDate) if err != nil { return nil, err } diff --git a/internal/auth/repository/eventsourcing/repository.go b/internal/auth/repository/eventsourcing/repository.go index 852ebcca91..454daffd11 100644 --- a/internal/auth/repository/eventsourcing/repository.go +++ b/internal/auth/repository/eventsourcing/repository.go @@ -34,8 +34,8 @@ type EsRepository struct { eventstore.OrgRepository } -func Start(ctx context.Context, conf Config, systemDefaults sd.SystemDefaults, command *command.Commands, queries *query.Queries, dbClient *database.DB, esV2 *eventstore2.Eventstore, oidcEncryption crypto.EncryptionAlgorithm, userEncryption crypto.EncryptionAlgorithm) (*EsRepository, error) { - es, err := v1.Start(dbClient) +func Start(ctx context.Context, conf Config, systemDefaults sd.SystemDefaults, command *command.Commands, queries *query.Queries, dbClient *database.DB, esV2 *eventstore2.Eventstore, oidcEncryption crypto.EncryptionAlgorithm, userEncryption crypto.EncryptionAlgorithm, allowOrderByCreationDate bool) (*EsRepository, error) { + es, err := v1.Start(dbClient, allowOrderByCreationDate) if err != nil { return nil, err } diff --git a/internal/authz/authz.go b/internal/authz/authz.go index f8f21e4125..6106d8a4e5 100644 --- a/internal/authz/authz.go +++ b/internal/authz/authz.go @@ -8,6 +8,6 @@ import ( "github.com/zitadel/zitadel/internal/query" ) -func Start(queries *query.Queries, dbClient *database.DB, keyEncryptionAlgorithm crypto.EncryptionAlgorithm, externalSecure bool) (repository.Repository, error) { - return eventsourcing.Start(queries, dbClient, keyEncryptionAlgorithm, externalSecure) +func Start(queries *query.Queries, dbClient *database.DB, keyEncryptionAlgorithm crypto.EncryptionAlgorithm, externalSecure, allowOrderByCreationDate bool) (repository.Repository, error) { + return eventsourcing.Start(queries, dbClient, keyEncryptionAlgorithm, externalSecure, allowOrderByCreationDate) } diff --git a/internal/authz/repository/eventsourcing/repository.go b/internal/authz/repository/eventsourcing/repository.go index c3f212d96c..2df593f11d 100644 --- a/internal/authz/repository/eventsourcing/repository.go +++ b/internal/authz/repository/eventsourcing/repository.go @@ -18,8 +18,8 @@ type EsRepository struct { eventstore.TokenVerifierRepo } -func Start(queries *query.Queries, dbClient *database.DB, keyEncryptionAlgorithm crypto.EncryptionAlgorithm, externalSecure bool) (repository.Repository, error) { - es, err := v1.Start(dbClient) +func Start(queries *query.Queries, dbClient *database.DB, keyEncryptionAlgorithm crypto.EncryptionAlgorithm, externalSecure, allowOrderByCreationDate bool) (repository.Repository, error) { + es, err := v1.Start(dbClient, allowOrderByCreationDate) if err != nil { return nil, err } diff --git a/internal/eventstore/config.go b/internal/eventstore/config.go index c9d00da267..21011dea67 100644 --- a/internal/eventstore/config.go +++ b/internal/eventstore/config.go @@ -9,8 +9,9 @@ import ( ) type Config struct { - PushTimeout time.Duration - Client *database.DB + PushTimeout time.Duration + Client *database.DB + AllowOrderByCreationDate bool repo repository.Repository } @@ -20,6 +21,6 @@ func TestConfig(repo repository.Repository) *Config { } func Start(config *Config) (*Eventstore, error) { - config.repo = z_sql.NewCRDB(config.Client) + config.repo = z_sql.NewCRDB(config.Client, config.AllowOrderByCreationDate) return NewEventstore(config), nil } diff --git a/internal/eventstore/repository/sql/crdb.go b/internal/eventstore/repository/sql/crdb.go index f78f210a6a..f27936de4c 100644 --- a/internal/eventstore/repository/sql/crdb.go +++ b/internal/eventstore/repository/sql/crdb.go @@ -99,10 +99,11 @@ const ( type CRDB struct { *database.DB + AllowOrderByCreationDate bool } -func NewCRDB(client *database.DB) *CRDB { - return &CRDB{client} +func NewCRDB(client *database.DB, allowOrderByCreationDate bool) *CRDB { + return &CRDB{client, allowOrderByCreationDate} } func (db *CRDB) Health(ctx context.Context) error { return db.Ping() } @@ -254,11 +255,19 @@ func (db *CRDB) db() *sql.DB { } func (db *CRDB) orderByEventSequence(desc bool) string { - if desc { - return " ORDER BY creation_date DESC, event_sequence DESC" + if db.AllowOrderByCreationDate { + if desc { + return " ORDER BY creation_date DESC, event_sequence DESC" + } + + return " ORDER BY creation_date, event_sequence" } - return " ORDER BY creation_date, event_sequence" + if desc { + return " ORDER BY event_sequence DESC" + } + + return " ORDER BY event_sequence" } func (db *CRDB) eventQuery() string { diff --git a/internal/eventstore/repository/sql/query_test.go b/internal/eventstore/repository/sql/query_test.go index 9560ff0918..af9aa9860c 100644 --- a/internal/eventstore/repository/sql/query_test.go +++ b/internal/eventstore/repository/sql/query_test.go @@ -542,6 +542,7 @@ func Test_query_events_with_crdb(t *testing.T) { DB: tt.fields.client, Database: new(testDB), }, + AllowOrderByCreationDate: true, } // setup initial data for query @@ -820,9 +821,12 @@ func Test_query_events_mocked(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - crdb := &CRDB{DB: &database.DB{ - Database: new(testDB), - }} + crdb := &CRDB{ + DB: &database.DB{ + Database: new(testDB), + }, + AllowOrderByCreationDate: true, + } if tt.fields.mock != nil { crdb.DB.DB = tt.fields.mock.client } diff --git a/internal/eventstore/v1/eventstore.go b/internal/eventstore/v1/eventstore.go index 0fc311f2e4..49e0910c73 100644 --- a/internal/eventstore/v1/eventstore.go +++ b/internal/eventstore/v1/eventstore.go @@ -22,9 +22,9 @@ type eventstore struct { repo repository.Repository } -func Start(db *database.DB) (Eventstore, error) { +func Start(db *database.DB, allowOrderByCreationDate bool) (Eventstore, error) { return &eventstore{ - repo: z_sql.Start(db), + repo: z_sql.Start(db, allowOrderByCreationDate), }, nil } diff --git a/internal/eventstore/v1/internal/repository/sql/config.go b/internal/eventstore/v1/internal/repository/sql/config.go index bef2df0de2..5ea34e6e40 100644 --- a/internal/eventstore/v1/internal/repository/sql/config.go +++ b/internal/eventstore/v1/internal/repository/sql/config.go @@ -4,8 +4,9 @@ import ( "github.com/zitadel/zitadel/internal/database" ) -func Start(client *database.DB) *SQL { +func Start(client *database.DB, allowOrderByCreationDate bool) *SQL { return &SQL{ - client: client, + client: client, + allowOrderByCreationDate: allowOrderByCreationDate, } } diff --git a/internal/eventstore/v1/internal/repository/sql/filter.go b/internal/eventstore/v1/internal/repository/sql/filter.go index 5674972de6..ab67730b59 100644 --- a/internal/eventstore/v1/internal/repository/sql/filter.go +++ b/internal/eventstore/v1/internal/repository/sql/filter.go @@ -21,11 +21,11 @@ func (db *SQL) Filter(ctx context.Context, searchQuery *es_models.SearchQueryFac if !searchQuery.InstanceFiltered { logging.WithFields("stack", string(debug.Stack())).Warn("instanceid not filtered") } - return filter(ctx, db.client, searchQuery) + return db.filter(ctx, db.client, searchQuery) } -func filter(ctx context.Context, db *database.DB, searchQuery *es_models.SearchQueryFactory) (events []*es_models.Event, err error) { - query, limit, values, rowScanner := buildQuery(ctx, db, searchQuery) +func (sql *SQL) filter(ctx context.Context, db *database.DB, searchQuery *es_models.SearchQueryFactory) (events []*es_models.Event, err error) { + query, limit, values, rowScanner := sql.buildQuery(ctx, db, searchQuery) if query == "" { return nil, errors.ThrowInvalidArgument(nil, "SQL-rWeBw", "invalid query factory") } @@ -53,7 +53,7 @@ func filter(ctx context.Context, db *database.DB, searchQuery *es_models.SearchQ } func (db *SQL) LatestSequence(ctx context.Context, queryFactory *es_models.SearchQueryFactory) (uint64, error) { - query, _, values, rowScanner := buildQuery(ctx, db.client, queryFactory) + query, _, values, rowScanner := db.buildQuery(ctx, db.client, queryFactory) if query == "" { return 0, errors.ThrowInvalidArgument(nil, "SQL-rWeBw", "invalid query factory") } @@ -68,7 +68,7 @@ func (db *SQL) LatestSequence(ctx context.Context, queryFactory *es_models.Searc } func (db *SQL) InstanceIDs(ctx context.Context, queryFactory *es_models.SearchQueryFactory) ([]string, error) { - query, _, values, rowScanner := buildQuery(ctx, db.client, queryFactory) + query, _, values, rowScanner := db.buildQuery(ctx, db.client, queryFactory) if query == "" { return nil, errors.ThrowInvalidArgument(nil, "SQL-Sfwg2", "invalid query factory") } diff --git a/internal/eventstore/v1/internal/repository/sql/filter_test.go b/internal/eventstore/v1/internal/repository/sql/filter_test.go index df352a236a..75863fed8a 100644 --- a/internal/eventstore/v1/internal/repository/sql/filter_test.go +++ b/internal/eventstore/v1/internal/repository/sql/filter_test.go @@ -123,7 +123,8 @@ func TestSQL_Filter(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { sql := &SQL{ - client: &database.DB{DB: tt.fields.client.sqlClient, Database: new(testDB)}, + client: &database.DB{DB: tt.fields.client.sqlClient, Database: new(testDB)}, + allowOrderByCreationDate: true, } events, err := sql.Filter(context.Background(), tt.args.searchQuery) if (err != nil) != tt.res.wantErr { diff --git a/internal/eventstore/v1/internal/repository/sql/query.go b/internal/eventstore/v1/internal/repository/sql/query.go index ed4510e7f7..8b2ea1dc98 100644 --- a/internal/eventstore/v1/internal/repository/sql/query.go +++ b/internal/eventstore/v1/internal/repository/sql/query.go @@ -33,7 +33,7 @@ const ( " FROM eventstore.events" ) -func buildQuery(ctx context.Context, db dialect.Database, queryFactory *es_models.SearchQueryFactory) (query string, limit uint64, values []interface{}, rowScanner func(s scan, dest interface{}) error) { +func (sql *SQL) buildQuery(ctx context.Context, db dialect.Database, queryFactory *es_models.SearchQueryFactory) (query string, limit uint64, values []interface{}, rowScanner func(s scan, dest interface{}) error) { searchQuery, err := queryFactory.Build() if err != nil { logging.New().WithError(err).Warn("search query factory invalid") @@ -51,9 +51,17 @@ func buildQuery(ctx context.Context, db dialect.Database, queryFactory *es_model query += where if searchQuery.Columns == es_models.Columns_Event { - order := " ORDER BY creation_date, event_sequence" - if searchQuery.Desc { - order = " ORDER BY creation_date DESC, event_sequence DESC" + var order string + if sql.allowOrderByCreationDate { + order = " ORDER BY creation_date, event_sequence" + if searchQuery.Desc { + order = " ORDER BY creation_date DESC, event_sequence DESC" + } + } else { + order = " ORDER BY event_sequence" + if searchQuery.Desc { + order = " ORDER BY event_sequence DESC" + } } query += order } diff --git a/internal/eventstore/v1/internal/repository/sql/query_test.go b/internal/eventstore/v1/internal/repository/sql/query_test.go index f0b13bc701..0c4b4d9db9 100644 --- a/internal/eventstore/v1/internal/repository/sql/query_test.go +++ b/internal/eventstore/v1/internal/repository/sql/query_test.go @@ -470,7 +470,7 @@ func Test_buildQuery(t *testing.T) { ctx := context.Background() db := new(testDB) t.Run(tt.name, func(t *testing.T) { - gotQuery, gotLimit, gotValues, gotRowScanner := buildQuery(ctx, db, tt.args.queryFactory) + gotQuery, gotLimit, gotValues, gotRowScanner := (&SQL{allowOrderByCreationDate: true}).buildQuery(ctx, db, tt.args.queryFactory) if gotQuery != tt.res.query { t.Errorf("buildQuery() gotQuery = %v, want %v", gotQuery, tt.res.query) } diff --git a/internal/eventstore/v1/internal/repository/sql/sql.go b/internal/eventstore/v1/internal/repository/sql/sql.go index 3f5ecc431e..ab4c9ca26c 100644 --- a/internal/eventstore/v1/internal/repository/sql/sql.go +++ b/internal/eventstore/v1/internal/repository/sql/sql.go @@ -7,7 +7,8 @@ import ( ) type SQL struct { - client *database.DB + client *database.DB + allowOrderByCreationDate bool } func (db *SQL) Health(ctx context.Context) error { From 40bf7e49cc46a152edce3a48b0160a5131d4d346 Mon Sep 17 00:00:00 2001 From: Silvan Date: Tue, 2 May 2023 10:46:44 +0200 Subject: [PATCH 4/5] fix: correct tracing in access interceptor (#5766) --- internal/api/http/middleware/access_interceptor.go | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/internal/api/http/middleware/access_interceptor.go b/internal/api/http/middleware/access_interceptor.go index 0de0e8c19b..4e95a83f6f 100644 --- a/internal/api/http/middleware/access_interceptor.go +++ b/internal/api/http/middleware/access_interceptor.go @@ -43,12 +43,10 @@ func (a *AccessInterceptor) Handle(next http.Handler) http.Handler { return next } return http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) { - ctx := request.Context() var err error - tracingCtx, span := tracing.NewServerInterceptorSpan(ctx) - defer func() { span.EndWithError(err) }() + tracingCtx, checkSpan := tracing.NewNamedSpan(ctx, "checkAccess") wrappedWriter := &statusRecorder{ResponseWriter: writer, status: 0} @@ -63,8 +61,13 @@ func (a *AccessInterceptor) Handle(next http.Handler) http.Handler { wrappedWriter.ignoreWrites = true } + checkSpan.End() + next.ServeHTTP(wrappedWriter, request) + tracingCtx, writeSpan := tracing.NewNamedSpan(tracingCtx, "writeAccess") + defer writeSpan.End() + requestURL := request.RequestURI unescapedURL, err := url.QueryUnescape(requestURL) if err != nil { From e0505b2defee141935b05e5d998b94bc8877ac9e Mon Sep 17 00:00:00 2001 From: Livio Spring Date: Tue, 2 May 2023 18:03:33 +0200 Subject: [PATCH 5/5] fix: use correct org id for external authentication actions (#5793) --- internal/api/ui/login/custom_action.go | 38 ++++++++++++++++++++++++++ 1 file changed, 38 insertions(+) diff --git a/internal/api/ui/login/custom_action.go b/internal/api/ui/login/custom_action.go index 772017277e..d6dbe6cdd4 100644 --- a/internal/api/ui/login/custom_action.go +++ b/internal/api/ui/login/custom_action.go @@ -6,6 +6,7 @@ import ( "net/http" "github.com/dop251/goja" + "github.com/zitadel/logging" "github.com/zitadel/oidc/v2/pkg/oidc" "golang.org/x/text/language" @@ -14,6 +15,7 @@ import ( "github.com/zitadel/zitadel/internal/api/authz" "github.com/zitadel/zitadel/internal/domain" "github.com/zitadel/zitadel/internal/idp" + "github.com/zitadel/zitadel/internal/query" ) func (l *Login) runPostExternalAuthenticationActions( @@ -26,7 +28,21 @@ func (l *Login) runPostExternalAuthenticationActions( ) (_ *domain.ExternalUser, userChanged bool, err error) { ctx := httpRequest.Context() + // use the request org (scopes or domain discovery) as default resourceOwner := authRequest.RequestedOrgID + // if the user was already linked to an IDP and redirected to that, the requested org might be empty + if resourceOwner == "" { + resourceOwner = authRequest.UserOrgID + } + // if we will have no org (e.g. user clicked directly on the IDP on the login page) + if resourceOwner == "" { + // in this case the user might nevertheless already be linked to an IDP, + // so let's do a workaround and resourceOwnerOfUserIDPLink if there would be a IDP link + resourceOwner, err = l.resourceOwnerOfUserIDPLink(ctx, authRequest.SelectedIDPConfigID, user.ExternalUserID) + logging.WithFields("authReq", authRequest.ID, "idpID", authRequest.SelectedIDPConfigID).OnError(err). + Warn("could not determine resource owner for runPostExternalAuthenticationActions, fall back to default org id") + } + // fallback to default org id if resourceOwner == "" { resourceOwner = authz.GetInstance(ctx).DefaultOrganisationID() } @@ -394,3 +410,25 @@ func tokenCtxFields(tokens *oidc.Tokens[*oidc.IDTokenClaims]) []actions.FieldOpt actions.SetFields("claimsJSON", claimsJSON), } } + +func (l *Login) resourceOwnerOfUserIDPLink(ctx context.Context, idpConfigID string, externalUserID string) (string, error) { + idQuery, err := query.NewIDPUserLinkIDPIDSearchQuery(idpConfigID) + if err != nil { + return "", err + } + externalIDQuery, err := query.NewIDPUserLinksExternalIDSearchQuery(externalUserID) + if err != nil { + return "", err + } + queries := []query.SearchQuery{ + idQuery, externalIDQuery, + } + links, err := l.query.IDPUserLinks(ctx, &query.IDPUserLinksSearchQuery{Queries: queries}, false) + if err != nil { + return "", err + } + if len(links.Links) != 1 { + return "", nil + } + return links.Links[0].ResourceOwner, nil +}