mirror of
https://github.com/zitadel/zitadel.git
synced 2025-08-12 10:57:35 +00:00
Add ID column to domains table and update repository implementation
Co-authored-by: adlerhurst <27845747+adlerhurst@users.noreply.github.com>
This commit is contained in:
@@ -30,16 +30,17 @@ type TX interface {
|
||||
}
|
||||
|
||||
const (
|
||||
domainsTable = "zitadel.domains"
|
||||
domainsInstanceIDCol = "instance_id"
|
||||
domainsOrgIDCol = "org_id"
|
||||
domainsDomainCol = "domain"
|
||||
domainsIsVerifiedCol = "is_verified"
|
||||
domainsIsPrimaryCol = "is_primary"
|
||||
domainsTable = "zitadel.domains"
|
||||
domainsIDCol = "id"
|
||||
domainsInstanceIDCol = "instance_id"
|
||||
domainsOrgIDCol = "org_id"
|
||||
domainsDomainCol = "domain"
|
||||
domainsIsVerifiedCol = "is_verified"
|
||||
domainsIsPrimaryCol = "is_primary"
|
||||
domainsValidationTypeCol = "validation_type"
|
||||
domainsCreatedAtCol = "created_at"
|
||||
domainsUpdatedAtCol = "updated_at"
|
||||
domainsDeletedAtCol = "deleted_at"
|
||||
domainsCreatedAtCol = "created_at"
|
||||
domainsUpdatedAtCol = "updated_at"
|
||||
domainsDeletedAtCol = "deleted_at"
|
||||
)
|
||||
|
||||
// DomainRepository implements both InstanceDomainRepository and OrganizationDomainRepository
|
||||
@@ -67,6 +68,7 @@ func (r *DomainRepository) AddInstanceDomain(ctx context.Context, instanceID, do
|
||||
domainsUpdatedAtCol,
|
||||
).
|
||||
Values(instanceID, domain, true, false, now, now).
|
||||
Suffix("RETURNING " + domainsIDCol).
|
||||
PlaceholderFormat(squirrel.Dollar)
|
||||
|
||||
stmt, args, err := query.ToSql()
|
||||
@@ -74,12 +76,14 @@ func (r *DomainRepository) AddInstanceDomain(ctx context.Context, instanceID, do
|
||||
return nil, zerrors.ThrowInternal(err, "DOMAIN-1n8fK", "Errors.Internal")
|
||||
}
|
||||
|
||||
_, err = r.client.ExecContext(ctx, stmt, args...)
|
||||
var id string
|
||||
err = r.client.QueryRowContext(ctx, stmt, args...).Scan(&id)
|
||||
if err != nil {
|
||||
return nil, zerrors.ThrowInternal(err, "DOMAIN-3m9sL", "Errors.Internal")
|
||||
}
|
||||
|
||||
return &v2domain.Domain{
|
||||
ID: id,
|
||||
InstanceID: instanceID,
|
||||
OrganizationID: nil,
|
||||
Domain: domain,
|
||||
@@ -93,7 +97,7 @@ func (r *DomainRepository) AddInstanceDomain(ctx context.Context, instanceID, do
|
||||
|
||||
// SetPrimary sets the primary domain for an instance
|
||||
func (r *DomainRepository) SetInstanceDomainPrimary(ctx context.Context, instanceID, domain string) error {
|
||||
return r.withTransaction(ctx, func(tx database.Tx) error {
|
||||
return r.withTransaction(ctx, func(tx TX) error {
|
||||
// First, unset any existing primary domain for this instance
|
||||
unsetQuery := squirrel.Update(domainsTable).
|
||||
Set(domainsIsPrimaryCol, false).
|
||||
@@ -201,6 +205,7 @@ func (r *DomainRepository) AddOrganizationDomain(ctx context.Context, instanceID
|
||||
domainsUpdatedAtCol,
|
||||
).
|
||||
Values(instanceID, organizationID, domain, false, false, int(validationType), now, now).
|
||||
Suffix("RETURNING " + domainsIDCol).
|
||||
PlaceholderFormat(squirrel.Dollar)
|
||||
|
||||
stmt, args, err := query.ToSql()
|
||||
@@ -208,12 +213,14 @@ func (r *DomainRepository) AddOrganizationDomain(ctx context.Context, instanceID
|
||||
return nil, zerrors.ThrowInternal(err, "DOMAIN-Ew2xU", "Errors.Internal")
|
||||
}
|
||||
|
||||
_, err = r.client.ExecContext(ctx, stmt, args...)
|
||||
var id string
|
||||
err = r.client.QueryRowContext(ctx, stmt, args...).Scan(&id)
|
||||
if err != nil {
|
||||
return nil, zerrors.ThrowInternal(err, "DOMAIN-Fx3yV", "Errors.Internal")
|
||||
}
|
||||
|
||||
return &v2domain.Domain{
|
||||
ID: id,
|
||||
InstanceID: instanceID,
|
||||
OrganizationID: &organizationID,
|
||||
Domain: domain,
|
||||
@@ -262,7 +269,7 @@ func (r *DomainRepository) SetOrganizationDomainVerified(ctx context.Context, in
|
||||
|
||||
// SetPrimary sets the primary domain for an organization
|
||||
func (r *DomainRepository) SetOrganizationDomainPrimary(ctx context.Context, instanceID, organizationID, domain string) error {
|
||||
return r.withTransaction(ctx, func(tx database.Tx) error {
|
||||
return r.withTransaction(ctx, func(tx TX) error {
|
||||
// First, unset any existing primary domain for this organization
|
||||
unsetQuery := squirrel.Update(domainsTable).
|
||||
Set(domainsIsPrimaryCol, false).
|
||||
@@ -443,6 +450,7 @@ func (r *DomainRepository) List(ctx context.Context, criteria v2domain.DomainSea
|
||||
|
||||
func (r *DomainRepository) buildSelectQuery(criteria v2domain.DomainSearchCriteria, pagination v2domain.DomainPagination) squirrel.SelectBuilder {
|
||||
query := squirrel.Select(
|
||||
domainsIDCol,
|
||||
domainsInstanceIDCol,
|
||||
domainsOrgIDCol,
|
||||
domainsDomainCol,
|
||||
@@ -474,8 +482,7 @@ func (r *DomainRepository) applySearchCriteria(query squirrel.SelectBuilder, cri
|
||||
query = query.Where(squirrel.Expr(domainsDeletedAtCol + " IS NULL"))
|
||||
|
||||
if criteria.ID != nil {
|
||||
// Note: Our table doesn't have an ID column. This might need to be reconsidered
|
||||
// For now, we'll ignore this criterion since the spec doesn't define where ID comes from
|
||||
query = query.Where(squirrel.Eq{domainsIDCol: *criteria.ID})
|
||||
}
|
||||
|
||||
if criteria.Domain != nil {
|
||||
@@ -536,20 +543,21 @@ func (r *DomainRepository) applyPagination(query squirrel.SelectBuilder, paginat
|
||||
}
|
||||
|
||||
func (r *DomainRepository) scanDomain(rows *sql.Rows) (*v2domain.Domain, error) {
|
||||
var domain v2domain.Domain
|
||||
var domainRecord v2domain.Domain
|
||||
var orgID sql.NullString
|
||||
var validationType sql.NullInt32
|
||||
var deletedAt sql.NullTime
|
||||
|
||||
err := rows.Scan(
|
||||
&domain.InstanceID,
|
||||
&domainRecord.ID,
|
||||
&domainRecord.InstanceID,
|
||||
&orgID,
|
||||
&domain.Domain,
|
||||
&domain.IsVerified,
|
||||
&domain.IsPrimary,
|
||||
&domainRecord.Domain,
|
||||
&domainRecord.IsVerified,
|
||||
&domainRecord.IsPrimary,
|
||||
&validationType,
|
||||
&domain.CreatedAt,
|
||||
&domain.UpdatedAt,
|
||||
&domainRecord.CreatedAt,
|
||||
&domainRecord.UpdatedAt,
|
||||
&deletedAt,
|
||||
)
|
||||
if err != nil {
|
||||
@@ -557,19 +565,19 @@ func (r *DomainRepository) scanDomain(rows *sql.Rows) (*v2domain.Domain, error)
|
||||
}
|
||||
|
||||
if orgID.Valid {
|
||||
domain.OrganizationID = &orgID.String
|
||||
domainRecord.OrganizationID = &orgID.String
|
||||
}
|
||||
|
||||
if validationType.Valid {
|
||||
validationTypeValue := domain.OrgDomainValidationType(validationType.Int32)
|
||||
domain.ValidationType = &validationTypeValue
|
||||
domainRecord.ValidationType = &validationTypeValue
|
||||
}
|
||||
|
||||
if deletedAt.Valid {
|
||||
domain.DeletedAt = &deletedAt.Time
|
||||
domainRecord.DeletedAt = &deletedAt.Time
|
||||
}
|
||||
|
||||
return &domain, nil
|
||||
return &domainRecord, nil
|
||||
}
|
||||
|
||||
func (r *DomainRepository) withTransaction(ctx context.Context, fn func(TX) error) error {
|
||||
|
@@ -25,14 +25,16 @@ func TestDomainRepository_AddInstanceDomain(t *testing.T) {
|
||||
|
||||
instanceID := "test-instance-id"
|
||||
domainName := "test.example.com"
|
||||
expectedID := "domain-id-123"
|
||||
|
||||
mock.ExpectExec(`INSERT INTO zitadel\.domains`).
|
||||
mock.ExpectQuery(`INSERT INTO zitadel\.domains`).
|
||||
WithArgs(instanceID, domainName, true, false, sqlmock.AnyArg(), sqlmock.AnyArg()).
|
||||
WillReturnResult(sqlmock.NewResult(1, 1))
|
||||
WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(expectedID))
|
||||
|
||||
result, err := repo.AddInstanceDomain(context.Background(), instanceID, domainName)
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, expectedID, result.ID)
|
||||
assert.Equal(t, instanceID, result.InstanceID)
|
||||
assert.Nil(t, result.OrganizationID)
|
||||
assert.Equal(t, domainName, result.Domain)
|
||||
@@ -54,14 +56,16 @@ func TestDomainRepository_AddOrganizationDomain(t *testing.T) {
|
||||
orgID := "test-org-id"
|
||||
domainName := "test.example.com"
|
||||
validationType := domain.OrgDomainValidationTypeHTTP
|
||||
expectedID := "domain-id-456"
|
||||
|
||||
mock.ExpectExec(`INSERT INTO zitadel\.domains`).
|
||||
mock.ExpectQuery(`INSERT INTO zitadel\.domains`).
|
||||
WithArgs(instanceID, orgID, domainName, false, false, int(validationType), sqlmock.AnyArg(), sqlmock.AnyArg()).
|
||||
WillReturnResult(sqlmock.NewResult(1, 1))
|
||||
WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(expectedID))
|
||||
|
||||
result, err := repo.AddOrganizationDomain(context.Background(), instanceID, orgID, domainName, validationType)
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, expectedID, result.ID)
|
||||
assert.Equal(t, instanceID, result.InstanceID)
|
||||
assert.Equal(t, orgID, *result.OrganizationID)
|
||||
assert.Equal(t, domainName, result.Domain)
|
||||
@@ -120,8 +124,8 @@ func TestDomainRepository_Get(t *testing.T) {
|
||||
}
|
||||
|
||||
rows := sqlmock.NewRows([]string{
|
||||
"instance_id", "org_id", "domain", "is_verified", "is_primary", "validation_type", "created_at", "updated_at", "deleted_at",
|
||||
}).AddRow(instanceID, nil, domainName, true, false, nil, now, now, nil)
|
||||
"id", "instance_id", "org_id", "domain", "is_verified", "is_primary", "validation_type", "created_at", "updated_at", "deleted_at",
|
||||
}).AddRow("domain-123", instanceID, nil, domainName, true, false, nil, now, now, nil)
|
||||
|
||||
mock.ExpectQuery(`SELECT .* FROM zitadel\.domains`).
|
||||
WithArgs(domainName, instanceID).
|
||||
@@ -130,6 +134,7 @@ func TestDomainRepository_Get(t *testing.T) {
|
||||
result, err := repo.Get(context.Background(), criteria)
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "domain-123", result.ID)
|
||||
assert.Equal(t, instanceID, result.InstanceID)
|
||||
assert.Nil(t, result.OrganizationID)
|
||||
assert.Equal(t, domainName, result.Domain)
|
||||
@@ -168,10 +173,10 @@ func TestDomainRepository_List(t *testing.T) {
|
||||
|
||||
// Mock data query
|
||||
rows := sqlmock.NewRows([]string{
|
||||
"instance_id", "org_id", "domain", "is_verified", "is_primary", "validation_type", "created_at", "updated_at", "deleted_at",
|
||||
"id", "instance_id", "org_id", "domain", "is_verified", "is_primary", "validation_type", "created_at", "updated_at", "deleted_at",
|
||||
}).
|
||||
AddRow(instanceID, nil, "instance.example.com", true, true, nil, now, now, nil).
|
||||
AddRow(instanceID, "org-id", "org.example.com", false, false, int(domain.OrgDomainValidationTypeHTTP), now, now, nil)
|
||||
AddRow("domain-instance", instanceID, nil, "instance.example.com", true, true, nil, now, now, nil).
|
||||
AddRow("domain-org", instanceID, "org-id", "org.example.com", false, false, int(domain.OrgDomainValidationTypeHTTP), now, now, nil)
|
||||
|
||||
mock.ExpectQuery(`SELECT .* FROM zitadel\.domains.*ORDER BY domain ASC.*LIMIT 10`).
|
||||
WithArgs(instanceID).
|
||||
@@ -184,6 +189,7 @@ func TestDomainRepository_List(t *testing.T) {
|
||||
assert.Len(t, result.Domains, 2)
|
||||
|
||||
// Check first domain (instance domain)
|
||||
assert.Equal(t, "domain-instance", result.Domains[0].ID)
|
||||
assert.Equal(t, instanceID, result.Domains[0].InstanceID)
|
||||
assert.Nil(t, result.Domains[0].OrganizationID)
|
||||
assert.Equal(t, "instance.example.com", result.Domains[0].Domain)
|
||||
@@ -191,6 +197,7 @@ func TestDomainRepository_List(t *testing.T) {
|
||||
assert.True(t, result.Domains[0].IsPrimary)
|
||||
|
||||
// Check second domain (org domain)
|
||||
assert.Equal(t, "domain-org", result.Domains[1].ID)
|
||||
assert.Equal(t, instanceID, result.Domains[1].InstanceID)
|
||||
assert.Equal(t, "org-id", *result.Domains[1].OrganizationID)
|
||||
assert.Equal(t, "org.example.com", result.Domains[1].Domain)
|
||||
|
Reference in New Issue
Block a user