Add ID column to domains table and update repository implementation

Co-authored-by: adlerhurst <27845747+adlerhurst@users.noreply.github.com>
This commit is contained in:
copilot-swe-agent[bot]
2025-07-14 20:11:28 +00:00
parent d3de8a2150
commit 655f9be015
6 changed files with 537 additions and 37 deletions

View File

@@ -30,16 +30,17 @@ type TX interface {
}
const (
domainsTable = "zitadel.domains"
domainsInstanceIDCol = "instance_id"
domainsOrgIDCol = "org_id"
domainsDomainCol = "domain"
domainsIsVerifiedCol = "is_verified"
domainsIsPrimaryCol = "is_primary"
domainsTable = "zitadel.domains"
domainsIDCol = "id"
domainsInstanceIDCol = "instance_id"
domainsOrgIDCol = "org_id"
domainsDomainCol = "domain"
domainsIsVerifiedCol = "is_verified"
domainsIsPrimaryCol = "is_primary"
domainsValidationTypeCol = "validation_type"
domainsCreatedAtCol = "created_at"
domainsUpdatedAtCol = "updated_at"
domainsDeletedAtCol = "deleted_at"
domainsCreatedAtCol = "created_at"
domainsUpdatedAtCol = "updated_at"
domainsDeletedAtCol = "deleted_at"
)
// DomainRepository implements both InstanceDomainRepository and OrganizationDomainRepository
@@ -67,6 +68,7 @@ func (r *DomainRepository) AddInstanceDomain(ctx context.Context, instanceID, do
domainsUpdatedAtCol,
).
Values(instanceID, domain, true, false, now, now).
Suffix("RETURNING " + domainsIDCol).
PlaceholderFormat(squirrel.Dollar)
stmt, args, err := query.ToSql()
@@ -74,12 +76,14 @@ func (r *DomainRepository) AddInstanceDomain(ctx context.Context, instanceID, do
return nil, zerrors.ThrowInternal(err, "DOMAIN-1n8fK", "Errors.Internal")
}
_, err = r.client.ExecContext(ctx, stmt, args...)
var id string
err = r.client.QueryRowContext(ctx, stmt, args...).Scan(&id)
if err != nil {
return nil, zerrors.ThrowInternal(err, "DOMAIN-3m9sL", "Errors.Internal")
}
return &v2domain.Domain{
ID: id,
InstanceID: instanceID,
OrganizationID: nil,
Domain: domain,
@@ -93,7 +97,7 @@ func (r *DomainRepository) AddInstanceDomain(ctx context.Context, instanceID, do
// SetPrimary sets the primary domain for an instance
func (r *DomainRepository) SetInstanceDomainPrimary(ctx context.Context, instanceID, domain string) error {
return r.withTransaction(ctx, func(tx database.Tx) error {
return r.withTransaction(ctx, func(tx TX) error {
// First, unset any existing primary domain for this instance
unsetQuery := squirrel.Update(domainsTable).
Set(domainsIsPrimaryCol, false).
@@ -201,6 +205,7 @@ func (r *DomainRepository) AddOrganizationDomain(ctx context.Context, instanceID
domainsUpdatedAtCol,
).
Values(instanceID, organizationID, domain, false, false, int(validationType), now, now).
Suffix("RETURNING " + domainsIDCol).
PlaceholderFormat(squirrel.Dollar)
stmt, args, err := query.ToSql()
@@ -208,12 +213,14 @@ func (r *DomainRepository) AddOrganizationDomain(ctx context.Context, instanceID
return nil, zerrors.ThrowInternal(err, "DOMAIN-Ew2xU", "Errors.Internal")
}
_, err = r.client.ExecContext(ctx, stmt, args...)
var id string
err = r.client.QueryRowContext(ctx, stmt, args...).Scan(&id)
if err != nil {
return nil, zerrors.ThrowInternal(err, "DOMAIN-Fx3yV", "Errors.Internal")
}
return &v2domain.Domain{
ID: id,
InstanceID: instanceID,
OrganizationID: &organizationID,
Domain: domain,
@@ -262,7 +269,7 @@ func (r *DomainRepository) SetOrganizationDomainVerified(ctx context.Context, in
// SetPrimary sets the primary domain for an organization
func (r *DomainRepository) SetOrganizationDomainPrimary(ctx context.Context, instanceID, organizationID, domain string) error {
return r.withTransaction(ctx, func(tx database.Tx) error {
return r.withTransaction(ctx, func(tx TX) error {
// First, unset any existing primary domain for this organization
unsetQuery := squirrel.Update(domainsTable).
Set(domainsIsPrimaryCol, false).
@@ -443,6 +450,7 @@ func (r *DomainRepository) List(ctx context.Context, criteria v2domain.DomainSea
func (r *DomainRepository) buildSelectQuery(criteria v2domain.DomainSearchCriteria, pagination v2domain.DomainPagination) squirrel.SelectBuilder {
query := squirrel.Select(
domainsIDCol,
domainsInstanceIDCol,
domainsOrgIDCol,
domainsDomainCol,
@@ -474,8 +482,7 @@ func (r *DomainRepository) applySearchCriteria(query squirrel.SelectBuilder, cri
query = query.Where(squirrel.Expr(domainsDeletedAtCol + " IS NULL"))
if criteria.ID != nil {
// Note: Our table doesn't have an ID column. This might need to be reconsidered
// For now, we'll ignore this criterion since the spec doesn't define where ID comes from
query = query.Where(squirrel.Eq{domainsIDCol: *criteria.ID})
}
if criteria.Domain != nil {
@@ -536,20 +543,21 @@ func (r *DomainRepository) applyPagination(query squirrel.SelectBuilder, paginat
}
func (r *DomainRepository) scanDomain(rows *sql.Rows) (*v2domain.Domain, error) {
var domain v2domain.Domain
var domainRecord v2domain.Domain
var orgID sql.NullString
var validationType sql.NullInt32
var deletedAt sql.NullTime
err := rows.Scan(
&domain.InstanceID,
&domainRecord.ID,
&domainRecord.InstanceID,
&orgID,
&domain.Domain,
&domain.IsVerified,
&domain.IsPrimary,
&domainRecord.Domain,
&domainRecord.IsVerified,
&domainRecord.IsPrimary,
&validationType,
&domain.CreatedAt,
&domain.UpdatedAt,
&domainRecord.CreatedAt,
&domainRecord.UpdatedAt,
&deletedAt,
)
if err != nil {
@@ -557,19 +565,19 @@ func (r *DomainRepository) scanDomain(rows *sql.Rows) (*v2domain.Domain, error)
}
if orgID.Valid {
domain.OrganizationID = &orgID.String
domainRecord.OrganizationID = &orgID.String
}
if validationType.Valid {
validationTypeValue := domain.OrgDomainValidationType(validationType.Int32)
domain.ValidationType = &validationTypeValue
domainRecord.ValidationType = &validationTypeValue
}
if deletedAt.Valid {
domain.DeletedAt = &deletedAt.Time
domainRecord.DeletedAt = &deletedAt.Time
}
return &domain, nil
return &domainRecord, nil
}
func (r *DomainRepository) withTransaction(ctx context.Context, fn func(TX) error) error {

View File

@@ -25,14 +25,16 @@ func TestDomainRepository_AddInstanceDomain(t *testing.T) {
instanceID := "test-instance-id"
domainName := "test.example.com"
expectedID := "domain-id-123"
mock.ExpectExec(`INSERT INTO zitadel\.domains`).
mock.ExpectQuery(`INSERT INTO zitadel\.domains`).
WithArgs(instanceID, domainName, true, false, sqlmock.AnyArg(), sqlmock.AnyArg()).
WillReturnResult(sqlmock.NewResult(1, 1))
WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(expectedID))
result, err := repo.AddInstanceDomain(context.Background(), instanceID, domainName)
require.NoError(t, err)
assert.Equal(t, expectedID, result.ID)
assert.Equal(t, instanceID, result.InstanceID)
assert.Nil(t, result.OrganizationID)
assert.Equal(t, domainName, result.Domain)
@@ -54,14 +56,16 @@ func TestDomainRepository_AddOrganizationDomain(t *testing.T) {
orgID := "test-org-id"
domainName := "test.example.com"
validationType := domain.OrgDomainValidationTypeHTTP
expectedID := "domain-id-456"
mock.ExpectExec(`INSERT INTO zitadel\.domains`).
mock.ExpectQuery(`INSERT INTO zitadel\.domains`).
WithArgs(instanceID, orgID, domainName, false, false, int(validationType), sqlmock.AnyArg(), sqlmock.AnyArg()).
WillReturnResult(sqlmock.NewResult(1, 1))
WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(expectedID))
result, err := repo.AddOrganizationDomain(context.Background(), instanceID, orgID, domainName, validationType)
require.NoError(t, err)
assert.Equal(t, expectedID, result.ID)
assert.Equal(t, instanceID, result.InstanceID)
assert.Equal(t, orgID, *result.OrganizationID)
assert.Equal(t, domainName, result.Domain)
@@ -120,8 +124,8 @@ func TestDomainRepository_Get(t *testing.T) {
}
rows := sqlmock.NewRows([]string{
"instance_id", "org_id", "domain", "is_verified", "is_primary", "validation_type", "created_at", "updated_at", "deleted_at",
}).AddRow(instanceID, nil, domainName, true, false, nil, now, now, nil)
"id", "instance_id", "org_id", "domain", "is_verified", "is_primary", "validation_type", "created_at", "updated_at", "deleted_at",
}).AddRow("domain-123", instanceID, nil, domainName, true, false, nil, now, now, nil)
mock.ExpectQuery(`SELECT .* FROM zitadel\.domains`).
WithArgs(domainName, instanceID).
@@ -130,6 +134,7 @@ func TestDomainRepository_Get(t *testing.T) {
result, err := repo.Get(context.Background(), criteria)
require.NoError(t, err)
assert.Equal(t, "domain-123", result.ID)
assert.Equal(t, instanceID, result.InstanceID)
assert.Nil(t, result.OrganizationID)
assert.Equal(t, domainName, result.Domain)
@@ -168,10 +173,10 @@ func TestDomainRepository_List(t *testing.T) {
// Mock data query
rows := sqlmock.NewRows([]string{
"instance_id", "org_id", "domain", "is_verified", "is_primary", "validation_type", "created_at", "updated_at", "deleted_at",
"id", "instance_id", "org_id", "domain", "is_verified", "is_primary", "validation_type", "created_at", "updated_at", "deleted_at",
}).
AddRow(instanceID, nil, "instance.example.com", true, true, nil, now, now, nil).
AddRow(instanceID, "org-id", "org.example.com", false, false, int(domain.OrgDomainValidationTypeHTTP), now, now, nil)
AddRow("domain-instance", instanceID, nil, "instance.example.com", true, true, nil, now, now, nil).
AddRow("domain-org", instanceID, "org-id", "org.example.com", false, false, int(domain.OrgDomainValidationTypeHTTP), now, now, nil)
mock.ExpectQuery(`SELECT .* FROM zitadel\.domains.*ORDER BY domain ASC.*LIMIT 10`).
WithArgs(instanceID).
@@ -184,6 +189,7 @@ func TestDomainRepository_List(t *testing.T) {
assert.Len(t, result.Domains, 2)
// Check first domain (instance domain)
assert.Equal(t, "domain-instance", result.Domains[0].ID)
assert.Equal(t, instanceID, result.Domains[0].InstanceID)
assert.Nil(t, result.Domains[0].OrganizationID)
assert.Equal(t, "instance.example.com", result.Domains[0].Domain)
@@ -191,6 +197,7 @@ func TestDomainRepository_List(t *testing.T) {
assert.True(t, result.Domains[0].IsPrimary)
// Check second domain (org domain)
assert.Equal(t, "domain-org", result.Domains[1].ID)
assert.Equal(t, instanceID, result.Domains[1].InstanceID)
assert.Equal(t, "org-id", *result.Domains[1].OrganizationID)
assert.Equal(t, "org.example.com", result.Domains[1].Domain)