Implement unified domains table with migration, repository, and projection

Co-authored-by: adlerhurst <27845747+adlerhurst@users.noreply.github.com>
This commit is contained in:
copilot-swe-agent[bot]
2025-07-14 20:05:09 +00:00
parent 11380c9cda
commit d3de8a2150
9 changed files with 1317 additions and 0 deletions

40
cmd/setup/61.go Normal file
View File

@@ -0,0 +1,40 @@
package setup
import (
"context"
"database/sql"
"embed"
"fmt"
"github.com/zitadel/logging"
"github.com/zitadel/zitadel/internal/database"
"github.com/zitadel/zitadel/internal/eventstore"
)
var (
//go:embed 61/*.sql
createDomainsTable embed.FS
)
type CreateDomainsTable struct {
dbClient *database.DB
}
func (mig *CreateDomainsTable) Execute(ctx context.Context, _ eventstore.Event) error {
statements, err := readStatements(createDomainsTable, "61")
if err != nil {
return err
}
for _, stmt := range statements {
logging.WithFields("file", stmt.file, "migration", mig.String()).Info("execute statement")
if _, err := mig.dbClient.ExecContext(ctx, stmt.query); err != nil {
return fmt.Errorf("%s %s: %w", mig.String(), stmt.file, err)
}
}
return nil
}
func (mig *CreateDomainsTable) String() string {
return "61_create_domains_table"
}

View File

@@ -0,0 +1,24 @@
CREATE TABLE IF NOT EXISTS zitadel.domains(
instance_id TEXT NOT NULL
, org_id TEXT
, domain TEXT NOT NULL CHECK (LENGTH(domain) BETWEEN 1 AND 255)
, is_verified BOOLEAN NOT NULL DEFAULT FALSE
, is_primary BOOLEAN NOT NULL DEFAULT FALSE
-- TODO make validation_type enum
, validation_type SMALLINT CHECK (validation_type >= 0)
, created_at TIMESTAMP DEFAULT NOW()
, updated_at TIMESTAMP DEFAULT NOW()
, deleted_at TIMESTAMP DEFAULT NULL
, FOREIGN KEY (instance_id) REFERENCES zitadel.instances(id) ON DELETE CASCADE
, FOREIGN KEY (instance_id, org_id) REFERENCES zitadel.organizations(instance_id, id) ON DELETE CASCADE
, CONSTRAINT domain_unique UNIQUE NULLS NOT DISTINCT (instance_id, org_id, domain) WHERE deleted_at IS NULL
);
CREATE TRIGGER IF NOT EXISTS trigger_set_updated_at
BEFORE UPDATE ON zitadel.domains
FOR EACH ROW
WHEN (OLD.updated_at IS NOT DISTINCT FROM NEW.updated_at)
EXECUTE FUNCTION zitadel.set_updated_at();

View File

