mirror of
https://github.com/zitadel/zitadel.git
synced 2025-08-11 13:19:21 +00:00
Implement unified domains table with migration, repository, and projection
Co-authored-by: adlerhurst <27845747+adlerhurst@users.noreply.github.com>
This commit is contained in:
40
cmd/setup/61.go
Normal file
40
cmd/setup/61.go
Normal file
@@ -0,0 +1,40 @@
|
||||
package setup
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"embed"
|
||||
"fmt"
|
||||
|
||||
"github.com/zitadel/logging"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/database"
|
||||
"github.com/zitadel/zitadel/internal/eventstore"
|
||||
)
|
||||
|
||||
var (
|
||||
//go:embed 61/*.sql
|
||||
createDomainsTable embed.FS
|
||||
)
|
||||
|
||||
type CreateDomainsTable struct {
|
||||
dbClient *database.DB
|
||||
}
|
||||
|
||||
func (mig *CreateDomainsTable) Execute(ctx context.Context, _ eventstore.Event) error {
|
||||
statements, err := readStatements(createDomainsTable, "61")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for _, stmt := range statements {
|
||||
logging.WithFields("file", stmt.file, "migration", mig.String()).Info("execute statement")
|
||||
if _, err := mig.dbClient.ExecContext(ctx, stmt.query); err != nil {
|
||||
return fmt.Errorf("%s %s: %w", mig.String(), stmt.file, err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (mig *CreateDomainsTable) String() string {
|
||||
return "61_create_domains_table"
|
||||
}
|
24
cmd/setup/61/01_create_domains_table.sql
Normal file
24
cmd/setup/61/01_create_domains_table.sql
Normal file
@@ -0,0 +1,24 @@
|
||||
CREATE TABLE IF NOT EXISTS zitadel.domains(
|
||||
instance_id TEXT NOT NULL
|
||||
, org_id TEXT
|
||||
, domain TEXT NOT NULL CHECK (LENGTH(domain) BETWEEN 1 AND 255)
|
||||
, is_verified BOOLEAN NOT NULL DEFAULT FALSE
|
||||
, is_primary BOOLEAN NOT NULL DEFAULT FALSE
|
||||
-- TODO make validation_type enum
|
||||
, validation_type SMALLINT CHECK (validation_type >= 0)
|
||||
|
||||
, created_at TIMESTAMP DEFAULT NOW()
|
||||
, updated_at TIMESTAMP DEFAULT NOW()
|
||||
, deleted_at TIMESTAMP DEFAULT NULL
|
||||
|
||||
, FOREIGN KEY (instance_id) REFERENCES zitadel.instances(id) ON DELETE CASCADE
|
||||
, FOREIGN KEY (instance_id, org_id) REFERENCES zitadel.organizations(instance_id, id) ON DELETE CASCADE
|
||||
|
||||
, CONSTRAINT domain_unique UNIQUE NULLS NOT DISTINCT (instance_id, org_id, domain) WHERE deleted_at IS NULL
|
||||
);
|
||||
|
||||
CREATE TRIGGER IF NOT EXISTS trigger_set_updated_at
|
||||
BEFORE UPDATE ON zitadel.domains
|
||||
FOR EACH ROW
|
||||
WHEN (OLD.updated_at IS NOT DISTINCT FROM NEW.updated_at)
|
||||
EXECUTE FUNCTION zitadel.set_updated_at();
|
@@ -157,6 +157,7 @@ type Steps struct {
|
||||
s58ReplaceLoginNames3View *ReplaceLoginNames3View
|
||||
s59SetupWebkeys *SetupWebkeys
|
||||
s60GenerateSystemID *GenerateSystemID
|
||||
s61CreateDomainsTable *CreateDomainsTable
|
||||
}
|
||||
|
||||
func MustNewSteps(v *viper.Viper) *Steps {
|
||||
|
@@ -218,6 +218,7 @@ func Setup(ctx context.Context, config *Config, steps *Steps, masterKey string)
|
||||
steps.s57CreateResourceCounts = &CreateResourceCounts{dbClient: dbClient}
|
||||
steps.s58ReplaceLoginNames3View = &ReplaceLoginNames3View{dbClient: dbClient}
|
||||
steps.s60GenerateSystemID = &GenerateSystemID{eventstore: eventstoreClient}
|
||||
steps.s61CreateDomainsTable = &CreateDomainsTable{dbClient: dbClient}
|
||||
|
||||
err = projection.Create(ctx, dbClient, eventstoreClient, config.Projections, nil, nil, nil)
|
||||
logging.OnError(err).Fatal("unable to start projections")
|
||||
@@ -266,6 +267,7 @@ func Setup(ctx context.Context, config *Config, steps *Steps, masterKey string)
|
||||
steps.s57CreateResourceCounts,
|
||||
steps.s58ReplaceLoginNames3View,
|
||||
steps.s60GenerateSystemID,
|
||||
steps.s61CreateDomainsTable,
|
||||
} {
|
||||
setupErr = executeMigration(ctx, eventstoreClient, step, "migration failed")
|
||||
if setupErr != nil {
|
||||
|
340
internal/query/projection/domains.go
Normal file
340
internal/query/projection/domains.go
Normal file
@@ -0,0 +1,340 @@
|
||||
package projection
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/domain"
|
||||
"github.com/zitadel/zitadel/internal/eventstore"
|
||||
old_handler "github.com/zitadel/zitadel/internal/eventstore/handler"
|
||||
"github.com/zitadel/zitadel/internal/eventstore/handler/v2"
|
||||
"github.com/zitadel/zitadel/internal/repository/instance"
|
||||
"github.com/zitadel/zitadel/internal/repository/org"
|
||||
"github.com/zitadel/zitadel/internal/zerrors"
|
||||
)
|
||||
|
||||
const (
|
||||
DomainsTable = "zitadel.domains"
|
||||
|
||||
DomainsInstanceIDCol = "instance_id"
|
||||
DomainsOrgIDCol = "org_id"
|
||||
DomainsDomainCol = "domain"
|
||||
DomainsIsVerifiedCol = "is_verified"
|
||||
DomainsIsPrimaryCol = "is_primary"
|
||||
DomainsValidationTypeCol = "validation_type"
|
||||
DomainsCreatedAtCol = "created_at"
|
||||
DomainsUpdatedAtCol = "updated_at"
|
||||
DomainsDeletedAtCol = "deleted_at"
|
||||
)
|
||||
|
||||
type domainsProjection struct{}
|
||||
|
||||
func newDomainsProjection(ctx context.Context, config handler.Config) *handler.Handler {
|
||||
return handler.NewHandler(ctx, &config, new(domainsProjection))
|
||||
}
|
||||
|
||||
func (*domainsProjection) Name() string {
|
||||
return DomainsTable
|
||||
}
|
||||
|
||||
func (*domainsProjection) Init() *old_handler.Check {
|
||||
// The table is created by migration, so we just need to check it exists
|
||||
return handler.NewTableCheck(
|
||||
handler.NewTable([]*handler.InitColumn{
|
||||
handler.NewColumn(DomainsInstanceIDCol, handler.ColumnTypeText),
|
||||
handler.NewColumn(DomainsOrgIDCol, handler.ColumnTypeText),
|
||||
handler.NewColumn(DomainsDomainCol, handler.ColumnTypeText),
|
||||
handler.NewColumn(DomainsIsVerifiedCol, handler.ColumnTypeBool),
|
||||
handler.NewColumn(DomainsIsPrimaryCol, handler.ColumnTypeBool),
|
||||
handler.NewColumn(DomainsValidationTypeCol, handler.ColumnTypeEnum),
|
||||
handler.NewColumn(DomainsCreatedAtCol, handler.ColumnTypeTimestamp),
|
||||
handler.NewColumn(DomainsUpdatedAtCol, handler.ColumnTypeTimestamp),
|
||||
handler.NewColumn(DomainsDeletedAtCol, handler.ColumnTypeTimestamp),
|
||||
},
|
||||
// Note: We don't define primary key here since the table is created by migration
|
||||
),
|
||||
)
|
||||
}
|
||||
|
||||
func (p *domainsProjection) Reducers() []handler.AggregateReducer {
|
||||
return []handler.AggregateReducer{
|
||||
{
|
||||
Aggregate: org.AggregateType,
|
||||
EventReducers: []handler.EventReducer{
|
||||
{
|
||||
Event: org.OrgDomainAddedEventType,
|
||||
Reduce: p.reduceOrgDomainAdded,
|
||||
},
|
||||
{
|
||||
Event: org.OrgDomainVerificationAddedEventType,
|
||||
Reduce: p.reduceOrgDomainVerificationAdded,
|
||||
},
|
||||
{
|
||||
Event: org.OrgDomainVerifiedEventType,
|
||||
Reduce: p.reduceOrgDomainVerified,
|
||||
},
|
||||
{
|
||||
Event: org.OrgDomainPrimarySetEventType,
|
||||
Reduce: p.reduceOrgPrimaryDomainSet,
|
||||
},
|
||||
{
|
||||
Event: org.OrgDomainRemovedEventType,
|
||||
Reduce: p.reduceOrgDomainRemoved,
|
||||
},
|
||||
{
|
||||
Event: org.OrgRemovedEventType,
|
||||
Reduce: p.reduceOrgRemoved,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Aggregate: instance.AggregateType,
|
||||
EventReducers: []handler.EventReducer{
|
||||
{
|
||||
Event: instance.InstanceDomainAddedEventType,
|
||||
Reduce: p.reduceInstanceDomainAdded,
|
||||
},
|
||||
{
|
||||
Event: instance.InstanceDomainPrimarySetEventType,
|
||||
Reduce: p.reduceInstancePrimaryDomainSet,
|
||||
},
|
||||
{
|
||||
Event: instance.InstanceDomainRemovedEventType,
|
||||
Reduce: p.reduceInstanceDomainRemoved,
|
||||
},
|
||||
{
|
||||
Event: instance.InstanceRemovedEventType,
|
||||
Reduce: p.reduceInstanceRemoved,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// Organization domain event handlers
|
||||
|
||||
func (p *domainsProjection) reduceOrgDomainAdded(event eventstore.Event) (*handler.Statement, error) {
|
||||
e, ok := event.(*org.DomainAddedEvent)
|
||||
if !ok {
|
||||
return nil, zerrors.ThrowInvalidArgumentf(nil, "PROJE-DM2DI", "reduce.wrong.event.type %s", org.OrgDomainAddedEventType)
|
||||
}
|
||||
return handler.NewCreateStatement(
|
||||
e,
|
||||
[]handler.Column{
|
||||
handler.NewCol(DomainsInstanceIDCol, e.Aggregate().InstanceID),
|
||||
handler.NewCol(DomainsOrgIDCol, e.Aggregate().ID),
|
||||
handler.NewCol(DomainsDomainCol, e.Domain),
|
||||
handler.NewCol(DomainsIsVerifiedCol, false),
|
||||
handler.NewCol(DomainsIsPrimaryCol, false),
|
||||
handler.NewCol(DomainsValidationTypeCol, domain.OrgDomainValidationTypeUnspecified),
|
||||
handler.NewCol(DomainsCreatedAtCol, e.CreationDate()),
|
||||
handler.NewCol(DomainsUpdatedAtCol, e.CreationDate()),
|
||||
},
|
||||
), nil
|
||||
}
|
||||
|
||||
func (p *domainsProjection) reduceOrgDomainVerificationAdded(event eventstore.Event) (*handler.Statement, error) {
|
||||
e, ok := event.(*org.DomainVerificationAddedEvent)
|
||||
if !ok {
|
||||
return nil, zerrors.ThrowInvalidArgumentf(nil, "PROJE-EBzyu", "reduce.wrong.event.type %s", org.OrgDomainVerificationAddedEventType)
|
||||
}
|
||||
return handler.NewUpdateStatement(
|
||||
e,
|
||||
[]handler.Column{
|
||||
handler.NewCol(DomainsUpdatedAtCol, e.CreationDate()),
|
||||
handler.NewCol(DomainsValidationTypeCol, e.ValidationType),
|
||||
},
|
||||
[]handler.Condition{
|
||||
handler.NewCond(DomainsInstanceIDCol, e.Aggregate().InstanceID),
|
||||
handler.NewCond(DomainsOrgIDCol, e.Aggregate().ID),
|
||||
handler.NewCond(DomainsDomainCol, e.Domain),
|
||||
handler.NewCond(DomainsDeletedAtCol, nil),
|
||||
},
|
||||
), nil
|
||||
}
|
||||
|
||||
func (p *domainsProjection) reduceOrgDomainVerified(event eventstore.Event) (*handler.Statement, error) {
|
||||
e, ok := event.(*org.DomainVerifiedEvent)
|
||||
if !ok {
|
||||
return nil, zerrors.ThrowInvalidArgumentf(nil, "PROJE-3Rvkr", "reduce.wrong.event.type %s", org.OrgDomainVerifiedEventType)
|
||||
}
|
||||
return handler.NewUpdateStatement(
|
||||
e,
|
||||
[]handler.Column{
|
||||
handler.NewCol(DomainsUpdatedAtCol, e.CreationDate()),
|
||||
handler.NewCol(DomainsIsVerifiedCol, true),
|
||||
},
|
||||
[]handler.Condition{
|
||||
handler.NewCond(DomainsInstanceIDCol, e.Aggregate().InstanceID),
|
||||
handler.NewCond(DomainsOrgIDCol, e.Aggregate().ID),
|
||||
handler.NewCond(DomainsDomainCol, e.Domain),
|
||||
handler.NewCond(DomainsDeletedAtCol, nil),
|
||||
},
|
||||
), nil
|
||||
}
|
||||
|
||||
func (p *domainsProjection) reduceOrgPrimaryDomainSet(event eventstore.Event) (*handler.Statement, error) {
|
||||
e, ok := event.(*org.DomainPrimarySetEvent)
|
||||
if !ok {
|
||||
return nil, zerrors.ThrowInvalidArgumentf(nil, "PROJE-aIuei", "reduce.wrong.event.type %s", org.OrgDomainPrimarySetEventType)
|
||||
}
|
||||
return handler.NewMultiStatement(
|
||||
e,
|
||||
handler.AddUpdateStatement(
|
||||
[]handler.Column{
|
||||
handler.NewCol(DomainsUpdatedAtCol, e.CreationDate()),
|
||||
handler.NewCol(DomainsIsPrimaryCol, false),
|
||||
},
|
||||
[]handler.Condition{
|
||||
handler.NewCond(DomainsInstanceIDCol, e.Aggregate().InstanceID),
|
||||
handler.NewCond(DomainsOrgIDCol, e.Aggregate().ID),
|
||||
handler.NewCond(DomainsIsPrimaryCol, true),
|
||||
handler.NewCond(DomainsDeletedAtCol, nil),
|
||||
},
|
||||
),
|
||||
handler.AddUpdateStatement(
|
||||
[]handler.Column{
|
||||
handler.NewCol(DomainsUpdatedAtCol, e.CreationDate()),
|
||||
handler.NewCol(DomainsIsPrimaryCol, true),
|
||||
},
|
||||
[]handler.Condition{
|
||||
handler.NewCond(DomainsInstanceIDCol, e.Aggregate().InstanceID),
|
||||
handler.NewCond(DomainsOrgIDCol, e.Aggregate().ID),
|
||||
handler.NewCond(DomainsDomainCol, e.Domain),
|
||||
handler.NewCond(DomainsDeletedAtCol, nil),
|
||||
},
|
||||
),
|
||||
), nil
|
||||
}
|
||||
|
||||
func (p *domainsProjection) reduceOrgDomainRemoved(event eventstore.Event) (*handler.Statement, error) {
|
||||
e, ok := event.(*org.DomainRemovedEvent)
|
||||
if !ok {
|
||||
return nil, zerrors.ThrowInvalidArgumentf(nil, "PROJE-gh1Mx", "reduce.wrong.event.type %s", org.OrgDomainRemovedEventType)
|
||||
}
|
||||
return handler.NewUpdateStatement(
|
||||
e,
|
||||
[]handler.Column{
|
||||
handler.NewCol(DomainsUpdatedAtCol, e.CreationDate()),
|
||||
handler.NewCol(DomainsDeletedAtCol, e.CreationDate()),
|
||||
},
|
||||
[]handler.Condition{
|
||||
handler.NewCond(DomainsInstanceIDCol, e.Aggregate().InstanceID),
|
||||
handler.NewCond(DomainsOrgIDCol, e.Aggregate().ID),
|
||||
handler.NewCond(DomainsDomainCol, e.Domain),
|
||||
},
|
||||
), nil
|
||||
}
|
||||
|
||||
func (p *domainsProjection) reduceOrgRemoved(event eventstore.Event) (*handler.Statement, error) {
|
||||
e, ok := event.(*org.OrgRemovedEvent)
|
||||
if !ok {
|
||||
return nil, zerrors.ThrowInvalidArgumentf(nil, "PROJE-dMUKJ", "reduce.wrong.event.type %s", org.OrgRemovedEventType)
|
||||
}
|
||||
|
||||
return handler.NewUpdateStatement(
|
||||
e,
|
||||
[]handler.Column{
|
||||
handler.NewCol(DomainsUpdatedAtCol, e.CreationDate()),
|
||||
handler.NewCol(DomainsDeletedAtCol, e.CreationDate()),
|
||||
},
|
||||
[]handler.Condition{
|
||||
handler.NewCond(DomainsInstanceIDCol, e.Aggregate().InstanceID),
|
||||
handler.NewCond(DomainsOrgIDCol, e.Aggregate().ID),
|
||||
handler.NewCond(DomainsDeletedAtCol, nil),
|
||||
},
|
||||
), nil
|
||||
}
|
||||
|
||||
// Instance domain event handlers
|
||||
|
||||
func (p *domainsProjection) reduceInstanceDomainAdded(event eventstore.Event) (*handler.Statement, error) {
|
||||
e, ok := event.(*instance.DomainAddedEvent)
|
||||
if !ok {
|
||||
return nil, zerrors.ThrowInvalidArgumentf(nil, "PROJE-38nNf", "reduce.wrong.event.type %s", instance.InstanceDomainAddedEventType)
|
||||
}
|
||||
return handler.NewCreateStatement(
|
||||
e,
|
||||
[]handler.Column{
|
||||
handler.NewCol(DomainsInstanceIDCol, e.Aggregate().ID),
|
||||
handler.NewCol(DomainsOrgIDCol, nil), // Instance domains have no org_id
|
||||
handler.NewCol(DomainsDomainCol, e.Domain),
|
||||
handler.NewCol(DomainsIsVerifiedCol, true), // Instance domains are always verified
|
||||
handler.NewCol(DomainsIsPrimaryCol, false),
|
||||
handler.NewCol(DomainsValidationTypeCol, nil), // Instance domains have no validation type
|
||||
handler.NewCol(DomainsCreatedAtCol, e.CreationDate()),
|
||||
handler.NewCol(DomainsUpdatedAtCol, e.CreationDate()),
|
||||
},
|
||||
), nil
|
||||
}
|
||||
|
||||
func (p *domainsProjection) reduceInstancePrimaryDomainSet(event eventstore.Event) (*handler.Statement, error) {
|
||||
e, ok := event.(*instance.DomainPrimarySetEvent)
|
||||
if !ok {
|
||||
return nil, zerrors.ThrowInvalidArgumentf(nil, "PROJE-f8nlw", "reduce.wrong.event.type %s", instance.InstanceDomainPrimarySetEventType)
|
||||
}
|
||||
return handler.NewMultiStatement(
|
||||
e,
|
||||
handler.AddUpdateStatement(
|
||||
[]handler.Column{
|
||||
handler.NewCol(DomainsUpdatedAtCol, e.CreationDate()),
|
||||
handler.NewCol(DomainsIsPrimaryCol, false),
|
||||
},
|
||||
[]handler.Condition{
|
||||
handler.NewCond(DomainsInstanceIDCol, e.Aggregate().ID),
|
||||
handler.NewCond(DomainsOrgIDCol, nil), // Instance domains
|
||||
handler.NewCond(DomainsIsPrimaryCol, true),
|
||||
handler.NewCond(DomainsDeletedAtCol, nil),
|
||||
},
|
||||
),
|
||||
handler.AddUpdateStatement(
|
||||
[]handler.Column{
|
||||
handler.NewCol(DomainsUpdatedAtCol, e.CreationDate()),
|
||||
handler.NewCol(DomainsIsPrimaryCol, true),
|
||||
},
|
||||
[]handler.Condition{
|
||||
handler.NewCond(DomainsInstanceIDCol, e.Aggregate().ID),
|
||||
handler.NewCond(DomainsOrgIDCol, nil), // Instance domains
|
||||
handler.NewCond(DomainsDomainCol, e.Domain),
|
||||
handler.NewCond(DomainsDeletedAtCol, nil),
|
||||
},
|
||||
),
|
||||
), nil
|
||||
}
|
||||
|
||||
func (p *domainsProjection) reduceInstanceDomainRemoved(event eventstore.Event) (*handler.Statement, error) {
|
||||
e, ok := event.(*instance.DomainRemovedEvent)
|
||||
if !ok {
|
||||
return nil, zerrors.ThrowInvalidArgumentf(nil, "PROJE-388Nk", "reduce.wrong.event.type %s", instance.InstanceDomainRemovedEventType)
|
||||
}
|
||||
return handler.NewUpdateStatement(
|
||||
e,
|
||||
[]handler.Column{
|
||||
handler.NewCol(DomainsUpdatedAtCol, e.CreationDate()),
|
||||
handler.NewCol(DomainsDeletedAtCol, e.CreationDate()),
|
||||
},
|
||||
[]handler.Condition{
|
||||
handler.NewCond(DomainsInstanceIDCol, e.Aggregate().ID),
|
||||
handler.NewCond(DomainsOrgIDCol, nil), // Instance domains
|
||||
handler.NewCond(DomainsDomainCol, e.Domain),
|
||||
},
|
||||
), nil
|
||||
}
|
||||
|
||||
func (p *domainsProjection) reduceInstanceRemoved(event eventstore.Event) (*handler.Statement, error) {
|
||||
e, ok := event.(*instance.InstanceRemovedEvent)
|
||||
if !ok {
|
||||
return nil, zerrors.ThrowInvalidArgumentf(nil, "PROJE-2n9f0", "reduce.wrong.event.type %s", instance.InstanceRemovedEventType)
|
||||
}
|
||||
|
||||
return handler.NewUpdateStatement(
|
||||
e,
|
||||
[]handler.Column{
|
||||
handler.NewCol(DomainsUpdatedAtCol, e.CreationDate()),
|
||||
handler.NewCol(DomainsDeletedAtCol, e.CreationDate()),
|
||||
},
|
||||
[]handler.Condition{
|
||||
handler.NewCond(DomainsInstanceIDCol, e.Aggregate().ID),
|
||||
handler.NewCond(DomainsDeletedAtCol, nil),
|
||||
},
|
||||
), nil
|
||||
}
|
@@ -37,6 +37,7 @@ var (
|
||||
ProjectGrantProjection *handler.Handler
|
||||
ProjectRoleProjection *handler.Handler
|
||||
OrgDomainProjection *handler.Handler
|
||||
DomainsProjection *handler.Handler // Unified domains table
|
||||
LoginPolicyProjection *handler.Handler
|
||||
IDPProjection *handler.Handler
|
||||
AppProjection *handler.Handler
|
||||
@@ -134,6 +135,7 @@ func Create(ctx context.Context, sqlClient *database.DB, es handler.EventStore,
|
||||
ProjectGrantProjection = newProjectGrantProjection(ctx, applyCustomConfig(projectionConfig, config.Customizations["project_grants"]))
|
||||
ProjectRoleProjection = newProjectRoleProjection(ctx, applyCustomConfig(projectionConfig, config.Customizations["project_roles"]))
|
||||
OrgDomainProjection = newOrgDomainProjection(ctx, applyCustomConfig(projectionConfig, config.Customizations["org_domains"]))
|
||||
DomainsProjection = newDomainsProjection(ctx, applyCustomConfig(projectionConfig, config.Customizations["domains"]))
|
||||
LoginPolicyProjection = newLoginPolicyProjection(ctx, applyCustomConfig(projectionConfig, config.Customizations["login_policies"]))
|
||||
IDPProjection = newIDPProjection(ctx, applyCustomConfig(projectionConfig, config.Customizations["idps"]))
|
||||
AppProjection = newAppProjection(ctx, applyCustomConfig(projectionConfig, config.Customizations["apps"]))
|
||||
|
107
internal/v2/domain/repository.go
Normal file
107
internal/v2/domain/repository.go
Normal file
@@ -0,0 +1,107 @@
|
||||
package domain
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/domain"
|
||||
"github.com/zitadel/zitadel/internal/v2/database"
|
||||
)
|
||||
|
||||
// Domain represents a unified domain that can belong to either an instance or an organization
|
||||
type Domain struct {
|
||||
ID string
|
||||
InstanceID string
|
||||
OrganizationID *string // nil for instance domains
|
||||
Domain string
|
||||
IsVerified bool
|
||||
IsPrimary bool
|
||||
ValidationType *domain.OrgDomainValidationType // nil for instance domains
|
||||
CreatedAt time.Time
|
||||
UpdatedAt time.Time
|
||||
DeletedAt *time.Time
|
||||
}
|
||||
|
||||
// IsInstanceDomain returns true if this is an instance domain (OrganizationID is nil)
|
||||
func (d *Domain) IsInstanceDomain() bool {
|
||||
return d.OrganizationID == nil
|
||||
}
|
||||
|
||||
// IsOrganizationDomain returns true if this is an organization domain (OrganizationID is not nil)
|
||||
func (d *Domain) IsOrganizationDomain() bool {
|
||||
return d.OrganizationID != nil
|
||||
}
|
||||
|
||||
// DomainSearchCriteria defines the search criteria for domains
|
||||
type DomainSearchCriteria struct {
|
||||
ID *string
|
||||
Domain *string
|
||||
InstanceID *string
|
||||
OrganizationID *string
|
||||
IsVerified *bool
|
||||
IsPrimary *bool
|
||||
}
|
||||
|
||||
// DomainPagination defines pagination options for domain queries
|
||||
type DomainPagination struct {
|
||||
Limit uint32
|
||||
Offset uint32
|
||||
SortBy DomainSortField
|
||||
Order database.SortOrder
|
||||
}
|
||||
|
||||
// DomainSortField defines the fields available for sorting domain results
|
||||
type DomainSortField int
|
||||
|
||||
const (
|
||||
DomainSortFieldCreatedAt DomainSortField = iota
|
||||
DomainSortFieldUpdatedAt
|
||||
DomainSortFieldDomain
|
||||
)
|
||||
|
||||
// DomainList represents a paginated list of domains
|
||||
type DomainList struct {
|
||||
Domains []*Domain
|
||||
TotalCount uint64
|
||||
}
|
||||
|
||||
// InstanceDomainRepository defines the repository interface for instance domain operations
|
||||
type InstanceDomainRepository interface {
|
||||
// Add creates a new instance domain (always verified)
|
||||
Add(ctx context.Context, instanceID, domain string) (*Domain, error)
|
||||
|
||||
// SetPrimary sets the primary domain for an instance
|
||||
SetPrimary(ctx context.Context, instanceID, domain string) error
|
||||
|
||||
// Remove soft deletes an instance domain
|
||||
Remove(ctx context.Context, instanceID, domain string) error
|
||||
|
||||
// Get returns a single instance domain matching the criteria
|
||||
// Returns error if multiple domains are found
|
||||
Get(ctx context.Context, criteria DomainSearchCriteria) (*Domain, error)
|
||||
|
||||
// List returns a list of instance domains matching the criteria with pagination
|
||||
List(ctx context.Context, criteria DomainSearchCriteria, pagination DomainPagination) (*DomainList, error)
|
||||
}
|
||||
|
||||
// OrganizationDomainRepository defines the repository interface for organization domain operations
|
||||
type OrganizationDomainRepository interface {
|
||||
// Add creates a new organization domain
|
||||
Add(ctx context.Context, instanceID, organizationID, domain string, validationType domain.OrgDomainValidationType) (*Domain, error)
|
||||
|
||||
// SetVerified marks an organization domain as verified
|
||||
SetVerified(ctx context.Context, instanceID, organizationID, domain string) error
|
||||
|
||||
// SetPrimary sets the primary domain for an organization
|
||||
SetPrimary(ctx context.Context, instanceID, organizationID, domain string) error
|
||||
|
||||
// Remove soft deletes an organization domain
|
||||
Remove(ctx context.Context, instanceID, organizationID, domain string) error
|
||||
|
||||
// Get returns a single organization domain matching the criteria
|
||||
// Returns error if multiple domains are found
|
||||
Get(ctx context.Context, criteria DomainSearchCriteria) (*Domain, error)
|
||||
|
||||
// List returns a list of organization domains matching the criteria with pagination
|
||||
List(ctx context.Context, criteria DomainSearchCriteria, pagination DomainPagination) (*DomainList, error)
|
||||
}
|
600
internal/v2/readmodel/domain_repository.go
Normal file
600
internal/v2/readmodel/domain_repository.go
Normal file
@@ -0,0 +1,600 @@
|
||||
package readmodel
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"time"
|
||||
|
||||
"github.com/Masterminds/squirrel"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/domain"
|
||||
v2domain "github.com/zitadel/zitadel/internal/v2/domain"
|
||||
"github.com/zitadel/zitadel/internal/v2/database"
|
||||
"github.com/zitadel/zitadel/internal/zerrors"
|
||||
)
|
||||
|
||||
// Database interfaces for the repository
|
||||
type DB interface {
|
||||
ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error)
|
||||
QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error)
|
||||
QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row
|
||||
BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error)
|
||||
}
|
||||
|
||||
type TX interface {
|
||||
ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error)
|
||||
QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error)
|
||||
QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row
|
||||
Rollback() error
|
||||
Commit() error
|
||||
}
|
||||
|
||||
const (
|
||||
domainsTable = "zitadel.domains"
|
||||
domainsInstanceIDCol = "instance_id"
|
||||
domainsOrgIDCol = "org_id"
|
||||
domainsDomainCol = "domain"
|
||||
domainsIsVerifiedCol = "is_verified"
|
||||
domainsIsPrimaryCol = "is_primary"
|
||||
domainsValidationTypeCol = "validation_type"
|
||||
domainsCreatedAtCol = "created_at"
|
||||
domainsUpdatedAtCol = "updated_at"
|
||||
domainsDeletedAtCol = "deleted_at"
|
||||
)
|
||||
|
||||
// DomainRepository implements both InstanceDomainRepository and OrganizationDomainRepository
|
||||
type DomainRepository struct {
|
||||
client DB
|
||||
}
|
||||
|
||||
// NewDomainRepository creates a new DomainRepository
|
||||
func NewDomainRepository(client DB) *DomainRepository {
|
||||
return &DomainRepository{
|
||||
client: client,
|
||||
}
|
||||
}
|
||||
|
||||
// Add creates a new instance domain (always verified)
|
||||
func (r *DomainRepository) AddInstanceDomain(ctx context.Context, instanceID, domain string) (*v2domain.Domain, error) {
|
||||
now := time.Now()
|
||||
query := squirrel.Insert(domainsTable).
|
||||
Columns(
|
||||
domainsInstanceIDCol,
|
||||
domainsDomainCol,
|
||||
domainsIsVerifiedCol,
|
||||
domainsIsPrimaryCol,
|
||||
domainsCreatedAtCol,
|
||||
domainsUpdatedAtCol,
|
||||
).
|
||||
Values(instanceID, domain, true, false, now, now).
|
||||
PlaceholderFormat(squirrel.Dollar)
|
||||
|
||||
stmt, args, err := query.ToSql()
|
||||
if err != nil {
|
||||
return nil, zerrors.ThrowInternal(err, "DOMAIN-1n8fK", "Errors.Internal")
|
||||
}
|
||||
|
||||
_, err = r.client.ExecContext(ctx, stmt, args...)
|
||||
if err != nil {
|
||||
return nil, zerrors.ThrowInternal(err, "DOMAIN-3m9sL", "Errors.Internal")
|
||||
}
|
||||
|
||||
return &v2domain.Domain{
|
||||
InstanceID: instanceID,
|
||||
OrganizationID: nil,
|
||||
Domain: domain,
|
||||
IsVerified: true,
|
||||
IsPrimary: false,
|
||||
ValidationType: nil,
|
||||
CreatedAt: now,
|
||||
UpdatedAt: now,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// SetPrimary sets the primary domain for an instance
|
||||
func (r *DomainRepository) SetInstanceDomainPrimary(ctx context.Context, instanceID, domain string) error {
|
||||
return r.withTransaction(ctx, func(tx database.Tx) error {
|
||||
// First, unset any existing primary domain for this instance
|
||||
unsetQuery := squirrel.Update(domainsTable).
|
||||
Set(domainsIsPrimaryCol, false).
|
||||
Set(domainsUpdatedAtCol, time.Now()).
|
||||
Where(squirrel.Eq{
|
||||
domainsInstanceIDCol: instanceID,
|
||||
domainsOrgIDCol: nil,
|
||||
domainsIsPrimaryCol: true,
|
||||
}).
|
||||
Where(squirrel.Expr(domainsDeletedAtCol + " IS NULL")).
|
||||
PlaceholderFormat(squirrel.Dollar)
|
||||
|
||||
stmt, args, err := unsetQuery.ToSql()
|
||||
if err != nil {
|
||||
return zerrors.ThrowInternal(err, "DOMAIN-4k2nL", "Errors.Internal")
|
||||
}
|
||||
|
||||
_, err = tx.ExecContext(ctx, stmt, args...)
|
||||
if err != nil {
|
||||
return zerrors.ThrowInternal(err, "DOMAIN-5n3mK", "Errors.Internal")
|
||||
}
|
||||
|
||||
// Then set the new primary domain
|
||||
setPrimaryQuery := squirrel.Update(domainsTable).
|
||||
Set(domainsIsPrimaryCol, true).
|
||||
Set(domainsUpdatedAtCol, time.Now()).
|
||||
Where(squirrel.Eq{
|
||||
domainsInstanceIDCol: instanceID,
|
||||
domainsOrgIDCol: nil,
|
||||
domainsDomainCol: domain,
|
||||
}).
|
||||
Where(squirrel.Expr(domainsDeletedAtCol + " IS NULL")).
|
||||
PlaceholderFormat(squirrel.Dollar)
|
||||
|
||||
stmt, args, err = setPrimaryQuery.ToSql()
|
||||
if err != nil {
|
||||
return zerrors.ThrowInternal(err, "DOMAIN-6o4pM", "Errors.Internal")
|
||||
}
|
||||
|
||||
result, err := tx.ExecContext(ctx, stmt, args...)
|
||||
if err != nil {
|
||||
return zerrors.ThrowInternal(err, "DOMAIN-7p5qN", "Errors.Internal")
|
||||
}
|
||||
|
||||
rowsAffected, err := result.RowsAffected()
|
||||
if err != nil {
|
||||
return zerrors.ThrowInternal(err, "DOMAIN-8q6rO", "Errors.Internal")
|
||||
}
|
||||
|
||||
if rowsAffected == 0 {
|
||||
return zerrors.ThrowNotFound(nil, "DOMAIN-9r7sP", "Errors.Domain.NotFound")
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
// Remove soft deletes an instance domain
|
||||
func (r *DomainRepository) RemoveInstanceDomain(ctx context.Context, instanceID, domain string) error {
|
||||
query := squirrel.Update(domainsTable).
|
||||
Set(domainsDeletedAtCol, time.Now()).
|
||||
Set(domainsUpdatedAtCol, time.Now()).
|
||||
Where(squirrel.Eq{
|
||||
domainsInstanceIDCol: instanceID,
|
||||
domainsOrgIDCol: nil,
|
||||
domainsDomainCol: domain,
|
||||
}).
|
||||
Where(squirrel.Expr(domainsDeletedAtCol + " IS NULL")).
|
||||
PlaceholderFormat(squirrel.Dollar)
|
||||
|
||||
stmt, args, err := query.ToSql()
|
||||
if err != nil {
|
||||
return zerrors.ThrowInternal(err, "DOMAIN-As8tQ", "Errors.Internal")
|
||||
}
|
||||
|
||||
result, err := r.client.ExecContext(ctx, stmt, args...)
|
||||
if err != nil {
|
||||
return zerrors.ThrowInternal(err, "DOMAIN-Bt9uR", "Errors.Internal")
|
||||
}
|
||||
|
||||
rowsAffected, err := result.RowsAffected()
|
||||
if err != nil {
|
||||
return zerrors.ThrowInternal(err, "DOMAIN-Cu0vS", "Errors.Internal")
|
||||
}
|
||||
|
||||
if rowsAffected == 0 {
|
||||
return zerrors.ThrowNotFound(nil, "DOMAIN-Dv1wT", "Errors.Domain.NotFound")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Add creates a new organization domain
|
||||
func (r *DomainRepository) AddOrganizationDomain(ctx context.Context, instanceID, organizationID, domain string, validationType domain.OrgDomainValidationType) (*v2domain.Domain, error) {
|
||||
now := time.Now()
|
||||
query := squirrel.Insert(domainsTable).
|
||||
Columns(
|
||||
domainsInstanceIDCol,
|
||||
domainsOrgIDCol,
|
||||
domainsDomainCol,
|
||||
domainsIsVerifiedCol,
|
||||
domainsIsPrimaryCol,
|
||||
domainsValidationTypeCol,
|
||||
domainsCreatedAtCol,
|
||||
domainsUpdatedAtCol,
|
||||
).
|
||||
Values(instanceID, organizationID, domain, false, false, int(validationType), now, now).
|
||||
PlaceholderFormat(squirrel.Dollar)
|
||||
|
||||
stmt, args, err := query.ToSql()
|
||||
if err != nil {
|
||||
return nil, zerrors.ThrowInternal(err, "DOMAIN-Ew2xU", "Errors.Internal")
|
||||
}
|
||||
|
||||
_, err = r.client.ExecContext(ctx, stmt, args...)
|
||||
if err != nil {
|
||||
return nil, zerrors.ThrowInternal(err, "DOMAIN-Fx3yV", "Errors.Internal")
|
||||
}
|
||||
|
||||
return &v2domain.Domain{
|
||||
InstanceID: instanceID,
|
||||
OrganizationID: &organizationID,
|
||||
Domain: domain,
|
||||
IsVerified: false,
|
||||
IsPrimary: false,
|
||||
ValidationType: &validationType,
|
||||
CreatedAt: now,
|
||||
UpdatedAt: now,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// SetVerified marks an organization domain as verified
|
||||
func (r *DomainRepository) SetOrganizationDomainVerified(ctx context.Context, instanceID, organizationID, domain string) error {
|
||||
query := squirrel.Update(domainsTable).
|
||||
Set(domainsIsVerifiedCol, true).
|
||||
Set(domainsUpdatedAtCol, time.Now()).
|
||||
Where(squirrel.Eq{
|
||||
domainsInstanceIDCol: instanceID,
|
||||
domainsOrgIDCol: organizationID,
|
||||
domainsDomainCol: domain,
|
||||
}).
|
||||
Where(squirrel.Expr(domainsDeletedAtCol + " IS NULL")).
|
||||
PlaceholderFormat(squirrel.Dollar)
|
||||
|
||||
stmt, args, err := query.ToSql()
|
||||
if err != nil {
|
||||
return zerrors.ThrowInternal(err, "DOMAIN-Gy4zW", "Errors.Internal")
|
||||
}
|
||||
|
||||
result, err := r.client.ExecContext(ctx, stmt, args...)
|
||||
if err != nil {
|
||||
return zerrors.ThrowInternal(err, "DOMAIN-Hz5aX", "Errors.Internal")
|
||||
}
|
||||
|
||||
rowsAffected, err := result.RowsAffected()
|
||||
if err != nil {
|
||||
return zerrors.ThrowInternal(err, "DOMAIN-I06bY", "Errors.Internal")
|
||||
}
|
||||
|
||||
if rowsAffected == 0 {
|
||||
return zerrors.ThrowNotFound(nil, "DOMAIN-J17cZ", "Errors.Domain.NotFound")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// SetPrimary sets the primary domain for an organization
|
||||
func (r *DomainRepository) SetOrganizationDomainPrimary(ctx context.Context, instanceID, organizationID, domain string) error {
|
||||
return r.withTransaction(ctx, func(tx database.Tx) error {
|
||||
// First, unset any existing primary domain for this organization
|
||||
unsetQuery := squirrel.Update(domainsTable).
|
||||
Set(domainsIsPrimaryCol, false).
|
||||
Set(domainsUpdatedAtCol, time.Now()).
|
||||
Where(squirrel.Eq{
|
||||
domainsInstanceIDCol: instanceID,
|
||||
domainsOrgIDCol: organizationID,
|
||||
domainsIsPrimaryCol: true,
|
||||
}).
|
||||
Where(squirrel.Expr(domainsDeletedAtCol + " IS NULL")).
|
||||
PlaceholderFormat(squirrel.Dollar)
|
||||
|
||||
stmt, args, err := unsetQuery.ToSql()
|
||||
if err != nil {
|
||||
return zerrors.ThrowInternal(err, "DOMAIN-K28d0", "Errors.Internal")
|
||||
}
|
||||
|
||||
_, err = tx.ExecContext(ctx, stmt, args...)
|
||||
if err != nil {
|
||||
return zerrors.ThrowInternal(err, "DOMAIN-L39e1", "Errors.Internal")
|
||||
}
|
||||
|
||||
// Then set the new primary domain
|
||||
setPrimaryQuery := squirrel.Update(domainsTable).
|
||||
Set(domainsIsPrimaryCol, true).
|
||||
Set(domainsUpdatedAtCol, time.Now()).
|
||||
Where(squirrel.Eq{
|
||||
domainsInstanceIDCol: instanceID,
|
||||
domainsOrgIDCol: organizationID,
|
||||
domainsDomainCol: domain,
|
||||
}).
|
||||
Where(squirrel.Expr(domainsDeletedAtCol + " IS NULL")).
|
||||
PlaceholderFormat(squirrel.Dollar)
|
||||
|
||||
stmt, args, err = setPrimaryQuery.ToSql()
|
||||
if err != nil {
|
||||
return zerrors.ThrowInternal(err, "DOMAIN-M40f2", "Errors.Internal")
|
||||
}
|
||||
|
||||
result, err := tx.ExecContext(ctx, stmt, args...)
|
||||
if err != nil {
|
||||
return zerrors.ThrowInternal(err, "DOMAIN-N51g3", "Errors.Internal")
|
||||
}
|
||||
|
||||
rowsAffected, err := result.RowsAffected()
|
||||
if err != nil {
|
||||
return zerrors.ThrowInternal(err, "DOMAIN-O62h4", "Errors.Internal")
|
||||
}
|
||||
|
||||
if rowsAffected == 0 {
|
||||
return zerrors.ThrowNotFound(nil, "DOMAIN-P73i5", "Errors.Domain.NotFound")
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
// Remove soft deletes an organization domain
|
||||
func (r *DomainRepository) RemoveOrganizationDomain(ctx context.Context, instanceID, organizationID, domain string) error {
|
||||
query := squirrel.Update(domainsTable).
|
||||
Set(domainsDeletedAtCol, time.Now()).
|
||||
Set(domainsUpdatedAtCol, time.Now()).
|
||||
Where(squirrel.Eq{
|
||||
domainsInstanceIDCol: instanceID,
|
||||
domainsOrgIDCol: organizationID,
|
||||
domainsDomainCol: domain,
|
||||
}).
|
||||
Where(squirrel.Expr(domainsDeletedAtCol + " IS NULL")).
|
||||
PlaceholderFormat(squirrel.Dollar)
|
||||
|
||||
stmt, args, err := query.ToSql()
|
||||
if err != nil {
|
||||
return zerrors.ThrowInternal(err, "DOMAIN-Q84j6", "Errors.Internal")
|
||||
}
|
||||
|
||||
result, err := r.client.ExecContext(ctx, stmt, args...)
|
||||
if err != nil {
|
||||
return zerrors.ThrowInternal(err, "DOMAIN-R95k7", "Errors.Internal")
|
||||
}
|
||||
|
||||
rowsAffected, err := result.RowsAffected()
|
||||
if err != nil {
|
||||
return zerrors.ThrowInternal(err, "DOMAIN-S06l8", "Errors.Internal")
|
||||
}
|
||||
|
||||
if rowsAffected == 0 {
|
||||
return zerrors.ThrowNotFound(nil, "DOMAIN-T17m9", "Errors.Domain.NotFound")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Get returns a single domain matching the criteria
|
||||
func (r *DomainRepository) Get(ctx context.Context, criteria v2domain.DomainSearchCriteria) (*v2domain.Domain, error) {
|
||||
query := r.buildSelectQuery(criteria, v2domain.DomainPagination{Limit: 2}) // Limit to 2 to detect multiple results
|
||||
|
||||
stmt, args, err := query.ToSql()
|
||||
if err != nil {
|
||||
return nil, zerrors.ThrowInternal(err, "DOMAIN-U28n0", "Errors.Internal")
|
||||
}
|
||||
|
||||
rows, err := r.client.QueryContext(ctx, stmt, args...)
|
||||
if err != nil {
|
||||
return nil, zerrors.ThrowInternal(err, "DOMAIN-V39o1", "Errors.Internal")
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var domains []*v2domain.Domain
|
||||
for rows.Next() {
|
||||
domain, err := r.scanDomain(rows)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
domains = append(domains, domain)
|
||||
}
|
||||
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, zerrors.ThrowInternal(err, "DOMAIN-W40p2", "Errors.Internal")
|
||||
}
|
||||
|
||||
if len(domains) == 0 {
|
||||
return nil, zerrors.ThrowNotFound(nil, "DOMAIN-X51q3", "Errors.Domain.NotFound")
|
||||
}
|
||||
|
||||
if len(domains) > 1 {
|
||||
return nil, zerrors.ThrowInvalidArgument(nil, "DOMAIN-Y62r4", "Errors.Domain.MultipleFound")
|
||||
}
|
||||
|
||||
return domains[0], nil
|
||||
}
|
||||
|
||||
// List returns a list of domains matching the criteria with pagination
|
||||
func (r *DomainRepository) List(ctx context.Context, criteria v2domain.DomainSearchCriteria, pagination v2domain.DomainPagination) (*v2domain.DomainList, error) {
|
||||
// First get the total count
|
||||
countQuery := r.buildCountQuery(criteria)
|
||||
stmt, args, err := countQuery.ToSql()
|
||||
if err != nil {
|
||||
return nil, zerrors.ThrowInternal(err, "DOMAIN-Z73s5", "Errors.Internal")
|
||||
}
|
||||
|
||||
var totalCount uint64
|
||||
err = r.client.QueryRowContext(ctx, stmt, args...).Scan(&totalCount)
|
||||
if err != nil {
|
||||
return nil, zerrors.ThrowInternal(err, "DOMAIN-A84t6", "Errors.Internal")
|
||||
}
|
||||
|
||||
// Then get the actual data
|
||||
query := r.buildSelectQuery(criteria, pagination)
|
||||
stmt, args, err = query.ToSql()
|
||||
if err != nil {
|
||||
return nil, zerrors.ThrowInternal(err, "DOMAIN-B95u7", "Errors.Internal")
|
||||
}
|
||||
|
||||
rows, err := r.client.QueryContext(ctx, stmt, args...)
|
||||
if err != nil {
|
||||
return nil, zerrors.ThrowInternal(err, "DOMAIN-C06v8", "Errors.Internal")
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var domains []*v2domain.Domain
|
||||
for rows.Next() {
|
||||
domain, err := r.scanDomain(rows)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
domains = append(domains, domain)
|
||||
}
|
||||
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, zerrors.ThrowInternal(err, "DOMAIN-D17w9", "Errors.Internal")
|
||||
}
|
||||
|
||||
return &v2domain.DomainList{
|
||||
Domains: domains,
|
||||
TotalCount: totalCount,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (r *DomainRepository) buildSelectQuery(criteria v2domain.DomainSearchCriteria, pagination v2domain.DomainPagination) squirrel.SelectBuilder {
|
||||
query := squirrel.Select(
|
||||
domainsInstanceIDCol,
|
||||
domainsOrgIDCol,
|
||||
domainsDomainCol,
|
||||
domainsIsVerifiedCol,
|
||||
domainsIsPrimaryCol,
|
||||
domainsValidationTypeCol,
|
||||
domainsCreatedAtCol,
|
||||
domainsUpdatedAtCol,
|
||||
domainsDeletedAtCol,
|
||||
).From(domainsTable).
|
||||
PlaceholderFormat(squirrel.Dollar)
|
||||
|
||||
query = r.applySearchCriteria(query, criteria)
|
||||
query = r.applyPagination(query, pagination)
|
||||
|
||||
return query
|
||||
}
|
||||
|
||||
func (r *DomainRepository) buildCountQuery(criteria v2domain.DomainSearchCriteria) squirrel.SelectBuilder {
|
||||
query := squirrel.Select("COUNT(*)").
|
||||
From(domainsTable).
|
||||
PlaceholderFormat(squirrel.Dollar)
|
||||
|
||||
return r.applySearchCriteria(query, criteria)
|
||||
}
|
||||
|
||||
func (r *DomainRepository) applySearchCriteria(query squirrel.SelectBuilder, criteria v2domain.DomainSearchCriteria) squirrel.SelectBuilder {
|
||||
// Always exclude soft-deleted domains
|
||||
query = query.Where(squirrel.Expr(domainsDeletedAtCol + " IS NULL"))
|
||||
|
||||
if criteria.ID != nil {
|
||||
// Note: Our table doesn't have an ID column. This might need to be reconsidered
|
||||
// For now, we'll ignore this criterion since the spec doesn't define where ID comes from
|
||||
}
|
||||
|
||||
if criteria.Domain != nil {
|
||||
query = query.Where(squirrel.Eq{domainsDomainCol: *criteria.Domain})
|
||||
}
|
||||
|
||||
if criteria.InstanceID != nil {
|
||||
query = query.Where(squirrel.Eq{domainsInstanceIDCol: *criteria.InstanceID})
|
||||
}
|
||||
|
||||
if criteria.OrganizationID != nil {
|
||||
query = query.Where(squirrel.Eq{domainsOrgIDCol: *criteria.OrganizationID})
|
||||
}
|
||||
|
||||
if criteria.IsVerified != nil {
|
||||
query = query.Where(squirrel.Eq{domainsIsVerifiedCol: *criteria.IsVerified})
|
||||
}
|
||||
|
||||
if criteria.IsPrimary != nil {
|
||||
query = query.Where(squirrel.Eq{domainsIsPrimaryCol: *criteria.IsPrimary})
|
||||
}
|
||||
|
||||
return query
|
||||
}
|
||||
|
||||
func (r *DomainRepository) applyPagination(query squirrel.SelectBuilder, pagination v2domain.DomainPagination) squirrel.SelectBuilder {
|
||||
// Apply sorting
|
||||
var orderBy string
|
||||
switch pagination.SortBy {
|
||||
case v2domain.DomainSortFieldCreatedAt:
|
||||
orderBy = domainsCreatedAtCol
|
||||
case v2domain.DomainSortFieldUpdatedAt:
|
||||
orderBy = domainsUpdatedAtCol
|
||||
case v2domain.DomainSortFieldDomain:
|
||||
orderBy = domainsDomainCol
|
||||
default:
|
||||
orderBy = domainsCreatedAtCol
|
||||
}
|
||||
|
||||
if pagination.Order == database.SortOrderDesc {
|
||||
orderBy += " DESC"
|
||||
} else {
|
||||
orderBy += " ASC"
|
||||
}
|
||||
|
||||
query = query.OrderBy(orderBy)
|
||||
|
||||
// Apply pagination
|
||||
if pagination.Limit > 0 {
|
||||
query = query.Limit(uint64(pagination.Limit))
|
||||
}
|
||||
|
||||
if pagination.Offset > 0 {
|
||||
query = query.Offset(uint64(pagination.Offset))
|
||||
}
|
||||
|
||||
return query
|
||||
}
|
||||
|
||||
func (r *DomainRepository) scanDomain(rows *sql.Rows) (*v2domain.Domain, error) {
|
||||
var domain v2domain.Domain
|
||||
var orgID sql.NullString
|
||||
var validationType sql.NullInt32
|
||||
var deletedAt sql.NullTime
|
||||
|
||||
err := rows.Scan(
|
||||
&domain.InstanceID,
|
||||
&orgID,
|
||||
&domain.Domain,
|
||||
&domain.IsVerified,
|
||||
&domain.IsPrimary,
|
||||
&validationType,
|
||||
&domain.CreatedAt,
|
||||
&domain.UpdatedAt,
|
||||
&deletedAt,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, zerrors.ThrowInternal(err, "DOMAIN-E28x0", "Errors.Internal")
|
||||
}
|
||||
|
||||
if orgID.Valid {
|
||||
domain.OrganizationID = &orgID.String
|
||||
}
|
||||
|
||||
if validationType.Valid {
|
||||
validationTypeValue := domain.OrgDomainValidationType(validationType.Int32)
|
||||
domain.ValidationType = &validationTypeValue
|
||||
}
|
||||
|
||||
if deletedAt.Valid {
|
||||
domain.DeletedAt = &deletedAt.Time
|
||||
}
|
||||
|
||||
return &domain, nil
|
||||
}
|
||||
|
||||
func (r *DomainRepository) withTransaction(ctx context.Context, fn func(TX) error) error {
|
||||
tx, err := r.client.BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
return zerrors.ThrowInternal(err, "DOMAIN-F39y1", "Errors.Internal")
|
||||
}
|
||||
|
||||
defer func() {
|
||||
if p := recover(); p != nil {
|
||||
_ = tx.Rollback()
|
||||
panic(p)
|
||||
}
|
||||
}()
|
||||
|
||||
if err := fn(tx); err != nil {
|
||||
if rbErr := tx.Rollback(); rbErr != nil {
|
||||
return zerrors.ThrowInternal(rbErr, "DOMAIN-G40z2", "Errors.Internal")
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
if err := tx.Commit(); err != nil {
|
||||
return zerrors.ThrowInternal(err, "DOMAIN-H51a3", "Errors.Internal")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
201
internal/v2/readmodel/domain_repository_test.go
Normal file
201
internal/v2/readmodel/domain_repository_test.go
Normal file
@@ -0,0 +1,201 @@
|
||||
package readmodel
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Masterminds/squirrel"
|
||||
"github.com/DATA-DOG/go-sqlmock"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/domain"
|
||||
v2domain "github.com/zitadel/zitadel/internal/v2/domain"
|
||||
"github.com/zitadel/zitadel/internal/v2/database"
|
||||
)
|
||||
|
||||
func TestDomainRepository_AddInstanceDomain(t *testing.T) {
|
||||
db, mock, err := sqlmock.New()
|
||||
require.NoError(t, err)
|
||||
defer db.Close()
|
||||
|
||||
repo := NewDomainRepository(db)
|
||||
|
||||
instanceID := "test-instance-id"
|
||||
domainName := "test.example.com"
|
||||
|
||||
mock.ExpectExec(`INSERT INTO zitadel\.domains`).
|
||||
WithArgs(instanceID, domainName, true, false, sqlmock.AnyArg(), sqlmock.AnyArg()).
|
||||
WillReturnResult(sqlmock.NewResult(1, 1))
|
||||
|
||||
result, err := repo.AddInstanceDomain(context.Background(), instanceID, domainName)
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, instanceID, result.InstanceID)
|
||||
assert.Nil(t, result.OrganizationID)
|
||||
assert.Equal(t, domainName, result.Domain)
|
||||
assert.True(t, result.IsVerified)
|
||||
assert.False(t, result.IsPrimary)
|
||||
assert.Nil(t, result.ValidationType)
|
||||
|
||||
assert.NoError(t, mock.ExpectationsWereMet())
|
||||
}
|
||||
|
||||
func TestDomainRepository_AddOrganizationDomain(t *testing.T) {
|
||||
db, mock, err := sqlmock.New()
|
||||
require.NoError(t, err)
|
||||
defer db.Close()
|
||||
|
||||
repo := NewDomainRepository(db)
|
||||
|
||||
instanceID := "test-instance-id"
|
||||
orgID := "test-org-id"
|
||||
domainName := "test.example.com"
|
||||
validationType := domain.OrgDomainValidationTypeHTTP
|
||||
|
||||
mock.ExpectExec(`INSERT INTO zitadel\.domains`).
|
||||
WithArgs(instanceID, orgID, domainName, false, false, int(validationType), sqlmock.AnyArg(), sqlmock.AnyArg()).
|
||||
WillReturnResult(sqlmock.NewResult(1, 1))
|
||||
|
||||
result, err := repo.AddOrganizationDomain(context.Background(), instanceID, orgID, domainName, validationType)
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, instanceID, result.InstanceID)
|
||||
assert.Equal(t, orgID, *result.OrganizationID)
|
||||
assert.Equal(t, domainName, result.Domain)
|
||||
assert.False(t, result.IsVerified)
|
||||
assert.False(t, result.IsPrimary)
|
||||
assert.Equal(t, validationType, *result.ValidationType)
|
||||
|
||||
assert.NoError(t, mock.ExpectationsWereMet())
|
||||
}
|
||||
|
||||
func TestDomainRepository_SetInstanceDomainPrimary(t *testing.T) {
|
||||
db, mock, err := sqlmock.New()
|
||||
require.NoError(t, err)
|
||||
defer db.Close()
|
||||
|
||||
repo := NewDomainRepository(db)
|
||||
|
||||
instanceID := "test-instance-id"
|
||||
domainName := "test.example.com"
|
||||
|
||||
// Mock transaction begin
|
||||
mock.ExpectBegin()
|
||||
|
||||
// Mock unset existing primary
|
||||
mock.ExpectExec(`UPDATE zitadel\.domains SET.*is_primary.*=.*false`).
|
||||
WillReturnResult(sqlmock.NewResult(0, 0))
|
||||
|
||||
// Mock set new primary
|
||||
mock.ExpectExec(`UPDATE zitadel\.domains SET.*is_primary.*=.*true`).
|
||||
WithArgs(sqlmock.AnyArg(), true, instanceID, nil, domainName).
|
||||
WillReturnResult(sqlmock.NewResult(0, 1))
|
||||
|
||||
// Mock transaction commit
|
||||
mock.ExpectCommit()
|
||||
|
||||
err = repo.SetInstanceDomainPrimary(context.Background(), instanceID, domainName)
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.NoError(t, mock.ExpectationsWereMet())
|
||||
}
|
||||
|
||||
func TestDomainRepository_Get(t *testing.T) {
|
||||
db, mock, err := sqlmock.New()
|
||||
require.NoError(t, err)
|
||||
defer db.Close()
|
||||
|
||||
repo := NewDomainRepository(db)
|
||||
|
||||
instanceID := "test-instance-id"
|
||||
domainName := "test.example.com"
|
||||
now := time.Now()
|
||||
|
||||
criteria := v2domain.DomainSearchCriteria{
|
||||
InstanceID: &instanceID,
|
||||
Domain: &domainName,
|
||||
}
|
||||
|
||||
rows := sqlmock.NewRows([]string{
|
||||
"instance_id", "org_id", "domain", "is_verified", "is_primary", "validation_type", "created_at", "updated_at", "deleted_at",
|
||||
}).AddRow(instanceID, nil, domainName, true, false, nil, now, now, nil)
|
||||
|
||||
mock.ExpectQuery(`SELECT .* FROM zitadel\.domains`).
|
||||
WithArgs(domainName, instanceID).
|
||||
WillReturnRows(rows)
|
||||
|
||||
result, err := repo.Get(context.Background(), criteria)
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, instanceID, result.InstanceID)
|
||||
assert.Nil(t, result.OrganizationID)
|
||||
assert.Equal(t, domainName, result.Domain)
|
||||
assert.True(t, result.IsVerified)
|
||||
assert.False(t, result.IsPrimary)
|
||||
|
||||
assert.NoError(t, mock.ExpectationsWereMet())
|
||||
}
|
||||
|
||||
func TestDomainRepository_List(t *testing.T) {
|
||||
db, mock, err := sqlmock.New()
|
||||
require.NoError(t, err)
|
||||
defer db.Close()
|
||||
|
||||
repo := NewDomainRepository(db)
|
||||
|
||||
instanceID := "test-instance-id"
|
||||
now := time.Now()
|
||||
|
||||
criteria := v2domain.DomainSearchCriteria{
|
||||
InstanceID: &instanceID,
|
||||
}
|
||||
|
||||
pagination := v2domain.DomainPagination{
|
||||
Limit: 10,
|
||||
Offset: 0,
|
||||
SortBy: v2domain.DomainSortFieldDomain,
|
||||
Order: database.SortOrderAsc,
|
||||
}
|
||||
|
||||
// Mock count query
|
||||
countRows := sqlmock.NewRows([]string{"count"}).AddRow(2)
|
||||
mock.ExpectQuery(`SELECT COUNT\(\*\) FROM zitadel\.domains`).
|
||||
WithArgs(instanceID).
|
||||
WillReturnRows(countRows)
|
||||
|
||||
// Mock data query
|
||||
rows := sqlmock.NewRows([]string{
|
||||
"instance_id", "org_id", "domain", "is_verified", "is_primary", "validation_type", "created_at", "updated_at", "deleted_at",
|
||||
}).
|
||||
AddRow(instanceID, nil, "instance.example.com", true, true, nil, now, now, nil).
|
||||
AddRow(instanceID, "org-id", "org.example.com", false, false, int(domain.OrgDomainValidationTypeHTTP), now, now, nil)
|
||||
|
||||
mock.ExpectQuery(`SELECT .* FROM zitadel\.domains.*ORDER BY domain ASC.*LIMIT 10`).
|
||||
WithArgs(instanceID).
|
||||
WillReturnRows(rows)
|
||||
|
||||
result, err := repo.List(context.Background(), criteria, pagination)
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, uint64(2), result.TotalCount)
|
||||
assert.Len(t, result.Domains, 2)
|
||||
|
||||
// Check first domain (instance domain)
|
||||
assert.Equal(t, instanceID, result.Domains[0].InstanceID)
|
||||
assert.Nil(t, result.Domains[0].OrganizationID)
|
||||
assert.Equal(t, "instance.example.com", result.Domains[0].Domain)
|
||||
assert.True(t, result.Domains[0].IsVerified)
|
||||
assert.True(t, result.Domains[0].IsPrimary)
|
||||
|
||||
// Check second domain (org domain)
|
||||
assert.Equal(t, instanceID, result.Domains[1].InstanceID)
|
||||
assert.Equal(t, "org-id", *result.Domains[1].OrganizationID)
|
||||
assert.Equal(t, "org.example.com", result.Domains[1].Domain)
|
||||
assert.False(t, result.Domains[1].IsVerified)
|
||||
assert.False(t, result.Domains[1].IsPrimary)
|
||||
|
||||
assert.NoError(t, mock.ExpectationsWereMet())
|
||||
}
|
Reference in New Issue
Block a user