From d3de8a2150044514e1cf1f270935f96f3be97bae Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Mon, 14 Jul 2025 20:05:09 +0000 Subject: [PATCH] Implement unified domains table with migration, repository, and projection Co-authored-by: adlerhurst <27845747+adlerhurst@users.noreply.github.com> --- cmd/setup/61.go | 40 ++ cmd/setup/61/01_create_domains_table.sql | 24 + cmd/setup/config.go | 1 + cmd/setup/setup.go | 2 + internal/query/projection/domains.go | 340 ++++++++++ internal/query/projection/projection.go | 2 + internal/v2/domain/repository.go | 107 ++++ internal/v2/readmodel/domain_repository.go | 600 ++++++++++++++++++ .../v2/readmodel/domain_repository_test.go | 201 ++++++ 9 files changed, 1317 insertions(+) create mode 100644 cmd/setup/61.go create mode 100644 cmd/setup/61/01_create_domains_table.sql create mode 100644 internal/query/projection/domains.go create mode 100644 internal/v2/domain/repository.go create mode 100644 internal/v2/readmodel/domain_repository.go create mode 100644 internal/v2/readmodel/domain_repository_test.go diff --git a/cmd/setup/61.go b/cmd/setup/61.go new file mode 100644 index 0000000000..804805d24e --- /dev/null +++ b/cmd/setup/61.go @@ -0,0 +1,40 @@ +package setup + +import ( + "context" + "database/sql" + "embed" + "fmt" + + "github.com/zitadel/logging" + + "github.com/zitadel/zitadel/internal/database" + "github.com/zitadel/zitadel/internal/eventstore" +) + +var ( + //go:embed 61/*.sql + createDomainsTable embed.FS +) + +type CreateDomainsTable struct { + dbClient *database.DB +} + +func (mig *CreateDomainsTable) Execute(ctx context.Context, _ eventstore.Event) error { + statements, err := readStatements(createDomainsTable, "61") + if err != nil { + return err + } + for _, stmt := range statements { + logging.WithFields("file", stmt.file, "migration", mig.String()).Info("execute statement") + if _, err := mig.dbClient.ExecContext(ctx, stmt.query); err != nil { + return fmt.Errorf("%s %s: %w", mig.String(), stmt.file, err) + } + } + return nil +} + +func (mig *CreateDomainsTable) String() string { + return "61_create_domains_table" +} \ No newline at end of file diff --git a/cmd/setup/61/01_create_domains_table.sql b/cmd/setup/61/01_create_domains_table.sql new file mode 100644 index 0000000000..ffbd3c25f4 --- /dev/null +++ b/cmd/setup/61/01_create_domains_table.sql @@ -0,0 +1,24 @@ +CREATE TABLE IF NOT EXISTS zitadel.domains( + instance_id TEXT NOT NULL + , org_id TEXT + , domain TEXT NOT NULL CHECK (LENGTH(domain) BETWEEN 1 AND 255) + , is_verified BOOLEAN NOT NULL DEFAULT FALSE + , is_primary BOOLEAN NOT NULL DEFAULT FALSE + -- TODO make validation_type enum + , validation_type SMALLINT CHECK (validation_type >= 0) + + , created_at TIMESTAMP DEFAULT NOW() + , updated_at TIMESTAMP DEFAULT NOW() + , deleted_at TIMESTAMP DEFAULT NULL + + , FOREIGN KEY (instance_id) REFERENCES zitadel.instances(id) ON DELETE CASCADE + , FOREIGN KEY (instance_id, org_id) REFERENCES zitadel.organizations(instance_id, id) ON DELETE CASCADE + + , CONSTRAINT domain_unique UNIQUE NULLS NOT DISTINCT (instance_id, org_id, domain) WHERE deleted_at IS NULL +); + +CREATE TRIGGER IF NOT EXISTS trigger_set_updated_at +BEFORE UPDATE ON zitadel.domains +FOR EACH ROW +WHEN (OLD.updated_at IS NOT DISTINCT FROM NEW.updated_at) +EXECUTE FUNCTION zitadel.set_updated_at(); \ No newline at end of file diff --git a/cmd/setup/config.go b/cmd/setup/config.go index bac73b0ae5..6bb0d50312 100644 --- a/cmd/setup/config.go +++ b/cmd/setup/config.go @@ -157,6 +157,7 @@ type Steps struct { s58ReplaceLoginNames3View *ReplaceLoginNames3View s59SetupWebkeys *SetupWebkeys s60GenerateSystemID *GenerateSystemID + s61CreateDomainsTable *CreateDomainsTable } func MustNewSteps(v *viper.Viper) *Steps { diff --git a/cmd/setup/setup.go b/cmd/setup/setup.go index 15236a73e9..7deefff485 100644 --- a/cmd/setup/setup.go +++ b/cmd/setup/setup.go @@ -218,6 +218,7 @@ func Setup(ctx context.Context, config *Config, steps *Steps, masterKey string) steps.s57CreateResourceCounts = &CreateResourceCounts{dbClient: dbClient} steps.s58ReplaceLoginNames3View = &ReplaceLoginNames3View{dbClient: dbClient} steps.s60GenerateSystemID = &GenerateSystemID{eventstore: eventstoreClient} + steps.s61CreateDomainsTable = &CreateDomainsTable{dbClient: dbClient} err = projection.Create(ctx, dbClient, eventstoreClient, config.Projections, nil, nil, nil) logging.OnError(err).Fatal("unable to start projections") @@ -266,6 +267,7 @@ func Setup(ctx context.Context, config *Config, steps *Steps, masterKey string) steps.s57CreateResourceCounts, steps.s58ReplaceLoginNames3View, steps.s60GenerateSystemID, + steps.s61CreateDomainsTable, } { setupErr = executeMigration(ctx, eventstoreClient, step, "migration failed") if setupErr != nil { diff --git a/internal/query/projection/domains.go b/internal/query/projection/domains.go new file mode 100644 index 0000000000..5a014ecf0f --- /dev/null +++ b/internal/query/projection/domains.go @@ -0,0 +1,340 @@ +package projection + +import ( + "context" + + "github.com/zitadel/zitadel/internal/domain" + "github.com/zitadel/zitadel/internal/eventstore" + old_handler "github.com/zitadel/zitadel/internal/eventstore/handler" + "github.com/zitadel/zitadel/internal/eventstore/handler/v2" + "github.com/zitadel/zitadel/internal/repository/instance" + "github.com/zitadel/zitadel/internal/repository/org" + "github.com/zitadel/zitadel/internal/zerrors" +) + +const ( + DomainsTable = "zitadel.domains" + + DomainsInstanceIDCol = "instance_id" + DomainsOrgIDCol = "org_id" + DomainsDomainCol = "domain" + DomainsIsVerifiedCol = "is_verified" + DomainsIsPrimaryCol = "is_primary" + DomainsValidationTypeCol = "validation_type" + DomainsCreatedAtCol = "created_at" + DomainsUpdatedAtCol = "updated_at" + DomainsDeletedAtCol = "deleted_at" +) + +type domainsProjection struct{} + +func newDomainsProjection(ctx context.Context, config handler.Config) *handler.Handler { + return handler.NewHandler(ctx, &config, new(domainsProjection)) +} + +func (*domainsProjection) Name() string { + return DomainsTable +} + +func (*domainsProjection) Init() *old_handler.Check { + // The table is created by migration, so we just need to check it exists + return handler.NewTableCheck( + handler.NewTable([]*handler.InitColumn{ + handler.NewColumn(DomainsInstanceIDCol, handler.ColumnTypeText), + handler.NewColumn(DomainsOrgIDCol, handler.ColumnTypeText), + handler.NewColumn(DomainsDomainCol, handler.ColumnTypeText), + handler.NewColumn(DomainsIsVerifiedCol, handler.ColumnTypeBool), + handler.NewColumn(DomainsIsPrimaryCol, handler.ColumnTypeBool), + handler.NewColumn(DomainsValidationTypeCol, handler.ColumnTypeEnum), + handler.NewColumn(DomainsCreatedAtCol, handler.ColumnTypeTimestamp), + handler.NewColumn(DomainsUpdatedAtCol, handler.ColumnTypeTimestamp), + handler.NewColumn(DomainsDeletedAtCol, handler.ColumnTypeTimestamp), + }, + // Note: We don't define primary key here since the table is created by migration + ), + ) +} + +func (p *domainsProjection) Reducers() []handler.AggregateReducer { + return []handler.AggregateReducer{ + { + Aggregate: org.AggregateType, + EventReducers: []handler.EventReducer{ + { + Event: org.OrgDomainAddedEventType, + Reduce: p.reduceOrgDomainAdded, + }, + { + Event: org.OrgDomainVerificationAddedEventType, + Reduce: p.reduceOrgDomainVerificationAdded, + }, + { + Event: org.OrgDomainVerifiedEventType, + Reduce: p.reduceOrgDomainVerified, + }, + { + Event: org.OrgDomainPrimarySetEventType, + Reduce: p.reduceOrgPrimaryDomainSet, + }, + { + Event: org.OrgDomainRemovedEventType, + Reduce: p.reduceOrgDomainRemoved, + }, + { + Event: org.OrgRemovedEventType, + Reduce: p.reduceOrgRemoved, + }, + }, + }, + { + Aggregate: instance.AggregateType, + EventReducers: []handler.EventReducer{ + { + Event: instance.InstanceDomainAddedEventType, + Reduce: p.reduceInstanceDomainAdded, + }, + { + Event: instance.InstanceDomainPrimarySetEventType, + Reduce: p.reduceInstancePrimaryDomainSet, + }, + { + Event: instance.InstanceDomainRemovedEventType, + Reduce: p.reduceInstanceDomainRemoved, + }, + { + Event: instance.InstanceRemovedEventType, + Reduce: p.reduceInstanceRemoved, + }, + }, + }, + } +} + +// Organization domain event handlers + +func (p *domainsProjection) reduceOrgDomainAdded(event eventstore.Event) (*handler.Statement, error) { + e, ok := event.(*org.DomainAddedEvent) + if !ok { + return nil, zerrors.ThrowInvalidArgumentf(nil, "PROJE-DM2DI", "reduce.wrong.event.type %s", org.OrgDomainAddedEventType) + } + return handler.NewCreateStatement( + e, + []handler.Column{ + handler.NewCol(DomainsInstanceIDCol, e.Aggregate().InstanceID), + handler.NewCol(DomainsOrgIDCol, e.Aggregate().ID), + handler.NewCol(DomainsDomainCol, e.Domain), + handler.NewCol(DomainsIsVerifiedCol, false), + handler.NewCol(DomainsIsPrimaryCol, false), + handler.NewCol(DomainsValidationTypeCol, domain.OrgDomainValidationTypeUnspecified), + handler.NewCol(DomainsCreatedAtCol, e.CreationDate()), + handler.NewCol(DomainsUpdatedAtCol, e.CreationDate()), + }, + ), nil +} + +func (p *domainsProjection) reduceOrgDomainVerificationAdded(event eventstore.Event) (*handler.Statement, error) { + e, ok := event.(*org.DomainVerificationAddedEvent) + if !ok { + return nil, zerrors.ThrowInvalidArgumentf(nil, "PROJE-EBzyu", "reduce.wrong.event.type %s", org.OrgDomainVerificationAddedEventType) + } + return handler.NewUpdateStatement( + e, + []handler.Column{ + handler.NewCol(DomainsUpdatedAtCol, e.CreationDate()), + handler.NewCol(DomainsValidationTypeCol, e.ValidationType), + }, + []handler.Condition{ + handler.NewCond(DomainsInstanceIDCol, e.Aggregate().InstanceID), + handler.NewCond(DomainsOrgIDCol, e.Aggregate().ID), + handler.NewCond(DomainsDomainCol, e.Domain), + handler.NewCond(DomainsDeletedAtCol, nil), + }, + ), nil +} + +func (p *domainsProjection) reduceOrgDomainVerified(event eventstore.Event) (*handler.Statement, error) { + e, ok := event.(*org.DomainVerifiedEvent) + if !ok { + return nil, zerrors.ThrowInvalidArgumentf(nil, "PROJE-3Rvkr", "reduce.wrong.event.type %s", org.OrgDomainVerifiedEventType) + } + return handler.NewUpdateStatement( + e, + []handler.Column{ + handler.NewCol(DomainsUpdatedAtCol, e.CreationDate()), + handler.NewCol(DomainsIsVerifiedCol, true), + }, + []handler.Condition{ + handler.NewCond(DomainsInstanceIDCol, e.Aggregate().InstanceID), + handler.NewCond(DomainsOrgIDCol, e.Aggregate().ID), + handler.NewCond(DomainsDomainCol, e.Domain), + handler.NewCond(DomainsDeletedAtCol, nil), + }, + ), nil +} + +func (p *domainsProjection) reduceOrgPrimaryDomainSet(event eventstore.Event) (*handler.Statement, error) { + e, ok := event.(*org.DomainPrimarySetEvent) + if !ok { + return nil, zerrors.ThrowInvalidArgumentf(nil, "PROJE-aIuei", "reduce.wrong.event.type %s", org.OrgDomainPrimarySetEventType) + } + return handler.NewMultiStatement( + e, + handler.AddUpdateStatement( + []handler.Column{ + handler.NewCol(DomainsUpdatedAtCol, e.CreationDate()), + handler.NewCol(DomainsIsPrimaryCol, false), + }, + []handler.Condition{ + handler.NewCond(DomainsInstanceIDCol, e.Aggregate().InstanceID), + handler.NewCond(DomainsOrgIDCol, e.Aggregate().ID), + handler.NewCond(DomainsIsPrimaryCol, true), + handler.NewCond(DomainsDeletedAtCol, nil), + }, + ), + handler.AddUpdateStatement( + []handler.Column{ + handler.NewCol(DomainsUpdatedAtCol, e.CreationDate()), + handler.NewCol(DomainsIsPrimaryCol, true), + }, + []handler.Condition{ + handler.NewCond(DomainsInstanceIDCol, e.Aggregate().InstanceID), + handler.NewCond(DomainsOrgIDCol, e.Aggregate().ID), + handler.NewCond(DomainsDomainCol, e.Domain), + handler.NewCond(DomainsDeletedAtCol, nil), + }, + ), + ), nil +} + +func (p *domainsProjection) reduceOrgDomainRemoved(event eventstore.Event) (*handler.Statement, error) { + e, ok := event.(*org.DomainRemovedEvent) + if !ok { + return nil, zerrors.ThrowInvalidArgumentf(nil, "PROJE-gh1Mx", "reduce.wrong.event.type %s", org.OrgDomainRemovedEventType) + } + return handler.NewUpdateStatement( + e, + []handler.Column{ + handler.NewCol(DomainsUpdatedAtCol, e.CreationDate()), + handler.NewCol(DomainsDeletedAtCol, e.CreationDate()), + }, + []handler.Condition{ + handler.NewCond(DomainsInstanceIDCol, e.Aggregate().InstanceID), + handler.NewCond(DomainsOrgIDCol, e.Aggregate().ID), + handler.NewCond(DomainsDomainCol, e.Domain), + }, + ), nil +} + +func (p *domainsProjection) reduceOrgRemoved(event eventstore.Event) (*handler.Statement, error) { + e, ok := event.(*org.OrgRemovedEvent) + if !ok { + return nil, zerrors.ThrowInvalidArgumentf(nil, "PROJE-dMUKJ", "reduce.wrong.event.type %s", org.OrgRemovedEventType) + } + + return handler.NewUpdateStatement( + e, + []handler.Column{ + handler.NewCol(DomainsUpdatedAtCol, e.CreationDate()), + handler.NewCol(DomainsDeletedAtCol, e.CreationDate()), + }, + []handler.Condition{ + handler.NewCond(DomainsInstanceIDCol, e.Aggregate().InstanceID), + handler.NewCond(DomainsOrgIDCol, e.Aggregate().ID), + handler.NewCond(DomainsDeletedAtCol, nil), + }, + ), nil +} + +// Instance domain event handlers + +func (p *domainsProjection) reduceInstanceDomainAdded(event eventstore.Event) (*handler.Statement, error) { + e, ok := event.(*instance.DomainAddedEvent) + if !ok { + return nil, zerrors.ThrowInvalidArgumentf(nil, "PROJE-38nNf", "reduce.wrong.event.type %s", instance.InstanceDomainAddedEventType) + } + return handler.NewCreateStatement( + e, + []handler.Column{ + handler.NewCol(DomainsInstanceIDCol, e.Aggregate().ID), + handler.NewCol(DomainsOrgIDCol, nil), // Instance domains have no org_id + handler.NewCol(DomainsDomainCol, e.Domain), + handler.NewCol(DomainsIsVerifiedCol, true), // Instance domains are always verified + handler.NewCol(DomainsIsPrimaryCol, false), + handler.NewCol(DomainsValidationTypeCol, nil), // Instance domains have no validation type + handler.NewCol(DomainsCreatedAtCol, e.CreationDate()), + handler.NewCol(DomainsUpdatedAtCol, e.CreationDate()), + }, + ), nil +} + +func (p *domainsProjection) reduceInstancePrimaryDomainSet(event eventstore.Event) (*handler.Statement, error) { + e, ok := event.(*instance.DomainPrimarySetEvent) + if !ok { + return nil, zerrors.ThrowInvalidArgumentf(nil, "PROJE-f8nlw", "reduce.wrong.event.type %s", instance.InstanceDomainPrimarySetEventType) + } + return handler.NewMultiStatement( + e, + handler.AddUpdateStatement( + []handler.Column{ + handler.NewCol(DomainsUpdatedAtCol, e.CreationDate()), + handler.NewCol(DomainsIsPrimaryCol, false), + }, + []handler.Condition{ + handler.NewCond(DomainsInstanceIDCol, e.Aggregate().ID), + handler.NewCond(DomainsOrgIDCol, nil), // Instance domains + handler.NewCond(DomainsIsPrimaryCol, true), + handler.NewCond(DomainsDeletedAtCol, nil), + }, + ), + handler.AddUpdateStatement( + []handler.Column{ + handler.NewCol(DomainsUpdatedAtCol, e.CreationDate()), + handler.NewCol(DomainsIsPrimaryCol, true), + }, + []handler.Condition{ + handler.NewCond(DomainsInstanceIDCol, e.Aggregate().ID), + handler.NewCond(DomainsOrgIDCol, nil), // Instance domains + handler.NewCond(DomainsDomainCol, e.Domain), + handler.NewCond(DomainsDeletedAtCol, nil), + }, + ), + ), nil +} + +func (p *domainsProjection) reduceInstanceDomainRemoved(event eventstore.Event) (*handler.Statement, error) { + e, ok := event.(*instance.DomainRemovedEvent) + if !ok { + return nil, zerrors.ThrowInvalidArgumentf(nil, "PROJE-388Nk", "reduce.wrong.event.type %s", instance.InstanceDomainRemovedEventType) + } + return handler.NewUpdateStatement( + e, + []handler.Column{ + handler.NewCol(DomainsUpdatedAtCol, e.CreationDate()), + handler.NewCol(DomainsDeletedAtCol, e.CreationDate()), + }, + []handler.Condition{ + handler.NewCond(DomainsInstanceIDCol, e.Aggregate().ID), + handler.NewCond(DomainsOrgIDCol, nil), // Instance domains + handler.NewCond(DomainsDomainCol, e.Domain), + }, + ), nil +} + +func (p *domainsProjection) reduceInstanceRemoved(event eventstore.Event) (*handler.Statement, error) { + e, ok := event.(*instance.InstanceRemovedEvent) + if !ok { + return nil, zerrors.ThrowInvalidArgumentf(nil, "PROJE-2n9f0", "reduce.wrong.event.type %s", instance.InstanceRemovedEventType) + } + + return handler.NewUpdateStatement( + e, + []handler.Column{ + handler.NewCol(DomainsUpdatedAtCol, e.CreationDate()), + handler.NewCol(DomainsDeletedAtCol, e.CreationDate()), + }, + []handler.Condition{ + handler.NewCond(DomainsInstanceIDCol, e.Aggregate().ID), + handler.NewCond(DomainsDeletedAtCol, nil), + }, + ), nil +} \ No newline at end of file diff --git a/internal/query/projection/projection.go b/internal/query/projection/projection.go index 5ad62380ea..aed1432f34 100644 --- a/internal/query/projection/projection.go +++ b/internal/query/projection/projection.go @@ -37,6 +37,7 @@ var ( ProjectGrantProjection *handler.Handler ProjectRoleProjection *handler.Handler OrgDomainProjection *handler.Handler + DomainsProjection *handler.Handler // Unified domains table LoginPolicyProjection *handler.Handler IDPProjection *handler.Handler AppProjection *handler.Handler @@ -134,6 +135,7 @@ func Create(ctx context.Context, sqlClient *database.DB, es handler.EventStore, ProjectGrantProjection = newProjectGrantProjection(ctx, applyCustomConfig(projectionConfig, config.Customizations["project_grants"])) ProjectRoleProjection = newProjectRoleProjection(ctx, applyCustomConfig(projectionConfig, config.Customizations["project_roles"])) OrgDomainProjection = newOrgDomainProjection(ctx, applyCustomConfig(projectionConfig, config.Customizations["org_domains"])) + DomainsProjection = newDomainsProjection(ctx, applyCustomConfig(projectionConfig, config.Customizations["domains"])) LoginPolicyProjection = newLoginPolicyProjection(ctx, applyCustomConfig(projectionConfig, config.Customizations["login_policies"])) IDPProjection = newIDPProjection(ctx, applyCustomConfig(projectionConfig, config.Customizations["idps"])) AppProjection = newAppProjection(ctx, applyCustomConfig(projectionConfig, config.Customizations["apps"])) diff --git a/internal/v2/domain/repository.go b/internal/v2/domain/repository.go new file mode 100644 index 0000000000..c833ce8e58 --- /dev/null +++ b/internal/v2/domain/repository.go @@ -0,0 +1,107 @@ +package domain + +import ( + "context" + "time" + + "github.com/zitadel/zitadel/internal/domain" + "github.com/zitadel/zitadel/internal/v2/database" +) + +// Domain represents a unified domain that can belong to either an instance or an organization +type Domain struct { + ID string + InstanceID string + OrganizationID *string // nil for instance domains + Domain string + IsVerified bool + IsPrimary bool + ValidationType *domain.OrgDomainValidationType // nil for instance domains + CreatedAt time.Time + UpdatedAt time.Time + DeletedAt *time.Time +} + +// IsInstanceDomain returns true if this is an instance domain (OrganizationID is nil) +func (d *Domain) IsInstanceDomain() bool { + return d.OrganizationID == nil +} + +// IsOrganizationDomain returns true if this is an organization domain (OrganizationID is not nil) +func (d *Domain) IsOrganizationDomain() bool { + return d.OrganizationID != nil +} + +// DomainSearchCriteria defines the search criteria for domains +type DomainSearchCriteria struct { + ID *string + Domain *string + InstanceID *string + OrganizationID *string + IsVerified *bool + IsPrimary *bool +} + +// DomainPagination defines pagination options for domain queries +type DomainPagination struct { + Limit uint32 + Offset uint32 + SortBy DomainSortField + Order database.SortOrder +} + +// DomainSortField defines the fields available for sorting domain results +type DomainSortField int + +const ( + DomainSortFieldCreatedAt DomainSortField = iota + DomainSortFieldUpdatedAt + DomainSortFieldDomain +) + +// DomainList represents a paginated list of domains +type DomainList struct { + Domains []*Domain + TotalCount uint64 +} + +// InstanceDomainRepository defines the repository interface for instance domain operations +type InstanceDomainRepository interface { + // Add creates a new instance domain (always verified) + Add(ctx context.Context, instanceID, domain string) (*Domain, error) + + // SetPrimary sets the primary domain for an instance + SetPrimary(ctx context.Context, instanceID, domain string) error + + // Remove soft deletes an instance domain + Remove(ctx context.Context, instanceID, domain string) error + + // Get returns a single instance domain matching the criteria + // Returns error if multiple domains are found + Get(ctx context.Context, criteria DomainSearchCriteria) (*Domain, error) + + // List returns a list of instance domains matching the criteria with pagination + List(ctx context.Context, criteria DomainSearchCriteria, pagination DomainPagination) (*DomainList, error) +} + +// OrganizationDomainRepository defines the repository interface for organization domain operations +type OrganizationDomainRepository interface { + // Add creates a new organization domain + Add(ctx context.Context, instanceID, organizationID, domain string, validationType domain.OrgDomainValidationType) (*Domain, error) + + // SetVerified marks an organization domain as verified + SetVerified(ctx context.Context, instanceID, organizationID, domain string) error + + // SetPrimary sets the primary domain for an organization + SetPrimary(ctx context.Context, instanceID, organizationID, domain string) error + + // Remove soft deletes an organization domain + Remove(ctx context.Context, instanceID, organizationID, domain string) error + + // Get returns a single organization domain matching the criteria + // Returns error if multiple domains are found + Get(ctx context.Context, criteria DomainSearchCriteria) (*Domain, error) + + // List returns a list of organization domains matching the criteria with pagination + List(ctx context.Context, criteria DomainSearchCriteria, pagination DomainPagination) (*DomainList, error) +} \ No newline at end of file diff --git a/internal/v2/readmodel/domain_repository.go b/internal/v2/readmodel/domain_repository.go new file mode 100644 index 0000000000..6acdef87a8 --- /dev/null +++ b/internal/v2/readmodel/domain_repository.go @@ -0,0 +1,600 @@ +package readmodel + +import ( + "context" + "database/sql" + "time" + + "github.com/Masterminds/squirrel" + + "github.com/zitadel/zitadel/internal/domain" + v2domain "github.com/zitadel/zitadel/internal/v2/domain" + "github.com/zitadel/zitadel/internal/v2/database" + "github.com/zitadel/zitadel/internal/zerrors" +) + +// Database interfaces for the repository +type DB interface { + ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) + QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) + QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row + BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error) +} + +type TX interface { + ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) + QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) + QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row + Rollback() error + Commit() error +} + +const ( + domainsTable = "zitadel.domains" + domainsInstanceIDCol = "instance_id" + domainsOrgIDCol = "org_id" + domainsDomainCol = "domain" + domainsIsVerifiedCol = "is_verified" + domainsIsPrimaryCol = "is_primary" + domainsValidationTypeCol = "validation_type" + domainsCreatedAtCol = "created_at" + domainsUpdatedAtCol = "updated_at" + domainsDeletedAtCol = "deleted_at" +) + +// DomainRepository implements both InstanceDomainRepository and OrganizationDomainRepository +type DomainRepository struct { + client DB +} + +// NewDomainRepository creates a new DomainRepository +func NewDomainRepository(client DB) *DomainRepository { + return &DomainRepository{ + client: client, + } +} + +// Add creates a new instance domain (always verified) +func (r *DomainRepository) AddInstanceDomain(ctx context.Context, instanceID, domain string) (*v2domain.Domain, error) { + now := time.Now() + query := squirrel.Insert(domainsTable). + Columns( + domainsInstanceIDCol, + domainsDomainCol, + domainsIsVerifiedCol, + domainsIsPrimaryCol, + domainsCreatedAtCol, + domainsUpdatedAtCol, + ). + Values(instanceID, domain, true, false, now, now). + PlaceholderFormat(squirrel.Dollar) + + stmt, args, err := query.ToSql() + if err != nil { + return nil, zerrors.ThrowInternal(err, "DOMAIN-1n8fK", "Errors.Internal") + } + + _, err = r.client.ExecContext(ctx, stmt, args...) + if err != nil { + return nil, zerrors.ThrowInternal(err, "DOMAIN-3m9sL", "Errors.Internal") + } + + return &v2domain.Domain{ + InstanceID: instanceID, + OrganizationID: nil, + Domain: domain, + IsVerified: true, + IsPrimary: false, + ValidationType: nil, + CreatedAt: now, + UpdatedAt: now, + }, nil +} + +// SetPrimary sets the primary domain for an instance +func (r *DomainRepository) SetInstanceDomainPrimary(ctx context.Context, instanceID, domain string) error { + return r.withTransaction(ctx, func(tx database.Tx) error { + // First, unset any existing primary domain for this instance + unsetQuery := squirrel.Update(domainsTable). + Set(domainsIsPrimaryCol, false). + Set(domainsUpdatedAtCol, time.Now()). + Where(squirrel.Eq{ + domainsInstanceIDCol: instanceID, + domainsOrgIDCol: nil, + domainsIsPrimaryCol: true, + }). + Where(squirrel.Expr(domainsDeletedAtCol + " IS NULL")). + PlaceholderFormat(squirrel.Dollar) + + stmt, args, err := unsetQuery.ToSql() + if err != nil { + return zerrors.ThrowInternal(err, "DOMAIN-4k2nL", "Errors.Internal") + } + + _, err = tx.ExecContext(ctx, stmt, args...) + if err != nil { + return zerrors.ThrowInternal(err, "DOMAIN-5n3mK", "Errors.Internal") + } + + // Then set the new primary domain + setPrimaryQuery := squirrel.Update(domainsTable). + Set(domainsIsPrimaryCol, true). + Set(domainsUpdatedAtCol, time.Now()). + Where(squirrel.Eq{ + domainsInstanceIDCol: instanceID, + domainsOrgIDCol: nil, + domainsDomainCol: domain, + }). + Where(squirrel.Expr(domainsDeletedAtCol + " IS NULL")). + PlaceholderFormat(squirrel.Dollar) + + stmt, args, err = setPrimaryQuery.ToSql() + if err != nil { + return zerrors.ThrowInternal(err, "DOMAIN-6o4pM", "Errors.Internal") + } + + result, err := tx.ExecContext(ctx, stmt, args...) + if err != nil { + return zerrors.ThrowInternal(err, "DOMAIN-7p5qN", "Errors.Internal") + } + + rowsAffected, err := result.RowsAffected() + if err != nil { + return zerrors.ThrowInternal(err, "DOMAIN-8q6rO", "Errors.Internal") + } + + if rowsAffected == 0 { + return zerrors.ThrowNotFound(nil, "DOMAIN-9r7sP", "Errors.Domain.NotFound") + } + + return nil + }) +} + +// Remove soft deletes an instance domain +func (r *DomainRepository) RemoveInstanceDomain(ctx context.Context, instanceID, domain string) error { + query := squirrel.Update(domainsTable). + Set(domainsDeletedAtCol, time.Now()). + Set(domainsUpdatedAtCol, time.Now()). + Where(squirrel.Eq{ + domainsInstanceIDCol: instanceID, + domainsOrgIDCol: nil, + domainsDomainCol: domain, + }). + Where(squirrel.Expr(domainsDeletedAtCol + " IS NULL")). + PlaceholderFormat(squirrel.Dollar) + + stmt, args, err := query.ToSql() + if err != nil { + return zerrors.ThrowInternal(err, "DOMAIN-As8tQ", "Errors.Internal") + } + + result, err := r.client.ExecContext(ctx, stmt, args...) + if err != nil { + return zerrors.ThrowInternal(err, "DOMAIN-Bt9uR", "Errors.Internal") + } + + rowsAffected, err := result.RowsAffected() + if err != nil { + return zerrors.ThrowInternal(err, "DOMAIN-Cu0vS", "Errors.Internal") + } + + if rowsAffected == 0 { + return zerrors.ThrowNotFound(nil, "DOMAIN-Dv1wT", "Errors.Domain.NotFound") + } + + return nil +} + +// Add creates a new organization domain +func (r *DomainRepository) AddOrganizationDomain(ctx context.Context, instanceID, organizationID, domain string, validationType domain.OrgDomainValidationType) (*v2domain.Domain, error) { + now := time.Now() + query := squirrel.Insert(domainsTable). + Columns( + domainsInstanceIDCol, + domainsOrgIDCol, + domainsDomainCol, + domainsIsVerifiedCol, + domainsIsPrimaryCol, + domainsValidationTypeCol, + domainsCreatedAtCol, + domainsUpdatedAtCol, + ). + Values(instanceID, organizationID, domain, false, false, int(validationType), now, now). + PlaceholderFormat(squirrel.Dollar) + + stmt, args, err := query.ToSql() + if err != nil { + return nil, zerrors.ThrowInternal(err, "DOMAIN-Ew2xU", "Errors.Internal") + } + + _, err = r.client.ExecContext(ctx, stmt, args...) + if err != nil { + return nil, zerrors.ThrowInternal(err, "DOMAIN-Fx3yV", "Errors.Internal") + } + + return &v2domain.Domain{ + InstanceID: instanceID, + OrganizationID: &organizationID, + Domain: domain, + IsVerified: false, + IsPrimary: false, + ValidationType: &validationType, + CreatedAt: now, + UpdatedAt: now, + }, nil +} + +// SetVerified marks an organization domain as verified +func (r *DomainRepository) SetOrganizationDomainVerified(ctx context.Context, instanceID, organizationID, domain string) error { + query := squirrel.Update(domainsTable). + Set(domainsIsVerifiedCol, true). + Set(domainsUpdatedAtCol, time.Now()). + Where(squirrel.Eq{ + domainsInstanceIDCol: instanceID, + domainsOrgIDCol: organizationID, + domainsDomainCol: domain, + }). + Where(squirrel.Expr(domainsDeletedAtCol + " IS NULL")). + PlaceholderFormat(squirrel.Dollar) + + stmt, args, err := query.ToSql() + if err != nil { + return zerrors.ThrowInternal(err, "DOMAIN-Gy4zW", "Errors.Internal") + } + + result, err := r.client.ExecContext(ctx, stmt, args...) + if err != nil { + return zerrors.ThrowInternal(err, "DOMAIN-Hz5aX", "Errors.Internal") + } + + rowsAffected, err := result.RowsAffected() + if err != nil { + return zerrors.ThrowInternal(err, "DOMAIN-I06bY", "Errors.Internal") + } + + if rowsAffected == 0 { + return zerrors.ThrowNotFound(nil, "DOMAIN-J17cZ", "Errors.Domain.NotFound") + } + + return nil +} + +// SetPrimary sets the primary domain for an organization +func (r *DomainRepository) SetOrganizationDomainPrimary(ctx context.Context, instanceID, organizationID, domain string) error { + return r.withTransaction(ctx, func(tx database.Tx) error { + // First, unset any existing primary domain for this organization + unsetQuery := squirrel.Update(domainsTable). + Set(domainsIsPrimaryCol, false). + Set(domainsUpdatedAtCol, time.Now()). + Where(squirrel.Eq{ + domainsInstanceIDCol: instanceID, + domainsOrgIDCol: organizationID, + domainsIsPrimaryCol: true, + }). + Where(squirrel.Expr(domainsDeletedAtCol + " IS NULL")). + PlaceholderFormat(squirrel.Dollar) + + stmt, args, err := unsetQuery.ToSql() + if err != nil { + return zerrors.ThrowInternal(err, "DOMAIN-K28d0", "Errors.Internal") + } + + _, err = tx.ExecContext(ctx, stmt, args...) + if err != nil { + return zerrors.ThrowInternal(err, "DOMAIN-L39e1", "Errors.Internal") + } + + // Then set the new primary domain + setPrimaryQuery := squirrel.Update(domainsTable). + Set(domainsIsPrimaryCol, true). + Set(domainsUpdatedAtCol, time.Now()). + Where(squirrel.Eq{ + domainsInstanceIDCol: instanceID, + domainsOrgIDCol: organizationID, + domainsDomainCol: domain, + }). + Where(squirrel.Expr(domainsDeletedAtCol + " IS NULL")). + PlaceholderFormat(squirrel.Dollar) + + stmt, args, err = setPrimaryQuery.ToSql() + if err != nil { + return zerrors.ThrowInternal(err, "DOMAIN-M40f2", "Errors.Internal") + } + + result, err := tx.ExecContext(ctx, stmt, args...) + if err != nil { + return zerrors.ThrowInternal(err, "DOMAIN-N51g3", "Errors.Internal") + } + + rowsAffected, err := result.RowsAffected() + if err != nil { + return zerrors.ThrowInternal(err, "DOMAIN-O62h4", "Errors.Internal") + } + + if rowsAffected == 0 { + return zerrors.ThrowNotFound(nil, "DOMAIN-P73i5", "Errors.Domain.NotFound") + } + + return nil + }) +} + +// Remove soft deletes an organization domain +func (r *DomainRepository) RemoveOrganizationDomain(ctx context.Context, instanceID, organizationID, domain string) error { + query := squirrel.Update(domainsTable). + Set(domainsDeletedAtCol, time.Now()). + Set(domainsUpdatedAtCol, time.Now()). + Where(squirrel.Eq{ + domainsInstanceIDCol: instanceID, + domainsOrgIDCol: organizationID, + domainsDomainCol: domain, + }). + Where(squirrel.Expr(domainsDeletedAtCol + " IS NULL")). + PlaceholderFormat(squirrel.Dollar) + + stmt, args, err := query.ToSql() + if err != nil { + return zerrors.ThrowInternal(err, "DOMAIN-Q84j6", "Errors.Internal") + } + + result, err := r.client.ExecContext(ctx, stmt, args...) + if err != nil { + return zerrors.ThrowInternal(err, "DOMAIN-R95k7", "Errors.Internal") + } + + rowsAffected, err := result.RowsAffected() + if err != nil { + return zerrors.ThrowInternal(err, "DOMAIN-S06l8", "Errors.Internal") + } + + if rowsAffected == 0 { + return zerrors.ThrowNotFound(nil, "DOMAIN-T17m9", "Errors.Domain.NotFound") + } + + return nil +} + +// Get returns a single domain matching the criteria +func (r *DomainRepository) Get(ctx context.Context, criteria v2domain.DomainSearchCriteria) (*v2domain.Domain, error) { + query := r.buildSelectQuery(criteria, v2domain.DomainPagination{Limit: 2}) // Limit to 2 to detect multiple results + + stmt, args, err := query.ToSql() + if err != nil { + return nil, zerrors.ThrowInternal(err, "DOMAIN-U28n0", "Errors.Internal") + } + + rows, err := r.client.QueryContext(ctx, stmt, args...) + if err != nil { + return nil, zerrors.ThrowInternal(err, "DOMAIN-V39o1", "Errors.Internal") + } + defer rows.Close() + + var domains []*v2domain.Domain + for rows.Next() { + domain, err := r.scanDomain(rows) + if err != nil { + return nil, err + } + domains = append(domains, domain) + } + + if err := rows.Err(); err != nil { + return nil, zerrors.ThrowInternal(err, "DOMAIN-W40p2", "Errors.Internal") + } + + if len(domains) == 0 { + return nil, zerrors.ThrowNotFound(nil, "DOMAIN-X51q3", "Errors.Domain.NotFound") + } + + if len(domains) > 1 { + return nil, zerrors.ThrowInvalidArgument(nil, "DOMAIN-Y62r4", "Errors.Domain.MultipleFound") + } + + return domains[0], nil +} + +// List returns a list of domains matching the criteria with pagination +func (r *DomainRepository) List(ctx context.Context, criteria v2domain.DomainSearchCriteria, pagination v2domain.DomainPagination) (*v2domain.DomainList, error) { + // First get the total count + countQuery := r.buildCountQuery(criteria) + stmt, args, err := countQuery.ToSql() + if err != nil { + return nil, zerrors.ThrowInternal(err, "DOMAIN-Z73s5", "Errors.Internal") + } + + var totalCount uint64 + err = r.client.QueryRowContext(ctx, stmt, args...).Scan(&totalCount) + if err != nil { + return nil, zerrors.ThrowInternal(err, "DOMAIN-A84t6", "Errors.Internal") + } + + // Then get the actual data + query := r.buildSelectQuery(criteria, pagination) + stmt, args, err = query.ToSql() + if err != nil { + return nil, zerrors.ThrowInternal(err, "DOMAIN-B95u7", "Errors.Internal") + } + + rows, err := r.client.QueryContext(ctx, stmt, args...) + if err != nil { + return nil, zerrors.ThrowInternal(err, "DOMAIN-C06v8", "Errors.Internal") + } + defer rows.Close() + + var domains []*v2domain.Domain + for rows.Next() { + domain, err := r.scanDomain(rows) + if err != nil { + return nil, err + } + domains = append(domains, domain) + } + + if err := rows.Err(); err != nil { + return nil, zerrors.ThrowInternal(err, "DOMAIN-D17w9", "Errors.Internal") + } + + return &v2domain.DomainList{ + Domains: domains, + TotalCount: totalCount, + }, nil +} + +func (r *DomainRepository) buildSelectQuery(criteria v2domain.DomainSearchCriteria, pagination v2domain.DomainPagination) squirrel.SelectBuilder { + query := squirrel.Select( + domainsInstanceIDCol, + domainsOrgIDCol, + domainsDomainCol, + domainsIsVerifiedCol, + domainsIsPrimaryCol, + domainsValidationTypeCol, + domainsCreatedAtCol, + domainsUpdatedAtCol, + domainsDeletedAtCol, + ).From(domainsTable). + PlaceholderFormat(squirrel.Dollar) + + query = r.applySearchCriteria(query, criteria) + query = r.applyPagination(query, pagination) + + return query +} + +func (r *DomainRepository) buildCountQuery(criteria v2domain.DomainSearchCriteria) squirrel.SelectBuilder { + query := squirrel.Select("COUNT(*)"). + From(domainsTable). + PlaceholderFormat(squirrel.Dollar) + + return r.applySearchCriteria(query, criteria) +} + +func (r *DomainRepository) applySearchCriteria(query squirrel.SelectBuilder, criteria v2domain.DomainSearchCriteria) squirrel.SelectBuilder { + // Always exclude soft-deleted domains + query = query.Where(squirrel.Expr(domainsDeletedAtCol + " IS NULL")) + + if criteria.ID != nil { + // Note: Our table doesn't have an ID column. This might need to be reconsidered + // For now, we'll ignore this criterion since the spec doesn't define where ID comes from + } + + if criteria.Domain != nil { + query = query.Where(squirrel.Eq{domainsDomainCol: *criteria.Domain}) + } + + if criteria.InstanceID != nil { + query = query.Where(squirrel.Eq{domainsInstanceIDCol: *criteria.InstanceID}) + } + + if criteria.OrganizationID != nil { + query = query.Where(squirrel.Eq{domainsOrgIDCol: *criteria.OrganizationID}) + } + + if criteria.IsVerified != nil { + query = query.Where(squirrel.Eq{domainsIsVerifiedCol: *criteria.IsVerified}) + } + + if criteria.IsPrimary != nil { + query = query.Where(squirrel.Eq{domainsIsPrimaryCol: *criteria.IsPrimary}) + } + + return query +} + +func (r *DomainRepository) applyPagination(query squirrel.SelectBuilder, pagination v2domain.DomainPagination) squirrel.SelectBuilder { + // Apply sorting + var orderBy string + switch pagination.SortBy { + case v2domain.DomainSortFieldCreatedAt: + orderBy = domainsCreatedAtCol + case v2domain.DomainSortFieldUpdatedAt: + orderBy = domainsUpdatedAtCol + case v2domain.DomainSortFieldDomain: + orderBy = domainsDomainCol + default: + orderBy = domainsCreatedAtCol + } + + if pagination.Order == database.SortOrderDesc { + orderBy += " DESC" + } else { + orderBy += " ASC" + } + + query = query.OrderBy(orderBy) + + // Apply pagination + if pagination.Limit > 0 { + query = query.Limit(uint64(pagination.Limit)) + } + + if pagination.Offset > 0 { + query = query.Offset(uint64(pagination.Offset)) + } + + return query +} + +func (r *DomainRepository) scanDomain(rows *sql.Rows) (*v2domain.Domain, error) { + var domain v2domain.Domain + var orgID sql.NullString + var validationType sql.NullInt32 + var deletedAt sql.NullTime + + err := rows.Scan( + &domain.InstanceID, + &orgID, + &domain.Domain, + &domain.IsVerified, + &domain.IsPrimary, + &validationType, + &domain.CreatedAt, + &domain.UpdatedAt, + &deletedAt, + ) + if err != nil { + return nil, zerrors.ThrowInternal(err, "DOMAIN-E28x0", "Errors.Internal") + } + + if orgID.Valid { + domain.OrganizationID = &orgID.String + } + + if validationType.Valid { + validationTypeValue := domain.OrgDomainValidationType(validationType.Int32) + domain.ValidationType = &validationTypeValue + } + + if deletedAt.Valid { + domain.DeletedAt = &deletedAt.Time + } + + return &domain, nil +} + +func (r *DomainRepository) withTransaction(ctx context.Context, fn func(TX) error) error { + tx, err := r.client.BeginTx(ctx, nil) + if err != nil { + return zerrors.ThrowInternal(err, "DOMAIN-F39y1", "Errors.Internal") + } + + defer func() { + if p := recover(); p != nil { + _ = tx.Rollback() + panic(p) + } + }() + + if err := fn(tx); err != nil { + if rbErr := tx.Rollback(); rbErr != nil { + return zerrors.ThrowInternal(rbErr, "DOMAIN-G40z2", "Errors.Internal") + } + return err + } + + if err := tx.Commit(); err != nil { + return zerrors.ThrowInternal(err, "DOMAIN-H51a3", "Errors.Internal") + } + + return nil +} \ No newline at end of file diff --git a/internal/v2/readmodel/domain_repository_test.go b/internal/v2/readmodel/domain_repository_test.go new file mode 100644 index 0000000000..e1c2194a43 --- /dev/null +++ b/internal/v2/readmodel/domain_repository_test.go @@ -0,0 +1,201 @@ +package readmodel + +import ( + "context" + "database/sql" + "testing" + "time" + + "github.com/Masterminds/squirrel" + "github.com/DATA-DOG/go-sqlmock" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/zitadel/zitadel/internal/domain" + v2domain "github.com/zitadel/zitadel/internal/v2/domain" + "github.com/zitadel/zitadel/internal/v2/database" +) + +func TestDomainRepository_AddInstanceDomain(t *testing.T) { + db, mock, err := sqlmock.New() + require.NoError(t, err) + defer db.Close() + + repo := NewDomainRepository(db) + + instanceID := "test-instance-id" + domainName := "test.example.com" + + mock.ExpectExec(`INSERT INTO zitadel\.domains`). + WithArgs(instanceID, domainName, true, false, sqlmock.AnyArg(), sqlmock.AnyArg()). + WillReturnResult(sqlmock.NewResult(1, 1)) + + result, err := repo.AddInstanceDomain(context.Background(), instanceID, domainName) + + require.NoError(t, err) + assert.Equal(t, instanceID, result.InstanceID) + assert.Nil(t, result.OrganizationID) + assert.Equal(t, domainName, result.Domain) + assert.True(t, result.IsVerified) + assert.False(t, result.IsPrimary) + assert.Nil(t, result.ValidationType) + + assert.NoError(t, mock.ExpectationsWereMet()) +} + +func TestDomainRepository_AddOrganizationDomain(t *testing.T) { + db, mock, err := sqlmock.New() + require.NoError(t, err) + defer db.Close() + + repo := NewDomainRepository(db) + + instanceID := "test-instance-id" + orgID := "test-org-id" + domainName := "test.example.com" + validationType := domain.OrgDomainValidationTypeHTTP + + mock.ExpectExec(`INSERT INTO zitadel\.domains`). + WithArgs(instanceID, orgID, domainName, false, false, int(validationType), sqlmock.AnyArg(), sqlmock.AnyArg()). + WillReturnResult(sqlmock.NewResult(1, 1)) + + result, err := repo.AddOrganizationDomain(context.Background(), instanceID, orgID, domainName, validationType) + + require.NoError(t, err) + assert.Equal(t, instanceID, result.InstanceID) + assert.Equal(t, orgID, *result.OrganizationID) + assert.Equal(t, domainName, result.Domain) + assert.False(t, result.IsVerified) + assert.False(t, result.IsPrimary) + assert.Equal(t, validationType, *result.ValidationType) + + assert.NoError(t, mock.ExpectationsWereMet()) +} + +func TestDomainRepository_SetInstanceDomainPrimary(t *testing.T) { + db, mock, err := sqlmock.New() + require.NoError(t, err) + defer db.Close() + + repo := NewDomainRepository(db) + + instanceID := "test-instance-id" + domainName := "test.example.com" + + // Mock transaction begin + mock.ExpectBegin() + + // Mock unset existing primary + mock.ExpectExec(`UPDATE zitadel\.domains SET.*is_primary.*=.*false`). + WillReturnResult(sqlmock.NewResult(0, 0)) + + // Mock set new primary + mock.ExpectExec(`UPDATE zitadel\.domains SET.*is_primary.*=.*true`). + WithArgs(sqlmock.AnyArg(), true, instanceID, nil, domainName). + WillReturnResult(sqlmock.NewResult(0, 1)) + + // Mock transaction commit + mock.ExpectCommit() + + err = repo.SetInstanceDomainPrimary(context.Background(), instanceID, domainName) + + require.NoError(t, err) + assert.NoError(t, mock.ExpectationsWereMet()) +} + +func TestDomainRepository_Get(t *testing.T) { + db, mock, err := sqlmock.New() + require.NoError(t, err) + defer db.Close() + + repo := NewDomainRepository(db) + + instanceID := "test-instance-id" + domainName := "test.example.com" + now := time.Now() + + criteria := v2domain.DomainSearchCriteria{ + InstanceID: &instanceID, + Domain: &domainName, + } + + rows := sqlmock.NewRows([]string{ + "instance_id", "org_id", "domain", "is_verified", "is_primary", "validation_type", "created_at", "updated_at", "deleted_at", + }).AddRow(instanceID, nil, domainName, true, false, nil, now, now, nil) + + mock.ExpectQuery(`SELECT .* FROM zitadel\.domains`). + WithArgs(domainName, instanceID). + WillReturnRows(rows) + + result, err := repo.Get(context.Background(), criteria) + + require.NoError(t, err) + assert.Equal(t, instanceID, result.InstanceID) + assert.Nil(t, result.OrganizationID) + assert.Equal(t, domainName, result.Domain) + assert.True(t, result.IsVerified) + assert.False(t, result.IsPrimary) + + assert.NoError(t, mock.ExpectationsWereMet()) +} + +func TestDomainRepository_List(t *testing.T) { + db, mock, err := sqlmock.New() + require.NoError(t, err) + defer db.Close() + + repo := NewDomainRepository(db) + + instanceID := "test-instance-id" + now := time.Now() + + criteria := v2domain.DomainSearchCriteria{ + InstanceID: &instanceID, + } + + pagination := v2domain.DomainPagination{ + Limit: 10, + Offset: 0, + SortBy: v2domain.DomainSortFieldDomain, + Order: database.SortOrderAsc, + } + + // Mock count query + countRows := sqlmock.NewRows([]string{"count"}).AddRow(2) + mock.ExpectQuery(`SELECT COUNT\(\*\) FROM zitadel\.domains`). + WithArgs(instanceID). + WillReturnRows(countRows) + + // Mock data query + rows := sqlmock.NewRows([]string{ + "instance_id", "org_id", "domain", "is_verified", "is_primary", "validation_type", "created_at", "updated_at", "deleted_at", + }). + AddRow(instanceID, nil, "instance.example.com", true, true, nil, now, now, nil). + AddRow(instanceID, "org-id", "org.example.com", false, false, int(domain.OrgDomainValidationTypeHTTP), now, now, nil) + + mock.ExpectQuery(`SELECT .* FROM zitadel\.domains.*ORDER BY domain ASC.*LIMIT 10`). + WithArgs(instanceID). + WillReturnRows(rows) + + result, err := repo.List(context.Background(), criteria, pagination) + + require.NoError(t, err) + assert.Equal(t, uint64(2), result.TotalCount) + assert.Len(t, result.Domains, 2) + + // Check first domain (instance domain) + assert.Equal(t, instanceID, result.Domains[0].InstanceID) + assert.Nil(t, result.Domains[0].OrganizationID) + assert.Equal(t, "instance.example.com", result.Domains[0].Domain) + assert.True(t, result.Domains[0].IsVerified) + assert.True(t, result.Domains[0].IsPrimary) + + // Check second domain (org domain) + assert.Equal(t, instanceID, result.Domains[1].InstanceID) + assert.Equal(t, "org-id", *result.Domains[1].OrganizationID) + assert.Equal(t, "org.example.com", result.Domains[1].Domain) + assert.False(t, result.Domains[1].IsVerified) + assert.False(t, result.Domains[1].IsPrimary) + + assert.NoError(t, mock.ExpectationsWereMet()) +} \ No newline at end of file