Add ID column to domains table and update repository implementation

Co-authored-by: adlerhurst <27845747+adlerhurst@users.noreply.github.com>
This commit is contained in:
copilot-swe-agent[bot]
2025-07-14 20:11:28 +00:00
parent d3de8a2150
commit 655f9be015
6 changed files with 537 additions and 37 deletions

View File

@@ -30,16 +30,17 @@ type TX interface {
}
const (
domainsTable = "zitadel.domains"
domainsInstanceIDCol = "instance_id"
domainsOrgIDCol = "org_id"
domainsDomainCol = "domain"
domainsIsVerifiedCol = "is_verified"
domainsIsPrimaryCol = "is_primary"
domainsTable = "zitadel.domains"
domainsIDCol = "id"
domainsInstanceIDCol = "instance_id"
domainsOrgIDCol = "org_id"
domainsDomainCol = "domain"
domainsIsVerifiedCol = "is_verified"
domainsIsPrimaryCol = "is_primary"
domainsValidationTypeCol = "validation_type"
domainsCreatedAtCol = "created_at"
domainsUpdatedAtCol = "updated_at"
domainsDeletedAtCol = "deleted_at"
domainsCreatedAtCol = "created_at"
domainsUpdatedAtCol = "updated_at"
domainsDeletedAtCol = "deleted_at"
)
// DomainRepository implements both InstanceDomainRepository and OrganizationDomainRepository
@@ -67,6 +68,7 @@ func (r *DomainRepository) AddInstanceDomain(ctx context.Context, instanceID, do
domainsUpdatedAtCol,
).
Values(instanceID, domain, true, false, now, now).
Suffix("RETURNING " + domainsIDCol).
PlaceholderFormat(squirrel.Dollar)
stmt, args, err := query.ToSql()
@@ -74,12 +76,14 @@ func (r *DomainRepository) AddInstanceDomain(ctx context.Context, instanceID, do
return nil, zerrors.ThrowInternal(err, "DOMAIN-1n8fK", "Errors.Internal")
}
_, err = r.client.ExecContext(ctx, stmt, args...)
var id string
err = r.client.QueryRowContext(ctx, stmt, args...).Scan(&id)
if err != nil {
return nil, zerrors.ThrowInternal(err, "DOMAIN-3m9sL", "Errors.Internal")
}
return &v2domain.Domain{
ID: id,
InstanceID: instanceID,
OrganizationID: nil,
Domain: domain,
@@ -93,7 +97,7 @@ func (r *DomainRepository) AddInstanceDomain(ctx context.Context, instanceID, do
// SetPrimary sets the primary domain for an instance
func (r *DomainRepository) SetInstanceDomainPrimary(ctx context.Context, instanceID, domain string) error {
return r.withTransaction(ctx, func(tx database.Tx) error {
return r.withTransaction(ctx, func(tx TX) error {
// First, unset any existing primary domain for this instance
unsetQuery := squirrel.Update(domainsTable).
Set(domainsIsPrimaryCol, false).
@@ -201,6 +205,7 @@ func (r *DomainRepository) AddOrganizationDomain(ctx context.Context, instanceID
domainsUpdatedAtCol,
).
Values(instanceID, organizationID, domain, false, false, int(validationType), now, now).
Suffix("RETURNING " + domainsIDCol).
PlaceholderFormat(squirrel.Dollar)
stmt, args, err := query.ToSql()
@@ -208,12 +213,14 @@ func (r *DomainRepository) AddOrganizationDomain(ctx context.Context, instanceID
return nil, zerrors.ThrowInternal(err, "DOMAIN-Ew2xU", "Errors.Internal")
}
_, err = r.client.ExecContext(ctx, stmt, args...)
var id string
err = r.client.QueryRowContext(ctx, stmt, args...).Scan(&id)
if err != nil {
return nil, zerrors.ThrowInternal(err, "DOMAIN-Fx3yV", "Errors.Internal")
}
return &v2domain.Domain{
ID: id,
InstanceID: instanceID,
OrganizationID: &organizationID,
Domain: domain,
@@ -262,7 +269,7 @@ func (r *DomainRepository) SetOrganizationDomainVerified(ctx context.Context, in
// SetPrimary sets the primary domain for an organization
func (r *DomainRepository) SetOrganizationDomainPrimary(ctx context.Context, instanceID, organizationID, domain string) error {
return r.withTransaction(ctx, func(tx database.Tx) error {
return r.withTransaction(ctx, func(tx TX) error {
// First, unset any existing primary domain for this organization
unsetQuery := squirrel.Update(domainsTable).
Set(domainsIsPrimaryCol, false).
@@ -443,6 +450,7 @@ func (r *DomainRepository) List(ctx context.Context, criteria v2domain.DomainSea
func (r *DomainRepository) buildSelectQuery(criteria v2domain.DomainSearchCriteria, pagination v2domain.DomainPagination) squirrel.SelectBuilder {
query := squirrel.Select(
domainsIDCol,
domainsInstanceIDCol,
domainsOrgIDCol,
domainsDomainCol,
@@ -474,8 +482,7 @@ func (r *DomainRepository) applySearchCriteria(query squirrel.SelectBuilder, cri
query = query.Where(squirrel.Expr(domainsDeletedAtCol + " IS NULL"))
if criteria.ID != nil {
// Note: Our table doesn't have an ID column. This might need to be reconsidered
// For now, we'll ignore this criterion since the spec doesn't define where ID comes from
query = query.Where(squirrel.Eq{domainsIDCol: *criteria.ID})
}
if criteria.Domain != nil {
@@ -536,20 +543,21 @@ func (r *DomainRepository) applyPagination(query squirrel.SelectBuilder, paginat
}
func (r *DomainRepository) scanDomain(rows *sql.Rows) (*v2domain.Domain, error) {
var domain v2domain.Domain
var domainRecord v2domain.Domain
var orgID sql.NullString
var validationType sql.NullInt32
var deletedAt sql.NullTime
err := rows.Scan(
&domain.InstanceID,
&domainRecord.ID,
&domainRecord.InstanceID,
&orgID,
&domain.Domain,
&domain.IsVerified,
&domain.IsPrimary,
&domainRecord.Domain,
&domainRecord.IsVerified,
&domainRecord.IsPrimary,
&validationType,
&domain.CreatedAt,
&domain.UpdatedAt,
&domainRecord.CreatedAt,
&domainRecord.UpdatedAt,
&deletedAt,
)
if err != nil {
@@ -557,19 +565,19 @@ func (r *DomainRepository) scanDomain(rows *sql.Rows) (*v2domain.Domain, error)
}
if orgID.Valid {
domain.OrganizationID = &orgID.String
domainRecord.OrganizationID = &orgID.String
}
if validationType.Valid {
validationTypeValue := domain.OrgDomainValidationType(validationType.Int32)
domain.ValidationType = &validationTypeValue
domainRecord.ValidationType = &validationTypeValue
}
if deletedAt.Valid {
domain.DeletedAt = &deletedAt.Time
domainRecord.DeletedAt = &deletedAt.Time
}
return &domain, nil
return &domainRecord, nil
}
func (r *DomainRepository) withTransaction(ctx context.Context, fn func(TX) error) error {