feat(projections): resource counters (#9979)

# Which Problems Are Solved

Add the ability to keep track of the current counts of projection
resources. We want to prevent calling `SELECT COUNT(*)` on tables, as
that forces a full scan and sudden spikes of DB resource uses.

# How the Problems Are Solved

- A resource_counts table is added
- Triggers that increment and decrement the counted values on inserts
and deletes
- Triggers that delete all counts of a table when the source table is
TRUNCATEd. This is not in the business logic, but prevents wrong counts
in case someone want to force a re-projection.
- Triggers that delete all counts if the parent resource is deleted
- Script to pre-populate the resource_counts table when a new source
table is added.

The triggers are reusable for any type of resource, in case we choose to
add more in the future.
Counts are aggregated by a given parent. Currently only `instance` and
`organization` are defined as possible parent. This can later be
extended to other types, such as `project`, should the need arise.

I deliberately chose to use `parent_id` to distinguish from the
de-factor `resource_owner` which is usually an organization ID. For
example:

- For users the parent is an organization and the `parent_id` matches
`resource_owner`.
- For organizations the parent is an instance, but the `resource_owner`
is the `org_id`. In this case the `parent_id` is the `instance_id`.
- Applications would have a similar problem, where the parent is a
project, but the `resource_owner` is the `org_id`


# Additional Context

Closes https://github.com/zitadel/zitadel/issues/9957
This commit is contained in:
Tim Möhlmann
2025-06-03 17:15:30 +03:00
committed by GitHub
parent b8ff83454e
commit b9c1cdf4ad
16 changed files with 1080 additions and 16 deletions

27
cmd/setup/57.go Normal file
View File

@@ -0,0 +1,27 @@
package setup
import (
"context"
_ "embed"
"github.com/zitadel/zitadel/internal/database"
"github.com/zitadel/zitadel/internal/eventstore"
)
var (
//go:embed 57.sql
createResourceCounts string
)
type CreateResourceCounts struct {
dbClient *database.DB
}
func (mig *CreateResourceCounts) Execute(ctx context.Context, _ eventstore.Event) error {
_, err := mig.dbClient.ExecContext(ctx, createResourceCounts)
return err
}
func (mig *CreateResourceCounts) String() string {
return "57_create_resource_counts"
}

106
cmd/setup/57.sql Normal file
View File

@@ -0,0 +1,106 @@
CREATE TABLE IF NOT EXISTS projections.resource_counts
(
id SERIAL PRIMARY KEY, -- allows for easy pagination
instance_id TEXT NOT NULL,
table_name TEXT NOT NULL, -- needed for trigger matching, not in reports
parent_type TEXT NOT NULL,
parent_id TEXT NOT NULL,
resource_name TEXT NOT NULL, -- friendly name for reporting
updated_at TIMESTAMPTZ NOT NULL DEFAULT now(),
amount INTEGER NOT NULL DEFAULT 1 CHECK (amount >= 0),
UNIQUE (instance_id, parent_type, parent_id, table_name)
);
-- count_resource is a trigger function which increases or decreases the count of a resource.
-- When creating the trigger the following required arguments (TG_ARGV) can be passed:
-- 1. The type of the parent
-- 2. The column name of the instance id
-- 3. The column name of the owner id
-- 4. The name of the resource
CREATE OR REPLACE FUNCTION projections.count_resource()
RETURNS trigger
LANGUAGE 'plpgsql' VOLATILE
AS $$
DECLARE
-- trigger variables
tg_table_name TEXT := TG_TABLE_SCHEMA || '.' || TG_TABLE_NAME;
tg_parent_type TEXT := TG_ARGV[0];
tg_instance_id_column TEXT := TG_ARGV[1];
tg_parent_id_column TEXT := TG_ARGV[2];
tg_resource_name TEXT := TG_ARGV[3];
tg_instance_id TEXT;
tg_parent_id TEXT;
select_ids TEXT := format('SELECT ($1).%I, ($1).%I', tg_instance_id_column, tg_parent_id_column);
BEGIN
IF (TG_OP = 'INSERT') THEN
EXECUTE select_ids INTO tg_instance_id, tg_parent_id USING NEW;
INSERT INTO projections.resource_counts(instance_id, table_name, parent_type, parent_id, resource_name)
VALUES (tg_instance_id, tg_table_name, tg_parent_type, tg_parent_id, tg_resource_name)
ON CONFLICT (instance_id, table_name, parent_type, parent_id) DO
UPDATE SET updated_at = now(), amount = projections.resource_counts.amount + 1;
RETURN NEW;
ELSEIF (TG_OP = 'DELETE') THEN
EXECUTE select_ids INTO tg_instance_id, tg_parent_id USING OLD;
UPDATE projections.resource_counts
SET updated_at = now(), amount = amount - 1
WHERE instance_id = tg_instance_id
AND table_name = tg_table_name
AND parent_type = tg_parent_type
AND parent_id = tg_parent_id
AND resource_name = tg_resource_name
AND amount > 0; -- prevent check failure on negative amount.
RETURN OLD;
END IF;
END
$$;
-- delete_table_counts removes all resource counts for a TRUNCATED table.
CREATE OR REPLACE FUNCTION projections.delete_table_counts()
RETURNS trigger
LANGUAGE 'plpgsql'
AS $$
DECLARE
-- trigger variables
tg_table_name TEXT := TG_TABLE_SCHEMA || '.' || TG_TABLE_NAME;
BEGIN
DELETE FROM projections.resource_counts
WHERE table_name = tg_table_name;
END
$$;
-- delete_parent_counts removes all resource counts for a deleted parent.
-- 1. The type of the parent
-- 2. The column name of the instance id
-- 3. The column name of the owner id
CREATE OR REPLACE FUNCTION projections.delete_parent_counts()
RETURNS trigger
LANGUAGE 'plpgsql'
AS $$
DECLARE
-- trigger variables
tg_parent_type TEXT := TG_ARGV[0];
tg_instance_id_column TEXT := TG_ARGV[1];
tg_parent_id_column TEXT := TG_ARGV[2];
tg_instance_id TEXT;
tg_parent_id TEXT;
select_ids TEXT := format('SELECT ($1).%I, ($1).%I', tg_instance_id_column, tg_parent_id_column);
BEGIN
EXECUTE select_ids INTO tg_instance_id, tg_parent_id USING OLD;
DELETE FROM projections.resource_counts
WHERE instance_id = tg_instance_id
AND parent_type = tg_parent_type
AND parent_id = tg_parent_id;
RETURN OLD;
END
$$;

View File

@@ -153,6 +153,7 @@ type Steps struct {
s54InstancePositionIndex *InstancePositionIndex
s55ExecutionHandlerStart *ExecutionHandlerStart
s56IDPTemplate6SAMLFederatedLogout *IDPTemplate6SAMLFederatedLogout
s57CreateResourceCounts *CreateResourceCounts
}
func MustNewSteps(v *viper.Viper) *Steps {

View File

@@ -215,6 +215,7 @@ func Setup(ctx context.Context, config *Config, steps *Steps, masterKey string)
steps.s54InstancePositionIndex = &InstancePositionIndex{dbClient: dbClient}
steps.s55ExecutionHandlerStart = &ExecutionHandlerStart{dbClient: dbClient}
steps.s56IDPTemplate6SAMLFederatedLogout = &IDPTemplate6SAMLFederatedLogout{dbClient: dbClient}
steps.s57CreateResourceCounts = &CreateResourceCounts{dbClient: dbClient}
err = projection.Create(ctx, dbClient, eventstoreClient, config.Projections, nil, nil, nil)
logging.OnError(err).Fatal("unable to start projections")
@@ -260,6 +261,7 @@ func Setup(ctx context.Context, config *Config, steps *Steps, masterKey string)
steps.s54InstancePositionIndex,
steps.s55ExecutionHandlerStart,
steps.s56IDPTemplate6SAMLFederatedLogout,
steps.s57CreateResourceCounts,
} {
setupErr = executeMigration(ctx, eventstoreClient, step, "migration failed")
if setupErr != nil {
@@ -296,6 +298,7 @@ func Setup(ctx context.Context, config *Config, steps *Steps, masterKey string)
client: dbClient,
},
}
repeatableSteps = append(repeatableSteps, triggerSteps(dbClient)...)
for _, repeatableStep := range repeatableSteps {
setupErr = executeMigration(ctx, eventstoreClient, repeatableStep, "unable to migrate repeatable step")

125
cmd/setup/trigger_steps.go Normal file
View File

@@ -0,0 +1,125 @@
package setup
import (
"fmt"
"github.com/zitadel/zitadel/internal/database"
"github.com/zitadel/zitadel/internal/domain"
"github.com/zitadel/zitadel/internal/migration"
"github.com/zitadel/zitadel/internal/query/projection"
)
// triggerSteps defines the repeatable migrations that set up triggers
// for counting resources in the database.
func triggerSteps(db *database.DB) []migration.RepeatableMigration {
return []migration.RepeatableMigration{
// Delete parent count triggers for instances and organizations
migration.DeleteParentCountsTrigger(db,
projection.InstanceProjectionTable,
domain.CountParentTypeInstance,
projection.InstanceColumnID,
projection.InstanceColumnID,
"instance",
),
migration.DeleteParentCountsTrigger(db,
projection.OrgProjectionTable,
domain.CountParentTypeOrganization,
projection.OrgColumnInstanceID,
projection.OrgColumnID,
"organization",
),
// Count triggers for all the resources
migration.CountTrigger(db,
projection.OrgProjectionTable,
domain.CountParentTypeInstance,
projection.OrgColumnInstanceID,
projection.OrgColumnInstanceID,
"organization",
),
migration.CountTrigger(db,
projection.ProjectProjectionTable,
domain.CountParentTypeOrganization,
projection.ProjectColumnInstanceID,
projection.ProjectColumnResourceOwner,
"project",
),
migration.CountTrigger(db,
projection.UserTable,
domain.CountParentTypeOrganization,
projection.UserInstanceIDCol,
projection.UserResourceOwnerCol,
"user",
),
migration.CountTrigger(db,
projection.InstanceMemberProjectionTable,
domain.CountParentTypeInstance,
projection.MemberInstanceID,
projection.MemberResourceOwner,
"iam_admin",
),
migration.CountTrigger(db,
projection.IDPTable,
domain.CountParentTypeInstance,
projection.IDPInstanceIDCol,
projection.IDPInstanceIDCol,
"identity_provider",
),
migration.CountTrigger(db,
projection.IDPTemplateLDAPTable,
domain.CountParentTypeInstance,
projection.LDAPInstanceIDCol,
projection.LDAPInstanceIDCol,
"identity_provider_ldap",
),
migration.CountTrigger(db,
projection.ActionTable,
domain.CountParentTypeInstance,
projection.ActionInstanceIDCol,
projection.ActionInstanceIDCol,
"action_v1",
),
migration.CountTrigger(db,
projection.ExecutionTable,
domain.CountParentTypeInstance,
projection.ExecutionInstanceIDCol,
projection.ExecutionInstanceIDCol,
"execution",
),
migration.CountTrigger(db,
fmt.Sprintf("%s_%s", projection.ExecutionTable, projection.ExecutionTargetSuffix),
domain.CountParentTypeInstance,
projection.ExecutionTargetInstanceIDCol,
projection.ExecutionTargetInstanceIDCol,
"execution_target",
),
migration.CountTrigger(db,
projection.LoginPolicyTable,
domain.CountParentTypeInstance,
projection.LoginPolicyInstanceIDCol,
projection.LoginPolicyInstanceIDCol,
"login_policy",
),
migration.CountTrigger(db,
projection.PasswordComplexityTable,
domain.CountParentTypeInstance,
projection.ComplexityPolicyInstanceIDCol,
projection.ComplexityPolicyInstanceIDCol,
"password_complexity_policy",
),
migration.CountTrigger(db,
projection.PasswordAgeTable,
domain.CountParentTypeInstance,
projection.AgePolicyInstanceIDCol,
projection.AgePolicyInstanceIDCol,
"password_expiry_policy",
),
migration.CountTrigger(db,
projection.LockoutPolicyTable,
domain.CountParentTypeInstance,
projection.LockoutPolicyInstanceIDCol,
projection.LockoutPolicyInstanceIDCol,
"lockout_policy",
),
}
}

View File

@@ -0,0 +1,9 @@
package domain
//go:generate enumer -type CountParentType -transform lower -trimprefix CountParentType -sql
type CountParentType int
const (
CountParentTypeInstance CountParentType = iota
CountParentTypeOrganization
)

View File

@@ -0,0 +1,109 @@
// Code generated by "enumer -type CountParentType -transform lower -trimprefix CountParentType -sql"; DO NOT EDIT.
package domain
import (
"database/sql/driver"
"fmt"
"strings"
)
const _CountParentTypeName = "instanceorganization"
var _CountParentTypeIndex = [...]uint8{0, 8, 20}
const _CountParentTypeLowerName = "instanceorganization"
func (i CountParentType) String() string {
if i < 0 || i >= CountParentType(len(_CountParentTypeIndex)-1) {
return fmt.Sprintf("CountParentType(%d)", i)
}
return _CountParentTypeName[_CountParentTypeIndex[i]:_CountParentTypeIndex[i+1]]
}
// An "invalid array index" compiler error signifies that the constant values have changed.
// Re-run the stringer command to generate them again.
func _CountParentTypeNoOp() {
var x [1]struct{}
_ = x[CountParentTypeInstance-(0)]
_ = x[CountParentTypeOrganization-(1)]
}
var _CountParentTypeValues = []CountParentType{CountParentTypeInstance, CountParentTypeOrganization}
var _CountParentTypeNameToValueMap = map[string]CountParentType{
_CountParentTypeName[0:8]: CountParentTypeInstance,
_CountParentTypeLowerName[0:8]: CountParentTypeInstance,
_CountParentTypeName[8:20]: CountParentTypeOrganization,
_CountParentTypeLowerName[8:20]: CountParentTypeOrganization,
}
var _CountParentTypeNames = []string{
_CountParentTypeName[0:8],
_CountParentTypeName[8:20],
}
// CountParentTypeString retrieves an enum value from the enum constants string name.
// Throws an error if the param is not part of the enum.
func CountParentTypeString(s string) (CountParentType, error) {
if val, ok := _CountParentTypeNameToValueMap[s]; ok {
return val, nil
}
if val, ok := _CountParentTypeNameToValueMap[strings.ToLower(s)]; ok {
return val, nil
}
return 0, fmt.Errorf("%s does not belong to CountParentType values", s)
}
// CountParentTypeValues returns all values of the enum
func CountParentTypeValues() []CountParentType {
return _CountParentTypeValues
}
// CountParentTypeStrings returns a slice of all String values of the enum
func CountParentTypeStrings() []string {
strs := make([]string, len(_CountParentTypeNames))
copy(strs, _CountParentTypeNames)
return strs
}
// IsACountParentType returns "true" if the value is listed in the enum definition. "false" otherwise
func (i CountParentType) IsACountParentType() bool {
for _, v := range _CountParentTypeValues {
if i == v {
return true
}
}
return false
}
func (i CountParentType) Value() (driver.Value, error) {
return i.String(), nil
}
func (i *CountParentType) Scan(value interface{}) error {
if value == nil {
return nil
}
var str string
switch v := value.(type) {
case []byte:
str = string(v)
case string:
str = v
case fmt.Stringer:
str = v.String()
default:
return fmt.Errorf("invalid value of CountParentType: %[1]T(%[1]v)", value)
}
val, err := CountParentTypeString(str)
if err != nil {
return err
}
*i = val
return nil
}

View File

@@ -4,11 +4,14 @@ package domain
import (
"fmt"
"strings"
)
const _SecretGeneratorTypeName = "unspecifiedinit_codeverify_email_codeverify_phone_codeverify_domainpassword_reset_codepasswordless_init_codeapp_secretotpsmsotp_emailinvite_codesecret_generator_type_count"
const _SecretGeneratorTypeName = "unspecifiedinit_codeverify_email_codeverify_phone_codeverify_domainpassword_reset_codepasswordless_init_codeapp_secretotpsmsotp_emailinvite_codesigning_keysecret_generator_type_count"
var _SecretGeneratorTypeIndex = [...]uint8{0, 11, 20, 37, 54, 67, 86, 108, 118, 124, 133, 144, 171}
var _SecretGeneratorTypeIndex = [...]uint8{0, 11, 20, 37, 54, 67, 86, 108, 118, 124, 133, 144, 155, 182}
const _SecretGeneratorTypeLowerName = "unspecifiedinit_codeverify_email_codeverify_phone_codeverify_domainpassword_reset_codepasswordless_init_codeapp_secretotpsmsotp_emailinvite_codesigning_keysecret_generator_type_count"
func (i SecretGeneratorType) String() string {
if i < 0 || i >= SecretGeneratorType(len(_SecretGeneratorTypeIndex)-1) {
@@ -17,21 +20,70 @@ func (i SecretGeneratorType) String() string {
return _SecretGeneratorTypeName[_SecretGeneratorTypeIndex[i]:_SecretGeneratorTypeIndex[i+1]]
}
var _SecretGeneratorTypeValues = []SecretGeneratorType{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}
// An "invalid array index" compiler error signifies that the constant values have changed.
// Re-run the stringer command to generate them again.
func _SecretGeneratorTypeNoOp() {
var x [1]struct{}
_ = x[SecretGeneratorTypeUnspecified-(0)]
_ = x[SecretGeneratorTypeInitCode-(1)]
_ = x[SecretGeneratorTypeVerifyEmailCode-(2)]
_ = x[SecretGeneratorTypeVerifyPhoneCode-(3)]
_ = x[SecretGeneratorTypeVerifyDomain-(4)]
_ = x[SecretGeneratorTypePasswordResetCode-(5)]
_ = x[SecretGeneratorTypePasswordlessInitCode-(6)]
_ = x[SecretGeneratorTypeAppSecret-(7)]
_ = x[SecretGeneratorTypeOTPSMS-(8)]
_ = x[SecretGeneratorTypeOTPEmail-(9)]
_ = x[SecretGeneratorTypeInviteCode-(10)]
_ = x[SecretGeneratorTypeSigningKey-(11)]
_ = x[secretGeneratorTypeCount-(12)]
}
var _SecretGeneratorTypeValues = []SecretGeneratorType{SecretGeneratorTypeUnspecified, SecretGeneratorTypeInitCode, SecretGeneratorTypeVerifyEmailCode, SecretGeneratorTypeVerifyPhoneCode, SecretGeneratorTypeVerifyDomain, SecretGeneratorTypePasswordResetCode, SecretGeneratorTypePasswordlessInitCode, SecretGeneratorTypeAppSecret, SecretGeneratorTypeOTPSMS, SecretGeneratorTypeOTPEmail, SecretGeneratorTypeInviteCode, SecretGeneratorTypeSigningKey, secretGeneratorTypeCount}
var _SecretGeneratorTypeNameToValueMap = map[string]SecretGeneratorType{
_SecretGeneratorTypeName[0:11]: 0,
_SecretGeneratorTypeName[11:20]: 1,
_SecretGeneratorTypeName[20:37]: 2,
_SecretGeneratorTypeName[37:54]: 3,
_SecretGeneratorTypeName[54:67]: 4,
_SecretGeneratorTypeName[67:86]: 5,
_SecretGeneratorTypeName[86:108]: 6,
_SecretGeneratorTypeName[108:118]: 7,
_SecretGeneratorTypeName[118:124]: 8,
_SecretGeneratorTypeName[124:133]: 9,
_SecretGeneratorTypeName[133:144]: 10,
_SecretGeneratorTypeName[144:171]: 11,
_SecretGeneratorTypeName[0:11]: SecretGeneratorTypeUnspecified,
_SecretGeneratorTypeLowerName[0:11]: SecretGeneratorTypeUnspecified,
_SecretGeneratorTypeName[11:20]: SecretGeneratorTypeInitCode,
_SecretGeneratorTypeLowerName[11:20]: SecretGeneratorTypeInitCode,
_SecretGeneratorTypeName[20:37]: SecretGeneratorTypeVerifyEmailCode,
_SecretGeneratorTypeLowerName[20:37]: SecretGeneratorTypeVerifyEmailCode,
_SecretGeneratorTypeName[37:54]: SecretGeneratorTypeVerifyPhoneCode,
_SecretGeneratorTypeLowerName[37:54]: SecretGeneratorTypeVerifyPhoneCode,
_SecretGeneratorTypeName[54:67]: SecretGeneratorTypeVerifyDomain,
_SecretGeneratorTypeLowerName[54:67]: SecretGeneratorTypeVerifyDomain,
_SecretGeneratorTypeName[67:86]: SecretGeneratorTypePasswordResetCode,
_SecretGeneratorTypeLowerName[67:86]: SecretGeneratorTypePasswordResetCode,
_SecretGeneratorTypeName[86:108]: SecretGeneratorTypePasswordlessInitCode,
_SecretGeneratorTypeLowerName[86:108]: SecretGeneratorTypePasswordlessInitCode,
_SecretGeneratorTypeName[108:118]: SecretGeneratorTypeAppSecret,
_SecretGeneratorTypeLowerName[108:118]: SecretGeneratorTypeAppSecret,
_SecretGeneratorTypeName[118:124]: SecretGeneratorTypeOTPSMS,
_SecretGeneratorTypeLowerName[118:124]: SecretGeneratorTypeOTPSMS,
_SecretGeneratorTypeName[124:133]: SecretGeneratorTypeOTPEmail,
_SecretGeneratorTypeLowerName[124:133]: SecretGeneratorTypeOTPEmail,
_SecretGeneratorTypeName[133:144]: SecretGeneratorTypeInviteCode,
_SecretGeneratorTypeLowerName[133:144]: SecretGeneratorTypeInviteCode,
_SecretGeneratorTypeName[144:155]: SecretGeneratorTypeSigningKey,
_SecretGeneratorTypeLowerName[144:155]: SecretGeneratorTypeSigningKey,
_SecretGeneratorTypeName[155:182]: secretGeneratorTypeCount,
_SecretGeneratorTypeLowerName[155:182]: secretGeneratorTypeCount,
}
var _SecretGeneratorTypeNames = []string{
_SecretGeneratorTypeName[0:11],
_SecretGeneratorTypeName[11:20],
_SecretGeneratorTypeName[20:37],
_SecretGeneratorTypeName[37:54],
_SecretGeneratorTypeName[54:67],
_SecretGeneratorTypeName[67:86],
_SecretGeneratorTypeName[86:108],
_SecretGeneratorTypeName[108:118],
_SecretGeneratorTypeName[118:124],
_SecretGeneratorTypeName[124:133],
_SecretGeneratorTypeName[133:144],
_SecretGeneratorTypeName[144:155],
_SecretGeneratorTypeName[155:182],
}
// SecretGeneratorTypeString retrieves an enum value from the enum constants string name.
@@ -40,6 +92,10 @@ func SecretGeneratorTypeString(s string) (SecretGeneratorType, error) {
if val, ok := _SecretGeneratorTypeNameToValueMap[s]; ok {
return val, nil
}
if val, ok := _SecretGeneratorTypeNameToValueMap[strings.ToLower(s)]; ok {
return val, nil
}
return 0, fmt.Errorf("%s does not belong to SecretGeneratorType values", s)
}
@@ -48,6 +104,13 @@ func SecretGeneratorTypeValues() []SecretGeneratorType {
return _SecretGeneratorTypeValues
}
// SecretGeneratorTypeStrings returns a slice of all String values of the enum
func SecretGeneratorTypeStrings() []string {
strs := make([]string, len(_SecretGeneratorTypeNames))
copy(strs, _SecretGeneratorTypeNames)
return strs
}
// IsASecretGeneratorType returns "true" if the value is listed in the enum definition. "false" otherwise
func (i SecretGeneratorType) IsASecretGeneratorType() bool {
for _, v := range _SecretGeneratorTypeValues {

View File

@@ -0,0 +1,43 @@
{{ define "count_trigger" -}}
CREATE OR REPLACE TRIGGER count_{{ .Resource }}
AFTER INSERT OR DELETE
ON {{ .Table }}
FOR EACH ROW
EXECUTE FUNCTION projections.count_resource(
'{{ .ParentType }}',
'{{ .InstanceIDColumn }}',
'{{ .ParentIDColumn }}',
'{{ .Resource }}'
);
CREATE OR REPLACE TRIGGER truncate_{{ .Resource }}_counts
AFTER TRUNCATE
ON {{ .Table }}
FOR EACH STATEMENT
EXECUTE FUNCTION projections.delete_table_counts();
-- Prevent inserts and deletes while we populate the counts.
LOCK TABLE {{ .Table }} IN SHARE MODE;
-- Populate the resource counts for the existing data in the table.
INSERT INTO projections.resource_counts(
instance_id,
table_name,
parent_type,
parent_id,
resource_name,
amount
)
SELECT
{{ .InstanceIDColumn }},
'{{ .Table }}',
'{{ .ParentType }}',
{{ .ParentIDColumn }},
'{{ .Resource }}',
COUNT(*) AS amount
FROM {{ .Table }}
GROUP BY ({{ .InstanceIDColumn }}, {{ .ParentIDColumn }})
ON CONFLICT (instance_id, table_name, parent_type, parent_id) DO
UPDATE SET updated_at = now(), amount = EXCLUDED.amount;
{{- end -}}

View File

@@ -0,0 +1,13 @@
{{ define "delete_parent_counts_trigger" -}}
CREATE OR REPLACE TRIGGER delete_parent_counts_trigger
AFTER DELETE
ON {{ .Table }}
FOR EACH ROW
EXECUTE FUNCTION projections.delete_parent_counts(
'{{ .ParentType }}',
'{{ .InstanceIDColumn }}',
'{{ .ParentIDColumn }}'
);
{{- end -}}

View File

@@ -36,7 +36,10 @@ type errCheckerMigration interface {
type RepeatableMigration interface {
Migration
Check(lastRun map[string]interface{}) bool
// Check if the migration should be executed again.
// True will repeat the migration, false will not.
Check(lastRun map[string]any) bool
}
func Migrate(ctx context.Context, es *eventstore.Eventstore, migration Migration) (err error) {

View File

@@ -0,0 +1,127 @@
package migration
import (
"context"
"embed"
"fmt"
"strings"
"text/template"
"github.com/mitchellh/mapstructure"
"github.com/zitadel/zitadel/internal/database"
"github.com/zitadel/zitadel/internal/domain"
"github.com/zitadel/zitadel/internal/eventstore"
)
const (
countTriggerTmpl = "count_trigger"
deleteParentCountsTmpl = "delete_parent_counts_trigger"
)
var (
//go:embed *.sql
templateFS embed.FS
templates = template.Must(template.ParseFS(templateFS, "*.sql"))
)
// CountTrigger registers the existing projections.count_trigger function.
// The trigger than takes care of keeping count of existing
// rows in the source table.
// It also pre-populates the projections.resource_counts table with
// the counts for the given table.
//
// During the population of the resource_counts table,
// the source table is share-locked to prevent concurrent modifications.
// Projection handlers will be halted until the lock is released.
// SELECT statements are not blocked by the lock.
//
// This migration repeats when any of the arguments are changed,
// such as renaming of a projection table.
func CountTrigger(
db *database.DB,
table string,
parentType domain.CountParentType,
instanceIDColumn string,
parentIDColumn string,
resource string,
) RepeatableMigration {
return &triggerMigration{
triggerConfig: triggerConfig{
Table: table,
ParentType: parentType.String(),
InstanceIDColumn: instanceIDColumn,
ParentIDColumn: parentIDColumn,
Resource: resource,
},
db: db,
templateName: countTriggerTmpl,
}
}
// DeleteParentCountsTrigger
//
// This migration repeats when any of the arguments are changed,
// such as renaming of a projection table.
func DeleteParentCountsTrigger(
db *database.DB,
table string,
parentType domain.CountParentType,
instanceIDColumn string,
parentIDColumn string,
resource string,
) RepeatableMigration {
return &triggerMigration{
triggerConfig: triggerConfig{
Table: table,
ParentType: parentType.String(),
InstanceIDColumn: instanceIDColumn,
ParentIDColumn: parentIDColumn,
Resource: resource,
},
db: db,
templateName: deleteParentCountsTmpl,
}
}
type triggerMigration struct {
triggerConfig
db *database.DB
templateName string
}
// String implements [Migration] and [fmt.Stringer].
func (m *triggerMigration) String() string {
return fmt.Sprintf("repeatable_%s_%s", m.Resource, m.templateName)
}
// Execute implements [Migration]
func (m *triggerMigration) Execute(ctx context.Context, _ eventstore.Event) error {
var query strings.Builder
err := templates.ExecuteTemplate(&query, m.templateName, m.triggerConfig)
if err != nil {
return fmt.Errorf("%s: execute trigger template: %w", m, err)
}
_, err = m.db.ExecContext(ctx, query.String())
if err != nil {
return fmt.Errorf("%s: exec trigger query: %w", m, err)
}
return nil
}
type triggerConfig struct {
Table string `json:"table,omitempty" mapstructure:"table"`
ParentType string `json:"parent_type,omitempty" mapstructure:"parent_type"`
InstanceIDColumn string `json:"instance_id_column,omitempty" mapstructure:"instance_id_column"`
ParentIDColumn string `json:"parent_id_column,omitempty" mapstructure:"parent_id_column"`
Resource string `json:"resource,omitempty" mapstructure:"resource"`
}
// Check implements [RepeatableMigration].
func (c *triggerConfig) Check(lastRun map[string]any) bool {
var dst triggerConfig
if err := mapstructure.Decode(lastRun, &dst); err != nil {
panic(err)
}
return dst != *c
}

View File

@@ -0,0 +1,253 @@
package migration
import (
"context"
"regexp"
"testing"
"github.com/DATA-DOG/go-sqlmock"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/zitadel/zitadel/internal/database"
)
const (
expCountTriggerQuery = `CREATE OR REPLACE TRIGGER count_resource
AFTER INSERT OR DELETE
ON table
FOR EACH ROW
EXECUTE FUNCTION projections.count_resource(
'instance',
'instance_id',
'parent_id',
'resource'
);
CREATE OR REPLACE TRIGGER truncate_resource_counts
AFTER TRUNCATE
ON table
FOR EACH STATEMENT
EXECUTE FUNCTION projections.delete_table_counts();
-- Prevent inserts and deletes while we populate the counts.
LOCK TABLE table IN SHARE MODE;
-- Populate the resource counts for the existing data in the table.
INSERT INTO projections.resource_counts(
instance_id,
table_name,
parent_type,
parent_id,
resource_name,
amount
)
SELECT
instance_id,
'table',
'instance',
parent_id,
'resource',
COUNT(*) AS amount
FROM table
GROUP BY (instance_id, parent_id)
ON CONFLICT (instance_id, table_name, parent_type, parent_id) DO
UPDATE SET updated_at = now(), amount = EXCLUDED.amount;`
expDeleteParentCountsQuery = `CREATE OR REPLACE TRIGGER delete_parent_counts_trigger
AFTER DELETE
ON table
FOR EACH ROW
EXECUTE FUNCTION projections.delete_parent_counts(
'instance',
'instance_id',
'parent_id'
);`
)
func Test_triggerMigration_Execute(t *testing.T) {
type fields struct {
triggerConfig triggerConfig
templateName string
}
tests := []struct {
name string
fields fields
expects func(sqlmock.Sqlmock)
wantErr bool
}{
{
name: "template error",
fields: fields{
triggerConfig: triggerConfig{
Table: "table",
ParentType: "instance",
InstanceIDColumn: "instance_id",
ParentIDColumn: "parent_id",
Resource: "resource",
},
templateName: "foo",
},
expects: func(_ sqlmock.Sqlmock) {},
wantErr: true,
},
{
name: "db error",
fields: fields{
triggerConfig: triggerConfig{
Table: "table",
ParentType: "instance",
InstanceIDColumn: "instance_id",
ParentIDColumn: "parent_id",
Resource: "resource",
},
templateName: countTriggerTmpl,
},
expects: func(mock sqlmock.Sqlmock) {
mock.ExpectExec(regexp.QuoteMeta(expCountTriggerQuery)).
WillReturnError(assert.AnError)
},
wantErr: true,
},
{
name: "count trigger",
fields: fields{
triggerConfig: triggerConfig{
Table: "table",
ParentType: "instance",
InstanceIDColumn: "instance_id",
ParentIDColumn: "parent_id",
Resource: "resource",
},
templateName: countTriggerTmpl,
},
expects: func(mock sqlmock.Sqlmock) {
mock.ExpectExec(regexp.QuoteMeta(expCountTriggerQuery)).
WithoutArgs().
WillReturnResult(
sqlmock.NewResult(1, 1),
)
},
},
{
name: "count trigger",
fields: fields{
triggerConfig: triggerConfig{
Table: "table",
ParentType: "instance",
InstanceIDColumn: "instance_id",
ParentIDColumn: "parent_id",
Resource: "resource",
},
templateName: deleteParentCountsTmpl,
},
expects: func(mock sqlmock.Sqlmock) {
mock.ExpectExec(regexp.QuoteMeta(expDeleteParentCountsQuery)).
WithoutArgs().
WillReturnResult(
sqlmock.NewResult(1, 1),
)
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
db, mock, err := sqlmock.New()
require.NoError(t, err)
defer func() {
err := mock.ExpectationsWereMet()
require.NoError(t, err)
}()
defer db.Close()
tt.expects(mock)
mock.ExpectClose()
m := &triggerMigration{
db: &database.DB{
DB: db,
},
triggerConfig: tt.fields.triggerConfig,
templateName: tt.fields.templateName,
}
err = m.Execute(context.Background(), nil)
if tt.wantErr {
assert.Error(t, err)
return
}
require.NoError(t, err)
})
}
}
func Test_triggerConfig_Check(t *testing.T) {
type fields struct {
Table string
ParentType string
InstanceIDColumn string
ParentIDColumn string
Resource string
}
type args struct {
lastRun map[string]any
}
tests := []struct {
name string
fields fields
args args
want bool
}{
{
name: "should",
fields: fields{
Table: "users2",
ParentType: "instance",
InstanceIDColumn: "instance_id",
ParentIDColumn: "parent_id",
Resource: "user",
},
args: args{
lastRun: map[string]any{
"table": "users1",
"parent_type": "instance",
"instance_id_column": "instance_id",
"parent_id_column": "parent_id",
"resource": "user",
},
},
want: true,
},
{
name: "should not",
fields: fields{
Table: "users1",
ParentType: "instance",
InstanceIDColumn: "instance_id",
ParentIDColumn: "parent_id",
Resource: "user",
},
args: args{
lastRun: map[string]any{
"table": "users1",
"parent_type": "instance",
"instance_id_column": "instance_id",
"parent_id_column": "parent_id",
"resource": "user",
},
},
want: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
c := &triggerConfig{
Table: tt.fields.Table,
ParentType: tt.fields.ParentType,
InstanceIDColumn: tt.fields.InstanceIDColumn,
ParentIDColumn: tt.fields.ParentIDColumn,
Resource: tt.fields.Resource,
}
got := c.Check(tt.args.lastRun)
assert.Equal(t, tt.want, got)
})
}
}

View File

@@ -0,0 +1,61 @@
package query
import (
"context"
"database/sql"
_ "embed"
"time"
"github.com/zitadel/zitadel/internal/domain"
"github.com/zitadel/zitadel/internal/telemetry/tracing"
"github.com/zitadel/zitadel/internal/zerrors"
)
var (
//go:embed resource_counts_list.sql
resourceCountsListQuery string
)
type ResourceCount struct {
ID int // Primary key, used for pagination
InstanceID string
TableName string
ParentType domain.CountParentType
ParentID string
Resource string
UpdatedAt time.Time
Amount int
}
// ListResourceCounts retrieves all resource counts.
// It supports pagination using lastID and limit parameters.
//
// TODO: Currently only a proof of concept, filters may be implemented later if required.
func (q *Queries) ListResourceCounts(ctx context.Context, lastID, limit int) (result []ResourceCount, err error) {
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()
err = q.client.QueryContext(ctx, func(rows *sql.Rows) error {
for rows.Next() {
var count ResourceCount
err := rows.Scan(
&count.ID,
&count.InstanceID,
&count.TableName,
&count.ParentType,
&count.ParentID,
&count.Resource,
&count.UpdatedAt,
&count.Amount)
if err != nil {
return zerrors.ThrowInternal(err, "QUERY-2f4g5", "Errors.Internal")
}
result = append(result, count)
}
return nil
}, resourceCountsListQuery, lastID, limit)
if err != nil {
return nil, zerrors.ThrowInternal(err, "QUERY-3f4g5", "Errors.Internal")
}
return result, nil
}

View File

@@ -0,0 +1,12 @@
SELECT id,
instance_id,
table_name,
parent_type,
parent_id,
resource_name,
updated_at,
amount
FROM projections.resource_counts
WHERE id > $1
ORDER BY id
LIMIT $2;

View File

@@ -0,0 +1,109 @@
package query
import (
"context"
_ "embed"
"regexp"
"testing"
"time"
"github.com/DATA-DOG/go-sqlmock"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/zitadel/zitadel/internal/database"
"github.com/zitadel/zitadel/internal/domain"
)
func TestQueries_ListResourceCounts(t *testing.T) {
columns := []string{"id", "instance_id", "table_name", "parent_type", "parent_id", "resource_name", "updated_at", "amount"}
type args struct {
lastID int
limit int
}
tests := []struct {
name string
args args
expects func(sqlmock.Sqlmock)
wantResult []ResourceCount
wantErr bool
}{
{
name: "query error",
args: args{
lastID: 0,
limit: 10,
},
expects: func(mock sqlmock.Sqlmock) {
mock.ExpectQuery(regexp.QuoteMeta(resourceCountsListQuery)).
WithArgs(0, 10).
WillReturnError(assert.AnError)
},
wantErr: true,
},
{
name: "success",
args: args{
lastID: 0,
limit: 10,
},
expects: func(mock sqlmock.Sqlmock) {
mock.ExpectQuery(regexp.QuoteMeta(resourceCountsListQuery)).
WithArgs(0, 10).
WillReturnRows(
sqlmock.NewRows(columns).
AddRow(1, "instance_1", "table", "instance", "parent_1", "resource_name", time.Unix(1, 2), 5).
AddRow(2, "instance_2", "table", "instance", "parent_2", "resource_name", time.Unix(1, 2), 6),
)
},
wantResult: []ResourceCount{
{
ID: 1,
InstanceID: "instance_1",
TableName: "table",
ParentType: domain.CountParentTypeInstance,
ParentID: "parent_1",
Resource: "resource_name",
UpdatedAt: time.Unix(1, 2),
Amount: 5,
},
{
ID: 2,
InstanceID: "instance_2",
TableName: "table",
ParentType: domain.CountParentTypeInstance,
ParentID: "parent_2",
Resource: "resource_name",
UpdatedAt: time.Unix(1, 2),
Amount: 6,
},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
db, mock, err := sqlmock.New()
require.NoError(t, err)
defer func() {
err := mock.ExpectationsWereMet()
require.NoError(t, err)
}()
defer db.Close()
tt.expects(mock)
mock.ExpectClose()
q := &Queries{
client: &database.DB{
DB: db,
},
}
gotResult, err := q.ListResourceCounts(context.Background(), tt.args.lastID, tt.args.limit)
if tt.wantErr {
require.Error(t, err)
return
}
require.NoError(t, err)
assert.Equal(t, tt.wantResult, gotResult, "ListResourceCounts() result mismatch")
})
}
}