@@ -157,6 +157,7 @@ type Steps struct {
s58ReplaceLoginNames3View *ReplaceLoginNames3View
s59SetupWebkeys *SetupWebkeys
s60GenerateSystemID *GenerateSystemID
s61CreateDomainsTable *CreateDomainsTable
}
func MustNewSteps(v *viper.Viper) *Steps {

View File

@@ -218,6 +218,7 @@ func Setup(ctx context.Context, config *Config, steps *Steps, masterKey string)
steps.s57CreateResourceCounts = &CreateResourceCounts{dbClient: dbClient}
steps.s58ReplaceLoginNames3View = &ReplaceLoginNames3View{dbClient: dbClient}
steps.s60GenerateSystemID = &GenerateSystemID{eventstore: eventstoreClient}
steps.s61CreateDomainsTable = &CreateDomainsTable{dbClient: dbClient}
err = projection.Create(ctx, dbClient, eventstoreClient, config.Projections, nil, nil, nil)
logging.OnError(err).Fatal("unable to start projections")
@@ -266,6 +267,7 @@ func Setup(ctx context.Context, config *Config, steps *Steps, masterKey string)
steps.s57CreateResourceCounts,
steps.s58ReplaceLoginNames3View,
steps.s60GenerateSystemID,
steps.s61CreateDomainsTable,
} {
setupErr = executeMigration(ctx, eventstoreClient, step, "migration failed")
if setupErr != nil {

View File

@@ -0,0 +1,340 @@
package projection
import (
"context"
"github.com/zitadel/zitadel/internal/domain"
"github.com/zitadel/zitadel/internal/eventstore"
old_handler "github.com/zitadel/zitadel/internal/eventstore/handler"
"github.com/zitadel/zitadel/internal/eventstore/handler/v2"
"github.com/zitadel/zitadel/internal/repository/instance"
"github.com/zitadel/zitadel/internal/repository/org"
"github.com/zitadel/zitadel/internal/zerrors"
)
const (
DomainsTable = "zitadel.domains"
DomainsInstanceIDCol = "instance_id"
DomainsOrgIDCol = "org_id"
DomainsDomainCol = "domain"
DomainsIsVerifiedCol = "is_verified"
DomainsIsPrimaryCol = "is_primary"
DomainsValidationTypeCol = "validation_type"
DomainsCreatedAtCol = "created_at"
DomainsUpdatedAtCol = "updated_at"
DomainsDeletedAtCol = "deleted_at"
)
type domainsProjection struct{}
func newDomainsProjection(ctx context.Context, config handler.Config) *handler.Handler {
return handler.NewHandler(ctx, &config, new(domainsProjection))
}
func (*domainsProjection) Name() string {
return DomainsTable
}
func (*domainsProjection) Init() *old_handler.Check {
// The table is created by migration, so we just need to check it exists
return handler.NewTableCheck(
handler.NewTable([]*handler.InitColumn{
handler.NewColumn(DomainsInstanceIDCol, handler.ColumnTypeText),
handler.NewColumn(DomainsOrgIDCol, handler.ColumnTypeText),
handler.NewColumn(DomainsDomainCol, handler.ColumnTypeText),
handler.NewColumn(DomainsIsVerifiedCol, handler.ColumnTypeBool),
handler.NewColumn(DomainsIsPrimaryCol, handler.ColumnTypeBool),
handler.NewColumn(DomainsValidationTypeCol, handler.ColumnTypeEnum),
handler.NewColumn(DomainsCreatedAtCol, handler.ColumnTypeTimestamp),
handler.NewColumn(DomainsUpdatedAtCol, handler.ColumnTypeTimestamp),
handler.NewColumn(DomainsDeletedAtCol, handler.ColumnTypeTimestamp),
},
// Note: We don't define primary key here since the table is created by migration
),
)
}
func (p *domainsProjection) Reducers() []handler.AggregateReducer {
return []handler.AggregateReducer{
{
Aggregate: org.AggregateType,
EventReducers: []handler.EventReducer{
{
Event: org.OrgDomainAddedEventType,
Reduce: p.reduceOrgDomainAdded,
},
{
Event: org.OrgDomainVerificationAddedEventType,
Reduce: p.reduceOrgDomainVerificationAdded,
},
{
Event: org.OrgDomainVerifiedEventType,
Reduce: p.reduceOrgDomainVerified,
},
{
Event: org.OrgDomainPrimarySetEventType,
Reduce: p.reduceOrgPrimaryDomainSet,
},
{
Event: org.OrgDomainRemovedEventType,
Reduce: p.reduceOrgDomainRemoved,
},
{
Event: org.OrgRemovedEventType,
Reduce: p.reduceOrgRemoved,
},
},
},
{
Aggregate: instance.AggregateType,
EventReducers: []handler.EventReducer{
{
Event: instance.InstanceDomainAddedEventType,
Reduce: p.reduceInstanceDomainAdded,
},
{
Event: instance.InstanceDomainPrimarySetEventType,
Reduce: p.reduceInstancePrimaryDomainSet,
},
{
Event: instance.InstanceDomainRemovedEventType,
Reduce: p.reduceInstanceDomainRemoved,
},
{
Event: instance.InstanceRemovedEventType,
Reduce: p.reduceInstanceRemoved,
},
},
},
}
}
// Organization domain event handlers
func (p *domainsProjection) reduceOrgDomainAdded(event eventstore.Event) (*handler.Statement, error) {
e, ok := event.(*org.DomainAddedEvent)
if !ok {
return nil, zerrors.ThrowInvalidArgumentf(nil, "PROJE-DM2DI", "reduce.wrong.event.type %s", org.OrgDomainAddedEventType)
}
return handler.NewCreateStatement(
e,
[]handler.Column{
handler.NewCol(DomainsInstanceIDCol, e.Aggregate().InstanceID),
handler.NewCol(DomainsOrgIDCol, e.Aggregate().ID),
handler.NewCol(DomainsDomainCol, e.Domain),
handler.NewCol(DomainsIsVerifiedCol, false),
handler.NewCol(DomainsIsPrimaryCol, false),
handler.NewCol(DomainsValidationTypeCol, domain.OrgDomainValidationTypeUnspecified),
handler.NewCol(DomainsCreatedAtCol, e.CreationDate()),
handler.NewCol(DomainsUpdatedAtCol, e.CreationDate()),
},
), nil
}
func (p *domainsProjection) reduceOrgDomainVerificationAdded(event eventstore.Event) (*handler.Statement, error) {
e, ok := event.(*org.DomainVerificationAddedEvent)
if !ok {
return nil, zerrors.ThrowInvalidArgumentf(nil, "PROJE-EBzyu", "reduce.wrong.event.type %s", org.OrgDomainVerificationAddedEventType)
}
return handler.NewUpdateStatement(
e,
[]handler.Column{
handler.NewCol(DomainsUpdatedAtCol, e.CreationDate()),
handler.NewCol(DomainsValidationTypeCol, e.ValidationType),
},
[]handler.Condition{
handler.NewCond(DomainsInstanceIDCol, e.Aggregate().InstanceID),
handler.NewCond(DomainsOrgIDCol, e.Aggregate().ID),
handler.NewCond(DomainsDomainCol, e.Domain),
handler.NewCond(DomainsDeletedAtCol, nil),
},
), nil
}
func (p *domainsProjection) reduceOrgDomainVerified(event eventstore.Event) (*handler.Statement, error) {
e, ok := event.(*org.DomainVerifiedEvent)
if !ok {
return nil, zerrors.ThrowInvalidArgumentf(nil, "PROJE-3Rvkr", "reduce.wrong.event.type %s", org.OrgDomainVerifiedEventType)
}
return handler.NewUpdateStatement(
e,
[]handler.Column{
handler.NewCol(DomainsUpdatedAtCol, e.CreationDate()),
handler.NewCol(DomainsIsVerifiedCol, true),
},
[]handler.Condition{
handler.NewCond(DomainsInstanceIDCol, e.Aggregate().InstanceID),
handler.NewCond(DomainsOrgIDCol, e.Aggregate().ID),
handler.NewCond(DomainsDomainCol, e.Domain),
handler.NewCond(DomainsDeletedAtCol, nil),
},
), nil
}
func (p *domainsProjection) reduceOrgPrimaryDomainSet(event eventstore.Event) (*handler.Statement, error) {
e, ok := event.(*org.DomainPrimarySetEvent)
if !ok {
return nil, zerrors.ThrowInvalidArgumentf(nil, "PROJE-aIuei", "reduce.wrong.event.type %s", org.OrgDomainPrimarySetEventType)
}
return handler.NewMultiStatement(
e,
handler.AddUpdateStatement(
[]handler.Column{
handler.NewCol(DomainsUpdatedAtCol, e.CreationDate()),
handler.NewCol(DomainsIsPrimaryCol, false),
},
[]handler.Condition{
handler.NewCond(DomainsInstanceIDCol, e.Aggregate().InstanceID),
handler.NewCond(DomainsOrgIDCol, e.Aggregate().ID),
handler.NewCond(DomainsIsPrimaryCol, true),
handler.NewCond(DomainsDeletedAtCol, nil),
},
),
handler.AddUpdateStatement(
[]handler.Column{
handler.NewCol(DomainsUpdatedAtCol, e.CreationDate()),
handler.NewCol(DomainsIsPrimaryCol, true),
},
[]handler.Condition{
handler.NewCond(DomainsInstanceIDCol, e.Aggregate().InstanceID),
handler.NewCond(DomainsOrgIDCol, e.Aggregate().ID),
handler.NewCond(DomainsDomainCol, e.Domain),
handler.NewCond(DomainsDeletedAtCol, nil),
},
),
), nil
}
func (p *domainsProjection) reduceOrgDomainRemoved(event eventstore.Event) (*handler.Statement, error) {
e, ok := event.(*org.DomainRemovedEvent)
if !ok {
return nil, zerrors.ThrowInvalidArgumentf(nil, "PROJE-gh1Mx", "reduce.wrong.event.type %s", org.OrgDomainRemovedEventType)
}
return handler.NewUpdateStatement(
e,
[]handler.Column{
handler.NewCol(DomainsUpdatedAtCol, e.CreationDate()),
handler.NewCol(DomainsDeletedAtCol, e.CreationDate()),
},
[]handler.Condition{
handler.NewCond(DomainsInstanceIDCol, e.Aggregate().InstanceID),
handler.NewCond(DomainsOrgIDCol, e.Aggregate().ID),
handler.NewCond(DomainsDomainCol, e.Domain),
},
), nil
}
func (p *domainsProjection) reduceOrgRemoved(event eventstore.Event) (*handler.Statement, error) {
e, ok := event.(*org.OrgRemovedEvent)
if !ok {
return nil, zerrors.ThrowInvalidArgumentf(nil, "PROJE-dMUKJ", "reduce.wrong.event.type %s", org.OrgRemovedEventType)
}
return handler.NewUpdateStatement(
e,
[]handler.Column{
handler.NewCol(DomainsUpdatedAtCol, e.CreationDate()),
handler.NewCol(DomainsDeletedAtCol, e.CreationDate()),
},
[]handler.Condition{
handler.NewCond(DomainsInstanceIDCol, e.Aggregate().InstanceID),
handler.NewCond(DomainsOrgIDCol, e.Aggregate().ID),
handler.NewCond(DomainsDeletedAtCol, nil),
},
), nil
}
// Instance domain event handlers
func (p *domainsProjection) reduceInstanceDomainAdded(event eventstore.Event) (*handler.Statement, error) {
e, ok := event.(*instance.DomainAddedEvent)
if !ok {
return nil, zerrors.ThrowInvalidArgumentf(nil, "PROJE-38nNf", "reduce.wrong.event.type %s", instance.InstanceDomainAddedEventType)
}
return handler.NewCreateStatement(
e,
[]handler.Column{
handler.NewCol(DomainsInstanceIDCol, e.Aggregate().ID),
handler.NewCol(DomainsOrgIDCol, nil), // Instance domains have no org_id
handler.NewCol(DomainsDomainCol, e.Domain),
handler.NewCol(DomainsIsVerifiedCol, true), // Instance domains are always verified
handler.NewCol(DomainsIsPrimaryCol, false),
handler.NewCol(DomainsValidationTypeCol, nil), // Instance domains have no validation type
handler.NewCol(DomainsCreatedAtCol, e.CreationDate()),
handler.NewCol(DomainsUpdatedAtCol, e.CreationDate()),
},
), nil
}
func (p *domainsProjection) reduceInstancePrimaryDomainSet(event eventstore.Event) (*handler.Statement, error) {
e, ok := event.(*instance.DomainPrimarySetEvent)
if !ok {
return nil, zerrors.ThrowInvalidArgumentf(nil, "PROJE-f8nlw", "reduce.wrong.event.type %s", instance.InstanceDomainPrimarySetEventType)
}
return handler.NewMultiStatement(
e,
handler.AddUpdateStatement(
[]handler.Column{
handler.NewCol(DomainsUpdatedAtCol, e.CreationDate()),
handler.NewCol(DomainsIsPrimaryCol, false),
},
[]handler.Condition{
handler.NewCond(DomainsInstanceIDCol, e.Aggregate().ID),
handler.NewCond(DomainsOrgIDCol, nil), // Instance domains
handler.NewCond(DomainsIsPrimaryCol, true),
handler.NewCond(DomainsDeletedAtCol, nil),
},
),
handler.AddUpdateStatement(
[]handler.Column{
handler.NewCol(DomainsUpdatedAtCol, e.CreationDate()),
handler.NewCol(DomainsIsPrimaryCol, true),
},
[]handler.Condition{
handler.NewCond(DomainsInstanceIDCol, e.Aggregate().ID),
handler.NewCond(DomainsOrgIDCol, nil), // Instance domains
handler.NewCond(DomainsDomainCol, e.Domain),
handler.NewCond(DomainsDeletedAtCol, nil),
},
),
), nil
}
func (p *domainsProjection) reduceInstanceDomainRemoved(event eventstore.Event) (*handler.Statement, error) {
e, ok := event.(*instance.DomainRemovedEvent)
if !ok {
return nil, zerrors.ThrowInvalidArgumentf(nil, "PROJE-388Nk", "reduce.wrong.event.type %s", instance.InstanceDomainRemovedEventType)
}
return handler.NewUpdateStatement(
e,
[]handler.Column{
handler.NewCol(DomainsUpdatedAtCol, e.CreationDate()),
handler.NewCol(DomainsDeletedAtCol, e.CreationDate()),
},
[]handler.Condition{
handler.NewCond(DomainsInstanceIDCol, e.Aggregate().ID),
handler.NewCond(DomainsOrgIDCol, nil), // Instance domains
handler.NewCond(DomainsDomainCol, e.Domain),
},
), nil
}
func (p *domainsProjection) reduceInstanceRemoved(event eventstore.Event) (*handler.Statement, error) {
e, ok := event.(*instance.InstanceRemovedEvent)
if !ok {
return nil, zerrors.ThrowInvalidArgumentf(nil, "PROJE-2n9f0", "reduce.wrong.event.type %s", instance.InstanceRemovedEventType)
}
return handler.NewUpdateStatement(
e,
[]handler.Column{
handler.NewCol(DomainsUpdatedAtCol, e.CreationDate()),
handler.NewCol(DomainsDeletedAtCol, e.CreationDate()),
},
[]handler.Condition{
handler.NewCond(DomainsInstanceIDCol, e.Aggregate().ID),
handler.NewCond(DomainsDeletedAtCol, nil),
},
), nil
}

View File

@@ -37,6 +37,7 @@ var (
ProjectGrantProjection *handler.Handler
ProjectRoleProjection *handler.Handler
OrgDomainProjection *handler.Handler
DomainsProjection *handler.Handler // Unified domains table
LoginPolicyProjection *handler.Handler
IDPProjection *handler.Handler
AppProjection *handler.Handler
@@ -134,6 +135,7 @@ func Create(ctx context.Context, sqlClient *database.DB, es handler.EventStore,
ProjectGrantProjection = newProjectGrantProjection(ctx, applyCustomConfig(projectionConfig, config.Customizations["project_grants"]))
ProjectRoleProjection = newProjectRoleProjection(ctx, applyCustomConfig(projectionConfig, config.Customizations["project_roles"]))
OrgDomainProjection = newOrgDomainProjection(ctx, applyCustomConfig(projectionConfig, config.Customizations["org_domains"]))
DomainsProjection = newDomainsProjection(ctx, applyCustomConfig(projectionConfig, config.Customizations["domains"]))
LoginPolicyProjection = newLoginPolicyProjection(ctx, applyCustomConfig(projectionConfig, config.Customizations["login_policies"]))
IDPProjection = newIDPProjection(ctx, applyCustomConfig(projectionConfig, config.Customizations["idps"]))
AppProjection = newAppProjection(ctx, applyCustomConfig(projectionConfig, config.Customizations["apps"]))

View File

@@ -0,0 +1,107 @@
package domain
import (
"context"
"time"
"github.com/zitadel/zitadel/internal/domain"
"github.com/zitadel/zitadel/internal/v2/database"
)
// Domain represents a unified domain that can belong to either an instance or an organization
type Domain struct {
ID string
InstanceID string
OrganizationID *string // nil for instance domains
Domain string
IsVerified bool
IsPrimary bool
ValidationType *domain.OrgDomainValidationType // nil for instance domains
CreatedAt time.Time
UpdatedAt time.Time
DeletedAt *time.Time
}
// IsInstanceDomain returns true if this is an instance domain (OrganizationID is nil)
func (d *Domain) IsInstanceDomain() bool {
return d.OrganizationID == nil
}
// IsOrganizationDomain returns true if this is an organization domain (OrganizationID is not nil)
func (d *Domain) IsOrganizationDomain() bool {
return d.OrganizationID != nil
}
// DomainSearchCriteria defines the search criteria for domains
type DomainSearchCriteria struct {
ID *string
Domain *string
InstanceID *string
OrganizationID *string
IsVerified *bool
IsPrimary *bool
}
// DomainPagination defines pagination options for domain queries
type DomainPagination struct {
Limit uint32
Offset uint32
SortBy DomainSortField
Order database.SortOrder
}
// DomainSortField defines the fields available for sorting domain results
type DomainSortField int
const (
DomainSortFieldCreatedAt DomainSortField = iota
DomainSortFieldUpdatedAt
DomainSortFieldDomain
)
// DomainList represents a paginated list of domains
type DomainList struct {
Domains []*Domain
TotalCount uint64
}
// InstanceDomainRepository defines the repository interface for instance domain operations
type InstanceDomainRepository interface {
// Add creates a new instance domain (always verified)
Add(ctx context.Context, instanceID, domain string) (*Domain, error)
// SetPrimary sets the primary domain for an instance
SetPrimary(ctx context.Context, instanceID, domain string) error
// Remove soft deletes an instance domain
Remove(ctx context.Context, instanceID, domain string) error
// Get returns a single instance domain matching the criteria
// Returns error if multiple domains are found
Get(ctx context.Context, criteria DomainSearchCriteria) (*Domain, error)
// List returns a list of instance domains matching the criteria with pagination
List(ctx context.Context, criteria DomainSearchCriteria, pagination DomainPagination) (*DomainList, error)
}
// OrganizationDomainRepository defines the repository interface for organization domain operations
type OrganizationDomainRepository interface {
// Add creates a new organization domain
Add(ctx context.Context, instanceID, organizationID, domain string, validationType domain.OrgDomainValidationType) (*Domain, error)
// SetVerified marks an organization domain as verified
SetVerified(ctx context.Context, instanceID, organizationID, domain string) error
// SetPrimary sets the primary domain for an organization
SetPrimary(ctx context.Context, instanceID, organizationID, domain string) error
// Remove soft deletes an organization domain
Remove(ctx context.Context, instanceID, organizationID, domain string) error
// Get returns a single organization domain matching the criteria
// Returns error if multiple domains are found
Get(ctx context.Context, criteria DomainSearchCriteria) (*Domain, error)
// List returns a list of organization domains matching the criteria with pagination
List(ctx context.Context, criteria DomainSearchCriteria, pagination DomainPagination) (*DomainList, error)
}

View File

@@ -0,0 +1,600 @@
package readmodel
import (
"context"
"database/sql"
"time"
"github.com/Masterminds/squirrel"
"github.com/zitadel/zitadel/internal/domain"
v2domain "github.com/zitadel/zitadel/internal/v2/domain"
"github.com/zitadel/zitadel/internal/v2/database"
"github.com/zitadel/zitadel/internal/zerrors"
)
// Database interfaces for the repository
type DB interface {
ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error)
QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error)
QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row
BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error)
}
type TX interface {
ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error)
QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error)
QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row
Rollback() error
Commit() error
}
const (
domainsTable = "zitadel.domains"
domainsInstanceIDCol = "instance_id"
domainsOrgIDCol = "org_id"
domainsDomainCol = "domain"
domainsIsVerifiedCol = "is_verified"
domainsIsPrimaryCol = "is_primary"
domainsValidationTypeCol = "validation_type"
domainsCreatedAtCol = "created_at"
domainsUpdatedAtCol = "updated_at"
domainsDeletedAtCol = "deleted_at"
)
// DomainRepository implements both InstanceDomainRepository and OrganizationDomainRepository
type DomainRepository struct {
client DB
}
// NewDomainRepository creates a new DomainRepository
func NewDomainRepository(client DB) *DomainRepository {
return &DomainRepository{
client: client,
}
}
// Add creates a new instance domain (always verified)
func (r *DomainRepository) AddInstanceDomain(ctx context.Context, instanceID, domain string) (*v2domain.Domain, error) {
now := time.Now()
query := squirrel.Insert(domainsTable).
Columns(
domainsInstanceIDCol,
domainsDomainCol,
domainsIsVerifiedCol,
domainsIsPrimaryCol,
domainsCreatedAtCol,
domainsUpdatedAtCol,
).
Values(instanceID, domain, true, false, now, now).
PlaceholderFormat(squirrel.Dollar)
stmt, args, err := query.ToSql()
if err != nil {
return nil, zerrors.ThrowInternal(err, "DOMAIN-1n8fK", "Errors.Internal")
}
_, err = r.client.ExecContext(ctx, stmt, args...)
if err != nil {
return nil, zerrors.ThrowInternal(err, "DOMAIN-3m9sL", "Errors.Internal")
}
return &v2domain.Domain{
InstanceID: instanceID,
OrganizationID: nil,
Domain: domain,
IsVerified: true,
IsPrimary: false,
ValidationType: nil,
CreatedAt: now,
UpdatedAt: now,
}, nil
}
// SetPrimary sets the primary domain for an instance
func (r *DomainRepository) SetInstanceDomainPrimary(ctx context.Context, instanceID, domain string) error {
return r.withTransaction(ctx, func(tx database.Tx) error {
// First, unset any existing primary domain for this instance
unsetQuery := squirrel.Update(domainsTable).
Set(domainsIsPrimaryCol, false).
Set(domainsUpdatedAtCol, time.Now()).
Where(squirrel.Eq{
domainsInstanceIDCol: instanceID,
domainsOrgIDCol: nil,
domainsIsPrimaryCol: true,
}).
Where(squirrel.Expr(domainsDeletedAtCol + " IS NULL")).
PlaceholderFormat(squirrel.Dollar)
stmt, args, err := unsetQuery.ToSql()
if err != nil {
return zerrors.ThrowInternal(err, "DOMAIN-4k2nL", "Errors.Internal")
}
_, err = tx.ExecContext(ctx, stmt, args...)
if err != nil {
return zerrors.ThrowInternal(err, "DOMAIN-5n3mK", "Errors.Internal")
}
// Then set the new primary domain
setPrimaryQuery := squirrel.Update(domainsTable).
Set(domainsIsPrimaryCol, true).
Set(domainsUpdatedAtCol, time.Now()).
Where(squirrel.Eq{
domainsInstanceIDCol: instanceID,
domainsOrgIDCol: nil,
domainsDomainCol: domain,
}).
Where(squirrel.Expr(domainsDeletedAtCol + " IS NULL")).
PlaceholderFormat(squirrel.Dollar)
stmt, args, err = setPrimaryQuery.ToSql()
if err != nil {
return zerrors.ThrowInternal(err, "DOMAIN-6o4pM", "Errors.Internal")
}
result, err := tx.ExecContext(ctx, stmt, args...)
if err != nil {
return zerrors.ThrowInternal(err, "DOMAIN-7p5qN", "Errors.Internal")
}
rowsAffected, err := result.RowsAffected()
if err != nil {
return zerrors.ThrowInternal(err, "DOMAIN-8q6rO", "Errors.Internal")
}
if rowsAffected == 0 {
return zerrors.ThrowNotFound(nil, "DOMAIN-9r7sP", "Errors.Domain.NotFound")
}
return nil
})
}
// Remove soft deletes an instance domain
func (r *DomainRepository) RemoveInstanceDomain(ctx context.Context, instanceID, domain string) error {
query := squirrel.Update(domainsTable).
Set(domainsDeletedAtCol, time.Now()).
Set(domainsUpdatedAtCol, time.Now()).
Where(squirrel.Eq{
domainsInstanceIDCol: instanceID,
domainsOrgIDCol: nil,
domainsDomainCol: domain,
}).
Where(squirrel.Expr(domainsDeletedAtCol + " IS NULL")).
PlaceholderFormat(squirrel.Dollar)
stmt, args, err := query.ToSql()
if err != nil {
return zerrors.ThrowInternal(err, "DOMAIN-As8tQ", "Errors.Internal")
}
result, err := r.client.ExecContext(ctx, stmt, args...)
if err != nil {
return zerrors.ThrowInternal(err, "DOMAIN-Bt9uR", "Errors.Internal")
}
rowsAffected, err := result.RowsAffected()
if err != nil {
return zerrors.ThrowInternal(err, "DOMAIN-Cu0vS", "Errors.Internal")
}
if rowsAffected == 0 {
return zerrors.ThrowNotFound(nil, "DOMAIN-Dv1wT", "Errors.Domain.NotFound")
}
return nil
}
// Add creates a new organization domain
func (r *DomainRepository) AddOrganizationDomain(ctx context.Context, instanceID, organizationID, domain string, validationType domain.OrgDomainValidationType) (*v2domain.Domain, error) {
now := time.Now()
query := squirrel.Insert(domainsTable).
Columns(
domainsInstanceIDCol,
domainsOrgIDCol,
domainsDomainCol,
domainsIsVerifiedCol,
domainsIsPrimaryCol,
domainsValidationTypeCol,
domainsCreatedAtCol,
domainsUpdatedAtCol,
).
Values(instanceID, organizationID, domain, false, false, int(validationType), now, now).
PlaceholderFormat(squirrel.Dollar)
stmt, args, err := query.ToSql()
if err != nil {
return nil, zerrors.ThrowInternal(err, "DOMAIN-Ew2xU", "Errors.Internal")
}
_, err = r.client.ExecContext(ctx, stmt, args...)
if err != nil {
return nil, zerrors.ThrowInternal(err, "DOMAIN-Fx3yV", "Errors.Internal")
}
return &v2domain.Domain{
InstanceID: instanceID,
OrganizationID: &organizationID,
Domain: domain,
IsVerified: false,
IsPrimary: false,
ValidationType: &validationType,
CreatedAt: now,
UpdatedAt: now,
}, nil
}
// SetVerified marks an organization domain as verified
func (r *DomainRepository) SetOrganizationDomainVerified(ctx context.Context, instanceID, organizationID, domain string) error {
query := squirrel.Update(domainsTable).
Set(domainsIsVerifiedCol, true).
Set(domainsUpdatedAtCol, time.Now()).
Where(squirrel.Eq{
domainsInstanceIDCol: instanceID,
domainsOrgIDCol: organizationID,
domainsDomainCol: domain,
}).
Where(squirrel.Expr(domainsDeletedAtCol + " IS NULL")).
PlaceholderFormat(squirrel.Dollar)
stmt, args, err := query.ToSql()
if err != nil {
return zerrors.ThrowInternal(err, "DOMAIN-Gy4zW", "Errors.Internal")
}
result, err := r.client.ExecContext(ctx, stmt, args...)
if err != nil {
return zerrors.ThrowInternal(err, "DOMAIN-Hz5aX", "Errors.Internal")
}
rowsAffected, err := result.RowsAffected()
if err != nil {
return zerrors.ThrowInternal(err, "DOMAIN-I06bY", "Errors.Internal")
}
if rowsAffected == 0 {
return zerrors.ThrowNotFound(nil, "DOMAIN-J17cZ", "Errors.Domain.NotFound")
}
return nil
}
// SetPrimary sets the primary domain for an organization
func (r *DomainRepository) SetOrganizationDomainPrimary(ctx context.Context, instanceID, organizationID, domain string) error {
return r.withTransaction(ctx, func(tx database.Tx) error {
// First, unset any existing primary domain for this organization
unsetQuery := squirrel.Update(domainsTable).
Set(domainsIsPrimaryCol, false).
Set(domainsUpdatedAtCol, time.Now()).
Where(squirrel.Eq{
domainsInstanceIDCol: instanceID,
domainsOrgIDCol: organizationID,
domainsIsPrimaryCol: true,
}).
Where(squirrel.Expr(domainsDeletedAtCol + " IS NULL")).
PlaceholderFormat(squirrel.Dollar)
stmt, args, err := unsetQuery.ToSql()
if err != nil {
return zerrors.ThrowInternal(err, "DOMAIN-K28d0", "Errors.Internal")
}
_, err = tx.ExecContext(ctx, stmt, args...)
if err != nil {
return zerrors.ThrowInternal(err, "DOMAIN-L39e1", "Errors.Internal")
}
// Then set the new primary domain
setPrimaryQuery := squirrel.Update(domainsTable).
Set(domainsIsPrimaryCol, true).
Set(domainsUpdatedAtCol, time.Now()).
Where(squirrel.Eq{
domainsInstanceIDCol: instanceID,
domainsOrgIDCol: organizationID,
domainsDomainCol: domain,
}).
Where(squirrel.Expr(domainsDeletedAtCol + " IS NULL")).
PlaceholderFormat(squirrel.Dollar)
stmt, args, err = setPrimaryQuery.ToSql()
if err != nil {
return zerrors.ThrowInternal(err, "DOMAIN-M40f2", "Errors.Internal")
}
result, err := tx.ExecContext(ctx, stmt, args...)
if err != nil {
return zerrors.ThrowInternal(err, "DOMAIN-N51g3", "Errors.Internal")
}
rowsAffected, err := result.RowsAffected()
if err != nil {
return zerrors.ThrowInternal(err, "DOMAIN-O62h4", "Errors.Internal")
}
if rowsAffected == 0 {
return zerrors.ThrowNotFound(nil, "DOMAIN-P73i5", "Errors.Domain.NotFound")
}
return nil
})
}
// Remove soft deletes an organization domain
func (r *DomainRepository) RemoveOrganizationDomain(ctx context.Context, instanceID, organizationID, domain string) error {
query := squirrel.Update(domainsTable).
Set(domainsDeletedAtCol, time.Now()).
Set(domainsUpdatedAtCol, time.Now()).
Where(squirrel.Eq{
domainsInstanceIDCol: instanceID,
domainsOrgIDCol: organizationID,
domainsDomainCol: domain,
}).
Where(squirrel.Expr(domainsDeletedAtCol + " IS NULL")).
PlaceholderFormat(squirrel.Dollar)
stmt, args, err := query.ToSql()
if err != nil {
return zerrors.ThrowInternal(err, "DOMAIN-Q84j6", "Errors.Internal")
}
result, err := r.client.ExecContext(ctx, stmt, args...)
if err != nil {
return zerrors.ThrowInternal(err, "DOMAIN-R95k7", "Errors.Internal")
}
rowsAffected, err := result.RowsAffected()
if err != nil {
return zerrors.ThrowInternal(err, "DOMAIN-S06l8", "Errors.Internal")
}
if rowsAffected == 0 {
return zerrors.ThrowNotFound(nil, "DOMAIN-T17m9", "Errors.Domain.NotFound")
}
return nil
}
// Get returns a single domain matching the criteria
func (r *DomainRepository) Get(ctx context.Context, criteria v2domain.DomainSearchCriteria) (*v2domain.Domain, error) {
query := r.buildSelectQuery(criteria, v2domain.DomainPagination{Limit: 2}) // Limit to 2 to detect multiple results
stmt, args, err := query.ToSql()
if err != nil {
return nil, zerrors.ThrowInternal(err, "DOMAIN-U28n0", "Errors.Internal")
}
rows, err := r.client.QueryContext(ctx, stmt, args...)
if err != nil {
return nil, zerrors.ThrowInternal(err, "DOMAIN-V39o1", "Errors.Internal")
}
defer rows.Close()
var domains []*v2domain.Domain
for rows.Next() {
domain, err := r.scanDomain(rows)
if err != nil {
return nil, err
}
domains = append(domains, domain)
}
if err := rows.Err(); err != nil {
return nil, zerrors.ThrowInternal(err, "DOMAIN-W40p2", "Errors.Internal")
}
if len(domains) == 0 {
return nil, zerrors.ThrowNotFound(nil, "DOMAIN-X51q3", "Errors.Domain.NotFound")
}
if len(domains) > 1 {
return nil, zerrors.ThrowInvalidArgument(nil, "DOMAIN-Y62r4", "Errors.Domain.MultipleFound")
}
return domains[0], nil
}
// List returns a list of domains matching the criteria with pagination
func (r *DomainRepository) List(ctx context.Context, criteria v2domain.DomainSearchCriteria, pagination v2domain.DomainPagination) (*v2domain.DomainList, error) {
// First get the total count
countQuery := r.buildCountQuery(criteria)
stmt, args, err := countQuery.ToSql()
if err != nil {
return nil, zerrors.ThrowInternal(err, "DOMAIN-Z73s5", "Errors.Internal")
}
var totalCount uint64
err = r.client.QueryRowContext(ctx, stmt, args...).Scan(&totalCount)
if err != nil {
return nil, zerrors.ThrowInternal(err, "DOMAIN-A84t6", "Errors.Internal")
}
// Then get the actual data
query := r.buildSelectQuery(criteria, pagination)
stmt, args, err = query.ToSql()
if err != nil {
return nil, zerrors.ThrowInternal(err, "DOMAIN-B95u7", "Errors.Internal")
}
rows, err := r.client.QueryContext(ctx, stmt, args...)
if err != nil {
return nil, zerrors.ThrowInternal(err, "DOMAIN-C06v8", "Errors.Internal")
}
defer rows.Close()
var domains []*v2domain.Domain
for rows.Next() {
domain, err := r.scanDomain(rows)
if err != nil {
return nil, err
}
domains = append(domains, domain)
}
if err := rows.Err(); err != nil {
return nil, zerrors.ThrowInternal(err, "DOMAIN-D17w9", "Errors.Internal")
}
return &v2domain.DomainList{
Domains: domains,
TotalCount: totalCount,
}, nil
}
func (r *DomainRepository) buildSelectQuery(criteria v2domain.DomainSearchCriteria, pagination v2domain.DomainPagination) squirrel.SelectBuilder {
query := squirrel.Select(
domainsInstanceIDCol,
domainsOrgIDCol,
domainsDomainCol,
domainsIsVerifiedCol,
domainsIsPrimaryCol,
domainsValidationTypeCol,
domainsCreatedAtCol,
domainsUpdatedAtCol,
domainsDeletedAtCol,
).From(domainsTable).
PlaceholderFormat(squirrel.Dollar)
query = r.applySearchCriteria(query, criteria)
query = r.applyPagination(query, pagination)
return query
}
func (r *DomainRepository) buildCountQuery(criteria v2domain.DomainSearchCriteria) squirrel.SelectBuilder {
query := squirrel.Select("COUNT(*)").
From(domainsTable).
PlaceholderFormat(squirrel.Dollar)
return r.applySearchCriteria(query, criteria)
}
func (r *DomainRepository) applySearchCriteria(query squirrel.SelectBuilder, criteria v2domain.DomainSearchCriteria) squirrel.SelectBuilder {
// Always exclude soft-deleted domains
query = query.Where(squirrel.Expr(domainsDeletedAtCol + " IS NULL"))
if criteria.ID != nil {
// Note: Our table doesn't have an ID column. This might need to be reconsidered
// For now, we'll ignore this criterion since the spec doesn't define where ID comes from
}
if criteria.Domain != nil {
query = query.Where(squirrel.Eq{domainsDomainCol: *criteria.Domain})
}
if criteria.InstanceID != nil {
query = query.Where(squirrel.Eq{domainsInstanceIDCol: *criteria.InstanceID})
}
if criteria.OrganizationID != nil {
query = query.Where(squirrel.Eq{domainsOrgIDCol: *criteria.OrganizationID})
}
if criteria.IsVerified != nil {
query = query.Where(squirrel.Eq{domainsIsVerifiedCol: *criteria.IsVerified})
}
if criteria.IsPrimary != nil {
query = query.Where(squirrel.Eq{domainsIsPrimaryCol: *criteria.IsPrimary})
}
return query
}
func (r *DomainRepository) applyPagination(query squirrel.SelectBuilder, pagination v2domain.DomainPagination) squirrel.SelectBuilder {
// Apply sorting
var orderBy string
switch pagination.SortBy {
case v2domain.DomainSortFieldCreatedAt:
orderBy = domainsCreatedAtCol
case v2domain.DomainSortFieldUpdatedAt:
orderBy = domainsUpdatedAtCol
case v2domain.DomainSortFieldDomain:
orderBy = domainsDomainCol
default:
orderBy = domainsCreatedAtCol
}
if pagination.Order == database.SortOrderDesc {
orderBy += " DESC"
} else {
orderBy += " ASC"
}
query = query.OrderBy(orderBy)
// Apply pagination
if pagination.Limit > 0 {
query = query.Limit(uint64(pagination.Limit))
}
if pagination.Offset > 0 {
query = query.Offset(uint64(pagination.Offset))
}
return query
}
func (r *DomainRepository) scanDomain(rows *sql.Rows) (*v2domain.Domain, error) {
var domain v2domain.Domain
var orgID sql.NullString
var validationType sql.NullInt32
var deletedAt sql.NullTime
err := rows.Scan(
&domain.InstanceID,
&orgID,
&domain.Domain,
&domain.IsVerified,
&domain.IsPrimary,
&validationType,
&domain.CreatedAt,
&domain.UpdatedAt,
&deletedAt,
)
if err != nil {
return nil, zerrors.ThrowInternal(err, "DOMAIN-E28x0", "Errors.Internal")
}
if orgID.Valid {
domain.OrganizationID = &orgID.String
}
if validationType.Valid {
validationTypeValue := domain.OrgDomainValidationType(validationType.Int32)
domain.ValidationType = &validationTypeValue
}
if deletedAt.Valid {
domain.DeletedAt = &deletedAt.Time
}
return &domain, nil
}
func (r *DomainRepository) withTransaction(ctx context.Context, fn func(TX) error) error {
tx, err := r.client.BeginTx(ctx, nil)
if err != nil {
return zerrors.ThrowInternal(err, "DOMAIN-F39y1", "Errors.Internal")
}
defer func() {
if p := recover(); p != nil {
_ = tx.Rollback()
panic(p)
}
}()
if err := fn(tx); err != nil {
if rbErr := tx.Rollback(); rbErr != nil {
return zerrors.ThrowInternal(rbErr, "DOMAIN-G40z2", "Errors.Internal")
}
return err
}
if err := tx.Commit(); err != nil {
return zerrors.ThrowInternal(err, "DOMAIN-H51a3", "Errors.Internal")
}
return nil
}

