From 1305c14e4910942380176bab89eb18cff580c619 Mon Sep 17 00:00:00 2001 From: Livio Amstutz Date: Tue, 19 Apr 2022 08:26:12 +0200 Subject: [PATCH] feat: handle instanceID in projections (#3442) * feat: handle instanceID in projections * rename functions * fix key lock * fix import --- cmd/admin/initialise/sql/09_events_table.sql | 2 +- cmd/admin/setup/01_sql/adminapi.sql | 11 +- cmd/admin/setup/01_sql/auth.sql | 47 +-- cmd/admin/setup/01_sql/authz.sql | 9 +- cmd/admin/setup/01_sql/notification.sql | 9 +- cmd/admin/setup/01_sql/projections.sql | 8 +- cmd/admin/setup/setup.go | 2 +- internal/admin/repository/administrator.go | 2 +- .../eventsourcing/eventstore/administrator.go | 13 +- .../eventsourcing/handler/styling.go | 30 +- .../repository/eventsourcing/spooler/lock.go | 7 +- .../eventsourcing/view/error_event.go | 4 +- .../repository/eventsourcing/view/sequence.go | 26 +- .../repository/eventsourcing/view/styling.go | 16 +- internal/api/oidc/key.go | 6 +- .../eventsourcing/eventstore/auth_request.go | 36 +- .../eventstore/auth_request_test.go | 18 +- .../eventsourcing/eventstore/org.go | 2 +- .../eventsourcing/eventstore/refresh_token.go | 18 +- .../eventsourcing/eventstore/token.go | 16 +- .../eventsourcing/eventstore/user.go | 5 +- .../eventsourcing/eventstore/user_session.go | 2 +- .../eventsourcing/handler/idp_config.go | 31 +- .../eventsourcing/handler/idp_providers.go | 33 +- .../handler/org_project_mapping.go | 31 +- .../eventsourcing/handler/refresh_token.go | 34 +- .../repository/eventsourcing/handler/token.go | 44 ++- .../repository/eventsourcing/handler/user.go | 36 +- .../handler/user_external_idps.go | 32 +- .../eventsourcing/handler/user_session.go | 35 +- .../repository/eventsourcing/spooler/lock.go | 4 +- .../eventsourcing/view/error_event.go | 4 +- .../eventsourcing/view/external_idps.go | 32 +- .../eventsourcing/view/idp_configs.go | 22 +- .../eventsourcing/view/idp_providers.go | 32 +- .../eventsourcing/view/org_project_mapping.go | 28 +- .../eventsourcing/view/refresh_token.go | 30 +- .../repository/eventsourcing/view/sequence.go | 22 +- .../repository/eventsourcing/view/token.go | 38 ++- .../repository/eventsourcing/view/user.go | 48 +-- .../eventsourcing/view/user_session.go | 28 +- .../eventstore/token_verifier.go | 9 +- .../eventsourcing/handler/user_membership.go | 58 ++-- .../repository/eventsourcing/spooler/lock.go | 7 +- .../eventsourcing/view/error_event.go | 4 +- .../repository/eventsourcing/view/sequence.go | 22 +- .../repository/eventsourcing/view/token.go | 20 +- .../eventsourcing/view/user_membership.go | 40 ++- .../instance_policy_password_age_test.go | 3 +- .../handler/crdb/current_sequence.go | 29 +- .../eventstore/handler/crdb/db_mock_test.go | 116 +++---- .../eventstore/handler/crdb/handler_stmt.go | 87 +++-- .../handler/crdb/handler_stmt_test.go | 154 ++++++--- internal/eventstore/handler/crdb/lock.go | 26 +- internal/eventstore/handler/crdb/lock_test.go | 43 ++- .../eventstore/handler/handler_projection.go | 10 +- .../handler/handler_projection_test.go | 4 +- .../eventstore/repository/search_query.go | 2 + internal/eventstore/repository/sql/crdb.go | 7 +- internal/eventstore/search_query.go | 39 ++- internal/eventstore/v1/eventstore.go | 8 - .../internal/repository/sql/db_mock_test.go | 16 +- .../v1/internal/repository/sql/filter_test.go | 18 +- .../v1/internal/repository/sql/query.go | 35 +- .../v1/internal/repository/sql/query_test.go | 42 ++- internal/eventstore/v1/locker/lock.go | 17 +- internal/eventstore/v1/locker/lock_test.go | 25 +- .../eventstore/v1/mock/eventstore.mock.go | 90 +---- .../eventstore/v1/models/aggregate_test.go | 4 +- internal/eventstore/v1/models/operation.go | 1 + internal/eventstore/v1/models/search_query.go | 302 +++++++++------- .../eventstore/v1/models/search_query_old.go | 63 +++- .../eventstore/v1/models/search_query_test.go | 323 ++++++++++++------ internal/eventstore/v1/query/handler.go | 9 +- .../eventstore/v1/spooler/mock/spooler.go | 23 +- internal/eventstore/v1/spooler/spooler.go | 24 +- .../eventstore/v1/spooler/spooler_test.go | 24 +- internal/iam/model/idp_provider_view.go | 1 + .../iam/repository/view/idp_provider_view.go | 28 +- internal/iam/repository/view/idp_view.go | 21 +- .../iam/repository/view/model/idp_config.go | 2 +- .../iam/repository/view/model/idp_provider.go | 3 +- .../view/model/idp_provider_query.go | 2 + .../iam/repository/view/model/label_policy.go | 2 +- internal/iam/repository/view/styling.go | 5 +- .../eventsourcing/handler/notification.go | 56 +-- .../eventsourcing/handler/notify_user.go | 36 +- .../repository/eventsourcing/spooler/lock.go | 7 +- .../eventsourcing/view/error_event.go | 4 +- .../eventsourcing/view/notification.go | 12 +- .../eventsourcing/view/notify_user.go | 24 +- .../repository/eventsourcing/view/sequence.go | 22 +- internal/org/repository/view/query.go | 8 +- .../view/org_project_mapping_view.go | 24 +- internal/project/repository/view/query.go | 11 +- internal/query/current_sequence.go | 6 + internal/query/projection/login_name.go | 2 +- internal/user/model/external_idp_view.go | 1 + internal/user/model/notify_user.go | 4 +- internal/user/model/user_session_view.go | 1 + .../user/repository/view/external_idp_view.go | 37 +- .../view/model/external_idp_query.go | 2 + .../repository/view/model/external_idps.go | 3 +- .../user/repository/view/model/notify_user.go | 3 +- .../view/model/notify_user_query.go | 2 + .../repository/view/model/refresh_token.go | 2 +- internal/user/repository/view/model/token.go | 2 +- internal/user/repository/view/model/user.go | 2 +- .../repository/view/model/user_membership.go | 2 +- .../repository/view/model/user_session.go | 3 +- .../view/model/user_session_query.go | 2 + internal/user/repository/view/notify_user.go | 26 +- internal/user/repository/view/query.go | 13 +- .../repository/view/refresh_token_view.go | 30 +- internal/user/repository/view/token_view.go | 57 +++- .../user/repository/view/user_session_view.go | 34 +- internal/user/repository/view/user_view.go | 82 ++++- .../repository/view/usermembership_view.go | 39 ++- internal/view/repository/failed_events.go | 18 +- internal/view/repository/sequence.go | 89 ++++- 120 files changed, 2078 insertions(+), 1209 deletions(-) diff --git a/cmd/admin/initialise/sql/09_events_table.sql b/cmd/admin/initialise/sql/09_events_table.sql index 0e672ca34b..f2df84f11d 100644 --- a/cmd/admin/initialise/sql/09_events_table.sql +++ b/cmd/admin/initialise/sql/09_events_table.sql @@ -17,7 +17,7 @@ CREATE TABLE eventstore.events ( , PRIMARY KEY (event_sequence DESC, instance_id) USING HASH WITH BUCKET_COUNT = 10 , INDEX agg_type_agg_id (aggregate_type, aggregate_id, instance_id) , INDEX agg_type (aggregate_type, instance_id) - , INDEX agg_type_seq (aggregate_type, event_sequence DESC, instance_id) + , INDEX agg_type_seq (aggregate_type, event_sequence DESC, instance_id) STORING (id, event_type, aggregate_id, aggregate_version, previous_aggregate_sequence, creation_date, event_data, editor_user, editor_service, resource_owner, previous_aggregate_type_sequence) , INDEX max_sequence (aggregate_type, aggregate_id, event_sequence DESC, instance_id) , CONSTRAINT previous_sequence_unique UNIQUE (previous_aggregate_sequence DESC, instance_id) diff --git a/cmd/admin/setup/01_sql/adminapi.sql b/cmd/admin/setup/01_sql/adminapi.sql index db5c3f37ba..57632660eb 100644 --- a/cmd/admin/setup/01_sql/adminapi.sql +++ b/cmd/admin/setup/01_sql/adminapi.sql @@ -4,8 +4,9 @@ CREATE TABLE adminapi.locks ( locker_id TEXT, locked_until TIMESTAMPTZ(3), view_name TEXT, + instance_id TEXT NOT NULL, - PRIMARY KEY (view_name) + PRIMARY KEY (view_name, instance_id) ); CREATE TABLE adminapi.current_sequences ( @@ -13,8 +14,9 @@ CREATE TABLE adminapi.current_sequences ( current_sequence BIGINT, event_timestamp TIMESTAMPTZ, last_successful_spooler_run TIMESTAMPTZ, + instance_id TEXT NOT NULL, - PRIMARY KEY (view_name) + PRIMARY KEY (view_name, instance_id) ); CREATE TABLE adminapi.failed_events ( @@ -22,8 +24,9 @@ CREATE TABLE adminapi.failed_events ( failed_sequence BIGINT, failure_count SMALLINT, err_msg TEXT, + instance_id TEXT NOT NULL, - PRIMARY KEY (view_name, failed_sequence) + PRIMARY KEY (view_name, failed_sequence, instance_id) ); CREATE TABLE adminapi.styling ( @@ -50,5 +53,5 @@ CREATE TABLE adminapi.styling ( hide_login_name_suffix BOOL NULL, instance_id STRING NOT NULL, - PRIMARY KEY (aggregate_id, label_policy_state) + PRIMARY KEY (aggregate_id, label_policy_state, instance_id) ); diff --git a/cmd/admin/setup/01_sql/auth.sql b/cmd/admin/setup/01_sql/auth.sql index 2f0e62d502..6392eab15c 100644 --- a/cmd/admin/setup/01_sql/auth.sql +++ b/cmd/admin/setup/01_sql/auth.sql @@ -4,8 +4,9 @@ CREATE TABLE auth.locks ( locker_id TEXT, locked_until TIMESTAMPTZ(3), view_name TEXT, + instance_id TEXT NOT NULL, - PRIMARY KEY (view_name) + PRIMARY KEY (view_name, instance_id) ); CREATE TABLE auth.current_sequences ( @@ -13,8 +14,9 @@ CREATE TABLE auth.current_sequences ( current_sequence BIGINT, event_timestamp TIMESTAMPTZ, last_successful_spooler_run TIMESTAMPTZ, + instance_id TEXT NOT NULL, - PRIMARY KEY (view_name) + PRIMARY KEY (view_name, instance_id) ); CREATE TABLE auth.failed_events ( @@ -22,8 +24,9 @@ CREATE TABLE auth.failed_events ( failed_sequence BIGINT, failure_count SMALLINT, err_msg TEXT, + instance_id TEXT NOT NULL, - PRIMARY KEY (view_name, failed_sequence) + PRIMARY KEY (view_name, failed_sequence, instance_id) ); CREATE TABLE auth.users ( @@ -68,9 +71,9 @@ CREATE TABLE auth.users ( avatar_key STRING NULL, passwordless_init_required BOOL NULL, password_init_required BOOL NULL, - instance_id STRING NULL, + instance_id STRING NOT NULL, - PRIMARY KEY (id) + PRIMARY KEY (id, instance_id) ); CREATE TABLE auth.user_sessions ( @@ -93,9 +96,9 @@ CREATE TABLE auth.user_sessions ( selected_idp_config_id STRING NULL, passwordless_verification TIMESTAMPTZ NULL, avatar_key STRING NULL, - instance_id STRING NULL, + instance_id STRING NOT NULL, - PRIMARY KEY (user_agent_id, user_id) + PRIMARY KEY (user_agent_id, user_id, instance_id) ); CREATE TABLE auth.user_external_idps ( @@ -108,9 +111,9 @@ CREATE TABLE auth.user_external_idps ( change_date TIMESTAMPTZ NULL, sequence INT8 NULL, resource_owner STRING NULL, - instance_id STRING NULL, + instance_id STRING NOT NULL, - PRIMARY KEY (external_user_id, idp_config_id) + PRIMARY KEY (external_user_id, idp_config_id, instance_id) ); CREATE TABLE auth.tokens ( @@ -128,9 +131,9 @@ CREATE TABLE auth.tokens ( preferred_language STRING NULL, refresh_token_id STRING NULL, is_pat BOOL NOT NULL DEFAULT false, - instance_id STRING NULL, + instance_id STRING NOT NULL, - PRIMARY KEY (id), + PRIMARY KEY (id, instance_id), INDEX user_user_agent_idx (user_id, user_agent_id) ); @@ -150,19 +153,19 @@ CREATE TABLE auth.refresh_tokens ( scopes STRING[] NULL, audience STRING[] NULL, amr STRING[] NULL, - instance_id STRING NULL, + instance_id STRING NOT NULL, - PRIMARY KEY (id), - UNIQUE INDEX unique_client_user_index (client_id ASC, user_agent_id ASC, user_id ASC) + PRIMARY KEY (id, instance_id), + UNIQUE INDEX unique_client_user_index (client_id ASC, user_agent_id ASC, user_id ASC, instance_id) ); CREATE TABLE auth.org_project_mapping ( org_id STRING NOT NULL, project_id STRING NOT NULL, project_grant_id STRING NULL, - instance_id STRING NULL, + instance_id STRING NOT NULL, - PRIMARY KEY (org_id, project_id) + PRIMARY KEY (org_id, project_id, instance_id) ); CREATE TABLE auth.idp_providers ( @@ -176,9 +179,9 @@ CREATE TABLE auth.idp_providers ( idp_provider_type INT2 NULL, idp_state INT2 NULL, styling_type INT2 NULL, - instance_id STRING NULL, + instance_id STRING NOT NULL, - PRIMARY KEY (aggregate_id, idp_config_id) + PRIMARY KEY (aggregate_id, idp_config_id, instance_id) ); CREATE TABLE auth.idp_configs ( @@ -204,9 +207,9 @@ CREATE TABLE auth.idp_configs ( jwt_endpoint STRING NULL, jwt_keys_endpoint STRING NULL, jwt_header_name STRING NULL, - instance_id STRING NULL, + instance_id STRING NOT NULL, - PRIMARY KEY (idp_config_id) + PRIMARY KEY (idp_config_id, instance_id) ); CREATE TABLE auth.auth_requests ( @@ -216,8 +219,8 @@ CREATE TABLE auth.auth_requests ( request_type INT2 NULL, creation_date TIMESTAMPTZ NULL, change_date TIMESTAMPTZ NULL, - instance_id STRING NULL, + instance_id STRING NOT NULL, - PRIMARY KEY (id), + PRIMARY KEY (id, instance_id), INDEX auth_code_idx (code) ); diff --git a/cmd/admin/setup/01_sql/authz.sql b/cmd/admin/setup/01_sql/authz.sql index 78bb3c81d1..ae7007cb47 100644 --- a/cmd/admin/setup/01_sql/authz.sql +++ b/cmd/admin/setup/01_sql/authz.sql @@ -4,8 +4,9 @@ CREATE TABLE authz.locks ( locker_id TEXT, locked_until TIMESTAMPTZ(3), view_name TEXT, + instance_id TEXT NOT NULL, - PRIMARY KEY (view_name) + PRIMARY KEY (view_name, instance_id) ); CREATE TABLE authz.current_sequences ( @@ -13,8 +14,9 @@ CREATE TABLE authz.current_sequences ( current_sequence BIGINT, event_timestamp TIMESTAMPTZ, last_successful_spooler_run TIMESTAMPTZ, + instance_id TEXT NOT NULL, - PRIMARY KEY (view_name) + PRIMARY KEY (view_name, instance_id) ); CREATE TABLE authz.failed_events ( @@ -22,8 +24,9 @@ CREATE TABLE authz.failed_events ( failed_sequence BIGINT, failure_count SMALLINT, err_msg TEXT, + instance_id TEXT NOT NULL, - PRIMARY KEY (view_name, failed_sequence) + PRIMARY KEY (view_name, failed_sequence, instance_id) ); CREATE TABLE authz.user_memberships ( diff --git a/cmd/admin/setup/01_sql/notification.sql b/cmd/admin/setup/01_sql/notification.sql index 6e3acf9b9e..b117a01ad9 100644 --- a/cmd/admin/setup/01_sql/notification.sql +++ b/cmd/admin/setup/01_sql/notification.sql @@ -4,8 +4,9 @@ CREATE TABLE notification.locks ( locker_id TEXT, locked_until TIMESTAMPTZ(3), view_name TEXT, + instance_id TEXT NOT NULL, - PRIMARY KEY (view_name) + PRIMARY KEY (view_name, instance_id) ); CREATE TABLE notification.current_sequences ( @@ -13,8 +14,9 @@ CREATE TABLE notification.current_sequences ( current_sequence BIGINT, event_timestamp TIMESTAMPTZ, last_successful_spooler_run TIMESTAMPTZ, + instance_id TEXT NOT NULL, - PRIMARY KEY (view_name) + PRIMARY KEY (view_name, instance_id) ); CREATE TABLE notification.failed_events ( @@ -22,8 +24,9 @@ CREATE TABLE notification.failed_events ( failed_sequence BIGINT, failure_count SMALLINT, err_msg TEXT, + instance_id TEXT NOT NULL, - PRIMARY KEY (view_name, failed_sequence) + PRIMARY KEY (view_name, failed_sequence, instance_id) ); CREATE TABLE notification.notify_users ( diff --git a/cmd/admin/setup/01_sql/projections.sql b/cmd/admin/setup/01_sql/projections.sql index 03678a0ae8..2e3837eb06 100644 --- a/cmd/admin/setup/01_sql/projections.sql +++ b/cmd/admin/setup/01_sql/projections.sql @@ -2,17 +2,19 @@ CREATE TABLE projections.locks ( locker_id TEXT, locked_until TIMESTAMPTZ(3), projection_name TEXT, + instance_id TEXT NOT NULL, - PRIMARY KEY (projection_name) + PRIMARY KEY (projection_name, instance_id) ); CREATE TABLE projections.current_sequences ( projection_name TEXT, aggregate_type TEXT, current_sequence BIGINT, + instance_id TEXT NOT NULL, timestamp TIMESTAMPTZ, - PRIMARY KEY (projection_name, aggregate_type) + PRIMARY KEY (projection_name, aggregate_type, instance_id) ); CREATE TABLE projections.failed_events ( @@ -20,7 +22,7 @@ CREATE TABLE projections.failed_events ( failed_sequence BIGINT, failure_count SMALLINT, error TEXT, - instance_id TEXT, + instance_id TEXT NOT NULL, PRIMARY KEY (projection_name, failed_sequence, instance_id) ); diff --git a/cmd/admin/setup/setup.go b/cmd/admin/setup/setup.go index 5745ca8b34..6604958985 100644 --- a/cmd/admin/setup/setup.go +++ b/cmd/admin/setup/setup.go @@ -60,7 +60,7 @@ func Setup(config *Config, steps *Steps, masterKey string) { steps.S3DefaultInstance.db = dbClient steps.S3DefaultInstance.defaults = config.SystemDefaults steps.S3DefaultInstance.masterKey = masterKey - steps.S3DefaultInstance.domain = config.SystemDefaults.Domain + steps.S3DefaultInstance.domain = config.ExternalDomain steps.S3DefaultInstance.zitadelRoles = config.InternalAuthZ.RolePermissionMappings steps.S3DefaultInstance.userEncryptionKey = config.EncryptionKeys.User steps.S3DefaultInstance.InstanceSetup.Zitadel.IsDevMode = !config.ExternalSecure diff --git a/internal/admin/repository/administrator.go b/internal/admin/repository/administrator.go index 026a293c9e..1f776d3508 100644 --- a/internal/admin/repository/administrator.go +++ b/internal/admin/repository/administrator.go @@ -2,6 +2,7 @@ package repository import ( "context" + "github.com/caos/zitadel/internal/view/model" ) @@ -9,6 +10,5 @@ type AdministratorRepository interface { GetFailedEvents(context.Context) ([]*model.FailedEvent, error) RemoveFailedEvent(context.Context, *model.FailedEvent) error GetViews() ([]*model.View, error) - GetSpoolerDiv(db, viewName string) int64 ClearView(ctx context.Context, db, viewName string) error } diff --git a/internal/admin/repository/eventsourcing/eventstore/administrator.go b/internal/admin/repository/eventsourcing/eventstore/administrator.go index a80b42574d..1056cc0d93 100644 --- a/internal/admin/repository/eventsourcing/eventstore/administrator.go +++ b/internal/admin/repository/eventsourcing/eventstore/administrator.go @@ -2,14 +2,13 @@ package eventstore import ( "context" - "time" "github.com/caos/zitadel/internal/admin/repository/eventsourcing/view" view_model "github.com/caos/zitadel/internal/view/model" "github.com/caos/zitadel/internal/view/repository" ) -var dbList = []string{"management", "auth", "authz", "adminapi", "notification"} +var dbList = []string{"auth", "authz", "adminapi", "notification"} type AdministratorRepo struct { View *view.View @@ -47,16 +46,6 @@ func (repo *AdministratorRepo) GetViews() ([]*view_model.View, error) { return views, nil } -func (repo *AdministratorRepo) GetSpoolerDiv(database, view string) int64 { - sequence, err := repo.View.GetCurrentSequence(database, view) - if err != nil { - - return 0 - } - divDuration := time.Now().Sub(sequence.LastSuccessfulSpoolerRun) - return divDuration.Milliseconds() -} - func (repo *AdministratorRepo) ClearView(ctx context.Context, database, view string) error { return repo.View.ClearView(database, view) } diff --git a/internal/admin/repository/eventsourcing/handler/styling.go b/internal/admin/repository/eventsourcing/handler/styling.go index 75839570a2..e5d8eef4b0 100644 --- a/internal/admin/repository/eventsourcing/handler/styling.go +++ b/internal/admin/repository/eventsourcing/handler/styling.go @@ -67,8 +67,8 @@ func (_ *Styling) AggregateTypes() []models.AggregateType { return []models.AggregateType{org.AggregateType, instance.AggregateType} } -func (m *Styling) CurrentSequence() (uint64, error) { - sequence, err := m.view.GetLatestStylingSequence() +func (m *Styling) CurrentSequence(instanceID string) (uint64, error) { + sequence, err := m.view.GetLatestStylingSequence(instanceID) if err != nil { return 0, err } @@ -76,13 +76,29 @@ func (m *Styling) CurrentSequence() (uint64, error) { } func (m *Styling) EventQuery() (*models.SearchQuery, error) { - sequence, err := m.view.GetLatestStylingSequence() + sequences, err := m.view.GetLatestStylingSequences() if err != nil { return nil, err } - return models.NewSearchQuery(). + query := models.NewSearchQuery() + instances := make([]string, 0) + for _, sequence := range sequences { + for _, instance := range instances { + if sequence.InstanceID == instance { + break + } + } + instances = append(instances, sequence.InstanceID) + query.AddQuery(). + AggregateTypeFilter(m.AggregateTypes()...). + LatestSequenceFilter(sequence.CurrentSequence). + InstanceIDFilter(sequence.InstanceID) + } + return query.AddQuery(). AggregateTypeFilter(m.AggregateTypes()...). - LatestSequenceFilter(sequence.CurrentSequence), nil + LatestSequenceFilter(0). + ExcludedInstanceIDsFilter(instances...). + SearchQuery(), nil } func (m *Styling) Reduce(event *models.Event) (err error) { @@ -123,7 +139,7 @@ func (m *Styling) processLabelPolicy(event *models.Event) (err error) { org.LabelPolicyFontRemovedEventType, instance.LabelPolicyAssetsRemovedEventType, org.LabelPolicyAssetsRemovedEventType: - policy, err = m.view.StylingByAggregateIDAndState(event.AggregateID, int32(domain.LabelPolicyStatePreview)) + policy, err = m.view.StylingByAggregateIDAndState(event.AggregateID, event.InstanceID, int32(domain.LabelPolicyStatePreview)) if err != nil { return err } @@ -131,7 +147,7 @@ func (m *Styling) processLabelPolicy(event *models.Event) (err error) { case instance.LabelPolicyActivatedEventType, org.LabelPolicyActivatedEventType: - policy, err = m.view.StylingByAggregateIDAndState(event.AggregateID, int32(domain.LabelPolicyStatePreview)) + policy, err = m.view.StylingByAggregateIDAndState(event.AggregateID, event.InstanceID, int32(domain.LabelPolicyStatePreview)) if err != nil { return err } diff --git a/internal/admin/repository/eventsourcing/spooler/lock.go b/internal/admin/repository/eventsourcing/spooler/lock.go index f1440f79af..812643a503 100644 --- a/internal/admin/repository/eventsourcing/spooler/lock.go +++ b/internal/admin/repository/eventsourcing/spooler/lock.go @@ -2,8 +2,9 @@ package spooler import ( "database/sql" - es_locker "github.com/caos/zitadel/internal/eventstore/v1/locker" "time" + + es_locker "github.com/caos/zitadel/internal/eventstore/v1/locker" ) const ( @@ -14,6 +15,6 @@ type locker struct { dbClient *sql.DB } -func (l *locker) Renew(lockerID, viewModel string, waitTime time.Duration) error { - return es_locker.Renew(l.dbClient, lockTable, lockerID, viewModel, waitTime) +func (l *locker) Renew(lockerID, viewModel, instanceID string, waitTime time.Duration) error { + return es_locker.Renew(l.dbClient, lockTable, lockerID, viewModel, instanceID, waitTime) } diff --git a/internal/admin/repository/eventsourcing/view/error_event.go b/internal/admin/repository/eventsourcing/view/error_event.go index 2cb7bd6754..4cc375b22f 100644 --- a/internal/admin/repository/eventsourcing/view/error_event.go +++ b/internal/admin/repository/eventsourcing/view/error_event.go @@ -17,8 +17,8 @@ func (v *View) RemoveFailedEvent(database string, failedEvent *repository.Failed return repository.RemoveFailedEvent(v.Db, database+"."+errColumn, failedEvent) } -func (v *View) latestFailedEvent(viewName string, sequence uint64) (*repository.FailedEvent, error) { - return repository.LatestFailedEvent(v.Db, errTable, viewName, sequence) +func (v *View) latestFailedEvent(viewName, instanceID string, sequence uint64) (*repository.FailedEvent, error) { + return repository.LatestFailedEvent(v.Db, errTable, viewName, instanceID, sequence) } func (v *View) AllFailedEvents(db string) ([]*repository.FailedEvent, error) { diff --git a/internal/admin/repository/eventsourcing/view/sequence.go b/internal/admin/repository/eventsourcing/view/sequence.go index b73461152b..02e8519e9c 100644 --- a/internal/admin/repository/eventsourcing/view/sequence.go +++ b/internal/admin/repository/eventsourcing/view/sequence.go @@ -12,11 +12,15 @@ const ( ) func (v *View) saveCurrentSequence(viewName string, event *models.Event) error { - return repository.SaveCurrentSequence(v.Db, sequencesTable, viewName, event.Sequence, event.CreationDate) + return repository.SaveCurrentSequence(v.Db, sequencesTable, viewName, event.InstanceID, event.Sequence, event.CreationDate) } -func (v *View) latestSequence(viewName string) (*repository.CurrentSequence, error) { - return repository.LatestSequence(v.Db, sequencesTable, viewName) +func (v *View) latestSequence(viewName, instanceID string) (*repository.CurrentSequence, error) { + return repository.LatestSequence(v.Db, sequencesTable, viewName, instanceID) +} + +func (v *View) latestSequences(viewName string) ([]*repository.CurrentSequence, error) { + return repository.LatestSequences(v.Db, sequencesTable, viewName) } func (v *View) AllCurrentSequences(db string) ([]*repository.CurrentSequence, error) { @@ -24,21 +28,23 @@ func (v *View) AllCurrentSequences(db string) ([]*repository.CurrentSequence, er } func (v *View) updateSpoolerRunSequence(viewName string) error { - currentSequence, err := repository.LatestSequence(v.Db, sequencesTable, viewName) + currentSequences, err := repository.LatestSequences(v.Db, sequencesTable, viewName) if err != nil { return err } - if currentSequence.ViewName == "" { - currentSequence.ViewName = viewName + for _, currentSequence := range currentSequences { + if currentSequence.ViewName == "" { + currentSequence.ViewName = viewName + } + currentSequence.LastSuccessfulSpoolerRun = time.Now() } - currentSequence.LastSuccessfulSpoolerRun = time.Now() - return repository.UpdateCurrentSequence(v.Db, sequencesTable, currentSequence) + return repository.UpdateCurrentSequences(v.Db, sequencesTable, currentSequences) } -func (v *View) GetCurrentSequence(db, viewName string) (*repository.CurrentSequence, error) { +func (v *View) GetCurrentSequence(db, viewName string) ([]*repository.CurrentSequence, error) { sequenceTable := db + ".current_sequences" fullView := db + "." + viewName - return repository.LatestSequence(v.Db, sequenceTable, fullView) + return repository.LatestSequences(v.Db, sequenceTable, fullView) } func (v *View) ClearView(db, viewName string) error { diff --git a/internal/admin/repository/eventsourcing/view/styling.go b/internal/admin/repository/eventsourcing/view/styling.go index 7d78172a25..811837bc48 100644 --- a/internal/admin/repository/eventsourcing/view/styling.go +++ b/internal/admin/repository/eventsourcing/view/styling.go @@ -11,8 +11,8 @@ const ( stylingTyble = "adminapi.styling" ) -func (v *View) StylingByAggregateIDAndState(aggregateID string, state int32) (*model.LabelPolicyView, error) { - return view.GetStylingByAggregateIDAndState(v.Db, stylingTyble, aggregateID, state) +func (v *View) StylingByAggregateIDAndState(aggregateID, instanceID string, state int32) (*model.LabelPolicyView, error) { + return view.GetStylingByAggregateIDAndState(v.Db, stylingTyble, aggregateID, instanceID, state) } func (v *View) PutStyling(policy *model.LabelPolicyView, event *models.Event) error { @@ -23,8 +23,12 @@ func (v *View) PutStyling(policy *model.LabelPolicyView, event *models.Event) er return v.ProcessedStylingSequence(event) } -func (v *View) GetLatestStylingSequence() (*global_view.CurrentSequence, error) { - return v.latestSequence(stylingTyble) +func (v *View) GetLatestStylingSequence(instanceID string) (*global_view.CurrentSequence, error) { + return v.latestSequence(stylingTyble, instanceID) +} + +func (v *View) GetLatestStylingSequences() ([]*global_view.CurrentSequence, error) { + return v.latestSequences(stylingTyble) } func (v *View) ProcessedStylingSequence(event *models.Event) error { @@ -35,8 +39,8 @@ func (v *View) UpdateStylingSpoolerRunTimestamp() error { return v.updateSpoolerRunSequence(stylingTyble) } -func (v *View) GetLatestStylingFailedEvent(sequence uint64) (*global_view.FailedEvent, error) { - return v.latestFailedEvent(stylingTyble, sequence) +func (v *View) GetLatestStylingFailedEvent(sequence uint64, instanceID string) (*global_view.FailedEvent, error) { + return v.latestFailedEvent(stylingTyble, instanceID, sequence) } func (v *View) ProcessedStylingFailedEvent(failedEvent *global_view.FailedEvent) error { diff --git a/internal/api/oidc/key.go b/internal/api/oidc/key.go index 6e04c21985..3ed2a875dd 100644 --- a/internal/api/oidc/key.go +++ b/internal/api/oidc/key.go @@ -8,13 +8,13 @@ import ( "github.com/caos/logging" "gopkg.in/square/go-jose.v2" - "github.com/caos/zitadel/internal/telemetry/tracing" - + "github.com/caos/zitadel/internal/api/authz" "github.com/caos/zitadel/internal/crypto" "github.com/caos/zitadel/internal/errors" "github.com/caos/zitadel/internal/eventstore" "github.com/caos/zitadel/internal/query" "github.com/caos/zitadel/internal/repository/keypair" + "github.com/caos/zitadel/internal/telemetry/tracing" ) const ( @@ -154,7 +154,7 @@ func (o *OPStorage) lockAndGenerateSigningKeyPair(ctx context.Context, algorithm ctx, cancel := context.WithCancel(ctx) defer cancel() - errs := o.locker.Lock(ctx, o.signingKeyRotationCheck*2) + errs := o.locker.Lock(ctx, o.signingKeyRotationCheck*2, authz.GetInstance(ctx).InstanceID()) err, ok := <-errs if err != nil || !ok { if errors.IsErrorAlreadyExists(err) { diff --git a/internal/auth/repository/eventsourcing/eventstore/auth_request.go b/internal/auth/repository/eventsourcing/eventstore/auth_request.go index d9e29a03c8..f942fd9cb5 100644 --- a/internal/auth/repository/eventsourcing/eventstore/auth_request.go +++ b/internal/auth/repository/eventsourcing/eventstore/auth_request.go @@ -61,12 +61,12 @@ type privacyPolicyProvider interface { } type userSessionViewProvider interface { - UserSessionByIDs(string, string) (*user_view_model.UserSessionView, error) - UserSessionsByAgentID(string) ([]*user_view_model.UserSessionView, error) + UserSessionByIDs(string, string, string) (*user_view_model.UserSessionView, error) + UserSessionsByAgentID(string, string) ([]*user_view_model.UserSessionView, error) PrefixAvatarURL() string } type userViewProvider interface { - UserByID(string) (*user_view_model.UserView, error) + UserByID(string, string) (*user_view_model.UserView, error) PrefixAvatarURL() string } @@ -79,7 +79,7 @@ type lockoutPolicyViewProvider interface { } type idpProviderViewProvider interface { - IDPProvidersByAggregateIDAndState(string, iam_model.IDPConfigState) ([]*iam_view_model.IDPProviderView, error) + IDPProvidersByAggregateIDAndState(string, string, iam_model.IDPConfigState) ([]*iam_view_model.IDPProviderView, error) } type userEventProvider interface { @@ -102,7 +102,7 @@ type userGrantProvider interface { type projectProvider interface { ProjectByOIDCClientID(context.Context, string) (*query.Project, error) - OrgProjectMappingByIDs(orgID, projectID string) (*project_view_model.OrgProjectMapping, error) + OrgProjectMappingByIDs(orgID, projectID, instanceID string) (*project_view_model.OrgProjectMapping, error) } type applicationProvider interface { @@ -596,7 +596,7 @@ func (repo *AuthRequestRepo) fillPolicies(ctx context.Context, request *domain.A } func (repo *AuthRequestRepo) tryUsingOnlyUserSession(request *domain.AuthRequest) error { - userSessions, err := userSessionsByUserAgentID(repo.UserSessionViewProvider, request.AgentID) + userSessions, err := userSessionsByUserAgentID(repo.UserSessionViewProvider, request.AgentID, request.InstanceID) if err != nil { return err } @@ -618,9 +618,9 @@ func (repo *AuthRequestRepo) checkLoginName(ctx context.Context, request *domain if request.RequestedOrgID != "" { preferredLoginName += "@" + request.RequestedPrimaryDomain } - user, err = repo.View.UserByLoginNameAndResourceOwner(preferredLoginName, request.RequestedOrgID) + user, err = repo.View.UserByLoginNameAndResourceOwner(preferredLoginName, request.RequestedOrgID, request.InstanceID) } else { - user, err = repo.View.UserByLoginName(loginName) + user, err = repo.View.UserByLoginName(loginName, request.InstanceID) if err == nil { err = repo.checkLoginPolicyWithResourceOwner(ctx, request, user) if err != nil { @@ -696,9 +696,9 @@ func (repo *AuthRequestRepo) checkSelectedExternalIDP(request *domain.AuthReques func (repo *AuthRequestRepo) checkExternalUserLogin(ctx context.Context, request *domain.AuthRequest, idpConfigID, externalUserID string) (err error) { externalIDP := new(user_view_model.ExternalIDPView) if request.RequestedOrgID != "" { - externalIDP, err = repo.View.ExternalIDPByExternalUserIDAndIDPConfigIDAndResourceOwner(externalUserID, idpConfigID, request.RequestedOrgID) + externalIDP, err = repo.View.ExternalIDPByExternalUserIDAndIDPConfigIDAndResourceOwner(externalUserID, idpConfigID, request.RequestedOrgID, request.InstanceID) } else { - externalIDP, err = repo.View.ExternalIDPByExternalUserIDAndIDPConfigID(externalUserID, idpConfigID) + externalIDP, err = repo.View.ExternalIDPByExternalUserIDAndIDPConfigID(externalUserID, idpConfigID, request.InstanceID) } if err != nil { return err @@ -828,7 +828,7 @@ func (repo *AuthRequestRepo) nextSteps(ctx context.Context, request *domain.Auth } func (repo *AuthRequestRepo) usersForUserSelection(request *domain.AuthRequest) ([]domain.UserSelection, error) { - userSessions, err := userSessionsByUserAgentID(repo.UserSessionViewProvider, request.AgentID) + userSessions, err := userSessionsByUserAgentID(repo.UserSessionViewProvider, request.AgentID, request.InstanceID) if err != nil { return nil, err } @@ -1044,13 +1044,13 @@ func setOrgID(orgViewProvider orgViewProvider, request *domain.AuthRequest) erro func getLoginPolicyIDPProviders(provider idpProviderViewProvider, iamID, orgID string, defaultPolicy bool) ([]*iam_model.IDPProviderView, error) { if defaultPolicy { - idpProviders, err := provider.IDPProvidersByAggregateIDAndState(iamID, iam_model.IDPConfigStateActive) + idpProviders, err := provider.IDPProvidersByAggregateIDAndState(iamID, iamID, iam_model.IDPConfigStateActive) if err != nil { return nil, err } return iam_view_model.IDPProviderViewsToModel(idpProviders), nil } - idpProviders, err := provider.IDPProvidersByAggregateIDAndState(orgID, iam_model.IDPConfigStateActive) + idpProviders, err := provider.IDPProvidersByAggregateIDAndState(orgID, iamID, iam_model.IDPConfigStateActive) if err != nil { return nil, err } @@ -1071,8 +1071,8 @@ func checkVerificationTime(verificationTime time.Time, lifetime time.Duration) b return verificationTime.Add(lifetime).After(time.Now().UTC()) } -func userSessionsByUserAgentID(provider userSessionViewProvider, agentID string) ([]*user_model.UserSessionView, error) { - session, err := provider.UserSessionsByAgentID(agentID) +func userSessionsByUserAgentID(provider userSessionViewProvider, agentID, instanceID string) ([]*user_model.UserSessionView, error) { + session, err := provider.UserSessionsByAgentID(agentID, instanceID) if err != nil { return nil, err } @@ -1080,7 +1080,7 @@ func userSessionsByUserAgentID(provider userSessionViewProvider, agentID string) } func userSessionByIDs(ctx context.Context, provider userSessionViewProvider, eventProvider userEventProvider, agentID string, user *user_model.UserView) (*user_model.UserSessionView, error) { - session, err := provider.UserSessionByIDs(agentID, user.ID) + session, err := provider.UserSessionByIDs(agentID, user.ID, authz.GetInstance(ctx).InstanceID()) if err != nil { if !errors.IsNotFound(err) { return nil, err @@ -1156,7 +1156,7 @@ func activeUserByID(ctx context.Context, userViewProvider userViewProvider, user } func userByID(ctx context.Context, viewProvider userViewProvider, eventProvider userEventProvider, userID string) (*user_model.UserView, error) { - user, viewErr := viewProvider.UserByID(userID) + user, viewErr := viewProvider.UserByID(userID, authz.GetInstance(ctx).InstanceID()) if viewErr != nil && !errors.IsNotFound(viewErr) { return nil, viewErr } else if user == nil { @@ -1254,7 +1254,7 @@ func projectRequired(ctx context.Context, request *domain.AuthRequest, projectPr if !project.HasProjectCheck { return false, nil } - _, err = projectProvider.OrgProjectMappingByIDs(request.UserOrgID, project.ID) + _, err = projectProvider.OrgProjectMappingByIDs(request.UserOrgID, project.ID, request.InstanceID) if errors.IsNotFound(err) { return true, nil } diff --git a/internal/auth/repository/eventsourcing/eventstore/auth_request_test.go b/internal/auth/repository/eventsourcing/eventstore/auth_request_test.go index 8d704f9e96..1f0db04f53 100644 --- a/internal/auth/repository/eventsourcing/eventstore/auth_request_test.go +++ b/internal/auth/repository/eventsourcing/eventstore/auth_request_test.go @@ -24,11 +24,11 @@ import ( type mockViewNoUserSession struct{} -func (m *mockViewNoUserSession) UserSessionByIDs(string, string) (*user_view_model.UserSessionView, error) { +func (m *mockViewNoUserSession) UserSessionByIDs(string, string, string) (*user_view_model.UserSessionView, error) { return nil, errors.ThrowNotFound(nil, "id", "user session not found") } -func (m *mockViewNoUserSession) UserSessionsByAgentID(string) ([]*user_view_model.UserSessionView, error) { +func (m *mockViewNoUserSession) UserSessionsByAgentID(string, string) ([]*user_view_model.UserSessionView, error) { return nil, nil } @@ -38,11 +38,11 @@ func (m *mockViewNoUserSession) PrefixAvatarURL() string { type mockViewErrUserSession struct{} -func (m *mockViewErrUserSession) UserSessionByIDs(string, string) (*user_view_model.UserSessionView, error) { +func (m *mockViewErrUserSession) UserSessionByIDs(string, string, string) (*user_view_model.UserSessionView, error) { return nil, errors.ThrowInternal(nil, "id", "internal error") } -func (m *mockViewErrUserSession) UserSessionsByAgentID(string) ([]*user_view_model.UserSessionView, error) { +func (m *mockViewErrUserSession) UserSessionsByAgentID(string, string) ([]*user_view_model.UserSessionView, error) { return nil, errors.ThrowInternal(nil, "id", "internal error") } @@ -65,7 +65,7 @@ type mockUser struct { ResourceOwner string } -func (m *mockViewUserSession) UserSessionByIDs(string, string) (*user_view_model.UserSessionView, error) { +func (m *mockViewUserSession) UserSessionByIDs(string, string, string) (*user_view_model.UserSessionView, error) { return &user_view_model.UserSessionView{ ExternalLoginVerification: m.ExternalLoginVerification, PasswordlessVerification: m.PasswordlessVerification, @@ -75,7 +75,7 @@ func (m *mockViewUserSession) UserSessionByIDs(string, string) (*user_view_model }, nil } -func (m *mockViewUserSession) UserSessionsByAgentID(string) ([]*user_view_model.UserSessionView, error) { +func (m *mockViewUserSession) UserSessionsByAgentID(string, string) ([]*user_view_model.UserSessionView, error) { sessions := make([]*user_view_model.UserSessionView, len(m.Users)) for i, user := range m.Users { sessions[i] = &user_view_model.UserSessionView{ @@ -93,7 +93,7 @@ func (m *mockViewUserSession) PrefixAvatarURL() string { type mockViewNoUser struct{} -func (m *mockViewNoUser) UserByID(string) (*user_view_model.UserView, error) { +func (m *mockViewNoUser) UserByID(string, string) (*user_view_model.UserView, error) { return nil, errors.ThrowNotFound(nil, "id", "user not found") } @@ -156,7 +156,7 @@ func (m *mockLockoutPolicy) LockoutPolicyByOrg(context.Context, string) (*query. return m.policy, nil } -func (m *mockViewUser) UserByID(string) (*user_view_model.UserView, error) { +func (m *mockViewUser) UserByID(string, string) (*user_view_model.UserView, error) { return &user_view_model.UserView{ State: int32(user_model.UserStateActive), UserName: "UserName", @@ -232,7 +232,7 @@ func (m *mockProject) ProjectByOIDCClientID(ctx context.Context, s string) (*que return &query.Project{HasProjectCheck: m.projectCheck}, nil } -func (m *mockProject) OrgProjectMappingByIDs(orgID, projectID string) (*proj_view_model.OrgProjectMapping, error) { +func (m *mockProject) OrgProjectMappingByIDs(orgID, projectID, instanceID string) (*proj_view_model.OrgProjectMapping, error) { if m.hasProject { return &proj_view_model.OrgProjectMapping{OrgID: orgID, ProjectID: projectID}, nil } diff --git a/internal/auth/repository/eventsourcing/eventstore/org.go b/internal/auth/repository/eventsourcing/eventstore/org.go index 543e61cac8..2e710d40eb 100644 --- a/internal/auth/repository/eventsourcing/eventstore/org.go +++ b/internal/auth/repository/eventsourcing/eventstore/org.go @@ -23,7 +23,7 @@ type OrgRepository struct { } func (repo *OrgRepository) GetIDPConfigByID(ctx context.Context, idpConfigID string) (*iam_model.IDPConfigView, error) { - idpConfig, err := repo.View.IDPConfigByID(idpConfigID) + idpConfig, err := repo.View.IDPConfigByID(idpConfigID, authz.GetInstance(ctx).InstanceID()) if err != nil { return nil, err } diff --git a/internal/auth/repository/eventsourcing/eventstore/refresh_token.go b/internal/auth/repository/eventsourcing/eventstore/refresh_token.go index 8572a68677..57f6b6b167 100644 --- a/internal/auth/repository/eventsourcing/eventstore/refresh_token.go +++ b/internal/auth/repository/eventsourcing/eventstore/refresh_token.go @@ -6,16 +6,16 @@ import ( "github.com/caos/logging" + "github.com/caos/zitadel/internal/api/authz" + "github.com/caos/zitadel/internal/auth/repository/eventsourcing/view" "github.com/caos/zitadel/internal/crypto" "github.com/caos/zitadel/internal/domain" + "github.com/caos/zitadel/internal/errors" v1 "github.com/caos/zitadel/internal/eventstore/v1" "github.com/caos/zitadel/internal/eventstore/v1/models" - usr_view "github.com/caos/zitadel/internal/user/repository/view" - - "github.com/caos/zitadel/internal/auth/repository/eventsourcing/view" - "github.com/caos/zitadel/internal/errors" "github.com/caos/zitadel/internal/telemetry/tracing" usr_model "github.com/caos/zitadel/internal/user/model" + usr_view "github.com/caos/zitadel/internal/user/repository/view" "github.com/caos/zitadel/internal/user/repository/view/model" ) @@ -31,7 +31,7 @@ func (r *RefreshTokenRepo) RefreshTokenByID(ctx context.Context, refreshToken st if err != nil { return nil, err } - tokenView, viewErr := r.View.RefreshTokenByID(tokenID) + tokenView, viewErr := r.View.RefreshTokenByID(tokenID, authz.GetInstance(ctx).InstanceID()) if viewErr != nil && !errors.IsNotFound(viewErr) { return nil, viewErr } @@ -41,7 +41,7 @@ func (r *RefreshTokenRepo) RefreshTokenByID(ctx context.Context, refreshToken st tokenView.UserID = userID } - events, esErr := r.getUserEvents(ctx, userID, tokenView.Sequence) + events, esErr := r.getUserEvents(ctx, userID, tokenView.InstanceID, tokenView.Sequence) if errors.IsNotFound(viewErr) && len(events) == 0 { return nil, errors.ThrowNotFound(nil, "EVENT-BHB52", "Errors.User.RefreshToken.Invalid") } @@ -68,7 +68,7 @@ func (r *RefreshTokenRepo) SearchMyRefreshTokens(ctx context.Context, userID str if err != nil { return nil, err } - sequence, err := r.View.GetLatestRefreshTokenSequence() + sequence, err := r.View.GetLatestRefreshTokenSequence(authz.GetInstance(ctx).InstanceID()) logging.Log("EVENT-GBdn4").OnError(err).WithField("traceID", tracing.TraceIDFromCtx(ctx)).Warn("could not read latest refresh token sequence") request.Queries = append(request.Queries, &usr_model.RefreshTokenSearchQuery{Key: usr_model.RefreshTokenSearchKeyUserID, Method: domain.SearchMethodEquals, Value: userID}) tokens, count, err := r.View.SearchRefreshTokens(request) @@ -85,8 +85,8 @@ func (r *RefreshTokenRepo) SearchMyRefreshTokens(ctx context.Context, userID str }, nil } -func (r *RefreshTokenRepo) getUserEvents(ctx context.Context, userID string, sequence uint64) ([]*models.Event, error) { - query, err := usr_view.UserByIDQuery(userID, sequence) +func (r *RefreshTokenRepo) getUserEvents(ctx context.Context, userID, instanceID string, sequence uint64) ([]*models.Event, error) { + query, err := usr_view.UserByIDQuery(userID, instanceID, sequence) if err != nil { return nil, err } diff --git a/internal/auth/repository/eventsourcing/eventstore/token.go b/internal/auth/repository/eventsourcing/eventstore/token.go index 765903228c..e799613e6d 100644 --- a/internal/auth/repository/eventsourcing/eventstore/token.go +++ b/internal/auth/repository/eventsourcing/eventstore/token.go @@ -2,18 +2,18 @@ package eventstore import ( "context" - "github.com/caos/zitadel/internal/eventstore/v1" "time" - "github.com/caos/zitadel/internal/eventstore/v1/models" - usr_view "github.com/caos/zitadel/internal/user/repository/view" - "github.com/caos/logging" + "github.com/caos/zitadel/internal/api/authz" "github.com/caos/zitadel/internal/auth/repository/eventsourcing/view" "github.com/caos/zitadel/internal/errors" + v1 "github.com/caos/zitadel/internal/eventstore/v1" + "github.com/caos/zitadel/internal/eventstore/v1/models" "github.com/caos/zitadel/internal/telemetry/tracing" usr_model "github.com/caos/zitadel/internal/user/model" + usr_view "github.com/caos/zitadel/internal/user/repository/view" "github.com/caos/zitadel/internal/user/repository/view/model" ) @@ -34,7 +34,7 @@ func (repo *TokenRepo) IsTokenValid(ctx context.Context, userID, tokenID string) } func (repo *TokenRepo) TokenByID(ctx context.Context, userID, tokenID string) (*usr_model.TokenView, error) { - token, viewErr := repo.View.TokenByID(tokenID) + token, viewErr := repo.View.TokenByID(tokenID, authz.GetInstance(ctx).InstanceID()) if viewErr != nil && !errors.IsNotFound(viewErr) { return nil, viewErr } @@ -44,7 +44,7 @@ func (repo *TokenRepo) TokenByID(ctx context.Context, userID, tokenID string) (* token.UserID = userID } - events, esErr := repo.getUserEvents(ctx, userID, token.Sequence) + events, esErr := repo.getUserEvents(ctx, userID, token.InstanceID, token.Sequence) if errors.IsNotFound(viewErr) && len(events) == 0 { return nil, errors.ThrowNotFound(nil, "EVENT-4T90g", "Errors.Token.NotFound") } @@ -66,8 +66,8 @@ func (repo *TokenRepo) TokenByID(ctx context.Context, userID, tokenID string) (* return model.TokenViewToModel(token), nil } -func (r *TokenRepo) getUserEvents(ctx context.Context, userID string, sequence uint64) ([]*models.Event, error) { - query, err := usr_view.UserByIDQuery(userID, sequence) +func (r *TokenRepo) getUserEvents(ctx context.Context, userID, instanceID string, sequence uint64) ([]*models.Event, error) { + query, err := usr_view.UserByIDQuery(userID, instanceID, sequence) if err != nil { return nil, err } diff --git a/internal/auth/repository/eventsourcing/eventstore/user.go b/internal/auth/repository/eventsourcing/eventstore/user.go index 3296482590..3ab69f99b6 100644 --- a/internal/auth/repository/eventsourcing/eventstore/user.go +++ b/internal/auth/repository/eventsourcing/eventstore/user.go @@ -3,6 +3,7 @@ package eventstore import ( "context" + "github.com/caos/zitadel/internal/api/authz" "github.com/caos/zitadel/internal/auth/repository/eventsourcing/view" "github.com/caos/zitadel/internal/config/systemdefaults" "github.com/caos/zitadel/internal/domain" @@ -26,7 +27,7 @@ func (repo *UserRepo) Health(ctx context.Context) error { } func (repo *UserRepo) UserSessionUserIDsByAgentID(ctx context.Context, agentID string) ([]string, error) { - userSessions, err := repo.View.UserSessionsByAgentID(agentID) + userSessions, err := repo.View.UserSessionsByAgentID(agentID, authz.GetInstance(ctx).InstanceID()) if err != nil { return nil, err } @@ -44,7 +45,7 @@ func (repo *UserRepo) UserEventsByID(ctx context.Context, id string, sequence ui } func (r *UserRepo) getUserEvents(ctx context.Context, userID string, sequence uint64) ([]*models.Event, error) { - query, err := usr_view.UserByIDQuery(userID, sequence) + query, err := usr_view.UserByIDQuery(userID, authz.GetInstance(ctx).InstanceID(), sequence) if err != nil { return nil, err } diff --git a/internal/auth/repository/eventsourcing/eventstore/user_session.go b/internal/auth/repository/eventsourcing/eventstore/user_session.go index 3092044b2d..6384714e8a 100644 --- a/internal/auth/repository/eventsourcing/eventstore/user_session.go +++ b/internal/auth/repository/eventsourcing/eventstore/user_session.go @@ -14,7 +14,7 @@ type UserSessionRepo struct { } func (repo *UserSessionRepo) GetMyUserSessions(ctx context.Context) ([]*usr_model.UserSessionView, error) { - userSessions, err := repo.View.UserSessionsByAgentID(authz.GetCtxData(ctx).AgentID) + userSessions, err := repo.View.UserSessionsByAgentID(authz.GetCtxData(ctx).AgentID, authz.GetInstance(ctx).InstanceID()) if err != nil { return nil, err } diff --git a/internal/auth/repository/eventsourcing/handler/idp_config.go b/internal/auth/repository/eventsourcing/handler/idp_config.go index a49211e561..b8e947804c 100644 --- a/internal/auth/repository/eventsourcing/handler/idp_config.go +++ b/internal/auth/repository/eventsourcing/handler/idp_config.go @@ -54,8 +54,8 @@ func (_ *IDPConfig) AggregateTypes() []models.AggregateType { return []models.AggregateType{org.AggregateType, instance.AggregateType} } -func (i *IDPConfig) CurrentSequence() (uint64, error) { - sequence, err := i.view.GetLatestIDPConfigSequence() +func (i *IDPConfig) CurrentSequence(instanceID string) (uint64, error) { + sequence, err := i.view.GetLatestIDPConfigSequence(instanceID) if err != nil { return 0, err } @@ -63,13 +63,30 @@ func (i *IDPConfig) CurrentSequence() (uint64, error) { } func (i *IDPConfig) EventQuery() (*models.SearchQuery, error) { - sequence, err := i.view.GetLatestIDPConfigSequence() + sequences, err := i.view.GetLatestIDPConfigSequences() if err != nil { return nil, err } - return models.NewSearchQuery(). + + query := models.NewSearchQuery() + instances := make([]string, 0) + for _, sequence := range sequences { + for _, instance := range instances { + if sequence.InstanceID == instance { + break + } + } + instances = append(instances, sequence.InstanceID) + query.AddQuery(). + AggregateTypeFilter(i.AggregateTypes()...). + LatestSequenceFilter(sequence.CurrentSequence). + InstanceIDFilter(sequence.InstanceID) + } + return query.AddQuery(). AggregateTypeFilter(i.AggregateTypes()...). - LatestSequenceFilter(sequence.CurrentSequence), nil + LatestSequenceFilter(0). + ExcludedInstanceIDsFilter(instances...). + SearchQuery(), nil } func (i *IDPConfig) Reduce(event *models.Event) (err error) { @@ -97,7 +114,7 @@ func (i *IDPConfig) processIdpConfig(providerType iam_model.IDPProviderType, eve if err != nil { return err } - idp, err = i.view.IDPConfigByID(idp.IDPConfigID) + idp, err = i.view.IDPConfigByID(idp.IDPConfigID, idp.InstanceID) if err != nil { return err } @@ -108,7 +125,7 @@ func (i *IDPConfig) processIdpConfig(providerType iam_model.IDPProviderType, eve if err != nil { return err } - idp, err = i.view.IDPConfigByID(idp.IDPConfigID) + idp, err = i.view.IDPConfigByID(idp.IDPConfigID, idp.InstanceID) if err != nil { return err } diff --git a/internal/auth/repository/eventsourcing/handler/idp_providers.go b/internal/auth/repository/eventsourcing/handler/idp_providers.go index b1026af09d..4c955f7b77 100644 --- a/internal/auth/repository/eventsourcing/handler/idp_providers.go +++ b/internal/auth/repository/eventsourcing/handler/idp_providers.go @@ -4,6 +4,7 @@ import ( "context" "github.com/caos/logging" + "github.com/caos/zitadel/internal/config/systemdefaults" "github.com/caos/zitadel/internal/domain" "github.com/caos/zitadel/internal/eventstore" @@ -67,8 +68,8 @@ func (_ *IDPProvider) AggregateTypes() []models.AggregateType { return []es_models.AggregateType{instance.AggregateType, org.AggregateType} } -func (i *IDPProvider) CurrentSequence() (uint64, error) { - sequence, err := i.view.GetLatestIDPProviderSequence() +func (i *IDPProvider) CurrentSequence(instanceID string) (uint64, error) { + sequence, err := i.view.GetLatestIDPProviderSequence(instanceID) if err != nil { return 0, err } @@ -76,13 +77,29 @@ func (i *IDPProvider) CurrentSequence() (uint64, error) { } func (i *IDPProvider) EventQuery() (*models.SearchQuery, error) { - sequence, err := i.view.GetLatestIDPProviderSequence() + sequences, err := i.view.GetLatestIDPProviderSequences() if err != nil { return nil, err } - return es_models.NewSearchQuery(). + query := es_models.NewSearchQuery() + instances := make([]string, 0) + for _, sequence := range sequences { + for _, instance := range instances { + if sequence.InstanceID == instance { + break + } + } + instances = append(instances, sequence.InstanceID) + query.AddQuery(). + AggregateTypeFilter(i.AggregateTypes()...). + LatestSequenceFilter(sequence.CurrentSequence). + InstanceIDFilter(sequence.InstanceID) + } + return query.AddQuery(). AggregateTypeFilter(i.AggregateTypes()...). - LatestSequenceFilter(sequence.CurrentSequence), nil + LatestSequenceFilter(0). + ExcludedInstanceIDsFilter(instances...). + SearchQuery(), nil } func (i *IDPProvider) Reduce(event *models.Event) (err error) { @@ -108,7 +125,7 @@ func (i *IDPProvider) processIdpProvider(event *models.Event) (err error) { if err != nil { return err } - return i.view.DeleteIDPProvider(event.AggregateID, provider.IDPConfigID, event) + return i.view.DeleteIDPProvider(event.AggregateID, provider.IDPConfigID, event.InstanceID, event) case instance.IDPConfigChangedEventType, org.IDPConfigChangedEventType: esConfig := new(iam_view_model.IDPConfigView) providerType := iam_model.IDPProviderTypeSystem @@ -116,7 +133,7 @@ func (i *IDPProvider) processIdpProvider(event *models.Event) (err error) { providerType = iam_model.IDPProviderTypeOrg } esConfig.AppendEvent(providerType, event) - providers, err := i.view.IDPProvidersByIDPConfigID(esConfig.IDPConfigID) + providers, err := i.view.IDPProvidersByIDPConfigID(esConfig.IDPConfigID, esConfig.InstanceID) if err != nil { return err } @@ -134,7 +151,7 @@ func (i *IDPProvider) processIdpProvider(event *models.Event) (err error) { } return i.view.PutIDPProviders(event, providers...) case org.LoginPolicyRemovedEventType: - return i.view.DeleteIDPProvidersByAggregateID(event.AggregateID, event) + return i.view.DeleteIDPProvidersByAggregateID(event.AggregateID, event.InstanceID, event) default: return i.view.ProcessedIDPProviderSequence(event) } diff --git a/internal/auth/repository/eventsourcing/handler/org_project_mapping.go b/internal/auth/repository/eventsourcing/handler/org_project_mapping.go index ea2e98f35c..7753f65ed4 100644 --- a/internal/auth/repository/eventsourcing/handler/org_project_mapping.go +++ b/internal/auth/repository/eventsourcing/handler/org_project_mapping.go @@ -8,7 +8,6 @@ import ( es_models "github.com/caos/zitadel/internal/eventstore/v1/models" "github.com/caos/zitadel/internal/eventstore/v1/query" "github.com/caos/zitadel/internal/eventstore/v1/spooler" - proj_view "github.com/caos/zitadel/internal/project/repository/view" view_model "github.com/caos/zitadel/internal/project/repository/view/model" "github.com/caos/zitadel/internal/repository/project" ) @@ -55,8 +54,8 @@ func (_ *OrgProjectMapping) AggregateTypes() []es_models.AggregateType { return []es_models.AggregateType{project.AggregateType} } -func (p *OrgProjectMapping) CurrentSequence() (uint64, error) { - sequence, err := p.view.GetLatestOrgProjectMappingSequence() +func (p *OrgProjectMapping) CurrentSequence(instanceID string) (uint64, error) { + sequence, err := p.view.GetLatestOrgProjectMappingSequence(instanceID) if err != nil { return 0, err } @@ -64,11 +63,29 @@ func (p *OrgProjectMapping) CurrentSequence() (uint64, error) { } func (p *OrgProjectMapping) EventQuery() (*es_models.SearchQuery, error) { - sequence, err := p.view.GetLatestOrgProjectMappingSequence() + sequences, err := p.view.GetLatestOrgProjectMappingSequences() if err != nil { return nil, err } - return proj_view.ProjectQuery(sequence.CurrentSequence), nil + query := es_models.NewSearchQuery() + instances := make([]string, 0) + for _, sequence := range sequences { + for _, instance := range instances { + if sequence.InstanceID == instance { + break + } + } + instances = append(instances, sequence.InstanceID) + query.AddQuery(). + AggregateTypeFilter(p.AggregateTypes()...). + LatestSequenceFilter(sequence.CurrentSequence). + InstanceIDFilter(sequence.InstanceID) + } + return query.AddQuery(). + AggregateTypeFilter(p.AggregateTypes()...). + LatestSequenceFilter(0). + ExcludedInstanceIDsFilter(instances...). + SearchQuery(), nil } func (p *OrgProjectMapping) Reduce(event *es_models.Event) (err error) { @@ -79,7 +96,7 @@ func (p *OrgProjectMapping) Reduce(event *es_models.Event) (err error) { mapping.ProjectID = event.AggregateID mapping.InstanceID = event.InstanceID case project.ProjectRemovedType: - err := p.view.DeleteOrgProjectMappingsByProjectID(event.AggregateID) + err := p.view.DeleteOrgProjectMappingsByProjectID(event.AggregateID, event.InstanceID) if err == nil { return p.view.ProcessedOrgProjectMappingSequence(event) } @@ -93,7 +110,7 @@ func (p *OrgProjectMapping) Reduce(event *es_models.Event) (err error) { case project.GrantRemovedType: projectGrant := new(view_model.ProjectGrant) projectGrant.SetData(event) - err := p.view.DeleteOrgProjectMappingsByProjectGrantID(event.AggregateID) + err := p.view.DeleteOrgProjectMappingsByProjectGrantID(event.AggregateID, event.InstanceID) if err == nil { return p.view.ProcessedOrgProjectMappingSequence(event) } diff --git a/internal/auth/repository/eventsourcing/handler/refresh_token.go b/internal/auth/repository/eventsourcing/handler/refresh_token.go index c44bc9e71c..926a045860 100644 --- a/internal/auth/repository/eventsourcing/handler/refresh_token.go +++ b/internal/auth/repository/eventsourcing/handler/refresh_token.go @@ -58,8 +58,8 @@ func (t *RefreshToken) AggregateTypes() []es_models.AggregateType { return []es_models.AggregateType{user.AggregateType, project.AggregateType} } -func (t *RefreshToken) CurrentSequence() (uint64, error) { - sequence, err := t.view.GetLatestRefreshTokenSequence() +func (t *RefreshToken) CurrentSequence(instanceID string) (uint64, error) { + sequence, err := t.view.GetLatestRefreshTokenSequence(instanceID) if err != nil { return 0, err } @@ -67,13 +67,29 @@ func (t *RefreshToken) CurrentSequence() (uint64, error) { } func (t *RefreshToken) EventQuery() (*es_models.SearchQuery, error) { - sequence, err := t.view.GetLatestRefreshTokenSequence() + sequences, err := t.view.GetLatestRefreshTokenSequences() if err != nil { return nil, err } - return es_models.NewSearchQuery(). - AggregateTypeFilter(user.AggregateType, project.AggregateType). - LatestSequenceFilter(sequence.CurrentSequence), nil + query := es_models.NewSearchQuery() + instances := make([]string, 0) + for _, sequence := range sequences { + for _, instance := range instances { + if sequence.InstanceID == instance { + break + } + } + instances = append(instances, sequence.InstanceID) + query.AddQuery(). + AggregateTypeFilter(t.AggregateTypes()...). + LatestSequenceFilter(sequence.CurrentSequence). + InstanceIDFilter(sequence.InstanceID) + } + return query.AddQuery(). + AggregateTypeFilter(t.AggregateTypes()...). + LatestSequenceFilter(0). + ExcludedInstanceIDsFilter(instances...). + SearchQuery(), nil } func (t *RefreshToken) Reduce(event *es_models.Event) (err error) { @@ -91,7 +107,7 @@ func (t *RefreshToken) Reduce(event *es_models.Event) (err error) { logging.Log("EVEN-DBbn4").WithError(err).Error("could not unmarshal event data") return caos_errs.ThrowInternal(nil, "MODEL-BHn75", "could not unmarshal data") } - token, err := t.view.RefreshTokenByID(e.TokenID) + token, err := t.view.RefreshTokenByID(e.TokenID, event.InstanceID) if err != nil { return err } @@ -106,11 +122,11 @@ func (t *RefreshToken) Reduce(event *es_models.Event) (err error) { logging.Log("EVEN-BDbh3").WithError(err).Error("could not unmarshal event data") return caos_errs.ThrowInternal(nil, "MODEL-Bz653", "could not unmarshal data") } - return t.view.DeleteRefreshToken(e.TokenID, event) + return t.view.DeleteRefreshToken(e.TokenID, event.InstanceID, event) case user.UserLockedType, user.UserDeactivatedType, user.UserRemovedType: - return t.view.DeleteUserRefreshTokens(event.AggregateID, event) + return t.view.DeleteUserRefreshTokens(event.AggregateID, event.InstanceID, event) default: return t.view.ProcessedRefreshTokenSequence(event) } diff --git a/internal/auth/repository/eventsourcing/handler/token.go b/internal/auth/repository/eventsourcing/handler/token.go index 15d9f99903..e2c39f4e47 100644 --- a/internal/auth/repository/eventsourcing/handler/token.go +++ b/internal/auth/repository/eventsourcing/handler/token.go @@ -64,8 +64,8 @@ func (_ *Token) AggregateTypes() []es_models.AggregateType { return []es_models.AggregateType{user.AggregateType, project.AggregateType} } -func (p *Token) CurrentSequence() (uint64, error) { - sequence, err := p.view.GetLatestTokenSequence() +func (p *Token) CurrentSequence(instanceID string) (uint64, error) { + sequence, err := p.view.GetLatestTokenSequence(instanceID) if err != nil { return 0, err } @@ -73,13 +73,29 @@ func (p *Token) CurrentSequence() (uint64, error) { } func (t *Token) EventQuery() (*es_models.SearchQuery, error) { - sequence, err := t.view.GetLatestTokenSequence() + sequences, err := t.view.GetLatestTokenSequences() if err != nil { return nil, err } - return es_models.NewSearchQuery(). - AggregateTypeFilter(user.AggregateType, project.AggregateType). - LatestSequenceFilter(sequence.CurrentSequence), nil + query := es_models.NewSearchQuery() + instances := make([]string, 0) + for _, sequence := range sequences { + for _, instance := range instances { + if sequence.InstanceID == instance { + break + } + } + instances = append(instances, sequence.InstanceID) + query.AddQuery(). + AggregateTypeFilter(t.AggregateTypes()...). + LatestSequenceFilter(sequence.CurrentSequence). + InstanceIDFilter(sequence.InstanceID) + } + return query.AddQuery(). + AggregateTypeFilter(t.AggregateTypes()...). + LatestSequenceFilter(0). + ExcludedInstanceIDsFilter(instances...). + SearchQuery(), nil } func (t *Token) Reduce(event *es_models.Event) (err error) { @@ -96,7 +112,7 @@ func (t *Token) Reduce(event *es_models.Event) (err error) { user.HumanProfileChangedType: user := new(view_model.UserView) user.AppendEvent(event) - tokens, err := t.view.TokensByUserID(event.AggregateID) + tokens, err := t.view.TokensByUserID(event.AggregateID, event.InstanceID) if err != nil { return err } @@ -110,24 +126,24 @@ func (t *Token) Reduce(event *es_models.Event) (err error) { if err != nil { return err } - return t.view.DeleteSessionTokens(id, event.AggregateID, event) + return t.view.DeleteSessionTokens(id, event.AggregateID, event.InstanceID, event) case user.UserLockedType, user.UserDeactivatedType, user.UserRemovedType: - return t.view.DeleteUserTokens(event.AggregateID, event) + return t.view.DeleteUserTokens(event.AggregateID, event.InstanceID, event) case user_repo.UserTokenRemovedType, user_repo.PersonalAccessTokenRemovedType: id, err := tokenIDFromRemovedEvent(event) if err != nil { return err } - return t.view.DeleteToken(id, event) + return t.view.DeleteToken(id, event.InstanceID, event) case user_repo.HumanRefreshTokenRemovedType: id, err := refreshTokenIDFromRemovedEvent(event) if err != nil { return err } - return t.view.DeleteTokensFromRefreshToken(id, event) + return t.view.DeleteTokensFromRefreshToken(id, event.InstanceID, event) case project.ApplicationDeactivatedType, project.ApplicationRemovedType: application, err := applicationFromSession(event) @@ -137,7 +153,7 @@ func (t *Token) Reduce(event *es_models.Event) (err error) { return t.view.DeleteApplicationTokens(event, application.AppID) case project.ProjectDeactivatedType, project.ProjectRemovedType: - project, err := t.getProjectByID(context.Background(), event.AggregateID) + project, err := t.getProjectByID(context.Background(), event.AggregateID, event.InstanceID) if err != nil { return err } @@ -196,8 +212,8 @@ func (t *Token) OnSuccess() error { return spooler.HandleSuccess(t.view.UpdateTokenSpoolerRunTimestamp) } -func (t *Token) getProjectByID(ctx context.Context, projID string) (*proj_model.Project, error) { - query, err := proj_view.ProjectByIDQuery(projID, 0) +func (t *Token) getProjectByID(ctx context.Context, projID, instanceID string) (*proj_model.Project, error) { + query, err := proj_view.ProjectByIDQuery(projID, instanceID, 0) if err != nil { return nil, err } diff --git a/internal/auth/repository/eventsourcing/handler/user.go b/internal/auth/repository/eventsourcing/handler/user.go index 8e1b45e295..ae1938027d 100644 --- a/internal/auth/repository/eventsourcing/handler/user.go +++ b/internal/auth/repository/eventsourcing/handler/user.go @@ -65,8 +65,8 @@ func (_ *User) AggregateTypes() []es_models.AggregateType { return []es_models.AggregateType{user_repo.AggregateType, org.AggregateType} } -func (u *User) CurrentSequence() (uint64, error) { - sequence, err := u.view.GetLatestUserSequence() +func (u *User) CurrentSequence(instanceID string) (uint64, error) { + sequence, err := u.view.GetLatestUserSequence(instanceID) if err != nil { return 0, err } @@ -74,13 +74,29 @@ func (u *User) CurrentSequence() (uint64, error) { } func (u *User) EventQuery() (*es_models.SearchQuery, error) { - sequence, err := u.view.GetLatestUserSequence() + sequences, err := u.view.GetLatestUserSequences() if err != nil { return nil, err } - return es_models.NewSearchQuery(). + query := es_models.NewSearchQuery() + instances := make([]string, 0) + for _, sequence := range sequences { + for _, instance := range instances { + if sequence.InstanceID == instance { + break + } + } + instances = append(instances, sequence.InstanceID) + query.AddQuery(). + AggregateTypeFilter(u.AggregateTypes()...). + LatestSequenceFilter(sequence.CurrentSequence). + InstanceIDFilter(sequence.InstanceID) + } + return query.AddQuery(). AggregateTypeFilter(u.AggregateTypes()...). - LatestSequenceFilter(sequence.CurrentSequence), nil + LatestSequenceFilter(0). + ExcludedInstanceIDsFilter(instances...). + SearchQuery(), nil } func (u *User) Reduce(event *es_models.Event) (err error) { @@ -146,14 +162,14 @@ func (u *User) ProcessUser(event *es_models.Event) (err error) { user_repo.HumanPasswordChangedType, user_repo.HumanPasswordlessInitCodeAddedType, user_repo.HumanPasswordlessInitCodeRequestedType: - user, err = u.view.UserByID(event.AggregateID) + user, err = u.view.UserByID(event.AggregateID, event.InstanceID) if err != nil { return err } err = user.AppendEvent(event) case user_repo.UserDomainClaimedType, user_repo.UserUserNameChangedType: - user, err = u.view.UserByID(event.AggregateID) + user, err = u.view.UserByID(event.AggregateID, event.InstanceID) if err != nil { return err } @@ -163,7 +179,7 @@ func (u *User) ProcessUser(event *es_models.Event) (err error) { } err = u.fillLoginNames(user) case user_repo.UserRemovedType: - return u.view.DeleteUser(event.AggregateID, event) + return u.view.DeleteUser(event.AggregateID, event.InstanceID, event) default: return u.view.ProcessedUserSequence(event) } @@ -203,7 +219,7 @@ func (u *User) fillLoginNamesOnOrgUsers(event *es_models.Event) error { if err != nil { return err } - users, err := u.view.UsersByOrgID(event.AggregateID) + users, err := u.view.UsersByOrgID(event.AggregateID, event.InstanceID) if err != nil { return err } @@ -221,7 +237,7 @@ func (u *User) fillPreferredLoginNamesOnOrgUsers(event *es_models.Event) error { if !userLoginMustBeDomain { return nil } - users, err := u.view.UsersByOrgID(event.AggregateID) + users, err := u.view.UsersByOrgID(event.AggregateID, event.InstanceID) if err != nil { return err } diff --git a/internal/auth/repository/eventsourcing/handler/user_external_idps.go b/internal/auth/repository/eventsourcing/handler/user_external_idps.go index 7478500e0e..974a9af529 100644 --- a/internal/auth/repository/eventsourcing/handler/user_external_idps.go +++ b/internal/auth/repository/eventsourcing/handler/user_external_idps.go @@ -69,8 +69,8 @@ func (_ *ExternalIDP) AggregateTypes() []es_models.AggregateType { return []es_models.AggregateType{user.AggregateType, instance.AggregateType, org.AggregateType} } -func (i *ExternalIDP) CurrentSequence() (uint64, error) { - sequence, err := i.view.GetLatestExternalIDPSequence() +func (i *ExternalIDP) CurrentSequence(instanceID string) (uint64, error) { + sequence, err := i.view.GetLatestExternalIDPSequence(instanceID) if err != nil { return 0, err } @@ -78,13 +78,29 @@ func (i *ExternalIDP) CurrentSequence() (uint64, error) { } func (i *ExternalIDP) EventQuery() (*es_models.SearchQuery, error) { - sequence, err := i.view.GetLatestExternalIDPSequence() + sequences, err := i.view.GetLatestExternalIDPSequences() if err != nil { return nil, err } - return es_models.NewSearchQuery(). + query := es_models.NewSearchQuery() + instances := make([]string, 0) + for _, sequence := range sequences { + for _, instance := range instances { + if sequence.InstanceID == instance { + break + } + } + instances = append(instances, sequence.InstanceID) + query.AddQuery(). + AggregateTypeFilter(i.AggregateTypes()...). + LatestSequenceFilter(sequence.CurrentSequence). + InstanceIDFilter(sequence.InstanceID) + } + return query.AddQuery(). AggregateTypeFilter(i.AggregateTypes()...). - LatestSequenceFilter(sequence.CurrentSequence), nil + LatestSequenceFilter(0). + ExcludedInstanceIDsFilter(instances...). + SearchQuery(), nil } func (i *ExternalIDP) Reduce(event *es_models.Event) (err error) { @@ -111,9 +127,9 @@ func (i *ExternalIDP) processUser(event *es_models.Event) (err error) { if err != nil { return err } - return i.view.DeleteExternalIDP(externalIDP.ExternalUserID, externalIDP.IDPConfigID, event) + return i.view.DeleteExternalIDP(externalIDP.ExternalUserID, externalIDP.IDPConfigID, externalIDP.InstanceID, event) case user.UserRemovedType: - return i.view.DeleteExternalIDPsByUserID(event.AggregateID, event) + return i.view.DeleteExternalIDPsByUserID(event.AggregateID, event.InstanceID, event) default: return i.view.ProcessedExternalIDPSequence(event) } @@ -133,7 +149,7 @@ func (i *ExternalIDP) processIdpConfig(event *es_models.Event) (err error) { } else { configView.AppendEvent(iam_model.IDPProviderTypeOrg, event) } - exterinalIDPs, err := i.view.ExternalIDPsByIDPConfigID(configView.IDPConfigID) + exterinalIDPs, err := i.view.ExternalIDPsByIDPConfigID(configView.IDPConfigID, configView.InstanceID) if err != nil { return err } diff --git a/internal/auth/repository/eventsourcing/handler/user_session.go b/internal/auth/repository/eventsourcing/handler/user_session.go index d57d76b40e..04e96b66f0 100644 --- a/internal/auth/repository/eventsourcing/handler/user_session.go +++ b/internal/auth/repository/eventsourcing/handler/user_session.go @@ -11,7 +11,6 @@ import ( "github.com/caos/zitadel/internal/eventstore/v1/query" "github.com/caos/zitadel/internal/eventstore/v1/spooler" "github.com/caos/zitadel/internal/repository/user" - "github.com/caos/zitadel/internal/user/repository/view" view_model "github.com/caos/zitadel/internal/user/repository/view/model" ) @@ -57,8 +56,8 @@ func (_ *UserSession) AggregateTypes() []models.AggregateType { return []models.AggregateType{user.AggregateType} } -func (u *UserSession) CurrentSequence() (uint64, error) { - sequence, err := u.view.GetLatestUserSessionSequence() +func (u *UserSession) CurrentSequence(instanceID string) (uint64, error) { + sequence, err := u.view.GetLatestUserSessionSequence(instanceID) if err != nil { return 0, err } @@ -66,11 +65,29 @@ func (u *UserSession) CurrentSequence() (uint64, error) { } func (u *UserSession) EventQuery() (*models.SearchQuery, error) { - sequence, err := u.view.GetLatestUserSessionSequence() + sequences, err := u.view.GetLatestUserSessionSequences() if err != nil { return nil, err } - return view.UserQuery(sequence.CurrentSequence), nil + query := models.NewSearchQuery() + instances := make([]string, 0) + for _, sequence := range sequences { + for _, instance := range instances { + if sequence.InstanceID == instance { + break + } + } + instances = append(instances, sequence.InstanceID) + query.AddQuery(). + AggregateTypeFilter(u.AggregateTypes()...). + LatestSequenceFilter(sequence.CurrentSequence). + InstanceIDFilter(sequence.InstanceID) + } + return query.AddQuery(). + AggregateTypeFilter(u.AggregateTypes()...). + LatestSequenceFilter(0). + ExcludedInstanceIDsFilter(instances...). + SearchQuery(), nil } func (u *UserSession) Reduce(event *models.Event) (err error) { @@ -95,7 +112,7 @@ func (u *UserSession) Reduce(event *models.Event) (err error) { if err != nil { return err } - session, err = u.view.UserSessionByIDs(eventData.UserAgentID, event.AggregateID) + session, err = u.view.UserSessionByIDs(eventData.UserAgentID, event.AggregateID, event.InstanceID) if err != nil { if !errors.IsNotFound(err) { return err @@ -126,7 +143,7 @@ func (u *UserSession) Reduce(event *models.Event) (err error) { user.UserIDPLinkCascadeRemovedType, user.HumanPasswordlessTokenRemovedType, user.HumanU2FTokenRemovedType: - sessions, err := u.view.UserSessionsByUserID(event.AggregateID) + sessions, err := u.view.UserSessionsByUserID(event.AggregateID, event.InstanceID) if err != nil { return err } @@ -143,7 +160,7 @@ func (u *UserSession) Reduce(event *models.Event) (err error) { } return u.view.PutUserSessions(sessions, event) case user.UserRemovedType: - return u.view.DeleteUserSessions(event.AggregateID, event) + return u.view.DeleteUserSessions(event.AggregateID, event.InstanceID, event) default: return u.view.ProcessedUserSessionSequence(event) } @@ -169,7 +186,7 @@ func (u *UserSession) updateSession(session *view_model.UserSessionView, event * } func (u *UserSession) fillUserInfo(session *view_model.UserSessionView, id string) error { - user, err := u.view.UserByID(id) + user, err := u.view.UserByID(id, session.InstanceID) if err != nil { return err } diff --git a/internal/auth/repository/eventsourcing/spooler/lock.go b/internal/auth/repository/eventsourcing/spooler/lock.go index a7b0bb41ad..b29d445353 100644 --- a/internal/auth/repository/eventsourcing/spooler/lock.go +++ b/internal/auth/repository/eventsourcing/spooler/lock.go @@ -19,6 +19,6 @@ func NewLocker(client *sql.DB) *locker { return &locker{dbClient: client} } -func (l *locker) Renew(lockerID, viewModel string, waitTime time.Duration) error { - return es_locker.Renew(l.dbClient, lockTable, lockerID, viewModel, waitTime) +func (l *locker) Renew(lockerID, viewModel, instanceID string, waitTime time.Duration) error { + return es_locker.Renew(l.dbClient, lockTable, lockerID, viewModel, instanceID, waitTime) } diff --git a/internal/auth/repository/eventsourcing/view/error_event.go b/internal/auth/repository/eventsourcing/view/error_event.go index d1c322ca6f..81bbcaa6b2 100644 --- a/internal/auth/repository/eventsourcing/view/error_event.go +++ b/internal/auth/repository/eventsourcing/view/error_event.go @@ -12,6 +12,6 @@ func (v *View) saveFailedEvent(failedEvent *repository.FailedEvent) error { return repository.SaveFailedEvent(v.Db, errTable, failedEvent) } -func (v *View) latestFailedEvent(viewName string, sequence uint64) (*repository.FailedEvent, error) { - return repository.LatestFailedEvent(v.Db, errTable, viewName, sequence) +func (v *View) latestFailedEvent(viewName, instanceID string, sequence uint64) (*repository.FailedEvent, error) { + return repository.LatestFailedEvent(v.Db, errTable, viewName, instanceID, sequence) } diff --git a/internal/auth/repository/eventsourcing/view/external_idps.go b/internal/auth/repository/eventsourcing/view/external_idps.go index 7fa21cc500..b2dcedcf4c 100644 --- a/internal/auth/repository/eventsourcing/view/external_idps.go +++ b/internal/auth/repository/eventsourcing/view/external_idps.go @@ -12,16 +12,16 @@ const ( externalIDPTable = "auth.user_external_idps" ) -func (v *View) ExternalIDPByExternalUserIDAndIDPConfigID(externalUserID, idpConfigID string) (*model.ExternalIDPView, error) { - return view.ExternalIDPByExternalUserIDAndIDPConfigID(v.Db, externalIDPTable, externalUserID, idpConfigID) +func (v *View) ExternalIDPByExternalUserIDAndIDPConfigID(externalUserID, idpConfigID, instanceID string) (*model.ExternalIDPView, error) { + return view.ExternalIDPByExternalUserIDAndIDPConfigID(v.Db, externalIDPTable, externalUserID, idpConfigID, instanceID) } -func (v *View) ExternalIDPByExternalUserIDAndIDPConfigIDAndResourceOwner(externalUserID, idpConfigID, resourceOwner string) (*model.ExternalIDPView, error) { - return view.ExternalIDPByExternalUserIDAndIDPConfigIDAndResourceOwner(v.Db, externalIDPTable, externalUserID, idpConfigID, resourceOwner) +func (v *View) ExternalIDPByExternalUserIDAndIDPConfigIDAndResourceOwner(externalUserID, idpConfigID, resourceOwner, instanceID string) (*model.ExternalIDPView, error) { + return view.ExternalIDPByExternalUserIDAndIDPConfigIDAndResourceOwner(v.Db, externalIDPTable, externalUserID, idpConfigID, resourceOwner, instanceID) } -func (v *View) ExternalIDPsByIDPConfigID(idpConfigID string) ([]*model.ExternalIDPView, error) { - return view.ExternalIDPsByIDPConfigID(v.Db, externalIDPTable, idpConfigID) +func (v *View) ExternalIDPsByIDPConfigID(idpConfigID, instanceID string) ([]*model.ExternalIDPView, error) { + return view.ExternalIDPsByIDPConfigID(v.Db, externalIDPTable, idpConfigID, instanceID) } func (v *View) PutExternalIDP(externalIDP *model.ExternalIDPView, event *models.Event) error { @@ -40,24 +40,28 @@ func (v *View) PutExternalIDPs(event *models.Event, externalIDPs ...*model.Exter return v.ProcessedExternalIDPSequence(event) } -func (v *View) DeleteExternalIDP(externalUserID, idpConfigID string, event *models.Event) error { - err := view.DeleteExternalIDP(v.Db, externalIDPTable, externalUserID, idpConfigID) +func (v *View) DeleteExternalIDP(externalUserID, idpConfigID, instanceID string, event *models.Event) error { + err := view.DeleteExternalIDP(v.Db, externalIDPTable, externalUserID, idpConfigID, instanceID) if err != nil && !errors.IsNotFound(err) { return err } return v.ProcessedExternalIDPSequence(event) } -func (v *View) DeleteExternalIDPsByUserID(userID string, event *models.Event) error { - err := view.DeleteExternalIDPsByUserID(v.Db, externalIDPTable, userID) +func (v *View) DeleteExternalIDPsByUserID(userID, instanceID string, event *models.Event) error { + err := view.DeleteExternalIDPsByUserID(v.Db, externalIDPTable, userID, instanceID) if err != nil { return err } return v.ProcessedExternalIDPSequence(event) } -func (v *View) GetLatestExternalIDPSequence() (*global_view.CurrentSequence, error) { - return v.latestSequence(externalIDPTable) +func (v *View) GetLatestExternalIDPSequence(instanceID string) (*global_view.CurrentSequence, error) { + return v.latestSequence(externalIDPTable, instanceID) +} + +func (v *View) GetLatestExternalIDPSequences() ([]*global_view.CurrentSequence, error) { + return v.latestSequences(externalIDPTable) } func (v *View) ProcessedExternalIDPSequence(event *models.Event) error { @@ -68,8 +72,8 @@ func (v *View) UpdateExternalIDPSpoolerRunTimestamp() error { return v.updateSpoolerRunSequence(externalIDPTable) } -func (v *View) GetLatestExternalIDPFailedEvent(sequence uint64) (*global_view.FailedEvent, error) { - return v.latestFailedEvent(externalIDPTable, sequence) +func (v *View) GetLatestExternalIDPFailedEvent(sequence uint64, instanceID string) (*global_view.FailedEvent, error) { + return v.latestFailedEvent(externalIDPTable, instanceID, sequence) } func (v *View) ProcessedExternalIDPFailedEvent(failedEvent *global_view.FailedEvent) error { diff --git a/internal/auth/repository/eventsourcing/view/idp_configs.go b/internal/auth/repository/eventsourcing/view/idp_configs.go index 5e0d604c6f..60e1ebc3d1 100644 --- a/internal/auth/repository/eventsourcing/view/idp_configs.go +++ b/internal/auth/repository/eventsourcing/view/idp_configs.go @@ -13,12 +13,12 @@ const ( idpConfigTable = "auth.idp_configs" ) -func (v *View) IDPConfigByID(idpID string) (*iam_es_model.IDPConfigView, error) { - return view.IDPByID(v.Db, idpConfigTable, idpID) +func (v *View) IDPConfigByID(idpID, instanceID string) (*iam_es_model.IDPConfigView, error) { + return view.IDPByID(v.Db, idpConfigTable, idpID, instanceID) } -func (v *View) GetIDPConfigsByAggregateID(aggregateID string) ([]*iam_es_model.IDPConfigView, error) { - return view.GetIDPConfigsByAggregateID(v.Db, idpConfigTable, aggregateID) +func (v *View) GetIDPConfigsByAggregateID(aggregateID, instanceID string) ([]*iam_es_model.IDPConfigView, error) { + return view.GetIDPConfigsByAggregateID(v.Db, idpConfigTable, aggregateID, instanceID) } func (v *View) SearchIDPConfigs(request *iam_model.IDPConfigSearchRequest) ([]*iam_es_model.IDPConfigView, uint64, error) { @@ -34,15 +34,19 @@ func (v *View) PutIDPConfig(idp *iam_es_model.IDPConfigView, event *models.Event } func (v *View) DeleteIDPConfig(idpID string, event *models.Event) error { - err := view.DeleteIDP(v.Db, idpConfigTable, idpID) + err := view.DeleteIDP(v.Db, idpConfigTable, idpID, event.InstanceID) if err != nil && !errors.IsNotFound(err) { return err } return v.ProcessedIDPConfigSequence(event) } -func (v *View) GetLatestIDPConfigSequence() (*global_view.CurrentSequence, error) { - return v.latestSequence(idpConfigTable) +func (v *View) GetLatestIDPConfigSequence(instanceID string) (*global_view.CurrentSequence, error) { + return v.latestSequence(idpConfigTable, instanceID) +} + +func (v *View) GetLatestIDPConfigSequences() ([]*global_view.CurrentSequence, error) { + return v.latestSequences(idpConfigTable) } func (v *View) ProcessedIDPConfigSequence(event *models.Event) error { @@ -53,8 +57,8 @@ func (v *View) UpdateIDPConfigSpoolerRunTimestamp() error { return v.updateSpoolerRunSequence(idpConfigTable) } -func (v *View) GetLatestIDPConfigFailedEvent(sequence uint64) (*global_view.FailedEvent, error) { - return v.latestFailedEvent(idpConfigTable, sequence) +func (v *View) GetLatestIDPConfigFailedEvent(sequence uint64, instanceID string) (*global_view.FailedEvent, error) { + return v.latestFailedEvent(idpConfigTable, instanceID, sequence) } func (v *View) ProcessedIDPConfigFailedEvent(failedEvent *global_view.FailedEvent) error { diff --git a/internal/auth/repository/eventsourcing/view/idp_providers.go b/internal/auth/repository/eventsourcing/view/idp_providers.go index 1b720ccdc2..4bf0162489 100644 --- a/internal/auth/repository/eventsourcing/view/idp_providers.go +++ b/internal/auth/repository/eventsourcing/view/idp_providers.go @@ -13,16 +13,16 @@ const ( idpProviderTable = "auth.idp_providers" ) -func (v *View) IDPProviderByAggregateAndIDPConfigID(aggregateID, idpConfigID string) (*model.IDPProviderView, error) { - return view.GetIDPProviderByAggregateIDAndConfigID(v.Db, idpProviderTable, aggregateID, idpConfigID) +func (v *View) IDPProviderByAggregateAndIDPConfigID(aggregateID, idpConfigID, instanceID string) (*model.IDPProviderView, error) { + return view.GetIDPProviderByAggregateIDAndConfigID(v.Db, idpProviderTable, aggregateID, idpConfigID, instanceID) } -func (v *View) IDPProvidersByIDPConfigID(idpConfigID string) ([]*model.IDPProviderView, error) { - return view.IDPProvidersByIdpConfigID(v.Db, idpProviderTable, idpConfigID) +func (v *View) IDPProvidersByIDPConfigID(idpConfigID, instanceID string) ([]*model.IDPProviderView, error) { + return view.IDPProvidersByIdpConfigID(v.Db, idpProviderTable, idpConfigID, instanceID) } -func (v *View) IDPProvidersByAggregateIDAndState(aggregateID string, idpConfigState iam_model.IDPConfigState) ([]*model.IDPProviderView, error) { - return view.IDPProvidersByAggregateIDAndState(v.Db, idpProviderTable, aggregateID, idpConfigState) +func (v *View) IDPProvidersByAggregateIDAndState(aggregateID, instanceID string, idpConfigState iam_model.IDPConfigState) ([]*model.IDPProviderView, error) { + return view.IDPProvidersByAggregateIDAndState(v.Db, idpProviderTable, aggregateID, instanceID, idpConfigState) } func (v *View) SearchIDPProviders(request *iam_model.IDPProviderSearchRequest) ([]*model.IDPProviderView, uint64, error) { @@ -45,24 +45,28 @@ func (v *View) PutIDPProviders(event *models.Event, providers ...*model.IDPProvi return v.ProcessedIDPProviderSequence(event) } -func (v *View) DeleteIDPProvider(aggregateID, idpConfigID string, event *models.Event) error { - err := view.DeleteIDPProvider(v.Db, idpProviderTable, aggregateID, idpConfigID) +func (v *View) DeleteIDPProvider(aggregateID, idpConfigID, instanceID string, event *models.Event) error { + err := view.DeleteIDPProvider(v.Db, idpProviderTable, aggregateID, idpConfigID, instanceID) if err != nil && !errors.IsNotFound(err) { return err } return v.ProcessedIDPProviderSequence(event) } -func (v *View) DeleteIDPProvidersByAggregateID(aggregateID string, event *models.Event) error { - err := view.DeleteIDPProvidersByAggregateID(v.Db, idpProviderTable, aggregateID) +func (v *View) DeleteIDPProvidersByAggregateID(aggregateID, instanceID string, event *models.Event) error { + err := view.DeleteIDPProvidersByAggregateID(v.Db, idpProviderTable, aggregateID, instanceID) if err != nil && !errors.IsNotFound(err) { return err } return v.ProcessedIDPProviderSequence(event) } -func (v *View) GetLatestIDPProviderSequence() (*global_view.CurrentSequence, error) { - return v.latestSequence(idpProviderTable) +func (v *View) GetLatestIDPProviderSequence(instanceID string) (*global_view.CurrentSequence, error) { + return v.latestSequence(idpProviderTable, instanceID) +} + +func (v *View) GetLatestIDPProviderSequences() ([]*global_view.CurrentSequence, error) { + return v.latestSequences(idpProviderTable) } func (v *View) ProcessedIDPProviderSequence(event *models.Event) error { @@ -73,8 +77,8 @@ func (v *View) UpdateIDPProviderSpoolerRunTimestamp() error { return v.updateSpoolerRunSequence(idpProviderTable) } -func (v *View) GetLatestIDPProviderFailedEvent(sequence uint64) (*global_view.FailedEvent, error) { - return v.latestFailedEvent(idpProviderTable, sequence) +func (v *View) GetLatestIDPProviderFailedEvent(sequence uint64, instanceID string) (*global_view.FailedEvent, error) { + return v.latestFailedEvent(idpProviderTable, instanceID, sequence) } func (v *View) ProcessedIDPProviderFailedEvent(failedEvent *global_view.FailedEvent) error { diff --git a/internal/auth/repository/eventsourcing/view/org_project_mapping.go b/internal/auth/repository/eventsourcing/view/org_project_mapping.go index 8103cf8c5c..ba39e37c78 100644 --- a/internal/auth/repository/eventsourcing/view/org_project_mapping.go +++ b/internal/auth/repository/eventsourcing/view/org_project_mapping.go @@ -12,8 +12,8 @@ const ( orgPrgojectMappingTable = "auth.org_project_mapping" ) -func (v *View) OrgProjectMappingByIDs(orgID, projectID string) (*model.OrgProjectMapping, error) { - return view.OrgProjectMappingByIDs(v.Db, orgPrgojectMappingTable, orgID, projectID) +func (v *View) OrgProjectMappingByIDs(orgID, projectID, instanceID string) (*model.OrgProjectMapping, error) { + return view.OrgProjectMappingByIDs(v.Db, orgPrgojectMappingTable, orgID, projectID, instanceID) } func (v *View) PutOrgProjectMapping(mapping *model.OrgProjectMapping, event *models.Event) error { @@ -24,24 +24,28 @@ func (v *View) PutOrgProjectMapping(mapping *model.OrgProjectMapping, event *mod return v.ProcessedOrgProjectMappingSequence(event) } -func (v *View) DeleteOrgProjectMapping(orgID, projectID string, event *models.Event) error { - err := view.DeleteOrgProjectMapping(v.Db, orgPrgojectMappingTable, orgID, projectID) +func (v *View) DeleteOrgProjectMapping(orgID, projectID, instanceID string, event *models.Event) error { + err := view.DeleteOrgProjectMapping(v.Db, orgPrgojectMappingTable, orgID, projectID, instanceID) if err != nil && !errors.IsNotFound(err) { return err } return v.ProcessedOrgProjectMappingSequence(event) } -func (v *View) DeleteOrgProjectMappingsByProjectID(projectID string) error { - return view.DeleteOrgProjectMappingsByProjectID(v.Db, orgPrgojectMappingTable, projectID) +func (v *View) DeleteOrgProjectMappingsByProjectID(projectID, instanceID string) error { + return view.DeleteOrgProjectMappingsByProjectID(v.Db, orgPrgojectMappingTable, projectID, instanceID) } -func (v *View) DeleteOrgProjectMappingsByProjectGrantID(projectGrantID string) error { - return view.DeleteOrgProjectMappingsByProjectGrantID(v.Db, orgPrgojectMappingTable, projectGrantID) +func (v *View) DeleteOrgProjectMappingsByProjectGrantID(projectGrantID, instanceID string) error { + return view.DeleteOrgProjectMappingsByProjectGrantID(v.Db, orgPrgojectMappingTable, projectGrantID, instanceID) } -func (v *View) GetLatestOrgProjectMappingSequence() (*repository.CurrentSequence, error) { - return v.latestSequence(orgPrgojectMappingTable) +func (v *View) GetLatestOrgProjectMappingSequence(instanceID string) (*repository.CurrentSequence, error) { + return v.latestSequence(orgPrgojectMappingTable, instanceID) +} + +func (v *View) GetLatestOrgProjectMappingSequences() ([]*repository.CurrentSequence, error) { + return v.latestSequences(orgPrgojectMappingTable) } func (v *View) ProcessedOrgProjectMappingSequence(event *models.Event) error { @@ -52,8 +56,8 @@ func (v *View) UpdateOrgProjectMappingSpoolerRunTimestamp() error { return v.updateSpoolerRunSequence(orgPrgojectMappingTable) } -func (v *View) GetLatestOrgProjectMappingFailedEvent(sequence uint64) (*repository.FailedEvent, error) { - return v.latestFailedEvent(orgPrgojectMappingTable, sequence) +func (v *View) GetLatestOrgProjectMappingFailedEvent(sequence uint64, instanceID string) (*repository.FailedEvent, error) { + return v.latestFailedEvent(orgPrgojectMappingTable, instanceID, sequence) } func (v *View) ProcessedOrgProjectMappingFailedEvent(failedEvent *repository.FailedEvent) error { diff --git a/internal/auth/repository/eventsourcing/view/refresh_token.go b/internal/auth/repository/eventsourcing/view/refresh_token.go index 01acb1ba3e..87534cb6a7 100644 --- a/internal/auth/repository/eventsourcing/view/refresh_token.go +++ b/internal/auth/repository/eventsourcing/view/refresh_token.go @@ -13,12 +13,12 @@ const ( refreshTokenTable = "auth.refresh_tokens" ) -func (v *View) RefreshTokenByID(tokenID string) (*model.RefreshTokenView, error) { - return usr_view.RefreshTokenByID(v.Db, refreshTokenTable, tokenID) +func (v *View) RefreshTokenByID(tokenID, instanceID string) (*model.RefreshTokenView, error) { + return usr_view.RefreshTokenByID(v.Db, refreshTokenTable, tokenID, instanceID) } -func (v *View) RefreshTokensByUserID(userID string) ([]*model.RefreshTokenView, error) { - return usr_view.RefreshTokensByUserID(v.Db, refreshTokenTable, userID) +func (v *View) RefreshTokensByUserID(userID, instanceID string) ([]*model.RefreshTokenView, error) { + return usr_view.RefreshTokensByUserID(v.Db, refreshTokenTable, userID, instanceID) } func (v *View) SearchRefreshTokens(request *user_model.RefreshTokenSearchRequest) ([]*model.RefreshTokenView, uint64, error) { @@ -41,16 +41,16 @@ func (v *View) PutRefreshTokens(token []*model.RefreshTokenView, event *models.E return v.ProcessedRefreshTokenSequence(event) } -func (v *View) DeleteRefreshToken(tokenID string, event *models.Event) error { - err := usr_view.DeleteRefreshToken(v.Db, refreshTokenTable, tokenID) +func (v *View) DeleteRefreshToken(tokenID, instanceID string, event *models.Event) error { + err := usr_view.DeleteRefreshToken(v.Db, refreshTokenTable, tokenID, instanceID) if err != nil && !errors.IsNotFound(err) { return err } return v.ProcessedRefreshTokenSequence(event) } -func (v *View) DeleteUserRefreshTokens(userID string, event *models.Event) error { - err := usr_view.DeleteUserRefreshTokens(v.Db, refreshTokenTable, userID) +func (v *View) DeleteUserRefreshTokens(userID, instanceID string, event *models.Event) error { + err := usr_view.DeleteUserRefreshTokens(v.Db, refreshTokenTable, userID, instanceID) if err != nil && !errors.IsNotFound(err) { return err } @@ -58,15 +58,19 @@ func (v *View) DeleteUserRefreshTokens(userID string, event *models.Event) error } func (v *View) DeleteApplicationRefreshTokens(event *models.Event, ids ...string) error { - err := usr_view.DeleteApplicationTokens(v.Db, refreshTokenTable, ids) + err := usr_view.DeleteApplicationTokens(v.Db, refreshTokenTable, event.InstanceID, ids) if err != nil && !errors.IsNotFound(err) { return err } return v.ProcessedRefreshTokenSequence(event) } -func (v *View) GetLatestRefreshTokenSequence() (*repository.CurrentSequence, error) { - return v.latestSequence(refreshTokenTable) +func (v *View) GetLatestRefreshTokenSequence(instanceID string) (*repository.CurrentSequence, error) { + return v.latestSequence(refreshTokenTable, instanceID) +} + +func (v *View) GetLatestRefreshTokenSequences() ([]*repository.CurrentSequence, error) { + return v.latestSequences(refreshTokenTable) } func (v *View) ProcessedRefreshTokenSequence(event *models.Event) error { @@ -77,8 +81,8 @@ func (v *View) UpdateRefreshTokenSpoolerRunTimestamp() error { return v.updateSpoolerRunSequence(refreshTokenTable) } -func (v *View) GetLatestRefreshTokenFailedEvent(sequence uint64) (*repository.FailedEvent, error) { - return v.latestFailedEvent(refreshTokenTable, sequence) +func (v *View) GetLatestRefreshTokenFailedEvent(sequence uint64, instanceID string) (*repository.FailedEvent, error) { + return v.latestFailedEvent(refreshTokenTable, instanceID, sequence) } func (v *View) ProcessedRefreshTokenFailedEvent(failedEvent *repository.FailedEvent) error { diff --git a/internal/auth/repository/eventsourcing/view/sequence.go b/internal/auth/repository/eventsourcing/view/sequence.go index be0830c060..040171ec50 100644 --- a/internal/auth/repository/eventsourcing/view/sequence.go +++ b/internal/auth/repository/eventsourcing/view/sequence.go @@ -12,21 +12,27 @@ const ( ) func (v *View) saveCurrentSequence(viewName string, event *models.Event) error { - return repository.SaveCurrentSequence(v.Db, sequencesTable, viewName, event.Sequence, event.CreationDate) + return repository.SaveCurrentSequence(v.Db, sequencesTable, viewName, event.InstanceID, event.Sequence, event.CreationDate) } -func (v *View) latestSequence(viewName string) (*repository.CurrentSequence, error) { - return repository.LatestSequence(v.Db, sequencesTable, viewName) +func (v *View) latestSequence(viewName, instanceID string) (*repository.CurrentSequence, error) { + return repository.LatestSequence(v.Db, sequencesTable, viewName, instanceID) +} + +func (v *View) latestSequences(viewName string) ([]*repository.CurrentSequence, error) { + return repository.LatestSequences(v.Db, sequencesTable, viewName) } func (v *View) updateSpoolerRunSequence(viewName string) error { - currentSequence, err := repository.LatestSequence(v.Db, sequencesTable, viewName) + currentSequences, err := repository.LatestSequences(v.Db, sequencesTable, viewName) if err != nil { return err } - if currentSequence.ViewName == "" { - currentSequence.ViewName = viewName + for _, currentSequence := range currentSequences { + if currentSequence.ViewName == "" { + currentSequence.ViewName = viewName + } + currentSequence.LastSuccessfulSpoolerRun = time.Now() } - currentSequence.LastSuccessfulSpoolerRun = time.Now() - return repository.UpdateCurrentSequence(v.Db, sequencesTable, currentSequence) + return repository.UpdateCurrentSequences(v.Db, sequencesTable, currentSequences) } diff --git a/internal/auth/repository/eventsourcing/view/token.go b/internal/auth/repository/eventsourcing/view/token.go index e160d11964..573602a420 100644 --- a/internal/auth/repository/eventsourcing/view/token.go +++ b/internal/auth/repository/eventsourcing/view/token.go @@ -12,12 +12,12 @@ const ( tokenTable = "auth.tokens" ) -func (v *View) TokenByID(tokenID string) (*model.TokenView, error) { - return usr_view.TokenByID(v.Db, tokenTable, tokenID) +func (v *View) TokenByID(tokenID, instanceID string) (*model.TokenView, error) { + return usr_view.TokenByID(v.Db, tokenTable, tokenID, instanceID) } -func (v *View) TokensByUserID(userID string) ([]*model.TokenView, error) { - return usr_view.TokensByUserID(v.Db, tokenTable, userID) +func (v *View) TokensByUserID(userID, instanceID string) ([]*model.TokenView, error) { + return usr_view.TokensByUserID(v.Db, tokenTable, userID, instanceID) } func (v *View) PutToken(token *model.TokenView, event *models.Event) error { @@ -36,24 +36,24 @@ func (v *View) PutTokens(token []*model.TokenView, event *models.Event) error { return v.ProcessedTokenSequence(event) } -func (v *View) DeleteToken(tokenID string, event *models.Event) error { - err := usr_view.DeleteToken(v.Db, tokenTable, tokenID) +func (v *View) DeleteToken(tokenID, instanceID string, event *models.Event) error { + err := usr_view.DeleteToken(v.Db, tokenTable, tokenID, instanceID) if err != nil && !errors.IsNotFound(err) { return err } return v.ProcessedTokenSequence(event) } -func (v *View) DeleteSessionTokens(agentID, userID string, event *models.Event) error { - err := usr_view.DeleteSessionTokens(v.Db, tokenTable, agentID, userID) +func (v *View) DeleteSessionTokens(agentID, userID, instanceID string, event *models.Event) error { + err := usr_view.DeleteSessionTokens(v.Db, tokenTable, agentID, userID, instanceID) if err != nil && !errors.IsNotFound(err) { return err } return v.ProcessedTokenSequence(event) } -func (v *View) DeleteUserTokens(userID string, event *models.Event) error { - err := usr_view.DeleteUserTokens(v.Db, tokenTable, userID) +func (v *View) DeleteUserTokens(userID, instanceID string, event *models.Event) error { + err := usr_view.DeleteUserTokens(v.Db, tokenTable, userID, instanceID) if err != nil && !errors.IsNotFound(err) { return err } @@ -61,23 +61,27 @@ func (v *View) DeleteUserTokens(userID string, event *models.Event) error { } func (v *View) DeleteApplicationTokens(event *models.Event, ids ...string) error { - err := usr_view.DeleteApplicationTokens(v.Db, tokenTable, ids) + err := usr_view.DeleteApplicationTokens(v.Db, tokenTable, event.InstanceID, ids) if err != nil && !errors.IsNotFound(err) { return err } return v.ProcessedTokenSequence(event) } -func (v *View) DeleteTokensFromRefreshToken(refreshTokenID string, event *models.Event) error { - err := usr_view.DeleteTokensFromRefreshToken(v.Db, tokenTable, refreshTokenID) +func (v *View) DeleteTokensFromRefreshToken(refreshTokenID, instanceID string, event *models.Event) error { + err := usr_view.DeleteTokensFromRefreshToken(v.Db, tokenTable, refreshTokenID, instanceID) if err != nil && !errors.IsNotFound(err) { return err } return v.ProcessedTokenSequence(event) } -func (v *View) GetLatestTokenSequence() (*repository.CurrentSequence, error) { - return v.latestSequence(tokenTable) +func (v *View) GetLatestTokenSequence(instanceID string) (*repository.CurrentSequence, error) { + return v.latestSequence(tokenTable, instanceID) +} + +func (v *View) GetLatestTokenSequences() ([]*repository.CurrentSequence, error) { + return v.latestSequences(tokenTable) } func (v *View) ProcessedTokenSequence(event *models.Event) error { @@ -88,8 +92,8 @@ func (v *View) UpdateTokenSpoolerRunTimestamp() error { return v.updateSpoolerRunSequence(tokenTable) } -func (v *View) GetLatestTokenFailedEvent(sequence uint64) (*repository.FailedEvent, error) { - return v.latestFailedEvent(tokenTable, sequence) +func (v *View) GetLatestTokenFailedEvent(sequence uint64, instanceID string) (*repository.FailedEvent, error) { + return v.latestFailedEvent(tokenTable, instanceID, sequence) } func (v *View) ProcessedTokenFailedEvent(failedEvent *repository.FailedEvent) error { diff --git a/internal/auth/repository/eventsourcing/view/user.go b/internal/auth/repository/eventsourcing/view/user.go index 85b25ea232..1d81d56bb8 100644 --- a/internal/auth/repository/eventsourcing/view/user.go +++ b/internal/auth/repository/eventsourcing/view/user.go @@ -13,40 +13,40 @@ const ( userTable = "auth.users" ) -func (v *View) UserByID(userID string) (*model.UserView, error) { - return view.UserByID(v.Db, userTable, userID) +func (v *View) UserByID(userID, instanceID string) (*model.UserView, error) { + return view.UserByID(v.Db, userTable, userID, instanceID) } -func (v *View) UserByUsername(userName string) (*model.UserView, error) { - return view.UserByUserName(v.Db, userTable, userName) +func (v *View) UserByUsername(userName, instanceID string) (*model.UserView, error) { + return view.UserByUserName(v.Db, userTable, userName, instanceID) } -func (v *View) UserByLoginName(loginName string) (*model.UserView, error) { - return view.UserByLoginName(v.Db, userTable, loginName) +func (v *View) UserByLoginName(loginName, instanceID string) (*model.UserView, error) { + return view.UserByLoginName(v.Db, userTable, loginName, instanceID) } -func (v *View) UserByLoginNameAndResourceOwner(loginName, resourceOwner string) (*model.UserView, error) { - return view.UserByLoginNameAndResourceOwner(v.Db, userTable, loginName, resourceOwner) +func (v *View) UserByLoginNameAndResourceOwner(loginName, resourceOwner, instanceID string) (*model.UserView, error) { + return view.UserByLoginNameAndResourceOwner(v.Db, userTable, loginName, resourceOwner, instanceID) } -func (v *View) UsersByOrgID(orgID string) ([]*model.UserView, error) { - return view.UsersByOrgID(v.Db, userTable, orgID) +func (v *View) UsersByOrgID(orgID, instanceID string) ([]*model.UserView, error) { + return view.UsersByOrgID(v.Db, userTable, orgID, instanceID) } -func (v *View) UserIDsByDomain(domain string) ([]string, error) { - return view.UserIDsByDomain(v.Db, userTable, domain) +func (v *View) UserIDsByDomain(domain, instanceID string) ([]string, error) { + return view.UserIDsByDomain(v.Db, userTable, domain, instanceID) } func (v *View) SearchUsers(request *usr_model.UserSearchRequest) ([]*model.UserView, uint64, error) { return view.SearchUsers(v.Db, userTable, request) } -func (v *View) GetGlobalUserByLoginName(email string) (*model.UserView, error) { - return view.GetGlobalUserByLoginName(v.Db, userTable, email) +func (v *View) GetGlobalUserByLoginName(email, instanceID string) (*model.UserView, error) { + return view.GetGlobalUserByLoginName(v.Db, userTable, email, instanceID) } -func (v *View) UserMFAs(userID string) ([]*usr_model.MultiFactor, error) { - return view.UserMFAs(v.Db, userTable, userID) +func (v *View) UserMFAs(userID, instanceID string) ([]*usr_model.MultiFactor, error) { + return view.UserMFAs(v.Db, userTable, userID, instanceID) } func (v *View) PutUser(user *model.UserView, event *models.Event) error { @@ -65,16 +65,20 @@ func (v *View) PutUsers(users []*model.UserView, event *models.Event) error { return v.ProcessedUserSequence(event) } -func (v *View) DeleteUser(userID string, event *models.Event) error { - err := view.DeleteUser(v.Db, userTable, userID) +func (v *View) DeleteUser(userID, instanceID string, event *models.Event) error { + err := view.DeleteUser(v.Db, userTable, userID, instanceID) if err != nil && !errors.IsNotFound(err) { return err } return v.ProcessedUserSequence(event) } -func (v *View) GetLatestUserSequence() (*repository.CurrentSequence, error) { - return v.latestSequence(userTable) +func (v *View) GetLatestUserSequence(instanceID string) (*repository.CurrentSequence, error) { + return v.latestSequence(userTable, instanceID) +} + +func (v *View) GetLatestUserSequences() ([]*repository.CurrentSequence, error) { + return v.latestSequences(userTable) } func (v *View) ProcessedUserSequence(event *models.Event) error { @@ -85,8 +89,8 @@ func (v *View) UpdateUserSpoolerRunTimestamp() error { return v.updateSpoolerRunSequence(userTable) } -func (v *View) GetLatestUserFailedEvent(sequence uint64) (*repository.FailedEvent, error) { - return v.latestFailedEvent(userTable, sequence) +func (v *View) GetLatestUserFailedEvent(sequence uint64, instanceID string) (*repository.FailedEvent, error) { + return v.latestFailedEvent(userTable, instanceID, sequence) } func (v *View) ProcessedUserFailedEvent(failedEvent *repository.FailedEvent) error { diff --git a/internal/auth/repository/eventsourcing/view/user_session.go b/internal/auth/repository/eventsourcing/view/user_session.go index 03c47e7a50..8aca8f0027 100644 --- a/internal/auth/repository/eventsourcing/view/user_session.go +++ b/internal/auth/repository/eventsourcing/view/user_session.go @@ -12,16 +12,16 @@ const ( userSessionTable = "auth.user_sessions" ) -func (v *View) UserSessionByIDs(agentID, userID string) (*model.UserSessionView, error) { - return view.UserSessionByIDs(v.Db, userSessionTable, agentID, userID) +func (v *View) UserSessionByIDs(agentID, userID, instanceID string) (*model.UserSessionView, error) { + return view.UserSessionByIDs(v.Db, userSessionTable, agentID, userID, instanceID) } -func (v *View) UserSessionsByUserID(userID string) ([]*model.UserSessionView, error) { - return view.UserSessionsByUserID(v.Db, userSessionTable, userID) +func (v *View) UserSessionsByUserID(userID, instanceID string) ([]*model.UserSessionView, error) { + return view.UserSessionsByUserID(v.Db, userSessionTable, userID, instanceID) } -func (v *View) UserSessionsByAgentID(agentID string) ([]*model.UserSessionView, error) { - return view.UserSessionsByAgentID(v.Db, userSessionTable, agentID) +func (v *View) UserSessionsByAgentID(agentID, instanceID string) ([]*model.UserSessionView, error) { + return view.UserSessionsByAgentID(v.Db, userSessionTable, agentID, instanceID) } func (v *View) ActiveUserSessionsCount() (uint64, error) { @@ -44,16 +44,20 @@ func (v *View) PutUserSessions(userSession []*model.UserSessionView, event *mode return v.ProcessedUserSessionSequence(event) } -func (v *View) DeleteUserSessions(userID string, event *models.Event) error { - err := view.DeleteUserSessions(v.Db, userSessionTable, userID) +func (v *View) DeleteUserSessions(userID, instanceID string, event *models.Event) error { + err := view.DeleteUserSessions(v.Db, userSessionTable, userID, instanceID) if err != nil && !errors.IsNotFound(err) { return err } return v.ProcessedUserSessionSequence(event) } -func (v *View) GetLatestUserSessionSequence() (*repository.CurrentSequence, error) { - return v.latestSequence(userSessionTable) +func (v *View) GetLatestUserSessionSequence(instanceID string) (*repository.CurrentSequence, error) { + return v.latestSequence(userSessionTable, instanceID) +} + +func (v *View) GetLatestUserSessionSequences() ([]*repository.CurrentSequence, error) { + return v.latestSequences(userSessionTable) } func (v *View) ProcessedUserSessionSequence(event *models.Event) error { @@ -64,8 +68,8 @@ func (v *View) UpdateUserSessionSpoolerRunTimestamp() error { return v.updateSpoolerRunSequence(userSessionTable) } -func (v *View) GetLatestUserSessionFailedEvent(sequence uint64) (*repository.FailedEvent, error) { - return v.latestFailedEvent(userSessionTable, sequence) +func (v *View) GetLatestUserSessionFailedEvent(sequence uint64, instanceID string) (*repository.FailedEvent, error) { + return v.latestFailedEvent(userSessionTable, instanceID, sequence) } func (v *View) ProcessedUserSessionFailedEvent(failedEvent *repository.FailedEvent) error { diff --git a/internal/authz/repository/eventsourcing/eventstore/token_verifier.go b/internal/authz/repository/eventsourcing/eventstore/token_verifier.go index f83a6f899d..986e63d33a 100644 --- a/internal/authz/repository/eventsourcing/eventstore/token_verifier.go +++ b/internal/authz/repository/eventsourcing/eventstore/token_verifier.go @@ -8,6 +8,7 @@ import ( "github.com/caos/logging" + "github.com/caos/zitadel/internal/api/authz" "github.com/caos/zitadel/internal/authz/repository/eventsourcing/view" "github.com/caos/zitadel/internal/crypto" "github.com/caos/zitadel/internal/domain" @@ -30,7 +31,7 @@ type TokenVerifierRepo struct { } func (repo *TokenVerifierRepo) tokenByID(ctx context.Context, tokenID, userID string) (*usr_model.TokenView, error) { - token, viewErr := repo.View.TokenByID(tokenID) + token, viewErr := repo.View.TokenByID(tokenID, authz.GetInstance(ctx).InstanceID()) if viewErr != nil && !caos_errs.IsNotFound(viewErr) { return nil, viewErr } @@ -40,7 +41,7 @@ func (repo *TokenVerifierRepo) tokenByID(ctx context.Context, tokenID, userID st token.UserID = userID } - events, esErr := repo.getUserEvents(ctx, userID, token.Sequence) + events, esErr := repo.getUserEvents(ctx, userID, token.InstanceID, token.Sequence) if caos_errs.IsNotFound(viewErr) && len(events) == 0 { return nil, caos_errs.ThrowNotFound(nil, "EVENT-4T90g", "Errors.Token.NotFound") } @@ -251,8 +252,8 @@ func (repo *TokenVerifierRepo) VerifierClientID(ctx context.Context, appName str return clientID, app.ProjectID, nil } -func (r *TokenVerifierRepo) getUserEvents(ctx context.Context, userID string, sequence uint64) ([]*models.Event, error) { - query, err := usr_view.UserByIDQuery(userID, sequence) +func (r *TokenVerifierRepo) getUserEvents(ctx context.Context, userID, instanceID string, sequence uint64) ([]*models.Event, error) { + query, err := usr_view.UserByIDQuery(userID, instanceID, sequence) if err != nil { return nil, err } diff --git a/internal/authz/repository/eventsourcing/handler/user_membership.go b/internal/authz/repository/eventsourcing/handler/user_membership.go index 89c70b74e3..af8591bb61 100644 --- a/internal/authz/repository/eventsourcing/handler/user_membership.go +++ b/internal/authz/repository/eventsourcing/handler/user_membership.go @@ -68,8 +68,8 @@ func (_ *UserMembership) AggregateTypes() []es_models.AggregateType { return []es_models.AggregateType{instance.AggregateType, org.AggregateType, project.AggregateType, user.AggregateType} } -func (m *UserMembership) CurrentSequence() (uint64, error) { - sequence, err := m.view.GetLatestUserMembershipSequence() +func (m *UserMembership) CurrentSequence(instanceID string) (uint64, error) { + sequence, err := m.view.GetLatestUserMembershipSequence(instanceID) if err != nil { return 0, err } @@ -77,13 +77,29 @@ func (m *UserMembership) CurrentSequence() (uint64, error) { } func (m *UserMembership) EventQuery() (*es_models.SearchQuery, error) { - sequence, err := m.view.GetLatestUserMembershipSequence() + sequences, err := m.view.GetLatestUserMembershipSequences() if err != nil { return nil, err } - return es_models.NewSearchQuery(). + query := es_models.NewSearchQuery() + instances := make([]string, 0) + for _, sequence := range sequences { + for _, instance := range instances { + if sequence.InstanceID == instance { + break + } + } + instances = append(instances, sequence.InstanceID) + query.AddQuery(). + AggregateTypeFilter(m.AggregateTypes()...). + LatestSequenceFilter(sequence.CurrentSequence). + InstanceIDFilter(sequence.InstanceID) + } + return query.AddQuery(). AggregateTypeFilter(m.AggregateTypes()...). - LatestSequenceFilter(sequence.CurrentSequence), nil + LatestSequenceFilter(0). + ExcludedInstanceIDsFilter(instances...). + SearchQuery(), nil } func (m *UserMembership) Reduce(event *es_models.Event) (err error) { @@ -110,14 +126,14 @@ func (m *UserMembership) processIAM(event *es_models.Event) (err error) { case instance.MemberAddedEventType: m.fillIamDisplayName(member) case instance.MemberChangedEventType: - member, err = m.view.UserMembershipByIDs(member.UserID, event.AggregateID, event.AggregateID, usr_model.MemberTypeIam) + member, err = m.view.UserMembershipByIDs(member.UserID, event.AggregateID, event.AggregateID, event.InstanceID, usr_model.MemberTypeIam) if err != nil { return err } err = member.AppendEvent(event) case instance.MemberRemovedEventType, instance.MemberCascadeRemovedEventType: - return m.view.DeleteUserMembership(member.UserID, event.AggregateID, event.AggregateID, usr_model.MemberTypeIam, event) + return m.view.DeleteUserMembership(member.UserID, event.AggregateID, event.AggregateID, event.InstanceID, usr_model.MemberTypeIam, event) default: return m.view.ProcessedUserMembershipSequence(event) } @@ -142,14 +158,14 @@ func (m *UserMembership) processOrg(event *es_models.Event) (err error) { case org.MemberAddedEventType: err = m.fillOrgName(member) case org.MemberChangedEventType: - member, err = m.view.UserMembershipByIDs(member.UserID, event.AggregateID, event.AggregateID, usr_model.MemberTypeOrganisation) + member, err = m.view.UserMembershipByIDs(member.UserID, event.AggregateID, event.AggregateID, event.InstanceID, usr_model.MemberTypeOrganisation) if err != nil { return err } err = member.AppendEvent(event) case org.MemberRemovedEventType, org.MemberCascadeRemovedEventType: - return m.view.DeleteUserMembership(member.UserID, event.AggregateID, event.AggregateID, usr_model.MemberTypeOrganisation, event) + return m.view.DeleteUserMembership(member.UserID, event.AggregateID, event.AggregateID, event.InstanceID, usr_model.MemberTypeOrganisation, event) case org.OrgChangedEventType: return m.updateOrgName(event) default: @@ -179,7 +195,7 @@ func (m *UserMembership) updateOrgName(event *es_models.Event) error { return err } - memberships, err := m.view.UserMembershipsByResourceOwner(event.ResourceOwner) + memberships, err := m.view.UserMembershipsByResourceOwner(event.ResourceOwner, event.InstanceID) if err != nil { return err } @@ -206,28 +222,28 @@ func (m *UserMembership) processProject(event *es_models.Event) (err error) { } err = m.fillOrgName(member) case project.MemberChangedType: - member, err = m.view.UserMembershipByIDs(member.UserID, event.AggregateID, event.AggregateID, usr_model.MemberTypeProject) + member, err = m.view.UserMembershipByIDs(member.UserID, event.AggregateID, event.AggregateID, event.InstanceID, usr_model.MemberTypeProject) if err != nil { return err } err = member.AppendEvent(event) case project.MemberRemovedType, project.MemberCascadeRemovedType: - return m.view.DeleteUserMembership(member.UserID, event.AggregateID, event.AggregateID, usr_model.MemberTypeProject, event) + return m.view.DeleteUserMembership(member.UserID, event.AggregateID, event.AggregateID, event.InstanceID, usr_model.MemberTypeProject, event) case project.GrantMemberChangedType: - member, err = m.view.UserMembershipByIDs(member.UserID, event.AggregateID, member.ObjectID, usr_model.MemberTypeProjectGrant) + member, err = m.view.UserMembershipByIDs(member.UserID, event.AggregateID, member.ObjectID, event.InstanceID, usr_model.MemberTypeProjectGrant) if err != nil { return err } err = member.AppendEvent(event) case project.GrantMemberRemovedType, project.GrantMemberCascadeRemovedType: - return m.view.DeleteUserMembership(member.UserID, event.AggregateID, member.ObjectID, usr_model.MemberTypeProjectGrant, event) + return m.view.DeleteUserMembership(member.UserID, event.AggregateID, member.ObjectID, member.InstanceID, usr_model.MemberTypeProjectGrant, event) case project.ProjectChangedType: return m.updateProjectDisplayName(event) case project.ProjectRemovedType: - return m.view.DeleteUserMembershipsByAggregateID(event.AggregateID, event) + return m.view.DeleteUserMembershipsByAggregateID(event.AggregateID, event.InstanceID, event) case project.GrantRemovedType: - return m.view.DeleteUserMembershipsByAggregateIDAndObjectID(event.AggregateID, member.ObjectID, event) + return m.view.DeleteUserMembershipsByAggregateIDAndObjectID(event.AggregateID, member.ObjectID, member.InstanceID, event) default: return m.view.ProcessedUserMembershipSequence(event) } @@ -238,7 +254,7 @@ func (m *UserMembership) processProject(event *es_models.Event) (err error) { } func (m *UserMembership) fillProjectDisplayName(member *usr_es_model.UserMembershipView) (err error) { - project, err := m.getProjectByID(context.Background(), member.AggregateID) + project, err := m.getProjectByID(context.Background(), member.AggregateID, member.InstanceID) if err != nil { return err } @@ -256,7 +272,7 @@ func (m *UserMembership) updateProjectDisplayName(event *es_models.Event) error return m.view.ProcessedUserMembershipSequence(event) } - memberships, err := m.view.UserMembershipsByAggregateID(event.AggregateID) + memberships, err := m.view.UserMembershipsByAggregateID(event.AggregateID, event.InstanceID) if err != nil { return err } @@ -269,7 +285,7 @@ func (m *UserMembership) updateProjectDisplayName(event *es_models.Event) error func (m *UserMembership) processUser(event *es_models.Event) (err error) { switch eventstore.EventType(event.Type) { case user.UserRemovedType: - return m.view.DeleteUserMembershipsByUserID(event.AggregateID, event) + return m.view.DeleteUserMembershipsByUserID(event.AggregateID, event.InstanceID, event) default: return m.view.ProcessedUserMembershipSequence(event) } @@ -306,8 +322,8 @@ func (u *UserMembership) getOrgByID(ctx context.Context, orgID string) (*org_mod return org_es_model.OrgToModel(esOrg), nil } -func (u *UserMembership) getProjectByID(ctx context.Context, projID string) (*proj_model.Project, error) { - query, err := proj_view.ProjectByIDQuery(projID, 0) +func (u *UserMembership) getProjectByID(ctx context.Context, projID, instanceID string) (*proj_model.Project, error) { + query, err := proj_view.ProjectByIDQuery(projID, instanceID, 0) if err != nil { return nil, err } diff --git a/internal/authz/repository/eventsourcing/spooler/lock.go b/internal/authz/repository/eventsourcing/spooler/lock.go index 1c58c9392f..4e60b71390 100644 --- a/internal/authz/repository/eventsourcing/spooler/lock.go +++ b/internal/authz/repository/eventsourcing/spooler/lock.go @@ -2,8 +2,9 @@ package spooler import ( "database/sql" - es_locker "github.com/caos/zitadel/internal/eventstore/v1/locker" "time" + + es_locker "github.com/caos/zitadel/internal/eventstore/v1/locker" ) const ( @@ -14,6 +15,6 @@ type locker struct { dbClient *sql.DB } -func (l *locker) Renew(lockerID, viewModel string, waitTime time.Duration) error { - return es_locker.Renew(l.dbClient, lockTable, lockerID, viewModel, waitTime) +func (l *locker) Renew(lockerID, viewModel, instanceID string, waitTime time.Duration) error { + return es_locker.Renew(l.dbClient, lockTable, lockerID, viewModel, instanceID, waitTime) } diff --git a/internal/authz/repository/eventsourcing/view/error_event.go b/internal/authz/repository/eventsourcing/view/error_event.go index 6f8488dd54..68aae4d431 100644 --- a/internal/authz/repository/eventsourcing/view/error_event.go +++ b/internal/authz/repository/eventsourcing/view/error_event.go @@ -12,6 +12,6 @@ func (v *View) saveFailedEvent(failedEvent *repository.FailedEvent) error { return repository.SaveFailedEvent(v.Db, errTable, failedEvent) } -func (v *View) latestFailedEvent(viewName string, sequence uint64) (*repository.FailedEvent, error) { - return repository.LatestFailedEvent(v.Db, errTable, viewName, sequence) +func (v *View) latestFailedEvent(viewName, instanceID string, sequence uint64) (*repository.FailedEvent, error) { + return repository.LatestFailedEvent(v.Db, errTable, viewName, instanceID, sequence) } diff --git a/internal/authz/repository/eventsourcing/view/sequence.go b/internal/authz/repository/eventsourcing/view/sequence.go index a19fa854c2..84f7518372 100644 --- a/internal/authz/repository/eventsourcing/view/sequence.go +++ b/internal/authz/repository/eventsourcing/view/sequence.go @@ -12,21 +12,27 @@ const ( ) func (v *View) saveCurrentSequence(viewName string, event *models.Event) error { - return repository.SaveCurrentSequence(v.Db, sequencesTable, viewName, event.Sequence, event.CreationDate) + return repository.SaveCurrentSequence(v.Db, sequencesTable, viewName, event.InstanceID, event.Sequence, event.CreationDate) } -func (v *View) latestSequence(viewName string) (*repository.CurrentSequence, error) { - return repository.LatestSequence(v.Db, sequencesTable, viewName) +func (v *View) latestSequence(viewName, instanceID string) (*repository.CurrentSequence, error) { + return repository.LatestSequence(v.Db, sequencesTable, viewName, instanceID) +} + +func (v *View) latestSequences(viewName string) ([]*repository.CurrentSequence, error) { + return repository.LatestSequences(v.Db, sequencesTable, viewName) } func (v *View) updateSpoolerRunSequence(viewName string) error { - currentSequence, err := repository.LatestSequence(v.Db, sequencesTable, viewName) + currentSequences, err := repository.LatestSequences(v.Db, sequencesTable, viewName) if err != nil { return err } - if currentSequence.ViewName == "" { - currentSequence.ViewName = viewName + for _, currentSequence := range currentSequences { + if currentSequence.ViewName == "" { + currentSequence.ViewName = viewName + } + currentSequence.LastSuccessfulSpoolerRun = time.Now() } - currentSequence.LastSuccessfulSpoolerRun = time.Now() - return repository.UpdateCurrentSequence(v.Db, sequencesTable, currentSequence) + return repository.UpdateCurrentSequences(v.Db, sequencesTable, currentSequences) } diff --git a/internal/authz/repository/eventsourcing/view/token.go b/internal/authz/repository/eventsourcing/view/token.go index ee43edfad2..98e5446ea3 100644 --- a/internal/authz/repository/eventsourcing/view/token.go +++ b/internal/authz/repository/eventsourcing/view/token.go @@ -12,8 +12,8 @@ const ( tokenTable = "auth.tokens" ) -func (v *View) TokenByID(tokenID string) (*usr_view_model.TokenView, error) { - return usr_view.TokenByID(v.Db, tokenTable, tokenID) +func (v *View) TokenByID(tokenID, instanceID string) (*usr_view_model.TokenView, error) { + return usr_view.TokenByID(v.Db, tokenTable, tokenID, instanceID) } func (v *View) PutToken(token *usr_view_model.TokenView, event *models.Event) error { @@ -24,24 +24,24 @@ func (v *View) PutToken(token *usr_view_model.TokenView, event *models.Event) er return v.ProcessedTokenSequence(event) } -func (v *View) DeleteToken(tokenID string, event *models.Event) error { - err := usr_view.DeleteToken(v.Db, tokenTable, tokenID) +func (v *View) DeleteToken(tokenID, instanceID string, event *models.Event) error { + err := usr_view.DeleteToken(v.Db, tokenTable, tokenID, instanceID) if err != nil && !errors.IsNotFound(err) { return err } return v.ProcessedTokenSequence(event) } -func (v *View) DeleteSessionTokens(agentID, userID string, event *models.Event) error { - err := usr_view.DeleteSessionTokens(v.Db, tokenTable, agentID, userID) +func (v *View) DeleteSessionTokens(agentID, userID, instanceID string, event *models.Event) error { + err := usr_view.DeleteSessionTokens(v.Db, tokenTable, agentID, userID, instanceID) if err != nil && !errors.IsNotFound(err) { return err } return v.ProcessedTokenSequence(event) } -func (v *View) GetLatestTokenSequence() (*repository.CurrentSequence, error) { - return v.latestSequence(tokenTable) +func (v *View) GetLatestTokenSequence(instanceID string) (*repository.CurrentSequence, error) { + return v.latestSequence(tokenTable, instanceID) } func (v *View) ProcessedTokenSequence(event *models.Event) error { @@ -52,8 +52,8 @@ func (v *View) UpdateTokenSpoolerRunTimestamp() error { return v.updateSpoolerRunSequence(tokenTable) } -func (v *View) GetLatestTokenFailedEvent(sequence uint64) (*repository.FailedEvent, error) { - return v.latestFailedEvent(tokenTable, sequence) +func (v *View) GetLatestTokenFailedEvent(sequence uint64, instanceID string) (*repository.FailedEvent, error) { + return v.latestFailedEvent(tokenTable, instanceID, sequence) } func (v *View) ProcessedTokenFailedEvent(failedEvent *repository.FailedEvent) error { diff --git a/internal/authz/repository/eventsourcing/view/user_membership.go b/internal/authz/repository/eventsourcing/view/user_membership.go index e570fb51ed..9fdc1668d2 100644 --- a/internal/authz/repository/eventsourcing/view/user_membership.go +++ b/internal/authz/repository/eventsourcing/view/user_membership.go @@ -13,16 +13,16 @@ const ( userMembershipTable = "authz.user_memberships" ) -func (v *View) UserMembershipByIDs(userID, aggregateID, objectID string, memberType usr_model.MemberType) (*model.UserMembershipView, error) { - return view.UserMembershipByIDs(v.Db, userMembershipTable, userID, aggregateID, objectID, memberType) +func (v *View) UserMembershipByIDs(userID, aggregateID, objectID, instanceID string, memberType usr_model.MemberType) (*model.UserMembershipView, error) { + return view.UserMembershipByIDs(v.Db, userMembershipTable, userID, aggregateID, objectID, instanceID, memberType) } -func (v *View) UserMembershipsByAggregateID(aggregateID string) ([]*model.UserMembershipView, error) { - return view.UserMembershipsByAggregateID(v.Db, userMembershipTable, aggregateID) +func (v *View) UserMembershipsByAggregateID(aggregateID, instanceID string) ([]*model.UserMembershipView, error) { + return view.UserMembershipsByAggregateID(v.Db, userMembershipTable, aggregateID, instanceID) } -func (v *View) UserMembershipsByResourceOwner(resourceOwner string) ([]*model.UserMembershipView, error) { - return view.UserMembershipsByResourceOwner(v.Db, userMembershipTable, resourceOwner) +func (v *View) UserMembershipsByResourceOwner(resourceOwner, instanceID string) ([]*model.UserMembershipView, error) { + return view.UserMembershipsByResourceOwner(v.Db, userMembershipTable, resourceOwner, instanceID) } func (v *View) SearchUserMemberships(request *usr_model.UserMembershipSearchRequest) ([]*model.UserMembershipView, uint64, error) { @@ -45,40 +45,44 @@ func (v *View) BulkPutUserMemberships(memberships []*model.UserMembershipView, e return v.ProcessedUserMembershipSequence(event) } -func (v *View) DeleteUserMembership(userID, aggregateID, objectID string, memberType usr_model.MemberType, event *models.Event) error { - err := view.DeleteUserMembership(v.Db, userMembershipTable, userID, aggregateID, objectID, memberType) +func (v *View) DeleteUserMembership(userID, aggregateID, objectID, instanceID string, memberType usr_model.MemberType, event *models.Event) error { + err := view.DeleteUserMembership(v.Db, userMembershipTable, userID, aggregateID, objectID, instanceID, memberType) if err != nil && !errors.IsNotFound(err) { return err } return v.ProcessedUserMembershipSequence(event) } -func (v *View) DeleteUserMembershipsByUserID(userID string, event *models.Event) error { - err := view.DeleteUserMembershipsByUserID(v.Db, userMembershipTable, userID) +func (v *View) DeleteUserMembershipsByUserID(userID, instanceID string, event *models.Event) error { + err := view.DeleteUserMembershipsByUserID(v.Db, userMembershipTable, userID, instanceID) if err != nil && !errors.IsNotFound(err) { return err } return v.ProcessedUserMembershipSequence(event) } -func (v *View) DeleteUserMembershipsByAggregateID(aggregateID string, event *models.Event) error { - err := view.DeleteUserMembershipsByAggregateID(v.Db, userMembershipTable, aggregateID) +func (v *View) DeleteUserMembershipsByAggregateID(aggregateID, instanceID string, event *models.Event) error { + err := view.DeleteUserMembershipsByAggregateID(v.Db, userMembershipTable, aggregateID, instanceID) if err != nil && !errors.IsNotFound(err) { return err } return v.ProcessedUserMembershipSequence(event) } -func (v *View) DeleteUserMembershipsByAggregateIDAndObjectID(aggregateID, objectID string, event *models.Event) error { - err := view.DeleteUserMembershipsByAggregateIDAndObjectID(v.Db, userMembershipTable, aggregateID, objectID) +func (v *View) DeleteUserMembershipsByAggregateIDAndObjectID(aggregateID, objectID, instanceID string, event *models.Event) error { + err := view.DeleteUserMembershipsByAggregateIDAndObjectID(v.Db, userMembershipTable, aggregateID, objectID, instanceID) if err != nil && !errors.IsNotFound(err) { return err } return v.ProcessedUserMembershipSequence(event) } -func (v *View) GetLatestUserMembershipSequence() (*repository.CurrentSequence, error) { - return v.latestSequence(userMembershipTable) +func (v *View) GetLatestUserMembershipSequence(instanceID string) (*repository.CurrentSequence, error) { + return v.latestSequence(userMembershipTable, instanceID) +} + +func (v *View) GetLatestUserMembershipSequences() ([]*repository.CurrentSequence, error) { + return v.latestSequences(userMembershipTable) } func (v *View) ProcessedUserMembershipSequence(event *models.Event) error { @@ -89,8 +93,8 @@ func (v *View) UpdateUserMembershipSpoolerRunTimestamp() error { return v.updateSpoolerRunSequence(userMembershipTable) } -func (v *View) GetLatestUserMembershipFailedEvent(sequence uint64) (*repository.FailedEvent, error) { - return v.latestFailedEvent(userMembershipTable, sequence) +func (v *View) GetLatestUserMembershipFailedEvent(sequence uint64, instanceID string) (*repository.FailedEvent, error) { + return v.latestFailedEvent(userMembershipTable, instanceID, sequence) } func (v *View) ProcessedUserMembershipFailedEvent(failedEvent *repository.FailedEvent) error { diff --git a/internal/command/instance_policy_password_age_test.go b/internal/command/instance_policy_password_age_test.go index 8c7d949c1c..5a17ffb1e6 100644 --- a/internal/command/instance_policy_password_age_test.go +++ b/internal/command/instance_policy_password_age_test.go @@ -4,6 +4,8 @@ import ( "context" "testing" + "github.com/stretchr/testify/assert" + "github.com/caos/zitadel/internal/api/authz" "github.com/caos/zitadel/internal/domain" caos_errs "github.com/caos/zitadel/internal/errors" @@ -12,7 +14,6 @@ import ( "github.com/caos/zitadel/internal/eventstore/v1/models" "github.com/caos/zitadel/internal/repository/instance" "github.com/caos/zitadel/internal/repository/policy" - "github.com/stretchr/testify/assert" ) func TestCommandSide_AddDefaultPasswordAgePolicy(t *testing.T) { diff --git a/internal/eventstore/handler/crdb/current_sequence.go b/internal/eventstore/handler/crdb/current_sequence.go index b1b0279126..0221f0d452 100644 --- a/internal/eventstore/handler/crdb/current_sequence.go +++ b/internal/eventstore/handler/crdb/current_sequence.go @@ -10,11 +10,16 @@ import ( ) const ( - currentSequenceStmtFormat = `SELECT current_sequence, aggregate_type FROM %s WHERE projection_name = $1 FOR UPDATE` - updateCurrentSequencesStmtFormat = `UPSERT INTO %s (projection_name, aggregate_type, current_sequence, timestamp) VALUES ` + currentSequenceStmtFormat = `SELECT current_sequence, aggregate_type, instance_id FROM %s WHERE projection_name = $1 FOR UPDATE` + updateCurrentSequencesStmtFormat = `UPSERT INTO %s (projection_name, aggregate_type, current_sequence, instance_id, timestamp) VALUES ` ) -type currentSequences map[eventstore.AggregateType]uint64 +type currentSequences map[eventstore.AggregateType][]*instanceSequence + +type instanceSequence struct { + instanceID string + sequence uint64 +} func (h *StatementHandler) currentSequences(query func(string, ...interface{}) (*sql.Rows, error)) (currentSequences, error) { rows, err := query(h.currentSequenceStmt, h.ProjectionName) @@ -29,14 +34,18 @@ func (h *StatementHandler) currentSequences(query func(string, ...interface{}) ( var ( aggregateType eventstore.AggregateType sequence uint64 + instanceID string ) - err = rows.Scan(&sequence, &aggregateType) + err = rows.Scan(&sequence, &aggregateType, &instanceID) if err != nil { return nil, errors.ThrowInternal(err, "CRDB-dbatK", "scan failed") } - sequences[aggregateType] = sequence + sequences[aggregateType] = append(sequences[aggregateType], &instanceSequence{ + sequence: sequence, + instanceID: instanceID, + }) } if err = rows.Close(); err != nil { @@ -54,10 +63,12 @@ func (h *StatementHandler) updateCurrentSequences(tx *sql.Tx, sequences currentS valueQueries := make([]string, 0, len(sequences)) valueCounter := 0 values := make([]interface{}, 0, len(sequences)*3) - for aggregate, sequence := range sequences { - valueQueries = append(valueQueries, "($"+strconv.Itoa(valueCounter+1)+", $"+strconv.Itoa(valueCounter+2)+", $"+strconv.Itoa(valueCounter+3)+", NOW())") - valueCounter += 3 - values = append(values, h.ProjectionName, aggregate, sequence) + for aggregate, instanceSequence := range sequences { + for _, sequence := range instanceSequence { + valueQueries = append(valueQueries, "($"+strconv.Itoa(valueCounter+1)+", $"+strconv.Itoa(valueCounter+2)+", $"+strconv.Itoa(valueCounter+3)+", $"+strconv.Itoa(valueCounter+4)+", NOW())") + valueCounter += 4 + values = append(values, h.ProjectionName, aggregate, sequence.sequence, sequence.instanceID) + } } res, err := tx.Exec(h.updateSequencesBaseStmt+strings.Join(valueQueries, ", "), values...) diff --git a/internal/eventstore/handler/crdb/db_mock_test.go b/internal/eventstore/handler/crdb/db_mock_test.go index dbd83ca01b..b0067e3b56 100644 --- a/internal/eventstore/handler/crdb/db_mock_test.go +++ b/internal/eventstore/handler/crdb/db_mock_test.go @@ -3,7 +3,7 @@ package crdb import ( "database/sql" "database/sql/driver" - "log" + "sort" "strings" "time" @@ -123,22 +123,22 @@ func expectSavePointRelease() func(sqlmock.Sqlmock) { } } -func expectCurrentSequence(tableName, projection string, seq uint64, aggregateType string) func(sqlmock.Sqlmock) { +func expectCurrentSequence(tableName, projection string, seq uint64, aggregateType, instanceID string) func(sqlmock.Sqlmock) { return func(m sqlmock.Sqlmock) { - m.ExpectQuery(`SELECT current_sequence, aggregate_type FROM ` + tableName + ` WHERE projection_name = \$1 FOR UPDATE`). + m.ExpectQuery(`SELECT current_sequence, aggregate_type, instance_id FROM ` + tableName + ` WHERE projection_name = \$1 FOR UPDATE`). WithArgs( projection, ). WillReturnRows( - sqlmock.NewRows([]string{"current_sequence", "aggregate_type"}). - AddRow(seq, aggregateType), + sqlmock.NewRows([]string{"current_sequence", "aggregate_type", "instance_id"}). + AddRow(seq, aggregateType, instanceID), ) } } func expectCurrentSequenceErr(tableName, projection string, err error) func(sqlmock.Sqlmock) { return func(m sqlmock.Sqlmock) { - m.ExpectQuery(`SELECT current_sequence, aggregate_type FROM ` + tableName + ` WHERE projection_name = \$1 FOR UPDATE`). + m.ExpectQuery(`SELECT current_sequence, aggregate_type, instance_id FROM ` + tableName + ` WHERE projection_name = \$1 FOR UPDATE`). WithArgs( projection, ). @@ -148,37 +148,38 @@ func expectCurrentSequenceErr(tableName, projection string, err error) func(sqlm func expectCurrentSequenceNoRows(tableName, projection string) func(sqlmock.Sqlmock) { return func(m sqlmock.Sqlmock) { - m.ExpectQuery(`SELECT current_sequence, aggregate_type FROM ` + tableName + ` WHERE projection_name = \$1 FOR UPDATE`). + m.ExpectQuery(`SELECT current_sequence, aggregate_type, instance_id FROM ` + tableName + ` WHERE projection_name = \$1 FOR UPDATE`). WithArgs( projection, ). WillReturnRows( - sqlmock.NewRows([]string{"current_sequence", "aggregate_type"}), + sqlmock.NewRows([]string{"current_sequence", "aggregate_type", "instance_id"}), ) } } func expectCurrentSequenceScanErr(tableName, projection string) func(sqlmock.Sqlmock) { return func(m sqlmock.Sqlmock) { - m.ExpectQuery(`SELECT current_sequence, aggregate_type FROM ` + tableName + ` WHERE projection_name = \$1 FOR UPDATE`). + m.ExpectQuery(`SELECT current_sequence, aggregate_type, instance_id FROM ` + tableName + ` WHERE projection_name = \$1 FOR UPDATE`). WithArgs( projection, ). WillReturnRows( - sqlmock.NewRows([]string{"current_sequence", "aggregate_type"}). + sqlmock.NewRows([]string{"current_sequence", "aggregate_type", "instance_id"}). RowError(0, sql.ErrTxDone). - AddRow(0, "agg"), + AddRow(0, "agg", "instanceID"), ) } } -func expectUpdateCurrentSequence(tableName, projection string, seq uint64, aggregateType string) func(sqlmock.Sqlmock) { +func expectUpdateCurrentSequence(tableName, projection string, seq uint64, aggregateType, instanceID string) func(sqlmock.Sqlmock) { return func(m sqlmock.Sqlmock) { - m.ExpectExec("UPSERT INTO "+tableName+` \(projection_name, aggregate_type, current_sequence, timestamp\) VALUES \(\$1, \$2, \$3, NOW\(\)\)`). + m.ExpectExec("UPSERT INTO "+tableName+` \(projection_name, aggregate_type, current_sequence, instance_id, timestamp\) VALUES \(\$1, \$2, \$3, \$4, NOW\(\)\)`). WithArgs( projection, aggregateType, seq, + instanceID, ). WillReturnResult( sqlmock.NewResult(1, 1), @@ -187,16 +188,26 @@ func expectUpdateCurrentSequence(tableName, projection string, seq uint64, aggre } func expectUpdateTwoCurrentSequence(tableName, projection string, sequences currentSequences) func(sqlmock.Sqlmock) { + //sort them so the args will always have the same order + keys := make([]string, 0, len(sequences)) + for k := range sequences { + keys = append(keys, string(k)) + } + sort.Strings(keys) + args := make([]driver.Value, len(keys)*4) + for i, k := range keys { + aggregateType := eventstore.AggregateType(k) + for _, sequence := range sequences[aggregateType] { + args[i*4] = projection + args[i*4+1] = aggregateType + args[i*4+2] = sequence.sequence + args[i*4+3] = sequence.instanceID + } + } return func(m sqlmock.Sqlmock) { - matcher := ¤tSequenceMatcher{seq: sequences} - m.ExpectExec("UPSERT INTO "+tableName+` \(projection_name, aggregate_type, current_sequence, timestamp\) VALUES \(\$1, \$2, \$3, NOW\(\)\), \(\$4, \$5, \$6, NOW\(\)\)`). + m.ExpectExec("UPSERT INTO " + tableName + ` \(projection_name, aggregate_type, current_sequence, instance_id, timestamp\) VALUES \(\$1, \$2, \$3, \$4, NOW\(\)\), \(\$5, \$6, \$7, \$8, NOW\(\)\)`). WithArgs( - projection, - matcher, - matcher, - projection, - matcher, - matcher, + args..., ). WillReturnResult( sqlmock.NewResult(1, 1), @@ -204,51 +215,27 @@ func expectUpdateTwoCurrentSequence(tableName, projection string, sequences curr } } -type currentSequenceMatcher struct { - seq currentSequences - currentAggregate eventstore.AggregateType -} - -func (m *currentSequenceMatcher) Match(value driver.Value) bool { - switch v := value.(type) { - case string: - if m.currentAggregate != "" { - log.Printf("expected sequence of %s but got next aggregate type %s", m.currentAggregate, value) - return false - } - _, ok := m.seq[eventstore.AggregateType(v)] - if !ok { - return false - } - m.currentAggregate = eventstore.AggregateType(v) - return true - default: - seq := m.seq[m.currentAggregate] - m.currentAggregate = "" - delete(m.seq, m.currentAggregate) - return int64(seq) == value.(int64) - } -} - -func expectUpdateCurrentSequenceErr(tableName, projection string, seq uint64, err error, aggregateType string) func(sqlmock.Sqlmock) { +func expectUpdateCurrentSequenceErr(tableName, projection string, seq uint64, err error, aggregateType, instanceID string) func(sqlmock.Sqlmock) { return func(m sqlmock.Sqlmock) { - m.ExpectExec("UPSERT INTO "+tableName+` \(projection_name, aggregate_type, current_sequence, timestamp\) VALUES \(\$1, \$2, \$3, NOW\(\)\)`). + m.ExpectExec("UPSERT INTO "+tableName+` \(projection_name, aggregate_type, current_sequence, instance_id, timestamp\) VALUES \(\$1, \$2, \$3, \$4, NOW\(\)\)`). WithArgs( projection, aggregateType, seq, + instanceID, ). WillReturnError(err) } } -func expectUpdateCurrentSequenceNoRows(tableName, projection string, seq uint64, aggregateType string) func(sqlmock.Sqlmock) { +func expectUpdateCurrentSequenceNoRows(tableName, projection string, seq uint64, aggregateType, instanceID string) func(sqlmock.Sqlmock) { return func(m sqlmock.Sqlmock) { - m.ExpectExec("UPSERT INTO "+tableName+` \(projection_name, aggregate_type, current_sequence, timestamp\) VALUES \(\$1, \$2, \$3, NOW\(\)\)`). + m.ExpectExec("UPSERT INTO "+tableName+` \(projection_name, aggregate_type, current_sequence, instance_id, timestamp\) VALUES \(\$1, \$2, \$3, \$4, NOW\(\)\)`). WithArgs( projection, aggregateType, seq, + instanceID, ). WillReturnResult( sqlmock.NewResult(0, 0), @@ -256,17 +243,18 @@ func expectUpdateCurrentSequenceNoRows(tableName, projection string, seq uint64, } } -func expectLock(lockTable, workerName string, d time.Duration) func(sqlmock.Sqlmock) { +func expectLock(lockTable, workerName string, d time.Duration, instanceID string) func(sqlmock.Sqlmock) { return func(m sqlmock.Sqlmock) { m.ExpectExec(`INSERT INTO `+lockTable+ - ` \(locker_id, locked_until, projection_name\) VALUES \(\$1, now\(\)\+\$2::INTERVAL, \$3\)`+ - ` ON CONFLICT \(projection_name\)`+ + ` \(locker_id, locked_until, projection_name, instance_id\) VALUES \(\$1, now\(\)\+\$2::INTERVAL, \$3\, \$4\)`+ + ` ON CONFLICT \(projection_name, instance_id\)`+ ` DO UPDATE SET locker_id = \$1, locked_until = now\(\)\+\$2::INTERVAL`+ - ` WHERE `+lockTable+`\.projection_name = \$3 AND \(`+lockTable+`\.locker_id = \$1 OR `+lockTable+`\.locked_until < now\(\)\)`). + ` WHERE `+lockTable+`\.projection_name = \$3 AND `+lockTable+`\.instance_id = \$4 AND \(`+lockTable+`\.locker_id = \$1 OR `+lockTable+`\.locked_until < now\(\)\)`). WithArgs( workerName, float64(d), projectionName, + instanceID, ). WillReturnResult( sqlmock.NewResult(1, 1), @@ -274,33 +262,35 @@ func expectLock(lockTable, workerName string, d time.Duration) func(sqlmock.Sqlm } } -func expectLockNoRows(lockTable, workerName string, d time.Duration) func(sqlmock.Sqlmock) { +func expectLockNoRows(lockTable, workerName string, d time.Duration, instanceID string) func(sqlmock.Sqlmock) { return func(m sqlmock.Sqlmock) { m.ExpectExec(`INSERT INTO `+lockTable+ - ` \(locker_id, locked_until, projection_name\) VALUES \(\$1, now\(\)\+\$2::INTERVAL, \$3\)`+ - ` ON CONFLICT \(projection_name\)`+ + ` \(locker_id, locked_until, projection_name, instance_id\) VALUES \(\$1, now\(\)\+\$2::INTERVAL, \$3\, \$4\)`+ + ` ON CONFLICT \(projection_name, instance_id\)`+ ` DO UPDATE SET locker_id = \$1, locked_until = now\(\)\+\$2::INTERVAL`+ - ` WHERE `+lockTable+`\.projection_name = \$3 AND \(`+lockTable+`\.locker_id = \$1 OR `+lockTable+`\.locked_until < now\(\)\)`). + ` WHERE `+lockTable+`\.projection_name = \$3 AND `+lockTable+`\.instance_id = \$4 AND \(`+lockTable+`\.locker_id = \$1 OR `+lockTable+`\.locked_until < now\(\)\)`). WithArgs( workerName, float64(d), projectionName, + instanceID, ). WillReturnResult(driver.ResultNoRows) } } -func expectLockErr(lockTable, workerName string, d time.Duration, err error) func(sqlmock.Sqlmock) { +func expectLockErr(lockTable, workerName string, d time.Duration, instanceID string, err error) func(sqlmock.Sqlmock) { return func(m sqlmock.Sqlmock) { m.ExpectExec(`INSERT INTO `+lockTable+ - ` \(locker_id, locked_until, projection_name\) VALUES \(\$1, now\(\)\+\$2::INTERVAL, \$3\)`+ - ` ON CONFLICT \(projection_name\)`+ + ` \(locker_id, locked_until, projection_name, instance_id\) VALUES \(\$1, now\(\)\+\$2::INTERVAL, \$3\, \$4\)`+ + ` ON CONFLICT \(projection_name, instance_id\)`+ ` DO UPDATE SET locker_id = \$1, locked_until = now\(\)\+\$2::INTERVAL`+ - ` WHERE `+lockTable+`\.projection_name = \$3 AND \(`+lockTable+`\.locker_id = \$1 OR `+lockTable+`\.locked_until < now\(\)\)`). + ` WHERE `+lockTable+`\.projection_name = \$3 AND `+lockTable+`\.instance_id = \$4 AND \(`+lockTable+`\.locker_id = \$1 OR `+lockTable+`\.locked_until < now\(\)\)`). WithArgs( workerName, float64(d), projectionName, + instanceID, ). WillReturnError(err) } diff --git a/internal/eventstore/handler/crdb/handler_stmt.go b/internal/eventstore/handler/crdb/handler_stmt.go index 4bcf1e1f1b..f330a11546 100644 --- a/internal/eventstore/handler/crdb/handler_stmt.go +++ b/internal/eventstore/handler/crdb/handler_stmt.go @@ -101,15 +101,34 @@ func (h *StatementHandler) SearchQuery() (*eventstore.SearchQueryBuilder, uint64 queryBuilder := eventstore.NewSearchQueryBuilder(eventstore.ColumnsEvent).Limit(h.bulkLimit) for _, aggregateType := range h.aggregates { + instances := make([]string, 0) + for _, sequence := range sequences[aggregateType] { + instances = appendToIgnoredInstances(instances, sequence.instanceID) + queryBuilder. + AddQuery(). + AggregateTypes(aggregateType). + SequenceGreater(sequence.sequence). + InstanceID(sequence.instanceID) + } queryBuilder. AddQuery(). AggregateTypes(aggregateType). - SequenceGreater(sequences[aggregateType]) + SequenceGreater(0). + ExcludedInstanceID(instances...) } return queryBuilder, h.bulkLimit, nil } +func appendToIgnoredInstances(instances []string, id string) []string { + for _, instance := range instances { + if instance == id { + return instances + } + } + return append(instances, id) +} + //Update implements handler.Update func (h *StatementHandler) Update(ctx context.Context, stmts []*handler.Statement, reduce handler.Reduce) (unexecutedStmts []*handler.Statement, err error) { tx, err := h.client.BeginTx(ctx, nil) @@ -127,7 +146,7 @@ func (h *StatementHandler) Update(ctx context.Context, stmts []*handler.Statemen // because there could be events between current sequence and a creation event // and we cannot check via stmt.PreviousSequence if stmts[0].PreviousSequence == 0 { - previousStmts, err := h.fetchPreviousStmts(ctx, stmts[0].Sequence, sequences, reduce) + previousStmts, err := h.fetchPreviousStmts(ctx, stmts[0].Sequence, stmts[0].InstanceID, sequences, reduce) if err != nil { tx.Rollback() return stmts, err @@ -164,27 +183,25 @@ func (h *StatementHandler) Update(ctx context.Context, stmts []*handler.Statemen return unexecutedStmts, nil } -func (h *StatementHandler) fetchPreviousStmts( - ctx context.Context, - stmtSeq uint64, - sequences currentSequences, - reduce handler.Reduce, -) (previousStmts []*handler.Statement, err error) { +func (h *StatementHandler) fetchPreviousStmts(ctx context.Context, stmtSeq uint64, instanceID string, sequences currentSequences, reduce handler.Reduce) (previousStmts []*handler.Statement, err error) { query := eventstore.NewSearchQueryBuilder(eventstore.ColumnsEvent) queriesAdded := false for _, aggregateType := range h.aggregates { - if stmtSeq <= sequences[aggregateType] { - continue + for _, sequence := range sequences[aggregateType] { + if stmtSeq <= sequence.sequence && instanceID == sequence.instanceID { + continue + } + + query. + AddQuery(). + AggregateTypes(aggregateType). + SequenceGreater(sequence.sequence). + SequenceLess(stmtSeq). + InstanceID(sequence.instanceID) + + queriesAdded = true } - - query. - AddQuery(). - AggregateTypes(aggregateType). - SequenceGreater(sequences[aggregateType]). - SequenceLess(stmtSeq) - - queriesAdded = true } if !queriesAdded { @@ -214,16 +231,19 @@ func (h *StatementHandler) executeStmts( lastSuccessfulIdx := -1 for i, stmt := range stmts { - if stmt.Sequence <= sequences[stmt.AggregateType] { - continue - } - if stmt.PreviousSequence > 0 && stmt.PreviousSequence != sequences[stmt.AggregateType] { - logging.WithFields("projection", h.ProjectionName, "aggregateType", stmt.AggregateType, "sequence", stmt.Sequence, "prevSeq", stmt.PreviousSequence, "currentSeq", sequences[stmt.AggregateType]).Warn("sequences do not match") - break + for _, sequence := range sequences[stmt.AggregateType] { + if stmt.Sequence <= sequence.sequence && stmt.InstanceID == sequence.instanceID { + continue + } + if stmt.PreviousSequence > 0 && stmt.PreviousSequence != sequence.sequence && stmt.InstanceID == sequence.instanceID { + logging.WithFields("projection", h.ProjectionName, "aggregateType", stmt.AggregateType, "sequence", stmt.Sequence, "prevSeq", stmt.PreviousSequence, "currentSeq", sequences[stmt.AggregateType]).Warn("sequences do not match") + break + } } err := h.executeStmt(tx, stmt) if err == nil { - sequences[stmt.AggregateType], lastSuccessfulIdx = stmt.Sequence, i + updateSequences(sequences, stmt) + lastSuccessfulIdx = i continue } @@ -232,7 +252,9 @@ func (h *StatementHandler) executeStmts( break } - sequences[stmt.AggregateType], lastSuccessfulIdx = stmt.Sequence, i + updateSequences(sequences, stmt) + lastSuccessfulIdx = i + continue } return lastSuccessfulIdx } @@ -261,3 +283,16 @@ func (h *StatementHandler) executeStmt(tx *sql.Tx, stmt *handler.Statement) erro } return nil } + +func updateSequences(sequences currentSequences, stmt *handler.Statement) { + for _, sequence := range sequences[stmt.AggregateType] { + if sequence.instanceID == stmt.InstanceID { + sequence.sequence = stmt.Sequence + return + } + } + sequences[stmt.AggregateType] = append(sequences[stmt.AggregateType], &instanceSequence{ + instanceID: stmt.InstanceID, + sequence: stmt.Sequence, + }) +} diff --git a/internal/eventstore/handler/crdb/handler_stmt_test.go b/internal/eventstore/handler/crdb/handler_stmt_test.go index e6e1936651..1fc436483d 100644 --- a/internal/eventstore/handler/crdb/handler_stmt_test.go +++ b/internal/eventstore/handler/crdb/handler_stmt_test.go @@ -9,6 +9,7 @@ import ( "time" "github.com/DATA-DOG/go-sqlmock" + "github.com/stretchr/testify/assert" "github.com/caos/zitadel/internal/eventstore" "github.com/caos/zitadel/internal/eventstore/handler" @@ -97,13 +98,18 @@ func TestProjectionHandler_SearchQuery(t *testing.T) { return err == nil }, expectations: []mockExpectation{ - expectCurrentSequence("my_sequences", "my_projection", 5, "testAgg"), + expectCurrentSequence("my_sequences", "my_projection", 5, "testAgg", "instanceID"), }, SearchQueryBuilder: eventstore. NewSearchQueryBuilder(eventstore.ColumnsEvent). AddQuery(). AggregateTypes("testAgg"). SequenceGreater(5). + InstanceID("instanceID"). + Or(). + AggregateTypes("testAgg"). + SequenceGreater(0). + ExcludedInstanceID("instanceID"). Builder(). Limit(5), }, @@ -225,7 +231,7 @@ func TestStatementHandler_Update(t *testing.T) { want: want{ expectations: []mockExpectation{ expectBegin(), - expectCurrentSequence("my_sequences", "my_projection", 5, "testAgg"), + expectCurrentSequence("my_sequences", "my_projection", 5, "testAgg", "instanceID"), expectRollback(), }, isErr: func(err error) bool { @@ -262,7 +268,7 @@ func TestStatementHandler_Update(t *testing.T) { want: want{ expectations: []mockExpectation{ expectBegin(), - expectCurrentSequence("my_sequences", "my_projection", 5, "testAgg"), + expectCurrentSequence("my_sequences", "my_projection", 5, "testAgg", "instanceID"), expectCommit(), }, isErr: func(err error) bool { @@ -287,6 +293,7 @@ func TestStatementHandler_Update(t *testing.T) { aggregateType: "agg", sequence: 7, previousSequence: 5, + instanceID: "instanceID", }, []handler.Column{ { @@ -299,11 +306,11 @@ func TestStatementHandler_Update(t *testing.T) { want: want{ expectations: []mockExpectation{ expectBegin(), - expectCurrentSequence("my_sequences", "my_projection", 5, "agg"), + expectCurrentSequence("my_sequences", "my_projection", 5, "agg", "instanceID"), expectSavePoint(), expectCreate("my_projection", []string{"col"}, []string{"$1"}), expectSavePointRelease(), - expectUpdateCurrentSequenceNoRows("my_sequences", "my_projection", 7, "agg"), + expectUpdateCurrentSequenceNoRows("my_sequences", "my_projection", 7, "agg", "instanceID"), expectRollback(), }, isErr: func(err error) bool { @@ -328,6 +335,7 @@ func TestStatementHandler_Update(t *testing.T) { aggregateType: "agg", sequence: 7, previousSequence: 5, + instanceID: "instanceID", }, []handler.Column{ { @@ -340,11 +348,11 @@ func TestStatementHandler_Update(t *testing.T) { want: want{ expectations: []mockExpectation{ expectBegin(), - expectCurrentSequence("my_sequences", "my_projection", 5, "agg"), + expectCurrentSequence("my_sequences", "my_projection", 5, "agg", "instanceID"), expectSavePoint(), expectCreate("my_projection", []string{"col"}, []string{"$1"}), expectSavePointRelease(), - expectUpdateCurrentSequence("my_sequences", "my_projection", 7, "agg"), + expectUpdateCurrentSequence("my_sequences", "my_projection", 7, "agg", "instanceID"), expectCommitErr(sql.ErrConnDone), }, isErr: func(err error) bool { @@ -368,14 +376,15 @@ func TestStatementHandler_Update(t *testing.T) { aggregateType: "testAgg", sequence: 7, previousSequence: 5, + instanceID: "instanceID", }), }, }, want: want{ expectations: []mockExpectation{ expectBegin(), - expectCurrentSequence("my_sequences", "my_projection", 5, "testAgg"), - expectUpdateCurrentSequence("my_sequences", "my_projection", 7, "testAgg"), + expectCurrentSequence("my_sequences", "my_projection", 5, "testAgg", "instanceID"), + expectUpdateCurrentSequence("my_sequences", "my_projection", 7, "testAgg", "instanceID"), expectCommit(), }, isErr: func(err error) bool { @@ -399,14 +408,15 @@ func TestStatementHandler_Update(t *testing.T) { aggregateType: "testAgg", sequence: 7, previousSequence: 0, + instanceID: "instanceID", }), }, }, want: want{ expectations: []mockExpectation{ expectBegin(), - expectCurrentSequence("my_sequences", "my_projection", 5, "testAgg"), - expectUpdateCurrentSequence("my_sequences", "my_projection", 7, "testAgg"), + expectCurrentSequence("my_sequences", "my_projection", 5, "testAgg", "instanceID"), + expectUpdateCurrentSequence("my_sequences", "my_projection", 7, "testAgg", "instanceID"), expectCommit(), }, isErr: func(err error) bool { @@ -423,6 +433,7 @@ func TestStatementHandler_Update(t *testing.T) { AggregateType: "testAgg", Sequence: 6, PreviousAggregateSequence: 5, + InstanceID: "instanceID", }, ), ), @@ -435,6 +446,7 @@ func TestStatementHandler_Update(t *testing.T) { aggregateType: "testAgg", sequence: 7, previousSequence: 0, + instanceID: "instanceID", }), }, reduce: testReduce(), @@ -442,8 +454,8 @@ func TestStatementHandler_Update(t *testing.T) { want: want{ expectations: []mockExpectation{ expectBegin(), - expectCurrentSequence("my_sequences", "my_projection", 5, "testAgg"), - expectUpdateCurrentSequence("my_sequences", "my_projection", 7, "testAgg"), + expectCurrentSequence("my_sequences", "my_projection", 5, "testAgg", "instanceID"), + expectUpdateCurrentSequence("my_sequences", "my_projection", 7, "testAgg", "instanceID"), expectCommit(), }, isErr: func(err error) bool { @@ -537,7 +549,9 @@ func TestProjectionHandler_fetchPreviousStmts(t *testing.T) { ctx: context.Background(), reduce: testReduce(), sequences: currentSequences{ - "testAgg": 5, + "testAgg": []*instanceSequence{ + {sequence: 5}, + }, }, stmtSeq: 6, }, @@ -560,7 +574,9 @@ func TestProjectionHandler_fetchPreviousStmts(t *testing.T) { ctx: context.Background(), reduce: testReduce(), sequences: currentSequences{ - "testAgg": 5, + "testAgg": []*instanceSequence{ + {sequence: 5}, + }, }, stmtSeq: 6, }, @@ -582,7 +598,9 @@ func TestProjectionHandler_fetchPreviousStmts(t *testing.T) { ctx: context.Background(), reduce: testReduce(), sequences: currentSequences{ - "testAgg": 5, + "testAgg": []*instanceSequence{ + {sequence: 5}, + }, }, stmtSeq: 10, }, @@ -626,7 +644,9 @@ func TestProjectionHandler_fetchPreviousStmts(t *testing.T) { ctx: context.Background(), reduce: testReduceErr(errReduce), sequences: currentSequences{ - "testAgg": 5, + "testAgg": []*instanceSequence{ + {sequence: 5}, + }, }, stmtSeq: 10, }, @@ -667,7 +687,7 @@ func TestProjectionHandler_fetchPreviousStmts(t *testing.T) { }), aggregates: tt.fields.aggregates, } - stmts, err := h.fetchPreviousStmts(tt.args.ctx, tt.args.stmtSeq, tt.args.sequences, tt.args.reduce) + stmts, err := h.fetchPreviousStmts(tt.args.ctx, tt.args.stmtSeq, "", tt.args.sequences, tt.args.reduce) if !tt.want.isErr(err) { t.Errorf("ProjectionHandler.prepareBulkStmts() error = %v", err) return @@ -720,7 +740,9 @@ func TestStatementHandler_executeStmts(t *testing.T) { }), }, sequences: currentSequences{ - "agg": 5, + "agg": []*instanceSequence{ + {sequence: 5}, + }, }, }, want: want{ @@ -762,7 +784,9 @@ func TestStatementHandler_executeStmts(t *testing.T) { }), }, sequences: currentSequences{ - "agg": 2, + "agg": []*instanceSequence{ + {sequence: 2}, + }, }, }, want: want{ @@ -824,7 +848,9 @@ func TestStatementHandler_executeStmts(t *testing.T) { }), }, sequences: currentSequences{ - "agg": 2, + "agg": []*instanceSequence{ + {sequence: 2}, + }, }, }, want: want{ @@ -891,7 +917,9 @@ func TestStatementHandler_executeStmts(t *testing.T) { }), }, sequences: currentSequences{ - "agg": 2, + "agg": []*instanceSequence{ + {sequence: 2}, + }, }, }, want: want{ @@ -979,7 +1007,9 @@ func TestStatementHandler_executeStmts(t *testing.T) { ), }, sequences: currentSequences{ - "agg": 2, + "agg": []*instanceSequence{ + {sequence: 2}, + }, }, }, want: want{ @@ -1309,9 +1339,7 @@ func TestStatementHandler_currentSequence(t *testing.T) { expectations: []mockExpectation{ expectCurrentSequenceNoRows("my_table", "my_projection"), }, - sequences: currentSequences{ - "agg": 0, - }, + sequences: currentSequences{}, }, }, { @@ -1331,9 +1359,7 @@ func TestStatementHandler_currentSequence(t *testing.T) { expectations: []mockExpectation{ expectCurrentSequenceScanErr("my_table", "my_projection"), }, - sequences: currentSequences{ - "agg": 0, - }, + sequences: currentSequences{}, }, }, { @@ -1351,10 +1377,15 @@ func TestStatementHandler_currentSequence(t *testing.T) { return errors.Is(err, nil) }, expectations: []mockExpectation{ - expectCurrentSequence("my_table", "my_projection", 5, "agg"), + expectCurrentSequence("my_table", "my_projection", 5, "agg", "instanceID"), }, sequences: currentSequences{ - "agg": 5, + "agg": []*instanceSequence{ + { + sequence: 5, + instanceID: "instanceID", + }, + }, }, }, }, @@ -1404,9 +1435,7 @@ func TestStatementHandler_currentSequence(t *testing.T) { } for _, aggregateType := range tt.fields.aggregates { - if seq[aggregateType] != tt.want.sequences[aggregateType] { - t.Errorf("unexpected sequence in aggregate type %s: want %d got %d", aggregateType, tt.want.sequences[aggregateType], seq[aggregateType]) - } + assert.Equal(t, tt.want.sequences[aggregateType], seq[aggregateType]) } }) } @@ -1440,7 +1469,12 @@ func TestStatementHandler_updateCurrentSequence(t *testing.T) { }, args: args{ sequences: currentSequences{ - "agg": 5, + "agg": []*instanceSequence{ + { + sequence: 5, + instanceID: "instanceID", + }, + }, }, }, want: want{ @@ -1448,7 +1482,7 @@ func TestStatementHandler_updateCurrentSequence(t *testing.T) { return errors.Is(err, sql.ErrConnDone) }, expectations: []mockExpectation{ - expectUpdateCurrentSequenceErr("my_table", "my_projection", 5, sql.ErrConnDone, "agg"), + expectUpdateCurrentSequenceErr("my_table", "my_projection", 5, sql.ErrConnDone, "agg", "instanceID"), }, }, }, @@ -1461,7 +1495,12 @@ func TestStatementHandler_updateCurrentSequence(t *testing.T) { }, args: args{ sequences: currentSequences{ - "agg": 5, + "agg": []*instanceSequence{ + { + sequence: 5, + instanceID: "instanceID", + }, + }, }, }, want: want{ @@ -1469,7 +1508,7 @@ func TestStatementHandler_updateCurrentSequence(t *testing.T) { return errors.As(err, &errSeqNotUpdated) }, expectations: []mockExpectation{ - expectUpdateCurrentSequenceNoRows("my_table", "my_projection", 5, "agg"), + expectUpdateCurrentSequenceNoRows("my_table", "my_projection", 5, "agg", "instanceID"), }, }, }, @@ -1482,7 +1521,12 @@ func TestStatementHandler_updateCurrentSequence(t *testing.T) { }, args: args{ sequences: currentSequences{ - "agg": 5, + "agg": []*instanceSequence{ + { + sequence: 5, + instanceID: "instanceID", + }, + }, }, }, want: want{ @@ -1490,7 +1534,7 @@ func TestStatementHandler_updateCurrentSequence(t *testing.T) { return err == nil }, expectations: []mockExpectation{ - expectUpdateCurrentSequence("my_table", "my_projection", 5, "agg"), + expectUpdateCurrentSequence("my_table", "my_projection", 5, "agg", "instanceID"), }, }, }, @@ -1503,8 +1547,18 @@ func TestStatementHandler_updateCurrentSequence(t *testing.T) { }, args: args{ sequences: currentSequences{ - "agg": 5, - "agg2": 6, + "agg": []*instanceSequence{ + { + sequence: 5, + instanceID: "instanceID", + }, + }, + "agg2": []*instanceSequence{ + { + sequence: 6, + instanceID: "instanceID", + }, + }, }, }, want: want{ @@ -1513,9 +1567,19 @@ func TestStatementHandler_updateCurrentSequence(t *testing.T) { }, expectations: []mockExpectation{ expectUpdateTwoCurrentSequence("my_table", "my_projection", currentSequences{ - "agg": 5, - "agg2": 6}, - ), + "agg": []*instanceSequence{ + { + sequence: 5, + instanceID: "instanceID", + }, + }, + "agg2": []*instanceSequence{ + { + sequence: 6, + instanceID: "instanceID", + }, + }, + }), }, }, }, diff --git a/internal/eventstore/handler/crdb/lock.go b/internal/eventstore/handler/crdb/lock.go index a344cc1099..3038b40979 100644 --- a/internal/eventstore/handler/crdb/lock.go +++ b/internal/eventstore/handler/crdb/lock.go @@ -15,15 +15,15 @@ import ( const ( lockStmtFormat = "INSERT INTO %[1]s" + - " (locker_id, locked_until, projection_name) VALUES ($1, now()+$2::INTERVAL, $3)" + - " ON CONFLICT (projection_name)" + + " (locker_id, locked_until, projection_name, instance_id) VALUES ($1, now()+$2::INTERVAL, $3, $4)" + + " ON CONFLICT (projection_name, instance_id)" + " DO UPDATE SET locker_id = $1, locked_until = now()+$2::INTERVAL" + - " WHERE %[1]s.projection_name = $3 AND (%[1]s.locker_id = $1 OR %[1]s.locked_until < now())" + " WHERE %[1]s.projection_name = $3 AND %[1]s.instance_id = $4 AND (%[1]s.locker_id = $1 OR %[1]s.locked_until < now())" ) type Locker interface { - Lock(ctx context.Context, lockDuration time.Duration) <-chan error - Unlock() error + Lock(ctx context.Context, lockDuration time.Duration, instanceID string) <-chan error + Unlock(instanceID string) error } type locker struct { @@ -47,18 +47,18 @@ func NewLocker(client *sql.DB, lockTable, projectionName string) Locker { } } -func (h *locker) Lock(ctx context.Context, lockDuration time.Duration) <-chan error { +func (h *locker) Lock(ctx context.Context, lockDuration time.Duration, instanceID string) <-chan error { errs := make(chan error) - go h.handleLock(ctx, errs, lockDuration) + go h.handleLock(ctx, errs, lockDuration, instanceID) return errs } -func (h *locker) handleLock(ctx context.Context, errs chan error, lockDuration time.Duration) { +func (h *locker) handleLock(ctx context.Context, errs chan error, lockDuration time.Duration, instanceID string) { renewLock := time.NewTimer(0) for { select { case <-renewLock.C: - errs <- h.renewLock(ctx, lockDuration) + errs <- h.renewLock(ctx, lockDuration, instanceID) //refresh the lock 500ms before it times out. 500ms should be enough for one transaction renewLock.Reset(lockDuration - (500 * time.Millisecond)) case <-ctx.Done(): @@ -69,9 +69,9 @@ func (h *locker) handleLock(ctx context.Context, errs chan error, lockDuration t } } -func (h *locker) renewLock(ctx context.Context, lockDuration time.Duration) error { +func (h *locker) renewLock(ctx context.Context, lockDuration time.Duration, instanceID string) error { //the unit of crdb interval is seconds (https://www.cockroachlabs.com/docs/stable/interval.html). - res, err := h.client.Exec(h.lockStmt, h.workerName, lockDuration.Seconds(), h.projectionName) + res, err := h.client.Exec(h.lockStmt, h.workerName, lockDuration.Seconds(), h.projectionName, instanceID) if err != nil { return errors.ThrowInternal(err, "CRDB-uaDoR", "unable to execute lock") } @@ -83,8 +83,8 @@ func (h *locker) renewLock(ctx context.Context, lockDuration time.Duration) erro return nil } -func (h *locker) Unlock() error { - _, err := h.client.Exec(h.lockStmt, h.workerName, float64(0), h.projectionName) +func (h *locker) Unlock(instanceID string) error { + _, err := h.client.Exec(h.lockStmt, h.workerName, float64(0), h.projectionName, instanceID) if err != nil { return errors.ThrowUnknown(err, "CRDB-JjfwO", "unlock failed") } diff --git a/internal/eventstore/handler/crdb/lock_test.go b/internal/eventstore/handler/crdb/lock_test.go index ac7991d921..bab920c22b 100644 --- a/internal/eventstore/handler/crdb/lock_test.go +++ b/internal/eventstore/handler/crdb/lock_test.go @@ -32,6 +32,7 @@ func TestStatementHandler_handleLock(t *testing.T) { lockDuration time.Duration ctx context.Context errMock *errsMock + instanceID string } tests := []struct { name string @@ -42,9 +43,9 @@ func TestStatementHandler_handleLock(t *testing.T) { name: "lock fails", want: want{ expectations: []mockExpectation{ - expectLock(lockTable, workerName, 2), - expectLock(lockTable, workerName, 2), - expectLockErr(lockTable, workerName, 2, errLock), + expectLock(lockTable, workerName, 2, "instanceID"), + expectLock(lockTable, workerName, 2, "instanceID"), + expectLockErr(lockTable, workerName, 2, "instanceID", errLock), }, }, args: args{ @@ -55,14 +56,15 @@ func TestStatementHandler_handleLock(t *testing.T) { successfulIters: 2, shouldErr: true, }, + instanceID: "instanceID", }, }, { name: "success", want: want{ expectations: []mockExpectation{ - expectLock(lockTable, workerName, 2), - expectLock(lockTable, workerName, 2), + expectLock(lockTable, workerName, 2, "instanceID"), + expectLock(lockTable, workerName, 2, "instanceID"), }, }, args: args{ @@ -72,6 +74,7 @@ func TestStatementHandler_handleLock(t *testing.T) { errs: make(chan error), successfulIters: 2, }, + instanceID: "instanceID", }, }, } @@ -96,7 +99,7 @@ func TestStatementHandler_handleLock(t *testing.T) { go tt.args.errMock.handleErrs(t, cancel) - go h.handleLock(ctx, tt.args.errMock.errs, tt.args.lockDuration) + go h.handleLock(ctx, tt.args.errMock.errs, tt.args.lockDuration, tt.args.instanceID) <-ctx.Done() @@ -115,6 +118,7 @@ func TestStatementHandler_renewLock(t *testing.T) { } type args struct { lockDuration time.Duration + instanceID string } tests := []struct { name string @@ -125,7 +129,7 @@ func TestStatementHandler_renewLock(t *testing.T) { name: "lock fails", want: want{ expectations: []mockExpectation{ - expectLockErr(lockTable, workerName, 1, sql.ErrTxDone), + expectLockErr(lockTable, workerName, 1, "instanceID", sql.ErrTxDone), }, isErr: func(err error) bool { return errors.Is(err, sql.ErrTxDone) @@ -133,13 +137,14 @@ func TestStatementHandler_renewLock(t *testing.T) { }, args: args{ lockDuration: 1 * time.Second, + instanceID: "instanceID", }, }, { name: "lock no rows", want: want{ expectations: []mockExpectation{ - expectLockNoRows(lockTable, workerName, 2), + expectLockNoRows(lockTable, workerName, 2, "instanceID"), }, isErr: func(err error) bool { return errors.As(err, &renewNoRowsAffectedErr) @@ -147,13 +152,14 @@ func TestStatementHandler_renewLock(t *testing.T) { }, args: args{ lockDuration: 2 * time.Second, + instanceID: "instanceID", }, }, { name: "success", want: want{ expectations: []mockExpectation{ - expectLock(lockTable, workerName, 3), + expectLock(lockTable, workerName, 3, "instanceID"), }, isErr: func(err error) bool { return errors.Is(err, nil) @@ -161,6 +167,7 @@ func TestStatementHandler_renewLock(t *testing.T) { }, args: args{ lockDuration: 3 * time.Second, + instanceID: "instanceID", }, }, } @@ -181,7 +188,7 @@ func TestStatementHandler_renewLock(t *testing.T) { expectation(mock) } - err = h.renewLock(context.Background(), tt.args.lockDuration) + err = h.renewLock(context.Background(), tt.args.lockDuration, tt.args.instanceID) if !tt.want.isErr(err) { t.Errorf("unexpected error = %v", err) } @@ -199,15 +206,22 @@ func TestStatementHandler_Unlock(t *testing.T) { expectations []mockExpectation isErr func(err error) bool } + type args struct { + instanceID string + } tests := []struct { name string + args args want want }{ { name: "unlock fails", + args: args{ + instanceID: "instanceID", + }, want: want{ expectations: []mockExpectation{ - expectLockErr(lockTable, workerName, 0, sql.ErrTxDone), + expectLockErr(lockTable, workerName, 0, "instanceID", sql.ErrTxDone), }, isErr: func(err error) bool { return errors.Is(err, sql.ErrTxDone) @@ -216,9 +230,12 @@ func TestStatementHandler_Unlock(t *testing.T) { }, { name: "success", + args: args{ + instanceID: "instanceID", + }, want: want{ expectations: []mockExpectation{ - expectLock(lockTable, workerName, 0), + expectLock(lockTable, workerName, 0, "instanceID"), }, isErr: func(err error) bool { return errors.Is(err, nil) @@ -243,7 +260,7 @@ func TestStatementHandler_Unlock(t *testing.T) { expectation(mock) } - err = h.Unlock() + err = h.Unlock(tt.args.instanceID) if !tt.want.isErr(err) { t.Errorf("unexpected error = %v", err) } diff --git a/internal/eventstore/handler/handler_projection.go b/internal/eventstore/handler/handler_projection.go index 3127f8cab3..4cfecf2ada 100644 --- a/internal/eventstore/handler/handler_projection.go +++ b/internal/eventstore/handler/handler_projection.go @@ -12,6 +12,8 @@ import ( "github.com/caos/zitadel/internal/eventstore" ) +const systemID = "system" + type ProjectionHandlerConfig struct { HandlerConfig ProjectionName string @@ -27,10 +29,10 @@ type Update func(context.Context, []*Statement, Reduce) (unexecutedStmts []*Stat type Reduce func(eventstore.Event) (*Statement, error) //Lock is used for mutex handling if needed on the projection -type Lock func(context.Context, time.Duration) <-chan error +type Lock func(context.Context, time.Duration, string) <-chan error //Unlock releases the mutex of the projection -type Unlock func() error +type Unlock func(string) error //SearchQuery generates the search query to lookup for events type SearchQuery func() (query *eventstore.SearchQueryBuilder, queryLimit uint64, err error) @@ -183,7 +185,7 @@ func (h *ProjectionHandler) bulk( ctx, cancel := context.WithCancel(ctx) defer cancel() - errs := lock(ctx, h.requeueAfter) + errs := lock(ctx, h.requeueAfter, systemID) //wait until projection is locked if err, ok := <-errs; err != nil || !ok { logging.WithFields("projection", h.ProjectionName).OnError(err).Warn("initial lock failed") @@ -194,7 +196,7 @@ func (h *ProjectionHandler) bulk( execErr := executeBulk(ctx) logging.WithFields("projection", h.ProjectionName).OnError(execErr).Warn("unable to execute") - unlockErr := unlock() + unlockErr := unlock(systemID) logging.WithFields("projection", h.ProjectionName).OnError(unlockErr).Warn("unable to unlock") if execErr != nil { diff --git a/internal/eventstore/handler/handler_projection_test.go b/internal/eventstore/handler/handler_projection_test.go index ec9cbffae4..198a915d97 100644 --- a/internal/eventstore/handler/handler_projection_test.go +++ b/internal/eventstore/handler/handler_projection_test.go @@ -912,7 +912,7 @@ type lockMock struct { } func (m *lockMock) lock() Lock { - return func(ctx context.Context, _ time.Duration) <-chan error { + return func(ctx context.Context, _ time.Duration, _ string) <-chan error { m.callCount++ errs := make(chan error) go func() { @@ -955,7 +955,7 @@ type unlockMock struct { } func (m *unlockMock) unlock() Unlock { - return func() error { + return func(instanceID string) error { m.callCount++ return m.err } diff --git a/internal/eventstore/repository/search_query.go b/internal/eventstore/repository/search_query.go index 18531eb2cc..2d8bd6734c 100644 --- a/internal/eventstore/repository/search_query.go +++ b/internal/eventstore/repository/search_query.go @@ -50,6 +50,8 @@ const ( OperationIn //OperationJSONContains checks if a stored value matches the given json OperationJSONContains + //OperationNotIn checks if a stored value does not match one of the passed value list + OperationNotIn operationCount ) diff --git a/internal/eventstore/repository/sql/crdb.go b/internal/eventstore/repository/sql/crdb.go index 9f5b6c318a..17d4900706 100644 --- a/internal/eventstore/repository/sql/crdb.go +++ b/internal/eventstore/repository/sql/crdb.go @@ -288,8 +288,11 @@ func (db *CRDB) columnName(col repository.Field) string { } func (db *CRDB) conditionFormat(operation repository.Operation) string { - if operation == repository.OperationIn { + switch operation { + case repository.OperationIn: return "%s %s ANY(?)" + case repository.OperationNotIn: + return "%s %s ALL(?)" } return "%s %s ?" } @@ -304,6 +307,8 @@ func (db *CRDB) operation(operation repository.Operation) string { return "<" case repository.OperationJSONContains: return "@>" + case repository.OperationNotIn: + return "<>" } return "" } diff --git a/internal/eventstore/search_query.go b/internal/eventstore/search_query.go index cedf685485..3c29870140 100644 --- a/internal/eventstore/search_query.go +++ b/internal/eventstore/search_query.go @@ -20,6 +20,8 @@ type SearchQuery struct { builder *SearchQueryBuilder aggregateTypes []AggregateType aggregateIDs []string + instanceID string + excludedInstanceIDs []string eventSequenceGreater uint64 eventSequenceLess uint64 eventTypes []EventType @@ -91,9 +93,9 @@ func (builder *SearchQueryBuilder) ResourceOwner(resourceOwner string) *SearchQu } //InstanceID defines the instanceID (system) of the events -func (factory *SearchQueryBuilder) InstanceID(instanceID string) *SearchQueryBuilder { - factory.instanceID = instanceID - return factory +func (builder *SearchQueryBuilder) InstanceID(instanceID string) *SearchQueryBuilder { + builder.instanceID = instanceID + return builder } //OrderDesc changes the sorting order of the returned events to descending @@ -149,6 +151,18 @@ func (query *SearchQuery) AggregateIDs(ids ...string) *SearchQuery { return query } +//InstanceID filters for events with the given instanceID +func (query *SearchQuery) InstanceID(instanceID string) *SearchQuery { + query.instanceID = instanceID + return query +} + +//ExcludedInstanceID filters for events not having the given instanceIDs +func (query *SearchQuery) ExcludedInstanceID(instanceIDs ...string) *SearchQuery { + query.excludedInstanceIDs = instanceIDs + return query +} + //EventTypes filters for events with the given event types func (query *SearchQuery) EventTypes(types ...EventType) *SearchQuery { query.eventTypes = types @@ -180,6 +194,9 @@ func (query *SearchQuery) matches(event Event) bool { if ok := isAggregateIDs(event.Aggregate(), query.aggregateIDs...); len(query.aggregateIDs) > 0 && !ok { return false } + if event.Aggregate().InstanceID != "" && query.instanceID != "" && event.Aggregate().InstanceID != query.instanceID { + return false + } if ok := isEventTypes(event, query.eventTypes...); len(query.eventTypes) > 0 && !ok { return false } @@ -203,6 +220,8 @@ func (builder *SearchQueryBuilder) build(instanceID string) (*repository.SearchQ query.eventDataFilter, query.eventSequenceGreaterFilter, query.eventSequenceLessFilter, + query.instanceIDFilter, + query.excludedInstanceIDFilter, query.builder.resourceOwnerFilter, query.builder.instanceIDFilter, } { @@ -281,6 +300,20 @@ func (query *SearchQuery) eventSequenceLessFilter() *repository.Filter { return repository.NewFilter(repository.FieldSequence, query.eventSequenceLess, sortOrder) } +func (query *SearchQuery) instanceIDFilter() *repository.Filter { + if query.instanceID == "" { + return nil + } + return repository.NewFilter(repository.FieldInstanceID, query.instanceID, repository.OperationEquals) +} + +func (query *SearchQuery) excludedInstanceIDFilter() *repository.Filter { + if len(query.excludedInstanceIDs) == 0 { + return nil + } + return repository.NewFilter(repository.FieldInstanceID, query.excludedInstanceIDs, repository.OperationNotIn) +} + func (builder *SearchQueryBuilder) resourceOwnerFilter() *repository.Filter { if builder.resourceOwner == "" { return nil diff --git a/internal/eventstore/v1/eventstore.go b/internal/eventstore/v1/eventstore.go index 0a1fefa249..e457d2fedf 100644 --- a/internal/eventstore/v1/eventstore.go +++ b/internal/eventstore/v1/eventstore.go @@ -12,7 +12,6 @@ import ( type Eventstore interface { Health(ctx context.Context) error FilterEvents(ctx context.Context, searchQuery *models.SearchQuery) (events []*models.Event, err error) - LatestSequence(ctx context.Context, searchQuery *models.SearchQueryFactory) (uint64, error) Subscribe(aggregates ...models.AggregateType) *Subscription } @@ -35,13 +34,6 @@ func (es *eventstore) FilterEvents(ctx context.Context, searchQuery *models.Sear return es.repo.Filter(ctx, models.FactoryFromSearchQuery(searchQuery)) } -func (es *eventstore) LatestSequence(ctx context.Context, queryFactory *models.SearchQueryFactory) (uint64, error) { - sequenceFactory := *queryFactory - sequenceFactory = *(&sequenceFactory).Columns(models.Columns_Max_Sequence) - sequenceFactory = *(&sequenceFactory).SequenceGreater(0) - return es.repo.LatestSequence(ctx, &sequenceFactory) -} - func (es *eventstore) Health(ctx context.Context) error { return es.repo.Health(ctx) } diff --git a/internal/eventstore/v1/internal/repository/sql/db_mock_test.go b/internal/eventstore/v1/internal/repository/sql/db_mock_test.go index 749e5b06c7..a299ae1a56 100644 --- a/internal/eventstore/v1/internal/repository/sql/db_mock_test.go +++ b/internal/eventstore/v1/internal/repository/sql/db_mock_test.go @@ -12,16 +12,16 @@ import ( ) const ( - selectEscaped = `SELECT creation_date, event_type, event_sequence, previous_aggregate_sequence, event_data, editor_service, editor_user, resource_owner, instance_id, aggregate_type, aggregate_id, aggregate_version FROM eventstore\.events WHERE aggregate_type = \$1` + selectEscaped = `SELECT creation_date, event_type, event_sequence, previous_aggregate_sequence, event_data, editor_service, editor_user, resource_owner, instance_id, aggregate_type, aggregate_id, aggregate_version FROM eventstore\.events WHERE \( aggregate_type = \$1` ) var ( eventColumns = []string{"creation_date", "event_type", "event_sequence", "previous_aggregate_sequence", "event_data", "editor_service", "editor_user", "resource_owner", "instance_id", "aggregate_type", "aggregate_id", "aggregate_version"} - expectedFilterEventsLimitFormat = regexp.MustCompile(selectEscaped + ` ORDER BY event_sequence LIMIT \$2`).String() - expectedFilterEventsDescFormat = regexp.MustCompile(selectEscaped + ` ORDER BY event_sequence DESC`).String() - expectedFilterEventsAggregateIDLimit = regexp.MustCompile(selectEscaped + ` AND aggregate_id = \$2 ORDER BY event_sequence LIMIT \$3`).String() - expectedFilterEventsAggregateIDTypeLimit = regexp.MustCompile(selectEscaped + ` AND aggregate_id = \$2 ORDER BY event_sequence LIMIT \$3`).String() - expectedGetAllEvents = regexp.MustCompile(selectEscaped + ` ORDER BY event_sequence`).String() + expectedFilterEventsLimitFormat = regexp.MustCompile(selectEscaped + ` \) ORDER BY event_sequence LIMIT \$2`).String() + expectedFilterEventsDescFormat = regexp.MustCompile(selectEscaped + ` \) ORDER BY event_sequence DESC`).String() + expectedFilterEventsAggregateIDLimit = regexp.MustCompile(selectEscaped + ` AND aggregate_id = \$2 \) ORDER BY event_sequence LIMIT \$3`).String() + expectedFilterEventsAggregateIDTypeLimit = regexp.MustCompile(selectEscaped + ` AND aggregate_id = \$2 \) ORDER BY event_sequence LIMIT \$3`).String() + expectedGetAllEvents = regexp.MustCompile(selectEscaped + ` \) ORDER BY event_sequence`).String() expectedInsertStatement = regexp.MustCompile(`INSERT INTO eventstore\.events ` + `\(event_type, aggregate_type, aggregate_id, aggregate_version, creation_date, event_data, editor_user, editor_service, resource_owner, instance_id, previous_aggregate_sequence, previous_aggregate_type_sequence\) ` + @@ -172,14 +172,14 @@ func (db *dbMock) expectFilterEventsError(returnedErr error) *dbMock { } func (db *dbMock) expectLatestSequenceFilter(aggregateType string, sequence Sequence) *dbMock { - db.mock.ExpectQuery(`SELECT MAX\(event_sequence\) FROM eventstore\.events WHERE aggregate_type = \$1`). + db.mock.ExpectQuery(`SELECT MAX\(event_sequence\) FROM eventstore\.events WHERE \( aggregate_type = \$1 \)`). WithArgs(aggregateType). WillReturnRows(sqlmock.NewRows([]string{"max_sequence"}).AddRow(sequence)) return db } func (db *dbMock) expectLatestSequenceFilterError(aggregateType string, err error) *dbMock { - db.mock.ExpectQuery(`SELECT MAX\(event_sequence\) FROM eventstore\.events WHERE aggregate_type = \$1`). + db.mock.ExpectQuery(`SELECT MAX\(event_sequence\) FROM eventstore\.events WHERE \( aggregate_type = \$1 \)`). WithArgs(aggregateType).WillReturnError(err) return db } diff --git a/internal/eventstore/v1/internal/repository/sql/filter_test.go b/internal/eventstore/v1/internal/repository/sql/filter_test.go index 382a832f25..e1993de2c8 100644 --- a/internal/eventstore/v1/internal/repository/sql/filter_test.go +++ b/internal/eventstore/v1/internal/repository/sql/filter_test.go @@ -41,7 +41,7 @@ func TestSQL_Filter(t *testing.T) { }, args: args{ events: &mockEvents{t: t}, - searchQuery: es_models.NewSearchQueryFactory("user").Limit(34), + searchQuery: es_models.NewSearchQueryFactory().Limit(34).AddQuery().AggregateTypes("user").Factory(), }, res: res{ eventsLen: 3, @@ -55,7 +55,7 @@ func TestSQL_Filter(t *testing.T) { }, args: args{ events: &mockEvents{t: t}, - searchQuery: es_models.NewSearchQueryFactory("user").OrderDesc(), + searchQuery: es_models.NewSearchQueryFactory().OrderDesc().AddQuery().AggregateTypes("user").Factory(), }, res: res{ eventsLen: 34, @@ -69,7 +69,7 @@ func TestSQL_Filter(t *testing.T) { }, args: args{ events: &mockEvents{t: t}, - searchQuery: es_models.NewSearchQueryFactory("nonAggregate"), + searchQuery: es_models.NewSearchQueryFactory().AddQuery().AggregateTypes("nonAggregate").Factory(), }, res: res{ wantErr: true, @@ -83,7 +83,7 @@ func TestSQL_Filter(t *testing.T) { }, args: args{ events: &mockEvents{t: t}, - searchQuery: es_models.NewSearchQueryFactory("user"), + searchQuery: es_models.NewSearchQueryFactory().AddQuery().AggregateTypes("user").Factory(), }, res: res{ wantErr: true, @@ -97,7 +97,7 @@ func TestSQL_Filter(t *testing.T) { }, args: args{ events: &mockEvents{t: t}, - searchQuery: es_models.NewSearchQueryFactory("user").Limit(5).AggregateIDs("hop"), + searchQuery: es_models.NewSearchQueryFactory().Limit(5).AddQuery().AggregateTypes("user").AggregateIDs("hop").Factory(), }, res: res{ wantErr: false, @@ -111,7 +111,7 @@ func TestSQL_Filter(t *testing.T) { }, args: args{ events: &mockEvents{t: t}, - searchQuery: es_models.NewSearchQueryFactory("user").Limit(5).AggregateIDs("hop"), + searchQuery: es_models.NewSearchQueryFactory().Limit(5).AddQuery().AggregateTypes("user").AggregateIDs("hop").Factory(), }, res: res{ wantErr: false, @@ -176,7 +176,7 @@ func TestSQL_LatestSequence(t *testing.T) { { name: "no events for aggregate", args: args{ - searchQuery: es_models.NewSearchQueryFactory("idiot").Columns(es_models.Columns_Max_Sequence), + searchQuery: es_models.NewSearchQueryFactory().Columns(es_models.Columns_Max_Sequence).AddQuery().AggregateTypes("idiot").Factory(), }, fields: fields{ client: mockDB(t).expectLatestSequenceFilterError("idiot", sql.ErrNoRows), @@ -189,7 +189,7 @@ func TestSQL_LatestSequence(t *testing.T) { { name: "sql query error", args: args{ - searchQuery: es_models.NewSearchQueryFactory("idiot").Columns(es_models.Columns_Max_Sequence), + searchQuery: es_models.NewSearchQueryFactory().Columns(es_models.Columns_Max_Sequence).AddQuery().AggregateTypes("idiot").Factory(), }, fields: fields{ client: mockDB(t).expectLatestSequenceFilterError("idiot", sql.ErrConnDone), @@ -203,7 +203,7 @@ func TestSQL_LatestSequence(t *testing.T) { { name: "events for aggregate found", args: args{ - searchQuery: es_models.NewSearchQueryFactory("user").Columns(es_models.Columns_Max_Sequence), + searchQuery: es_models.NewSearchQueryFactory().Columns(es_models.Columns_Max_Sequence).AddQuery().AggregateTypes("user").Factory(), }, fields: fields{ client: mockDB(t).expectLatestSequenceFilter("user", math.MaxUint64), diff --git a/internal/eventstore/v1/internal/repository/sql/query.go b/internal/eventstore/v1/internal/repository/sql/query.go index f49db454c9..a092f10dc2 100644 --- a/internal/eventstore/v1/internal/repository/sql/query.go +++ b/internal/eventstore/v1/internal/repository/sql/query.go @@ -61,27 +61,31 @@ func buildQuery(queryFactory *es_models.SearchQueryFactory) (query string, limit return query, searchQuery.Limit, values, rowScanner } -func prepareCondition(filters []*es_models.Filter) (clause string, values []interface{}) { - values = make([]interface{}, len(filters)) +func prepareCondition(filters [][]*es_models.Filter) (clause string, values []interface{}) { + values = make([]interface{}, 0, len(filters)) clauses := make([]string, len(filters)) if len(filters) == 0 { return clause, values } for i, filter := range filters { - value := filter.GetValue() - switch value.(type) { - case []bool, []float64, []int64, []string, []es_models.AggregateType, []es_models.EventType, *[]bool, *[]float64, *[]int64, *[]string, *[]es_models.AggregateType, *[]es_models.EventType: - value = pq.Array(value) - } + subClauses := make([]string, 0, len(filter)) + for _, f := range filter { + value := f.GetValue() + switch value.(type) { + case []bool, []float64, []int64, []string, []es_models.AggregateType, []es_models.EventType, *[]bool, *[]float64, *[]int64, *[]string, *[]es_models.AggregateType, *[]es_models.EventType: + value = pq.Array(value) + } - clauses[i] = getCondition(filter) - if clauses[i] == "" { - return "", nil + subClauses = append(subClauses, getCondition(f)) + if subClauses[len(subClauses)-1] == "" { + return "", nil + } + values = append(values, value) } - values[i] = value + clauses[i] = "( " + strings.Join(subClauses, " AND ") + " )" } - return " WHERE " + strings.Join(clauses, " AND "), values + return " WHERE " + strings.Join(clauses, " OR "), values } type scan func(dest ...interface{}) error @@ -162,8 +166,11 @@ func getCondition(filter *es_models.Filter) (condition string) { } func getConditionFormat(operation es_models.Operation) string { - if operation == es_models.Operation_In { + switch operation { + case es_models.Operation_In: return "%s %s ANY(?)" + case es_models.Operation_NotIn: + return "%s %s ALL(?)" } return "%s %s ?" } @@ -200,6 +207,8 @@ func getOperation(operation es_models.Operation) string { return ">" case es_models.Operation_Less: return "<" + case es_models.Operation_NotIn: + return "<>" } return "" } diff --git a/internal/eventstore/v1/internal/repository/sql/query_test.go b/internal/eventstore/v1/internal/repository/sql/query_test.go index c5d4c561dd..6b9c990514 100644 --- a/internal/eventstore/v1/internal/repository/sql/query_test.go +++ b/internal/eventstore/v1/internal/repository/sql/query_test.go @@ -309,7 +309,7 @@ func prepareTestScan(err error, res []interface{}) scan { func Test_prepareCondition(t *testing.T) { type args struct { - filters []*es_models.Filter + filters [][]*es_models.Filter } type res struct { clause string @@ -333,7 +333,7 @@ func Test_prepareCondition(t *testing.T) { { name: "empty filters", args: args{ - filters: []*es_models.Filter{}, + filters: [][]*es_models.Filter{}, }, res: res{ clause: "", @@ -343,8 +343,10 @@ func Test_prepareCondition(t *testing.T) { { name: "invalid condition", args: args{ - filters: []*es_models.Filter{ - es_models.NewFilter(es_models.Field_AggregateID, "wrong", es_models.Operation(-1)), + filters: [][]*es_models.Filter{ + { + es_models.NewFilter(es_models.Field_AggregateID, "wrong", es_models.Operation(-1)), + }, }, }, res: res{ @@ -355,26 +357,30 @@ func Test_prepareCondition(t *testing.T) { { name: "array as condition value", args: args{ - filters: []*es_models.Filter{ - es_models.NewFilter(es_models.Field_AggregateType, []es_models.AggregateType{"user", "org"}, es_models.Operation_In), + filters: [][]*es_models.Filter{ + { + es_models.NewFilter(es_models.Field_AggregateType, []es_models.AggregateType{"user", "org"}, es_models.Operation_In), + }, }, }, res: res{ - clause: " WHERE aggregate_type = ANY(?)", + clause: " WHERE ( aggregate_type = ANY(?) )", values: []interface{}{pq.Array([]es_models.AggregateType{"user", "org"})}, }, }, { name: "multiple filters", args: args{ - filters: []*es_models.Filter{ - es_models.NewFilter(es_models.Field_AggregateType, []es_models.AggregateType{"user", "org"}, es_models.Operation_In), - es_models.NewFilter(es_models.Field_AggregateID, "1234", es_models.Operation_Equals), - es_models.NewFilter(es_models.Field_EventType, []es_models.EventType{"user.created", "org.created"}, es_models.Operation_In), + filters: [][]*es_models.Filter{ + { + es_models.NewFilter(es_models.Field_AggregateType, []es_models.AggregateType{"user", "org"}, es_models.Operation_In), + es_models.NewFilter(es_models.Field_AggregateID, "1234", es_models.Operation_Equals), + es_models.NewFilter(es_models.Field_EventType, []es_models.EventType{"user.created", "org.created"}, es_models.Operation_In), + }, }, }, res: res{ - clause: " WHERE aggregate_type = ANY(?) AND aggregate_id = ? AND event_type = ANY(?)", + clause: " WHERE ( aggregate_type = ANY(?) AND aggregate_id = ? AND event_type = ANY(?) )", values: []interface{}{pq.Array([]es_models.AggregateType{"user", "org"}), "1234", pq.Array([]es_models.EventType{"user.created", "org.created"})}, }, }, @@ -428,10 +434,10 @@ func Test_buildQuery(t *testing.T) { { name: "with order by desc", args: args{ - queryFactory: es_models.NewSearchQueryFactory("user").OrderDesc(), + queryFactory: es_models.NewSearchQueryFactory().OrderDesc().AddQuery().AggregateTypes("user").Factory(), }, res: res{ - query: "SELECT creation_date, event_type, event_sequence, previous_aggregate_sequence, event_data, editor_service, editor_user, resource_owner, instance_id, aggregate_type, aggregate_id, aggregate_version FROM eventstore.events WHERE aggregate_type = $1 ORDER BY event_sequence DESC", + query: "SELECT creation_date, event_type, event_sequence, previous_aggregate_sequence, event_data, editor_service, editor_user, resource_owner, instance_id, aggregate_type, aggregate_id, aggregate_version FROM eventstore.events WHERE ( aggregate_type = $1 ) ORDER BY event_sequence DESC", rowScanner: true, values: []interface{}{es_models.AggregateType("user")}, }, @@ -439,10 +445,10 @@ func Test_buildQuery(t *testing.T) { { name: "with limit", args: args{ - queryFactory: es_models.NewSearchQueryFactory("user").Limit(5), + queryFactory: es_models.NewSearchQueryFactory().Limit(5).AddQuery().AggregateTypes("user").Factory(), }, res: res{ - query: "SELECT creation_date, event_type, event_sequence, previous_aggregate_sequence, event_data, editor_service, editor_user, resource_owner, instance_id, aggregate_type, aggregate_id, aggregate_version FROM eventstore.events WHERE aggregate_type = $1 ORDER BY event_sequence LIMIT $2", + query: "SELECT creation_date, event_type, event_sequence, previous_aggregate_sequence, event_data, editor_service, editor_user, resource_owner, instance_id, aggregate_type, aggregate_id, aggregate_version FROM eventstore.events WHERE ( aggregate_type = $1 ) ORDER BY event_sequence LIMIT $2", rowScanner: true, values: []interface{}{es_models.AggregateType("user"), uint64(5)}, limit: 5, @@ -451,10 +457,10 @@ func Test_buildQuery(t *testing.T) { { name: "with limit and order by desc", args: args{ - queryFactory: es_models.NewSearchQueryFactory("user").Limit(5).OrderDesc(), + queryFactory: es_models.NewSearchQueryFactory().Limit(5).OrderDesc().AddQuery().AggregateTypes("user").Factory(), }, res: res{ - query: "SELECT creation_date, event_type, event_sequence, previous_aggregate_sequence, event_data, editor_service, editor_user, resource_owner, instance_id, aggregate_type, aggregate_id, aggregate_version FROM eventstore.events WHERE aggregate_type = $1 ORDER BY event_sequence DESC LIMIT $2", + query: "SELECT creation_date, event_type, event_sequence, previous_aggregate_sequence, event_data, editor_service, editor_user, resource_owner, instance_id, aggregate_type, aggregate_id, aggregate_version FROM eventstore.events WHERE ( aggregate_type = $1 ) ORDER BY event_sequence DESC LIMIT $2", rowScanner: true, values: []interface{}{es_models.AggregateType("user"), uint64(5)}, limit: 5, diff --git a/internal/eventstore/v1/locker/lock.go b/internal/eventstore/v1/locker/lock.go index a81dd8dd7d..8ae22b1103 100644 --- a/internal/eventstore/v1/locker/lock.go +++ b/internal/eventstore/v1/locker/lock.go @@ -7,16 +7,17 @@ import ( "time" "github.com/caos/logging" - caos_errs "github.com/caos/zitadel/internal/errors" "github.com/cockroachdb/cockroach-go/v2/crdb" + + caos_errs "github.com/caos/zitadel/internal/errors" ) const ( insertStmtFormat = "INSERT INTO %s" + - " (locker_id, locked_until, view_name) VALUES ($1, now()+$2::INTERVAL, $3)" + - " ON CONFLICT (view_name)" + - " DO UPDATE SET locker_id = $4, locked_until = now()+$5::INTERVAL" + - " WHERE locks.view_name = $6 AND (locks.locker_id = $7 OR locks.locked_until < now())" + " (locker_id, locked_until, view_name, instance_id) VALUES ($1, now()+$2::INTERVAL, $3, $4)" + + " ON CONFLICT (view_name, instance_id)" + + " DO UPDATE SET locker_id = $1, locked_until = now()+$2::INTERVAL" + + " WHERE locks.view_name = $3 AND locks.instance_id = $4 AND (locks.locker_id = $1 OR locks.locked_until < now())" millisecondsAsSeconds = int64(time.Second / time.Millisecond) ) @@ -26,13 +27,11 @@ type lock struct { ViewName string `gorm:"column:view_name;primary_key"` } -func Renew(dbClient *sql.DB, lockTable, lockerID, viewModel string, waitTime time.Duration) error { +func Renew(dbClient *sql.DB, lockTable, lockerID, viewModel, instanceID string, waitTime time.Duration) error { return crdb.ExecuteTx(context.Background(), dbClient, nil, func(tx *sql.Tx) error { insert := fmt.Sprintf(insertStmtFormat, lockTable) result, err := tx.Exec(insert, - lockerID, waitTime.Milliseconds()/millisecondsAsSeconds, viewModel, - lockerID, waitTime.Milliseconds()/millisecondsAsSeconds, - viewModel, lockerID) + lockerID, waitTime.Milliseconds()/millisecondsAsSeconds, viewModel, instanceID) if err != nil { tx.Rollback() diff --git a/internal/eventstore/v1/locker/lock_test.go b/internal/eventstore/v1/locker/lock_test.go index cb5358e868..37afa65deb 100644 --- a/internal/eventstore/v1/locker/lock_test.go +++ b/internal/eventstore/v1/locker/lock_test.go @@ -55,10 +55,10 @@ func (db *dbMock) expectReleaseSavepoint() *dbMock { return db } -func (db *dbMock) expectRenew(lockerID, view string, affectedRows int64) *dbMock { +func (db *dbMock) expectRenew(lockerID, view, instanceID string, affectedRows int64) *dbMock { query := db.mock. - ExpectExec(`INSERT INTO table\.locks \(locker_id, locked_until, view_name\) VALUES \(\$1, now\(\)\+\$2::INTERVAL, \$3\) ON CONFLICT \(view_name\) DO UPDATE SET locker_id = \$4, locked_until = now\(\)\+\$5::INTERVAL WHERE locks\.view_name = \$6 AND \(locks\.locker_id = \$7 OR locks\.locked_until < now\(\)\)`). - WithArgs(lockerID, sqlmock.AnyArg(), view, lockerID, sqlmock.AnyArg(), view, lockerID). + ExpectExec(`INSERT INTO table\.locks \(locker_id, locked_until, view_name, instance_id\) VALUES \(\$1, now\(\)\+\$2::INTERVAL, \$3\, \$4\) ON CONFLICT \(view_name, instance_id\) DO UPDATE SET locker_id = \$1, locked_until = now\(\)\+\$2::INTERVAL WHERE locks\.view_name = \$3 AND locks\.instance_id = \$4 AND \(locks\.locker_id = \$1 OR locks\.locked_until < now\(\)\)`). + WithArgs(lockerID, sqlmock.AnyArg(), view, instanceID). WillReturnResult(sqlmock.NewResult(1, 1)) if affectedRows == 0 { @@ -75,10 +75,11 @@ func Test_locker_Renew(t *testing.T) { db *dbMock } type args struct { - tableName string - lockerID string - viewModel string - waitTime time.Duration + tableName string + lockerID string + viewModel string + instanceID string + waitTime time.Duration } tests := []struct { name string @@ -92,11 +93,11 @@ func Test_locker_Renew(t *testing.T) { db: mockDB(t). expectBegin(). expectSavepoint(). - expectRenew("locker", "view", 1). + expectRenew("locker", "view", "instanceID", 1). expectReleaseSavepoint(). expectCommit(), }, - args: args{tableName: "table.locks", lockerID: "locker", viewModel: "view", waitTime: 1 * time.Second}, + args: args{tableName: "table.locks", lockerID: "locker", viewModel: "view", instanceID: "instanceID", waitTime: 1 * time.Second}, wantErr: false, }, { @@ -105,16 +106,16 @@ func Test_locker_Renew(t *testing.T) { db: mockDB(t). expectBegin(). expectSavepoint(). - expectRenew("locker", "view", 0). + expectRenew("locker", "view", "instanceID", 0). expectRollback(), }, - args: args{tableName: "table.locks", lockerID: "locker", viewModel: "view", waitTime: 1 * time.Second}, + args: args{tableName: "table.locks", lockerID: "locker", viewModel: "view", instanceID: "instanceID", waitTime: 1 * time.Second}, wantErr: true, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - if err := Renew(tt.fields.db.db, tt.args.tableName, tt.args.lockerID, tt.args.viewModel, tt.args.waitTime); (err != nil) != tt.wantErr { + if err := Renew(tt.fields.db.db, tt.args.tableName, tt.args.lockerID, tt.args.viewModel, tt.args.instanceID, tt.args.waitTime); (err != nil) != tt.wantErr { t.Errorf("locker.Renew() error = %v, wantErr %v", err, tt.wantErr) } if err := tt.fields.db.mock.ExpectationsWereMet(); err != nil { diff --git a/internal/eventstore/v1/mock/eventstore.mock.go b/internal/eventstore/v1/mock/eventstore.mock.go index fc2bbe1719..3e0d6468f2 100644 --- a/internal/eventstore/v1/mock/eventstore.mock.go +++ b/internal/eventstore/v1/mock/eventstore.mock.go @@ -1,56 +1,42 @@ // Code generated by MockGen. DO NOT EDIT. -// Source: github.com/caos/zitadel/internal/eventstore (interfaces: Eventstore) +// Source: github.com/caos/zitadel/internal/eventstore/v1 (interfaces: Eventstore) // Package mock is a generated GoMock package. package mock import ( context "context" - "github.com/caos/zitadel/internal/eventstore" - "github.com/caos/zitadel/internal/eventstore/v1" + reflect "reflect" + + v1 "github.com/caos/zitadel/internal/eventstore/v1" models "github.com/caos/zitadel/internal/eventstore/v1/models" gomock "github.com/golang/mock/gomock" - reflect "reflect" ) -// MockEventstore is a mock of Eventstore interface +// MockEventstore is a mock of Eventstore interface. type MockEventstore struct { ctrl *gomock.Controller recorder *MockEventstoreMockRecorder } -// MockEventstoreMockRecorder is the mock recorder for MockEventstore +// MockEventstoreMockRecorder is the mock recorder for MockEventstore. type MockEventstoreMockRecorder struct { mock *MockEventstore } -// NewMockEventstore creates a new mock instance +// NewMockEventstore creates a new mock instance. func NewMockEventstore(ctrl *gomock.Controller) *MockEventstore { mock := &MockEventstore{ctrl: ctrl} mock.recorder = &MockEventstoreMockRecorder{mock} return mock } -// EXPECT returns an object that allows the caller to indicate expected use +// EXPECT returns an object that allows the caller to indicate expected use. func (m *MockEventstore) EXPECT() *MockEventstoreMockRecorder { return m.recorder } -// AggregateCreator mocks base method -func (m *MockEventstore) AggregateCreator() *models.AggregateCreator { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "AggregateCreator") - ret0, _ := ret[0].(*models.AggregateCreator) - return ret0 -} - -// AggregateCreator indicates an expected call of AggregateCreator -func (mr *MockEventstoreMockRecorder) AggregateCreator() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AggregateCreator", reflect.TypeOf((*MockEventstore)(nil).AggregateCreator)) -} - -// FilterEvents mocks base method +// FilterEvents mocks base method. func (m *MockEventstore) FilterEvents(arg0 context.Context, arg1 *models.SearchQuery) ([]*models.Event, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "FilterEvents", arg0, arg1) @@ -59,13 +45,13 @@ func (m *MockEventstore) FilterEvents(arg0 context.Context, arg1 *models.SearchQ return ret0, ret1 } -// FilterEvents indicates an expected call of FilterEvents +// FilterEvents indicates an expected call of FilterEvents. func (mr *MockEventstoreMockRecorder) FilterEvents(arg0, arg1 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FilterEvents", reflect.TypeOf((*MockEventstore)(nil).FilterEvents), arg0, arg1) } -// Health mocks base method +// Health mocks base method. func (m *MockEventstore) Health(arg0 context.Context) error { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "Health", arg0) @@ -73,47 +59,13 @@ func (m *MockEventstore) Health(arg0 context.Context) error { return ret0 } -// Health indicates an expected call of Health +// Health indicates an expected call of Health. func (mr *MockEventstoreMockRecorder) Health(arg0 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Health", reflect.TypeOf((*MockEventstore)(nil).Health), arg0) } -// LatestSequence mocks base method -func (m *MockEventstore) LatestSequence(arg0 context.Context, arg1 *models.SearchQueryFactory) (uint64, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "LatestSequence", arg0, arg1) - ret0, _ := ret[0].(uint64) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// LatestSequence indicates an expected call of LatestSequence -func (mr *MockEventstoreMockRecorder) LatestSequence(arg0, arg1 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LatestSequence", reflect.TypeOf((*MockEventstore)(nil).LatestSequence), arg0, arg1) -} - -// PushAggregates mocks base method -func (m *MockEventstore) PushAggregates(arg0 context.Context, arg1 ...*models.Aggregate) error { - m.ctrl.T.Helper() - varargs := []interface{}{arg0} - for _, a := range arg1 { - varargs = append(varargs, a) - } - ret := m.ctrl.Call(m, "PushAggregates", varargs...) - ret0, _ := ret[0].(error) - return ret0 -} - -// PushAggregates indicates an expected call of PushAggregates -func (mr *MockEventstoreMockRecorder) PushAggregates(arg0 interface{}, arg1 ...interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - varargs := append([]interface{}{arg0}, arg1...) - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PushAggregates", reflect.TypeOf((*MockEventstore)(nil).PushAggregates), varargs...) -} - -// Subscribe mocks base method +// Subscribe mocks base method. func (m *MockEventstore) Subscribe(arg0 ...models.AggregateType) *v1.Subscription { m.ctrl.T.Helper() varargs := []interface{}{} @@ -125,22 +77,8 @@ func (m *MockEventstore) Subscribe(arg0 ...models.AggregateType) *v1.Subscriptio return ret0 } -// Subscribe indicates an expected call of Subscribe +// Subscribe indicates an expected call of Subscribe. func (mr *MockEventstoreMockRecorder) Subscribe(arg0 ...interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Subscribe", reflect.TypeOf((*MockEventstore)(nil).Subscribe), arg0...) } - -// V2 mocks base method -func (m *MockEventstore) V2() *eventstore.Eventstore { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "V2") - ret0, _ := ret[0].(*eventstore.Eventstore) - return ret0 -} - -// V2 indicates an expected call of V2 -func (mr *MockEventstoreMockRecorder) V2() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "V2", reflect.TypeOf((*MockEventstore)(nil).V2)) -} diff --git a/internal/eventstore/v1/models/aggregate_test.go b/internal/eventstore/v1/models/aggregate_test.go index 2e17580fd1..a01cff3373 100644 --- a/internal/eventstore/v1/models/aggregate_test.go +++ b/internal/eventstore/v1/models/aggregate_test.go @@ -190,7 +190,7 @@ func TestAggregate_Validate(t *testing.T) { resourceOwner: "org", PreviousSequence: 5, Precondition: &precondition{ - Query: NewSearchQuery().AggregateIDFilter("hodor"), + Query: NewSearchQuery().AddQuery().AggregateIDFilter("hodor").SearchQuery(), }, Events: []*Event{ { @@ -240,7 +240,7 @@ func TestAggregate_Validate(t *testing.T) { PreviousSequence: 5, Precondition: &precondition{ Validation: func(...*Event) error { return nil }, - Query: NewSearchQuery().AggregateIDFilter("hodor"), + Query: NewSearchQuery().AddQuery().AggregateIDFilter("hodor").SearchQuery(), }, Events: []*Event{ { diff --git a/internal/eventstore/v1/models/operation.go b/internal/eventstore/v1/models/operation.go index 8f73c2e310..dc3bb06967 100644 --- a/internal/eventstore/v1/models/operation.go +++ b/internal/eventstore/v1/models/operation.go @@ -7,4 +7,5 @@ const ( Operation_Greater Operation_Less Operation_In + Operation_NotIn ) diff --git a/internal/eventstore/v1/models/search_query.go b/internal/eventstore/v1/models/search_query.go index a4370c5918..b878f10cc2 100644 --- a/internal/eventstore/v1/models/search_query.go +++ b/internal/eventstore/v1/models/search_query.go @@ -9,24 +9,31 @@ import ( ) type SearchQueryFactory struct { - columns Columns - limit uint64 - desc bool - aggregateTypes []AggregateType - aggregateIDs []string - sequenceFrom uint64 - sequenceTo uint64 - eventTypes []EventType - resourceOwner string - instanceID string - creationDate time.Time + columns Columns + limit uint64 + desc bool + queries []*query +} + +type query struct { + desc bool + aggregateTypes []AggregateType + aggregateIDs []string + sequenceFrom uint64 + sequenceTo uint64 + eventTypes []EventType + resourceOwner string + instanceID string + ignoredInstanceIDs []string + creationDate time.Time + factory *SearchQueryFactory } type searchQuery struct { Columns Columns Limit uint64 Desc bool - Filters []*Filter + Filters [][]*Filter } type Columns int32 @@ -39,49 +46,55 @@ const ( ) //FactoryFromSearchQuery is deprecated because it's for migration purposes. use NewSearchQueryFactory -func FactoryFromSearchQuery(query *SearchQuery) *SearchQueryFactory { +func FactoryFromSearchQuery(q *SearchQuery) *SearchQueryFactory { factory := &SearchQueryFactory{ columns: Columns_Event, - desc: query.Desc, - limit: query.Limit, + desc: q.Desc, + limit: q.Limit, + queries: make([]*query, len(q.Queries)), } - for _, filter := range query.Filters { - switch filter.field { - case Field_AggregateType: - factory = factory.aggregateTypesMig(filter.value.([]AggregateType)...) - case Field_AggregateID: - if aggregateID, ok := filter.value.(string); ok { - factory = factory.AggregateIDs(aggregateID) - } else if aggregateIDs, ok := filter.value.([]string); ok { - factory = factory.AggregateIDs(aggregateIDs...) + for i, qq := range q.Queries { + factory.queries[i] = &query{factory: factory} + for _, filter := range qq.Filters { + switch filter.field { + case Field_AggregateType: + factory.queries[i] = factory.queries[i].aggregateTypesMig(filter.value.([]AggregateType)...) + case Field_AggregateID: + if aggregateID, ok := filter.value.(string); ok { + factory.queries[i] = factory.queries[i].AggregateIDs(aggregateID) + } else if aggregateIDs, ok := filter.value.([]string); ok { + factory.queries[i] = factory.queries[i].AggregateIDs(aggregateIDs...) + } + case Field_LatestSequence: + if filter.operation == Operation_Greater { + factory.queries[i] = factory.queries[i].SequenceGreater(filter.value.(uint64)) + } else { + factory.queries[i] = factory.queries[i].SequenceLess(filter.value.(uint64)) + } + case Field_ResourceOwner: + factory.queries[i] = factory.queries[i].ResourceOwner(filter.value.(string)) + case Field_InstanceID: + if filter.operation == Operation_Equals { + factory.queries[i] = factory.queries[i].InstanceID(filter.value.(string)) + } else if filter.operation == Operation_NotIn { + factory.queries[i] = factory.queries[i].IgnoredInstanceIDs(filter.value.([]string)...) + } + case Field_EventType: + factory.queries[i] = factory.queries[i].EventTypes(filter.value.([]EventType)...) + case Field_EditorService, Field_EditorUser: + logging.WithFields("value", filter.value).Panic("field not converted to factory") + case Field_CreationDate: + factory.queries[i] = factory.queries[i].CreationDateNewer(filter.value.(time.Time)) } - case Field_LatestSequence: - if filter.operation == Operation_Greater { - factory = factory.SequenceGreater(filter.value.(uint64)) - } else { - factory = factory.SequenceLess(filter.value.(uint64)) - } - case Field_ResourceOwner: - factory = factory.ResourceOwner(filter.value.(string)) - case Field_InstanceID: - factory = factory.InstanceID(filter.value.(string)) - case Field_EventType: - factory = factory.EventTypes(filter.value.([]EventType)...) - case Field_EditorService, Field_EditorUser: - logging.Log("MODEL-Mr0VN").WithField("value", filter.value).Panic("field not converted to factory") - case Field_CreationDate: - factory = factory.CreationDateNewer(filter.value.(time.Time)) } } return factory } -func NewSearchQueryFactory(aggregateTypes ...AggregateType) *SearchQueryFactory { - return &SearchQueryFactory{ - aggregateTypes: aggregateTypes, - } +func NewSearchQueryFactory() *SearchQueryFactory { + return &SearchQueryFactory{} } func (factory *SearchQueryFactory) Columns(columns Columns) *SearchQueryFactory { @@ -94,46 +107,6 @@ func (factory *SearchQueryFactory) Limit(limit uint64) *SearchQueryFactory { return factory } -func (factory *SearchQueryFactory) SequenceGreater(sequence uint64) *SearchQueryFactory { - factory.sequenceFrom = sequence - return factory -} - -func (factory *SearchQueryFactory) SequenceLess(sequence uint64) *SearchQueryFactory { - factory.sequenceTo = sequence - return factory -} - -func (factory *SearchQueryFactory) AggregateIDs(ids ...string) *SearchQueryFactory { - factory.aggregateIDs = ids - return factory -} - -func (factory *SearchQueryFactory) aggregateTypesMig(types ...AggregateType) *SearchQueryFactory { - factory.aggregateTypes = types - return factory -} - -func (factory *SearchQueryFactory) EventTypes(types ...EventType) *SearchQueryFactory { - factory.eventTypes = types - return factory -} - -func (factory *SearchQueryFactory) ResourceOwner(resourceOwner string) *SearchQueryFactory { - factory.resourceOwner = resourceOwner - return factory -} - -func (factory *SearchQueryFactory) InstanceID(instanceID string) *SearchQueryFactory { - factory.instanceID = instanceID - return factory -} - -func (factory *SearchQueryFactory) CreationDateNewer(time time.Time) *SearchQueryFactory { - factory.creationDate = time - return factory -} - func (factory *SearchQueryFactory) OrderDesc() *SearchQueryFactory { factory.desc = true return factory @@ -144,27 +117,89 @@ func (factory *SearchQueryFactory) OrderAsc() *SearchQueryFactory { return factory } +func (factory *SearchQueryFactory) AddQuery() *query { + q := &query{factory: factory} + factory.queries = append(factory.queries, q) + return q +} + +func (q *query) Factory() *SearchQueryFactory { + return q.factory +} + +func (q *query) SequenceGreater(sequence uint64) *query { + q.sequenceFrom = sequence + return q +} + +func (q *query) SequenceLess(sequence uint64) *query { + q.sequenceTo = sequence + return q +} + +func (q *query) AggregateTypes(types ...AggregateType) *query { + q.aggregateTypes = types + return q +} + +func (q *query) AggregateIDs(ids ...string) *query { + q.aggregateIDs = ids + return q +} + +func (q *query) aggregateTypesMig(types ...AggregateType) *query { + q.aggregateTypes = types + return q +} + +func (q *query) EventTypes(types ...EventType) *query { + q.eventTypes = types + return q +} + +func (q *query) ResourceOwner(resourceOwner string) *query { + q.resourceOwner = resourceOwner + return q +} + +func (q *query) InstanceID(instanceID string) *query { + q.instanceID = instanceID + return q +} + +func (q *query) IgnoredInstanceIDs(instanceIDs ...string) *query { + q.ignoredInstanceIDs = instanceIDs + return q +} + +func (q *query) CreationDateNewer(time time.Time) *query { + q.creationDate = time + return q +} + func (factory *SearchQueryFactory) Build() (*searchQuery, error) { if factory == nil || - len(factory.aggregateTypes) < 1 || + len(factory.queries) < 1 || (factory.columns < 0 || factory.columns >= columnsCount) { return nil, errors.ThrowPreconditionFailed(nil, "MODEL-tGAD3", "factory invalid") } - filters := []*Filter{ - factory.aggregateTypeFilter(), - } + filters := make([][]*Filter, len(factory.queries)) - for _, f := range []func() *Filter{ - factory.aggregateIDFilter, - factory.sequenceFromFilter, - factory.sequenceToFilter, - factory.eventTypeFilter, - factory.resourceOwnerFilter, - factory.instanceIDFilter, - factory.creationDateNewerFilter, - } { - if filter := f(); filter != nil { - filters = append(filters, filter) + for i, query := range factory.queries { + for _, f := range []func() *Filter{ + query.aggregateTypeFilter, + query.aggregateIDFilter, + query.sequenceFromFilter, + query.sequenceToFilter, + query.eventTypeFilter, + query.resourceOwnerFilter, + query.instanceIDFilter, + query.ignoredInstanceIDsFilter, + query.creationDateNewerFilter, + } { + if filter := f(); filter != nil { + filters[i] = append(filters[i], filter) + } } } @@ -176,72 +211,79 @@ func (factory *SearchQueryFactory) Build() (*searchQuery, error) { }, nil } -func (factory *SearchQueryFactory) aggregateIDFilter() *Filter { - if len(factory.aggregateIDs) < 1 { +func (q *query) aggregateIDFilter() *Filter { + if len(q.aggregateIDs) < 1 { return nil } - if len(factory.aggregateIDs) == 1 { - return NewFilter(Field_AggregateID, factory.aggregateIDs[0], Operation_Equals) + if len(q.aggregateIDs) == 1 { + return NewFilter(Field_AggregateID, q.aggregateIDs[0], Operation_Equals) } - return NewFilter(Field_AggregateID, factory.aggregateIDs, Operation_In) + return NewFilter(Field_AggregateID, q.aggregateIDs, Operation_In) } -func (factory *SearchQueryFactory) eventTypeFilter() *Filter { - if len(factory.eventTypes) < 1 { +func (q *query) eventTypeFilter() *Filter { + if len(q.eventTypes) < 1 { return nil } - if len(factory.eventTypes) == 1 { - return NewFilter(Field_EventType, factory.eventTypes[0], Operation_Equals) + if len(q.eventTypes) == 1 { + return NewFilter(Field_EventType, q.eventTypes[0], Operation_Equals) } - return NewFilter(Field_EventType, factory.eventTypes, Operation_In) + return NewFilter(Field_EventType, q.eventTypes, Operation_In) } -func (factory *SearchQueryFactory) aggregateTypeFilter() *Filter { - if len(factory.aggregateTypes) == 1 { - return NewFilter(Field_AggregateType, factory.aggregateTypes[0], Operation_Equals) +func (q *query) aggregateTypeFilter() *Filter { + if len(q.aggregateTypes) == 1 { + return NewFilter(Field_AggregateType, q.aggregateTypes[0], Operation_Equals) } - return NewFilter(Field_AggregateType, factory.aggregateTypes, Operation_In) + return NewFilter(Field_AggregateType, q.aggregateTypes, Operation_In) } -func (factory *SearchQueryFactory) sequenceFromFilter() *Filter { - if factory.sequenceFrom == 0 { +func (q *query) sequenceFromFilter() *Filter { + if q.sequenceFrom == 0 { return nil } sortOrder := Operation_Greater - if factory.desc { + if q.factory.desc { sortOrder = Operation_Less } - return NewFilter(Field_LatestSequence, factory.sequenceFrom, sortOrder) + return NewFilter(Field_LatestSequence, q.sequenceFrom, sortOrder) } -func (factory *SearchQueryFactory) sequenceToFilter() *Filter { - if factory.sequenceTo == 0 { +func (q *query) sequenceToFilter() *Filter { + if q.sequenceTo == 0 { return nil } sortOrder := Operation_Less - if factory.desc { + if q.factory.desc { sortOrder = Operation_Greater } - return NewFilter(Field_LatestSequence, factory.sequenceTo, sortOrder) + return NewFilter(Field_LatestSequence, q.sequenceTo, sortOrder) } -func (factory *SearchQueryFactory) resourceOwnerFilter() *Filter { - if factory.resourceOwner == "" { +func (q *query) resourceOwnerFilter() *Filter { + if q.resourceOwner == "" { return nil } - return NewFilter(Field_ResourceOwner, factory.resourceOwner, Operation_Equals) + return NewFilter(Field_ResourceOwner, q.resourceOwner, Operation_Equals) } -func (factory *SearchQueryFactory) instanceIDFilter() *Filter { - if factory.instanceID == "" { +func (q *query) instanceIDFilter() *Filter { + if q.instanceID == "" { return nil } - return NewFilter(Field_InstanceID, factory.instanceID, Operation_Equals) + return NewFilter(Field_InstanceID, q.instanceID, Operation_Equals) } -func (factory *SearchQueryFactory) creationDateNewerFilter() *Filter { - if factory.creationDate.IsZero() { +func (q *query) ignoredInstanceIDsFilter() *Filter { + if len(q.ignoredInstanceIDs) == 0 { return nil } - return NewFilter(Field_CreationDate, factory.creationDate, Operation_Greater) + return NewFilter(Field_InstanceID, q.ignoredInstanceIDs, Operation_NotIn) +} + +func (q *query) creationDateNewerFilter() *Filter { + if q.creationDate.IsZero() { + return nil + } + return NewFilter(Field_CreationDate, q.creationDate, Operation_Greater) } diff --git a/internal/eventstore/v1/models/search_query_old.go b/internal/eventstore/v1/models/search_query_old.go index 75f4021a07..09754c190a 100644 --- a/internal/eventstore/v1/models/search_query_old.go +++ b/internal/eventstore/v1/models/search_query_old.go @@ -11,15 +11,46 @@ type SearchQuery struct { Limit uint64 Desc bool Filters []*Filter + Queries []*Query +} + +type Query struct { + searchQuery *SearchQuery + Filters []*Filter } //NewSearchQuery is deprecated. Use SearchQueryFactory func NewSearchQuery() *SearchQuery { return &SearchQuery{ Filters: make([]*Filter, 0, 4), + Queries: make([]*Query, 0), } } +func (q *SearchQuery) AddQuery() *Query { + query := &Query{ + searchQuery: q, + } + q.Queries = append(q.Queries, query) + + return query +} + +//SearchQuery returns the SearchQuery of the sub query +func (q *Query) SearchQuery() *SearchQuery { + return q.searchQuery +} +func (q *Query) setFilter(filter *Filter) *Query { + for i, f := range q.Filters { + if f.field == filter.field && f.field != Field_LatestSequence { + q.Filters[i] = filter + return q + } + } + q.Filters = append(q.Filters, filter) + return q +} + func (q *SearchQuery) SetLimit(limit uint64) *SearchQuery { q.Limit = limit return q @@ -35,23 +66,23 @@ func (q *SearchQuery) OrderAsc() *SearchQuery { return q } -func (q *SearchQuery) AggregateIDFilter(id string) *SearchQuery { +func (q *Query) AggregateIDFilter(id string) *Query { return q.setFilter(NewFilter(Field_AggregateID, id, Operation_Equals)) } -func (q *SearchQuery) AggregateIDsFilter(ids ...string) *SearchQuery { +func (q *Query) AggregateIDsFilter(ids ...string) *Query { return q.setFilter(NewFilter(Field_AggregateID, ids, Operation_In)) } -func (q *SearchQuery) AggregateTypeFilter(types ...AggregateType) *SearchQuery { +func (q *Query) AggregateTypeFilter(types ...AggregateType) *Query { return q.setFilter(NewFilter(Field_AggregateType, types, Operation_In)) } -func (q *SearchQuery) EventTypesFilter(types ...EventType) *SearchQuery { +func (q *Query) EventTypesFilter(types ...EventType) *Query { return q.setFilter(NewFilter(Field_EventType, types, Operation_In)) } -func (q *SearchQuery) LatestSequenceFilter(sequence uint64) *SearchQuery { +func (q *Query) LatestSequenceFilter(sequence uint64) *Query { if sequence == 0 { return q } @@ -59,21 +90,25 @@ func (q *SearchQuery) LatestSequenceFilter(sequence uint64) *SearchQuery { return q.setFilter(NewFilter(Field_LatestSequence, sequence, sortOrder)) } -func (q *SearchQuery) SequenceBetween(from, to uint64) *SearchQuery { +func (q *Query) SequenceBetween(from, to uint64) *Query { q.setFilter(NewFilter(Field_LatestSequence, from, Operation_Greater)) q.setFilter(NewFilter(Field_LatestSequence, to, Operation_Less)) return q } -func (q *SearchQuery) ResourceOwnerFilter(resourceOwner string) *SearchQuery { +func (q *Query) ResourceOwnerFilter(resourceOwner string) *Query { return q.setFilter(NewFilter(Field_ResourceOwner, resourceOwner, Operation_Equals)) } -func (q *SearchQuery) InstanceIDFilter(instanceID string) *SearchQuery { +func (q *Query) InstanceIDFilter(instanceID string) *Query { return q.setFilter(NewFilter(Field_InstanceID, instanceID, Operation_Equals)) } -func (q *SearchQuery) CreationDateNewerFilter(time time.Time) *SearchQuery { +func (q *Query) ExcludedInstanceIDsFilter(instanceIDs ...string) *Query { + return q.setFilter(NewFilter(Field_InstanceID, instanceIDs, Operation_NotIn)) +} + +func (q *Query) CreationDateNewerFilter(time time.Time) *Query { return q.setFilter(NewFilter(Field_CreationDate, time, Operation_Greater)) } @@ -92,12 +127,14 @@ func (q *SearchQuery) Validate() error { if q == nil { return errors.ThrowPreconditionFailed(nil, "MODEL-J5xQi", "search query is nil") } - if len(q.Filters) == 0 { + if len(q.Queries) == 0 { return errors.ThrowPreconditionFailed(nil, "MODEL-pF3DR", "no filters set") } - for _, filter := range q.Filters { - if err := filter.Validate(); err != nil { - return err + for _, query := range q.Queries { + for _, filter := range query.Filters { + if err := filter.Validate(); err != nil { + return err + } } } diff --git a/internal/eventstore/v1/models/search_query_test.go b/internal/eventstore/v1/models/search_query_test.go index b52f6490cf..03c719c85b 100644 --- a/internal/eventstore/v1/models/search_query_test.go +++ b/internal/eventstore/v1/models/search_query_test.go @@ -21,31 +21,48 @@ func testSetLimit(limit uint64) func(factory *SearchQueryFactory) *SearchQueryFa } } -func testSetSequence(sequence uint64) func(factory *SearchQueryFactory) *SearchQueryFactory { - return func(factory *SearchQueryFactory) *SearchQueryFactory { - factory = factory.SequenceGreater(sequence) - return factory +func testAddQuery(queryFuncs ...func(*query) *query) func(*SearchQueryFactory) *SearchQueryFactory { + return func(builder *SearchQueryFactory) *SearchQueryFactory { + query := builder.AddQuery() + for _, queryFunc := range queryFuncs { + queryFunc(query) + } + return query.Factory() } } -func testSetAggregateIDs(aggregateIDs ...string) func(factory *SearchQueryFactory) *SearchQueryFactory { - return func(factory *SearchQueryFactory) *SearchQueryFactory { - factory = factory.AggregateIDs(aggregateIDs...) - return factory +func testSetSequence(sequence uint64) func(*query) *query { + return func(q *query) *query { + q.SequenceGreater(sequence) + return q } } -func testSetEventTypes(eventTypes ...EventType) func(factory *SearchQueryFactory) *SearchQueryFactory { - return func(factory *SearchQueryFactory) *SearchQueryFactory { - factory = factory.EventTypes(eventTypes...) - return factory +func testSetAggregateIDs(aggregateIDs ...string) func(*query) *query { + return func(q *query) *query { + q.AggregateIDs(aggregateIDs...) + return q } } -func testSetResourceOwner(resourceOwner string) func(factory *SearchQueryFactory) *SearchQueryFactory { - return func(factory *SearchQueryFactory) *SearchQueryFactory { - factory = factory.ResourceOwner(resourceOwner) - return factory +func testSetAggregateTypes(aggregateTypes ...AggregateType) func(*query) *query { + return func(q *query) *query { + q.AggregateTypes(aggregateTypes...) + return q + } +} + +func testSetEventTypes(eventTypes ...EventType) func(*query) *query { + return func(q *query) *query { + q.EventTypes(eventTypes...) + return q + } +} + +func testSetResourceOwner(resourceOwner string) func(*query) *query { + return func(q *query) *query { + q.ResourceOwner(resourceOwner) + return q } } @@ -60,10 +77,50 @@ func testSetSortOrder(asc bool) func(factory *SearchQueryFactory) *SearchQueryFa } } +func assertFactory(t *testing.T, want, got *SearchQueryFactory) { + t.Helper() + + if got.columns != want.columns { + t.Errorf("wrong column: got: %v want: %v", got.columns, want.columns) + } + if got.desc != want.desc { + t.Errorf("wrong desc: got: %v want: %v", got.desc, want.desc) + } + if got.limit != want.limit { + t.Errorf("wrong limit: got: %v want: %v", got.limit, want.limit) + } + if len(got.queries) != len(want.queries) { + t.Errorf("wrong length of queries: got: %v want: %v", len(got.queries), len(want.queries)) + } + + for i, query := range got.queries { + assertQuery(t, i, want.queries[i], query) + } +} + +func assertQuery(t *testing.T, i int, want, got *query) { + t.Helper() + + if !reflect.DeepEqual(got.aggregateIDs, want.aggregateIDs) { + t.Errorf("wrong aggregateIDs in query %d : got: %v want: %v", i, got.aggregateIDs, want.aggregateIDs) + } + if !reflect.DeepEqual(got.aggregateTypes, want.aggregateTypes) { + t.Errorf("wrong aggregateTypes in query %d : got: %v want: %v", i, got.aggregateTypes, want.aggregateTypes) + } + if got.sequenceFrom != want.sequenceFrom { + t.Errorf("wrong sequenceFrom in query %d : got: %v want: %v", i, got.sequenceFrom, want.sequenceFrom) + } + if got.sequenceTo != want.sequenceTo { + t.Errorf("wrong sequenceTo in query %d : got: %v want: %v", i, got.sequenceTo, want.sequenceTo) + } + if !reflect.DeepEqual(got.eventTypes, want.eventTypes) { + t.Errorf("wrong eventTypes in query %d : got: %v want: %v", i, got.eventTypes, want.eventTypes) + } +} + func TestSearchQueryFactorySetters(t *testing.T) { type args struct { - aggregateTypes []AggregateType - setters []func(*SearchQueryFactory) *SearchQueryFactory + setters []func(*SearchQueryFactory) *SearchQueryFactory } tests := []struct { name string @@ -73,11 +130,9 @@ func TestSearchQueryFactorySetters(t *testing.T) { { name: "New factory", args: args{ - aggregateTypes: []AggregateType{"user", "org"}, - }, - res: &SearchQueryFactory{ - aggregateTypes: []AggregateType{"user", "org"}, + setters: []func(*SearchQueryFactory) *SearchQueryFactory{}, }, + res: &SearchQueryFactory{}, }, { name: "set columns", @@ -100,69 +155,98 @@ func TestSearchQueryFactorySetters(t *testing.T) { { name: "set sequence", args: args{ - setters: []func(*SearchQueryFactory) *SearchQueryFactory{testSetSequence(90)}, + setters: []func(*SearchQueryFactory) *SearchQueryFactory{testAddQuery(testSetSequence(90))}, }, res: &SearchQueryFactory{ - sequenceFrom: 90, + queries: []*query{ + { + sequenceFrom: 90, + }, + }, + }, + }, + { + name: "set aggregateTypes", + args: args{ + setters: []func(*SearchQueryFactory) *SearchQueryFactory{testAddQuery(testSetAggregateTypes("user", "org"))}, + }, + res: &SearchQueryFactory{ + queries: []*query{ + { + aggregateTypes: []AggregateType{"user", "org"}, + }, + }, }, }, { name: "set aggregateIDs", args: args{ - setters: []func(*SearchQueryFactory) *SearchQueryFactory{testSetAggregateIDs("1235", "09824")}, + setters: []func(*SearchQueryFactory) *SearchQueryFactory{testAddQuery(testSetAggregateIDs("1235", "09824"))}, }, res: &SearchQueryFactory{ - aggregateIDs: []string{"1235", "09824"}, + queries: []*query{ + { + aggregateIDs: []string{"1235", "09824"}, + }, + }, }, }, { name: "set eventTypes", args: args{ - setters: []func(*SearchQueryFactory) *SearchQueryFactory{testSetEventTypes("user.created", "user.updated")}, + setters: []func(*SearchQueryFactory) *SearchQueryFactory{testAddQuery(testSetEventTypes("user.created", "user.updated"))}, }, res: &SearchQueryFactory{ - eventTypes: []EventType{"user.created", "user.updated"}, + queries: []*query{ + { + eventTypes: []EventType{"user.created", "user.updated"}, + }, + }, }, }, { name: "set resource owner", args: args{ - setters: []func(*SearchQueryFactory) *SearchQueryFactory{testSetResourceOwner("hodor")}, + setters: []func(*SearchQueryFactory) *SearchQueryFactory{testAddQuery(testSetResourceOwner("hodor"))}, }, res: &SearchQueryFactory{ - resourceOwner: "hodor", + queries: []*query{ + { + resourceOwner: "hodor", + }, + }, }, }, { name: "default search query", args: args{ - aggregateTypes: []AggregateType{"user"}, - setters: []func(*SearchQueryFactory) *SearchQueryFactory{testSetAggregateIDs("1235", "024"), testSetSortOrder(false)}, + setters: []func(*SearchQueryFactory) *SearchQueryFactory{testAddQuery(testSetAggregateTypes("user"), testSetAggregateIDs("1235", "024")), testSetSortOrder(false)}, }, res: &SearchQueryFactory{ - aggregateTypes: []AggregateType{"user"}, - aggregateIDs: []string{"1235", "024"}, - desc: true, + desc: true, + queries: []*query{ + { + aggregateTypes: []AggregateType{"user"}, + aggregateIDs: []string{"1235", "024"}, + }, + }, }, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - factory := NewSearchQueryFactory(tt.args.aggregateTypes...) + factory := NewSearchQueryFactory() for _, setter := range tt.args.setters { factory = setter(factory) } - if !reflect.DeepEqual(factory, tt.res) { - t.Errorf("NewSearchQueryFactory() = %v, want %v", factory, tt.res) - } + assertFactory(t, tt.res, factory) }) } } func TestSearchQueryFactoryBuild(t *testing.T) { type args struct { - aggregateTypes []AggregateType - setters []func(*SearchQueryFactory) *SearchQueryFactory + setters []func(*SearchQueryFactory) *SearchQueryFactory } type res struct { isErr func(err error) bool @@ -176,8 +260,7 @@ func TestSearchQueryFactoryBuild(t *testing.T) { { name: "no aggregate types", args: args{ - aggregateTypes: []AggregateType{}, - setters: []func(*SearchQueryFactory) *SearchQueryFactory{}, + setters: []func(*SearchQueryFactory) *SearchQueryFactory{}, }, res: res{ isErr: errors.IsPreconditionFailed, @@ -187,9 +270,9 @@ func TestSearchQueryFactoryBuild(t *testing.T) { { name: "invalid column (too low)", args: args{ - aggregateTypes: []AggregateType{"user"}, setters: []func(*SearchQueryFactory) *SearchQueryFactory{ testSetColumns(Columns(-1)), + testAddQuery(testSetAggregateTypes("user")), }, }, res: res{ @@ -199,9 +282,9 @@ func TestSearchQueryFactoryBuild(t *testing.T) { { name: "invalid column (too high)", args: args{ - aggregateTypes: []AggregateType{"user"}, setters: []func(*SearchQueryFactory) *SearchQueryFactory{ testSetColumns(columnsCount), + testAddQuery(testSetAggregateTypes("user")), }, }, res: res{ @@ -211,8 +294,9 @@ func TestSearchQueryFactoryBuild(t *testing.T) { { name: "filter aggregate type", args: args{ - aggregateTypes: []AggregateType{"user"}, - setters: []func(*SearchQueryFactory) *SearchQueryFactory{}, + setters: []func(*SearchQueryFactory) *SearchQueryFactory{ + testAddQuery(testSetAggregateTypes("user")), + }, }, res: res{ isErr: nil, @@ -220,8 +304,10 @@ func TestSearchQueryFactoryBuild(t *testing.T) { Columns: 0, Desc: false, Limit: 0, - Filters: []*Filter{ - NewFilter(Field_AggregateType, AggregateType("user"), Operation_Equals), + Filters: [][]*Filter{ + { + NewFilter(Field_AggregateType, AggregateType("user"), Operation_Equals), + }, }, }, }, @@ -229,8 +315,9 @@ func TestSearchQueryFactoryBuild(t *testing.T) { { name: "filter aggregate types", args: args{ - aggregateTypes: []AggregateType{"user", "org"}, - setters: []func(*SearchQueryFactory) *SearchQueryFactory{}, + setters: []func(*SearchQueryFactory) *SearchQueryFactory{ + testAddQuery(testSetAggregateTypes("user", "org")), + }, }, res: res{ isErr: nil, @@ -238,8 +325,10 @@ func TestSearchQueryFactoryBuild(t *testing.T) { Columns: 0, Desc: false, Limit: 0, - Filters: []*Filter{ - NewFilter(Field_AggregateType, []AggregateType{"user", "org"}, Operation_In), + Filters: [][]*Filter{ + { + NewFilter(Field_AggregateType, []AggregateType{"user", "org"}, Operation_In), + }, }, }, }, @@ -247,11 +336,13 @@ func TestSearchQueryFactoryBuild(t *testing.T) { { name: "filter aggregate type, limit, desc", args: args{ - aggregateTypes: []AggregateType{"user"}, setters: []func(*SearchQueryFactory) *SearchQueryFactory{ testSetLimit(5), testSetSortOrder(false), - testSetSequence(100), + testAddQuery( + testSetAggregateTypes("user"), + testSetSequence(100), + ), }, }, res: res{ @@ -260,9 +351,11 @@ func TestSearchQueryFactoryBuild(t *testing.T) { Columns: 0, Desc: true, Limit: 5, - Filters: []*Filter{ - NewFilter(Field_AggregateType, AggregateType("user"), Operation_Equals), - NewFilter(Field_LatestSequence, uint64(100), Operation_Less), + Filters: [][]*Filter{ + { + NewFilter(Field_AggregateType, AggregateType("user"), Operation_Equals), + NewFilter(Field_LatestSequence, uint64(100), Operation_Less), + }, }, }, }, @@ -270,11 +363,13 @@ func TestSearchQueryFactoryBuild(t *testing.T) { { name: "filter aggregate type, limit, asc", args: args{ - aggregateTypes: []AggregateType{"user"}, setters: []func(*SearchQueryFactory) *SearchQueryFactory{ testSetLimit(5), testSetSortOrder(true), - testSetSequence(100), + testAddQuery( + testSetSequence(100), + testSetAggregateTypes("user"), + ), }, }, res: res{ @@ -283,9 +378,11 @@ func TestSearchQueryFactoryBuild(t *testing.T) { Columns: 0, Desc: false, Limit: 5, - Filters: []*Filter{ - NewFilter(Field_AggregateType, AggregateType("user"), Operation_Equals), - NewFilter(Field_LatestSequence, uint64(100), Operation_Greater), + Filters: [][]*Filter{ + { + NewFilter(Field_AggregateType, AggregateType("user"), Operation_Equals), + NewFilter(Field_LatestSequence, uint64(100), Operation_Greater), + }, }, }, }, @@ -293,12 +390,14 @@ func TestSearchQueryFactoryBuild(t *testing.T) { { name: "filter aggregate type, limit, desc, max event sequence cols", args: args{ - aggregateTypes: []AggregateType{"user"}, setters: []func(*SearchQueryFactory) *SearchQueryFactory{ testSetLimit(5), testSetSortOrder(false), - testSetSequence(100), testSetColumns(Columns_Max_Sequence), + testAddQuery( + testSetSequence(100), + testSetAggregateTypes("user"), + ), }, }, res: res{ @@ -307,9 +406,11 @@ func TestSearchQueryFactoryBuild(t *testing.T) { Columns: Columns_Max_Sequence, Desc: true, Limit: 5, - Filters: []*Filter{ - NewFilter(Field_AggregateType, AggregateType("user"), Operation_Equals), - NewFilter(Field_LatestSequence, uint64(100), Operation_Less), + Filters: [][]*Filter{ + { + NewFilter(Field_AggregateType, AggregateType("user"), Operation_Equals), + NewFilter(Field_LatestSequence, uint64(100), Operation_Less), + }, }, }, }, @@ -317,9 +418,11 @@ func TestSearchQueryFactoryBuild(t *testing.T) { { name: "filter aggregate type and aggregate id", args: args{ - aggregateTypes: []AggregateType{"user"}, setters: []func(*SearchQueryFactory) *SearchQueryFactory{ - testSetAggregateIDs("1234"), + testAddQuery( + testSetAggregateIDs("1234"), + testSetAggregateTypes("user"), + ), }, }, res: res{ @@ -328,9 +431,11 @@ func TestSearchQueryFactoryBuild(t *testing.T) { Columns: 0, Desc: false, Limit: 0, - Filters: []*Filter{ - NewFilter(Field_AggregateType, AggregateType("user"), Operation_Equals), - NewFilter(Field_AggregateID, "1234", Operation_Equals), + Filters: [][]*Filter{ + { + NewFilter(Field_AggregateType, AggregateType("user"), Operation_Equals), + NewFilter(Field_AggregateID, "1234", Operation_Equals), + }, }, }, }, @@ -338,9 +443,11 @@ func TestSearchQueryFactoryBuild(t *testing.T) { { name: "filter aggregate type and aggregate ids", args: args{ - aggregateTypes: []AggregateType{"user"}, setters: []func(*SearchQueryFactory) *SearchQueryFactory{ - testSetAggregateIDs("1234", "0815"), + testAddQuery( + testSetAggregateIDs("1234", "0815"), + testSetAggregateTypes("user"), + ), }, }, res: res{ @@ -349,9 +456,11 @@ func TestSearchQueryFactoryBuild(t *testing.T) { Columns: 0, Desc: false, Limit: 0, - Filters: []*Filter{ - NewFilter(Field_AggregateType, AggregateType("user"), Operation_Equals), - NewFilter(Field_AggregateID, []string{"1234", "0815"}, Operation_In), + Filters: [][]*Filter{ + { + NewFilter(Field_AggregateType, AggregateType("user"), Operation_Equals), + NewFilter(Field_AggregateID, []string{"1234", "0815"}, Operation_In), + }, }, }, }, @@ -359,9 +468,11 @@ func TestSearchQueryFactoryBuild(t *testing.T) { { name: "filter aggregate type and sequence greater", args: args{ - aggregateTypes: []AggregateType{"user"}, setters: []func(*SearchQueryFactory) *SearchQueryFactory{ - testSetSequence(8), + testAddQuery( + testSetSequence(8), + testSetAggregateTypes("user"), + ), }, }, res: res{ @@ -370,9 +481,11 @@ func TestSearchQueryFactoryBuild(t *testing.T) { Columns: 0, Desc: false, Limit: 0, - Filters: []*Filter{ - NewFilter(Field_AggregateType, AggregateType("user"), Operation_Equals), - NewFilter(Field_LatestSequence, uint64(8), Operation_Greater), + Filters: [][]*Filter{ + { + NewFilter(Field_AggregateType, AggregateType("user"), Operation_Equals), + NewFilter(Field_LatestSequence, uint64(8), Operation_Greater), + }, }, }, }, @@ -380,9 +493,11 @@ func TestSearchQueryFactoryBuild(t *testing.T) { { name: "filter aggregate type and event type", args: args{ - aggregateTypes: []AggregateType{"user"}, setters: []func(*SearchQueryFactory) *SearchQueryFactory{ - testSetEventTypes("user.created"), + testAddQuery( + testSetAggregateTypes("user"), + testSetEventTypes("user.created"), + ), }, }, res: res{ @@ -391,9 +506,11 @@ func TestSearchQueryFactoryBuild(t *testing.T) { Columns: 0, Desc: false, Limit: 0, - Filters: []*Filter{ - NewFilter(Field_AggregateType, AggregateType("user"), Operation_Equals), - NewFilter(Field_EventType, EventType("user.created"), Operation_Equals), + Filters: [][]*Filter{ + { + NewFilter(Field_AggregateType, AggregateType("user"), Operation_Equals), + NewFilter(Field_EventType, EventType("user.created"), Operation_Equals), + }, }, }, }, @@ -401,9 +518,11 @@ func TestSearchQueryFactoryBuild(t *testing.T) { { name: "filter aggregate type and event types", args: args{ - aggregateTypes: []AggregateType{"user"}, setters: []func(*SearchQueryFactory) *SearchQueryFactory{ - testSetEventTypes("user.created", "user.changed"), + testAddQuery( + testSetAggregateTypes("user"), + testSetEventTypes("user.created", "user.changed"), + ), }, }, res: res{ @@ -412,9 +531,11 @@ func TestSearchQueryFactoryBuild(t *testing.T) { Columns: 0, Desc: false, Limit: 0, - Filters: []*Filter{ - NewFilter(Field_AggregateType, AggregateType("user"), Operation_Equals), - NewFilter(Field_EventType, []EventType{"user.created", "user.changed"}, Operation_In), + Filters: [][]*Filter{ + { + NewFilter(Field_AggregateType, AggregateType("user"), Operation_Equals), + NewFilter(Field_EventType, []EventType{"user.created", "user.changed"}, Operation_In), + }, }, }, }, @@ -422,9 +543,11 @@ func TestSearchQueryFactoryBuild(t *testing.T) { { name: "filter aggregate type resource owner", args: args{ - aggregateTypes: []AggregateType{"user"}, setters: []func(*SearchQueryFactory) *SearchQueryFactory{ - testSetResourceOwner("hodor"), + testAddQuery( + testSetAggregateTypes("user"), + testSetResourceOwner("hodor"), + ), }, }, res: res{ @@ -433,9 +556,11 @@ func TestSearchQueryFactoryBuild(t *testing.T) { Columns: 0, Desc: false, Limit: 0, - Filters: []*Filter{ - NewFilter(Field_AggregateType, AggregateType("user"), Operation_Equals), - NewFilter(Field_ResourceOwner, "hodor", Operation_Equals), + Filters: [][]*Filter{ + { + NewFilter(Field_AggregateType, AggregateType("user"), Operation_Equals), + NewFilter(Field_ResourceOwner, "hodor", Operation_Equals), + }, }, }, }, @@ -443,7 +568,7 @@ func TestSearchQueryFactoryBuild(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - factory := NewSearchQueryFactory(tt.args.aggregateTypes...) + factory := NewSearchQueryFactory() for _, f := range tt.args.setters { factory = f(factory) } diff --git a/internal/eventstore/v1/query/handler.go b/internal/eventstore/v1/query/handler.go index 00bc5e4829..ea8b5b6578 100755 --- a/internal/eventstore/v1/query/handler.go +++ b/internal/eventstore/v1/query/handler.go @@ -26,7 +26,7 @@ type Handler interface { QueryLimit() uint64 AggregateTypes() []models.AggregateType - CurrentSequence() (uint64, error) + CurrentSequence(instanceID string) (uint64, error) Eventstore() v1.Eventstore Subscription() *v1.Subscription @@ -41,15 +41,18 @@ func ReduceEvent(handler Handler, event *models.Event) { handler.Subscription().Unsubscribe() } }() - currentSequence, err := handler.CurrentSequence() + currentSequence, err := handler.CurrentSequence(event.InstanceID) if err != nil { logging.New().WithError(err).Warn("unable to get current sequence") return } searchQuery := models.NewSearchQuery(). + AddQuery(). AggregateTypeFilter(handler.AggregateTypes()...). SequenceBetween(currentSequence, event.Sequence). + InstanceIDFilter(event.InstanceID). + SearchQuery(). SetLimit(eventLimit) unprocessedEvents, err := handler.Eventstore().FilterEvents(context.Background(), searchQuery) @@ -59,7 +62,7 @@ func ReduceEvent(handler Handler, event *models.Event) { } for _, unprocessedEvent := range unprocessedEvents { - currentSequence, err := handler.CurrentSequence() + currentSequence, err := handler.CurrentSequence(unprocessedEvent.InstanceID) if err != nil { logging.Log("HANDL-BmpkC").WithError(err).Warn("unable to get current sequence") return diff --git a/internal/eventstore/v1/spooler/mock/spooler.go b/internal/eventstore/v1/spooler/mock/spooler.go index 25b2b076a5..c0e3cb224d 100644 --- a/internal/eventstore/v1/spooler/mock/spooler.go +++ b/internal/eventstore/v1/spooler/mock/spooler.go @@ -5,44 +5,45 @@ package mock import ( - gomock "github.com/golang/mock/gomock" reflect "reflect" time "time" + + gomock "github.com/golang/mock/gomock" ) -// MockLocker is a mock of Locker interface +// MockLocker is a mock of Locker interface. type MockLocker struct { ctrl *gomock.Controller recorder *MockLockerMockRecorder } -// MockLockerMockRecorder is the mock recorder for MockLocker +// MockLockerMockRecorder is the mock recorder for MockLocker. type MockLockerMockRecorder struct { mock *MockLocker } -// NewMockLocker creates a new mock instance +// NewMockLocker creates a new mock instance. func NewMockLocker(ctrl *gomock.Controller) *MockLocker { mock := &MockLocker{ctrl: ctrl} mock.recorder = &MockLockerMockRecorder{mock} return mock } -// EXPECT returns an object that allows the caller to indicate expected use +// EXPECT returns an object that allows the caller to indicate expected use. func (m *MockLocker) EXPECT() *MockLockerMockRecorder { return m.recorder } -// Renew mocks base method -func (m *MockLocker) Renew(lockerID, viewModel string, waitTime time.Duration) error { +// Renew mocks base method. +func (m *MockLocker) Renew(lockerID, viewModel, instanceID string, waitTime time.Duration) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Renew", lockerID, viewModel, waitTime) + ret := m.ctrl.Call(m, "Renew", lockerID, viewModel, instanceID, waitTime) ret0, _ := ret[0].(error) return ret0 } -// Renew indicates an expected call of Renew -func (mr *MockLockerMockRecorder) Renew(lockerID, viewModel, waitTime interface{}) *gomock.Call { +// Renew indicates an expected call of Renew. +func (mr *MockLockerMockRecorder) Renew(lockerID, viewModel, instanceID, waitTime interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Renew", reflect.TypeOf((*MockLocker)(nil).Renew), lockerID, viewModel, waitTime) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Renew", reflect.TypeOf((*MockLocker)(nil).Renew), lockerID, viewModel, instanceID, waitTime) } diff --git a/internal/eventstore/v1/spooler/spooler.go b/internal/eventstore/v1/spooler/spooler.go index 0a28808878..0884b11bb0 100644 --- a/internal/eventstore/v1/spooler/spooler.go +++ b/internal/eventstore/v1/spooler/spooler.go @@ -16,6 +16,8 @@ import ( "github.com/caos/zitadel/internal/view/repository" ) +const systemID = "system" + type Spooler struct { handlers []query.Handler locker Locker @@ -26,7 +28,7 @@ type Spooler struct { } type Locker interface { - Renew(lockerID, viewModel string, waitTime time.Duration) error + Renew(lockerID, viewModel, instanceID string, waitTime time.Duration) error } type spooledHandler struct { @@ -138,19 +140,6 @@ func (s *spooledHandler) query(ctx context.Context) ([]*models.Event, error) { if err != nil { return nil, err } - factory := models.FactoryFromSearchQuery(query) - sequence, err := s.eventstore.LatestSequence(ctx, factory) - logging.OnError(err).WithField("traceID", tracing.TraceIDFromCtx(ctx)).Debug("unable to query latest sequence") - var processedSequence uint64 - for _, filter := range query.Filters { - if filter.GetField() == models.Field_LatestSequence { - processedSequence = filter.GetValue().(uint64) - } - } - if sequence != 0 && processedSequence == sequence { - return nil, nil - } - query.Limit = s.QueryLimit() return s.eventstore.FilterEvents(ctx, query) } @@ -169,7 +158,7 @@ func (s *spooledHandler) lock(ctx context.Context, errs chan<- error, workerID s case <-ctx.Done(): return case <-renewTimer: - err := s.locker.Renew(workerID, s.ViewModel(), s.LockDuration()) + err := s.locker.Renew(workerID, s.ViewModel(), systemID, s.LockDuration()) firstLock.Do(func() { locked <- err == nil }) @@ -190,16 +179,17 @@ func (s *spooledHandler) lock(ctx context.Context, errs chan<- error, workerID s } func HandleError(event *models.Event, failedErr error, - latestFailedEvent func(sequence uint64) (*repository.FailedEvent, error), + latestFailedEvent func(sequence uint64, instanceID string) (*repository.FailedEvent, error), processFailedEvent func(*repository.FailedEvent) error, processSequence func(*models.Event) error, errorCountUntilSkip uint64) error { - failedEvent, err := latestFailedEvent(event.Sequence) + failedEvent, err := latestFailedEvent(event.Sequence, event.InstanceID) if err != nil { return err } failedEvent.FailureCount++ failedEvent.ErrMsg = failedErr.Error() + failedEvent.InstanceID = event.InstanceID err = processFailedEvent(failedEvent) if err != nil { return err diff --git a/internal/eventstore/v1/spooler/spooler_test.go b/internal/eventstore/v1/spooler/spooler_test.go index 6b327fa7bd..1549c5f4fe 100644 --- a/internal/eventstore/v1/spooler/spooler_test.go +++ b/internal/eventstore/v1/spooler/spooler_test.go @@ -3,17 +3,18 @@ package spooler import ( "context" "fmt" - "github.com/caos/zitadel/internal/eventstore" - "github.com/caos/zitadel/internal/eventstore/v1" "testing" "time" + "github.com/golang/mock/gomock" + "github.com/caos/zitadel/internal/errors" + "github.com/caos/zitadel/internal/eventstore" + v1 "github.com/caos/zitadel/internal/eventstore/v1" "github.com/caos/zitadel/internal/eventstore/v1/models" "github.com/caos/zitadel/internal/eventstore/v1/query" "github.com/caos/zitadel/internal/eventstore/v1/spooler/mock" "github.com/caos/zitadel/internal/view/repository" - "github.com/golang/mock/gomock" ) type testHandler struct { @@ -30,7 +31,7 @@ func (h *testHandler) AggregateTypes() []models.AggregateType { return nil } -func (h *testHandler) CurrentSequence() (uint64, error) { +func (h *testHandler) CurrentSequence(instanceID string) (uint64, error) { return 0, nil } @@ -376,8 +377,8 @@ func newTestLocker(t *testing.T, lockerID, viewName string) *testLocker { func (l *testLocker) expectRenew(t *testing.T, err error, waitTime time.Duration) *testLocker { t.Helper() - l.mock.EXPECT().Renew(gomock.Any(), l.viewName, gomock.Any()).DoAndReturn( - func(_, _ string, gotten time.Duration) error { + l.mock.EXPECT().Renew(gomock.Any(), l.viewName, gomock.Any(), gomock.Any()).DoAndReturn( + func(_, _, _ string, gotten time.Duration) error { t.Helper() if waitTime-gotten != 0 { t.Errorf("expected waittime %v got %v", waitTime, gotten) @@ -396,7 +397,7 @@ func TestHandleError(t *testing.T) { type args struct { event *models.Event failedErr error - latestFailedEvent func(sequence uint64) (*repository.FailedEvent, error) + latestFailedEvent func(sequence uint64, instanceID string) (*repository.FailedEvent, error) errorCountUntilSkip uint64 } type res struct { @@ -413,12 +414,13 @@ func TestHandleError(t *testing.T) { args: args{ event: &models.Event{Sequence: 30000000}, failedErr: errors.ThrowInternal(nil, "SPOOL-Wk53B", "this was wrong"), - latestFailedEvent: func(s uint64) (*repository.FailedEvent, error) { + latestFailedEvent: func(s uint64, instanceID string) (*repository.FailedEvent, error) { return &repository.FailedEvent{ ErrMsg: "blub", FailedSequence: s - 1, FailureCount: 6, ViewName: "super.table", + InstanceID: instanceID, }, nil }, errorCountUntilSkip: 5, @@ -432,12 +434,13 @@ func TestHandleError(t *testing.T) { args: args{ event: &models.Event{Sequence: 30000000}, failedErr: errors.ThrowInternal(nil, "SPOOL-Wk53B", "this was wrong"), - latestFailedEvent: func(s uint64) (*repository.FailedEvent, error) { + latestFailedEvent: func(s uint64, instanceID string) (*repository.FailedEvent, error) { return &repository.FailedEvent{ ErrMsg: "blub", FailedSequence: s - 1, FailureCount: 5, ViewName: "super.table", + InstanceID: instanceID, }, nil }, errorCountUntilSkip: 6, @@ -451,12 +454,13 @@ func TestHandleError(t *testing.T) { args: args{ event: &models.Event{Sequence: 30000000}, failedErr: errors.ThrowInternal(nil, "SPOOL-Wk53B", "this was wrong"), - latestFailedEvent: func(s uint64) (*repository.FailedEvent, error) { + latestFailedEvent: func(s uint64, instanceID string) (*repository.FailedEvent, error) { return &repository.FailedEvent{ ErrMsg: "blub", FailedSequence: s - 1, FailureCount: 3, ViewName: "super.table", + InstanceID: instanceID, }, nil }, errorCountUntilSkip: 5, diff --git a/internal/iam/model/idp_provider_view.go b/internal/iam/model/idp_provider_view.go index b40760a6a3..4508ea3a08 100644 --- a/internal/iam/model/idp_provider_view.go +++ b/internal/iam/model/idp_provider_view.go @@ -36,6 +36,7 @@ const ( IDPProviderSearchKeyAggregateID IDPProviderSearchKeyIdpConfigID IDPProviderSearchKeyState + IDPProviderSearchKeyInstanceID ) type IDPProviderSearchQuery struct { diff --git a/internal/iam/repository/view/idp_provider_view.go b/internal/iam/repository/view/idp_provider_view.go index 54f59a8226..c600f33b2a 100644 --- a/internal/iam/repository/view/idp_provider_view.go +++ b/internal/iam/repository/view/idp_provider_view.go @@ -1,19 +1,21 @@ package view import ( + "github.com/jinzhu/gorm" + "github.com/caos/zitadel/internal/domain" caos_errs "github.com/caos/zitadel/internal/errors" iam_model "github.com/caos/zitadel/internal/iam/model" "github.com/caos/zitadel/internal/iam/repository/view/model" "github.com/caos/zitadel/internal/view/repository" - "github.com/jinzhu/gorm" ) -func GetIDPProviderByAggregateIDAndConfigID(db *gorm.DB, table, aggregateID, idpConfigID string) (*model.IDPProviderView, error) { +func GetIDPProviderByAggregateIDAndConfigID(db *gorm.DB, table, aggregateID, idpConfigID, instanceID string) (*model.IDPProviderView, error) { policy := new(model.IDPProviderView) aggIDQuery := &model.IDPProviderSearchQuery{Key: iam_model.IDPProviderSearchKeyAggregateID, Value: aggregateID, Method: domain.SearchMethodEquals} idpConfigIDQuery := &model.IDPProviderSearchQuery{Key: iam_model.IDPProviderSearchKeyIdpConfigID, Value: idpConfigID, Method: domain.SearchMethodEquals} - query := repository.PrepareGetByQuery(table, aggIDQuery, idpConfigIDQuery) + instanceIDQuery := &model.IDPProviderSearchQuery{Key: iam_model.IDPProviderSearchKeyInstanceID, Value: instanceID, Method: domain.SearchMethodEquals} + query := repository.PrepareGetByQuery(table, aggIDQuery, idpConfigIDQuery, instanceIDQuery) err := query(db, policy) if caos_errs.IsNotFound(err) { return nil, caos_errs.ThrowNotFound(nil, "VIEW-Skvi8", "Errors.IAM.LoginPolicy.IDP.NotExisting") @@ -21,7 +23,7 @@ func GetIDPProviderByAggregateIDAndConfigID(db *gorm.DB, table, aggregateID, idp return policy, err } -func IDPProvidersByIdpConfigID(db *gorm.DB, table string, idpConfigID string) ([]*model.IDPProviderView, error) { +func IDPProvidersByIdpConfigID(db *gorm.DB, table, idpConfigID, instanceID string) ([]*model.IDPProviderView, error) { providers := make([]*model.IDPProviderView, 0) queries := []*iam_model.IDPProviderSearchQuery{ { @@ -29,6 +31,11 @@ func IDPProvidersByIdpConfigID(db *gorm.DB, table string, idpConfigID string) ([ Value: idpConfigID, Method: domain.SearchMethodEquals, }, + { + Key: iam_model.IDPProviderSearchKeyInstanceID, + Value: instanceID, + Method: domain.SearchMethodEquals, + }, } query := repository.PrepareSearchQuery(table, model.IDPProviderSearchRequest{Queries: queries}) _, err := query(db, &providers) @@ -38,7 +45,7 @@ func IDPProvidersByIdpConfigID(db *gorm.DB, table string, idpConfigID string) ([ return providers, nil } -func IDPProvidersByAggregateIDAndState(db *gorm.DB, table string, aggregateID string, idpConfigState iam_model.IDPConfigState) ([]*model.IDPProviderView, error) { +func IDPProvidersByAggregateIDAndState(db *gorm.DB, table string, aggregateID, instanceID string, idpConfigState iam_model.IDPConfigState) ([]*model.IDPProviderView, error) { providers := make([]*model.IDPProviderView, 0) queries := []*iam_model.IDPProviderSearchQuery{ { @@ -51,6 +58,11 @@ func IDPProvidersByAggregateIDAndState(db *gorm.DB, table string, aggregateID st Value: int(idpConfigState), Method: domain.SearchMethodEquals, }, + { + Key: iam_model.IDPProviderSearchKeyInstanceID, + Value: instanceID, + Method: domain.SearchMethodEquals, + }, } query := repository.PrepareSearchQuery(table, model.IDPProviderSearchRequest{Queries: queries}) _, err := query(db, &providers) @@ -84,17 +96,19 @@ func PutIDPProviders(db *gorm.DB, table string, providers ...*model.IDPProviderV return save(db, p...) } -func DeleteIDPProvider(db *gorm.DB, table, aggregateID, idpConfigID string) error { +func DeleteIDPProvider(db *gorm.DB, table, aggregateID, idpConfigID, instanceID string) error { delete := repository.PrepareDeleteByKeys(table, repository.Key{Key: model.IDPProviderSearchKey(iam_model.IDPProviderSearchKeyAggregateID), Value: aggregateID}, repository.Key{Key: model.IDPProviderSearchKey(iam_model.IDPProviderSearchKeyIdpConfigID), Value: idpConfigID}, + repository.Key{Key: model.IDPProviderSearchKey(iam_model.IDPProviderSearchKeyInstanceID), Value: instanceID}, ) return delete(db) } -func DeleteIDPProvidersByAggregateID(db *gorm.DB, table, aggregateID string) error { +func DeleteIDPProvidersByAggregateID(db *gorm.DB, table, aggregateID, instanceID string) error { delete := repository.PrepareDeleteByKeys(table, repository.Key{Key: model.IDPProviderSearchKey(iam_model.IDPProviderSearchKeyAggregateID), Value: aggregateID}, + repository.Key{Key: model.IDPProviderSearchKey(iam_model.IDPProviderSearchKeyInstanceID), Value: instanceID}, ) return delete(db) } diff --git a/internal/iam/repository/view/idp_view.go b/internal/iam/repository/view/idp_view.go index 3687d4e59e..85f123a7ea 100644 --- a/internal/iam/repository/view/idp_view.go +++ b/internal/iam/repository/view/idp_view.go @@ -1,18 +1,20 @@ package view import ( + "github.com/jinzhu/gorm" + "github.com/caos/zitadel/internal/domain" caos_errs "github.com/caos/zitadel/internal/errors" iam_model "github.com/caos/zitadel/internal/iam/model" "github.com/caos/zitadel/internal/iam/repository/view/model" "github.com/caos/zitadel/internal/view/repository" - "github.com/jinzhu/gorm" ) -func IDPByID(db *gorm.DB, table, idpID string) (*model.IDPConfigView, error) { +func IDPByID(db *gorm.DB, table, idpID, instanceID string) (*model.IDPConfigView, error) { idp := new(model.IDPConfigView) idpIDQuery := &model.IDPConfigSearchQuery{Key: iam_model.IDPConfigSearchKeyIdpConfigID, Value: idpID, Method: domain.SearchMethodEquals} - query := repository.PrepareGetByQuery(table, idpIDQuery) + instanceIDQuery := &model.IDPConfigSearchQuery{Key: iam_model.IDPConfigSearchKeyInstanceID, Value: instanceID, Method: domain.SearchMethodEquals} + query := repository.PrepareGetByQuery(table, idpIDQuery, instanceIDQuery) err := query(db, idp) if caos_errs.IsNotFound(err) { return nil, caos_errs.ThrowNotFound(nil, "VIEW-Ahq2s", "Errors.IDP.NotExisting") @@ -20,13 +22,17 @@ func IDPByID(db *gorm.DB, table, idpID string) (*model.IDPConfigView, error) { return idp, err } -func GetIDPConfigsByAggregateID(db *gorm.DB, table string, aggregateID string) ([]*model.IDPConfigView, error) { +func GetIDPConfigsByAggregateID(db *gorm.DB, table string, aggregateID, instanceID string) ([]*model.IDPConfigView, error) { idps := make([]*model.IDPConfigView, 0) queries := []*iam_model.IDPConfigSearchQuery{ { Key: iam_model.IDPConfigSearchKeyAggregateID, Value: aggregateID, Method: domain.SearchMethodEquals, + }, { + Key: iam_model.IDPConfigSearchKeyInstanceID, + Value: instanceID, + Method: domain.SearchMethodEquals, }, } query := repository.PrepareSearchQuery(table, model.IDPConfigSearchRequest{Queries: queries}) @@ -52,8 +58,11 @@ func PutIDP(db *gorm.DB, table string, idp *model.IDPConfigView) error { return save(db, idp) } -func DeleteIDP(db *gorm.DB, table, idpID string) error { - delete := repository.PrepareDeleteByKey(table, model.IDPConfigSearchKey(iam_model.IDPConfigSearchKeyIdpConfigID), idpID) +func DeleteIDP(db *gorm.DB, table, idpID, instanceID string) error { + delete := repository.PrepareDeleteByKeys(table, + repository.Key{model.IDPConfigSearchKey(iam_model.IDPConfigSearchKeyIdpConfigID), idpID}, + repository.Key{model.IDPConfigSearchKey(iam_model.IDPConfigSearchKeyInstanceID), instanceID}, + ) return delete(db) } diff --git a/internal/iam/repository/view/model/idp_config.go b/internal/iam/repository/view/model/idp_config.go index e0231bbc9a..2d8a6bb16a 100644 --- a/internal/iam/repository/view/model/idp_config.go +++ b/internal/iam/repository/view/model/idp_config.go @@ -50,7 +50,7 @@ type IDPConfigView struct { JWTHeaderName string `json:"headerName" gorm:"jwt_header_name"` Sequence uint64 `json:"-" gorm:"column:sequence"` - InstanceID string `json:"instanceID" gorm:"column:instance_id"` + InstanceID string `json:"instanceID" gorm:"column:instance_id;primary_key"` } func IDPConfigViewToModel(idp *IDPConfigView) *model.IDPConfigView { diff --git a/internal/iam/repository/view/model/idp_provider.go b/internal/iam/repository/view/model/idp_provider.go index 40a2a770c2..3d086263f5 100644 --- a/internal/iam/repository/view/model/idp_provider.go +++ b/internal/iam/repository/view/model/idp_provider.go @@ -18,6 +18,7 @@ const ( IDPProviderKeyAggregateID = "aggregate_id" IDPProviderKeyIdpConfigID = "idp_config_id" IDPProviderKeyState = "idp_state" + IDPProviderKeyInstanceID = "instance_id" ) type IDPProviderView struct { @@ -34,7 +35,7 @@ type IDPProviderView struct { IDPState int32 `json:"-" gorm:"column:idp_state"` Sequence uint64 `json:"-" gorm:"column:sequence"` - InstanceID string `json:"instanceID" gorm:"column:instance_id"` + InstanceID string `json:"instanceID" gorm:"column:instance_id;primary_key"` } func IDPProviderViewToModel(provider *IDPProviderView) *model.IDPProviderView { diff --git a/internal/iam/repository/view/model/idp_provider_query.go b/internal/iam/repository/view/model/idp_provider_query.go index 9ff537dc44..29198849f0 100644 --- a/internal/iam/repository/view/model/idp_provider_query.go +++ b/internal/iam/repository/view/model/idp_provider_query.go @@ -57,6 +57,8 @@ func (key IDPProviderSearchKey) ToColumnName() string { return IDPProviderKeyIdpConfigID case iam_model.IDPProviderSearchKeyState: return IDPProviderKeyState + case iam_model.IDPProviderSearchKeyInstanceID: + return IDPProviderKeyInstanceID default: return "" } diff --git a/internal/iam/repository/view/model/label_policy.go b/internal/iam/repository/view/model/label_policy.go index 174c9a0995..9c9ea26c27 100644 --- a/internal/iam/repository/view/model/label_policy.go +++ b/internal/iam/repository/view/model/label_policy.go @@ -45,7 +45,7 @@ type LabelPolicyView struct { Default bool `json:"-" gorm:"-"` Sequence uint64 `json:"-" gorm:"column:sequence"` - InstanceID string `json:"instanceID" gorm:"column:instance_id"` + InstanceID string `json:"instanceID" gorm:"column:instance_id;primary_key"` } type AssetView struct { diff --git a/internal/iam/repository/view/styling.go b/internal/iam/repository/view/styling.go index b685733b3f..146e557315 100644 --- a/internal/iam/repository/view/styling.go +++ b/internal/iam/repository/view/styling.go @@ -10,11 +10,12 @@ import ( "github.com/jinzhu/gorm" ) -func GetStylingByAggregateIDAndState(db *gorm.DB, table, aggregateID string, state int32) (*model.LabelPolicyView, error) { +func GetStylingByAggregateIDAndState(db *gorm.DB, table, aggregateID, instanceID string, state int32) (*model.LabelPolicyView, error) { policy := new(model.LabelPolicyView) aggregateIDQuery := &model.LabelPolicySearchQuery{Key: iam_model.LabelPolicySearchKeyAggregateID, Value: aggregateID, Method: domain.SearchMethodEquals} stateQuery := &model.LabelPolicySearchQuery{Key: iam_model.LabelPolicySearchKeyState, Value: state, Method: domain.SearchMethodEquals} - query := repository.PrepareGetByQuery(table, aggregateIDQuery, stateQuery) + instanceIDQuery := &model.LabelPolicySearchQuery{Key: iam_model.LabelPolicySearchKeyInstanceID, Value: instanceID, Method: domain.SearchMethodEquals} + query := repository.PrepareGetByQuery(table, aggregateIDQuery, stateQuery, instanceIDQuery) err := query(db, policy) if caos_errs.IsNotFound(err) { return nil, caos_errs.ThrowNotFound(nil, "VIEW-68G11", "Errors.IAM.LabelPolicy.NotExisting") diff --git a/internal/notification/repository/eventsourcing/handler/notification.go b/internal/notification/repository/eventsourcing/handler/notification.go index 2f404259e5..7d59b9cee6 100644 --- a/internal/notification/repository/eventsourcing/handler/notification.go +++ b/internal/notification/repository/eventsourcing/handler/notification.go @@ -99,8 +99,8 @@ func (_ *Notification) AggregateTypes() []models.AggregateType { return []models.AggregateType{user_repo.AggregateType} } -func (n *Notification) CurrentSequence() (uint64, error) { - sequence, err := n.view.GetLatestNotificationSequence() +func (n *Notification) CurrentSequence(instanceID string) (uint64, error) { + sequence, err := n.view.GetLatestNotificationSequence(instanceID) if err != nil { return 0, err } @@ -108,11 +108,29 @@ func (n *Notification) CurrentSequence() (uint64, error) { } func (n *Notification) EventQuery() (*models.SearchQuery, error) { - sequence, err := n.view.GetLatestNotificationSequence() + sequences, err := n.view.GetLatestNotificationSequences() if err != nil { return nil, err } - return view.UserQuery(sequence.CurrentSequence), nil + query := models.NewSearchQuery() + instances := make([]string, 0) + for _, sequence := range sequences { + for _, instance := range instances { + if sequence.InstanceID == instance { + break + } + } + instances = append(instances, sequence.InstanceID) + query.AddQuery(). + AggregateTypeFilter(n.AggregateTypes()...). + LatestSequenceFilter(sequence.CurrentSequence). + InstanceIDFilter(sequence.InstanceID) + } + return query.AddQuery(). + AggregateTypeFilter(n.AggregateTypes()...). + LatestSequenceFilter(0). + ExcludedInstanceIDsFilter(instances...). + SearchQuery(), nil } func (n *Notification) Reduce(event *models.Event) (err error) { @@ -162,7 +180,7 @@ func (n *Notification) handleInitUserCode(event *models.Event) (err error) { return err } - user, err := n.getUserByID(event.AggregateID) + user, err := n.getUserByID(event.AggregateID, event.InstanceID) if err != nil { return err } @@ -201,7 +219,7 @@ func (n *Notification) handlePasswordCode(event *models.Event) (err error) { return err } - user, err := n.getUserByID(event.AggregateID) + user, err := n.getUserByID(event.AggregateID, event.InstanceID) if err != nil { return err } @@ -239,7 +257,7 @@ func (n *Notification) handleEmailVerificationCode(event *models.Event) (err err return err } - user, err := n.getUserByID(event.AggregateID) + user, err := n.getUserByID(event.AggregateID, event.InstanceID) if err != nil { return err } @@ -268,7 +286,7 @@ func (n *Notification) handlePhoneVerificationCode(event *models.Event) (err err if err != nil || alreadyHandled { return nil } - user, err := n.getUserByID(event.AggregateID) + user, err := n.getUserByID(event.AggregateID, event.InstanceID) if err != nil { return err } @@ -285,7 +303,7 @@ func (n *Notification) handlePhoneVerificationCode(event *models.Event) (err err func (n *Notification) handleDomainClaimed(event *models.Event) (err error) { ctx := getSetNotifyContextData(event.InstanceID, event.ResourceOwner) - alreadyHandled, err := n.checkIfAlreadyHandled(ctx, event.AggregateID, event.Sequence, user_repo.UserDomainClaimedType, user_repo.UserDomainClaimedSentType) + alreadyHandled, err := n.checkIfAlreadyHandled(ctx, event.AggregateID, event.InstanceID, event.Sequence, user_repo.UserDomainClaimedType, user_repo.UserDomainClaimedSentType) if err != nil || alreadyHandled { return nil } @@ -294,7 +312,7 @@ func (n *Notification) handleDomainClaimed(event *models.Event) (err error) { logging.Log("HANDLE-Gghq2").WithError(err).Error("could not unmarshal event data") return errors.ThrowInternal(err, "HANDLE-7hgj3", "could not unmarshal event") } - user, err := n.getUserByID(event.AggregateID) + user, err := n.getUserByID(event.AggregateID, event.InstanceID) if err != nil { return err } @@ -329,7 +347,7 @@ func (n *Notification) handlePasswordlessRegistrationLink(event *models.Event) ( return err } ctx := getSetNotifyContextData(event.InstanceID, event.ResourceOwner) - events, err := n.getUserEvents(ctx, event.AggregateID, event.Sequence) + events, err := n.getUserEvents(ctx, event.AggregateID, event.InstanceID, event.Sequence) if err != nil { return err } @@ -344,7 +362,7 @@ func (n *Notification) handlePasswordlessRegistrationLink(event *models.Event) ( } } } - user, err := n.getUserByID(event.AggregateID) + user, err := n.getUserByID(event.AggregateID, event.InstanceID) if err != nil { return err } @@ -374,11 +392,11 @@ func (n *Notification) checkIfCodeAlreadyHandledOrExpired(ctx context.Context, e if event.CreationDate.Add(expiry).Before(time.Now().UTC()) { return true, nil } - return n.checkIfAlreadyHandled(ctx, event.AggregateID, event.Sequence, eventTypes...) + return n.checkIfAlreadyHandled(ctx, event.AggregateID, event.InstanceID, event.Sequence, eventTypes...) } -func (n *Notification) checkIfAlreadyHandled(ctx context.Context, userID string, sequence uint64, eventTypes ...eventstore.EventType) (bool, error) { - events, err := n.getUserEvents(ctx, userID, sequence) +func (n *Notification) checkIfAlreadyHandled(ctx context.Context, userID, instanceID string, sequence uint64, eventTypes ...eventstore.EventType) (bool, error) { + events, err := n.getUserEvents(ctx, userID, instanceID, sequence) if err != nil { return false, err } @@ -392,8 +410,8 @@ func (n *Notification) checkIfAlreadyHandled(ctx context.Context, userID string, return false, nil } -func (n *Notification) getUserEvents(ctx context.Context, userID string, sequence uint64) ([]*models.Event, error) { - query, err := view.UserByIDQuery(userID, sequence) +func (n *Notification) getUserEvents(ctx context.Context, userID, instanceID string, sequence uint64) ([]*models.Event, error) { + query, err := view.UserByIDQuery(userID, instanceID, sequence) if err != nil { return nil, err } @@ -514,6 +532,6 @@ func (n *Notification) getTranslatorWithOrgTexts(ctx context.Context, orgID, tex return translator, nil } -func (n *Notification) getUserByID(userID string) (*model.NotifyUser, error) { - return n.view.NotifyUserByID(userID) +func (n *Notification) getUserByID(userID, instanceID string) (*model.NotifyUser, error) { + return n.view.NotifyUserByID(userID, instanceID) } diff --git a/internal/notification/repository/eventsourcing/handler/notify_user.go b/internal/notification/repository/eventsourcing/handler/notify_user.go index 2f791bfe14..f8da094692 100644 --- a/internal/notification/repository/eventsourcing/handler/notify_user.go +++ b/internal/notification/repository/eventsourcing/handler/notify_user.go @@ -67,8 +67,8 @@ func (_ *NotifyUser) AggregateTypes() []es_models.AggregateType { return []es_models.AggregateType{user.AggregateType, org.AggregateType} } -func (p *NotifyUser) CurrentSequence() (uint64, error) { - sequence, err := p.view.GetLatestNotifyUserSequence() +func (p *NotifyUser) CurrentSequence(instanceID string) (uint64, error) { + sequence, err := p.view.GetLatestNotifyUserSequence(instanceID) if err != nil { return 0, err } @@ -76,13 +76,29 @@ func (p *NotifyUser) CurrentSequence() (uint64, error) { } func (p *NotifyUser) EventQuery() (*es_models.SearchQuery, error) { - sequence, err := p.view.GetLatestNotifyUserSequence() + sequences, err := p.view.GetLatestNotifyUserSequences() if err != nil { return nil, err } - return es_models.NewSearchQuery(). + query := es_models.NewSearchQuery() + instances := make([]string, 0) + for _, sequence := range sequences { + for _, instance := range instances { + if sequence.InstanceID == instance { + break + } + } + instances = append(instances, sequence.InstanceID) + query.AddQuery(). + AggregateTypeFilter(p.AggregateTypes()...). + LatestSequenceFilter(sequence.CurrentSequence). + InstanceIDFilter(sequence.InstanceID) + } + return query.AddQuery(). AggregateTypeFilter(p.AggregateTypes()...). - LatestSequenceFilter(sequence.CurrentSequence), nil + LatestSequenceFilter(0). + ExcludedInstanceIDsFilter(instances...). + SearchQuery(), nil } func (u *NotifyUser) Reduce(event *es_models.Event) (err error) { @@ -122,14 +138,14 @@ func (u *NotifyUser) ProcessUser(event *es_models.Event) (err error) { user.HumanPhoneVerifiedType, user.HumanPhoneRemovedType, user.MachineChangedEventType: - notifyUser, err = u.view.NotifyUserByID(event.AggregateID) + notifyUser, err = u.view.NotifyUserByID(event.AggregateID, event.InstanceID) if err != nil { return err } err = notifyUser.AppendEvent(event) case user.UserDomainClaimedType, user.UserUserNameChangedType: - notifyUser, err = u.view.NotifyUserByID(event.AggregateID) + notifyUser, err = u.view.NotifyUserByID(event.AggregateID, event.InstanceID) if err != nil { return err } @@ -139,7 +155,7 @@ func (u *NotifyUser) ProcessUser(event *es_models.Event) (err error) { } err = u.fillLoginNames(notifyUser) case user.UserRemovedType: - return u.view.DeleteNotifyUser(event.AggregateID, event) + return u.view.DeleteNotifyUser(event.AggregateID, event.InstanceID, event) default: return u.view.ProcessedNotifyUserSequence(event) } @@ -169,7 +185,7 @@ func (u *NotifyUser) fillLoginNamesOnOrgUsers(event *es_models.Event) error { if err != nil { return err } - users, err := u.view.NotifyUsersByOrgID(event.AggregateID) + users, err := u.view.NotifyUsersByOrgID(event.AggregateID, event.InstanceID) if err != nil { return err } @@ -191,7 +207,7 @@ func (u *NotifyUser) fillPreferredLoginNamesOnOrgUsers(event *es_models.Event) e if !userLoginMustBeDomain { return nil } - users, err := u.view.NotifyUsersByOrgID(event.AggregateID) + users, err := u.view.NotifyUsersByOrgID(event.AggregateID, event.InstanceID) if err != nil { return err } diff --git a/internal/notification/repository/eventsourcing/spooler/lock.go b/internal/notification/repository/eventsourcing/spooler/lock.go index c368737ca9..d027f106e8 100644 --- a/internal/notification/repository/eventsourcing/spooler/lock.go +++ b/internal/notification/repository/eventsourcing/spooler/lock.go @@ -2,8 +2,9 @@ package spooler import ( "database/sql" - es_locker "github.com/caos/zitadel/internal/eventstore/v1/locker" "time" + + es_locker "github.com/caos/zitadel/internal/eventstore/v1/locker" ) const ( @@ -14,6 +15,6 @@ type locker struct { dbClient *sql.DB } -func (l *locker) Renew(lockerID, viewModel string, waitTime time.Duration) error { - return es_locker.Renew(l.dbClient, lockTable, lockerID, viewModel, waitTime) +func (l *locker) Renew(lockerID, viewModel, instanceID string, waitTime time.Duration) error { + return es_locker.Renew(l.dbClient, lockTable, lockerID, viewModel, instanceID, waitTime) } diff --git a/internal/notification/repository/eventsourcing/view/error_event.go b/internal/notification/repository/eventsourcing/view/error_event.go index 8979826576..f5d44d264b 100644 --- a/internal/notification/repository/eventsourcing/view/error_event.go +++ b/internal/notification/repository/eventsourcing/view/error_event.go @@ -12,6 +12,6 @@ func (v *View) saveFailedEvent(failedEvent *repository.FailedEvent) error { return repository.SaveFailedEvent(v.Db, errTable, failedEvent) } -func (v *View) latestFailedEvent(viewName string, sequence uint64) (*repository.FailedEvent, error) { - return repository.LatestFailedEvent(v.Db, errTable, viewName, sequence) +func (v *View) latestFailedEvent(viewName, instanceID string, sequence uint64) (*repository.FailedEvent, error) { + return repository.LatestFailedEvent(v.Db, errTable, viewName, instanceID, sequence) } diff --git a/internal/notification/repository/eventsourcing/view/notification.go b/internal/notification/repository/eventsourcing/view/notification.go index ebfef66bfe..dd4a6ca30a 100644 --- a/internal/notification/repository/eventsourcing/view/notification.go +++ b/internal/notification/repository/eventsourcing/view/notification.go @@ -9,8 +9,12 @@ const ( notificationTable = "notification.notifications" ) -func (v *View) GetLatestNotificationSequence() (*repository.CurrentSequence, error) { - return v.latestSequence(notificationTable) +func (v *View) GetLatestNotificationSequence(instanceID string) (*repository.CurrentSequence, error) { + return v.latestSequence(notificationTable, instanceID) +} + +func (v *View) GetLatestNotificationSequences() ([]*repository.CurrentSequence, error) { + return v.latestSequences(notificationTable) } func (v *View) ProcessedNotificationSequence(event *models.Event) error { @@ -21,8 +25,8 @@ func (v *View) UpdateNotificationSpoolerRunTimestamp() error { return v.updateSpoolerRunSequence(notificationTable) } -func (v *View) GetLatestNotificationFailedEvent(sequence uint64) (*repository.FailedEvent, error) { - return v.latestFailedEvent(notificationTable, sequence) +func (v *View) GetLatestNotificationFailedEvent(sequence uint64, instanceID string) (*repository.FailedEvent, error) { + return v.latestFailedEvent(notificationTable, instanceID, sequence) } func (v *View) ProcessedNotificationFailedEvent(failedEvent *repository.FailedEvent) error { diff --git a/internal/notification/repository/eventsourcing/view/notify_user.go b/internal/notification/repository/eventsourcing/view/notify_user.go index 99d9516a66..58ae1bb371 100644 --- a/internal/notification/repository/eventsourcing/view/notify_user.go +++ b/internal/notification/repository/eventsourcing/view/notify_user.go @@ -12,8 +12,8 @@ const ( notifyUserTable = "notification.notify_users" ) -func (v *View) NotifyUserByID(userID string) (*model.NotifyUser, error) { - return view.NotifyUserByID(v.Db, notifyUserTable, userID) +func (v *View) NotifyUserByID(userID, instanceID string) (*model.NotifyUser, error) { + return view.NotifyUserByID(v.Db, notifyUserTable, userID, instanceID) } func (v *View) PutNotifyUser(user *model.NotifyUser, event *models.Event) error { @@ -24,20 +24,24 @@ func (v *View) PutNotifyUser(user *model.NotifyUser, event *models.Event) error return v.ProcessedNotifyUserSequence(event) } -func (v *View) NotifyUsersByOrgID(orgID string) ([]*model.NotifyUser, error) { - return view.NotifyUsersByOrgID(v.Db, notifyUserTable, orgID) +func (v *View) NotifyUsersByOrgID(orgID, instanceID string) ([]*model.NotifyUser, error) { + return view.NotifyUsersByOrgID(v.Db, notifyUserTable, orgID, instanceID) } -func (v *View) DeleteNotifyUser(userID string, event *models.Event) error { - err := view.DeleteNotifyUser(v.Db, notifyUserTable, userID) +func (v *View) DeleteNotifyUser(userID, instanceID string, event *models.Event) error { + err := view.DeleteNotifyUser(v.Db, notifyUserTable, userID, instanceID) if err != nil && !errors.IsNotFound(err) { return err } return v.ProcessedNotifyUserSequence(event) } -func (v *View) GetLatestNotifyUserSequence() (*repository.CurrentSequence, error) { - return v.latestSequence(notifyUserTable) +func (v *View) GetLatestNotifyUserSequence(instanceID string) (*repository.CurrentSequence, error) { + return v.latestSequence(notifyUserTable, instanceID) +} + +func (v *View) GetLatestNotifyUserSequences() ([]*repository.CurrentSequence, error) { + return v.latestSequences(notifyUserTable) } func (v *View) ProcessedNotifyUserSequence(event *models.Event) error { @@ -48,8 +52,8 @@ func (v *View) UpdateNotifyUserSpoolerRunTimestamp() error { return v.updateSpoolerRunSequence(notifyUserTable) } -func (v *View) GetLatestNotifyUserFailedEvent(sequence uint64) (*repository.FailedEvent, error) { - return v.latestFailedEvent(notifyUserTable, sequence) +func (v *View) GetLatestNotifyUserFailedEvent(sequence uint64, instanceID string) (*repository.FailedEvent, error) { + return v.latestFailedEvent(notifyUserTable, instanceID, sequence) } func (v *View) ProcessedNotifyUserFailedEvent(failedEvent *repository.FailedEvent) error { diff --git a/internal/notification/repository/eventsourcing/view/sequence.go b/internal/notification/repository/eventsourcing/view/sequence.go index ad1723cc54..8be8166cba 100644 --- a/internal/notification/repository/eventsourcing/view/sequence.go +++ b/internal/notification/repository/eventsourcing/view/sequence.go @@ -12,21 +12,27 @@ const ( ) func (v *View) saveCurrentSequence(viewName string, event *models.Event) error { - return repository.SaveCurrentSequence(v.Db, sequencesTable, viewName, event.Sequence, event.CreationDate) + return repository.SaveCurrentSequence(v.Db, sequencesTable, viewName, event.InstanceID, event.Sequence, event.CreationDate) } -func (v *View) latestSequence(viewName string) (*repository.CurrentSequence, error) { - return repository.LatestSequence(v.Db, sequencesTable, viewName) +func (v *View) latestSequence(viewName, instanceID string) (*repository.CurrentSequence, error) { + return repository.LatestSequence(v.Db, sequencesTable, viewName, instanceID) +} + +func (v *View) latestSequences(viewName string) ([]*repository.CurrentSequence, error) { + return repository.LatestSequences(v.Db, sequencesTable, viewName) } func (v *View) updateSpoolerRunSequence(viewName string) error { - currentSequence, err := repository.LatestSequence(v.Db, sequencesTable, viewName) + currentSequences, err := repository.LatestSequences(v.Db, sequencesTable, viewName) if err != nil { return err } - if currentSequence.ViewName == "" { - currentSequence.ViewName = viewName + for _, currentSequence := range currentSequences { + if currentSequence.ViewName == "" { + currentSequence.ViewName = viewName + } + currentSequence.LastSuccessfulSpoolerRun = time.Now() } - currentSequence.LastSuccessfulSpoolerRun = time.Now() - return repository.UpdateCurrentSequence(v.Db, sequencesTable, currentSequence) + return repository.UpdateCurrentSequences(v.Db, sequencesTable, currentSequences) } diff --git a/internal/org/repository/view/query.go b/internal/org/repository/view/query.go index 46ed781575..3643e2d01d 100644 --- a/internal/org/repository/view/query.go +++ b/internal/org/repository/view/query.go @@ -11,11 +11,15 @@ func OrgByIDQuery(id string, latestSequence uint64) (*es_models.SearchQuery, err return nil, errors.ThrowPreconditionFailed(nil, "EVENT-dke74", "id should be filled") } return OrgQuery(latestSequence). - AggregateIDFilter(id), nil + AddQuery(). + AggregateIDFilter(id). + SearchQuery(), nil } func OrgQuery(latestSequence uint64) *es_models.SearchQuery { return es_models.NewSearchQuery(). + AddQuery(). AggregateTypeFilter(org.AggregateType). - LatestSequenceFilter(latestSequence) + LatestSequenceFilter(latestSequence). + SearchQuery() } diff --git a/internal/project/repository/view/org_project_mapping_view.go b/internal/project/repository/view/org_project_mapping_view.go index f980dcc36d..826623ba52 100644 --- a/internal/project/repository/view/org_project_mapping_view.go +++ b/internal/project/repository/view/org_project_mapping_view.go @@ -10,12 +10,13 @@ import ( "github.com/caos/zitadel/internal/view/repository" ) -func OrgProjectMappingByIDs(db *gorm.DB, table, orgID, projectID string) (*model.OrgProjectMapping, error) { +func OrgProjectMappingByIDs(db *gorm.DB, table, orgID, projectID, instanceID string) (*model.OrgProjectMapping, error) { orgProjectMapping := new(model.OrgProjectMapping) projectIDQuery := model.OrgProjectMappingSearchQuery{Key: proj_model.OrgProjectMappingSearchKeyProjectID, Value: projectID, Method: domain.SearchMethodEquals} orgIDQuery := model.OrgProjectMappingSearchQuery{Key: proj_model.OrgProjectMappingSearchKeyOrgID, Value: orgID, Method: domain.SearchMethodEquals} - query := repository.PrepareGetByQuery(table, projectIDQuery, orgIDQuery) + instanceIDQuery := model.OrgProjectMappingSearchQuery{Key: proj_model.OrgProjectMappingSearchKeyInstanceID, Value: instanceID, Method: domain.SearchMethodEquals} + query := repository.PrepareGetByQuery(table, projectIDQuery, orgIDQuery, instanceIDQuery) err := query(db, orgProjectMapping) if caos_errs.IsNotFound(err) { return nil, caos_errs.ThrowNotFound(nil, "VIEW-fn9fs", "Errors.OrgProjectMapping.NotExisting") @@ -28,19 +29,26 @@ func PutOrgProjectMapping(db *gorm.DB, table string, grant *model.OrgProjectMapp return save(db, grant) } -func DeleteOrgProjectMapping(db *gorm.DB, table, orgID, projectID string) error { +func DeleteOrgProjectMapping(db *gorm.DB, table, orgID, projectID, instanceID string) error { projectIDSearch := repository.Key{Key: model.OrgProjectMappingSearchKey(proj_model.OrgProjectMappingSearchKeyProjectID), Value: projectID} orgIDSearch := repository.Key{Key: model.OrgProjectMappingSearchKey(proj_model.OrgProjectMappingSearchKeyOrgID), Value: orgID} - delete := repository.PrepareDeleteByKeys(table, projectIDSearch, orgIDSearch) + instanceIDSearch := repository.Key{Key: model.OrgProjectMappingSearchKey(proj_model.OrgProjectMappingSearchKeyInstanceID), Value: instanceID} + delete := repository.PrepareDeleteByKeys(table, projectIDSearch, orgIDSearch, instanceIDSearch) return delete(db) } -func DeleteOrgProjectMappingsByProjectID(db *gorm.DB, table, projectID string) error { - delete := repository.PrepareDeleteByKey(table, model.OrgProjectMappingSearchKey(proj_model.OrgProjectMappingSearchKeyProjectID), projectID) +func DeleteOrgProjectMappingsByProjectID(db *gorm.DB, table, projectID, instanceID string) error { + delete := repository.PrepareDeleteByKeys(table, + repository.Key{model.OrgProjectMappingSearchKey(proj_model.OrgProjectMappingSearchKeyProjectID), projectID}, + repository.Key{model.OrgProjectMappingSearchKey(proj_model.OrgProjectMappingSearchKeyInstanceID), instanceID}, + ) return delete(db) } -func DeleteOrgProjectMappingsByProjectGrantID(db *gorm.DB, table, projectGrantID string) error { - delete := repository.PrepareDeleteByKey(table, model.OrgProjectMappingSearchKey(proj_model.OrgProjectMappingSearchKeyProjectGrantID), projectGrantID) +func DeleteOrgProjectMappingsByProjectGrantID(db *gorm.DB, table, projectGrantID, instanceID string) error { + delete := repository.PrepareDeleteByKeys(table, + repository.Key{model.OrgProjectMappingSearchKey(proj_model.OrgProjectMappingSearchKeyProjectGrantID), projectGrantID}, + repository.Key{model.OrgProjectMappingSearchKey(proj_model.OrgProjectMappingSearchKeyInstanceID), instanceID}, + ) return delete(db) } diff --git a/internal/project/repository/view/query.go b/internal/project/repository/view/query.go index 2969857bd9..2b5e36a400 100644 --- a/internal/project/repository/view/query.go +++ b/internal/project/repository/view/query.go @@ -6,16 +6,21 @@ import ( "github.com/caos/zitadel/internal/repository/project" ) -func ProjectByIDQuery(id string, latestSequence uint64) (*es_models.SearchQuery, error) { +func ProjectByIDQuery(id, instanceID string, latestSequence uint64) (*es_models.SearchQuery, error) { if id == "" { return nil, errors.ThrowPreconditionFailed(nil, "EVENT-dke74", "Errors.Project.ProjectIDMissing") } return ProjectQuery(latestSequence). - AggregateIDFilter(id), nil + AddQuery(). + AggregateIDFilter(id). + InstanceIDFilter(instanceID). + SearchQuery(), nil } func ProjectQuery(latestSequence uint64) *es_models.SearchQuery { return es_models.NewSearchQuery(). + AddQuery(). AggregateTypeFilter(project.AggregateType). - LatestSequenceFilter(latestSequence) + LatestSequenceFilter(latestSequence). + SearchQuery() } diff --git a/internal/query/current_sequence.go b/internal/query/current_sequence.go index ee585b7c19..c97d44630c 100644 --- a/internal/query/current_sequence.go +++ b/internal/query/current_sequence.go @@ -9,6 +9,7 @@ import ( sq "github.com/Masterminds/squirrel" + "github.com/caos/zitadel/internal/api/authz" "github.com/caos/zitadel/internal/errors" "github.com/caos/zitadel/internal/query/projection" ) @@ -72,6 +73,7 @@ func (q *Queries) latestSequence(ctx context.Context, projections ...table) (*La } stmt, args, err := query. Where(or). + Where(sq.Eq{CurrentSequenceColInstanceID.identifier(): authz.GetInstance(ctx).InstanceID()}). OrderBy(CurrentSequenceColCurrentSequence.identifier()). ToSql() if err != nil { @@ -269,6 +271,10 @@ var ( name: "projection_name", table: currentSequencesTable, } + CurrentSequenceColInstanceID = Column{ + name: "instance_id", + table: currentSequencesTable, + } ) var ( diff --git a/internal/query/projection/login_name.go b/internal/query/projection/login_name.go index 308ec9a813..fc4427fab2 100644 --- a/internal/query/projection/login_name.go +++ b/internal/query/projection/login_name.go @@ -65,7 +65,7 @@ var ( " , IFNULL(policy_custom.%[1]s, policy_default.%[1]s) AS %[1]s"+ " FROM %[7]s users"+ " LEFT JOIN %[8]s policy_custom on policy_custom.%[9]s = users.%[5]s AND policy_custom.%[10]s = users.%[4]s"+ - " LEFT JOIN %[8]s policy_default on policy_default.%[11]s = true) policy_users"+ + " LEFT JOIN %[8]s policy_default on policy_default.%[11]s = true AND policy_default.%[10]s = users.%[4]s) policy_users"+ " LEFT JOIN %[12]s domains ON policy_users.%[1]s AND policy_users.%[5]s = domains.%[13]s AND policy_users.%[10]s = domains.%[14]s"+ ");", LoginNamePoliciesMustBeDomainCol, diff --git a/internal/user/model/external_idp_view.go b/internal/user/model/external_idp_view.go index 68d0b5038f..d6b7fc02d0 100644 --- a/internal/user/model/external_idp_view.go +++ b/internal/user/model/external_idp_view.go @@ -35,6 +35,7 @@ const ( ExternalIDPSearchKeyUserID ExternalIDPSearchKeyIdpConfigID ExternalIDPSearchKeyResourceOwner + ExternalIDPSearchKeyInstanceID ) type ExternalIDPSearchQuery struct { diff --git a/internal/user/model/notify_user.go b/internal/user/model/notify_user.go index 4f0b79c0ff..102c4592a3 100644 --- a/internal/user/model/notify_user.go +++ b/internal/user/model/notify_user.go @@ -1,8 +1,9 @@ package model import ( - "github.com/caos/zitadel/internal/domain" "time" + + "github.com/caos/zitadel/internal/domain" ) type NotifyUser struct { @@ -41,6 +42,7 @@ const ( NotifyUserSearchKeyUnspecified NotifyUserSearchKey = iota NotifyUserSearchKeyUserID NotifyUserSearchKeyResourceOwner + NotifyUserSearchKeyInstanceID ) type NotifyUserSearchQuery struct { diff --git a/internal/user/model/user_session_view.go b/internal/user/model/user_session_view.go index c26334ccc2..b2bbf968c1 100644 --- a/internal/user/model/user_session_view.go +++ b/internal/user/model/user_session_view.go @@ -46,6 +46,7 @@ const ( UserSessionSearchKeyUserID UserSessionSearchKeyState UserSessionSearchKeyResourceOwner + UserSessionSearchKeyInstanceID ) type UserSessionSearchQuery struct { diff --git a/internal/user/repository/view/external_idp_view.go b/internal/user/repository/view/external_idp_view.go index e29b8c3b7d..64e7c90819 100644 --- a/internal/user/repository/view/external_idp_view.go +++ b/internal/user/repository/view/external_idp_view.go @@ -11,7 +11,7 @@ import ( "github.com/caos/zitadel/internal/user/repository/view/model" ) -func ExternalIDPByExternalUserIDAndIDPConfigID(db *gorm.DB, table, externalUserID, idpConfigID string) (*model.ExternalIDPView, error) { +func ExternalIDPByExternalUserIDAndIDPConfigID(db *gorm.DB, table, externalUserID, idpConfigID, instanceID string) (*model.ExternalIDPView, error) { user := new(model.ExternalIDPView) userIDQuery := &model.ExternalIDPSearchQuery{ Key: usr_model.ExternalIDPSearchKeyExternalUserID, @@ -23,7 +23,12 @@ func ExternalIDPByExternalUserIDAndIDPConfigID(db *gorm.DB, table, externalUserI Method: domain.SearchMethodEquals, Value: idpConfigID, } - query := repository.PrepareGetByQuery(table, userIDQuery, idpConfigIDQuery) + instanceIDQuery := &model.ExternalIDPSearchQuery{ + Key: usr_model.ExternalIDPSearchKeyInstanceID, + Method: domain.SearchMethodEquals, + Value: instanceID, + } + query := repository.PrepareGetByQuery(table, userIDQuery, idpConfigIDQuery, instanceIDQuery) err := query(db, user) if caos_errs.IsNotFound(err) { return nil, caos_errs.ThrowNotFound(nil, "VIEW-Mso9f", "Errors.ExternalIDP.NotFound") @@ -31,7 +36,7 @@ func ExternalIDPByExternalUserIDAndIDPConfigID(db *gorm.DB, table, externalUserI return user, err } -func ExternalIDPByExternalUserIDAndIDPConfigIDAndResourceOwner(db *gorm.DB, table, externalUserID, idpConfigID, resourceOwner string) (*model.ExternalIDPView, error) { +func ExternalIDPByExternalUserIDAndIDPConfigIDAndResourceOwner(db *gorm.DB, table, externalUserID, idpConfigID, resourceOwner, instanceID string) (*model.ExternalIDPView, error) { user := new(model.ExternalIDPView) userIDQuery := &model.ExternalIDPSearchQuery{ Key: usr_model.ExternalIDPSearchKeyExternalUserID, @@ -48,7 +53,12 @@ func ExternalIDPByExternalUserIDAndIDPConfigIDAndResourceOwner(db *gorm.DB, tabl Method: domain.SearchMethodEquals, Value: resourceOwner, } - query := repository.PrepareGetByQuery(table, userIDQuery, idpConfigIDQuery, resourceOwnerQuery) + instanceIDQuery := &model.ExternalIDPSearchQuery{ + Key: usr_model.ExternalIDPSearchKeyInstanceID, + Method: domain.SearchMethodEquals, + Value: instanceID, + } + query := repository.PrepareGetByQuery(table, userIDQuery, idpConfigIDQuery, resourceOwnerQuery, instanceIDQuery) err := query(db, user) if caos_errs.IsNotFound(err) { return nil, caos_errs.ThrowNotFound(nil, "VIEW-Sf8sd", "Errors.ExternalIDP.NotFound") @@ -56,15 +66,20 @@ func ExternalIDPByExternalUserIDAndIDPConfigIDAndResourceOwner(db *gorm.DB, tabl return user, err } -func ExternalIDPsByIDPConfigID(db *gorm.DB, table, idpConfigID string) ([]*model.ExternalIDPView, error) { +func ExternalIDPsByIDPConfigID(db *gorm.DB, table, idpConfigID, instanceID string) ([]*model.ExternalIDPView, error) { externalIDPs := make([]*model.ExternalIDPView, 0) orgIDQuery := &usr_model.ExternalIDPSearchQuery{ Key: usr_model.ExternalIDPSearchKeyIdpConfigID, Method: domain.SearchMethodEquals, Value: idpConfigID, } + instanceIDQuery := &usr_model.ExternalIDPSearchQuery{ + Key: usr_model.ExternalIDPSearchKeyInstanceID, + Method: domain.SearchMethodEquals, + Value: instanceID, + } query := repository.PrepareSearchQuery(table, model.ExternalIDPSearchRequest{ - Queries: []*usr_model.ExternalIDPSearchQuery{orgIDQuery}, + Queries: []*usr_model.ExternalIDPSearchQuery{orgIDQuery, instanceIDQuery}, }) _, err := query(db, &externalIDPs) return externalIDPs, err @@ -84,15 +99,19 @@ func PutExternalIDP(db *gorm.DB, table string, idp *model.ExternalIDPView) error return save(db, idp) } -func DeleteExternalIDP(db *gorm.DB, table, externalUserID, idpConfigID string) error { +func DeleteExternalIDP(db *gorm.DB, table, externalUserID, idpConfigID, instanceID string) error { delete := repository.PrepareDeleteByKeys(table, repository.Key{Key: model.ExternalIDPSearchKey(usr_model.ExternalIDPSearchKeyExternalUserID), Value: externalUserID}, repository.Key{Key: model.ExternalIDPSearchKey(usr_model.ExternalIDPSearchKeyIdpConfigID), Value: idpConfigID}, + repository.Key{Key: model.ExternalIDPSearchKey(usr_model.ExternalIDPSearchKeyInstanceID), Value: instanceID}, ) return delete(db) } -func DeleteExternalIDPsByUserID(db *gorm.DB, table, userID string) error { - delete := repository.PrepareDeleteByKey(table, model.ExternalIDPSearchKey(usr_model.ExternalIDPSearchKeyUserID), userID) +func DeleteExternalIDPsByUserID(db *gorm.DB, table, userID, instanceID string) error { + delete := repository.PrepareDeleteByKeys(table, + repository.Key{model.ExternalIDPSearchKey(usr_model.ExternalIDPSearchKeyUserID), userID}, + repository.Key{model.ExternalIDPSearchKey(usr_model.ExternalIDPSearchKeyInstanceID), instanceID}, + ) return delete(db) } diff --git a/internal/user/repository/view/model/external_idp_query.go b/internal/user/repository/view/model/external_idp_query.go index 08e637f355..7c9a774f10 100644 --- a/internal/user/repository/view/model/external_idp_query.go +++ b/internal/user/repository/view/model/external_idp_query.go @@ -59,6 +59,8 @@ func (key ExternalIDPSearchKey) ToColumnName() string { return ExternalIDPKeyIDPConfigID case usr_model.ExternalIDPSearchKeyResourceOwner: return ExternalIDPKeyResourceOwner + case usr_model.ExternalIDPSearchKeyInstanceID: + return ExternalIDPKeyInstanceID default: return "" } diff --git a/internal/user/repository/view/model/external_idps.go b/internal/user/repository/view/model/external_idps.go index 1e68ceb522..b2508dfc66 100644 --- a/internal/user/repository/view/model/external_idps.go +++ b/internal/user/repository/view/model/external_idps.go @@ -17,6 +17,7 @@ const ( ExternalIDPKeyUserID = "user_id" ExternalIDPKeyIDPConfigID = "idp_config_id" ExternalIDPKeyResourceOwner = "resource_owner" + ExternalIDPKeyInstanceID = "instance_id" ) type ExternalIDPView struct { @@ -29,7 +30,7 @@ type ExternalIDPView struct { ChangeDate time.Time `json:"-" gorm:"column:change_date"` ResourceOwner string `json:"-" gorm:"column:resource_owner"` Sequence uint64 `json:"-" gorm:"column:sequence"` - InstanceID string `json:"instanceID" gorm:"column:instance_id"` + InstanceID string `json:"instanceID" gorm:"column:instance_id;primary_key"` } func (i *ExternalIDPView) AppendEvent(event *models.Event) (err error) { diff --git a/internal/user/repository/view/model/notify_user.go b/internal/user/repository/view/model/notify_user.go index b0ee382c12..bde6a81c66 100644 --- a/internal/user/repository/view/model/notify_user.go +++ b/internal/user/repository/view/model/notify_user.go @@ -18,6 +18,7 @@ import ( const ( NotifyUserKeyUserID = "id" NotifyUserKeyResourceOwner = "resource_owner" + NotifyUserKeyInstanceID = "instance_id" ) type NotifyUser struct { @@ -41,7 +42,7 @@ type NotifyUser struct { PasswordSet bool `json:"-" gorm:"column:password_set"` Sequence uint64 `json:"-" gorm:"column:sequence"` State int32 `json:"-" gorm:"-"` - InstanceID string `json:"instanceID" gorm:"column:instance_id"` + InstanceID string `json:"instanceID" gorm:"column:instance_id;primary_key"` } func (u *NotifyUser) GenerateLoginName(domain string, appendDomain bool) string { diff --git a/internal/user/repository/view/model/notify_user_query.go b/internal/user/repository/view/model/notify_user_query.go index 0d437e5925..3141f77f2b 100644 --- a/internal/user/repository/view/model/notify_user_query.go +++ b/internal/user/repository/view/model/notify_user_query.go @@ -55,6 +55,8 @@ func (key NotifyUserSearchKey) ToColumnName() string { return NotifyUserKeyUserID case usr_model.NotifyUserSearchKeyResourceOwner: return NotifyUserKeyResourceOwner + case usr_model.NotifyUserSearchKeyInstanceID: + return NotifyUserKeyInstanceID default: return "" } diff --git a/internal/user/repository/view/model/refresh_token.go b/internal/user/repository/view/model/refresh_token.go index 58be9c5926..33b6210a51 100644 --- a/internal/user/repository/view/model/refresh_token.go +++ b/internal/user/repository/view/model/refresh_token.go @@ -40,7 +40,7 @@ type RefreshTokenView struct { IdleExpiration time.Time `json:"-" gorm:"column:idle_expiration"` Expiration time.Time `json:"-" gorm:"column:expiration"` Sequence uint64 `json:"-" gorm:"column:sequence"` - InstanceID string `json:"instanceID" gorm:"column:instance_id"` + InstanceID string `json:"instanceID" gorm:"column:instance_id;primary_key"` } func RefreshTokenViewsToModel(tokens []*RefreshTokenView) []*usr_model.RefreshTokenView { diff --git a/internal/user/repository/view/model/token.go b/internal/user/repository/view/model/token.go index 18d14afecf..ec592c5125 100644 --- a/internal/user/repository/view/model/token.go +++ b/internal/user/repository/view/model/token.go @@ -41,7 +41,7 @@ type TokenView struct { RefreshTokenID string `json:"refreshTokenID,omitempty" gorm:"refresh_token_id"` IsPAT bool `json:"-" gorm:"is_pat"` Deactivated bool `json:"-" gorm:"-"` - InstanceID string `json:"instanceID" gorm:"column:instance_id"` + InstanceID string `json:"instanceID" gorm:"column:instance_id;primary_key"` } func TokenViewToModel(token *TokenView) *usr_model.TokenView { diff --git a/internal/user/repository/view/model/user.go b/internal/user/repository/view/model/user.go index 4ddb59e7c5..a9bc9ec8a5 100644 --- a/internal/user/repository/view/model/user.go +++ b/internal/user/repository/view/model/user.go @@ -53,7 +53,7 @@ type UserView struct { Sequence uint64 `json:"-" gorm:"column:sequence"` Type userType `json:"-" gorm:"column:user_type"` UserName string `json:"userName" gorm:"column:user_name"` - InstanceID string `json:"instanceID" gorm:"column:instance_id"` + InstanceID string `json:"instanceID" gorm:"column:instance_id;primary_key"` *MachineView *HumanView } diff --git a/internal/user/repository/view/model/user_membership.go b/internal/user/repository/view/model/user_membership.go index ae0487dcf1..d6a03d36a9 100644 --- a/internal/user/repository/view/model/user_membership.go +++ b/internal/user/repository/view/model/user_membership.go @@ -41,7 +41,7 @@ type UserMembershipView struct { ResourceOwner string `json:"-" gorm:"column:resource_owner"` ResourceOwnerName string `json:"-" gorm:"column:resource_owner_name"` Sequence uint64 `json:"-" gorm:"column:sequence"` - InstanceID string `json:"instanceID" gorm:"column:instance_id"` + InstanceID string `json:"instanceID" gorm:"column:instance_id;primary_key"` } func (u *UserMembershipView) AppendEvent(event *models.Event) (err error) { diff --git a/internal/user/repository/view/model/user_session.go b/internal/user/repository/view/model/user_session.go index cb2a54b3e3..6394def794 100644 --- a/internal/user/repository/view/model/user_session.go +++ b/internal/user/repository/view/model/user_session.go @@ -20,6 +20,7 @@ const ( UserSessionKeyUserID = "user_id" UserSessionKeyState = "state" UserSessionKeyResourceOwner = "resource_owner" + UserSessionKeyInstanceID = "instance_id" ) type UserSessionView struct { @@ -42,7 +43,7 @@ type UserSessionView struct { MultiFactorVerification time.Time `json:"-" gorm:"column:multi_factor_verification"` MultiFactorVerificationType int32 `json:"-" gorm:"column:multi_factor_verification_type"` Sequence uint64 `json:"-" gorm:"column:sequence"` - InstanceID string `json:"instanceID" gorm:"column:instance_id"` + InstanceID string `json:"instanceID" gorm:"column:instance_id;primary_key"` } func UserSessionFromEvent(event *models.Event) (*UserSessionView, error) { diff --git a/internal/user/repository/view/model/user_session_query.go b/internal/user/repository/view/model/user_session_query.go index 49e2b47595..3ba32fbf8a 100644 --- a/internal/user/repository/view/model/user_session_query.go +++ b/internal/user/repository/view/model/user_session_query.go @@ -59,6 +59,8 @@ func (key UserSessionSearchKey) ToColumnName() string { return UserSessionKeyState case usr_model.UserSessionSearchKeyResourceOwner: return UserSessionKeyResourceOwner + case usr_model.UserSessionSearchKeyInstanceID: + return UserSessionKeyInstanceID default: return "" } diff --git a/internal/user/repository/view/notify_user.go b/internal/user/repository/view/notify_user.go index 374d85bdc4..aa03f037c9 100644 --- a/internal/user/repository/view/notify_user.go +++ b/internal/user/repository/view/notify_user.go @@ -1,17 +1,21 @@ package view import ( + "github.com/jinzhu/gorm" + "github.com/caos/zitadel/internal/domain" caos_errs "github.com/caos/zitadel/internal/errors" usr_model "github.com/caos/zitadel/internal/user/model" "github.com/caos/zitadel/internal/user/repository/view/model" "github.com/caos/zitadel/internal/view/repository" - "github.com/jinzhu/gorm" ) -func NotifyUserByID(db *gorm.DB, table, userID string) (*model.NotifyUser, error) { +func NotifyUserByID(db *gorm.DB, table, userID, instanceID string) (*model.NotifyUser, error) { user := new(model.NotifyUser) - query := repository.PrepareGetByKey(table, model.NotifyUserSearchKey(usr_model.NotifyUserSearchKeyUserID), userID) + query := repository.PrepareGetByQuery(table, + model.NotifyUserSearchQuery{Key: usr_model.NotifyUserSearchKeyUserID, Method: domain.SearchMethodEquals, Value: userID}, + model.NotifyUserSearchQuery{Key: usr_model.NotifyUserSearchKeyInstanceID, Method: domain.SearchMethodEquals, Value: instanceID}, + ) err := query(db, user) if caos_errs.IsNotFound(err) { return nil, caos_errs.ThrowNotFound(nil, "VIEW-Gad31", "Errors.User.NotFound") @@ -19,15 +23,20 @@ func NotifyUserByID(db *gorm.DB, table, userID string) (*model.NotifyUser, error return user, err } -func NotifyUsersByOrgID(db *gorm.DB, table, orgID string) ([]*model.NotifyUser, error) { +func NotifyUsersByOrgID(db *gorm.DB, table, orgID, instanceID string) ([]*model.NotifyUser, error) { users := make([]*model.NotifyUser, 0) orgIDQuery := &usr_model.NotifyUserSearchQuery{ Key: usr_model.NotifyUserSearchKeyResourceOwner, Method: domain.SearchMethodEquals, Value: orgID, } + instanceIDQuery := &usr_model.NotifyUserSearchQuery{ + Key: usr_model.NotifyUserSearchKeyInstanceID, + Method: domain.SearchMethodEquals, + Value: instanceID, + } query := repository.PrepareSearchQuery(table, model.NotifyUserSearchRequest{ - Queries: []*usr_model.NotifyUserSearchQuery{orgIDQuery}, + Queries: []*usr_model.NotifyUserSearchQuery{orgIDQuery, instanceIDQuery}, }) _, err := query(db, &users) return users, err @@ -38,7 +47,10 @@ func PutNotifyUser(db *gorm.DB, table string, project *model.NotifyUser) error { return save(db, project) } -func DeleteNotifyUser(db *gorm.DB, table, userID string) error { - delete := repository.PrepareDeleteByKey(table, model.UserSearchKey(usr_model.NotifyUserSearchKeyUserID), userID) +func DeleteNotifyUser(db *gorm.DB, table, userID, instanceID string) error { + delete := repository.PrepareDeleteByKeys(table, + repository.Key{model.UserSearchKey(usr_model.NotifyUserSearchKeyUserID), userID}, + repository.Key{model.UserSearchKey(usr_model.NotifyUserSearchKeyInstanceID), instanceID}, + ) return delete(db) } diff --git a/internal/user/repository/view/query.go b/internal/user/repository/view/query.go index 7133d87ce5..bb5971f2c4 100644 --- a/internal/user/repository/view/query.go +++ b/internal/user/repository/view/query.go @@ -6,16 +6,15 @@ import ( "github.com/caos/zitadel/internal/repository/user" ) -func UserByIDQuery(id string, latestSequence uint64) (*es_models.SearchQuery, error) { +func UserByIDQuery(id, instanceID string, latestSequence uint64) (*es_models.SearchQuery, error) { if id == "" { return nil, errors.ThrowPreconditionFailed(nil, "EVENT-d8isw", "Errors.User.UserIDMissing") } - return UserQuery(latestSequence). - AggregateIDFilter(id), nil -} - -func UserQuery(latestSequence uint64) *es_models.SearchQuery { return es_models.NewSearchQuery(). + AddQuery(). AggregateTypeFilter(user.AggregateType). - LatestSequenceFilter(latestSequence) + AggregateIDFilter(id). + LatestSequenceFilter(latestSequence). + InstanceIDFilter(instanceID). + SearchQuery(), nil } diff --git a/internal/user/repository/view/refresh_token_view.go b/internal/user/repository/view/refresh_token_view.go index 149219fd54..eee15bbe31 100644 --- a/internal/user/repository/view/refresh_token_view.go +++ b/internal/user/repository/view/refresh_token_view.go @@ -11,9 +11,12 @@ import ( "github.com/caos/zitadel/internal/view/repository" ) -func RefreshTokenByID(db *gorm.DB, table, tokenID string) (*usr_model.RefreshTokenView, error) { +func RefreshTokenByID(db *gorm.DB, table, tokenID, instanceID string) (*usr_model.RefreshTokenView, error) { token := new(usr_model.RefreshTokenView) - query := repository.PrepareGetByKey(table, usr_model.RefreshTokenSearchKey(model.RefreshTokenSearchKeyRefreshTokenID), tokenID) + query := repository.PrepareGetByQuery(table, + &usr_model.RefreshTokenSearchQuery{Key: model.RefreshTokenSearchKeyRefreshTokenID, Method: domain.SearchMethodEquals, Value: tokenID}, + &usr_model.RefreshTokenSearchQuery{Key: model.RefreshTokenSearchKeyInstanceID, Method: domain.SearchMethodEquals, Value: instanceID}, + ) err := query(db, token) if errors.IsNotFound(err) { return nil, errors.ThrowNotFound(nil, "VIEW-6ub3p", "Errors.RefreshToken.NotFound") @@ -21,15 +24,20 @@ func RefreshTokenByID(db *gorm.DB, table, tokenID string) (*usr_model.RefreshTok return token, err } -func RefreshTokensByUserID(db *gorm.DB, table, userID string) ([]*usr_model.RefreshTokenView, error) { +func RefreshTokensByUserID(db *gorm.DB, table, userID, instanceID string) ([]*usr_model.RefreshTokenView, error) { tokens := make([]*usr_model.RefreshTokenView, 0) userIDQuery := &model.RefreshTokenSearchQuery{ Key: model.RefreshTokenSearchKeyUserID, Method: domain.SearchMethodEquals, Value: userID, } + instanceIDQuery := &model.RefreshTokenSearchQuery{ + Key: model.RefreshTokenSearchKeyInstanceID, + Method: domain.SearchMethodEquals, + Value: instanceID, + } query := repository.PrepareSearchQuery(table, usr_model.RefreshTokenSearchRequest{ - Queries: []*model.RefreshTokenSearchQuery{userIDQuery}, + Queries: []*model.RefreshTokenSearchQuery{userIDQuery, instanceIDQuery}, }) _, err := query(db, &tokens) return tokens, err @@ -59,8 +67,11 @@ func SearchRefreshTokens(db *gorm.DB, table string, req *model.RefreshTokenSearc return tokens, count, err } -func DeleteRefreshToken(db *gorm.DB, table, tokenID string) error { - delete := repository.PrepareDeleteByKey(table, usr_model.RefreshTokenSearchKey(model.RefreshTokenSearchKeyRefreshTokenID), tokenID) +func DeleteRefreshToken(db *gorm.DB, table, tokenID, instanceID string) error { + delete := repository.PrepareDeleteByKeys(table, + repository.Key{usr_model.RefreshTokenSearchKey(model.RefreshTokenSearchKeyRefreshTokenID), tokenID}, + repository.Key{usr_model.RefreshTokenSearchKey(model.RefreshTokenSearchKeyInstanceID), instanceID}, + ) return delete(db) } @@ -72,8 +83,11 @@ func DeleteSessionRefreshTokens(db *gorm.DB, table, agentID, userID string) erro return delete(db) } -func DeleteUserRefreshTokens(db *gorm.DB, table, userID string) error { - delete := repository.PrepareDeleteByKey(table, usr_model.RefreshTokenSearchKey(model.RefreshTokenSearchKeyUserID), userID) +func DeleteUserRefreshTokens(db *gorm.DB, table, userID, instanceID string) error { + delete := repository.PrepareDeleteByKeys(table, + repository.Key{usr_model.RefreshTokenSearchKey(model.RefreshTokenSearchKeyUserID), userID}, + repository.Key{usr_model.RefreshTokenSearchKey(model.RefreshTokenSearchKeyInstanceID), instanceID}, + ) return delete(db) } diff --git a/internal/user/repository/view/token_view.go b/internal/user/repository/view/token_view.go index caf195b107..9815828290 100644 --- a/internal/user/repository/view/token_view.go +++ b/internal/user/repository/view/token_view.go @@ -11,9 +11,12 @@ import ( "github.com/caos/zitadel/internal/view/repository" ) -func TokenByID(db *gorm.DB, table, tokenID string) (*usr_model.TokenView, error) { +func TokenByID(db *gorm.DB, table, tokenID, instanceID string) (*usr_model.TokenView, error) { token := new(usr_model.TokenView) - query := repository.PrepareGetByKey(table, usr_model.TokenSearchKey(model.TokenSearchKeyTokenID), tokenID) + query := repository.PrepareGetByQuery(table, + &usr_model.TokenSearchQuery{Key: model.TokenSearchKeyTokenID, Method: domain.SearchMethodEquals, Value: tokenID}, + &usr_model.TokenSearchQuery{Key: model.TokenSearchKeyInstanceID, Method: domain.SearchMethodEquals, Value: instanceID}, + ) err := query(db, token) if errors.IsNotFound(err) { return nil, errors.ThrowNotFound(nil, "VIEW-6ub3p", "Errors.Token.NotFound") @@ -21,15 +24,20 @@ func TokenByID(db *gorm.DB, table, tokenID string) (*usr_model.TokenView, error) return token, err } -func TokensByUserID(db *gorm.DB, table, userID string) ([]*usr_model.TokenView, error) { +func TokensByUserID(db *gorm.DB, table, userID, instanceID string) ([]*usr_model.TokenView, error) { tokens := make([]*usr_model.TokenView, 0) userIDQuery := &model.TokenSearchQuery{ Key: model.TokenSearchKeyUserID, Method: domain.SearchMethodEquals, Value: userID, } + instanceIDQuery := &model.TokenSearchQuery{ + Key: model.TokenSearchKeyInstanceID, + Method: domain.SearchMethodEquals, + Value: instanceID, + } query := repository.PrepareSearchQuery(table, usr_model.TokenSearchRequest{ - Queries: []*model.TokenSearchQuery{userIDQuery}, + Queries: []*model.TokenSearchQuery{userIDQuery, instanceIDQuery}, }) _, err := query(db, &tokens) return tokens, err @@ -49,30 +57,43 @@ func PutTokens(db *gorm.DB, table string, tokens ...*usr_model.TokenView) error return save(db, t...) } -func DeleteToken(db *gorm.DB, table, tokenID string) error { - delete := repository.PrepareDeleteByKey(table, usr_model.TokenSearchKey(model.TokenSearchKeyTokenID), tokenID) - return delete(db) -} - -func DeleteSessionTokens(db *gorm.DB, table, agentID, userID string) error { +func DeleteToken(db *gorm.DB, table, tokenID, instanceID string) error { delete := repository.PrepareDeleteByKeys(table, - repository.Key{Key: usr_model.TokenSearchKey(model.TokenSearchKeyUserAgentID), Value: agentID}, - repository.Key{Key: usr_model.TokenSearchKey(model.TokenSearchKeyUserID), Value: userID}, + repository.Key{usr_model.TokenSearchKey(model.TokenSearchKeyTokenID), tokenID}, + repository.Key{usr_model.TokenSearchKey(model.TokenSearchKeyInstanceID), instanceID}, ) return delete(db) } -func DeleteUserTokens(db *gorm.DB, table, userID string) error { - delete := repository.PrepareDeleteByKey(table, usr_model.TokenSearchKey(model.TokenSearchKeyUserID), userID) +func DeleteSessionTokens(db *gorm.DB, table, agentID, userID, instanceID string) error { + delete := repository.PrepareDeleteByKeys(table, + repository.Key{Key: usr_model.TokenSearchKey(model.TokenSearchKeyUserAgentID), Value: agentID}, + repository.Key{Key: usr_model.TokenSearchKey(model.TokenSearchKeyUserID), Value: userID}, + repository.Key{Key: usr_model.TokenSearchKey(model.TokenSearchKeyInstanceID), Value: instanceID}, + ) return delete(db) } -func DeleteTokensFromRefreshToken(db *gorm.DB, table, refreshTokenID string) error { - delete := repository.PrepareDeleteByKey(table, usr_model.TokenSearchKey(model.TokenSearchKeyRefreshTokenID), refreshTokenID) +func DeleteUserTokens(db *gorm.DB, table, userID, instanceID string) error { + delete := repository.PrepareDeleteByKeys(table, + repository.Key{usr_model.TokenSearchKey(model.TokenSearchKeyUserID), userID}, + repository.Key{usr_model.TokenSearchKey(model.TokenSearchKeyInstanceID), instanceID}, + ) return delete(db) } -func DeleteApplicationTokens(db *gorm.DB, table string, appIDs []string) error { - delete := repository.PrepareDeleteByKey(table, usr_model.TokenSearchKey(model.TokenSearchKeyApplicationID), pq.StringArray(appIDs)) +func DeleteTokensFromRefreshToken(db *gorm.DB, table, refreshTokenID, instanceID string) error { + delete := repository.PrepareDeleteByKeys(table, + repository.Key{usr_model.TokenSearchKey(model.TokenSearchKeyRefreshTokenID), refreshTokenID}, + repository.Key{usr_model.TokenSearchKey(model.TokenSearchKeyInstanceID), instanceID}, + ) + return delete(db) +} + +func DeleteApplicationTokens(db *gorm.DB, table, instanceID string, appIDs []string) error { + delete := repository.PrepareDeleteByKeys(table, + repository.Key{usr_model.TokenSearchKey(model.TokenSearchKeyApplicationID), pq.StringArray(appIDs)}, + repository.Key{usr_model.TokenSearchKey(model.TokenSearchKeyInstanceID), instanceID}, + ) return delete(db) } diff --git a/internal/user/repository/view/user_session_view.go b/internal/user/repository/view/user_session_view.go index 48e813ab99..32a72263e3 100644 --- a/internal/user/repository/view/user_session_view.go +++ b/internal/user/repository/view/user_session_view.go @@ -10,7 +10,7 @@ import ( "github.com/caos/zitadel/internal/view/repository" ) -func UserSessionByIDs(db *gorm.DB, table, agentID, userID string) (*model.UserSessionView, error) { +func UserSessionByIDs(db *gorm.DB, table, agentID, userID, instanceID string) (*model.UserSessionView, error) { userSession := new(model.UserSessionView) userAgentQuery := model.UserSessionSearchQuery{ Key: usr_model.UserSessionSearchKeyUserAgentID, @@ -22,7 +22,12 @@ func UserSessionByIDs(db *gorm.DB, table, agentID, userID string) (*model.UserSe Method: domain.SearchMethodEquals, Value: userID, } - query := repository.PrepareGetByQuery(table, userAgentQuery, userQuery) + instanceIDQuery := &model.UserSessionSearchQuery{ + Key: usr_model.UserSessionSearchKeyInstanceID, + Method: domain.SearchMethodEquals, + Value: instanceID, + } + query := repository.PrepareGetByQuery(table, userAgentQuery, userQuery, instanceIDQuery) err := query(db, userSession) if caos_errs.IsNotFound(err) { return nil, caos_errs.ThrowNotFound(nil, "VIEW-NGBs1", "Errors.UserSession.NotFound") @@ -30,29 +35,39 @@ func UserSessionByIDs(db *gorm.DB, table, agentID, userID string) (*model.UserSe return userSession, err } -func UserSessionsByUserID(db *gorm.DB, table, userID string) ([]*model.UserSessionView, error) { +func UserSessionsByUserID(db *gorm.DB, table, userID, instanceID string) ([]*model.UserSessionView, error) { userSessions := make([]*model.UserSessionView, 0) userAgentQuery := &usr_model.UserSessionSearchQuery{ Key: usr_model.UserSessionSearchKeyUserID, Method: domain.SearchMethodEquals, Value: userID, } + instanceIDQuery := &usr_model.UserSessionSearchQuery{ + Key: usr_model.UserSessionSearchKeyInstanceID, + Method: domain.SearchMethodEquals, + Value: instanceID, + } query := repository.PrepareSearchQuery(table, model.UserSessionSearchRequest{ - Queries: []*usr_model.UserSessionSearchQuery{userAgentQuery}, + Queries: []*usr_model.UserSessionSearchQuery{userAgentQuery, instanceIDQuery}, }) _, err := query(db, &userSessions) return userSessions, err } -func UserSessionsByAgentID(db *gorm.DB, table, agentID string) ([]*model.UserSessionView, error) { +func UserSessionsByAgentID(db *gorm.DB, table, agentID, instanceID string) ([]*model.UserSessionView, error) { userSessions := make([]*model.UserSessionView, 0) userAgentQuery := &usr_model.UserSessionSearchQuery{ Key: usr_model.UserSessionSearchKeyUserAgentID, Method: domain.SearchMethodEquals, Value: agentID, } + instanceIDQuery := &usr_model.UserSessionSearchQuery{ + Key: usr_model.UserSessionSearchKeyInstanceID, + Method: domain.SearchMethodEquals, + Value: instanceID, + } query := repository.PrepareSearchQuery(table, model.UserSessionSearchRequest{ - Queries: []*usr_model.UserSessionSearchQuery{userAgentQuery}, + Queries: []*usr_model.UserSessionSearchQuery{userAgentQuery, instanceIDQuery}, }) _, err := query(db, &userSessions) return userSessions, err @@ -84,7 +99,10 @@ func PutUserSessions(db *gorm.DB, table string, sessions ...*model.UserSessionVi return save(db, s...) } -func DeleteUserSessions(db *gorm.DB, table, userID string) error { - delete := repository.PrepareDeleteByKey(table, model.UserSessionSearchKey(usr_model.UserSessionSearchKeyUserID), userID) +func DeleteUserSessions(db *gorm.DB, table, userID, instanceID string) error { + delete := repository.PrepareDeleteByKeys(table, + repository.Key{model.UserSessionSearchKey(usr_model.UserSessionSearchKeyUserID), userID}, + repository.Key{model.UserSessionSearchKey(usr_model.UserSessionSearchKeyInstanceID), instanceID}, + ) return delete(db) } diff --git a/internal/user/repository/view/user_view.go b/internal/user/repository/view/user_view.go index 5a8db73570..0f3c4a2a67 100644 --- a/internal/user/repository/view/user_view.go +++ b/internal/user/repository/view/user_view.go @@ -11,9 +11,19 @@ import ( "github.com/caos/zitadel/internal/user/repository/view/model" ) -func UserByID(db *gorm.DB, table, userID string) (*model.UserView, error) { +func UserByID(db *gorm.DB, table, userID, instanceID string) (*model.UserView, error) { user := new(model.UserView) - query := repository.PrepareGetByKey(table, model.UserSearchKey(usr_model.UserSearchKeyUserID), userID) + userIDQuery := &model.UserSearchQuery{ + Key: usr_model.UserSearchKeyUserID, + Method: domain.SearchMethodEquals, + Value: userID, + } + instanceIDQuery := &model.UserSearchQuery{ + Key: usr_model.UserSearchKeyInstanceID, + Method: domain.SearchMethodEquals, + Value: instanceID, + } + query := repository.PrepareGetByQuery(table, userIDQuery, instanceIDQuery) err := query(db, user) if caos_errs.IsNotFound(err) { return nil, caos_errs.ThrowNotFound(nil, "VIEW-sj8Sw", "Errors.User.NotFound") @@ -22,9 +32,19 @@ func UserByID(db *gorm.DB, table, userID string) (*model.UserView, error) { return user, err } -func UserByUserName(db *gorm.DB, table, userName string) (*model.UserView, error) { +func UserByUserName(db *gorm.DB, table, userName, instanceID string) (*model.UserView, error) { user := new(model.UserView) - query := repository.PrepareGetByKey(table, model.UserSearchKey(usr_model.UserSearchKeyUserName), userName) + userNameQuery := &model.UserSearchQuery{ + Key: usr_model.UserSearchKeyUserName, + Method: domain.SearchMethodEquals, + Value: userName, + } + instanceIDQuery := &model.UserSearchQuery{ + Key: usr_model.UserSearchKeyInstanceID, + Method: domain.SearchMethodEquals, + Value: instanceID, + } + query := repository.PrepareGetByQuery(table, userNameQuery, instanceIDQuery) err := query(db, user) if caos_errs.IsNotFound(err) { return nil, caos_errs.ThrowNotFound(nil, "VIEW-Lso9s", "Errors.User.NotFound") @@ -33,14 +53,19 @@ func UserByUserName(db *gorm.DB, table, userName string) (*model.UserView, error return user, err } -func UserByLoginName(db *gorm.DB, table, loginName string) (*model.UserView, error) { +func UserByLoginName(db *gorm.DB, table, loginName, instanceID string) (*model.UserView, error) { user := new(model.UserView) loginNameQuery := &model.UserSearchQuery{ Key: usr_model.UserSearchKeyLoginNames, Method: domain.SearchMethodListContains, Value: loginName, } - query := repository.PrepareGetByQuery(table, loginNameQuery) + instanceIDQuery := &model.UserSearchQuery{ + Key: usr_model.UserSearchKeyInstanceID, + Method: domain.SearchMethodEquals, + Value: instanceID, + } + query := repository.PrepareGetByQuery(table, loginNameQuery, instanceIDQuery) err := query(db, user) if caos_errs.IsNotFound(err) { return nil, caos_errs.ThrowNotFound(nil, "VIEW-AD4qs", "Errors.User.NotFound") @@ -49,7 +74,7 @@ func UserByLoginName(db *gorm.DB, table, loginName string) (*model.UserView, err return user, err } -func UserByLoginNameAndResourceOwner(db *gorm.DB, table, loginName, resourceOwner string) (*model.UserView, error) { +func UserByLoginNameAndResourceOwner(db *gorm.DB, table, loginName, resourceOwner, instanceID string) (*model.UserView, error) { user := new(model.UserView) loginNameQuery := &model.UserSearchQuery{ Key: usr_model.UserSearchKeyLoginNames, @@ -61,7 +86,12 @@ func UserByLoginNameAndResourceOwner(db *gorm.DB, table, loginName, resourceOwne Method: domain.SearchMethodEquals, Value: resourceOwner, } - query := repository.PrepareGetByQuery(table, loginNameQuery, resourceOwnerQuery) + instanceIDQuery := &model.UserSearchQuery{ + Key: usr_model.UserSearchKeyInstanceID, + Method: domain.SearchMethodEquals, + Value: instanceID, + } + query := repository.PrepareGetByQuery(table, loginNameQuery, resourceOwnerQuery, instanceIDQuery) err := query(db, user) if caos_errs.IsNotFound(err) { return nil, caos_errs.ThrowNotFound(nil, "VIEW-AD4qs", "Errors.User.NotFoundOnOrg") @@ -70,21 +100,26 @@ func UserByLoginNameAndResourceOwner(db *gorm.DB, table, loginName, resourceOwne return user, err } -func UsersByOrgID(db *gorm.DB, table, orgID string) ([]*model.UserView, error) { +func UsersByOrgID(db *gorm.DB, table, orgID, instanceID string) ([]*model.UserView, error) { users := make([]*model.UserView, 0) orgIDQuery := &usr_model.UserSearchQuery{ Key: usr_model.UserSearchKeyResourceOwner, Method: domain.SearchMethodEquals, Value: orgID, } + instanceIDQuery := &usr_model.UserSearchQuery{ + Key: usr_model.UserSearchKeyInstanceID, + Method: domain.SearchMethodEquals, + Value: instanceID, + } query := repository.PrepareSearchQuery(table, model.UserSearchRequest{ - Queries: []*usr_model.UserSearchQuery{orgIDQuery}, + Queries: []*usr_model.UserSearchQuery{orgIDQuery, instanceIDQuery}, }) _, err := query(db, &users) return users, err } -func UserIDsByDomain(db *gorm.DB, table, orgDomain string) ([]string, error) { +func UserIDsByDomain(db *gorm.DB, table, orgDomain, instanceID string) ([]string, error) { type id struct { Id string } @@ -94,8 +129,13 @@ func UserIDsByDomain(db *gorm.DB, table, orgDomain string) ([]string, error) { Method: domain.SearchMethodEndsWithIgnoreCase, Value: "%" + orgDomain, } + instanceIDQuery := &usr_model.UserSearchQuery{ + Key: usr_model.UserSearchKeyInstanceID, + Method: domain.SearchMethodEquals, + Value: instanceID, + } query := repository.PrepareSearchQuery(table, model.UserSearchRequest{ - Queries: []*usr_model.UserSearchQuery{orgIDQuery}, + Queries: []*usr_model.UserSearchQuery{orgIDQuery, instanceIDQuery}, }) _, err := query(db, &ids) if err != nil { @@ -118,9 +158,12 @@ func SearchUsers(db *gorm.DB, table string, req *usr_model.UserSearchRequest) ([ return users, count, nil } -func GetGlobalUserByLoginName(db *gorm.DB, table, loginName string) (*model.UserView, error) { +func GetGlobalUserByLoginName(db *gorm.DB, table, loginName, instanceID string) (*model.UserView, error) { user := new(model.UserView) - query := repository.PrepareGetByQuery(table, &model.UserSearchQuery{Key: usr_model.UserSearchKeyLoginNames, Value: loginName, Method: domain.SearchMethodListContains}) + query := repository.PrepareGetByQuery(table, + &model.UserSearchQuery{Key: usr_model.UserSearchKeyLoginNames, Value: loginName, Method: domain.SearchMethodListContains}, + &model.UserSearchQuery{Key: usr_model.UserSearchKeyInstanceID, Value: instanceID, Method: domain.SearchMethodEquals}, + ) err := query(db, user) if caos_errs.IsNotFound(err) { return nil, caos_errs.ThrowNotFound(nil, "VIEW-8uWer", "Errors.User.NotFound") @@ -129,8 +172,8 @@ func GetGlobalUserByLoginName(db *gorm.DB, table, loginName string) (*model.User return user, err } -func UserMFAs(db *gorm.DB, table, userID string) ([]*usr_model.MultiFactor, error) { - user, err := UserByID(db, table, userID) +func UserMFAs(db *gorm.DB, table, userID, instanceID string) ([]*usr_model.MultiFactor, error) { + user, err := UserByID(db, table, userID, instanceID) if err != nil { return nil, err } @@ -154,7 +197,10 @@ func PutUser(db *gorm.DB, table string, user *model.UserView) error { return save(db, user) } -func DeleteUser(db *gorm.DB, table, userID string) error { - delete := repository.PrepareDeleteByKey(table, model.UserSearchKey(usr_model.UserSearchKeyUserID), userID) +func DeleteUser(db *gorm.DB, table, userID, instanceID string) error { + delete := repository.PrepareDeleteByKeys(table, + repository.Key{model.UserSearchKey(usr_model.UserSearchKeyUserID), userID}, + repository.Key{model.UserSearchKey(usr_model.UserSearchKeyInstanceID), instanceID}, + ) return delete(db) } diff --git a/internal/user/repository/view/usermembership_view.go b/internal/user/repository/view/usermembership_view.go index 47e3e95d40..a6c69b5201 100644 --- a/internal/user/repository/view/usermembership_view.go +++ b/internal/user/repository/view/usermembership_view.go @@ -1,23 +1,24 @@ package view import ( - "github.com/caos/zitadel/internal/domain" - "github.com/caos/zitadel/internal/view/repository" "github.com/jinzhu/gorm" + "github.com/caos/zitadel/internal/domain" caos_errs "github.com/caos/zitadel/internal/errors" usr_model "github.com/caos/zitadel/internal/user/model" "github.com/caos/zitadel/internal/user/repository/view/model" + "github.com/caos/zitadel/internal/view/repository" ) -func UserMembershipByIDs(db *gorm.DB, table, userID, aggregateID, objectID string, membertype usr_model.MemberType) (*model.UserMembershipView, error) { +func UserMembershipByIDs(db *gorm.DB, table, userID, aggregateID, objectID, instanceID string, membertype usr_model.MemberType) (*model.UserMembershipView, error) { memberships := new(model.UserMembershipView) userIDQuery := &model.UserMembershipSearchQuery{Key: usr_model.UserMembershipSearchKeyUserID, Value: userID, Method: domain.SearchMethodEquals} aggregateIDQuery := &model.UserMembershipSearchQuery{Key: usr_model.UserMembershipSearchKeyAggregateID, Value: aggregateID, Method: domain.SearchMethodEquals} objectIDQuery := &model.UserMembershipSearchQuery{Key: usr_model.UserMembershipSearchKeyObjectID, Value: objectID, Method: domain.SearchMethodEquals} memberTypeQuery := &model.UserMembershipSearchQuery{Key: usr_model.UserMembershipSearchKeyMemberType, Value: int32(membertype), Method: domain.SearchMethodEquals} + instanceIDQuery := &model.UserMembershipSearchQuery{Key: usr_model.UserMembershipSearchKeyInstanceID, Value: instanceID, Method: domain.SearchMethodEquals} - query := repository.PrepareGetByQuery(table, userIDQuery, aggregateIDQuery, objectIDQuery, memberTypeQuery) + query := repository.PrepareGetByQuery(table, userIDQuery, aggregateIDQuery, objectIDQuery, memberTypeQuery, instanceIDQuery) err := query(db, memberships) if caos_errs.IsNotFound(err) { return nil, caos_errs.ThrowNotFound(nil, "VIEW-5Tsji", "Errors.UserMembership.NotFound") @@ -25,21 +26,23 @@ func UserMembershipByIDs(db *gorm.DB, table, userID, aggregateID, objectID strin return memberships, err } -func UserMembershipsByAggregateID(db *gorm.DB, table, aggregateID string) ([]*model.UserMembershipView, error) { +func UserMembershipsByAggregateID(db *gorm.DB, table, aggregateID, instanceID string) ([]*model.UserMembershipView, error) { memberships := make([]*model.UserMembershipView, 0) aggregateIDQuery := &usr_model.UserMembershipSearchQuery{Key: usr_model.UserMembershipSearchKeyAggregateID, Value: aggregateID, Method: domain.SearchMethodEquals} + instanceIDQuery := &usr_model.UserMembershipSearchQuery{Key: usr_model.UserMembershipSearchKeyInstanceID, Value: instanceID, Method: domain.SearchMethodEquals} query := repository.PrepareSearchQuery(table, model.UserMembershipSearchRequest{ - Queries: []*usr_model.UserMembershipSearchQuery{aggregateIDQuery}, + Queries: []*usr_model.UserMembershipSearchQuery{aggregateIDQuery, instanceIDQuery}, }) _, err := query(db, &memberships) return memberships, err } -func UserMembershipsByResourceOwner(db *gorm.DB, table, resourceOwner string) ([]*model.UserMembershipView, error) { +func UserMembershipsByResourceOwner(db *gorm.DB, table, resourceOwner, instanceID string) ([]*model.UserMembershipView, error) { memberships := make([]*model.UserMembershipView, 0) aggregateIDQuery := &usr_model.UserMembershipSearchQuery{Key: usr_model.UserMembershipSearchKeyResourceOwner, Value: resourceOwner, Method: domain.SearchMethodEquals} + instanceIDQuery := &usr_model.UserMembershipSearchQuery{Key: usr_model.UserMembershipSearchKeyInstanceID, Value: instanceID, Method: domain.SearchMethodEquals} query := repository.PrepareSearchQuery(table, model.UserMembershipSearchRequest{ - Queries: []*usr_model.UserMembershipSearchQuery{aggregateIDQuery}, + Queries: []*usr_model.UserMembershipSearchQuery{aggregateIDQuery, instanceIDQuery}, }) _, err := query(db, &memberships) return memberships, err @@ -69,30 +72,38 @@ func PutUserMembership(db *gorm.DB, table string, user *model.UserMembershipView return save(db, user) } -func DeleteUserMembership(db *gorm.DB, table, userID, aggregateID, objectID string, membertype usr_model.MemberType) error { +func DeleteUserMembership(db *gorm.DB, table, userID, aggregateID, objectID, instanceID string, membertype usr_model.MemberType) error { delete := repository.PrepareDeleteByKeys(table, repository.Key{Key: model.UserMembershipSearchKey(usr_model.UserMembershipSearchKeyUserID), Value: userID}, repository.Key{Key: model.UserMembershipSearchKey(usr_model.UserMembershipSearchKeyAggregateID), Value: aggregateID}, repository.Key{Key: model.UserMembershipSearchKey(usr_model.UserMembershipSearchKeyObjectID), Value: objectID}, repository.Key{Key: model.UserMembershipSearchKey(usr_model.UserMembershipSearchKeyMemberType), Value: membertype}, + repository.Key{Key: model.UserMembershipSearchKey(usr_model.UserMembershipSearchKeyInstanceID), Value: instanceID}, ) return delete(db) } -func DeleteUserMembershipsByUserID(db *gorm.DB, table, userID string) error { - delete := repository.PrepareDeleteByKey(table, model.UserMembershipSearchKey(usr_model.UserMembershipSearchKeyUserID), userID) +func DeleteUserMembershipsByUserID(db *gorm.DB, table, userID, instanceID string) error { + delete := repository.PrepareDeleteByKeys(table, + repository.Key{model.UserMembershipSearchKey(usr_model.UserMembershipSearchKeyUserID), userID}, + repository.Key{model.UserMembershipSearchKey(usr_model.UserMembershipSearchKeyInstanceID), instanceID}, + ) return delete(db) } -func DeleteUserMembershipsByAggregateID(db *gorm.DB, table, aggregateID string) error { - delete := repository.PrepareDeleteByKey(table, model.UserMembershipSearchKey(usr_model.UserMembershipSearchKeyAggregateID), aggregateID) +func DeleteUserMembershipsByAggregateID(db *gorm.DB, table, aggregateID, instanceID string) error { + delete := repository.PrepareDeleteByKeys(table, + repository.Key{model.UserMembershipSearchKey(usr_model.UserMembershipSearchKeyAggregateID), aggregateID}, + repository.Key{model.UserMembershipSearchKey(usr_model.UserMembershipSearchKeyInstanceID), instanceID}, + ) return delete(db) } -func DeleteUserMembershipsByAggregateIDAndObjectID(db *gorm.DB, table, aggregateID, objectID string) error { +func DeleteUserMembershipsByAggregateIDAndObjectID(db *gorm.DB, table, aggregateID, objectID, instanceID string) error { delete := repository.PrepareDeleteByKeys(table, repository.Key{Key: model.UserMembershipSearchKey(usr_model.UserMembershipSearchKeyAggregateID), Value: aggregateID}, repository.Key{Key: model.UserMembershipSearchKey(usr_model.UserMembershipSearchKeyObjectID), Value: objectID}, + repository.Key{Key: model.UserMembershipSearchKey(usr_model.UserMembershipSearchKeyInstanceID), Value: instanceID}, ) return delete(db) } diff --git a/internal/view/repository/failed_events.go b/internal/view/repository/failed_events.go index 9c4ecbeeec..bad4abc549 100644 --- a/internal/view/repository/failed_events.go +++ b/internal/view/repository/failed_events.go @@ -1,17 +1,13 @@ package repository import ( - "github.com/caos/zitadel/internal/domain" "strings" + "github.com/jinzhu/gorm" + + "github.com/caos/zitadel/internal/domain" "github.com/caos/zitadel/internal/errors" view_model "github.com/caos/zitadel/internal/view/model" - "github.com/jinzhu/gorm" -) - -const ( - errViewNameKey = "view_name" - errFailedSeqKey = "failed_sequence" ) type FailedEvent struct { @@ -19,6 +15,7 @@ type FailedEvent struct { FailedSequence uint64 `gorm:"column:failed_sequence;primary_key"` FailureCount uint64 `gorm:"column:failure_count"` ErrMsg string `gorm:"column:err_msg"` + InstanceID string `gorm:"column:instance_id"` } type FailedEventSearchQuery struct { @@ -45,6 +42,7 @@ const ( FailedEventKeyUndefined FailedEventSearchKey = iota FailedEventKeyViewName FailedEventKeyFailedSequence + FailedEventKeyInstanceID ) type failedEventSearchKey FailedEventSearchKey @@ -55,6 +53,8 @@ func (key failedEventSearchKey) ToColumnName() string { return "view_name" case FailedEventKeyFailedSequence: return "failed_sequence" + case FailedEventKeyInstanceID: + return "instance_id" default: return "" } @@ -93,15 +93,17 @@ func RemoveFailedEvent(db *gorm.DB, table string, failedEvent *FailedEvent) erro delete := PrepareDeleteByKeys(table, Key{Key: failedEventSearchKey(FailedEventKeyViewName), Value: failedEvent.ViewName}, Key{Key: failedEventSearchKey(FailedEventKeyFailedSequence), Value: failedEvent.FailedSequence}, + Key{Key: failedEventSearchKey(FailedEventKeyInstanceID), Value: failedEvent.InstanceID}, ) return delete(db) } -func LatestFailedEvent(db *gorm.DB, table, viewName string, sequence uint64) (*FailedEvent, error) { +func LatestFailedEvent(db *gorm.DB, table, viewName, instanceID string, sequence uint64) (*FailedEvent, error) { failedEvent := new(FailedEvent) queries := []SearchQuery{ FailedEventSearchQuery{Key: FailedEventKeyViewName, Method: domain.SearchMethodEqualsIgnoreCase, Value: viewName}, FailedEventSearchQuery{Key: FailedEventKeyFailedSequence, Method: domain.SearchMethodEquals, Value: sequence}, + FailedEventSearchQuery{Key: FailedEventKeyInstanceID, Method: domain.SearchMethodEquals, Value: instanceID}, } query := PrepareGetByQuery(table, queries...) err := query(db, failedEvent) diff --git a/internal/view/repository/sequence.go b/internal/view/repository/sequence.go index 7835fcb973..99102ff49d 100644 --- a/internal/view/repository/sequence.go +++ b/internal/view/repository/sequence.go @@ -16,6 +16,7 @@ type CurrentSequence struct { CurrentSequence uint64 `gorm:"column:current_sequence"` EventTimestamp time.Time `gorm:"column:event_timestamp"` LastSuccessfulSpoolerRun time.Time `gorm:"column:last_successful_spooler_run"` + InstanceID string `gorm:"column:instance_id;primary_key"` } type currentSequenceViewWithSequence struct { @@ -35,6 +36,7 @@ const ( SequenceSearchKeyUndefined SequenceSearchKey = iota SequenceSearchKeyViewName SequenceSearchKeyAggregateType + SequenceSearchKeyInstanceID ) type sequenceSearchKey SequenceSearchKey @@ -45,6 +47,8 @@ func (key sequenceSearchKey) ToColumnName() string { return "view_name" case SequenceSearchKeyAggregateType: return "aggregate_type" + case SequenceSearchKeyInstanceID: + return "instance_id" default: return "" } @@ -67,6 +71,34 @@ func (q *sequenceSearchQuery) GetValue() interface{} { return q.value } +type sequenceSearchRequest struct { + queries []sequenceSearchQuery +} + +func (s *sequenceSearchRequest) GetLimit() uint64 { + return 0 +} + +func (s *sequenceSearchRequest) GetOffset() uint64 { + return 0 +} + +func (s *sequenceSearchRequest) GetSortingColumn() ColumnKey { + return nil +} + +func (s *sequenceSearchRequest) GetAsc() bool { + return false +} + +func (s *sequenceSearchRequest) GetQueries() []SearchQuery { + result := make([]SearchQuery, len(s.queries)) + for i, q := range s.queries { + result[i] = &sequenceSearchQuery{key: q.key, value: q.value} + } + return result +} + func CurrentSequenceToModel(sequence *CurrentSequence) *model.View { dbView := strings.Split(sequence.ViewName, ".") return &model.View{ @@ -78,8 +110,17 @@ func CurrentSequenceToModel(sequence *CurrentSequence) *model.View { } } -func SaveCurrentSequence(db *gorm.DB, table, viewName string, sequence uint64, eventTimestamp time.Time) error { - return UpdateCurrentSequence(db, table, &CurrentSequence{viewName, sequence, eventTimestamp, time.Now()}) +func SaveCurrentSequence(db *gorm.DB, table, viewName, instanceID string, sequence uint64, eventTimestamp time.Time) error { + return UpdateCurrentSequence(db, table, &CurrentSequence{viewName, sequence, eventTimestamp, time.Now(), instanceID}) +} + +func SaveCurrentSequences(db *gorm.DB, table, viewName string, sequence uint64, eventTimestamp time.Time) error { + err := db.Table(table).Where("view_name = ?", viewName). + Updates(map[string]interface{}{"current_sequence": sequence, "event_timestamp": eventTimestamp, "last_successful_spooler_run": time.Now()}).Error + if err != nil { + return caos_errs.ThrowInternal(err, "VIEW-Sfdqs", "unable to updated processed sequence") + } + return nil } func UpdateCurrentSequence(db *gorm.DB, table string, currentSequence *CurrentSequence) (err error) { @@ -91,9 +132,24 @@ func UpdateCurrentSequence(db *gorm.DB, table string, currentSequence *CurrentSe return nil } -func LatestSequence(db *gorm.DB, table, viewName string) (*CurrentSequence, error) { - searchQueries := make([]SearchQuery, 0, 2) - searchQueries = append(searchQueries, &sequenceSearchQuery{key: sequenceSearchKey(SequenceSearchKeyViewName), value: viewName}) +func UpdateCurrentSequences(db *gorm.DB, table string, currentSequences []*CurrentSequence) (err error) { + save := PrepareBulkSave(table) + s := make([]interface{}, len(currentSequences)) + for i, currentSequence := range currentSequences { + s[i] = currentSequence + } + err = save(db, s...) + if err != nil { + return caos_errs.ThrowInternal(err, "VIEW-5kOhP", "unable to updated processed sequence") + } + return nil +} + +func LatestSequence(db *gorm.DB, table, viewName, instanceID string) (*CurrentSequence, error) { + searchQueries := []SearchQuery{ + &sequenceSearchQuery{key: sequenceSearchKey(SequenceSearchKeyViewName), value: viewName}, + &sequenceSearchQuery{key: sequenceSearchKey(SequenceSearchKeyInstanceID), value: instanceID}, + } // ensure highest sequence of view db = db.Order("current_sequence DESC") @@ -112,6 +168,27 @@ func LatestSequence(db *gorm.DB, table, viewName string) (*CurrentSequence, erro return nil, caos_errs.ThrowInternalf(err, "VIEW-9LyCB", "unable to get latest sequence of %s", viewName) } +func LatestSequences(db *gorm.DB, table, viewName string) ([]*CurrentSequence, error) { + searchQueries := make([]SearchQuery, 0, 2) + searchQueries = append(searchQueries) + searchRequest := &sequenceSearchRequest{ + queries: []sequenceSearchQuery{ + {key: sequenceSearchKey(SequenceSearchKeyViewName), value: viewName}, + }, + } + + // ensure highest sequence of view + db = db.Order("current_sequence DESC") + + sequences := make([]*CurrentSequence, 0) + query := PrepareSearchQuery(table, searchRequest) + _, err := query(db, &sequences) + if err != nil { + return nil, err + } + return sequences, nil +} + func AllCurrentSequences(db *gorm.DB, table string) ([]*CurrentSequence, error) { sequences := make([]*CurrentSequence, 0) query := PrepareSearchQuery(table, GeneralSearchRequest{}) @@ -128,5 +205,5 @@ func ClearView(db *gorm.DB, truncateView, sequenceTable string) error { if err != nil { return err } - return SaveCurrentSequence(db, sequenceTable, truncateView, 0, time.Now()) + return SaveCurrentSequences(db, sequenceTable, truncateView, 0, time.Now()) }