Add ID column to domains table and update repository implementation

Co-authored-by: adlerhurst <27845747+adlerhurst@users.noreply.github.com>
This commit is contained in:
copilot-swe-agent[bot]
2025-07-14 20:11:28 +00:00
parent d3de8a2150
commit 655f9be015
6 changed files with 537 additions and 37 deletions

126
DOMAINS_IMPLEMENTATION.md Normal file
View File

@@ -0,0 +1,126 @@
# Unified Domains Table Implementation
This implementation provides a unified domains table (`zitadel.domains`) that consolidates both organization and instance domains into a single table structure.
## Architecture
The implementation follows Zitadel's established patterns:
### Database Layer
- **Migration 61**: Creates the unified domains table with proper constraints
- **Table Structure**: Uses nullable `org_id` to distinguish between instance domains (NULL) and organization domains (NOT NULL)
- **Soft Deletes**: Implements `deleted_at` for data preservation
- **Unique Constraints**: Ensures domain uniqueness within instance/organization scope
### Domain Layer
- **Interfaces**: Clean separation between instance and organization domain operations
- **Models**: Unified domain model with optional organization ID
### Repository Layer
- **Implementation**: Single repository handling both instance and organization domains
- **Transactions**: Atomic operations for primary domain changes
- **Query Building**: Type-safe SQL generation using squirrel
### Projection Layer
- **Event Handling**: Processes both org and instance domain events
- **Data Synchronization**: Maintains consistency with event sourcing
## Event Mapping
### Organization Domain Events
- `org.domain.added` → Creates domain record with org_id
- `org.domain.verification.added` → Updates validation_type
- `org.domain.verified` → Sets is_verified = true
- `org.domain.primary.set` → Manages primary domain flags
- `org.domain.removed` → Soft deletes domain
- `org.removed` → Soft deletes all org domains
### Instance Domain Events
- `instance.domain.added` → Creates domain record with org_id = NULL, is_verified = true
- `instance.domain.primary.set` → Manages primary domain flags
- `instance.domain.removed` → Soft deletes domain
- `instance.removed` → Soft deletes all instance domains
## Usage Examples
### Instance Domain Operations
```go
// Add instance domain (always verified)
domain, err := repo.AddInstanceDomain(ctx, "instance-123", "api.example.com")
// Set primary instance domain
err := repo.SetInstanceDomainPrimary(ctx, "instance-123", "api.example.com")
// Remove instance domain
err := repo.RemoveInstanceDomain(ctx, "instance-123", "api.example.com")
// List instance domains
criteria := v2domain.DomainSearchCriteria{
InstanceID: &instanceID,
}
pagination := v2domain.DomainPagination{
Limit: 10,
SortBy: v2domain.DomainSortFieldDomain,
Order: database.SortOrderAsc,
}
list, err := repo.List(ctx, criteria, pagination)
```
### Organization Domain Operations
```go
// Add organization domain
domain, err := repo.AddOrganizationDomain(ctx, "instance-123", "org-456", "company.com", domain.OrgDomainValidationTypeHTTP)
// Verify organization domain
err := repo.SetOrganizationDomainVerified(ctx, "instance-123", "org-456", "company.com")
// Set primary organization domain
err := repo.SetOrganizationDomainPrimary(ctx, "instance-123", "org-456", "company.com")
// Find specific domain
criteria := v2domain.DomainSearchCriteria{
Domain: &domainName,
InstanceID: &instanceID,
}
domain, err := repo.Get(ctx, criteria)
```
## Testing
The implementation includes comprehensive tests:
### Repository Tests
- CRUD operations for both instance and organization domains
- Transaction behavior verification
- Error handling and edge cases
- SQL query parameter validation
### Projection Tests
- Event reduction logic for all supported events
- Column and condition verification
- Multi-statement handling for primary domain changes
- Soft delete behavior
## Migration Strategy
This table is designed to work alongside existing tables initially:
1. **Phase 1** (This PR): Create unified table and maintain via projections
2. **Phase 2** (Future): Migrate query logic to use unified table
3. **Phase 3** (Future): Deprecate separate org_domains2 and instance_domains tables
## Performance Considerations
- **Indexing**: The unique constraint provides efficient domain lookups
- **Queries**: Nullable org_id allows efficient filtering between domain types
- **Pagination**: Supports sorting by created_at, updated_at, and domain name
- **Soft Deletes**: WHERE deleted_at IS NULL conditions optimize active domain queries
## Validation
The table enforces:
- Domain length between 1-255 characters
- Non-negative validation_type values
- Foreign key integrity to instances and organizations
- Unique domain constraints per instance/organization scope
- Automatic updated_at timestamp management

