diff --git a/backend/v3/storage/database/repository/instance_domain.go b/backend/v3/storage/database/repository/instance_domain.go index 5ce4c12e7b..2f7ea6170b 100644 --- a/backend/v3/storage/database/repository/instance_domain.go +++ b/backend/v3/storage/database/repository/instance_domain.go @@ -72,16 +72,6 @@ func (i *instanceDomain) Add(ctx context.Context, domain *domain.AddInstanceDoma return i.client.QueryRow(ctx, builder.String(), builder.Args()...).Scan(&domain.CreatedAt, &domain.UpdatedAt) } -// Remove implements [domain.InstanceDomainRepository]. -func (i *instanceDomain) Remove(ctx context.Context, condition database.Condition) (int64, error) { - var builder database.StatementBuilder - - builder.WriteString(`DELETE FROM zitadel.instance_domains WHERE `) - condition.Write(&builder) - - return i.client.Exec(ctx, builder.String(), builder.Args()...) -} - // Update implements [domain.InstanceDomainRepository]. // Subtle: this method shadows the method ([domain.InstanceRepository]).Update of instanceDomain.instance. func (i *instanceDomain) Update(ctx context.Context, condition database.Condition, changes ...database.Change) (int64, error) { @@ -98,6 +88,16 @@ func (i *instanceDomain) Update(ctx context.Context, condition database.Conditio return i.client.Exec(ctx, builder.String(), builder.Args()...) } +// Remove implements [domain.InstanceDomainRepository]. +func (i *instanceDomain) Remove(ctx context.Context, condition database.Condition) (int64, error) { + var builder database.StatementBuilder + + builder.WriteString(`DELETE FROM zitadel.instance_domains WHERE `) + condition.Write(&builder) + + return i.client.Exec(ctx, builder.String(), builder.Args()...) +} + // ------------------------------------------------------------- // changes // ------------------------------------------------------------- @@ -213,12 +213,12 @@ func scanInstanceDomains(ctx context.Context, querier database.Querier, builder return nil, err } - var instanceDomains []*domain.InstanceDomain - if err := rows.(database.CollectableRows).Collect(&instanceDomains); err != nil { + var domains []*domain.InstanceDomain + if err := rows.(database.CollectableRows).Collect(&domains); err != nil { return nil, err } - return instanceDomains, nil + return domains, nil } func scanInstanceDomain(ctx context.Context, querier database.Querier, builder *database.StatementBuilder) (*domain.InstanceDomain, error) { @@ -226,10 +226,10 @@ func scanInstanceDomain(ctx context.Context, querier database.Querier, builder * if err != nil { return nil, err } - instanceDomain := new(domain.InstanceDomain) - if err := rows.(database.CollectableRows).CollectExactlyOneRow(instanceDomain); err != nil { + domain := new(domain.InstanceDomain) + if err := rows.(database.CollectableRows).CollectExactlyOneRow(domain); err != nil { return nil, err } - return instanceDomain, nil + return domain, nil } diff --git a/backend/v3/storage/database/repository/instance_domain_test.go b/backend/v3/storage/database/repository/instance_domain_test.go index 693d7f9a0b..cc18ada380 100644 --- a/backend/v3/storage/database/repository/instance_domain_test.go +++ b/backend/v3/storage/database/repository/instance_domain_test.go @@ -649,6 +649,16 @@ func TestInstanceDomainConditions(t *testing.T) { condition: domainRepo.IsPrimaryCondition(false), expected: "instance_domains.is_primary = $1", }, + { + name: "is type custom", + condition: domainRepo.TypeCondition(domain.DomainTypeCustom), + expected: "instance_domains.type = $1", + }, + { + name: "is type trusted", + condition: domainRepo.TypeCondition(domain.DomainTypeTrusted), + expected: "instance_domains.type = $1", + }, } for _, test := range tests { @@ -674,6 +684,11 @@ func TestInstanceDomainChanges(t *testing.T) { change: domainRepo.SetPrimary(), expected: "is_primary = $1", }, + { + name: "set type", + change: domainRepo.SetType(domain.DomainTypeCustom), + expected: "type = $1", + }, } for _, test := range tests { diff --git a/backend/v3/storage/database/repository/org_domain.go b/backend/v3/storage/database/repository/org_domain.go index f07e0fcf83..7a5db25371 100644 --- a/backend/v3/storage/database/repository/org_domain.go +++ b/backend/v3/storage/database/repository/org_domain.go @@ -54,10 +54,19 @@ func (o *orgDomain) List(ctx context.Context, opts ...database.QueryOption) ([]* // Add implements [domain.OrganizationDomainRepository]. func (o *orgDomain) Add(ctx context.Context, domain *domain.AddOrganizationDomain) error { - var builder database.StatementBuilder + var ( + builder database.StatementBuilder + createdAt, updatedAt any = database.DefaultInstruction, database.DefaultInstruction + ) + if !domain.CreatedAt.IsZero() { + createdAt = domain.CreatedAt + } + if !domain.UpdatedAt.IsZero() { + updatedAt = domain.UpdatedAt + } - builder.WriteString(`INSERT INTO zitadel.org_domains (instance_id, org_id, domain, is_verified, is_primary, validation_type) VALUES (`) - builder.WriteArgs(domain.InstanceID, domain.OrgID, domain.Domain, domain.IsVerified, domain.IsPrimary, domain.ValidationType) + builder.WriteString(`INSERT INTO zitadel.org_domains (instance_id, org_id, domain, is_verified, is_primary, validation_type, created_at, updated_at) VALUES (`) + builder.WriteArgs(domain.InstanceID, domain.OrgID, domain.Domain, domain.IsVerified, domain.IsPrimary, domain.ValidationType, createdAt, updatedAt) builder.WriteString(`) RETURNING created_at, updated_at`) return o.client.QueryRow(ctx, builder.String(), builder.Args()...).Scan(&domain.CreatedAt, &domain.UpdatedAt) @@ -224,11 +233,11 @@ func scanOrganizationDomain(ctx context.Context, client database.Querier, builde return nil, err } - organizationDomain := &domain.OrganizationDomain{} - if err := rows.(database.CollectableRows).CollectExactlyOneRow(organizationDomain); err != nil { + domain := &domain.OrganizationDomain{} + if err := rows.(database.CollectableRows).CollectExactlyOneRow(domain); err != nil { return nil, err } - return organizationDomain, nil + return domain, nil } func scanOrganizationDomains(ctx context.Context, client database.Querier, builder *database.StatementBuilder) ([]*domain.OrganizationDomain, error) { @@ -237,9 +246,9 @@ func scanOrganizationDomains(ctx context.Context, client database.Querier, build return nil, err } - var organizationDomains []*domain.OrganizationDomain - if err := rows.(database.CollectableRows).Collect(&organizationDomains); err != nil { + var domains []*domain.OrganizationDomain + if err := rows.(database.CollectableRows).Collect(&domains); err != nil { return nil, err } - return organizationDomains, nil + return domains, nil }