diff --git a/cmd/setup/57.go b/cmd/setup/57.go new file mode 100644 index 0000000000..4c52018f1e --- /dev/null +++ b/cmd/setup/57.go @@ -0,0 +1,27 @@ +package setup + +import ( + "context" + _ "embed" + + "github.com/zitadel/zitadel/internal/database" + "github.com/zitadel/zitadel/internal/eventstore" +) + +var ( + //go:embed 57.sql + createResourceCounts string +) + +type CreateResourceCounts struct { + dbClient *database.DB +} + +func (mig *CreateResourceCounts) Execute(ctx context.Context, _ eventstore.Event) error { + _, err := mig.dbClient.ExecContext(ctx, createResourceCounts) + return err +} + +func (mig *CreateResourceCounts) String() string { + return "57_create_resource_counts" +} diff --git a/cmd/setup/57.sql b/cmd/setup/57.sql new file mode 100644 index 0000000000..f2f0a40202 --- /dev/null +++ b/cmd/setup/57.sql @@ -0,0 +1,106 @@ +CREATE TABLE IF NOT EXISTS projections.resource_counts +( + id SERIAL PRIMARY KEY, -- allows for easy pagination + instance_id TEXT NOT NULL, + table_name TEXT NOT NULL, -- needed for trigger matching, not in reports + parent_type TEXT NOT NULL, + parent_id TEXT NOT NULL, + resource_name TEXT NOT NULL, -- friendly name for reporting + updated_at TIMESTAMPTZ NOT NULL DEFAULT now(), + amount INTEGER NOT NULL DEFAULT 1 CHECK (amount >= 0), + + UNIQUE (instance_id, parent_type, parent_id, table_name) +); + +-- count_resource is a trigger function which increases or decreases the count of a resource. +-- When creating the trigger the following required arguments (TG_ARGV) can be passed: +-- 1. The type of the parent +-- 2. The column name of the instance id +-- 3. The column name of the owner id +-- 4. The name of the resource +CREATE OR REPLACE FUNCTION projections.count_resource() + RETURNS trigger + LANGUAGE 'plpgsql' VOLATILE +AS $$ +DECLARE + -- trigger variables + tg_table_name TEXT := TG_TABLE_SCHEMA || '.' || TG_TABLE_NAME; + tg_parent_type TEXT := TG_ARGV[0]; + tg_instance_id_column TEXT := TG_ARGV[1]; + tg_parent_id_column TEXT := TG_ARGV[2]; + tg_resource_name TEXT := TG_ARGV[3]; + + tg_instance_id TEXT; + tg_parent_id TEXT; + + select_ids TEXT := format('SELECT ($1).%I, ($1).%I', tg_instance_id_column, tg_parent_id_column); +BEGIN + IF (TG_OP = 'INSERT') THEN + EXECUTE select_ids INTO tg_instance_id, tg_parent_id USING NEW; + + INSERT INTO projections.resource_counts(instance_id, table_name, parent_type, parent_id, resource_name) + VALUES (tg_instance_id, tg_table_name, tg_parent_type, tg_parent_id, tg_resource_name) + ON CONFLICT (instance_id, table_name, parent_type, parent_id) DO + UPDATE SET updated_at = now(), amount = projections.resource_counts.amount + 1; + + RETURN NEW; + ELSEIF (TG_OP = 'DELETE') THEN + EXECUTE select_ids INTO tg_instance_id, tg_parent_id USING OLD; + + UPDATE projections.resource_counts + SET updated_at = now(), amount = amount - 1 + WHERE instance_id = tg_instance_id + AND table_name = tg_table_name + AND parent_type = tg_parent_type + AND parent_id = tg_parent_id + AND resource_name = tg_resource_name + AND amount > 0; -- prevent check failure on negative amount. + + RETURN OLD; + END IF; +END +$$; + +-- delete_table_counts removes all resource counts for a TRUNCATED table. +CREATE OR REPLACE FUNCTION projections.delete_table_counts() + RETURNS trigger + LANGUAGE 'plpgsql' +AS $$ +DECLARE + -- trigger variables + tg_table_name TEXT := TG_TABLE_SCHEMA || '.' || TG_TABLE_NAME; +BEGIN + DELETE FROM projections.resource_counts + WHERE table_name = tg_table_name; +END +$$; + +-- delete_parent_counts removes all resource counts for a deleted parent. +-- 1. The type of the parent +-- 2. The column name of the instance id +-- 3. The column name of the owner id +CREATE OR REPLACE FUNCTION projections.delete_parent_counts() + RETURNS trigger + LANGUAGE 'plpgsql' +AS $$ +DECLARE + -- trigger variables + tg_parent_type TEXT := TG_ARGV[0]; + tg_instance_id_column TEXT := TG_ARGV[1]; + tg_parent_id_column TEXT := TG_ARGV[2]; + + tg_instance_id TEXT; + tg_parent_id TEXT; + + select_ids TEXT := format('SELECT ($1).%I, ($1).%I', tg_instance_id_column, tg_parent_id_column); +BEGIN + EXECUTE select_ids INTO tg_instance_id, tg_parent_id USING OLD; + + DELETE FROM projections.resource_counts + WHERE instance_id = tg_instance_id + AND parent_type = tg_parent_type + AND parent_id = tg_parent_id; + + RETURN OLD; +END +$$; diff --git a/cmd/setup/config.go b/cmd/setup/config.go index bd2abde9ea..dd59ba3f07 100644 --- a/cmd/setup/config.go +++ b/cmd/setup/config.go @@ -153,6 +153,7 @@ type Steps struct { s54InstancePositionIndex *InstancePositionIndex s55ExecutionHandlerStart *ExecutionHandlerStart s56IDPTemplate6SAMLFederatedLogout *IDPTemplate6SAMLFederatedLogout + s57CreateResourceCounts *CreateResourceCounts } func MustNewSteps(v *viper.Viper) *Steps { diff --git a/cmd/setup/setup.go b/cmd/setup/setup.go index c84976f282..1465180a6b 100644 --- a/cmd/setup/setup.go +++ b/cmd/setup/setup.go @@ -215,6 +215,7 @@ func Setup(ctx context.Context, config *Config, steps *Steps, masterKey string) steps.s54InstancePositionIndex = &InstancePositionIndex{dbClient: dbClient} steps.s55ExecutionHandlerStart = &ExecutionHandlerStart{dbClient: dbClient} steps.s56IDPTemplate6SAMLFederatedLogout = &IDPTemplate6SAMLFederatedLogout{dbClient: dbClient} + steps.s57CreateResourceCounts = &CreateResourceCounts{dbClient: dbClient} err = projection.Create(ctx, dbClient, eventstoreClient, config.Projections, nil, nil, nil) logging.OnError(err).Fatal("unable to start projections") @@ -260,6 +261,7 @@ func Setup(ctx context.Context, config *Config, steps *Steps, masterKey string) steps.s54InstancePositionIndex, steps.s55ExecutionHandlerStart, steps.s56IDPTemplate6SAMLFederatedLogout, + steps.s57CreateResourceCounts, } { setupErr = executeMigration(ctx, eventstoreClient, step, "migration failed") if setupErr != nil { @@ -296,6 +298,7 @@ func Setup(ctx context.Context, config *Config, steps *Steps, masterKey string) client: dbClient, }, } + repeatableSteps = append(repeatableSteps, triggerSteps(dbClient)...) for _, repeatableStep := range repeatableSteps { setupErr = executeMigration(ctx, eventstoreClient, repeatableStep, "unable to migrate repeatable step") diff --git a/cmd/setup/trigger_steps.go b/cmd/setup/trigger_steps.go new file mode 100644 index 0000000000..163a8fdb59 --- /dev/null +++ b/cmd/setup/trigger_steps.go @@ -0,0 +1,125 @@ +package setup + +import ( + "fmt" + + "github.com/zitadel/zitadel/internal/database" + "github.com/zitadel/zitadel/internal/domain" + "github.com/zitadel/zitadel/internal/migration" + "github.com/zitadel/zitadel/internal/query/projection" +) + +// triggerSteps defines the repeatable migrations that set up triggers +// for counting resources in the database. +func triggerSteps(db *database.DB) []migration.RepeatableMigration { + return []migration.RepeatableMigration{ + // Delete parent count triggers for instances and organizations + migration.DeleteParentCountsTrigger(db, + projection.InstanceProjectionTable, + domain.CountParentTypeInstance, + projection.InstanceColumnID, + projection.InstanceColumnID, + "instance", + ), + migration.DeleteParentCountsTrigger(db, + projection.OrgProjectionTable, + domain.CountParentTypeOrganization, + projection.OrgColumnInstanceID, + projection.OrgColumnID, + "organization", + ), + + // Count triggers for all the resources + migration.CountTrigger(db, + projection.OrgProjectionTable, + domain.CountParentTypeInstance, + projection.OrgColumnInstanceID, + projection.OrgColumnInstanceID, + "organization", + ), + migration.CountTrigger(db, + projection.ProjectProjectionTable, + domain.CountParentTypeOrganization, + projection.ProjectColumnInstanceID, + projection.ProjectColumnResourceOwner, + "project", + ), + migration.CountTrigger(db, + projection.UserTable, + domain.CountParentTypeOrganization, + projection.UserInstanceIDCol, + projection.UserResourceOwnerCol, + "user", + ), + migration.CountTrigger(db, + projection.InstanceMemberProjectionTable, + domain.CountParentTypeInstance, + projection.MemberInstanceID, + projection.MemberResourceOwner, + "iam_admin", + ), + migration.CountTrigger(db, + projection.IDPTable, + domain.CountParentTypeInstance, + projection.IDPInstanceIDCol, + projection.IDPInstanceIDCol, + "identity_provider", + ), + migration.CountTrigger(db, + projection.IDPTemplateLDAPTable, + domain.CountParentTypeInstance, + projection.LDAPInstanceIDCol, + projection.LDAPInstanceIDCol, + "identity_provider_ldap", + ), + migration.CountTrigger(db, + projection.ActionTable, + domain.CountParentTypeInstance, + projection.ActionInstanceIDCol, + projection.ActionInstanceIDCol, + "action_v1", + ), + migration.CountTrigger(db, + projection.ExecutionTable, + domain.CountParentTypeInstance, + projection.ExecutionInstanceIDCol, + projection.ExecutionInstanceIDCol, + "execution", + ), + migration.CountTrigger(db, + fmt.Sprintf("%s_%s", projection.ExecutionTable, projection.ExecutionTargetSuffix), + domain.CountParentTypeInstance, + projection.ExecutionTargetInstanceIDCol, + projection.ExecutionTargetInstanceIDCol, + "execution_target", + ), + migration.CountTrigger(db, + projection.LoginPolicyTable, + domain.CountParentTypeInstance, + projection.LoginPolicyInstanceIDCol, + projection.LoginPolicyInstanceIDCol, + "login_policy", + ), + migration.CountTrigger(db, + projection.PasswordComplexityTable, + domain.CountParentTypeInstance, + projection.ComplexityPolicyInstanceIDCol, + projection.ComplexityPolicyInstanceIDCol, + "password_complexity_policy", + ), + migration.CountTrigger(db, + projection.PasswordAgeTable, + domain.CountParentTypeInstance, + projection.AgePolicyInstanceIDCol, + projection.AgePolicyInstanceIDCol, + "password_expiry_policy", + ), + migration.CountTrigger(db, + projection.LockoutPolicyTable, + domain.CountParentTypeInstance, + projection.LockoutPolicyInstanceIDCol, + projection.LockoutPolicyInstanceIDCol, + "lockout_policy", + ), + } +} diff --git a/internal/domain/count_trigger.go b/internal/domain/count_trigger.go new file mode 100644 index 0000000000..a29d125fe9 --- /dev/null +++ b/internal/domain/count_trigger.go @@ -0,0 +1,9 @@ +package domain + +//go:generate enumer -type CountParentType -transform lower -trimprefix CountParentType -sql +type CountParentType int + +const ( + CountParentTypeInstance CountParentType = iota + CountParentTypeOrganization +) diff --git a/internal/domain/countparenttype_enumer.go b/internal/domain/countparenttype_enumer.go new file mode 100644 index 0000000000..8691d97e62 --- /dev/null +++ b/internal/domain/countparenttype_enumer.go @@ -0,0 +1,109 @@ +// Code generated by "enumer -type CountParentType -transform lower -trimprefix CountParentType -sql"; DO NOT EDIT. + +package domain + +import ( + "database/sql/driver" + "fmt" + "strings" +) + +const _CountParentTypeName = "instanceorganization" + +var _CountParentTypeIndex = [...]uint8{0, 8, 20} + +const _CountParentTypeLowerName = "instanceorganization" + +func (i CountParentType) String() string { + if i < 0 || i >= CountParentType(len(_CountParentTypeIndex)-1) { + return fmt.Sprintf("CountParentType(%d)", i) + } + return _CountParentTypeName[_CountParentTypeIndex[i]:_CountParentTypeIndex[i+1]] +} + +// An "invalid array index" compiler error signifies that the constant values have changed. +// Re-run the stringer command to generate them again. +func _CountParentTypeNoOp() { + var x [1]struct{} + _ = x[CountParentTypeInstance-(0)] + _ = x[CountParentTypeOrganization-(1)] +} + +var _CountParentTypeValues = []CountParentType{CountParentTypeInstance, CountParentTypeOrganization} + +var _CountParentTypeNameToValueMap = map[string]CountParentType{ + _CountParentTypeName[0:8]: CountParentTypeInstance, + _CountParentTypeLowerName[0:8]: CountParentTypeInstance, + _CountParentTypeName[8:20]: CountParentTypeOrganization, + _CountParentTypeLowerName[8:20]: CountParentTypeOrganization, +} + +var _CountParentTypeNames = []string{ + _CountParentTypeName[0:8], + _CountParentTypeName[8:20], +} + +// CountParentTypeString retrieves an enum value from the enum constants string name. +// Throws an error if the param is not part of the enum. +func CountParentTypeString(s string) (CountParentType, error) { + if val, ok := _CountParentTypeNameToValueMap[s]; ok { + return val, nil + } + + if val, ok := _CountParentTypeNameToValueMap[strings.ToLower(s)]; ok { + return val, nil + } + return 0, fmt.Errorf("%s does not belong to CountParentType values", s) +} + +// CountParentTypeValues returns all values of the enum +func CountParentTypeValues() []CountParentType { + return _CountParentTypeValues +} + +// CountParentTypeStrings returns a slice of all String values of the enum +func CountParentTypeStrings() []string { + strs := make([]string, len(_CountParentTypeNames)) + copy(strs, _CountParentTypeNames) + return strs +} + +// IsACountParentType returns "true" if the value is listed in the enum definition. "false" otherwise +func (i CountParentType) IsACountParentType() bool { + for _, v := range _CountParentTypeValues { + if i == v { + return true + } + } + return false +} + +func (i CountParentType) Value() (driver.Value, error) { + return i.String(), nil +} + +func (i *CountParentType) Scan(value interface{}) error { + if value == nil { + return nil + } + + var str string + switch v := value.(type) { + case []byte: + str = string(v) + case string: + str = v + case fmt.Stringer: + str = v.String() + default: + return fmt.Errorf("invalid value of CountParentType: %[1]T(%[1]v)", value) + } + + val, err := CountParentTypeString(str) + if err != nil { + return err + } + + *i = val + return nil +} diff --git a/internal/domain/secretgeneratortype_enumer.go b/internal/domain/secretgeneratortype_enumer.go index f819bafc1f..db66715670 100644 --- a/internal/domain/secretgeneratortype_enumer.go +++ b/internal/domain/secretgeneratortype_enumer.go @@ -4,11 +4,14 @@ package domain import ( "fmt" + "strings" ) -const _SecretGeneratorTypeName = "unspecifiedinit_codeverify_email_codeverify_phone_codeverify_domainpassword_reset_codepasswordless_init_codeapp_secretotpsmsotp_emailinvite_codesecret_generator_type_count" +const _SecretGeneratorTypeName = "unspecifiedinit_codeverify_email_codeverify_phone_codeverify_domainpassword_reset_codepasswordless_init_codeapp_secretotpsmsotp_emailinvite_codesigning_keysecret_generator_type_count" -var _SecretGeneratorTypeIndex = [...]uint8{0, 11, 20, 37, 54, 67, 86, 108, 118, 124, 133, 144, 171} +var _SecretGeneratorTypeIndex = [...]uint8{0, 11, 20, 37, 54, 67, 86, 108, 118, 124, 133, 144, 155, 182} + +const _SecretGeneratorTypeLowerName = "unspecifiedinit_codeverify_email_codeverify_phone_codeverify_domainpassword_reset_codepasswordless_init_codeapp_secretotpsmsotp_emailinvite_codesigning_keysecret_generator_type_count" func (i SecretGeneratorType) String() string { if i < 0 || i >= SecretGeneratorType(len(_SecretGeneratorTypeIndex)-1) { @@ -17,21 +20,70 @@ func (i SecretGeneratorType) String() string { return _SecretGeneratorTypeName[_SecretGeneratorTypeIndex[i]:_SecretGeneratorTypeIndex[i+1]] } -var _SecretGeneratorTypeValues = []SecretGeneratorType{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11} +// An "invalid array index" compiler error signifies that the constant values have changed. +// Re-run the stringer command to generate them again. +func _SecretGeneratorTypeNoOp() { + var x [1]struct{} + _ = x[SecretGeneratorTypeUnspecified-(0)] + _ = x[SecretGeneratorTypeInitCode-(1)] + _ = x[SecretGeneratorTypeVerifyEmailCode-(2)] + _ = x[SecretGeneratorTypeVerifyPhoneCode-(3)] + _ = x[SecretGeneratorTypeVerifyDomain-(4)] + _ = x[SecretGeneratorTypePasswordResetCode-(5)] + _ = x[SecretGeneratorTypePasswordlessInitCode-(6)] + _ = x[SecretGeneratorTypeAppSecret-(7)] + _ = x[SecretGeneratorTypeOTPSMS-(8)] + _ = x[SecretGeneratorTypeOTPEmail-(9)] + _ = x[SecretGeneratorTypeInviteCode-(10)] + _ = x[SecretGeneratorTypeSigningKey-(11)] + _ = x[secretGeneratorTypeCount-(12)] +} + +var _SecretGeneratorTypeValues = []SecretGeneratorType{SecretGeneratorTypeUnspecified, SecretGeneratorTypeInitCode, SecretGeneratorTypeVerifyEmailCode, SecretGeneratorTypeVerifyPhoneCode, SecretGeneratorTypeVerifyDomain, SecretGeneratorTypePasswordResetCode, SecretGeneratorTypePasswordlessInitCode, SecretGeneratorTypeAppSecret, SecretGeneratorTypeOTPSMS, SecretGeneratorTypeOTPEmail, SecretGeneratorTypeInviteCode, SecretGeneratorTypeSigningKey, secretGeneratorTypeCount} var _SecretGeneratorTypeNameToValueMap = map[string]SecretGeneratorType{ - _SecretGeneratorTypeName[0:11]: 0, - _SecretGeneratorTypeName[11:20]: 1, - _SecretGeneratorTypeName[20:37]: 2, - _SecretGeneratorTypeName[37:54]: 3, - _SecretGeneratorTypeName[54:67]: 4, - _SecretGeneratorTypeName[67:86]: 5, - _SecretGeneratorTypeName[86:108]: 6, - _SecretGeneratorTypeName[108:118]: 7, - _SecretGeneratorTypeName[118:124]: 8, - _SecretGeneratorTypeName[124:133]: 9, - _SecretGeneratorTypeName[133:144]: 10, - _SecretGeneratorTypeName[144:171]: 11, + _SecretGeneratorTypeName[0:11]: SecretGeneratorTypeUnspecified, + _SecretGeneratorTypeLowerName[0:11]: SecretGeneratorTypeUnspecified, + _SecretGeneratorTypeName[11:20]: SecretGeneratorTypeInitCode, + _SecretGeneratorTypeLowerName[11:20]: SecretGeneratorTypeInitCode, + _SecretGeneratorTypeName[20:37]: SecretGeneratorTypeVerifyEmailCode, + _SecretGeneratorTypeLowerName[20:37]: SecretGeneratorTypeVerifyEmailCode, + _SecretGeneratorTypeName[37:54]: SecretGeneratorTypeVerifyPhoneCode, + _SecretGeneratorTypeLowerName[37:54]: SecretGeneratorTypeVerifyPhoneCode, + _SecretGeneratorTypeName[54:67]: SecretGeneratorTypeVerifyDomain, + _SecretGeneratorTypeLowerName[54:67]: SecretGeneratorTypeVerifyDomain, + _SecretGeneratorTypeName[67:86]: SecretGeneratorTypePasswordResetCode, + _SecretGeneratorTypeLowerName[67:86]: SecretGeneratorTypePasswordResetCode, + _SecretGeneratorTypeName[86:108]: SecretGeneratorTypePasswordlessInitCode, + _SecretGeneratorTypeLowerName[86:108]: SecretGeneratorTypePasswordlessInitCode, + _SecretGeneratorTypeName[108:118]: SecretGeneratorTypeAppSecret, + _SecretGeneratorTypeLowerName[108:118]: SecretGeneratorTypeAppSecret, + _SecretGeneratorTypeName[118:124]: SecretGeneratorTypeOTPSMS, + _SecretGeneratorTypeLowerName[118:124]: SecretGeneratorTypeOTPSMS, + _SecretGeneratorTypeName[124:133]: SecretGeneratorTypeOTPEmail, + _SecretGeneratorTypeLowerName[124:133]: SecretGeneratorTypeOTPEmail, + _SecretGeneratorTypeName[133:144]: SecretGeneratorTypeInviteCode, + _SecretGeneratorTypeLowerName[133:144]: SecretGeneratorTypeInviteCode, + _SecretGeneratorTypeName[144:155]: SecretGeneratorTypeSigningKey, + _SecretGeneratorTypeLowerName[144:155]: SecretGeneratorTypeSigningKey, + _SecretGeneratorTypeName[155:182]: secretGeneratorTypeCount, + _SecretGeneratorTypeLowerName[155:182]: secretGeneratorTypeCount, +} + +var _SecretGeneratorTypeNames = []string{ + _SecretGeneratorTypeName[0:11], + _SecretGeneratorTypeName[11:20], + _SecretGeneratorTypeName[20:37], + _SecretGeneratorTypeName[37:54], + _SecretGeneratorTypeName[54:67], + _SecretGeneratorTypeName[67:86], + _SecretGeneratorTypeName[86:108], + _SecretGeneratorTypeName[108:118], + _SecretGeneratorTypeName[118:124], + _SecretGeneratorTypeName[124:133], + _SecretGeneratorTypeName[133:144], + _SecretGeneratorTypeName[144:155], + _SecretGeneratorTypeName[155:182], } // SecretGeneratorTypeString retrieves an enum value from the enum constants string name. @@ -40,6 +92,10 @@ func SecretGeneratorTypeString(s string) (SecretGeneratorType, error) { if val, ok := _SecretGeneratorTypeNameToValueMap[s]; ok { return val, nil } + + if val, ok := _SecretGeneratorTypeNameToValueMap[strings.ToLower(s)]; ok { + return val, nil + } return 0, fmt.Errorf("%s does not belong to SecretGeneratorType values", s) } @@ -48,6 +104,13 @@ func SecretGeneratorTypeValues() []SecretGeneratorType { return _SecretGeneratorTypeValues } +// SecretGeneratorTypeStrings returns a slice of all String values of the enum +func SecretGeneratorTypeStrings() []string { + strs := make([]string, len(_SecretGeneratorTypeNames)) + copy(strs, _SecretGeneratorTypeNames) + return strs +} + // IsASecretGeneratorType returns "true" if the value is listed in the enum definition. "false" otherwise func (i SecretGeneratorType) IsASecretGeneratorType() bool { for _, v := range _SecretGeneratorTypeValues { diff --git a/internal/migration/count_trigger.sql b/internal/migration/count_trigger.sql new file mode 100644 index 0000000000..4b521094ab --- /dev/null +++ b/internal/migration/count_trigger.sql @@ -0,0 +1,43 @@ +{{ define "count_trigger" -}} +CREATE OR REPLACE TRIGGER count_{{ .Resource }} + AFTER INSERT OR DELETE + ON {{ .Table }} + FOR EACH ROW + EXECUTE FUNCTION projections.count_resource( + '{{ .ParentType }}', + '{{ .InstanceIDColumn }}', + '{{ .ParentIDColumn }}', + '{{ .Resource }}' + ); + +CREATE OR REPLACE TRIGGER truncate_{{ .Resource }}_counts + AFTER TRUNCATE + ON {{ .Table }} + FOR EACH STATEMENT + EXECUTE FUNCTION projections.delete_table_counts(); + +-- Prevent inserts and deletes while we populate the counts. +LOCK TABLE {{ .Table }} IN SHARE MODE; + +-- Populate the resource counts for the existing data in the table. +INSERT INTO projections.resource_counts( + instance_id, + table_name, + parent_type, + parent_id, + resource_name, + amount +) +SELECT + {{ .InstanceIDColumn }}, + '{{ .Table }}', + '{{ .ParentType }}', + {{ .ParentIDColumn }}, + '{{ .Resource }}', + COUNT(*) AS amount +FROM {{ .Table }} +GROUP BY ({{ .InstanceIDColumn }}, {{ .ParentIDColumn }}) +ON CONFLICT (instance_id, table_name, parent_type, parent_id) DO +UPDATE SET updated_at = now(), amount = EXCLUDED.amount; + +{{- end -}} diff --git a/internal/migration/delete_parent_counts_trigger.sql b/internal/migration/delete_parent_counts_trigger.sql new file mode 100644 index 0000000000..a2e9df6626 --- /dev/null +++ b/internal/migration/delete_parent_counts_trigger.sql @@ -0,0 +1,13 @@ +{{ define "delete_parent_counts_trigger" -}} + +CREATE OR REPLACE TRIGGER delete_parent_counts_trigger + AFTER DELETE + ON {{ .Table }} + FOR EACH ROW + EXECUTE FUNCTION projections.delete_parent_counts( + '{{ .ParentType }}', + '{{ .InstanceIDColumn }}', + '{{ .ParentIDColumn }}' + ); + +{{- end -}} diff --git a/internal/migration/migration.go b/internal/migration/migration.go index a2224340a7..3aeb2f0612 100644 --- a/internal/migration/migration.go +++ b/internal/migration/migration.go @@ -36,7 +36,10 @@ type errCheckerMigration interface { type RepeatableMigration interface { Migration - Check(lastRun map[string]interface{}) bool + + // Check if the migration should be executed again. + // True will repeat the migration, false will not. + Check(lastRun map[string]any) bool } func Migrate(ctx context.Context, es *eventstore.Eventstore, migration Migration) (err error) { diff --git a/internal/migration/trigger.go b/internal/migration/trigger.go new file mode 100644 index 0000000000..bd06afd5c5 --- /dev/null +++ b/internal/migration/trigger.go @@ -0,0 +1,127 @@ +package migration + +import ( + "context" + "embed" + "fmt" + "strings" + "text/template" + + "github.com/mitchellh/mapstructure" + + "github.com/zitadel/zitadel/internal/database" + "github.com/zitadel/zitadel/internal/domain" + "github.com/zitadel/zitadel/internal/eventstore" +) + +const ( + countTriggerTmpl = "count_trigger" + deleteParentCountsTmpl = "delete_parent_counts_trigger" +) + +var ( + //go:embed *.sql + templateFS embed.FS + templates = template.Must(template.ParseFS(templateFS, "*.sql")) +) + +// CountTrigger registers the existing projections.count_trigger function. +// The trigger than takes care of keeping count of existing +// rows in the source table. +// It also pre-populates the projections.resource_counts table with +// the counts for the given table. +// +// During the population of the resource_counts table, +// the source table is share-locked to prevent concurrent modifications. +// Projection handlers will be halted until the lock is released. +// SELECT statements are not blocked by the lock. +// +// This migration repeats when any of the arguments are changed, +// such as renaming of a projection table. +func CountTrigger( + db *database.DB, + table string, + parentType domain.CountParentType, + instanceIDColumn string, + parentIDColumn string, + resource string, +) RepeatableMigration { + return &triggerMigration{ + triggerConfig: triggerConfig{ + Table: table, + ParentType: parentType.String(), + InstanceIDColumn: instanceIDColumn, + ParentIDColumn: parentIDColumn, + Resource: resource, + }, + db: db, + templateName: countTriggerTmpl, + } +} + +// DeleteParentCountsTrigger +// +// This migration repeats when any of the arguments are changed, +// such as renaming of a projection table. +func DeleteParentCountsTrigger( + db *database.DB, + table string, + parentType domain.CountParentType, + instanceIDColumn string, + parentIDColumn string, + resource string, +) RepeatableMigration { + return &triggerMigration{ + triggerConfig: triggerConfig{ + Table: table, + ParentType: parentType.String(), + InstanceIDColumn: instanceIDColumn, + ParentIDColumn: parentIDColumn, + Resource: resource, + }, + db: db, + templateName: deleteParentCountsTmpl, + } +} + +type triggerMigration struct { + triggerConfig + db *database.DB + templateName string +} + +// String implements [Migration] and [fmt.Stringer]. +func (m *triggerMigration) String() string { + return fmt.Sprintf("repeatable_%s_%s", m.Resource, m.templateName) +} + +// Execute implements [Migration] +func (m *triggerMigration) Execute(ctx context.Context, _ eventstore.Event) error { + var query strings.Builder + err := templates.ExecuteTemplate(&query, m.templateName, m.triggerConfig) + if err != nil { + return fmt.Errorf("%s: execute trigger template: %w", m, err) + } + _, err = m.db.ExecContext(ctx, query.String()) + if err != nil { + return fmt.Errorf("%s: exec trigger query: %w", m, err) + } + return nil +} + +type triggerConfig struct { + Table string `json:"table,omitempty" mapstructure:"table"` + ParentType string `json:"parent_type,omitempty" mapstructure:"parent_type"` + InstanceIDColumn string `json:"instance_id_column,omitempty" mapstructure:"instance_id_column"` + ParentIDColumn string `json:"parent_id_column,omitempty" mapstructure:"parent_id_column"` + Resource string `json:"resource,omitempty" mapstructure:"resource"` +} + +// Check implements [RepeatableMigration]. +func (c *triggerConfig) Check(lastRun map[string]any) bool { + var dst triggerConfig + if err := mapstructure.Decode(lastRun, &dst); err != nil { + panic(err) + } + return dst != *c +} diff --git a/internal/migration/trigger_test.go b/internal/migration/trigger_test.go new file mode 100644 index 0000000000..5799526428 --- /dev/null +++ b/internal/migration/trigger_test.go @@ -0,0 +1,253 @@ +package migration + +import ( + "context" + "regexp" + "testing" + + "github.com/DATA-DOG/go-sqlmock" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/zitadel/zitadel/internal/database" +) + +const ( + expCountTriggerQuery = `CREATE OR REPLACE TRIGGER count_resource + AFTER INSERT OR DELETE + ON table + FOR EACH ROW + EXECUTE FUNCTION projections.count_resource( + 'instance', + 'instance_id', + 'parent_id', + 'resource' + ); + +CREATE OR REPLACE TRIGGER truncate_resource_counts + AFTER TRUNCATE + ON table + FOR EACH STATEMENT + EXECUTE FUNCTION projections.delete_table_counts(); + +-- Prevent inserts and deletes while we populate the counts. +LOCK TABLE table IN SHARE MODE; + +-- Populate the resource counts for the existing data in the table. +INSERT INTO projections.resource_counts( + instance_id, + table_name, + parent_type, + parent_id, + resource_name, + amount +) +SELECT + instance_id, + 'table', + 'instance', + parent_id, + 'resource', + COUNT(*) AS amount +FROM table +GROUP BY (instance_id, parent_id) +ON CONFLICT (instance_id, table_name, parent_type, parent_id) DO +UPDATE SET updated_at = now(), amount = EXCLUDED.amount;` + + expDeleteParentCountsQuery = `CREATE OR REPLACE TRIGGER delete_parent_counts_trigger + AFTER DELETE + ON table + FOR EACH ROW + EXECUTE FUNCTION projections.delete_parent_counts( + 'instance', + 'instance_id', + 'parent_id' + );` +) + +func Test_triggerMigration_Execute(t *testing.T) { + type fields struct { + triggerConfig triggerConfig + templateName string + } + tests := []struct { + name string + fields fields + expects func(sqlmock.Sqlmock) + wantErr bool + }{ + { + name: "template error", + fields: fields{ + triggerConfig: triggerConfig{ + Table: "table", + ParentType: "instance", + InstanceIDColumn: "instance_id", + ParentIDColumn: "parent_id", + Resource: "resource", + }, + templateName: "foo", + }, + expects: func(_ sqlmock.Sqlmock) {}, + wantErr: true, + }, + { + name: "db error", + fields: fields{ + triggerConfig: triggerConfig{ + Table: "table", + ParentType: "instance", + InstanceIDColumn: "instance_id", + ParentIDColumn: "parent_id", + Resource: "resource", + }, + templateName: countTriggerTmpl, + }, + expects: func(mock sqlmock.Sqlmock) { + mock.ExpectExec(regexp.QuoteMeta(expCountTriggerQuery)). + WillReturnError(assert.AnError) + }, + wantErr: true, + }, + { + name: "count trigger", + fields: fields{ + triggerConfig: triggerConfig{ + Table: "table", + ParentType: "instance", + InstanceIDColumn: "instance_id", + ParentIDColumn: "parent_id", + Resource: "resource", + }, + templateName: countTriggerTmpl, + }, + expects: func(mock sqlmock.Sqlmock) { + mock.ExpectExec(regexp.QuoteMeta(expCountTriggerQuery)). + WithoutArgs(). + WillReturnResult( + sqlmock.NewResult(1, 1), + ) + }, + }, + { + name: "count trigger", + fields: fields{ + triggerConfig: triggerConfig{ + Table: "table", + ParentType: "instance", + InstanceIDColumn: "instance_id", + ParentIDColumn: "parent_id", + Resource: "resource", + }, + templateName: deleteParentCountsTmpl, + }, + expects: func(mock sqlmock.Sqlmock) { + mock.ExpectExec(regexp.QuoteMeta(expDeleteParentCountsQuery)). + WithoutArgs(). + WillReturnResult( + sqlmock.NewResult(1, 1), + ) + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + db, mock, err := sqlmock.New() + require.NoError(t, err) + defer func() { + err := mock.ExpectationsWereMet() + require.NoError(t, err) + }() + defer db.Close() + tt.expects(mock) + mock.ExpectClose() + + m := &triggerMigration{ + db: &database.DB{ + DB: db, + }, + triggerConfig: tt.fields.triggerConfig, + templateName: tt.fields.templateName, + } + err = m.Execute(context.Background(), nil) + if tt.wantErr { + assert.Error(t, err) + return + } + require.NoError(t, err) + }) + } +} + +func Test_triggerConfig_Check(t *testing.T) { + type fields struct { + Table string + ParentType string + InstanceIDColumn string + ParentIDColumn string + Resource string + } + type args struct { + lastRun map[string]any + } + tests := []struct { + name string + fields fields + args args + want bool + }{ + { + name: "should", + fields: fields{ + Table: "users2", + ParentType: "instance", + InstanceIDColumn: "instance_id", + ParentIDColumn: "parent_id", + Resource: "user", + }, + args: args{ + lastRun: map[string]any{ + "table": "users1", + "parent_type": "instance", + "instance_id_column": "instance_id", + "parent_id_column": "parent_id", + "resource": "user", + }, + }, + want: true, + }, + { + name: "should not", + fields: fields{ + Table: "users1", + ParentType: "instance", + InstanceIDColumn: "instance_id", + ParentIDColumn: "parent_id", + Resource: "user", + }, + args: args{ + lastRun: map[string]any{ + "table": "users1", + "parent_type": "instance", + "instance_id_column": "instance_id", + "parent_id_column": "parent_id", + "resource": "user", + }, + }, + want: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c := &triggerConfig{ + Table: tt.fields.Table, + ParentType: tt.fields.ParentType, + InstanceIDColumn: tt.fields.InstanceIDColumn, + ParentIDColumn: tt.fields.ParentIDColumn, + Resource: tt.fields.Resource, + } + got := c.Check(tt.args.lastRun) + assert.Equal(t, tt.want, got) + }) + } +} diff --git a/internal/query/resource_counts.go b/internal/query/resource_counts.go new file mode 100644 index 0000000000..9d486e0b90 --- /dev/null +++ b/internal/query/resource_counts.go @@ -0,0 +1,61 @@ +package query + +import ( + "context" + "database/sql" + _ "embed" + "time" + + "github.com/zitadel/zitadel/internal/domain" + "github.com/zitadel/zitadel/internal/telemetry/tracing" + "github.com/zitadel/zitadel/internal/zerrors" +) + +var ( + //go:embed resource_counts_list.sql + resourceCountsListQuery string +) + +type ResourceCount struct { + ID int // Primary key, used for pagination + InstanceID string + TableName string + ParentType domain.CountParentType + ParentID string + Resource string + UpdatedAt time.Time + Amount int +} + +// ListResourceCounts retrieves all resource counts. +// It supports pagination using lastID and limit parameters. +// +// TODO: Currently only a proof of concept, filters may be implemented later if required. +func (q *Queries) ListResourceCounts(ctx context.Context, lastID, limit int) (result []ResourceCount, err error) { + ctx, span := tracing.NewSpan(ctx) + defer func() { span.EndWithError(err) }() + + err = q.client.QueryContext(ctx, func(rows *sql.Rows) error { + for rows.Next() { + var count ResourceCount + err := rows.Scan( + &count.ID, + &count.InstanceID, + &count.TableName, + &count.ParentType, + &count.ParentID, + &count.Resource, + &count.UpdatedAt, + &count.Amount) + if err != nil { + return zerrors.ThrowInternal(err, "QUERY-2f4g5", "Errors.Internal") + } + result = append(result, count) + } + return nil + }, resourceCountsListQuery, lastID, limit) + if err != nil { + return nil, zerrors.ThrowInternal(err, "QUERY-3f4g5", "Errors.Internal") + } + return result, nil +} diff --git a/internal/query/resource_counts_list.sql b/internal/query/resource_counts_list.sql new file mode 100644 index 0000000000..0d4abf87eb --- /dev/null +++ b/internal/query/resource_counts_list.sql @@ -0,0 +1,12 @@ +SELECT id, + instance_id, + table_name, + parent_type, + parent_id, + resource_name, + updated_at, + amount +FROM projections.resource_counts +WHERE id > $1 +ORDER BY id +LIMIT $2; diff --git a/internal/query/resource_counts_test.go b/internal/query/resource_counts_test.go new file mode 100644 index 0000000000..2829a660ef --- /dev/null +++ b/internal/query/resource_counts_test.go @@ -0,0 +1,109 @@ +package query + +import ( + "context" + _ "embed" + "regexp" + "testing" + "time" + + "github.com/DATA-DOG/go-sqlmock" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/zitadel/zitadel/internal/database" + "github.com/zitadel/zitadel/internal/domain" +) + +func TestQueries_ListResourceCounts(t *testing.T) { + columns := []string{"id", "instance_id", "table_name", "parent_type", "parent_id", "resource_name", "updated_at", "amount"} + type args struct { + lastID int + limit int + } + tests := []struct { + name string + args args + expects func(sqlmock.Sqlmock) + wantResult []ResourceCount + wantErr bool + }{ + { + name: "query error", + args: args{ + lastID: 0, + limit: 10, + }, + expects: func(mock sqlmock.Sqlmock) { + mock.ExpectQuery(regexp.QuoteMeta(resourceCountsListQuery)). + WithArgs(0, 10). + WillReturnError(assert.AnError) + }, + wantErr: true, + }, + { + name: "success", + args: args{ + lastID: 0, + limit: 10, + }, + expects: func(mock sqlmock.Sqlmock) { + mock.ExpectQuery(regexp.QuoteMeta(resourceCountsListQuery)). + WithArgs(0, 10). + WillReturnRows( + sqlmock.NewRows(columns). + AddRow(1, "instance_1", "table", "instance", "parent_1", "resource_name", time.Unix(1, 2), 5). + AddRow(2, "instance_2", "table", "instance", "parent_2", "resource_name", time.Unix(1, 2), 6), + ) + }, + wantResult: []ResourceCount{ + { + ID: 1, + InstanceID: "instance_1", + TableName: "table", + ParentType: domain.CountParentTypeInstance, + ParentID: "parent_1", + Resource: "resource_name", + UpdatedAt: time.Unix(1, 2), + Amount: 5, + }, + { + ID: 2, + InstanceID: "instance_2", + TableName: "table", + ParentType: domain.CountParentTypeInstance, + ParentID: "parent_2", + Resource: "resource_name", + UpdatedAt: time.Unix(1, 2), + Amount: 6, + }, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + db, mock, err := sqlmock.New() + require.NoError(t, err) + defer func() { + err := mock.ExpectationsWereMet() + require.NoError(t, err) + }() + defer db.Close() + tt.expects(mock) + mock.ExpectClose() + q := &Queries{ + client: &database.DB{ + DB: db, + }, + } + + gotResult, err := q.ListResourceCounts(context.Background(), tt.args.lastID, tt.args.limit) + if tt.wantErr { + require.Error(t, err) + return + } + require.NoError(t, err) + assert.Equal(t, tt.wantResult, gotResult, "ListResourceCounts() result mismatch") + }) + } +}