View File

@@ -1,5 +1,6 @@
CREATE TABLE IF NOT EXISTS zitadel.domains( CREATE TABLE IF NOT EXISTS zitadel.domains(
instance_id TEXT NOT NULL id TEXT NOT NULL PRIMARY KEY DEFAULT generate_ulid()
, instance_id TEXT NOT NULL
, org_id TEXT , org_id TEXT
, domain TEXT NOT NULL CHECK (LENGTH(domain) BETWEEN 1 AND 255) , domain TEXT NOT NULL CHECK (LENGTH(domain) BETWEEN 1 AND 255)
, is_verified BOOLEAN NOT NULL DEFAULT FALSE , is_verified BOOLEAN NOT NULL DEFAULT FALSE

View File

@@ -15,6 +15,7 @@ import (
const ( const (
DomainsTable = "zitadel.domains" DomainsTable = "zitadel.domains"
DomainsIDCol = "id"
DomainsInstanceIDCol = "instance_id" DomainsInstanceIDCol = "instance_id"
DomainsOrgIDCol = "org_id" DomainsOrgIDCol = "org_id"
DomainsDomainCol = "domain" DomainsDomainCol = "domain"
@@ -40,6 +41,7 @@ func (*domainsProjection) Init() *old_handler.Check {
// The table is created by migration, so we just need to check it exists // The table is created by migration, so we just need to check it exists
return handler.NewTableCheck( return handler.NewTableCheck(
handler.NewTable([]*handler.InitColumn{ handler.NewTable([]*handler.InitColumn{
handler.NewColumn(DomainsIDCol, handler.ColumnTypeText),
handler.NewColumn(DomainsInstanceIDCol, handler.ColumnTypeText), handler.NewColumn(DomainsInstanceIDCol, handler.ColumnTypeText),
handler.NewColumn(DomainsOrgIDCol, handler.ColumnTypeText), handler.NewColumn(DomainsOrgIDCol, handler.ColumnTypeText),
handler.NewColumn(DomainsDomainCol, handler.ColumnTypeText), handler.NewColumn(DomainsDomainCol, handler.ColumnTypeText),
@@ -50,7 +52,7 @@ func (*domainsProjection) Init() *old_handler.Check {
handler.NewColumn(DomainsUpdatedAtCol, handler.ColumnTypeTimestamp), handler.NewColumn(DomainsUpdatedAtCol, handler.ColumnTypeTimestamp),
handler.NewColumn(DomainsDeletedAtCol, handler.ColumnTypeTimestamp), handler.NewColumn(DomainsDeletedAtCol, handler.ColumnTypeTimestamp),
}, },
// Note: We don't define primary key here since the table is created by migration handler.NewPrimaryKey(DomainsIDCol),
), ),
) )
} }

View File

@@ -0,0 +1,356 @@
package projection
import (
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/zitadel/zitadel/internal/domain"
"github.com/zitadel/zitadel/internal/eventstore"
"github.com/zitadel/zitadel/internal/eventstore/handler/v2"
"github.com/zitadel/zitadel/internal/repository/instance"
"github.com/zitadel/zitadel/internal/repository/org"
)
func TestDomainsProjection_reduceOrgDomainAdded(t *testing.T) {
projection := &domainsProjection{}
event := &org.DomainAddedEvent{
BaseEvent: &eventstore.BaseEvent{
Agg: &eventstore.Aggregate{
ID: "org-id",
InstanceID: "instance-id",
Type: org.AggregateType,
ResourceOwner: "instance-id",
},
CreationDate: time.Now(),
},
Domain: "test.example.com",
}
stmt, err := projection.reduceOrgDomainAdded(event)
require.NoError(t, err)
assert.NotNil(t, stmt)
createStmt, ok := stmt.(*handler.CreateStatement)
require.True(t, ok)
// Verify the columns being set
expectedColumns := map[string]interface{}{
DomainsInstanceIDCol: "instance-id",
DomainsOrgIDCol: "org-id",
DomainsDomainCol: "test.example.com",
DomainsIsVerifiedCol: false,
DomainsIsPrimaryCol: false,
DomainsValidationTypeCol: domain.OrgDomainValidationTypeUnspecified,
DomainsCreatedAtCol: event.CreationDate(),
DomainsUpdatedAtCol: event.CreationDate(),
}
assert.Len(t, createStmt.Cols, len(expectedColumns))
for i, col := range createStmt.Cols {
switch col.Name {
case DomainsInstanceIDCol:
assert.Equal(t, "instance-id", col.Value)
case DomainsOrgIDCol:
assert.Equal(t, "org-id", col.Value)
case DomainsDomainCol:
assert.Equal(t, "test.example.com", col.Value)
case DomainsIsVerifiedCol:
assert.Equal(t, false, col.Value)
case DomainsIsPrimaryCol:
assert.Equal(t, false, col.Value)
case DomainsValidationTypeCol:
assert.Equal(t, domain.OrgDomainValidationTypeUnspecified, col.Value)
case DomainsCreatedAtCol, DomainsUpdatedAtCol:
assert.Equal(t, event.CreationDate(), col.Value)
default:
t.Errorf("Unexpected column: %s at index %d", col.Name, i)
}
}
}
func TestDomainsProjection_reduceOrgDomainVerified(t *testing.T) {
projection := &domainsProjection{}
event := &org.DomainVerifiedEvent{
BaseEvent: &eventstore.BaseEvent{
Agg: &eventstore.Aggregate{
ID: "org-id",
InstanceID: "instance-id",
Type: org.AggregateType,
ResourceOwner: "instance-id",
},
CreationDate: time.Now(),
},
Domain: "test.example.com",
}
stmt, err := projection.reduceOrgDomainVerified(event)
require.NoError(t, err)
assert.NotNil(t, stmt)
updateStmt, ok := stmt.(*handler.UpdateStatement)
require.True(t, ok)
// Verify update columns
assert.Len(t, updateStmt.Cols, 2)
assert.Equal(t, DomainsUpdatedAtCol, updateStmt.Cols[0].Name)
assert.Equal(t, event.CreationDate(), updateStmt.Cols[0].Value)
assert.Equal(t, DomainsIsVerifiedCol, updateStmt.Cols[1].Name)
assert.Equal(t, true, updateStmt.Cols[1].Value)
// Verify conditions
assert.Len(t, updateStmt.Conditions, 4)
conditionMap := make(map[string]interface{})
for _, cond := range updateStmt.Conditions {
conditionMap[cond.Name] = cond.Value
}
assert.Equal(t, "instance-id", conditionMap[DomainsInstanceIDCol])
assert.Equal(t, "org-id", conditionMap[DomainsOrgIDCol])
assert.Equal(t, "test.example.com", conditionMap[DomainsDomainCol])
assert.Nil(t, conditionMap[DomainsDeletedAtCol])
}
func TestDomainsProjection_reduceOrgPrimaryDomainSet(t *testing.T) {
projection := &domainsProjection{}
event := &org.DomainPrimarySetEvent{
BaseEvent: &eventstore.BaseEvent{
Agg: &eventstore.Aggregate{
ID: "org-id",
InstanceID: "instance-id",
Type: org.AggregateType,
ResourceOwner: "instance-id",
},
CreationDate: time.Now(),
},
Domain: "test.example.com",
}
stmt, err := projection.reduceOrgPrimaryDomainSet(event)
require.NoError(t, err)
assert.NotNil(t, stmt)
multiStmt, ok := stmt.(*handler.MultiStatement)
require.True(t, ok)
// Should have 2 update statements: unset old primary, set new primary
assert.Len(t, multiStmt.Statements, 2)
// First statement: unset existing primary
unsetStmt, ok := multiStmt.Statements[0].(*handler.UpdateStatement)
require.True(t, ok)
assert.Len(t, unsetStmt.Cols, 2)
assert.Equal(t, DomainsUpdatedAtCol, unsetStmt.Cols[0].Name)
assert.Equal(t, DomainsIsPrimaryCol, unsetStmt.Cols[1].Name)
assert.Equal(t, false, unsetStmt.Cols[1].Value)
// Second statement: set new primary
setStmt, ok := multiStmt.Statements[1].(*handler.UpdateStatement)
require.True(t, ok)
assert.Len(t, setStmt.Cols, 2)
assert.Equal(t, DomainsUpdatedAtCol, setStmt.Cols[0].Name)
assert.Equal(t, DomainsIsPrimaryCol, setStmt.Cols[1].Name)
assert.Equal(t, true, setStmt.Cols[1].Value)
}
func TestDomainsProjection_reduceInstanceDomainAdded(t *testing.T) {
projection := &domainsProjection{}
event := &instance.DomainAddedEvent{
BaseEvent: &eventstore.BaseEvent{
Agg: &eventstore.Aggregate{
ID: "instance-id",
InstanceID: "instance-id",
Type: instance.AggregateType,
ResourceOwner: "instance-id",
},
CreationDate: time.Now(),
},
Domain: "instance.example.com",
Generated: false,
}
stmt, err := projection.reduceInstanceDomainAdded(event)
require.NoError(t, err)
assert.NotNil(t, stmt)
createStmt, ok := stmt.(*handler.CreateStatement)
require.True(t, ok)
// Verify the columns being set for instance domain
expectedColumns := map[string]interface{}{
DomainsInstanceIDCol: "instance-id",
DomainsOrgIDCol: nil, // Instance domains have no org_id
DomainsDomainCol: "instance.example.com",
DomainsIsVerifiedCol: true, // Instance domains are always verified
DomainsIsPrimaryCol: false,
DomainsValidationTypeCol: nil, // Instance domains have no validation type
DomainsCreatedAtCol: event.CreationDate(),
DomainsUpdatedAtCol: event.CreationDate(),
}
assert.Len(t, createStmt.Cols, len(expectedColumns))
for _, col := range createStmt.Cols {
switch col.Name {
case DomainsInstanceIDCol:
assert.Equal(t, "instance-id", col.Value)
case DomainsOrgIDCol:
assert.Nil(t, col.Value)
case DomainsDomainCol:
assert.Equal(t, "instance.example.com", col.Value)
case DomainsIsVerifiedCol:
assert.Equal(t, true, col.Value)
case DomainsIsPrimaryCol:
assert.Equal(t, false, col.Value)
case DomainsValidationTypeCol:
assert.Nil(t, col.Value)
case DomainsCreatedAtCol, DomainsUpdatedAtCol:
assert.Equal(t, event.CreationDate(), col.Value)
default:
t.Errorf("Unexpected column: %s", col.Name)
}
}
}
func TestDomainsProjection_reduceOrgDomainRemoved(t *testing.T) {
projection := &domainsProjection{}
event := &org.DomainRemovedEvent{
BaseEvent: &eventstore.BaseEvent{
Agg: &eventstore.Aggregate{
ID: "org-id",
InstanceID: "instance-id",
Type: org.AggregateType,
ResourceOwner: "instance-id",
},
CreationDate: time.Now(),
},
Domain: "test.example.com",
}
stmt, err := projection.reduceOrgDomainRemoved(event)
require.NoError(t, err)
assert.NotNil(t, stmt)
updateStmt, ok := stmt.(*handler.UpdateStatement)
require.True(t, ok)
// Should update updated_at and deleted_at
assert.Len(t, updateStmt.Cols, 2)
assert.Equal(t, DomainsUpdatedAtCol, updateStmt.Cols[0].Name)
assert.Equal(t, event.CreationDate(), updateStmt.Cols[0].Value)
assert.Equal(t, DomainsDeletedAtCol, updateStmt.Cols[1].Name)
assert.Equal(t, event.CreationDate(), updateStmt.Cols[1].Value)
// Verify conditions include instance, org, and domain
assert.Len(t, updateStmt.Conditions, 3)
conditionMap := make(map[string]interface{})
for _, cond := range updateStmt.Conditions {
conditionMap[cond.Name] = cond.Value
}
assert.Equal(t, "instance-id", conditionMap[DomainsInstanceIDCol])
assert.Equal(t, "org-id", conditionMap[DomainsOrgIDCol])
assert.Equal(t, "test.example.com", conditionMap[DomainsDomainCol])
}
func TestDomainsProjection_reduceOrgRemoved(t *testing.T) {
projection := &domainsProjection{}
event := &org.OrgRemovedEvent{
BaseEvent: &eventstore.BaseEvent{
Agg: &eventstore.Aggregate{
ID: "org-id",
InstanceID: "instance-id",
Type: org.AggregateType,
ResourceOwner: "instance-id",
},
CreationDate: time.Now(),
},
}
stmt, err := projection.reduceOrgRemoved(event)
require.NoError(t, err)
assert.NotNil(t, stmt)
updateStmt, ok := stmt.(*handler.UpdateStatement)
require.True(t, ok)
// Should update updated_at and deleted_at
assert.Len(t, updateStmt.Cols, 2)
assert.Equal(t, DomainsUpdatedAtCol, updateStmt.Cols[0].Name)
assert.Equal(t, event.CreationDate(), updateStmt.Cols[0].Value)
assert.Equal(t, DomainsDeletedAtCol, updateStmt.Cols[1].Name)
assert.Equal(t, event.CreationDate(), updateStmt.Cols[1].Value)
// Should soft delete all domains for the org
assert.Len(t, updateStmt.Conditions, 3)
conditionMap := make(map[string]interface{})
for _, cond := range updateStmt.Conditions {
conditionMap[cond.Name] = cond.Value
}
assert.Equal(t, "instance-id", conditionMap[DomainsInstanceIDCol])
assert.Equal(t, "org-id", conditionMap[DomainsOrgIDCol])
assert.Nil(t, conditionMap[DomainsDeletedAtCol]) // Only affect non-deleted domains
}
func TestDomainsProjection_reduceInstanceRemoved(t *testing.T) {
projection := &domainsProjection{}
event := &instance.InstanceRemovedEvent{
BaseEvent: &eventstore.BaseEvent{
Agg: &eventstore.Aggregate{
ID: "instance-id",
InstanceID: "instance-id",
Type: instance.AggregateType,
ResourceOwner: "instance-id",
},
CreationDate: time.Now(),
},
}
stmt, err := projection.reduceInstanceRemoved(event)
require.NoError(t, err)
assert.NotNil(t, stmt)
updateStmt, ok := stmt.(*handler.UpdateStatement)
require.True(t, ok)
// Should update updated_at and deleted_at
assert.Len(t, updateStmt.Cols, 2)
assert.Equal(t, DomainsUpdatedAtCol, updateStmt.Cols[0].Name)
assert.Equal(t, event.CreationDate(), updateStmt.Cols[0].Value)
assert.Equal(t, DomainsDeletedAtCol, updateStmt.Cols[1].Name)
assert.Equal(t, event.CreationDate(), updateStmt.Cols[1].Value)
// Should soft delete all domains for the instance
assert.Len(t, updateStmt.Conditions, 2)
conditionMap := make(map[string]interface{})
for _, cond := range updateStmt.Conditions {
conditionMap[cond.Name] = cond.Value
}
assert.Equal(t, "instance-id", conditionMap[DomainsInstanceIDCol])
assert.Nil(t, conditionMap[DomainsDeletedAtCol]) // Only affect non-deleted domains
}

View File

@@ -30,16 +30,17 @@ type TX interface {
} }
const ( const (
domainsTable = "zitadel.domains" domainsTable = "zitadel.domains"
domainsInstanceIDCol = "instance_id" domainsIDCol = "id"
domainsOrgIDCol = "org_id" domainsInstanceIDCol = "instance_id"
domainsDomainCol = "domain" domainsOrgIDCol = "org_id"
domainsIsVerifiedCol = "is_verified" domainsDomainCol = "domain"
domainsIsPrimaryCol = "is_primary" domainsIsVerifiedCol = "is_verified"
domainsIsPrimaryCol = "is_primary"
domainsValidationTypeCol = "validation_type" domainsValidationTypeCol = "validation_type"
domainsCreatedAtCol = "created_at" domainsCreatedAtCol = "created_at"
domainsUpdatedAtCol = "updated_at" domainsUpdatedAtCol = "updated_at"
domainsDeletedAtCol = "deleted_at" domainsDeletedAtCol = "deleted_at"
) )
// DomainRepository implements both InstanceDomainRepository and OrganizationDomainRepository // DomainRepository implements both InstanceDomainRepository and OrganizationDomainRepository
@@ -67,6 +68,7 @@ func (r *DomainRepository) AddInstanceDomain(ctx context.Context, instanceID, do
domainsUpdatedAtCol, domainsUpdatedAtCol,
). ).
Values(instanceID, domain, true, false, now, now). Values(instanceID, domain, true, false, now, now).
Suffix("RETURNING " + domainsIDCol).
PlaceholderFormat(squirrel.Dollar) PlaceholderFormat(squirrel.Dollar)
stmt, args, err := query.ToSql() stmt, args, err := query.ToSql()
@@ -74,12 +76,14 @@ func (r *DomainRepository) AddInstanceDomain(ctx context.Context, instanceID, do
return nil, zerrors.ThrowInternal(err, "DOMAIN-1n8fK", "Errors.Internal") return nil, zerrors.ThrowInternal(err, "DOMAIN-1n8fK", "Errors.Internal")
} }
_, err = r.client.ExecContext(ctx, stmt, args...) var id string
err = r.client.QueryRowContext(ctx, stmt, args...).Scan(&id)
if err != nil { if err != nil {
return nil, zerrors.ThrowInternal(err, "DOMAIN-3m9sL", "Errors.Internal") return nil, zerrors.ThrowInternal(err, "DOMAIN-3m9sL", "Errors.Internal")
} }
return &v2domain.Domain{ return &v2domain.Domain{
ID: id,
InstanceID: instanceID, InstanceID: instanceID,
OrganizationID: nil, OrganizationID: nil,
Domain: domain, Domain: domain,
@@ -93,7 +97,7 @@ func (r *DomainRepository) AddInstanceDomain(ctx context.Context, instanceID, do
// SetPrimary sets the primary domain for an instance // SetPrimary sets the primary domain for an instance
func (r *DomainRepository) SetInstanceDomainPrimary(ctx context.Context, instanceID, domain string) error { func (r *DomainRepository) SetInstanceDomainPrimary(ctx context.Context, instanceID, domain string) error {
return r.withTransaction(ctx, func(tx database.Tx) error { return r.withTransaction(ctx, func(tx TX) error {
// First, unset any existing primary domain for this instance // First, unset any existing primary domain for this instance
unsetQuery := squirrel.Update(domainsTable). unsetQuery := squirrel.Update(domainsTable).
Set(domainsIsPrimaryCol, false). Set(domainsIsPrimaryCol, false).
@@ -201,6 +205,7 @@ func (r *DomainRepository) AddOrganizationDomain(ctx context.Context, instanceID
domainsUpdatedAtCol, domainsUpdatedAtCol,
). ).
Values(instanceID, organizationID, domain, false, false, int(validationType), now, now). Values(instanceID, organizationID, domain, false, false, int(validationType), now, now).
Suffix("RETURNING " + domainsIDCol).
PlaceholderFormat(squirrel.Dollar) PlaceholderFormat(squirrel.Dollar)
stmt, args, err := query.ToSql() stmt, args, err := query.ToSql()
@@ -208,12 +213,14 @@ func (r *DomainRepository) AddOrganizationDomain(ctx context.Context, instanceID
return nil, zerrors.ThrowInternal(err, "DOMAIN-Ew2xU", "Errors.Internal") return nil, zerrors.ThrowInternal(err, "DOMAIN-Ew2xU", "Errors.Internal")
} }
_, err = r.client.ExecContext(ctx, stmt, args...) var id string
err = r.client.QueryRowContext(ctx, stmt, args...).Scan(&id)
if err != nil { if err != nil {
return nil, zerrors.ThrowInternal(err, "DOMAIN-Fx3yV", "Errors.Internal") return nil, zerrors.ThrowInternal(err, "DOMAIN-Fx3yV", "Errors.Internal")
} }
return &v2domain.Domain{ return &v2domain.Domain{
ID: id,
InstanceID: instanceID, InstanceID: instanceID,
OrganizationID: &organizationID, OrganizationID: &organizationID,
Domain: domain, Domain: domain,
@@ -262,7 +269,7 @@ func (r *DomainRepository) SetOrganizationDomainVerified(ctx context.Context, in
// SetPrimary sets the primary domain for an organization // SetPrimary sets the primary domain for an organization
func (r *DomainRepository) SetOrganizationDomainPrimary(ctx context.Context, instanceID, organizationID, domain string) error { func (r *DomainRepository) SetOrganizationDomainPrimary(ctx context.Context, instanceID, organizationID, domain string) error {
return r.withTransaction(ctx, func(tx database.Tx) error { return r.withTransaction(ctx, func(tx TX) error {
// First, unset any existing primary domain for this organization // First, unset any existing primary domain for this organization
unsetQuery := squirrel.Update(domainsTable). unsetQuery := squirrel.Update(domainsTable).
Set(domainsIsPrimaryCol, false). Set(domainsIsPrimaryCol, false).
@@ -443,6 +450,7 @@ func (r *DomainRepository) List(ctx context.Context, criteria v2domain.DomainSea
func (r *DomainRepository) buildSelectQuery(criteria v2domain.DomainSearchCriteria, pagination v2domain.DomainPagination) squirrel.SelectBuilder { func (r *DomainRepository) buildSelectQuery(criteria v2domain.DomainSearchCriteria, pagination v2domain.DomainPagination) squirrel.SelectBuilder {
query := squirrel.Select( query := squirrel.Select(
domainsIDCol,
domainsInstanceIDCol, domainsInstanceIDCol,
domainsOrgIDCol, domainsOrgIDCol,
domainsDomainCol, domainsDomainCol,
@@ -474,8 +482,7 @@ func (r *DomainRepository) applySearchCriteria(query squirrel.SelectBuilder, cri
query = query.Where(squirrel.Expr(domainsDeletedAtCol + " IS NULL")) query = query.Where(squirrel.Expr(domainsDeletedAtCol + " IS NULL"))
if criteria.ID != nil { if criteria.ID != nil {
// Note: Our table doesn't have an ID column. This might need to be reconsidered query = query.Where(squirrel.Eq{domainsIDCol: *criteria.ID})
// For now, we'll ignore this criterion since the spec doesn't define where ID comes from
} }
if criteria.Domain != nil { if criteria.Domain != nil {
@@ -536,20 +543,21 @@ func (r *DomainRepository) applyPagination(query squirrel.SelectBuilder, paginat
} }
func (r *DomainRepository) scanDomain(rows *sql.Rows) (*v2domain.Domain, error) { func (r *DomainRepository) scanDomain(rows *sql.Rows) (*v2domain.Domain, error) {
var domain v2domain.Domain var domainRecord v2domain.Domain
var orgID sql.NullString var orgID sql.NullString
var validationType sql.NullInt32 var validationType sql.NullInt32
var deletedAt sql.NullTime var deletedAt sql.NullTime
err := rows.Scan( err := rows.Scan(
&domain.InstanceID, &domainRecord.ID,
&domainRecord.InstanceID,
&orgID, &orgID,
&domain.Domain, &domainRecord.Domain,
&domain.IsVerified, &domainRecord.IsVerified,
&domain.IsPrimary, &domainRecord.IsPrimary,
&validationType, &validationType,
&domain.CreatedAt, &domainRecord.CreatedAt,
&domain.UpdatedAt, &domainRecord.UpdatedAt,
&deletedAt, &deletedAt,
) )
if err != nil { if err != nil {
@@ -557,19 +565,19 @@ func (r *DomainRepository) scanDomain(rows *sql.Rows) (*v2domain.Domain, error)
} }
if orgID.Valid { if orgID.Valid {
domain.OrganizationID = &orgID.String domainRecord.OrganizationID = &orgID.String
} }
if validationType.Valid { if validationType.Valid {
validationTypeValue := domain.OrgDomainValidationType(validationType.Int32) validationTypeValue := domain.OrgDomainValidationType(validationType.Int32)
domain.ValidationType = &validationTypeValue domainRecord.ValidationType = &validationTypeValue
} }
if deletedAt.Valid { if deletedAt.Valid {
domain.DeletedAt = &deletedAt.Time domainRecord.DeletedAt = &deletedAt.Time
} }
return &domain, nil return &domainRecord, nil
} }
func (r *DomainRepository) withTransaction(ctx context.Context, fn func(TX) error) error { func (r *DomainRepository) withTransaction(ctx context.Context, fn func(TX) error) error {

View File

@@ -25,14 +25,16 @@ func TestDomainRepository_AddInstanceDomain(t *testing.T) {
instanceID := "test-instance-id" instanceID := "test-instance-id"
domainName := "test.example.com" domainName := "test.example.com"
expectedID := "domain-id-123"
mock.ExpectExec(`INSERT INTO zitadel\.domains`). mock.ExpectQuery(`INSERT INTO zitadel\.domains`).
WithArgs(instanceID, domainName, true, false, sqlmock.AnyArg(), sqlmock.AnyArg()). WithArgs(instanceID, domainName, true, false, sqlmock.AnyArg(), sqlmock.AnyArg()).
WillReturnResult(sqlmock.NewResult(1, 1)) WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(expectedID))
result, err := repo.AddInstanceDomain(context.Background(), instanceID, domainName) result, err := repo.AddInstanceDomain(context.Background(), instanceID, domainName)
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, expectedID, result.ID)
assert.Equal(t, instanceID, result.InstanceID) assert.Equal(t, instanceID, result.InstanceID)
assert.Nil(t, result.OrganizationID) assert.Nil(t, result.OrganizationID)
assert.Equal(t, domainName, result.Domain) assert.Equal(t, domainName, result.Domain)
@@ -54,14 +56,16 @@ func TestDomainRepository_AddOrganizationDomain(t *testing.T) {
orgID := "test-org-id" orgID := "test-org-id"
domainName := "test.example.com" domainName := "test.example.com"
validationType := domain.OrgDomainValidationTypeHTTP validationType := domain.OrgDomainValidationTypeHTTP
expectedID := "domain-id-456"
mock.ExpectExec(`INSERT INTO zitadel\.domains`). mock.ExpectQuery(`INSERT INTO zitadel\.domains`).
WithArgs(instanceID, orgID, domainName, false, false, int(validationType), sqlmock.AnyArg(), sqlmock.AnyArg()). WithArgs(instanceID, orgID, domainName, false, false, int(validationType), sqlmock.AnyArg(), sqlmock.AnyArg()).
WillReturnResult(sqlmock.NewResult(1, 1)) WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(expectedID))
result, err := repo.AddOrganizationDomain(context.Background(), instanceID, orgID, domainName, validationType) result, err := repo.AddOrganizationDomain(context.Background(), instanceID, orgID, domainName, validationType)
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, expectedID, result.ID)
assert.Equal(t, instanceID, result.InstanceID) assert.Equal(t, instanceID, result.InstanceID)
assert.Equal(t, orgID, *result.OrganizationID) assert.Equal(t, orgID, *result.OrganizationID)
assert.Equal(t, domainName, result.Domain) assert.Equal(t, domainName, result.Domain)
@@ -120,8 +124,8 @@ func TestDomainRepository_Get(t *testing.T) {
} }
rows := sqlmock.NewRows([]string{ rows := sqlmock.NewRows([]string{
"instance_id", "org_id", "domain", "is_verified", "is_primary", "validation_type", "created_at", "updated_at", "deleted_at", "id", "instance_id", "org_id", "domain", "is_verified", "is_primary", "validation_type", "created_at", "updated_at", "deleted_at",
}).AddRow(instanceID, nil, domainName, true, false, nil, now, now, nil) }).AddRow("domain-123", instanceID, nil, domainName, true, false, nil, now, now, nil)
mock.ExpectQuery(`SELECT .* FROM zitadel\.domains`). mock.ExpectQuery(`SELECT .* FROM zitadel\.domains`).
WithArgs(domainName, instanceID). WithArgs(domainName, instanceID).
@@ -130,6 +134,7 @@ func TestDomainRepository_Get(t *testing.T) {
result, err := repo.Get(context.Background(), criteria) result, err := repo.Get(context.Background(), criteria)
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, "domain-123", result.ID)
assert.Equal(t, instanceID, result.InstanceID) assert.Equal(t, instanceID, result.InstanceID)
assert.Nil(t, result.OrganizationID) assert.Nil(t, result.OrganizationID)
assert.Equal(t, domainName, result.Domain) assert.Equal(t, domainName, result.Domain)
@@ -168,10 +173,10 @@ func TestDomainRepository_List(t *testing.T) {
// Mock data query // Mock data query
rows := sqlmock.NewRows([]string{ rows := sqlmock.NewRows([]string{
"instance_id", "org_id", "domain", "is_verified", "is_primary", "validation_type", "created_at", "updated_at", "deleted_at", "id", "instance_id", "org_id", "domain", "is_verified", "is_primary", "validation_type", "created_at", "updated_at", "deleted_at",
}). }).
AddRow(instanceID, nil, "instance.example.com", true, true, nil, now, now, nil). AddRow("domain-instance", instanceID, nil, "instance.example.com", true, true, nil, now, now, nil).
AddRow(instanceID, "org-id", "org.example.com", false, false, int(domain.OrgDomainValidationTypeHTTP), now, now, nil) AddRow("domain-org", instanceID, "org-id", "org.example.com", false, false, int(domain.OrgDomainValidationTypeHTTP), now, now, nil)
mock.ExpectQuery(`SELECT .* FROM zitadel\.domains.*ORDER BY domain ASC.*LIMIT 10`). mock.ExpectQuery(`SELECT .* FROM zitadel\.domains.*ORDER BY domain ASC.*LIMIT 10`).
WithArgs(instanceID). WithArgs(instanceID).
@@ -184,6 +189,7 @@ func TestDomainRepository_List(t *testing.T) {
assert.Len(t, result.Domains, 2) assert.Len(t, result.Domains, 2)
// Check first domain (instance domain) // Check first domain (instance domain)
assert.Equal(t, "domain-instance", result.Domains[0].ID)
assert.Equal(t, instanceID, result.Domains[0].InstanceID) assert.Equal(t, instanceID, result.Domains[0].InstanceID)
assert.Nil(t, result.Domains[0].OrganizationID) assert.Nil(t, result.Domains[0].OrganizationID)
assert.Equal(t, "instance.example.com", result.Domains[0].Domain) assert.Equal(t, "instance.example.com", result.Domains[0].Domain)
@@ -191,6 +197,7 @@ func TestDomainRepository_List(t *testing.T) {
assert.True(t, result.Domains[0].IsPrimary) assert.True(t, result.Domains[0].IsPrimary)
// Check second domain (org domain) // Check second domain (org domain)
assert.Equal(t, "domain-org", result.Domains[1].ID)
assert.Equal(t, instanceID, result.Domains[1].InstanceID) assert.Equal(t, instanceID, result.Domains[1].InstanceID)
assert.Equal(t, "org-id", *result.Domains[1].OrganizationID) assert.Equal(t, "org-id", *result.Domains[1].OrganizationID)
assert.Equal(t, "org.example.com", result.Domains[1].Domain) assert.Equal(t, "org.example.com", result.Domains[1].Domain)