View File

@@ -0,0 +1,201 @@
package readmodel
import (
"context"
"database/sql"
"testing"
"time"
"github.com/Masterminds/squirrel"
"github.com/DATA-DOG/go-sqlmock"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/zitadel/zitadel/internal/domain"
v2domain "github.com/zitadel/zitadel/internal/v2/domain"
"github.com/zitadel/zitadel/internal/v2/database"
)
func TestDomainRepository_AddInstanceDomain(t *testing.T) {
db, mock, err := sqlmock.New()
require.NoError(t, err)
defer db.Close()
repo := NewDomainRepository(db)
instanceID := "test-instance-id"
domainName := "test.example.com"
mock.ExpectExec(`INSERT INTO zitadel\.domains`).
WithArgs(instanceID, domainName, true, false, sqlmock.AnyArg(), sqlmock.AnyArg()).
WillReturnResult(sqlmock.NewResult(1, 1))
result, err := repo.AddInstanceDomain(context.Background(), instanceID, domainName)
require.NoError(t, err)
assert.Equal(t, instanceID, result.InstanceID)
assert.Nil(t, result.OrganizationID)
assert.Equal(t, domainName, result.Domain)
assert.True(t, result.IsVerified)
assert.False(t, result.IsPrimary)
assert.Nil(t, result.ValidationType)
assert.NoError(t, mock.ExpectationsWereMet())
}
func TestDomainRepository_AddOrganizationDomain(t *testing.T) {
db, mock, err := sqlmock.New()
require.NoError(t, err)
defer db.Close()
repo := NewDomainRepository(db)
instanceID := "test-instance-id"
orgID := "test-org-id"
domainName := "test.example.com"
validationType := domain.OrgDomainValidationTypeHTTP
mock.ExpectExec(`INSERT INTO zitadel\.domains`).
WithArgs(instanceID, orgID, domainName, false, false, int(validationType), sqlmock.AnyArg(), sqlmock.AnyArg()).
WillReturnResult(sqlmock.NewResult(1, 1))
result, err := repo.AddOrganizationDomain(context.Background(), instanceID, orgID, domainName, validationType)
require.NoError(t, err)
assert.Equal(t, instanceID, result.InstanceID)
assert.Equal(t, orgID, *result.OrganizationID)
assert.Equal(t, domainName, result.Domain)
assert.False(t, result.IsVerified)
assert.False(t, result.IsPrimary)
assert.Equal(t, validationType, *result.ValidationType)
assert.NoError(t, mock.ExpectationsWereMet())
}
func TestDomainRepository_SetInstanceDomainPrimary(t *testing.T) {
db, mock, err := sqlmock.New()
require.NoError(t, err)
defer db.Close()
repo := NewDomainRepository(db)
instanceID := "test-instance-id"
domainName := "test.example.com"
// Mock transaction begin
mock.ExpectBegin()
// Mock unset existing primary
mock.ExpectExec(`UPDATE zitadel\.domains SET.*is_primary.*=.*false`).
WillReturnResult(sqlmock.NewResult(0, 0))
// Mock set new primary
mock.ExpectExec(`UPDATE zitadel\.domains SET.*is_primary.*=.*true`).
WithArgs(sqlmock.AnyArg(), true, instanceID, nil, domainName).
WillReturnResult(sqlmock.NewResult(0, 1))
// Mock transaction commit
mock.ExpectCommit()
err = repo.SetInstanceDomainPrimary(context.Background(), instanceID, domainName)
require.NoError(t, err)
assert.NoError(t, mock.ExpectationsWereMet())
}
func TestDomainRepository_Get(t *testing.T) {
db, mock, err := sqlmock.New()
require.NoError(t, err)
defer db.Close()
repo := NewDomainRepository(db)
instanceID := "test-instance-id"
domainName := "test.example.com"
now := time.Now()
criteria := v2domain.DomainSearchCriteria{
InstanceID: &instanceID,
Domain: &domainName,
}
rows := sqlmock.NewRows([]string{
"instance_id", "org_id", "domain", "is_verified", "is_primary", "validation_type", "created_at", "updated_at", "deleted_at",
}).AddRow(instanceID, nil, domainName, true, false, nil, now, now, nil)
mock.ExpectQuery(`SELECT .* FROM zitadel\.domains`).
WithArgs(domainName, instanceID).
WillReturnRows(rows)
result, err := repo.Get(context.Background(), criteria)
require.NoError(t, err)
assert.Equal(t, instanceID, result.InstanceID)
assert.Nil(t, result.OrganizationID)
assert.Equal(t, domainName, result.Domain)
assert.True(t, result.IsVerified)
assert.False(t, result.IsPrimary)
assert.NoError(t, mock.ExpectationsWereMet())
}
func TestDomainRepository_List(t *testing.T) {
db, mock, err := sqlmock.New()
require.NoError(t, err)
defer db.Close()
repo := NewDomainRepository(db)
instanceID := "test-instance-id"
now := time.Now()
criteria := v2domain.DomainSearchCriteria{
InstanceID: &instanceID,
}
pagination := v2domain.DomainPagination{
Limit: 10,
Offset: 0,
SortBy: v2domain.DomainSortFieldDomain,
Order: database.SortOrderAsc,
}
// Mock count query
countRows := sqlmock.NewRows([]string{"count"}).AddRow(2)
mock.ExpectQuery(`SELECT COUNT\(\*\) FROM zitadel\.domains`).
WithArgs(instanceID).
WillReturnRows(countRows)
// Mock data query
rows := sqlmock.NewRows([]string{
"instance_id", "org_id", "domain", "is_verified", "is_primary", "validation_type", "created_at", "updated_at", "deleted_at",
}).
AddRow(instanceID, nil, "instance.example.com", true, true, nil, now, now, nil).
AddRow(instanceID, "org-id", "org.example.com", false, false, int(domain.OrgDomainValidationTypeHTTP), now, now, nil)
mock.ExpectQuery(`SELECT .* FROM zitadel\.domains.*ORDER BY domain ASC.*LIMIT 10`).
WithArgs(instanceID).
WillReturnRows(rows)
result, err := repo.List(context.Background(), criteria, pagination)
require.NoError(t, err)
assert.Equal(t, uint64(2), result.TotalCount)
assert.Len(t, result.Domains, 2)
// Check first domain (instance domain)
assert.Equal(t, instanceID, result.Domains[0].InstanceID)
assert.Nil(t, result.Domains[0].OrganizationID)
assert.Equal(t, "instance.example.com", result.Domains[0].Domain)
assert.True(t, result.Domains[0].IsVerified)
assert.True(t, result.Domains[0].IsPrimary)
// Check second domain (org domain)
assert.Equal(t, instanceID, result.Domains[1].InstanceID)
assert.Equal(t, "org-id", *result.Domains[1].OrganizationID)
assert.Equal(t, "org.example.com", result.Domains[1].Domain)
assert.False(t, result.Domains[1].IsVerified)
assert.False(t, result.Domains[1].IsPrimary)
assert.NoError(t, mock.ExpectationsWereMet())
}