diff --git a/backend/v3/domain/domain.go b/backend/v3/domain/domain.go index e6ac73dac1..fbe4f04192 100644 --- a/backend/v3/domain/domain.go +++ b/backend/v3/domain/domain.go @@ -2,28 +2,35 @@ package domain import "github.com/zitadel/zitadel/backend/v3/storage/database" -type DomainVerificationType string +type DomainValidationType string const ( - DomainVerificationTypeDNS DomainVerificationType = "dns" - DomainVerificationTypeHTTP DomainVerificationType = "http" + DomainValidationTypeDNS DomainValidationType = "dns" + DomainValidationTypeHTTP DomainValidationType = "http" ) type domainColumns interface { // InstanceIDColumn returns the column for the instance id field. - InstanceIDColumn() database.Column + // `qualified` indicates if the column should be qualified with the table name. + InstanceIDColumn(qualified bool) database.Column // DomainColumn returns the column for the domain field. - DomainColumn() database.Column + // `qualified` indicates if the column should be qualified with the table name. + DomainColumn(qualified bool) database.Column // IsVerifiedColumn returns the column for the is verified field. - IsVerifiedColumn() database.Column + // `qualified` indicates if the column should be qualified with the table name. + IsVerifiedColumn(qualified bool) database.Column // IsPrimaryColumn returns the column for the is primary field. - IsPrimaryColumn() database.Column - // VerificationTypeColumn returns the column for the verification type field. - VerificationTypeColumn() database.Column + // `qualified` indicates if the column should be qualified with the table name. + IsPrimaryColumn(qualified bool) database.Column + // ValidationTypeColumn returns the column for the verification type field. + // `qualified` indicates if the column should be qualified with the table name. + ValidationTypeColumn(qualified bool) database.Column // CreatedAtColumn returns the column for the created at field. - CreatedAtColumn() database.Column + // `qualified` indicates if the column should be qualified with the table name. + CreatedAtColumn(qualified bool) database.Column // UpdatedAtColumn returns the column for the updated at field. - UpdatedAtColumn() database.Column + // `qualified` indicates if the column should be qualified with the table name. + UpdatedAtColumn(qualified bool) database.Column } type domainConditions interface { @@ -51,9 +58,9 @@ type domainChanges interface { // - The domain is already primary. // - No domain matches the condition. SetPrimary() database.Change - // SetVerificationType sets the verification type column. + // SetValidationType sets the verification type column. // If the domain is already verified, this is a no-op. - SetVerificationType(verificationType DomainVerificationType) database.Change + SetValidationType(verificationType DomainValidationType) database.Change } // import ( diff --git a/backend/v3/domain/instance.go b/backend/v3/domain/instance.go index 53fced3745..cc08791d38 100644 --- a/backend/v3/domain/instance.go +++ b/backend/v3/domain/instance.go @@ -18,6 +18,8 @@ type Instance struct { DefaultLanguage string `json:"defaultLanguage,omitempty" db:"default_language"` CreatedAt time.Time `json:"createdAt" db:"created_at"` UpdatedAt time.Time `json:"updatedAt" db:"updated_at"` + + Domains []*InstanceDomain `json:"domains,omitempty" db:"-"` } type instanceCacheIndex uint8 @@ -40,23 +42,32 @@ var _ cache.Entry[instanceCacheIndex, string] = (*Instance)(nil) // instanceColumns define all the columns of the instance table. type instanceColumns interface { // IDColumn returns the column for the id field. - IDColumn() database.Column + // `qualified` indicates if the column should be qualified with the table name. + IDColumn(qualified bool) database.Column // NameColumn returns the column for the name field. - NameColumn() database.Column + // `qualified` indicates if the column should be qualified with the table name. + NameColumn(qualified bool) database.Column // DefaultOrgIDColumn returns the column for the default org id field - DefaultOrgIDColumn() database.Column + // `qualified` indicates if the column should be qualified with the table name. + DefaultOrgIDColumn(qualified bool) database.Column // IAMProjectIDColumn returns the column for the default IAM org id field - IAMProjectIDColumn() database.Column + // `qualified` indicates if the column should be qualified with the table name. + IAMProjectIDColumn(qualified bool) database.Column // ConsoleClientIDColumn returns the column for the default IAM org id field - ConsoleClientIDColumn() database.Column + // `qualified` indicates if the column should be qualified with the table name. + ConsoleClientIDColumn(qualified bool) database.Column // ConsoleAppIDColumn returns the column for the console client id field - ConsoleAppIDColumn() database.Column + // `qualified` indicates if the column should be qualified with the table name. + ConsoleAppIDColumn(qualified bool) database.Column // DefaultLanguageColumn returns the column for the default language field - DefaultLanguageColumn() database.Column + // `qualified` indicates if the column should be qualified with the table name. + DefaultLanguageColumn(qualified bool) database.Column // CreatedAtColumn returns the column for the created at field. - CreatedAtColumn() database.Column + // `qualified` indicates if the column should be qualified with the table name. + CreatedAtColumn(qualified bool) database.Column // UpdatedAtColumn returns the column for the updated at field. - UpdatedAtColumn() database.Column + // `qualified` indicates if the column should be qualified with the table name. + UpdatedAtColumn(qualified bool) database.Column } // instanceConditions define all the conditions for the instance table. @@ -83,16 +94,27 @@ type InstanceRepository interface { // Member returns the member repository which is a sub repository of the instance repository. // Member() MemberRepository - Get(ctx context.Context, id string) (*Instance, error) - List(ctx context.Context, opts ...database.Condition) ([]*Instance, error) + Get(ctx context.Context, opts ...database.QueryOption) (*Instance, error) + List(ctx context.Context, opts ...database.QueryOption) ([]*Instance, error) Create(ctx context.Context, instance *Instance) error Update(ctx context.Context, id string, changes ...database.Change) (int64, error) Delete(ctx context.Context, id string) (int64, error) - Domains() InstanceDomainRepository + // Domains returns the domain sub repository for the instance. + // If shouldLoad is true, the domains will be loaded from the database and written to the [Instance].Domains field. + // If shouldLoad is set to true once, the Domains field will be set event if shouldLoad is false in the future. + Domains(shouldLoad bool) InstanceDomainRepository } type CreateInstance struct { Name string `json:"name"` } + +type InstanceQueryOption func(*InstanceQueryOpts) + +type InstanceQueryOpts struct { + database.QueryOpts + + JoinDomains bool +} \ No newline at end of file diff --git a/backend/v3/domain/instance_domain.go b/backend/v3/domain/instance_domain.go index db648b6740..eadbd09406 100644 --- a/backend/v3/domain/instance_domain.go +++ b/backend/v3/domain/instance_domain.go @@ -2,17 +2,23 @@ package domain import ( "context" + "encoding/json" "time" "github.com/zitadel/zitadel/backend/v3/storage/database" ) +type InstanceDomains struct { + domains []*InstanceDomain + Raw json.RawMessage +} + type InstanceDomain struct { InstanceID string `json:"instanceId,omitempty" db:"instance_id"` Domain string `json:"domain,omitempty" db:"domain"` IsVerified bool `json:"isVerified,omitempty" db:"is_verified"` IsPrimary bool `json:"isPrimary,omitempty" db:"is_primary"` - VerificationType DomainVerificationType `json:"verificationType,omitempty" db:"verification_type"` + ValidationType DomainValidationType `json:"validationType,omitempty" db:"validation_type"` CreatedAt string `json:"createdAt,omitempty" db:"created_at"` UpdatedAt string `json:"updatedAt,omitempty" db:"updated_at"` @@ -23,7 +29,7 @@ type AddInstanceDomain struct { Domain string `json:"domain,omitempty" db:"domain"` IsVerified bool `json:"isVerified,omitempty" db:"is_verified"` IsPrimary bool `json:"isPrimary,omitempty" db:"is_primary"` - VerificationType DomainVerificationType `json:"verificationType,omitempty" db:"verification_type"` + VerificationType DomainValidationType `json:"validationType,omitempty" db:"validation_type"` // CreatedAt is the time when the domain was added. // It is set by the repository and should not be set by the caller. @@ -36,7 +42,8 @@ type AddInstanceDomain struct { type instanceDomainColumns interface { domainColumns // IsGeneratedColumn returns the column for the is generated field. - IsGeneratedColumn() database.Column + // `qualified` indicates if the column should be qualified with the table name. + IsGeneratedColumn(qualified bool) database.Column } type instanceDomainConditions interface { diff --git a/backend/v3/domain/organization.go b/backend/v3/domain/organization.go index df497d2c43..255c564a75 100644 --- a/backend/v3/domain/organization.go +++ b/backend/v3/domain/organization.go @@ -19,8 +19,10 @@ type Organization struct { Name string `json:"name,omitempty" db:"name"` InstanceID string `json:"instanceId,omitempty" db:"instance_id"` State OrgState `json:"state,omitempty" db:"state"` - CreatedAt time.Time `json:"createdAt,omitempty" db:"created_at"` - UpdatedAt time.Time `json:"updatedAt,omitempty" db:"updated_at"` + CreatedAt time.Time `json:"createdAt,omitzero" db:"created_at"` + UpdatedAt time.Time `json:"updatedAt,omitzero" db:"updated_at"` + + Domains []*OrganizationDomain `json:"domains,omitempty" db:"-"` } // OrgIdentifierCondition is used to help specify a single Organization, @@ -33,17 +35,23 @@ type OrgIdentifierCondition interface { // organizationColumns define all the columns of the instance table. type organizationColumns interface { // IDColumn returns the column for the id field. - IDColumn() database.Column + // `qualified` indicates if the column should be qualified with the table name. + IDColumn(qualified bool) database.Column // NameColumn returns the column for the name field. - NameColumn() database.Column + // `qualified` indicates if the column should be qualified with the table name. + NameColumn(qualified bool) database.Column // InstanceIDColumn returns the column for the default org id field - InstanceIDColumn() database.Column + // `qualified` indicates if the column should be qualified with the table name. + InstanceIDColumn(qualified bool) database.Column // StateColumn returns the column for the name field. - StateColumn() database.Column + // `qualified` indicates if the column should be qualified with the table name. + StateColumn(qualified bool) database.Column // CreatedAtColumn returns the column for the created at field. - CreatedAtColumn() database.Column + // `qualified` indicates if the column should be qualified with the table name. + CreatedAtColumn(qualified bool) database.Column // UpdatedAtColumn returns the column for the updated at field. - UpdatedAtColumn() database.Column + // `qualified` indicates if the column should be qualified with the table name. + UpdatedAtColumn(qualified bool) database.Column } // organizationConditions define all the conditions for the instance table. @@ -72,15 +80,17 @@ type OrganizationRepository interface { organizationConditions organizationChanges - Get(ctx context.Context, id OrgIdentifierCondition, instance_id string, opts ...database.Condition) (*Organization, error) - List(ctx context.Context, conditions ...database.Condition) ([]*Organization, error) + Get(ctx context.Context, opts ...database.QueryOption) (*Organization, error) + List(ctx context.Context, opts ...database.QueryOption) ([]*Organization, error) Create(ctx context.Context, instance *Organization) error Update(ctx context.Context, id OrgIdentifierCondition, instance_id string, changes ...database.Change) (int64, error) Delete(ctx context.Context, id OrgIdentifierCondition, instance_id string) (int64, error) // Domains returns the domain sub repository for the organization. - Domains() OrganizationDomainRepository + // If shouldLoad is true, the domains will be loaded from the database and written to the [Instance].Domains field. + // If shouldLoad is set to true once, the Domains field will be set event if shouldLoad is false in the future. + Domains(shouldLoad bool) OrganizationDomainRepository } type CreateOrganization struct { diff --git a/backend/v3/domain/organization_domain.go b/backend/v3/domain/organization_domain.go index 7e244dd401..9f8aac9ccc 100644 --- a/backend/v3/domain/organization_domain.go +++ b/backend/v3/domain/organization_domain.go @@ -13,7 +13,7 @@ type OrganizationDomain struct { Domain string `json:"domain,omitempty" db:"domain"` IsVerified bool `json:"isVerified,omitempty" db:"is_verified"` IsPrimary bool `json:"isPrimary,omitempty" db:"is_primary"` - VerificationType DomainVerificationType `json:"verificationType,omitempty" db:"verification_type"` + ValidationType DomainValidationType `json:"validationType,omitempty" db:"validation_type"` CreatedAt string `json:"createdAt,omitempty" db:"created_at"` UpdatedAt string `json:"updatedAt,omitempty" db:"updated_at"` @@ -25,7 +25,7 @@ type AddOrganizationDomain struct { Domain string `json:"domain,omitempty" db:"domain"` IsVerified bool `json:"isVerified,omitempty" db:"is_verified"` IsPrimary bool `json:"isPrimary,omitempty" db:"is_primary"` - VerificationType DomainVerificationType `json:"verificationType,omitempty" db:"verification_type"` + ValidationType DomainValidationType `json:"validationType,omitempty" db:"validation_type"` // CreatedAt is the time when the domain was added. // It is set by the repository and should not be set by the caller. @@ -38,7 +38,7 @@ type AddOrganizationDomain struct { type organizationDomainColumns interface { domainColumns // OrgIDColumn returns the column for the org id field. - OrgIDColumn() database.Column + OrgIDColumn(qualified bool) database.Column } type organizationDomainConditions interface { diff --git a/backend/v3/storage/database/condition.go b/backend/v3/storage/database/condition.go index 55f1e862e6..276b60514a 100644 --- a/backend/v3/storage/database/condition.go +++ b/backend/v3/storage/database/condition.go @@ -113,6 +113,15 @@ func NewBooleanCondition[V Boolean](col Column, value V) Condition { }) } +// NewColumnCondition creates a condition that compares two columns on equality. +func NewColumnCondition(col1, col2 Column) Condition { + return valueCondition(func(builder *StatementBuilder) { + col1.Write(builder) + builder.WriteString(" = ") + col2.Write(builder) + }) +} + // Write implements [Condition]. func (c valueCondition) Write(builder *StatementBuilder) { c(builder) diff --git a/backend/v3/storage/database/errors.go b/backend/v3/storage/database/errors.go index 1856960bb1..4504efb1ae 100644 --- a/backend/v3/storage/database/errors.go +++ b/backend/v3/storage/database/errors.go @@ -1,9 +1,12 @@ package database import ( + "errors" "fmt" ) +var NoChangesError = errors.New("Update must contain a change") + // NoRowFoundError is returned when QueryRow does not find any row. // It wraps the dialect specific original error to provide more context. type NoRowFoundError struct { diff --git a/backend/v3/storage/database/query.go b/backend/v3/storage/database/query.go index 9ae8b10804..640a03d7e2 100644 --- a/backend/v3/storage/database/query.go +++ b/backend/v3/storage/database/query.go @@ -30,6 +30,36 @@ func WithOffset(offset uint32) QueryOption { } } +// WithGroupBy sets the columns to group the results by. +func WithGroupBy(groupBy ...Column) QueryOption { + return func(opts *QueryOpts) { + opts.GroupBy = groupBy + } +} + +// WithLeftJoin adds a LEFT JOIN to the query. +func WithLeftJoin(table string, columns Condition) QueryOption { + return func(opts *QueryOpts) { + opts.Joins = append(opts.Joins, join{ + table: table, + typ: JoinTypeLeft, + columns: columns, + }) + } +} + +type joinType string + +const ( + JoinTypeLeft joinType = "LEFT" +) + +type join struct { + table string + typ joinType + columns Condition +} + // QueryOpts holds the options for a query. // It is used to build the SQL SELECT statement. type QueryOpts struct { @@ -45,10 +75,19 @@ type QueryOpts struct { // Offset is the number of results to skip before returning the results. // It is used to build the OFFSET clause of the SQL statement. Offset uint32 + // GroupBy is the columns to group the results by. + // It is used to build the GROUP BY clause of the SQL statement. + GroupBy Columns + // Joins is a list of joins to be applied to the query. + // It is used to build the JOIN clauses of the SQL statement. + Joins []join } + func (opts *QueryOpts) Write(builder *StatementBuilder) { + opts.WriteLeftJoins(builder) opts.WriteCondition(builder) + opts.WriteGroupBy(builder) opts.WriteOrderBy(builder) opts.WriteLimit(builder) opts.WriteOffset(builder) @@ -85,3 +124,25 @@ func (opts *QueryOpts) WriteOffset(builder *StatementBuilder) { builder.WriteString(" OFFSET ") builder.WriteArg(opts.Offset) } + +func (opts *QueryOpts) WriteGroupBy(builder *StatementBuilder) { + if len(opts.GroupBy) == 0 { + return + } + builder.WriteString(" GROUP BY ") + opts.GroupBy.Write(builder) +} + +func (opts *QueryOpts) WriteLeftJoins(builder *StatementBuilder) { + if len(opts.Joins) == 0 { + return + } + for _, join := range opts.Joins { + builder.WriteString(" ") + builder.WriteString(string(join.typ)) + builder.WriteString(" JOIN ") + builder.WriteString(join.table) + builder.WriteString(" ON ") + join.columns.Write(builder) + } +} \ No newline at end of file diff --git a/backend/v3/storage/database/repository/instance.go b/backend/v3/storage/database/repository/instance.go index 3fbc631a33..5476316fc3 100644 --- a/backend/v3/storage/database/repository/instance.go +++ b/backend/v3/storage/database/repository/instance.go @@ -2,7 +2,7 @@ package repository import ( "context" - "errors" + "encoding/json" "github.com/zitadel/zitadel/backend/v3/domain" "github.com/zitadel/zitadel/backend/v3/storage/database" @@ -12,8 +12,8 @@ var _ domain.InstanceRepository = (*instance)(nil) type instance struct { repository - shouldJoinDomains bool - domainRepo domain.InstanceDomainRepository + shouldLoadDomains bool + domainRepo *instanceDomain } func InstanceRepository(client database.QueryExecutor) domain.InstanceRepository { @@ -24,38 +24,71 @@ func InstanceRepository(client database.QueryExecutor) domain.InstanceRepository } } + // ------------------------------------------------------------- // repository // ------------------------------------------------------------- -const queryInstanceStmt = `SELECT id, name, default_org_id, iam_project_id, console_client_id, console_app_id, default_language, created_at, updated_at` + - ` FROM zitadel.instances` +const ( + queryInstanceStmt = `SELECT instances.id, instances.name, instances.default_org_id, instances.iam_project_id, instances.console_client_id, instances.console_app_id, instances.default_language, instances.created_at, instances.updated_at` + + ` , CASE WHEN count(instance_domains.domain) > 0 THEN jsonb_agg(json_build_object('domain', instance_domains.domain, 'isVerified', instance_domains.is_verified, 'isPrimary', instance_domains.is_primary, 'isGenerated', instance_domains.is_generated, 'validationType', instance_domains.validation_type, 'createdAt', instance_domains.created_at, 'updatedAt', instance_domains.updated_at)) ELSE NULL::JSONB END domains` + + ` FROM zitadel.instances` +) // Get implements [domain.InstanceRepository]. -func (i *instance) Get(ctx context.Context, id string) (*domain.Instance, error) { +func (i *instance) Get(ctx context.Context, opts ...database.QueryOption) (*domain.Instance, error) { + opts = append(opts, + i.joinDomains(), + database.WithGroupBy(i.IDColumn(true)), + ) + + options := new(database.QueryOpts) + for _, opt := range opts { + opt(options) + } + var builder database.StatementBuilder - builder.WriteString(queryInstanceStmt) - - idCondition := i.IDCondition(id) - writeCondition(&builder, idCondition) + options.Write(&builder) return scanInstance(ctx, i.client, &builder) } // List implements [domain.InstanceRepository]. -func (i *instance) List(ctx context.Context, conditions ...database.Condition) ([]*domain.Instance, error) { - var builder database.StatementBuilder - - builder.WriteString(queryInstanceStmt) - - if conditions != nil { - writeCondition(&builder, database.And(conditions...)) +func (i *instance) List(ctx context.Context, opts ...database.QueryOption) ([]*domain.Instance, error) { + opts = append(opts, + i.joinDomains(), + database.WithGroupBy(i.IDColumn(true)), + ) + + options := new(database.QueryOpts) + for _, opt := range opts { + opt(options) } + var builder database.StatementBuilder + builder.WriteString(queryInstanceStmt) + options.Write(&builder) + return scanInstances(ctx, i.client, &builder) } +func (i *instance) joinDomains() database.QueryOption { + columns := make([]database.Condition, 0, 2) + columns = append(columns, database.NewColumnCondition(i.IDColumn(true), i.Domains(false).InstanceIDColumn(true))) + + // If domains should not be joined, we make sure to return null for the domain columns + // the query optimizer of the dialect should optimize this away if no domains are requested + if !i.shouldLoadDomains { + columns = append(columns, database.IsNull(i.Domains(false).InstanceIDColumn(true))) + } + + return database.WithLeftJoin( + "zitadel.instance_domains", + database.And(columns...), + ) +} + const createInstanceStmt = `INSERT INTO zitadel.instances (id, name, default_org_id, iam_project_id, console_client_id, console_app_id, default_language)` + ` VALUES ($1, $2, $3, $4, $5, $6, $7)` + ` RETURNING created_at, updated_at` @@ -72,8 +105,8 @@ func (i *instance) Create(ctx context.Context, instance *domain.Instance) error // Update implements [domain.InstanceRepository]. func (i instance) Update(ctx context.Context, id string, changes ...database.Change) (int64, error) { - if changes == nil { - return 0, errors.New("Update must contain a change") + if len(changes) == 0 { + return 0, database.NoChangesError } var builder database.StatementBuilder @@ -107,7 +140,7 @@ func (i instance) Delete(ctx context.Context, id string) (int64, error) { // SetName implements [domain.instanceChanges]. func (i instance) SetName(name string) database.Change { - return database.NewChange(i.NameColumn(), name) + return database.NewChange(i.NameColumn(false), name) } // ------------------------------------------------------------- @@ -116,12 +149,12 @@ func (i instance) SetName(name string) database.Change { // IDCondition implements [domain.instanceConditions]. func (i instance) IDCondition(id string) database.Condition { - return database.NewTextCondition(i.IDColumn(), database.TextOperationEqual, id) + return database.NewTextCondition(i.IDColumn(true), database.TextOperationEqual, id) } // NameCondition implements [domain.instanceConditions]. func (i instance) NameCondition(op database.TextOperation, name string) database.Condition { - return database.NewTextCondition(i.NameColumn(), op, name) + return database.NewTextCondition(i.NameColumn(true), op, name) } // ------------------------------------------------------------- @@ -129,74 +162,122 @@ func (i instance) NameCondition(op database.TextOperation, name string) database // ------------------------------------------------------------- // IDColumn implements [domain.instanceColumns]. -func (instance) IDColumn() database.Column { +func (instance) IDColumn(qualified bool) database.Column { + if qualified { + return database.NewColumn("instances.id") + } return database.NewColumn("id") } // NameColumn implements [domain.instanceColumns]. -func (instance) NameColumn() database.Column { +func (instance) NameColumn(qualified bool) database.Column { + if qualified { + return database.NewColumn("instances.name") + } return database.NewColumn("name") } // CreatedAtColumn implements [domain.instanceColumns]. -func (instance) CreatedAtColumn() database.Column { +func (instance) CreatedAtColumn(qualified bool) database.Column { + if qualified { + return database.NewColumn("instances.created_at") + } return database.NewColumn("created_at") } // DefaultOrgIdColumn implements [domain.instanceColumns]. -func (instance) DefaultOrgIDColumn() database.Column { +func (instance) DefaultOrgIDColumn(qualified bool) database.Column { + if qualified { + return database.NewColumn("instances.default_org_id") + } return database.NewColumn("default_org_id") } // IAMProjectIDColumn implements [domain.instanceColumns]. -func (instance) IAMProjectIDColumn() database.Column { +func (instance) IAMProjectIDColumn(qualified bool) database.Column { + if qualified { + return database.NewColumn("instances.iam_project_id") + } return database.NewColumn("iam_project_id") } // ConsoleClientIDColumn implements [domain.instanceColumns]. -func (instance) ConsoleClientIDColumn() database.Column { +func (instance) ConsoleClientIDColumn(qualified bool) database.Column { + if qualified { + return database.NewColumn("instances.console_client_id") + } return database.NewColumn("console_client_id") } // ConsoleAppIDColumn implements [domain.instanceColumns]. -func (instance) ConsoleAppIDColumn() database.Column { +func (instance) ConsoleAppIDColumn(qualified bool) database.Column { + if qualified { + return database.NewColumn("instances.console_app_id") + } return database.NewColumn("console_app_id") } // DefaultLanguageColumn implements [domain.instanceColumns]. -func (instance) DefaultLanguageColumn() database.Column { +func (instance) DefaultLanguageColumn(qualified bool) database.Column { + if qualified { + return database.NewColumn("instances.default_language") + } return database.NewColumn("default_language") } // UpdatedAtColumn implements [domain.instanceColumns]. -func (instance) UpdatedAtColumn() database.Column { +func (instance) UpdatedAtColumn(qualified bool) database.Column { + if qualified { + return database.NewColumn("instances.updated_at") + } return database.NewColumn("updated_at") } +type rawInstance struct { + *domain.Instance + RawDomains json.RawMessage `json:"domains,omitempty" db:"domains"` +} + func scanInstance(ctx context.Context, querier database.Querier, builder *database.StatementBuilder) (*domain.Instance, error) { rows, err := querier.Query(ctx, builder.String(), builder.Args()...) if err != nil { return nil, err } - - instance := new(domain.Instance) - if err := rows.(database.CollectableRows).CollectExactlyOneRow(instance); err != nil { + + var instance rawInstance + if err := rows.(database.CollectableRows).CollectExactlyOneRow(&instance); err != nil { return nil, err } - return instance, nil + if len(instance.RawDomains) > 0 { + if err := json.Unmarshal(instance.RawDomains, &instance.Domains); err != nil { + return nil, err + } + } + + return instance.Instance, nil } -func scanInstances(ctx context.Context, querier database.Querier, builder *database.StatementBuilder) (instances []*domain.Instance, err error) { +func scanInstances(ctx context.Context, querier database.Querier, builder *database.StatementBuilder) ([]*domain.Instance, error) { rows, err := querier.Query(ctx, builder.String(), builder.Args()...) if err != nil { return nil, err } - if err := rows.(database.CollectableRows).Collect(&instances); err != nil { + var rawInstances []*rawInstance + if err := rows.(database.CollectableRows).Collect(&rawInstances); err != nil { return nil, err } + instances := make([]*domain.Instance, len(rawInstances)) + for i, instance := range rawInstances { + if len(instance.RawDomains) > 0 { + if err := json.Unmarshal(instance.RawDomains, &instance.Domains); err != nil { + return nil, err + } + } + instances[i] = instance.Instance + } return instances, nil } @@ -205,8 +286,10 @@ func scanInstances(ctx context.Context, querier database.Querier, builder *datab // ------------------------------------------------------------- // Domains implements [domain.InstanceRepository]. -func (i *instance) Domains() domain.InstanceDomainRepository { - i.shouldJoinDomains = true +func (i *instance) Domains(shouldLoad bool) domain.InstanceDomainRepository { + if !i.shouldLoadDomains { + i.shouldLoadDomains = shouldLoad + } if i.domainRepo != nil { return i.domainRepo diff --git a/backend/v3/storage/database/repository/instance_domain.go b/backend/v3/storage/database/repository/instance_domain.go index ae3169e83e..db1ad509fa 100644 --- a/backend/v3/storage/database/repository/instance_domain.go +++ b/backend/v3/storage/database/repository/instance_domain.go @@ -18,8 +18,8 @@ type instanceDomain struct { // repository // ------------------------------------------------------------- -const queryInstanceDomainStmt = `SELECT instance_id, domain, is_verified, is_primary, verification_type, created_at, updated_at ` + - `FROM zitadel.instance_domains` +const queryInstanceDomainStmt = `SELECT instance_domains.instance_id, instance_domains.domain, instance_domains.is_verified, instance_domains.is_primary, instance_domains.validation_type, instance_domains.created_at, instance_domains.updated_at ` + + `FROM zitadel.instance_domains id` // Get implements [domain.InstanceDomainRepository]. // Subtle: this method shadows the method ([domain.InstanceRepository]).Get of instanceDomain.instance. @@ -55,7 +55,7 @@ func (i *instanceDomain) List(ctx context.Context, opts ...database.QueryOption) func (i *instanceDomain) Add(ctx context.Context, domain *domain.AddInstanceDomain) error { var builder database.StatementBuilder - builder.WriteString(`INSERT INTO zitadel.instance_domains (instance_id, domain, is_verified, is_primary, verification_type) ` + + builder.WriteString(`INSERT INTO zitadel.instance_domains (instance_id, domain, is_verified, is_primary, validation_type) ` + `VALUES ($1, $2, $3, $4, $5)` + ` RETURNING created_at, updated_at`) @@ -91,19 +91,19 @@ func (i *instanceDomain) Update(ctx context.Context, condition database.Conditio // changes // ------------------------------------------------------------- -// SetVerificationType implements [domain.InstanceDomainRepository]. -func (i instanceDomain) SetVerificationType(verificationType domain.DomainVerificationType) database.Change { - return database.NewChange(i.VerificationTypeColumn(), verificationType) +// SetValidationType implements [domain.InstanceDomainRepository]. +func (i instanceDomain) SetValidationType(verificationType domain.DomainValidationType) database.Change { + return database.NewChange(i.ValidationTypeColumn(false), verificationType) } // SetPrimary implements [domain.InstanceDomainRepository]. func (i instanceDomain) SetPrimary() database.Change { - return database.NewChange(i.IsPrimaryColumn(), true) + return database.NewChange(i.IsPrimaryColumn(false), true) } // SetVerified implements [domain.InstanceDomainRepository]. func (i instanceDomain) SetVerified() database.Change { - return database.NewChange(i.IsVerifiedColumn(), true) + return database.NewChange(i.IsVerifiedColumn(false), true) } // ------------------------------------------------------------- @@ -112,22 +112,22 @@ func (i instanceDomain) SetVerified() database.Change { // DomainCondition implements [domain.InstanceDomainRepository]. func (i instanceDomain) DomainCondition(op database.TextOperation, domain string) database.Condition { - return database.NewTextCondition(i.DomainColumn(), op, domain) + return database.NewTextCondition(i.DomainColumn(true), op, domain) } // InstanceIDCondition implements [domain.InstanceDomainRepository]. func (i instanceDomain) InstanceIDCondition(instanceID string) database.Condition { - return database.NewTextCondition(i.InstanceIDColumn(), database.TextOperationEqual, instanceID) + return database.NewTextCondition(i.InstanceIDColumn(true), database.TextOperationEqual, instanceID) } // IsPrimaryCondition implements [domain.InstanceDomainRepository]. func (i instanceDomain) IsPrimaryCondition(isPrimary bool) database.Condition { - return database.NewBooleanCondition(i.IsPrimaryColumn(), isPrimary) + return database.NewBooleanCondition(i.IsPrimaryColumn(true), isPrimary) } // IsVerifiedCondition implements [domain.InstanceDomainRepository]. func (i instanceDomain) IsVerifiedCondition(isVerified bool) database.Condition { - return database.NewBooleanCondition(i.IsVerifiedColumn(), isVerified) + return database.NewBooleanCondition(i.IsVerifiedColumn(true), isVerified) } // ------------------------------------------------------------- @@ -136,43 +136,67 @@ func (i instanceDomain) IsVerifiedCondition(isVerified bool) database.Condition // CreatedAtColumn implements [domain.InstanceDomainRepository]. // Subtle: this method shadows the method ([domain.InstanceRepository]).CreatedAtColumn of instanceDomain.instance. -func (instanceDomain) CreatedAtColumn() database.Column { +func (instanceDomain) CreatedAtColumn(qualified bool) database.Column { + if qualified { + return database.NewColumn("instance_domains.created_at") + } return database.NewColumn("created_at") } // DomainColumn implements [domain.InstanceDomainRepository]. -func (instanceDomain) DomainColumn() database.Column { +func (instanceDomain) DomainColumn(qualified bool) database.Column { + if qualified { + return database.NewColumn("instance_domains.domain") + } return database.NewColumn("domain") } // InstanceIDColumn implements [domain.InstanceDomainRepository]. -func (instanceDomain) InstanceIDColumn() database.Column { +func (instanceDomain) InstanceIDColumn(qualified bool) database.Column { + if qualified { + return database.NewColumn("instance_domains.instance_id") + } return database.NewColumn("instance_id") } // IsPrimaryColumn implements [domain.InstanceDomainRepository]. -func (instanceDomain) IsPrimaryColumn() database.Column { +func (instanceDomain) IsPrimaryColumn(qualified bool) database.Column { + if qualified { + return database.NewColumn("instance_domains.is_primary") + } return database.NewColumn("is_primary") } // IsVerifiedColumn implements [domain.InstanceDomainRepository]. -func (instanceDomain) IsVerifiedColumn() database.Column { +func (instanceDomain) IsVerifiedColumn(qualified bool) database.Column { + if qualified { + return database.NewColumn("instance_domains.is_verified") + } return database.NewColumn("is_verified") } // UpdatedAtColumn implements [domain.InstanceDomainRepository]. // Subtle: this method shadows the method ([domain.InstanceRepository]).UpdatedAtColumn of instanceDomain.instance. -func (instanceDomain) UpdatedAtColumn() database.Column { +func (instanceDomain) UpdatedAtColumn(qualified bool) database.Column { + if qualified { + return database.NewColumn("instance_domains.updated_at") + } return database.NewColumn("updated_at") } -// VerificationTypeColumn implements [domain.InstanceDomainRepository]. -func (instanceDomain) VerificationTypeColumn() database.Column { - return database.NewColumn("verification_type") +// ValidationTypeColumn implements [domain.InstanceDomainRepository]. +func (instanceDomain) ValidationTypeColumn(qualified bool) database.Column { + if qualified { + return database.NewColumn("instance_domains.validation_type") + } + return database.NewColumn("validation_type") } // IsGeneratedColumn implements [domain.InstanceDomainRepository]. -func (instanceDomain) IsGeneratedColumn() database.Column { +func (instanceDomain) IsGeneratedColumn(qualified bool) database.Column { + if qualified { + return database.NewColumn("instance_domains.is_generated") + } return database.NewColumn("is_generated") } diff --git a/backend/v3/storage/database/repository/instance_test.go b/backend/v3/storage/database/repository/instance_test.go index bdc914157b..e40f416099 100644 --- a/backend/v3/storage/database/repository/instance_test.go +++ b/backend/v3/storage/database/repository/instance_test.go @@ -171,7 +171,9 @@ func TestCreateInstance(t *testing.T) { // check instance values instance, err = instanceRepo.Get(ctx, - instance.ID, + database.WithCondition( + instanceRepo.IDCondition(instance.ID), + ), ) require.NoError(t, err) @@ -290,7 +292,9 @@ func TestUpdateInstance(t *testing.T) { // check instance values instance, err = instanceRepo.Get(ctx, - instance.ID, + database.WithCondition( + instanceRepo.IDCondition(instance.ID), + ), ) require.Equal(t, tt.getErr, err) @@ -356,7 +360,9 @@ func TestGetInstance(t *testing.T) { // check instance values returnedInstance, err := instanceRepo.Get(ctx, - instance.ID, + database.WithCondition( + instanceRepo.IDCondition(instance.ID), + ), ) if tt.err != nil { require.ErrorIs(t, err, tt.err) @@ -524,9 +530,15 @@ func TestListInstance(t *testing.T) { instanceRepo := repository.InstanceRepository(pool) + var condition database.Condition + if len(tt.conditionClauses) > 0 { + condition = database.And(tt.conditionClauses...) + } + // check instance values returnedInstances, err := instanceRepo.List(ctx, - tt.conditionClauses..., + database.WithCondition(condition), + database.WithOrderBy(instanceRepo.CreatedAtColumn(true)), ) require.NoError(t, err) if tt.noInstanceReturned { @@ -652,7 +664,9 @@ func TestDeleteInstance(t *testing.T) { // check instance was deleted instance, err := instanceRepo.Get(ctx, - tt.instanceID, + database.WithCondition( + instanceRepo.IDCondition(tt.instanceID), + ), ) require.ErrorIs(t, err, new(database.NoRowFoundError)) assert.Nil(t, instance) diff --git a/backend/v3/storage/database/repository/org.go b/backend/v3/storage/database/repository/org.go index 470bed8af9..cb80f240d7 100644 --- a/backend/v3/storage/database/repository/org.go +++ b/backend/v3/storage/database/repository/org.go @@ -2,7 +2,7 @@ package repository import ( "context" - "errors" + "encoding/json" "github.com/zitadel/zitadel/backend/v3/domain" "github.com/zitadel/zitadel/backend/v3/storage/database" @@ -15,9 +15,9 @@ import ( var _ domain.OrganizationRepository = (*org)(nil) type org struct { - shouldJoinDomains bool repository - domainRepo domain.OrganizationDomainRepository + shouldLoadDomains bool + domainRepo domain.OrganizationDomainRepository } func OrganizationRepository(client database.QueryExecutor) domain.OrganizationRepository { @@ -28,39 +28,68 @@ func OrganizationRepository(client database.QueryExecutor) domain.OrganizationRe } } -const queryOrganizationStmt = `SELECT id, name, instance_id, state, created_at, updated_at` + +const queryOrganizationStmt = `SELECT organizations.id, organizations.name, organizations.instance_id, organizations.state, organizations.created_at, organizations.updated_at` + + ` , CASE WHEN count(org_domains.domain) > 0 THEN jsonb_agg(json_build_object('domain', org_domains.domain, 'isVerified', org_domains.is_verified, 'isPrimary', org_domains.is_primary, 'validationType', org_domains.validation_type, 'createdAt', org_domains.created_at, 'updatedAt', org_domains.updated_at)) ELSE NULL::JSONB END domains` + ` FROM zitadel.organizations` // Get implements [domain.OrganizationRepository]. -func (o *org) Get(ctx context.Context, id domain.OrgIdentifierCondition, instanceID string, conditions ...database.Condition) (*domain.Organization, error) { - builder := database.StatementBuilder{} +func (o *org) Get(ctx context.Context, opts ...database.QueryOption) (*domain.Organization, error) { + opts = append(opts, + o.joinDomains(), + database.WithGroupBy(o.InstanceIDColumn(true), o.IDColumn(true)), + ) + options := new(database.QueryOpts) + for _, opt := range opts { + opt(options) + } + + var builder database.StatementBuilder builder.WriteString(queryOrganizationStmt) - - instanceIDCondition := o.InstanceIDCondition(instanceID) - - conditions = append(conditions, id, instanceIDCondition) - writeCondition(&builder, database.And(conditions...)) + options.Write(&builder) return scanOrganization(ctx, o.client, &builder) } // List implements [domain.OrganizationRepository]. -func (o *org) List(ctx context.Context, conditions ...database.Condition) ([]*domain.Organization, error) { - builder := database.StatementBuilder{} +func (o *org) List(ctx context.Context, opts ...database.QueryOption) ([]*domain.Organization, error) { + opts = append(opts, + o.joinDomains(), + database.WithGroupBy(o.InstanceIDColumn(true), o.IDColumn(true)), + ) - builder.WriteString(queryOrganizationStmt) - - if conditions != nil { - writeCondition(&builder, database.And(conditions...)) + options := new(database.QueryOpts) + for _, opt := range opts { + opt(options) } - - orderBy := database.OrderBy(o.CreatedAtColumn()) - orderBy.Write(&builder) + + var builder database.StatementBuilder + builder.WriteString(queryOrganizationStmt) + options.Write(&builder) return scanOrganizations(ctx, o.client, &builder) } +func (o *org) joinDomains() database.QueryOption { + columns := make([]database.Condition, 0, 3) + columns = append(columns, + database.NewColumnCondition(o.InstanceIDColumn(true), o.Domains(false).InstanceIDColumn(true)), + database.NewColumnCondition(o.IDColumn(true), o.Domains(false).OrgIDColumn(true)), + ) + + // If domains should not be joined, we make sure to return null for the domain columns + // the query optimizer of the dialect should optimize this away if no domains are requested + if !o.shouldLoadDomains { + columns = append(columns, database.IsNull(o.domainRepo.OrgIDColumn(true))) + } + + return database.WithLeftJoin( + "zitadel.org_domains", + database.And(columns...), + ) +} + + const createOrganizationStmt = `INSERT INTO zitadel.organizations (id, name, instance_id, state)` + ` VALUES ($1, $2, $3, $4)` + ` RETURNING created_at, updated_at` @@ -77,8 +106,8 @@ func (o *org) Create(ctx context.Context, organization *domain.Organization) err // Update implements [domain.OrganizationRepository]. func (o *org) Update(ctx context.Context, id domain.OrgIdentifierCondition, instanceID string, changes ...database.Change) (int64, error) { - if changes == nil { - return 0, errors.New("Update must contain a condition") // (otherwise ALL organizations will be updated) + if len(changes) == 0 { + return 0, database.NoChangesError } builder := database.StatementBuilder{} builder.WriteString(`UPDATE zitadel.organizations SET `) @@ -115,12 +144,12 @@ func (o *org) Delete(ctx context.Context, id domain.OrgIdentifierCondition, inst // SetName implements [domain.organizationChanges]. func (o org) SetName(name string) database.Change { - return database.NewChange(o.NameColumn(), name) + return database.NewChange(o.NameColumn(false), name) } // SetState implements [domain.organizationChanges]. func (o org) SetState(state domain.OrgState) database.Change { - return database.NewChange(o.StateColumn(), state) + return database.NewChange(o.StateColumn(false), state) } // ------------------------------------------------------------- @@ -129,22 +158,22 @@ func (o org) SetState(state domain.OrgState) database.Change { // IDCondition implements [domain.organizationConditions]. func (o org) IDCondition(id string) domain.OrgIdentifierCondition { - return database.NewTextCondition(o.IDColumn(), database.TextOperationEqual, id) + return database.NewTextCondition(o.IDColumn(true), database.TextOperationEqual, id) } // NameCondition implements [domain.organizationConditions]. func (o org) NameCondition(name string) domain.OrgIdentifierCondition { - return database.NewTextCondition(o.NameColumn(), database.TextOperationEqual, name) + return database.NewTextCondition(o.NameColumn(true), database.TextOperationEqual, name) } // InstanceIDCondition implements [domain.organizationConditions]. func (o org) InstanceIDCondition(instanceID string) database.Condition { - return database.NewTextCondition(o.InstanceIDColumn(), database.TextOperationEqual, instanceID) + return database.NewTextCondition(o.InstanceIDColumn(true), database.TextOperationEqual, instanceID) } // StateCondition implements [domain.organizationConditions]. func (o org) StateCondition(state domain.OrgState) database.Condition { - return database.NewTextCondition(o.StateColumn(), database.TextOperationEqual, state) + return database.NewTextCondition(o.StateColumn(true), database.TextOperationEqual, state) } // ------------------------------------------------------------- @@ -152,32 +181,50 @@ func (o org) StateCondition(state domain.OrgState) database.Condition { // ------------------------------------------------------------- // IDColumn implements [domain.organizationColumns]. -func (org) IDColumn() database.Column { +func (org) IDColumn(qualified bool) database.Column { + if qualified { + return database.NewColumn("organizations.id") + } return database.NewColumn("id") } // NameColumn implements [domain.organizationColumns]. -func (org) NameColumn() database.Column { +func (org) NameColumn(qualified bool) database.Column { + if qualified { + return database.NewColumn("organizations.name") + } return database.NewColumn("name") } // InstanceIDColumn implements [domain.organizationColumns]. -func (org) InstanceIDColumn() database.Column { +func (org) InstanceIDColumn(qualified bool) database.Column { + if qualified { + return database.NewColumn("organizations.instance_id") + } return database.NewColumn("instance_id") } // StateColumn implements [domain.organizationColumns]. -func (org) StateColumn() database.Column { +func (org) StateColumn(qualified bool) database.Column { + if qualified { + return database.NewColumn("organizations.state") + } return database.NewColumn("state") } // CreatedAtColumn implements [domain.organizationColumns]. -func (org) CreatedAtColumn() database.Column { +func (org) CreatedAtColumn(qualified bool) database.Column { + if qualified { + return database.NewColumn("organizations.created_at") + } return database.NewColumn("created_at") } // UpdatedAtColumn implements [domain.organizationColumns]. -func (org) UpdatedAtColumn() database.Column { +func (org) UpdatedAtColumn(qualified bool) database.Column { + if qualified { + return database.NewColumn("organizations.updated_at") + } return database.NewColumn("updated_at") } @@ -185,18 +232,28 @@ func (org) UpdatedAtColumn() database.Column { // scanners // ------------------------------------------------------------- +type rawOrganization struct { + *domain.Organization + RawDomains json.RawMessage `json:"domains,omitempty" db:"domains"` +} + func scanOrganization(ctx context.Context, querier database.Querier, builder *database.StatementBuilder) (*domain.Organization, error) { rows, err := querier.Query(ctx, builder.String(), builder.Args()...) if err != nil { return nil, err } - organization := &domain.Organization{} - if err := rows.(database.CollectableRows).CollectExactlyOneRow(organization); err != nil { + var org rawOrganization + if err := rows.(database.CollectableRows).CollectExactlyOneRow(&org); err != nil { return nil, err } + if len(org.RawDomains) > 0 { + if err := json.Unmarshal(org.RawDomains, &org.Domains); err != nil { + return nil, err + } + } - return organization, nil + return org.Organization, nil } func scanOrganizations(ctx context.Context, querier database.Querier, builder *database.StatementBuilder) ([]*domain.Organization, error) { @@ -205,10 +262,20 @@ func scanOrganizations(ctx context.Context, querier database.Querier, builder *d return nil, err } - organizations := []*domain.Organization{} - if err := rows.(database.CollectableRows).Collect(&organizations); err != nil { + var rawOrgs []*rawOrganization + if err := rows.(database.CollectableRows).Collect(&rawOrgs); err != nil { return nil, err } + + organizations := make([]*domain.Organization, len(rawOrgs)) + for i, org := range rawOrgs { + if len(org.RawDomains) > 0 { + if err := json.Unmarshal(org.RawDomains, &org.Domains); err != nil { + return nil, err + } + } + organizations[i] = org.Organization + } return organizations, nil } @@ -216,8 +283,11 @@ func scanOrganizations(ctx context.Context, querier database.Querier, builder *d // sub repositories // ------------------------------------------------------------- -func (o *org) Domains() domain.OrganizationDomainRepository { - o.shouldJoinDomains = true +// Domains implements [domain.OrganizationRepository]. +func (o *org) Domains(shouldLoad bool) domain.OrganizationDomainRepository { + if !o.shouldLoadDomains { + o.shouldLoadDomains = shouldLoad + } if o.domainRepo != nil { return o.domainRepo diff --git a/backend/v3/storage/database/repository/org_domain.go b/backend/v3/storage/database/repository/org_domain.go index 74b6c403a2..f114785bbd 100644 --- a/backend/v3/storage/database/repository/org_domain.go +++ b/backend/v3/storage/database/repository/org_domain.go @@ -61,7 +61,7 @@ func (o *orgDomain) Add(ctx context.Context, domain *domain.AddOrganizationDomai `VALUES ($1, $2, $3, $4, $5, $6)` + ` RETURNING created_at, updated_at`) - builder.AppendArgs(domain.InstanceID, domain.OrgID, domain.Domain, domain.IsVerified, domain.IsPrimary, domain.VerificationType) + builder.AppendArgs(domain.InstanceID, domain.OrgID, domain.Domain, domain.IsVerified, domain.IsPrimary, domain.ValidationType) return o.client.QueryRow(ctx, builder.String(), builder.Args()...).Scan(&domain.CreatedAt, &domain.UpdatedAt) } @@ -95,17 +95,17 @@ func (o *orgDomain) Remove(ctx context.Context, condition database.Condition) (i // SetPrimary implements [domain.OrganizationDomainRepository]. func (o orgDomain) SetPrimary() database.Change { - return database.NewChange(o.IsPrimaryColumn(), true) + return database.NewChange(o.IsPrimaryColumn(false), true) } -// SetVerificationType implements [domain.OrganizationDomainRepository]. -func (o orgDomain) SetVerificationType(verificationType domain.DomainVerificationType) database.Change { - return database.NewChange(o.VerificationTypeColumn(), verificationType) +// SetValidationType implements [domain.OrganizationDomainRepository]. +func (o orgDomain) SetValidationType(verificationType domain.DomainValidationType) database.Change { + return database.NewChange(o.ValidationTypeColumn(false), verificationType) } // SetVerified implements [domain.OrganizationDomainRepository]. func (o orgDomain) SetVerified() database.Change { - return database.NewChange(o.IsVerifiedColumn(), true) + return database.NewChange(o.IsVerifiedColumn(false), true) } // ------------------------------------------------------------- @@ -114,28 +114,28 @@ func (o orgDomain) SetVerified() database.Change { // DomainCondition implements [domain.OrganizationDomainRepository]. func (o orgDomain) DomainCondition(op database.TextOperation, domain string) database.Condition { - return database.NewTextCondition(o.DomainColumn(), op, domain) + return database.NewTextCondition(o.DomainColumn(true), op, domain) } // InstanceIDCondition implements [domain.OrganizationDomainRepository]. // Subtle: this method shadows the method ([domain.OrganizationRepository]).InstanceIDCondition of orgDomain.org. func (o orgDomain) InstanceIDCondition(instanceID string) database.Condition { - return database.NewTextCondition(o.InstanceIDColumn(), database.TextOperationEqual, instanceID) + return database.NewTextCondition(o.InstanceIDColumn(true), database.TextOperationEqual, instanceID) } // IsPrimaryCondition implements [domain.OrganizationDomainRepository]. func (o orgDomain) IsPrimaryCondition(isPrimary bool) database.Condition { - return database.NewBooleanCondition(o.IsPrimaryColumn(), isPrimary) + return database.NewBooleanCondition(o.IsPrimaryColumn(true), isPrimary) } // IsVerifiedCondition implements [domain.OrganizationDomainRepository]. func (o orgDomain) IsVerifiedCondition(isVerified bool) database.Condition { - return database.NewBooleanCondition(o.IsVerifiedColumn(), isVerified) + return database.NewBooleanCondition(o.IsVerifiedColumn(true), isVerified) } // OrgIDCondition implements [domain.OrganizationDomainRepository]. func (o orgDomain) OrgIDCondition(orgID string) database.Condition { - return database.NewTextCondition(o.OrgIDColumn(), database.TextOperationEqual, orgID) + return database.NewTextCondition(o.OrgIDColumn(true), database.TextOperationEqual, orgID) } // ------------------------------------------------------------- @@ -144,45 +144,69 @@ func (o orgDomain) OrgIDCondition(orgID string) database.Condition { // CreatedAtColumn implements [domain.OrganizationDomainRepository]. // Subtle: this method shadows the method ([domain.OrganizationRepository]).CreatedAtColumn of orgDomain.org. -func (orgDomain) CreatedAtColumn() database.Column { +func (orgDomain) CreatedAtColumn(qualified bool) database.Column { + if qualified { + return database.NewColumn("org_domains.created_at") + } return database.NewColumn("created_at") } // DomainColumn implements [domain.OrganizationDomainRepository]. -func (orgDomain) DomainColumn() database.Column { +func (orgDomain) DomainColumn(qualified bool) database.Column { + if qualified { + return database.NewColumn("org_domains.domain") + } return database.NewColumn("domain") } // InstanceIDColumn implements [domain.OrganizationDomainRepository]. // Subtle: this method shadows the method ([domain.OrganizationRepository]).InstanceIDColumn of orgDomain.org. -func (orgDomain) InstanceIDColumn() database.Column { +func (orgDomain) InstanceIDColumn(qualified bool) database.Column { + if qualified { + return database.NewColumn("org_domains.instance_id") + } return database.NewColumn("instance_id") } // IsPrimaryColumn implements [domain.OrganizationDomainRepository]. -func (orgDomain) IsPrimaryColumn() database.Column { +func (orgDomain) IsPrimaryColumn(qualified bool) database.Column { + if qualified { + return database.NewColumn("org_domains.is_primary") + } return database.NewColumn("is_primary") } // IsVerifiedColumn implements [domain.OrganizationDomainRepository]. -func (orgDomain) IsVerifiedColumn() database.Column { +func (orgDomain) IsVerifiedColumn(qualified bool) database.Column { + if qualified { + return database.NewColumn("org_domains.is_verified") + } return database.NewColumn("is_verified") } // OrgIDColumn implements [domain.OrganizationDomainRepository]. -func (orgDomain) OrgIDColumn() database.Column { +func (orgDomain) OrgIDColumn(qualified bool) database.Column { + if qualified { + return database.NewColumn("org_domains.org_id") + } return database.NewColumn("org_id") } // UpdatedAtColumn implements [domain.OrganizationDomainRepository]. // Subtle: this method shadows the method ([domain.OrganizationRepository]).UpdatedAtColumn of orgDomain.org. -func (orgDomain) UpdatedAtColumn() database.Column { +func (orgDomain) UpdatedAtColumn(qualified bool) database.Column { + if qualified { + return database.NewColumn("org_domains.updated_at") + } return database.NewColumn("updated_at") } -// VerificationTypeColumn implements [domain.OrganizationDomainRepository]. -func (orgDomain) VerificationTypeColumn() database.Column { - return database.NewColumn("verification_type") +// ValidationTypeColumn implements [domain.OrganizationDomainRepository]. +func (orgDomain) ValidationTypeColumn(qualified bool) database.Column { + if qualified { + return database.NewColumn("org_domains.validation_type") + } + return database.NewColumn("validation_type") } // ------------------------------------------------------------- diff --git a/backend/v3/storage/database/repository/org_test.go b/backend/v3/storage/database/repository/org_test.go index 74d72c9e59..26b7add8d9 100644 --- a/backend/v3/storage/database/repository/org_test.go +++ b/backend/v3/storage/database/repository/org_test.go @@ -235,8 +235,12 @@ func TestCreateOrganization(t *testing.T) { // check organization values organization, err = organizationRepo.Get(ctx, - organizationRepo.IDCondition(organization.ID), - organization.InstanceID, + database.WithCondition( + database.And( + organizationRepo.IDCondition(organization.ID), + organizationRepo.InstanceIDCondition(organization.InstanceID), + ), + ), ) require.NoError(t, err) @@ -389,8 +393,12 @@ func TestUpdateOrganization(t *testing.T) { // check organization values organization, err := organizationRepo.Get(ctx, - organizationRepo.IDCondition(createdOrg.ID), - createdOrg.InstanceID, + database.WithCondition( + database.And( + organizationRepo.IDCondition(createdOrg.ID), + organizationRepo.InstanceIDCondition(createdOrg.InstanceID), + ), + ), ) require.NoError(t, err) @@ -511,13 +519,18 @@ func TestGetOrganization(t *testing.T) { // get org values returnedOrg, err := orgRepo.Get(ctx, - tt.orgIdentifierCondition, - org.InstanceID, + database.WithCondition( + database.And( + tt.orgIdentifierCondition, + orgRepo.InstanceIDCondition(org.InstanceID), + ), + ), ) if tt.err != nil { require.ErrorIs(t, tt.err, err) return } + require.NoError(t, err) if org.Name == "non existent org" { assert.Nil(t, returnedOrg) @@ -764,9 +777,15 @@ func TestListOrganization(t *testing.T) { organizations := tt.testFunc(ctx, t) + var condition database.Condition + if len(tt.conditionClauses) > 0 { + condition = database.And(tt.conditionClauses...) + } + // check organization values returnedOrgs, err := organizationRepo.List(ctx, - tt.conditionClauses..., + database.WithCondition(condition), + database.WithOrderBy(organizationRepo.CreatedAtColumn(true)), ) require.NoError(t, err) if tt.noOrganizationReturned { @@ -929,8 +948,12 @@ func TestDeleteOrganization(t *testing.T) { // check organization was deleted organization, err := organizationRepo.Get(ctx, - tt.orgIdentifierCondition, - instanceId, + database.WithCondition( + database.And( + tt.orgIdentifierCondition, + organizationRepo.InstanceIDCondition(instanceId), + ), + ), ) require.ErrorIs(t, err, new(database.NoRowFoundError)) assert.Nil(t, organization)