From 655f9be015cd17511838f81fd8e92adb7f42bc75 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Mon, 14 Jul 2025 20:11:28 +0000 Subject: [PATCH] Add ID column to domains table and update repository implementation Co-authored-by: adlerhurst <27845747+adlerhurst@users.noreply.github.com> --- DOMAINS_IMPLEMENTATION.md | 126 +++++++ cmd/setup/61/01_create_domains_table.sql | 3 +- internal/query/projection/domains.go | 4 +- internal/query/projection/domains_test.go | 356 ++++++++++++++++++ internal/v2/readmodel/domain_repository.go | 60 +-- .../v2/readmodel/domain_repository_test.go | 25 +- 6 files changed, 537 insertions(+), 37 deletions(-) create mode 100644 DOMAINS_IMPLEMENTATION.md create mode 100644 internal/query/projection/domains_test.go diff --git a/DOMAINS_IMPLEMENTATION.md b/DOMAINS_IMPLEMENTATION.md new file mode 100644 index 0000000000..c8f0076aad --- /dev/null +++ b/DOMAINS_IMPLEMENTATION.md @@ -0,0 +1,126 @@ +# Unified Domains Table Implementation + +This implementation provides a unified domains table (`zitadel.domains`) that consolidates both organization and instance domains into a single table structure. + +## Architecture + +The implementation follows Zitadel's established patterns: + +### Database Layer +- **Migration 61**: Creates the unified domains table with proper constraints +- **Table Structure**: Uses nullable `org_id` to distinguish between instance domains (NULL) and organization domains (NOT NULL) +- **Soft Deletes**: Implements `deleted_at` for data preservation +- **Unique Constraints**: Ensures domain uniqueness within instance/organization scope + +### Domain Layer +- **Interfaces**: Clean separation between instance and organization domain operations +- **Models**: Unified domain model with optional organization ID + +### Repository Layer +- **Implementation**: Single repository handling both instance and organization domains +- **Transactions**: Atomic operations for primary domain changes +- **Query Building**: Type-safe SQL generation using squirrel + +### Projection Layer +- **Event Handling**: Processes both org and instance domain events +- **Data Synchronization**: Maintains consistency with event sourcing + +## Event Mapping + +### Organization Domain Events +- `org.domain.added` → Creates domain record with org_id +- `org.domain.verification.added` → Updates validation_type +- `org.domain.verified` → Sets is_verified = true +- `org.domain.primary.set` → Manages primary domain flags +- `org.domain.removed` → Soft deletes domain +- `org.removed` → Soft deletes all org domains + +### Instance Domain Events +- `instance.domain.added` → Creates domain record with org_id = NULL, is_verified = true +- `instance.domain.primary.set` → Manages primary domain flags +- `instance.domain.removed` → Soft deletes domain +- `instance.removed` → Soft deletes all instance domains + +## Usage Examples + +### Instance Domain Operations +```go +// Add instance domain (always verified) +domain, err := repo.AddInstanceDomain(ctx, "instance-123", "api.example.com") + +// Set primary instance domain +err := repo.SetInstanceDomainPrimary(ctx, "instance-123", "api.example.com") + +// Remove instance domain +err := repo.RemoveInstanceDomain(ctx, "instance-123", "api.example.com") + +// List instance domains +criteria := v2domain.DomainSearchCriteria{ + InstanceID: &instanceID, +} +pagination := v2domain.DomainPagination{ + Limit: 10, + SortBy: v2domain.DomainSortFieldDomain, + Order: database.SortOrderAsc, +} +list, err := repo.List(ctx, criteria, pagination) +``` + +### Organization Domain Operations +```go +// Add organization domain +domain, err := repo.AddOrganizationDomain(ctx, "instance-123", "org-456", "company.com", domain.OrgDomainValidationTypeHTTP) + +// Verify organization domain +err := repo.SetOrganizationDomainVerified(ctx, "instance-123", "org-456", "company.com") + +// Set primary organization domain +err := repo.SetOrganizationDomainPrimary(ctx, "instance-123", "org-456", "company.com") + +// Find specific domain +criteria := v2domain.DomainSearchCriteria{ + Domain: &domainName, + InstanceID: &instanceID, +} +domain, err := repo.Get(ctx, criteria) +``` + +## Testing + +The implementation includes comprehensive tests: + +### Repository Tests +- CRUD operations for both instance and organization domains +- Transaction behavior verification +- Error handling and edge cases +- SQL query parameter validation + +### Projection Tests +- Event reduction logic for all supported events +- Column and condition verification +- Multi-statement handling for primary domain changes +- Soft delete behavior + +## Migration Strategy + +This table is designed to work alongside existing tables initially: + +1. **Phase 1** (This PR): Create unified table and maintain via projections +2. **Phase 2** (Future): Migrate query logic to use unified table +3. **Phase 3** (Future): Deprecate separate org_domains2 and instance_domains tables + +## Performance Considerations + +- **Indexing**: The unique constraint provides efficient domain lookups +- **Queries**: Nullable org_id allows efficient filtering between domain types +- **Pagination**: Supports sorting by created_at, updated_at, and domain name +- **Soft Deletes**: WHERE deleted_at IS NULL conditions optimize active domain queries + +## Validation + +The table enforces: +- Domain length between 1-255 characters +- Non-negative validation_type values +- Foreign key integrity to instances and organizations +- Unique domain constraints per instance/organization scope +- Automatic updated_at timestamp management \ No newline at end of file diff --git a/cmd/setup/61/01_create_domains_table.sql b/cmd/setup/61/01_create_domains_table.sql index ffbd3c25f4..3c54542715 100644 --- a/cmd/setup/61/01_create_domains_table.sql +++ b/cmd/setup/61/01_create_domains_table.sql @@ -1,5 +1,6 @@ CREATE TABLE IF NOT EXISTS zitadel.domains( - instance_id TEXT NOT NULL + id TEXT NOT NULL PRIMARY KEY DEFAULT generate_ulid() + , instance_id TEXT NOT NULL , org_id TEXT , domain TEXT NOT NULL CHECK (LENGTH(domain) BETWEEN 1 AND 255) , is_verified BOOLEAN NOT NULL DEFAULT FALSE diff --git a/internal/query/projection/domains.go b/internal/query/projection/domains.go index 5a014ecf0f..437fcc74f2 100644 --- a/internal/query/projection/domains.go +++ b/internal/query/projection/domains.go @@ -15,6 +15,7 @@ import ( const ( DomainsTable = "zitadel.domains" + DomainsIDCol = "id" DomainsInstanceIDCol = "instance_id" DomainsOrgIDCol = "org_id" DomainsDomainCol = "domain" @@ -40,6 +41,7 @@ func (*domainsProjection) Init() *old_handler.Check { // The table is created by migration, so we just need to check it exists return handler.NewTableCheck( handler.NewTable([]*handler.InitColumn{ + handler.NewColumn(DomainsIDCol, handler.ColumnTypeText), handler.NewColumn(DomainsInstanceIDCol, handler.ColumnTypeText), handler.NewColumn(DomainsOrgIDCol, handler.ColumnTypeText), handler.NewColumn(DomainsDomainCol, handler.ColumnTypeText), @@ -50,7 +52,7 @@ func (*domainsProjection) Init() *old_handler.Check { handler.NewColumn(DomainsUpdatedAtCol, handler.ColumnTypeTimestamp), handler.NewColumn(DomainsDeletedAtCol, handler.ColumnTypeTimestamp), }, - // Note: We don't define primary key here since the table is created by migration + handler.NewPrimaryKey(DomainsIDCol), ), ) } diff --git a/internal/query/projection/domains_test.go b/internal/query/projection/domains_test.go new file mode 100644 index 0000000000..4442080f20 --- /dev/null +++ b/internal/query/projection/domains_test.go @@ -0,0 +1,356 @@ +package projection + +import ( + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/zitadel/zitadel/internal/domain" + "github.com/zitadel/zitadel/internal/eventstore" + "github.com/zitadel/zitadel/internal/eventstore/handler/v2" + "github.com/zitadel/zitadel/internal/repository/instance" + "github.com/zitadel/zitadel/internal/repository/org" +) + +func TestDomainsProjection_reduceOrgDomainAdded(t *testing.T) { + projection := &domainsProjection{} + + event := &org.DomainAddedEvent{ + BaseEvent: &eventstore.BaseEvent{ + Agg: &eventstore.Aggregate{ + ID: "org-id", + InstanceID: "instance-id", + Type: org.AggregateType, + ResourceOwner: "instance-id", + }, + CreationDate: time.Now(), + }, + Domain: "test.example.com", + } + + stmt, err := projection.reduceOrgDomainAdded(event) + + require.NoError(t, err) + assert.NotNil(t, stmt) + + createStmt, ok := stmt.(*handler.CreateStatement) + require.True(t, ok) + + // Verify the columns being set + expectedColumns := map[string]interface{}{ + DomainsInstanceIDCol: "instance-id", + DomainsOrgIDCol: "org-id", + DomainsDomainCol: "test.example.com", + DomainsIsVerifiedCol: false, + DomainsIsPrimaryCol: false, + DomainsValidationTypeCol: domain.OrgDomainValidationTypeUnspecified, + DomainsCreatedAtCol: event.CreationDate(), + DomainsUpdatedAtCol: event.CreationDate(), + } + + assert.Len(t, createStmt.Cols, len(expectedColumns)) + + for i, col := range createStmt.Cols { + switch col.Name { + case DomainsInstanceIDCol: + assert.Equal(t, "instance-id", col.Value) + case DomainsOrgIDCol: + assert.Equal(t, "org-id", col.Value) + case DomainsDomainCol: + assert.Equal(t, "test.example.com", col.Value) + case DomainsIsVerifiedCol: + assert.Equal(t, false, col.Value) + case DomainsIsPrimaryCol: + assert.Equal(t, false, col.Value) + case DomainsValidationTypeCol: + assert.Equal(t, domain.OrgDomainValidationTypeUnspecified, col.Value) + case DomainsCreatedAtCol, DomainsUpdatedAtCol: + assert.Equal(t, event.CreationDate(), col.Value) + default: + t.Errorf("Unexpected column: %s at index %d", col.Name, i) + } + } +} + +func TestDomainsProjection_reduceOrgDomainVerified(t *testing.T) { + projection := &domainsProjection{} + + event := &org.DomainVerifiedEvent{ + BaseEvent: &eventstore.BaseEvent{ + Agg: &eventstore.Aggregate{ + ID: "org-id", + InstanceID: "instance-id", + Type: org.AggregateType, + ResourceOwner: "instance-id", + }, + CreationDate: time.Now(), + }, + Domain: "test.example.com", + } + + stmt, err := projection.reduceOrgDomainVerified(event) + + require.NoError(t, err) + assert.NotNil(t, stmt) + + updateStmt, ok := stmt.(*handler.UpdateStatement) + require.True(t, ok) + + // Verify update columns + assert.Len(t, updateStmt.Cols, 2) + assert.Equal(t, DomainsUpdatedAtCol, updateStmt.Cols[0].Name) + assert.Equal(t, event.CreationDate(), updateStmt.Cols[0].Value) + assert.Equal(t, DomainsIsVerifiedCol, updateStmt.Cols[1].Name) + assert.Equal(t, true, updateStmt.Cols[1].Value) + + // Verify conditions + assert.Len(t, updateStmt.Conditions, 4) + + conditionMap := make(map[string]interface{}) + for _, cond := range updateStmt.Conditions { + conditionMap[cond.Name] = cond.Value + } + + assert.Equal(t, "instance-id", conditionMap[DomainsInstanceIDCol]) + assert.Equal(t, "org-id", conditionMap[DomainsOrgIDCol]) + assert.Equal(t, "test.example.com", conditionMap[DomainsDomainCol]) + assert.Nil(t, conditionMap[DomainsDeletedAtCol]) +} + +func TestDomainsProjection_reduceOrgPrimaryDomainSet(t *testing.T) { + projection := &domainsProjection{} + + event := &org.DomainPrimarySetEvent{ + BaseEvent: &eventstore.BaseEvent{ + Agg: &eventstore.Aggregate{ + ID: "org-id", + InstanceID: "instance-id", + Type: org.AggregateType, + ResourceOwner: "instance-id", + }, + CreationDate: time.Now(), + }, + Domain: "test.example.com", + } + + stmt, err := projection.reduceOrgPrimaryDomainSet(event) + + require.NoError(t, err) + assert.NotNil(t, stmt) + + multiStmt, ok := stmt.(*handler.MultiStatement) + require.True(t, ok) + + // Should have 2 update statements: unset old primary, set new primary + assert.Len(t, multiStmt.Statements, 2) + + // First statement: unset existing primary + unsetStmt, ok := multiStmt.Statements[0].(*handler.UpdateStatement) + require.True(t, ok) + + assert.Len(t, unsetStmt.Cols, 2) + assert.Equal(t, DomainsUpdatedAtCol, unsetStmt.Cols[0].Name) + assert.Equal(t, DomainsIsPrimaryCol, unsetStmt.Cols[1].Name) + assert.Equal(t, false, unsetStmt.Cols[1].Value) + + // Second statement: set new primary + setStmt, ok := multiStmt.Statements[1].(*handler.UpdateStatement) + require.True(t, ok) + + assert.Len(t, setStmt.Cols, 2) + assert.Equal(t, DomainsUpdatedAtCol, setStmt.Cols[0].Name) + assert.Equal(t, DomainsIsPrimaryCol, setStmt.Cols[1].Name) + assert.Equal(t, true, setStmt.Cols[1].Value) +} + +func TestDomainsProjection_reduceInstanceDomainAdded(t *testing.T) { + projection := &domainsProjection{} + + event := &instance.DomainAddedEvent{ + BaseEvent: &eventstore.BaseEvent{ + Agg: &eventstore.Aggregate{ + ID: "instance-id", + InstanceID: "instance-id", + Type: instance.AggregateType, + ResourceOwner: "instance-id", + }, + CreationDate: time.Now(), + }, + Domain: "instance.example.com", + Generated: false, + } + + stmt, err := projection.reduceInstanceDomainAdded(event) + + require.NoError(t, err) + assert.NotNil(t, stmt) + + createStmt, ok := stmt.(*handler.CreateStatement) + require.True(t, ok) + + // Verify the columns being set for instance domain + expectedColumns := map[string]interface{}{ + DomainsInstanceIDCol: "instance-id", + DomainsOrgIDCol: nil, // Instance domains have no org_id + DomainsDomainCol: "instance.example.com", + DomainsIsVerifiedCol: true, // Instance domains are always verified + DomainsIsPrimaryCol: false, + DomainsValidationTypeCol: nil, // Instance domains have no validation type + DomainsCreatedAtCol: event.CreationDate(), + DomainsUpdatedAtCol: event.CreationDate(), + } + + assert.Len(t, createStmt.Cols, len(expectedColumns)) + + for _, col := range createStmt.Cols { + switch col.Name { + case DomainsInstanceIDCol: + assert.Equal(t, "instance-id", col.Value) + case DomainsOrgIDCol: + assert.Nil(t, col.Value) + case DomainsDomainCol: + assert.Equal(t, "instance.example.com", col.Value) + case DomainsIsVerifiedCol: + assert.Equal(t, true, col.Value) + case DomainsIsPrimaryCol: + assert.Equal(t, false, col.Value) + case DomainsValidationTypeCol: + assert.Nil(t, col.Value) + case DomainsCreatedAtCol, DomainsUpdatedAtCol: + assert.Equal(t, event.CreationDate(), col.Value) + default: + t.Errorf("Unexpected column: %s", col.Name) + } + } +} + +func TestDomainsProjection_reduceOrgDomainRemoved(t *testing.T) { + projection := &domainsProjection{} + + event := &org.DomainRemovedEvent{ + BaseEvent: &eventstore.BaseEvent{ + Agg: &eventstore.Aggregate{ + ID: "org-id", + InstanceID: "instance-id", + Type: org.AggregateType, + ResourceOwner: "instance-id", + }, + CreationDate: time.Now(), + }, + Domain: "test.example.com", + } + + stmt, err := projection.reduceOrgDomainRemoved(event) + + require.NoError(t, err) + assert.NotNil(t, stmt) + + updateStmt, ok := stmt.(*handler.UpdateStatement) + require.True(t, ok) + + // Should update updated_at and deleted_at + assert.Len(t, updateStmt.Cols, 2) + assert.Equal(t, DomainsUpdatedAtCol, updateStmt.Cols[0].Name) + assert.Equal(t, event.CreationDate(), updateStmt.Cols[0].Value) + assert.Equal(t, DomainsDeletedAtCol, updateStmt.Cols[1].Name) + assert.Equal(t, event.CreationDate(), updateStmt.Cols[1].Value) + + // Verify conditions include instance, org, and domain + assert.Len(t, updateStmt.Conditions, 3) + + conditionMap := make(map[string]interface{}) + for _, cond := range updateStmt.Conditions { + conditionMap[cond.Name] = cond.Value + } + + assert.Equal(t, "instance-id", conditionMap[DomainsInstanceIDCol]) + assert.Equal(t, "org-id", conditionMap[DomainsOrgIDCol]) + assert.Equal(t, "test.example.com", conditionMap[DomainsDomainCol]) +} + +func TestDomainsProjection_reduceOrgRemoved(t *testing.T) { + projection := &domainsProjection{} + + event := &org.OrgRemovedEvent{ + BaseEvent: &eventstore.BaseEvent{ + Agg: &eventstore.Aggregate{ + ID: "org-id", + InstanceID: "instance-id", + Type: org.AggregateType, + ResourceOwner: "instance-id", + }, + CreationDate: time.Now(), + }, + } + + stmt, err := projection.reduceOrgRemoved(event) + + require.NoError(t, err) + assert.NotNil(t, stmt) + + updateStmt, ok := stmt.(*handler.UpdateStatement) + require.True(t, ok) + + // Should update updated_at and deleted_at + assert.Len(t, updateStmt.Cols, 2) + assert.Equal(t, DomainsUpdatedAtCol, updateStmt.Cols[0].Name) + assert.Equal(t, event.CreationDate(), updateStmt.Cols[0].Value) + assert.Equal(t, DomainsDeletedAtCol, updateStmt.Cols[1].Name) + assert.Equal(t, event.CreationDate(), updateStmt.Cols[1].Value) + + // Should soft delete all domains for the org + assert.Len(t, updateStmt.Conditions, 3) + + conditionMap := make(map[string]interface{}) + for _, cond := range updateStmt.Conditions { + conditionMap[cond.Name] = cond.Value + } + + assert.Equal(t, "instance-id", conditionMap[DomainsInstanceIDCol]) + assert.Equal(t, "org-id", conditionMap[DomainsOrgIDCol]) + assert.Nil(t, conditionMap[DomainsDeletedAtCol]) // Only affect non-deleted domains +} + +func TestDomainsProjection_reduceInstanceRemoved(t *testing.T) { + projection := &domainsProjection{} + + event := &instance.InstanceRemovedEvent{ + BaseEvent: &eventstore.BaseEvent{ + Agg: &eventstore.Aggregate{ + ID: "instance-id", + InstanceID: "instance-id", + Type: instance.AggregateType, + ResourceOwner: "instance-id", + }, + CreationDate: time.Now(), + }, + } + + stmt, err := projection.reduceInstanceRemoved(event) + + require.NoError(t, err) + assert.NotNil(t, stmt) + + updateStmt, ok := stmt.(*handler.UpdateStatement) + require.True(t, ok) + + // Should update updated_at and deleted_at + assert.Len(t, updateStmt.Cols, 2) + assert.Equal(t, DomainsUpdatedAtCol, updateStmt.Cols[0].Name) + assert.Equal(t, event.CreationDate(), updateStmt.Cols[0].Value) + assert.Equal(t, DomainsDeletedAtCol, updateStmt.Cols[1].Name) + assert.Equal(t, event.CreationDate(), updateStmt.Cols[1].Value) + + // Should soft delete all domains for the instance + assert.Len(t, updateStmt.Conditions, 2) + + conditionMap := make(map[string]interface{}) + for _, cond := range updateStmt.Conditions { + conditionMap[cond.Name] = cond.Value + } + + assert.Equal(t, "instance-id", conditionMap[DomainsInstanceIDCol]) + assert.Nil(t, conditionMap[DomainsDeletedAtCol]) // Only affect non-deleted domains +} \ No newline at end of file diff --git a/internal/v2/readmodel/domain_repository.go b/internal/v2/readmodel/domain_repository.go index 6acdef87a8..be38635752 100644 --- a/internal/v2/readmodel/domain_repository.go +++ b/internal/v2/readmodel/domain_repository.go @@ -30,16 +30,17 @@ type TX interface { } const ( - domainsTable = "zitadel.domains" - domainsInstanceIDCol = "instance_id" - domainsOrgIDCol = "org_id" - domainsDomainCol = "domain" - domainsIsVerifiedCol = "is_verified" - domainsIsPrimaryCol = "is_primary" + domainsTable = "zitadel.domains" + domainsIDCol = "id" + domainsInstanceIDCol = "instance_id" + domainsOrgIDCol = "org_id" + domainsDomainCol = "domain" + domainsIsVerifiedCol = "is_verified" + domainsIsPrimaryCol = "is_primary" domainsValidationTypeCol = "validation_type" - domainsCreatedAtCol = "created_at" - domainsUpdatedAtCol = "updated_at" - domainsDeletedAtCol = "deleted_at" + domainsCreatedAtCol = "created_at" + domainsUpdatedAtCol = "updated_at" + domainsDeletedAtCol = "deleted_at" ) // DomainRepository implements both InstanceDomainRepository and OrganizationDomainRepository @@ -67,6 +68,7 @@ func (r *DomainRepository) AddInstanceDomain(ctx context.Context, instanceID, do domainsUpdatedAtCol, ). Values(instanceID, domain, true, false, now, now). + Suffix("RETURNING " + domainsIDCol). PlaceholderFormat(squirrel.Dollar) stmt, args, err := query.ToSql() @@ -74,12 +76,14 @@ func (r *DomainRepository) AddInstanceDomain(ctx context.Context, instanceID, do return nil, zerrors.ThrowInternal(err, "DOMAIN-1n8fK", "Errors.Internal") } - _, err = r.client.ExecContext(ctx, stmt, args...) + var id string + err = r.client.QueryRowContext(ctx, stmt, args...).Scan(&id) if err != nil { return nil, zerrors.ThrowInternal(err, "DOMAIN-3m9sL", "Errors.Internal") } return &v2domain.Domain{ + ID: id, InstanceID: instanceID, OrganizationID: nil, Domain: domain, @@ -93,7 +97,7 @@ func (r *DomainRepository) AddInstanceDomain(ctx context.Context, instanceID, do // SetPrimary sets the primary domain for an instance func (r *DomainRepository) SetInstanceDomainPrimary(ctx context.Context, instanceID, domain string) error { - return r.withTransaction(ctx, func(tx database.Tx) error { + return r.withTransaction(ctx, func(tx TX) error { // First, unset any existing primary domain for this instance unsetQuery := squirrel.Update(domainsTable). Set(domainsIsPrimaryCol, false). @@ -201,6 +205,7 @@ func (r *DomainRepository) AddOrganizationDomain(ctx context.Context, instanceID domainsUpdatedAtCol, ). Values(instanceID, organizationID, domain, false, false, int(validationType), now, now). + Suffix("RETURNING " + domainsIDCol). PlaceholderFormat(squirrel.Dollar) stmt, args, err := query.ToSql() @@ -208,12 +213,14 @@ func (r *DomainRepository) AddOrganizationDomain(ctx context.Context, instanceID return nil, zerrors.ThrowInternal(err, "DOMAIN-Ew2xU", "Errors.Internal") } - _, err = r.client.ExecContext(ctx, stmt, args...) + var id string + err = r.client.QueryRowContext(ctx, stmt, args...).Scan(&id) if err != nil { return nil, zerrors.ThrowInternal(err, "DOMAIN-Fx3yV", "Errors.Internal") } return &v2domain.Domain{ + ID: id, InstanceID: instanceID, OrganizationID: &organizationID, Domain: domain, @@ -262,7 +269,7 @@ func (r *DomainRepository) SetOrganizationDomainVerified(ctx context.Context, in // SetPrimary sets the primary domain for an organization func (r *DomainRepository) SetOrganizationDomainPrimary(ctx context.Context, instanceID, organizationID, domain string) error { - return r.withTransaction(ctx, func(tx database.Tx) error { + return r.withTransaction(ctx, func(tx TX) error { // First, unset any existing primary domain for this organization unsetQuery := squirrel.Update(domainsTable). Set(domainsIsPrimaryCol, false). @@ -443,6 +450,7 @@ func (r *DomainRepository) List(ctx context.Context, criteria v2domain.DomainSea func (r *DomainRepository) buildSelectQuery(criteria v2domain.DomainSearchCriteria, pagination v2domain.DomainPagination) squirrel.SelectBuilder { query := squirrel.Select( + domainsIDCol, domainsInstanceIDCol, domainsOrgIDCol, domainsDomainCol, @@ -474,8 +482,7 @@ func (r *DomainRepository) applySearchCriteria(query squirrel.SelectBuilder, cri query = query.Where(squirrel.Expr(domainsDeletedAtCol + " IS NULL")) if criteria.ID != nil { - // Note: Our table doesn't have an ID column. This might need to be reconsidered - // For now, we'll ignore this criterion since the spec doesn't define where ID comes from + query = query.Where(squirrel.Eq{domainsIDCol: *criteria.ID}) } if criteria.Domain != nil { @@ -536,20 +543,21 @@ func (r *DomainRepository) applyPagination(query squirrel.SelectBuilder, paginat } func (r *DomainRepository) scanDomain(rows *sql.Rows) (*v2domain.Domain, error) { - var domain v2domain.Domain + var domainRecord v2domain.Domain var orgID sql.NullString var validationType sql.NullInt32 var deletedAt sql.NullTime err := rows.Scan( - &domain.InstanceID, + &domainRecord.ID, + &domainRecord.InstanceID, &orgID, - &domain.Domain, - &domain.IsVerified, - &domain.IsPrimary, + &domainRecord.Domain, + &domainRecord.IsVerified, + &domainRecord.IsPrimary, &validationType, - &domain.CreatedAt, - &domain.UpdatedAt, + &domainRecord.CreatedAt, + &domainRecord.UpdatedAt, &deletedAt, ) if err != nil { @@ -557,19 +565,19 @@ func (r *DomainRepository) scanDomain(rows *sql.Rows) (*v2domain.Domain, error) } if orgID.Valid { - domain.OrganizationID = &orgID.String + domainRecord.OrganizationID = &orgID.String } if validationType.Valid { validationTypeValue := domain.OrgDomainValidationType(validationType.Int32) - domain.ValidationType = &validationTypeValue + domainRecord.ValidationType = &validationTypeValue } if deletedAt.Valid { - domain.DeletedAt = &deletedAt.Time + domainRecord.DeletedAt = &deletedAt.Time } - return &domain, nil + return &domainRecord, nil } func (r *DomainRepository) withTransaction(ctx context.Context, fn func(TX) error) error { diff --git a/internal/v2/readmodel/domain_repository_test.go b/internal/v2/readmodel/domain_repository_test.go index e1c2194a43..92c24ab348 100644 --- a/internal/v2/readmodel/domain_repository_test.go +++ b/internal/v2/readmodel/domain_repository_test.go @@ -25,14 +25,16 @@ func TestDomainRepository_AddInstanceDomain(t *testing.T) { instanceID := "test-instance-id" domainName := "test.example.com" + expectedID := "domain-id-123" - mock.ExpectExec(`INSERT INTO zitadel\.domains`). + mock.ExpectQuery(`INSERT INTO zitadel\.domains`). WithArgs(instanceID, domainName, true, false, sqlmock.AnyArg(), sqlmock.AnyArg()). - WillReturnResult(sqlmock.NewResult(1, 1)) + WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(expectedID)) result, err := repo.AddInstanceDomain(context.Background(), instanceID, domainName) require.NoError(t, err) + assert.Equal(t, expectedID, result.ID) assert.Equal(t, instanceID, result.InstanceID) assert.Nil(t, result.OrganizationID) assert.Equal(t, domainName, result.Domain) @@ -54,14 +56,16 @@ func TestDomainRepository_AddOrganizationDomain(t *testing.T) { orgID := "test-org-id" domainName := "test.example.com" validationType := domain.OrgDomainValidationTypeHTTP + expectedID := "domain-id-456" - mock.ExpectExec(`INSERT INTO zitadel\.domains`). + mock.ExpectQuery(`INSERT INTO zitadel\.domains`). WithArgs(instanceID, orgID, domainName, false, false, int(validationType), sqlmock.AnyArg(), sqlmock.AnyArg()). - WillReturnResult(sqlmock.NewResult(1, 1)) + WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(expectedID)) result, err := repo.AddOrganizationDomain(context.Background(), instanceID, orgID, domainName, validationType) require.NoError(t, err) + assert.Equal(t, expectedID, result.ID) assert.Equal(t, instanceID, result.InstanceID) assert.Equal(t, orgID, *result.OrganizationID) assert.Equal(t, domainName, result.Domain) @@ -120,8 +124,8 @@ func TestDomainRepository_Get(t *testing.T) { } rows := sqlmock.NewRows([]string{ - "instance_id", "org_id", "domain", "is_verified", "is_primary", "validation_type", "created_at", "updated_at", "deleted_at", - }).AddRow(instanceID, nil, domainName, true, false, nil, now, now, nil) + "id", "instance_id", "org_id", "domain", "is_verified", "is_primary", "validation_type", "created_at", "updated_at", "deleted_at", + }).AddRow("domain-123", instanceID, nil, domainName, true, false, nil, now, now, nil) mock.ExpectQuery(`SELECT .* FROM zitadel\.domains`). WithArgs(domainName, instanceID). @@ -130,6 +134,7 @@ func TestDomainRepository_Get(t *testing.T) { result, err := repo.Get(context.Background(), criteria) require.NoError(t, err) + assert.Equal(t, "domain-123", result.ID) assert.Equal(t, instanceID, result.InstanceID) assert.Nil(t, result.OrganizationID) assert.Equal(t, domainName, result.Domain) @@ -168,10 +173,10 @@ func TestDomainRepository_List(t *testing.T) { // Mock data query rows := sqlmock.NewRows([]string{ - "instance_id", "org_id", "domain", "is_verified", "is_primary", "validation_type", "created_at", "updated_at", "deleted_at", + "id", "instance_id", "org_id", "domain", "is_verified", "is_primary", "validation_type", "created_at", "updated_at", "deleted_at", }). - AddRow(instanceID, nil, "instance.example.com", true, true, nil, now, now, nil). - AddRow(instanceID, "org-id", "org.example.com", false, false, int(domain.OrgDomainValidationTypeHTTP), now, now, nil) + AddRow("domain-instance", instanceID, nil, "instance.example.com", true, true, nil, now, now, nil). + AddRow("domain-org", instanceID, "org-id", "org.example.com", false, false, int(domain.OrgDomainValidationTypeHTTP), now, now, nil) mock.ExpectQuery(`SELECT .* FROM zitadel\.domains.*ORDER BY domain ASC.*LIMIT 10`). WithArgs(instanceID). @@ -184,6 +189,7 @@ func TestDomainRepository_List(t *testing.T) { assert.Len(t, result.Domains, 2) // Check first domain (instance domain) + assert.Equal(t, "domain-instance", result.Domains[0].ID) assert.Equal(t, instanceID, result.Domains[0].InstanceID) assert.Nil(t, result.Domains[0].OrganizationID) assert.Equal(t, "instance.example.com", result.Domains[0].Domain) @@ -191,6 +197,7 @@ func TestDomainRepository_List(t *testing.T) { assert.True(t, result.Domains[0].IsPrimary) // Check second domain (org domain) + assert.Equal(t, "domain-org", result.Domains[1].ID) assert.Equal(t, instanceID, result.Domains[1].InstanceID) assert.Equal(t, "org-id", *result.Domains[1].OrganizationID) assert.Equal(t, "org.example.com", result.Domains[1].Domain)