implementation done

This commit is contained in:
adlerhurst
2025-07-22 19:09:56 +02:00
parent bb2d0aff3f
commit 9c348c0429
14 changed files with 536 additions and 179 deletions

View File

@@ -2,28 +2,35 @@ package domain
import "github.com/zitadel/zitadel/backend/v3/storage/database" import "github.com/zitadel/zitadel/backend/v3/storage/database"
type DomainVerificationType string type DomainValidationType string
const ( const (
DomainVerificationTypeDNS DomainVerificationType = "dns" DomainValidationTypeDNS DomainValidationType = "dns"
DomainVerificationTypeHTTP DomainVerificationType = "http" DomainValidationTypeHTTP DomainValidationType = "http"
) )
type domainColumns interface { type domainColumns interface {
// InstanceIDColumn returns the column for the instance id field. // InstanceIDColumn returns the column for the instance id field.
InstanceIDColumn() database.Column // `qualified` indicates if the column should be qualified with the table name.
InstanceIDColumn(qualified bool) database.Column
// DomainColumn returns the column for the domain field. // DomainColumn returns the column for the domain field.
DomainColumn() database.Column // `qualified` indicates if the column should be qualified with the table name.
DomainColumn(qualified bool) database.Column
// IsVerifiedColumn returns the column for the is verified field. // IsVerifiedColumn returns the column for the is verified field.
IsVerifiedColumn() database.Column // `qualified` indicates if the column should be qualified with the table name.
IsVerifiedColumn(qualified bool) database.Column
// IsPrimaryColumn returns the column for the is primary field. // IsPrimaryColumn returns the column for the is primary field.
IsPrimaryColumn() database.Column // `qualified` indicates if the column should be qualified with the table name.
// VerificationTypeColumn returns the column for the verification type field. IsPrimaryColumn(qualified bool) database.Column
VerificationTypeColumn() database.Column // ValidationTypeColumn returns the column for the verification type field.
// `qualified` indicates if the column should be qualified with the table name.
ValidationTypeColumn(qualified bool) database.Column
// CreatedAtColumn returns the column for the created at field. // CreatedAtColumn returns the column for the created at field.
CreatedAtColumn() database.Column // `qualified` indicates if the column should be qualified with the table name.
CreatedAtColumn(qualified bool) database.Column
// UpdatedAtColumn returns the column for the updated at field. // UpdatedAtColumn returns the column for the updated at field.
UpdatedAtColumn() database.Column // `qualified` indicates if the column should be qualified with the table name.
UpdatedAtColumn(qualified bool) database.Column
} }
type domainConditions interface { type domainConditions interface {
@@ -51,9 +58,9 @@ type domainChanges interface {
// - The domain is already primary. // - The domain is already primary.
// - No domain matches the condition. // - No domain matches the condition.
SetPrimary() database.Change SetPrimary() database.Change
// SetVerificationType sets the verification type column. // SetValidationType sets the verification type column.
// If the domain is already verified, this is a no-op. // If the domain is already verified, this is a no-op.
SetVerificationType(verificationType DomainVerificationType) database.Change SetValidationType(verificationType DomainValidationType) database.Change
} }
// import ( // import (

View File

@@ -18,6 +18,8 @@ type Instance struct {
DefaultLanguage string `json:"defaultLanguage,omitempty" db:"default_language"` DefaultLanguage string `json:"defaultLanguage,omitempty" db:"default_language"`
CreatedAt time.Time `json:"createdAt" db:"created_at"` CreatedAt time.Time `json:"createdAt" db:"created_at"`
UpdatedAt time.Time `json:"updatedAt" db:"updated_at"` UpdatedAt time.Time `json:"updatedAt" db:"updated_at"`
Domains []*InstanceDomain `json:"domains,omitempty" db:"-"`
} }
type instanceCacheIndex uint8 type instanceCacheIndex uint8
@@ -40,23 +42,32 @@ var _ cache.Entry[instanceCacheIndex, string] = (*Instance)(nil)
// instanceColumns define all the columns of the instance table. // instanceColumns define all the columns of the instance table.
type instanceColumns interface { type instanceColumns interface {
// IDColumn returns the column for the id field. // IDColumn returns the column for the id field.
IDColumn() database.Column // `qualified` indicates if the column should be qualified with the table name.
IDColumn(qualified bool) database.Column
// NameColumn returns the column for the name field. // NameColumn returns the column for the name field.
NameColumn() database.Column // `qualified` indicates if the column should be qualified with the table name.
NameColumn(qualified bool) database.Column
// DefaultOrgIDColumn returns the column for the default org id field // DefaultOrgIDColumn returns the column for the default org id field
DefaultOrgIDColumn() database.Column // `qualified` indicates if the column should be qualified with the table name.
DefaultOrgIDColumn(qualified bool) database.Column
// IAMProjectIDColumn returns the column for the default IAM org id field // IAMProjectIDColumn returns the column for the default IAM org id field
IAMProjectIDColumn() database.Column // `qualified` indicates if the column should be qualified with the table name.
IAMProjectIDColumn(qualified bool) database.Column
// ConsoleClientIDColumn returns the column for the default IAM org id field // ConsoleClientIDColumn returns the column for the default IAM org id field
ConsoleClientIDColumn() database.Column // `qualified` indicates if the column should be qualified with the table name.
ConsoleClientIDColumn(qualified bool) database.Column
// ConsoleAppIDColumn returns the column for the console client id field // ConsoleAppIDColumn returns the column for the console client id field
ConsoleAppIDColumn() database.Column // `qualified` indicates if the column should be qualified with the table name.
ConsoleAppIDColumn(qualified bool) database.Column
// DefaultLanguageColumn returns the column for the default language field // DefaultLanguageColumn returns the column for the default language field
DefaultLanguageColumn() database.Column // `qualified` indicates if the column should be qualified with the table name.
DefaultLanguageColumn(qualified bool) database.Column
// CreatedAtColumn returns the column for the created at field. // CreatedAtColumn returns the column for the created at field.
CreatedAtColumn() database.Column // `qualified` indicates if the column should be qualified with the table name.
CreatedAtColumn(qualified bool) database.Column
// UpdatedAtColumn returns the column for the updated at field. // UpdatedAtColumn returns the column for the updated at field.
UpdatedAtColumn() database.Column // `qualified` indicates if the column should be qualified with the table name.
UpdatedAtColumn(qualified bool) database.Column
} }
// instanceConditions define all the conditions for the instance table. // instanceConditions define all the conditions for the instance table.
@@ -83,16 +94,27 @@ type InstanceRepository interface {
// Member returns the member repository which is a sub repository of the instance repository. // Member returns the member repository which is a sub repository of the instance repository.
// Member() MemberRepository // Member() MemberRepository
Get(ctx context.Context, id string) (*Instance, error) Get(ctx context.Context, opts ...database.QueryOption) (*Instance, error)
List(ctx context.Context, opts ...database.Condition) ([]*Instance, error) List(ctx context.Context, opts ...database.QueryOption) ([]*Instance, error)
Create(ctx context.Context, instance *Instance) error Create(ctx context.Context, instance *Instance) error
Update(ctx context.Context, id string, changes ...database.Change) (int64, error) Update(ctx context.Context, id string, changes ...database.Change) (int64, error)
Delete(ctx context.Context, id string) (int64, error) Delete(ctx context.Context, id string) (int64, error)
Domains() InstanceDomainRepository // Domains returns the domain sub repository for the instance.
// If shouldLoad is true, the domains will be loaded from the database and written to the [Instance].Domains field.
// If shouldLoad is set to true once, the Domains field will be set event if shouldLoad is false in the future.
Domains(shouldLoad bool) InstanceDomainRepository
} }
type CreateInstance struct { type CreateInstance struct {
Name string `json:"name"` Name string `json:"name"`
} }
type InstanceQueryOption func(*InstanceQueryOpts)
type InstanceQueryOpts struct {
database.QueryOpts
JoinDomains bool
}

View File

@@ -2,17 +2,23 @@ package domain
import ( import (
"context" "context"
"encoding/json"
"time" "time"
"github.com/zitadel/zitadel/backend/v3/storage/database" "github.com/zitadel/zitadel/backend/v3/storage/database"
) )
type InstanceDomains struct {
domains []*InstanceDomain
Raw json.RawMessage
}
type InstanceDomain struct { type InstanceDomain struct {
InstanceID string `json:"instanceId,omitempty" db:"instance_id"` InstanceID string `json:"instanceId,omitempty" db:"instance_id"`
Domain string `json:"domain,omitempty" db:"domain"` Domain string `json:"domain,omitempty" db:"domain"`
IsVerified bool `json:"isVerified,omitempty" db:"is_verified"` IsVerified bool `json:"isVerified,omitempty" db:"is_verified"`
IsPrimary bool `json:"isPrimary,omitempty" db:"is_primary"` IsPrimary bool `json:"isPrimary,omitempty" db:"is_primary"`
VerificationType DomainVerificationType `json:"verificationType,omitempty" db:"verification_type"` ValidationType DomainValidationType `json:"validationType,omitempty" db:"validation_type"`
CreatedAt string `json:"createdAt,omitempty" db:"created_at"` CreatedAt string `json:"createdAt,omitempty" db:"created_at"`
UpdatedAt string `json:"updatedAt,omitempty" db:"updated_at"` UpdatedAt string `json:"updatedAt,omitempty" db:"updated_at"`
@@ -23,7 +29,7 @@ type AddInstanceDomain struct {
Domain string `json:"domain,omitempty" db:"domain"` Domain string `json:"domain,omitempty" db:"domain"`
IsVerified bool `json:"isVerified,omitempty" db:"is_verified"` IsVerified bool `json:"isVerified,omitempty" db:"is_verified"`
IsPrimary bool `json:"isPrimary,omitempty" db:"is_primary"` IsPrimary bool `json:"isPrimary,omitempty" db:"is_primary"`
VerificationType DomainVerificationType `json:"verificationType,omitempty" db:"verification_type"` VerificationType DomainValidationType `json:"validationType,omitempty" db:"validation_type"`
// CreatedAt is the time when the domain was added. // CreatedAt is the time when the domain was added.
// It is set by the repository and should not be set by the caller. // It is set by the repository and should not be set by the caller.
@@ -36,7 +42,8 @@ type AddInstanceDomain struct {
type instanceDomainColumns interface { type instanceDomainColumns interface {
domainColumns domainColumns
// IsGeneratedColumn returns the column for the is generated field. // IsGeneratedColumn returns the column for the is generated field.
IsGeneratedColumn() database.Column // `qualified` indicates if the column should be qualified with the table name.
IsGeneratedColumn(qualified bool) database.Column
} }
type instanceDomainConditions interface { type instanceDomainConditions interface {

View File

@@ -19,8 +19,10 @@ type Organization struct {
Name string `json:"name,omitempty" db:"name"` Name string `json:"name,omitempty" db:"name"`
InstanceID string `json:"instanceId,omitempty" db:"instance_id"` InstanceID string `json:"instanceId,omitempty" db:"instance_id"`
State OrgState `json:"state,omitempty" db:"state"` State OrgState `json:"state,omitempty" db:"state"`
CreatedAt time.Time `json:"createdAt,omitempty" db:"created_at"` CreatedAt time.Time `json:"createdAt,omitzero" db:"created_at"`
UpdatedAt time.Time `json:"updatedAt,omitempty" db:"updated_at"` UpdatedAt time.Time `json:"updatedAt,omitzero" db:"updated_at"`
Domains []*OrganizationDomain `json:"domains,omitempty" db:"-"`
} }
// OrgIdentifierCondition is used to help specify a single Organization, // OrgIdentifierCondition is used to help specify a single Organization,
@@ -33,17 +35,23 @@ type OrgIdentifierCondition interface {
// organizationColumns define all the columns of the instance table. // organizationColumns define all the columns of the instance table.
type organizationColumns interface { type organizationColumns interface {
// IDColumn returns the column for the id field. // IDColumn returns the column for the id field.
IDColumn() database.Column // `qualified` indicates if the column should be qualified with the table name.
IDColumn(qualified bool) database.Column
// NameColumn returns the column for the name field. // NameColumn returns the column for the name field.
NameColumn() database.Column // `qualified` indicates if the column should be qualified with the table name.
NameColumn(qualified bool) database.Column
// InstanceIDColumn returns the column for the default org id field // InstanceIDColumn returns the column for the default org id field
InstanceIDColumn() database.Column // `qualified` indicates if the column should be qualified with the table name.
InstanceIDColumn(qualified bool) database.Column
// StateColumn returns the column for the name field. // StateColumn returns the column for the name field.
StateColumn() database.Column // `qualified` indicates if the column should be qualified with the table name.
StateColumn(qualified bool) database.Column
// CreatedAtColumn returns the column for the created at field. // CreatedAtColumn returns the column for the created at field.
CreatedAtColumn() database.Column // `qualified` indicates if the column should be qualified with the table name.
CreatedAtColumn(qualified bool) database.Column
// UpdatedAtColumn returns the column for the updated at field. // UpdatedAtColumn returns the column for the updated at field.
UpdatedAtColumn() database.Column // `qualified` indicates if the column should be qualified with the table name.
UpdatedAtColumn(qualified bool) database.Column
} }
// organizationConditions define all the conditions for the instance table. // organizationConditions define all the conditions for the instance table.
@@ -72,15 +80,17 @@ type OrganizationRepository interface {
organizationConditions organizationConditions
organizationChanges organizationChanges
Get(ctx context.Context, id OrgIdentifierCondition, instance_id string, opts ...database.Condition) (*Organization, error) Get(ctx context.Context, opts ...database.QueryOption) (*Organization, error)
List(ctx context.Context, conditions ...database.Condition) ([]*Organization, error) List(ctx context.Context, opts ...database.QueryOption) ([]*Organization, error)
Create(ctx context.Context, instance *Organization) error Create(ctx context.Context, instance *Organization) error
Update(ctx context.Context, id OrgIdentifierCondition, instance_id string, changes ...database.Change) (int64, error) Update(ctx context.Context, id OrgIdentifierCondition, instance_id string, changes ...database.Change) (int64, error)
Delete(ctx context.Context, id OrgIdentifierCondition, instance_id string) (int64, error) Delete(ctx context.Context, id OrgIdentifierCondition, instance_id string) (int64, error)
// Domains returns the domain sub repository for the organization. // Domains returns the domain sub repository for the organization.
Domains() OrganizationDomainRepository // If shouldLoad is true, the domains will be loaded from the database and written to the [Instance].Domains field.
// If shouldLoad is set to true once, the Domains field will be set event if shouldLoad is false in the future.
Domains(shouldLoad bool) OrganizationDomainRepository
} }
type CreateOrganization struct { type CreateOrganization struct {

View File

@@ -13,7 +13,7 @@ type OrganizationDomain struct {
Domain string `json:"domain,omitempty" db:"domain"` Domain string `json:"domain,omitempty" db:"domain"`
IsVerified bool `json:"isVerified,omitempty" db:"is_verified"` IsVerified bool `json:"isVerified,omitempty" db:"is_verified"`
IsPrimary bool `json:"isPrimary,omitempty" db:"is_primary"` IsPrimary bool `json:"isPrimary,omitempty" db:"is_primary"`
VerificationType DomainVerificationType `json:"verificationType,omitempty" db:"verification_type"` ValidationType DomainValidationType `json:"validationType,omitempty" db:"validation_type"`
CreatedAt string `json:"createdAt,omitempty" db:"created_at"` CreatedAt string `json:"createdAt,omitempty" db:"created_at"`
UpdatedAt string `json:"updatedAt,omitempty" db:"updated_at"` UpdatedAt string `json:"updatedAt,omitempty" db:"updated_at"`
@@ -25,7 +25,7 @@ type AddOrganizationDomain struct {
Domain string `json:"domain,omitempty" db:"domain"` Domain string `json:"domain,omitempty" db:"domain"`
IsVerified bool `json:"isVerified,omitempty" db:"is_verified"` IsVerified bool `json:"isVerified,omitempty" db:"is_verified"`
IsPrimary bool `json:"isPrimary,omitempty" db:"is_primary"` IsPrimary bool `json:"isPrimary,omitempty" db:"is_primary"`
VerificationType DomainVerificationType `json:"verificationType,omitempty" db:"verification_type"` ValidationType DomainValidationType `json:"validationType,omitempty" db:"validation_type"`
// CreatedAt is the time when the domain was added. // CreatedAt is the time when the domain was added.
// It is set by the repository and should not be set by the caller. // It is set by the repository and should not be set by the caller.
@@ -38,7 +38,7 @@ type AddOrganizationDomain struct {
type organizationDomainColumns interface { type organizationDomainColumns interface {
domainColumns domainColumns
// OrgIDColumn returns the column for the org id field. // OrgIDColumn returns the column for the org id field.
OrgIDColumn() database.Column OrgIDColumn(qualified bool) database.Column
} }
type organizationDomainConditions interface { type organizationDomainConditions interface {

View File

@@ -113,6 +113,15 @@ func NewBooleanCondition[V Boolean](col Column, value V) Condition {
}) })
} }
// NewColumnCondition creates a condition that compares two columns on equality.
func NewColumnCondition(col1, col2 Column) Condition {
return valueCondition(func(builder *StatementBuilder) {
col1.Write(builder)
builder.WriteString(" = ")
col2.Write(builder)
})
}
// Write implements [Condition]. // Write implements [Condition].
func (c valueCondition) Write(builder *StatementBuilder) { func (c valueCondition) Write(builder *StatementBuilder) {
c(builder) c(builder)

View File

@@ -1,9 +1,12 @@
package database package database
import ( import (
"errors"
"fmt" "fmt"
) )
var NoChangesError = errors.New("Update must contain a change")
// NoRowFoundError is returned when QueryRow does not find any row. // NoRowFoundError is returned when QueryRow does not find any row.
// It wraps the dialect specific original error to provide more context. // It wraps the dialect specific original error to provide more context.
type NoRowFoundError struct { type NoRowFoundError struct {

View File

@@ -30,6 +30,36 @@ func WithOffset(offset uint32) QueryOption {
} }
} }
// WithGroupBy sets the columns to group the results by.
func WithGroupBy(groupBy ...Column) QueryOption {
return func(opts *QueryOpts) {
opts.GroupBy = groupBy
}
}
// WithLeftJoin adds a LEFT JOIN to the query.
func WithLeftJoin(table string, columns Condition) QueryOption {
return func(opts *QueryOpts) {
opts.Joins = append(opts.Joins, join{
table: table,
typ: JoinTypeLeft,
columns: columns,
})
}
}
type joinType string
const (
JoinTypeLeft joinType = "LEFT"
)
type join struct {
table string
typ joinType
columns Condition
}
// QueryOpts holds the options for a query. // QueryOpts holds the options for a query.
// It is used to build the SQL SELECT statement. // It is used to build the SQL SELECT statement.
type QueryOpts struct { type QueryOpts struct {
@@ -45,10 +75,19 @@ type QueryOpts struct {
// Offset is the number of results to skip before returning the results. // Offset is the number of results to skip before returning the results.
// It is used to build the OFFSET clause of the SQL statement. // It is used to build the OFFSET clause of the SQL statement.
Offset uint32 Offset uint32
// GroupBy is the columns to group the results by.
// It is used to build the GROUP BY clause of the SQL statement.
GroupBy Columns
// Joins is a list of joins to be applied to the query.
// It is used to build the JOIN clauses of the SQL statement.
Joins []join
} }
func (opts *QueryOpts) Write(builder *StatementBuilder) { func (opts *QueryOpts) Write(builder *StatementBuilder) {
opts.WriteLeftJoins(builder)
opts.WriteCondition(builder) opts.WriteCondition(builder)
opts.WriteGroupBy(builder)
opts.WriteOrderBy(builder) opts.WriteOrderBy(builder)
opts.WriteLimit(builder) opts.WriteLimit(builder)
opts.WriteOffset(builder) opts.WriteOffset(builder)
@@ -85,3 +124,25 @@ func (opts *QueryOpts) WriteOffset(builder *StatementBuilder) {
builder.WriteString(" OFFSET ") builder.WriteString(" OFFSET ")
builder.WriteArg(opts.Offset) builder.WriteArg(opts.Offset)
} }
func (opts *QueryOpts) WriteGroupBy(builder *StatementBuilder) {
if len(opts.GroupBy) == 0 {
return
}
builder.WriteString(" GROUP BY ")
opts.GroupBy.Write(builder)
}
func (opts *QueryOpts) WriteLeftJoins(builder *StatementBuilder) {
if len(opts.Joins) == 0 {
return
}
for _, join := range opts.Joins {
builder.WriteString(" ")
builder.WriteString(string(join.typ))
builder.WriteString(" JOIN ")
builder.WriteString(join.table)
builder.WriteString(" ON ")
join.columns.Write(builder)
}
}

View File

@@ -2,7 +2,7 @@ package repository
import ( import (
"context" "context"
"errors" "encoding/json"
"github.com/zitadel/zitadel/backend/v3/domain" "github.com/zitadel/zitadel/backend/v3/domain"
"github.com/zitadel/zitadel/backend/v3/storage/database" "github.com/zitadel/zitadel/backend/v3/storage/database"
@@ -12,8 +12,8 @@ var _ domain.InstanceRepository = (*instance)(nil)
type instance struct { type instance struct {
repository repository
shouldJoinDomains bool shouldLoadDomains bool
domainRepo domain.InstanceDomainRepository domainRepo *instanceDomain
} }
func InstanceRepository(client database.QueryExecutor) domain.InstanceRepository { func InstanceRepository(client database.QueryExecutor) domain.InstanceRepository {
@@ -24,38 +24,71 @@ func InstanceRepository(client database.QueryExecutor) domain.InstanceRepository
} }
} }
// ------------------------------------------------------------- // -------------------------------------------------------------
// repository // repository
// ------------------------------------------------------------- // -------------------------------------------------------------
const queryInstanceStmt = `SELECT id, name, default_org_id, iam_project_id, console_client_id, console_app_id, default_language, created_at, updated_at` + const (
` FROM zitadel.instances` queryInstanceStmt = `SELECT instances.id, instances.name, instances.default_org_id, instances.iam_project_id, instances.console_client_id, instances.console_app_id, instances.default_language, instances.created_at, instances.updated_at` +
` , CASE WHEN count(instance_domains.domain) > 0 THEN jsonb_agg(json_build_object('domain', instance_domains.domain, 'isVerified', instance_domains.is_verified, 'isPrimary', instance_domains.is_primary, 'isGenerated', instance_domains.is_generated, 'validationType', instance_domains.validation_type, 'createdAt', instance_domains.created_at, 'updatedAt', instance_domains.updated_at)) ELSE NULL::JSONB END domains` +
` FROM zitadel.instances`
)
// Get implements [domain.InstanceRepository]. // Get implements [domain.InstanceRepository].
func (i *instance) Get(ctx context.Context, id string) (*domain.Instance, error) { func (i *instance) Get(ctx context.Context, opts ...database.QueryOption) (*domain.Instance, error) {
opts = append(opts,
i.joinDomains(),
database.WithGroupBy(i.IDColumn(true)),
)
options := new(database.QueryOpts)
for _, opt := range opts {
opt(options)
}
var builder database.StatementBuilder var builder database.StatementBuilder
builder.WriteString(queryInstanceStmt) builder.WriteString(queryInstanceStmt)
options.Write(&builder)
idCondition := i.IDCondition(id)
writeCondition(&builder, idCondition)
return scanInstance(ctx, i.client, &builder) return scanInstance(ctx, i.client, &builder)
} }
// List implements [domain.InstanceRepository]. // List implements [domain.InstanceRepository].
func (i *instance) List(ctx context.Context, conditions ...database.Condition) ([]*domain.Instance, error) { func (i *instance) List(ctx context.Context, opts ...database.QueryOption) ([]*domain.Instance, error) {
var builder database.StatementBuilder opts = append(opts,
i.joinDomains(),
builder.WriteString(queryInstanceStmt) database.WithGroupBy(i.IDColumn(true)),
)
if conditions != nil {
writeCondition(&builder, database.And(conditions...)) options := new(database.QueryOpts)
for _, opt := range opts {
opt(options)
} }
var builder database.StatementBuilder
builder.WriteString(queryInstanceStmt)
options.Write(&builder)
return scanInstances(ctx, i.client, &builder) return scanInstances(ctx, i.client, &builder)
} }
func (i *instance) joinDomains() database.QueryOption {
columns := make([]database.Condition, 0, 2)
columns = append(columns, database.NewColumnCondition(i.IDColumn(true), i.Domains(false).InstanceIDColumn(true)))
// If domains should not be joined, we make sure to return null for the domain columns
// the query optimizer of the dialect should optimize this away if no domains are requested
if !i.shouldLoadDomains {
columns = append(columns, database.IsNull(i.Domains(false).InstanceIDColumn(true)))
}
return database.WithLeftJoin(
"zitadel.instance_domains",
database.And(columns...),
)
}
const createInstanceStmt = `INSERT INTO zitadel.instances (id, name, default_org_id, iam_project_id, console_client_id, console_app_id, default_language)` + const createInstanceStmt = `INSERT INTO zitadel.instances (id, name, default_org_id, iam_project_id, console_client_id, console_app_id, default_language)` +
` VALUES ($1, $2, $3, $4, $5, $6, $7)` + ` VALUES ($1, $2, $3, $4, $5, $6, $7)` +
` RETURNING created_at, updated_at` ` RETURNING created_at, updated_at`
@@ -72,8 +105,8 @@ func (i *instance) Create(ctx context.Context, instance *domain.Instance) error
// Update implements [domain.InstanceRepository]. // Update implements [domain.InstanceRepository].
func (i instance) Update(ctx context.Context, id string, changes ...database.Change) (int64, error) { func (i instance) Update(ctx context.Context, id string, changes ...database.Change) (int64, error) {
if changes == nil { if len(changes) == 0 {
return 0, errors.New("Update must contain a change") return 0, database.NoChangesError
} }
var builder database.StatementBuilder var builder database.StatementBuilder
@@ -107,7 +140,7 @@ func (i instance) Delete(ctx context.Context, id string) (int64, error) {
// SetName implements [domain.instanceChanges]. // SetName implements [domain.instanceChanges].
func (i instance) SetName(name string) database.Change { func (i instance) SetName(name string) database.Change {
return database.NewChange(i.NameColumn(), name) return database.NewChange(i.NameColumn(false), name)
} }
// ------------------------------------------------------------- // -------------------------------------------------------------
@@ -116,12 +149,12 @@ func (i instance) SetName(name string) database.Change {
// IDCondition implements [domain.instanceConditions]. // IDCondition implements [domain.instanceConditions].
func (i instance) IDCondition(id string) database.Condition { func (i instance) IDCondition(id string) database.Condition {
return database.NewTextCondition(i.IDColumn(), database.TextOperationEqual, id) return database.NewTextCondition(i.IDColumn(true), database.TextOperationEqual, id)
} }
// NameCondition implements [domain.instanceConditions]. // NameCondition implements [domain.instanceConditions].
func (i instance) NameCondition(op database.TextOperation, name string) database.Condition { func (i instance) NameCondition(op database.TextOperation, name string) database.Condition {
return database.NewTextCondition(i.NameColumn(), op, name) return database.NewTextCondition(i.NameColumn(true), op, name)
} }
// ------------------------------------------------------------- // -------------------------------------------------------------
@@ -129,74 +162,122 @@ func (i instance) NameCondition(op database.TextOperation, name string) database
// ------------------------------------------------------------- // -------------------------------------------------------------
// IDColumn implements [domain.instanceColumns]. // IDColumn implements [domain.instanceColumns].
func (instance) IDColumn() database.Column { func (instance) IDColumn(qualified bool) database.Column {
if qualified {
return database.NewColumn("instances.id")
}
return database.NewColumn("id") return database.NewColumn("id")
} }
// NameColumn implements [domain.instanceColumns]. // NameColumn implements [domain.instanceColumns].
func (instance) NameColumn() database.Column { func (instance) NameColumn(qualified bool) database.Column {
if qualified {
return database.NewColumn("instances.name")
}
return database.NewColumn("name") return database.NewColumn("name")
} }
// CreatedAtColumn implements [domain.instanceColumns]. // CreatedAtColumn implements [domain.instanceColumns].
func (instance) CreatedAtColumn() database.Column { func (instance) CreatedAtColumn(qualified bool) database.Column {
if qualified {
return database.NewColumn("instances.created_at")
}
return database.NewColumn("created_at") return database.NewColumn("created_at")
} }
// DefaultOrgIdColumn implements [domain.instanceColumns]. // DefaultOrgIdColumn implements [domain.instanceColumns].
func (instance) DefaultOrgIDColumn() database.Column { func (instance) DefaultOrgIDColumn(qualified bool) database.Column {
if qualified {
return database.NewColumn("instances.default_org_id")
}
return database.NewColumn("default_org_id") return database.NewColumn("default_org_id")
} }
// IAMProjectIDColumn implements [domain.instanceColumns]. // IAMProjectIDColumn implements [domain.instanceColumns].
func (instance) IAMProjectIDColumn() database.Column { func (instance) IAMProjectIDColumn(qualified bool) database.Column {
if qualified {
return database.NewColumn("instances.iam_project_id")
}
return database.NewColumn("iam_project_id") return database.NewColumn("iam_project_id")
} }
// ConsoleClientIDColumn implements [domain.instanceColumns]. // ConsoleClientIDColumn implements [domain.instanceColumns].
func (instance) ConsoleClientIDColumn() database.Column { func (instance) ConsoleClientIDColumn(qualified bool) database.Column {
if qualified {
return database.NewColumn("instances.console_client_id")
}
return database.NewColumn("console_client_id") return database.NewColumn("console_client_id")
} }
// ConsoleAppIDColumn implements [domain.instanceColumns]. // ConsoleAppIDColumn implements [domain.instanceColumns].
func (instance) ConsoleAppIDColumn() database.Column { func (instance) ConsoleAppIDColumn(qualified bool) database.Column {
if qualified {
return database.NewColumn("instances.console_app_id")
}
return database.NewColumn("console_app_id") return database.NewColumn("console_app_id")
} }
// DefaultLanguageColumn implements [domain.instanceColumns]. // DefaultLanguageColumn implements [domain.instanceColumns].
func (instance) DefaultLanguageColumn() database.Column { func (instance) DefaultLanguageColumn(qualified bool) database.Column {
if qualified {
return database.NewColumn("instances.default_language")
}
return database.NewColumn("default_language") return database.NewColumn("default_language")
} }
// UpdatedAtColumn implements [domain.instanceColumns]. // UpdatedAtColumn implements [domain.instanceColumns].
func (instance) UpdatedAtColumn() database.Column { func (instance) UpdatedAtColumn(qualified bool) database.Column {
if qualified {
return database.NewColumn("instances.updated_at")
}
return database.NewColumn("updated_at") return database.NewColumn("updated_at")
} }
type rawInstance struct {
*domain.Instance
RawDomains json.RawMessage `json:"domains,omitempty" db:"domains"`
}
func scanInstance(ctx context.Context, querier database.Querier, builder *database.StatementBuilder) (*domain.Instance, error) { func scanInstance(ctx context.Context, querier database.Querier, builder *database.StatementBuilder) (*domain.Instance, error) {
rows, err := querier.Query(ctx, builder.String(), builder.Args()...) rows, err := querier.Query(ctx, builder.String(), builder.Args()...)
if err != nil { if err != nil {
return nil, err return nil, err
} }
instance := new(domain.Instance) var instance rawInstance
if err := rows.(database.CollectableRows).CollectExactlyOneRow(instance); err != nil { if err := rows.(database.CollectableRows).CollectExactlyOneRow(&instance); err != nil {
return nil, err return nil, err
} }
return instance, nil if len(instance.RawDomains) > 0 {
if err := json.Unmarshal(instance.RawDomains, &instance.Domains); err != nil {
return nil, err
}
}
return instance.Instance, nil
} }
func scanInstances(ctx context.Context, querier database.Querier, builder *database.StatementBuilder) (instances []*domain.Instance, err error) { func scanInstances(ctx context.Context, querier database.Querier, builder *database.StatementBuilder) ([]*domain.Instance, error) {
rows, err := querier.Query(ctx, builder.String(), builder.Args()...) rows, err := querier.Query(ctx, builder.String(), builder.Args()...)
if err != nil { if err != nil {
return nil, err return nil, err
} }
if err := rows.(database.CollectableRows).Collect(&instances); err != nil { var rawInstances []*rawInstance
if err := rows.(database.CollectableRows).Collect(&rawInstances); err != nil {
return nil, err return nil, err
} }
instances := make([]*domain.Instance, len(rawInstances))
for i, instance := range rawInstances {
if len(instance.RawDomains) > 0 {
if err := json.Unmarshal(instance.RawDomains, &instance.Domains); err != nil {
return nil, err
}
}
instances[i] = instance.Instance
}
return instances, nil return instances, nil
} }
@@ -205,8 +286,10 @@ func scanInstances(ctx context.Context, querier database.Querier, builder *datab
// ------------------------------------------------------------- // -------------------------------------------------------------
// Domains implements [domain.InstanceRepository]. // Domains implements [domain.InstanceRepository].
func (i *instance) Domains() domain.InstanceDomainRepository { func (i *instance) Domains(shouldLoad bool) domain.InstanceDomainRepository {
i.shouldJoinDomains = true if !i.shouldLoadDomains {
i.shouldLoadDomains = shouldLoad
}
if i.domainRepo != nil { if i.domainRepo != nil {
return i.domainRepo return i.domainRepo

View File

@@ -18,8 +18,8 @@ type instanceDomain struct {
// repository // repository
// ------------------------------------------------------------- // -------------------------------------------------------------
const queryInstanceDomainStmt = `SELECT instance_id, domain, is_verified, is_primary, verification_type, created_at, updated_at ` + const queryInstanceDomainStmt = `SELECT instance_domains.instance_id, instance_domains.domain, instance_domains.is_verified, instance_domains.is_primary, instance_domains.validation_type, instance_domains.created_at, instance_domains.updated_at ` +
`FROM zitadel.instance_domains` `FROM zitadel.instance_domains id`
// Get implements [domain.InstanceDomainRepository]. // Get implements [domain.InstanceDomainRepository].
// Subtle: this method shadows the method ([domain.InstanceRepository]).Get of instanceDomain.instance. // Subtle: this method shadows the method ([domain.InstanceRepository]).Get of instanceDomain.instance.
@@ -55,7 +55,7 @@ func (i *instanceDomain) List(ctx context.Context, opts ...database.QueryOption)
func (i *instanceDomain) Add(ctx context.Context, domain *domain.AddInstanceDomain) error { func (i *instanceDomain) Add(ctx context.Context, domain *domain.AddInstanceDomain) error {
var builder database.StatementBuilder var builder database.StatementBuilder
builder.WriteString(`INSERT INTO zitadel.instance_domains (instance_id, domain, is_verified, is_primary, verification_type) ` + builder.WriteString(`INSERT INTO zitadel.instance_domains (instance_id, domain, is_verified, is_primary, validation_type) ` +
`VALUES ($1, $2, $3, $4, $5)` + `VALUES ($1, $2, $3, $4, $5)` +
` RETURNING created_at, updated_at`) ` RETURNING created_at, updated_at`)
@@ -91,19 +91,19 @@ func (i *instanceDomain) Update(ctx context.Context, condition database.Conditio
// changes // changes
// ------------------------------------------------------------- // -------------------------------------------------------------
// SetVerificationType implements [domain.InstanceDomainRepository]. // SetValidationType implements [domain.InstanceDomainRepository].
func (i instanceDomain) SetVerificationType(verificationType domain.DomainVerificationType) database.Change { func (i instanceDomain) SetValidationType(verificationType domain.DomainValidationType) database.Change {
return database.NewChange(i.VerificationTypeColumn(), verificationType) return database.NewChange(i.ValidationTypeColumn(false), verificationType)
} }
// SetPrimary implements [domain.InstanceDomainRepository]. // SetPrimary implements [domain.InstanceDomainRepository].
func (i instanceDomain) SetPrimary() database.Change { func (i instanceDomain) SetPrimary() database.Change {
return database.NewChange(i.IsPrimaryColumn(), true) return database.NewChange(i.IsPrimaryColumn(false), true)
} }
// SetVerified implements [domain.InstanceDomainRepository]. // SetVerified implements [domain.InstanceDomainRepository].
func (i instanceDomain) SetVerified() database.Change { func (i instanceDomain) SetVerified() database.Change {
return database.NewChange(i.IsVerifiedColumn(), true) return database.NewChange(i.IsVerifiedColumn(false), true)
} }
// ------------------------------------------------------------- // -------------------------------------------------------------
@@ -112,22 +112,22 @@ func (i instanceDomain) SetVerified() database.Change {
// DomainCondition implements [domain.InstanceDomainRepository]. // DomainCondition implements [domain.InstanceDomainRepository].
func (i instanceDomain) DomainCondition(op database.TextOperation, domain string) database.Condition { func (i instanceDomain) DomainCondition(op database.TextOperation, domain string) database.Condition {
return database.NewTextCondition(i.DomainColumn(), op, domain) return database.NewTextCondition(i.DomainColumn(true), op, domain)
} }
// InstanceIDCondition implements [domain.InstanceDomainRepository]. // InstanceIDCondition implements [domain.InstanceDomainRepository].
func (i instanceDomain) InstanceIDCondition(instanceID string) database.Condition { func (i instanceDomain) InstanceIDCondition(instanceID string) database.Condition {
return database.NewTextCondition(i.InstanceIDColumn(), database.TextOperationEqual, instanceID) return database.NewTextCondition(i.InstanceIDColumn(true), database.TextOperationEqual, instanceID)
} }
// IsPrimaryCondition implements [domain.InstanceDomainRepository]. // IsPrimaryCondition implements [domain.InstanceDomainRepository].
func (i instanceDomain) IsPrimaryCondition(isPrimary bool) database.Condition { func (i instanceDomain) IsPrimaryCondition(isPrimary bool) database.Condition {
return database.NewBooleanCondition(i.IsPrimaryColumn(), isPrimary) return database.NewBooleanCondition(i.IsPrimaryColumn(true), isPrimary)
} }
// IsVerifiedCondition implements [domain.InstanceDomainRepository]. // IsVerifiedCondition implements [domain.InstanceDomainRepository].
func (i instanceDomain) IsVerifiedCondition(isVerified bool) database.Condition { func (i instanceDomain) IsVerifiedCondition(isVerified bool) database.Condition {
return database.NewBooleanCondition(i.IsVerifiedColumn(), isVerified) return database.NewBooleanCondition(i.IsVerifiedColumn(true), isVerified)
} }
// ------------------------------------------------------------- // -------------------------------------------------------------
@@ -136,43 +136,67 @@ func (i instanceDomain) IsVerifiedCondition(isVerified bool) database.Condition
// CreatedAtColumn implements [domain.InstanceDomainRepository]. // CreatedAtColumn implements [domain.InstanceDomainRepository].
// Subtle: this method shadows the method ([domain.InstanceRepository]).CreatedAtColumn of instanceDomain.instance. // Subtle: this method shadows the method ([domain.InstanceRepository]).CreatedAtColumn of instanceDomain.instance.
func (instanceDomain) CreatedAtColumn() database.Column { func (instanceDomain) CreatedAtColumn(qualified bool) database.Column {
if qualified {
return database.NewColumn("instance_domains.created_at")
}
return database.NewColumn("created_at") return database.NewColumn("created_at")
} }
// DomainColumn implements [domain.InstanceDomainRepository]. // DomainColumn implements [domain.InstanceDomainRepository].
func (instanceDomain) DomainColumn() database.Column { func (instanceDomain) DomainColumn(qualified bool) database.Column {
if qualified {
return database.NewColumn("instance_domains.domain")
}
return database.NewColumn("domain") return database.NewColumn("domain")
} }
// InstanceIDColumn implements [domain.InstanceDomainRepository]. // InstanceIDColumn implements [domain.InstanceDomainRepository].
func (instanceDomain) InstanceIDColumn() database.Column { func (instanceDomain) InstanceIDColumn(qualified bool) database.Column {
if qualified {
return database.NewColumn("instance_domains.instance_id")
}
return database.NewColumn("instance_id") return database.NewColumn("instance_id")
} }
// IsPrimaryColumn implements [domain.InstanceDomainRepository]. // IsPrimaryColumn implements [domain.InstanceDomainRepository].
func (instanceDomain) IsPrimaryColumn() database.Column { func (instanceDomain) IsPrimaryColumn(qualified bool) database.Column {
if qualified {
return database.NewColumn("instance_domains.is_primary")
}
return database.NewColumn("is_primary") return database.NewColumn("is_primary")
} }
// IsVerifiedColumn implements [domain.InstanceDomainRepository]. // IsVerifiedColumn implements [domain.InstanceDomainRepository].
func (instanceDomain) IsVerifiedColumn() database.Column { func (instanceDomain) IsVerifiedColumn(qualified bool) database.Column {
if qualified {
return database.NewColumn("instance_domains.is_verified")
}
return database.NewColumn("is_verified") return database.NewColumn("is_verified")
} }
// UpdatedAtColumn implements [domain.InstanceDomainRepository]. // UpdatedAtColumn implements [domain.InstanceDomainRepository].
// Subtle: this method shadows the method ([domain.InstanceRepository]).UpdatedAtColumn of instanceDomain.instance. // Subtle: this method shadows the method ([domain.InstanceRepository]).UpdatedAtColumn of instanceDomain.instance.
func (instanceDomain) UpdatedAtColumn() database.Column { func (instanceDomain) UpdatedAtColumn(qualified bool) database.Column {
if qualified {
return database.NewColumn("instance_domains.updated_at")
}
return database.NewColumn("updated_at") return database.NewColumn("updated_at")
} }
// VerificationTypeColumn implements [domain.InstanceDomainRepository]. // ValidationTypeColumn implements [domain.InstanceDomainRepository].
func (instanceDomain) VerificationTypeColumn() database.Column { func (instanceDomain) ValidationTypeColumn(qualified bool) database.Column {
return database.NewColumn("verification_type") if qualified {
return database.NewColumn("instance_domains.validation_type")
}
return database.NewColumn("validation_type")
} }
// IsGeneratedColumn implements [domain.InstanceDomainRepository]. // IsGeneratedColumn implements [domain.InstanceDomainRepository].
func (instanceDomain) IsGeneratedColumn() database.Column { func (instanceDomain) IsGeneratedColumn(qualified bool) database.Column {
if qualified {
return database.NewColumn("instance_domains.is_generated")
}
return database.NewColumn("is_generated") return database.NewColumn("is_generated")
} }

View File

@@ -171,7 +171,9 @@ func TestCreateInstance(t *testing.T) {
// check instance values // check instance values
instance, err = instanceRepo.Get(ctx, instance, err = instanceRepo.Get(ctx,
instance.ID, database.WithCondition(
instanceRepo.IDCondition(instance.ID),
),
) )
require.NoError(t, err) require.NoError(t, err)
@@ -290,7 +292,9 @@ func TestUpdateInstance(t *testing.T) {
// check instance values // check instance values
instance, err = instanceRepo.Get(ctx, instance, err = instanceRepo.Get(ctx,
instance.ID, database.WithCondition(
instanceRepo.IDCondition(instance.ID),
),
) )
require.Equal(t, tt.getErr, err) require.Equal(t, tt.getErr, err)
@@ -356,7 +360,9 @@ func TestGetInstance(t *testing.T) {
// check instance values // check instance values
returnedInstance, err := instanceRepo.Get(ctx, returnedInstance, err := instanceRepo.Get(ctx,
instance.ID, database.WithCondition(
instanceRepo.IDCondition(instance.ID),
),
) )
if tt.err != nil { if tt.err != nil {
require.ErrorIs(t, err, tt.err) require.ErrorIs(t, err, tt.err)
@@ -524,9 +530,15 @@ func TestListInstance(t *testing.T) {
instanceRepo := repository.InstanceRepository(pool) instanceRepo := repository.InstanceRepository(pool)
var condition database.Condition
if len(tt.conditionClauses) > 0 {
condition = database.And(tt.conditionClauses...)
}
// check instance values // check instance values
returnedInstances, err := instanceRepo.List(ctx, returnedInstances, err := instanceRepo.List(ctx,
tt.conditionClauses..., database.WithCondition(condition),
database.WithOrderBy(instanceRepo.CreatedAtColumn(true)),
) )
require.NoError(t, err) require.NoError(t, err)
if tt.noInstanceReturned { if tt.noInstanceReturned {
@@ -652,7 +664,9 @@ func TestDeleteInstance(t *testing.T) {
// check instance was deleted // check instance was deleted
instance, err := instanceRepo.Get(ctx, instance, err := instanceRepo.Get(ctx,
tt.instanceID, database.WithCondition(
instanceRepo.IDCondition(tt.instanceID),
),
) )
require.ErrorIs(t, err, new(database.NoRowFoundError)) require.ErrorIs(t, err, new(database.NoRowFoundError))
assert.Nil(t, instance) assert.Nil(t, instance)

View File

@@ -2,7 +2,7 @@ package repository
import ( import (
"context" "context"
"errors" "encoding/json"
"github.com/zitadel/zitadel/backend/v3/domain" "github.com/zitadel/zitadel/backend/v3/domain"
"github.com/zitadel/zitadel/backend/v3/storage/database" "github.com/zitadel/zitadel/backend/v3/storage/database"
@@ -15,9 +15,9 @@ import (
var _ domain.OrganizationRepository = (*org)(nil) var _ domain.OrganizationRepository = (*org)(nil)
type org struct { type org struct {
shouldJoinDomains bool
repository repository
domainRepo domain.OrganizationDomainRepository shouldLoadDomains bool
domainRepo domain.OrganizationDomainRepository
} }
func OrganizationRepository(client database.QueryExecutor) domain.OrganizationRepository { func OrganizationRepository(client database.QueryExecutor) domain.OrganizationRepository {
@@ -28,39 +28,68 @@ func OrganizationRepository(client database.QueryExecutor) domain.OrganizationRe
} }
} }
const queryOrganizationStmt = `SELECT id, name, instance_id, state, created_at, updated_at` + const queryOrganizationStmt = `SELECT organizations.id, organizations.name, organizations.instance_id, organizations.state, organizations.created_at, organizations.updated_at` +
` , CASE WHEN count(org_domains.domain) > 0 THEN jsonb_agg(json_build_object('domain', org_domains.domain, 'isVerified', org_domains.is_verified, 'isPrimary', org_domains.is_primary, 'validationType', org_domains.validation_type, 'createdAt', org_domains.created_at, 'updatedAt', org_domains.updated_at)) ELSE NULL::JSONB END domains` +
` FROM zitadel.organizations` ` FROM zitadel.organizations`
// Get implements [domain.OrganizationRepository]. // Get implements [domain.OrganizationRepository].
func (o *org) Get(ctx context.Context, id domain.OrgIdentifierCondition, instanceID string, conditions ...database.Condition) (*domain.Organization, error) { func (o *org) Get(ctx context.Context, opts ...database.QueryOption) (*domain.Organization, error) {
builder := database.StatementBuilder{} opts = append(opts,
o.joinDomains(),
database.WithGroupBy(o.InstanceIDColumn(true), o.IDColumn(true)),
)
options := new(database.QueryOpts)
for _, opt := range opts {
opt(options)
}
var builder database.StatementBuilder
builder.WriteString(queryOrganizationStmt) builder.WriteString(queryOrganizationStmt)
options.Write(&builder)
instanceIDCondition := o.InstanceIDCondition(instanceID)
conditions = append(conditions, id, instanceIDCondition)
writeCondition(&builder, database.And(conditions...))
return scanOrganization(ctx, o.client, &builder) return scanOrganization(ctx, o.client, &builder)
} }
// List implements [domain.OrganizationRepository]. // List implements [domain.OrganizationRepository].
func (o *org) List(ctx context.Context, conditions ...database.Condition) ([]*domain.Organization, error) { func (o *org) List(ctx context.Context, opts ...database.QueryOption) ([]*domain.Organization, error) {
builder := database.StatementBuilder{} opts = append(opts,
o.joinDomains(),
database.WithGroupBy(o.InstanceIDColumn(true), o.IDColumn(true)),
)
builder.WriteString(queryOrganizationStmt) options := new(database.QueryOpts)
for _, opt := range opts {
if conditions != nil { opt(options)
writeCondition(&builder, database.And(conditions...))
} }
orderBy := database.OrderBy(o.CreatedAtColumn()) var builder database.StatementBuilder
orderBy.Write(&builder) builder.WriteString(queryOrganizationStmt)
options.Write(&builder)
return scanOrganizations(ctx, o.client, &builder) return scanOrganizations(ctx, o.client, &builder)
} }
func (o *org) joinDomains() database.QueryOption {
columns := make([]database.Condition, 0, 3)
columns = append(columns,
database.NewColumnCondition(o.InstanceIDColumn(true), o.Domains(false).InstanceIDColumn(true)),
database.NewColumnCondition(o.IDColumn(true), o.Domains(false).OrgIDColumn(true)),
)
// If domains should not be joined, we make sure to return null for the domain columns
// the query optimizer of the dialect should optimize this away if no domains are requested
if !o.shouldLoadDomains {
columns = append(columns, database.IsNull(o.domainRepo.OrgIDColumn(true)))
}
return database.WithLeftJoin(
"zitadel.org_domains",
database.And(columns...),
)
}
const createOrganizationStmt = `INSERT INTO zitadel.organizations (id, name, instance_id, state)` + const createOrganizationStmt = `INSERT INTO zitadel.organizations (id, name, instance_id, state)` +
` VALUES ($1, $2, $3, $4)` + ` VALUES ($1, $2, $3, $4)` +
` RETURNING created_at, updated_at` ` RETURNING created_at, updated_at`
@@ -77,8 +106,8 @@ func (o *org) Create(ctx context.Context, organization *domain.Organization) err
// Update implements [domain.OrganizationRepository]. // Update implements [domain.OrganizationRepository].
func (o *org) Update(ctx context.Context, id domain.OrgIdentifierCondition, instanceID string, changes ...database.Change) (int64, error) { func (o *org) Update(ctx context.Context, id domain.OrgIdentifierCondition, instanceID string, changes ...database.Change) (int64, error) {
if changes == nil { if len(changes) == 0 {
return 0, errors.New("Update must contain a condition") // (otherwise ALL organizations will be updated) return 0, database.NoChangesError
} }
builder := database.StatementBuilder{} builder := database.StatementBuilder{}
builder.WriteString(`UPDATE zitadel.organizations SET `) builder.WriteString(`UPDATE zitadel.organizations SET `)
@@ -115,12 +144,12 @@ func (o *org) Delete(ctx context.Context, id domain.OrgIdentifierCondition, inst
// SetName implements [domain.organizationChanges]. // SetName implements [domain.organizationChanges].
func (o org) SetName(name string) database.Change { func (o org) SetName(name string) database.Change {
return database.NewChange(o.NameColumn(), name) return database.NewChange(o.NameColumn(false), name)
} }
// SetState implements [domain.organizationChanges]. // SetState implements [domain.organizationChanges].
func (o org) SetState(state domain.OrgState) database.Change { func (o org) SetState(state domain.OrgState) database.Change {
return database.NewChange(o.StateColumn(), state) return database.NewChange(o.StateColumn(false), state)
} }
// ------------------------------------------------------------- // -------------------------------------------------------------
@@ -129,22 +158,22 @@ func (o org) SetState(state domain.OrgState) database.Change {
// IDCondition implements [domain.organizationConditions]. // IDCondition implements [domain.organizationConditions].
func (o org) IDCondition(id string) domain.OrgIdentifierCondition { func (o org) IDCondition(id string) domain.OrgIdentifierCondition {
return database.NewTextCondition(o.IDColumn(), database.TextOperationEqual, id) return database.NewTextCondition(o.IDColumn(true), database.TextOperationEqual, id)
} }
// NameCondition implements [domain.organizationConditions]. // NameCondition implements [domain.organizationConditions].
func (o org) NameCondition(name string) domain.OrgIdentifierCondition { func (o org) NameCondition(name string) domain.OrgIdentifierCondition {
return database.NewTextCondition(o.NameColumn(), database.TextOperationEqual, name) return database.NewTextCondition(o.NameColumn(true), database.TextOperationEqual, name)
} }
// InstanceIDCondition implements [domain.organizationConditions]. // InstanceIDCondition implements [domain.organizationConditions].
func (o org) InstanceIDCondition(instanceID string) database.Condition { func (o org) InstanceIDCondition(instanceID string) database.Condition {
return database.NewTextCondition(o.InstanceIDColumn(), database.TextOperationEqual, instanceID) return database.NewTextCondition(o.InstanceIDColumn(true), database.TextOperationEqual, instanceID)
} }
// StateCondition implements [domain.organizationConditions]. // StateCondition implements [domain.organizationConditions].
func (o org) StateCondition(state domain.OrgState) database.Condition { func (o org) StateCondition(state domain.OrgState) database.Condition {
return database.NewTextCondition(o.StateColumn(), database.TextOperationEqual, state) return database.NewTextCondition(o.StateColumn(true), database.TextOperationEqual, state)
} }
// ------------------------------------------------------------- // -------------------------------------------------------------
@@ -152,32 +181,50 @@ func (o org) StateCondition(state domain.OrgState) database.Condition {
// ------------------------------------------------------------- // -------------------------------------------------------------
// IDColumn implements [domain.organizationColumns]. // IDColumn implements [domain.organizationColumns].
func (org) IDColumn() database.Column { func (org) IDColumn(qualified bool) database.Column {
if qualified {
return database.NewColumn("organizations.id")
}
return database.NewColumn("id") return database.NewColumn("id")
} }
// NameColumn implements [domain.organizationColumns]. // NameColumn implements [domain.organizationColumns].
func (org) NameColumn() database.Column { func (org) NameColumn(qualified bool) database.Column {
if qualified {
return database.NewColumn("organizations.name")
}
return database.NewColumn("name") return database.NewColumn("name")
} }
// InstanceIDColumn implements [domain.organizationColumns]. // InstanceIDColumn implements [domain.organizationColumns].
func (org) InstanceIDColumn() database.Column { func (org) InstanceIDColumn(qualified bool) database.Column {
if qualified {
return database.NewColumn("organizations.instance_id")
}
return database.NewColumn("instance_id") return database.NewColumn("instance_id")
} }
// StateColumn implements [domain.organizationColumns]. // StateColumn implements [domain.organizationColumns].
func (org) StateColumn() database.Column { func (org) StateColumn(qualified bool) database.Column {
if qualified {
return database.NewColumn("organizations.state")
}
return database.NewColumn("state") return database.NewColumn("state")
} }
// CreatedAtColumn implements [domain.organizationColumns]. // CreatedAtColumn implements [domain.organizationColumns].
func (org) CreatedAtColumn() database.Column { func (org) CreatedAtColumn(qualified bool) database.Column {
if qualified {
return database.NewColumn("organizations.created_at")
}
return database.NewColumn("created_at") return database.NewColumn("created_at")
} }
// UpdatedAtColumn implements [domain.organizationColumns]. // UpdatedAtColumn implements [domain.organizationColumns].
func (org) UpdatedAtColumn() database.Column { func (org) UpdatedAtColumn(qualified bool) database.Column {
if qualified {
return database.NewColumn("organizations.updated_at")
}
return database.NewColumn("updated_at") return database.NewColumn("updated_at")
} }
@@ -185,18 +232,28 @@ func (org) UpdatedAtColumn() database.Column {
// scanners // scanners
// ------------------------------------------------------------- // -------------------------------------------------------------
type rawOrganization struct {
*domain.Organization
RawDomains json.RawMessage `json:"domains,omitempty" db:"domains"`
}
func scanOrganization(ctx context.Context, querier database.Querier, builder *database.StatementBuilder) (*domain.Organization, error) { func scanOrganization(ctx context.Context, querier database.Querier, builder *database.StatementBuilder) (*domain.Organization, error) {
rows, err := querier.Query(ctx, builder.String(), builder.Args()...) rows, err := querier.Query(ctx, builder.String(), builder.Args()...)
if err != nil { if err != nil {
return nil, err return nil, err
} }
organization := &domain.Organization{} var org rawOrganization
if err := rows.(database.CollectableRows).CollectExactlyOneRow(organization); err != nil { if err := rows.(database.CollectableRows).CollectExactlyOneRow(&org); err != nil {
return nil, err return nil, err
} }
if len(org.RawDomains) > 0 {
if err := json.Unmarshal(org.RawDomains, &org.Domains); err != nil {
return nil, err
}
}
return organization, nil return org.Organization, nil
} }
func scanOrganizations(ctx context.Context, querier database.Querier, builder *database.StatementBuilder) ([]*domain.Organization, error) { func scanOrganizations(ctx context.Context, querier database.Querier, builder *database.StatementBuilder) ([]*domain.Organization, error) {
@@ -205,10 +262,20 @@ func scanOrganizations(ctx context.Context, querier database.Querier, builder *d
return nil, err return nil, err
} }
organizations := []*domain.Organization{} var rawOrgs []*rawOrganization
if err := rows.(database.CollectableRows).Collect(&organizations); err != nil { if err := rows.(database.CollectableRows).Collect(&rawOrgs); err != nil {
return nil, err return nil, err
} }
organizations := make([]*domain.Organization, len(rawOrgs))
for i, org := range rawOrgs {
if len(org.RawDomains) > 0 {
if err := json.Unmarshal(org.RawDomains, &org.Domains); err != nil {
return nil, err
}
}
organizations[i] = org.Organization
}
return organizations, nil return organizations, nil
} }
@@ -216,8 +283,11 @@ func scanOrganizations(ctx context.Context, querier database.Querier, builder *d
// sub repositories // sub repositories
// ------------------------------------------------------------- // -------------------------------------------------------------
func (o *org) Domains() domain.OrganizationDomainRepository { // Domains implements [domain.OrganizationRepository].
o.shouldJoinDomains = true func (o *org) Domains(shouldLoad bool) domain.OrganizationDomainRepository {
if !o.shouldLoadDomains {
o.shouldLoadDomains = shouldLoad
}
if o.domainRepo != nil { if o.domainRepo != nil {
return o.domainRepo return o.domainRepo

View File

@@ -61,7 +61,7 @@ func (o *orgDomain) Add(ctx context.Context, domain *domain.AddOrganizationDomai
`VALUES ($1, $2, $3, $4, $5, $6)` + `VALUES ($1, $2, $3, $4, $5, $6)` +
` RETURNING created_at, updated_at`) ` RETURNING created_at, updated_at`)
builder.AppendArgs(domain.InstanceID, domain.OrgID, domain.Domain, domain.IsVerified, domain.IsPrimary, domain.VerificationType) builder.AppendArgs(domain.InstanceID, domain.OrgID, domain.Domain, domain.IsVerified, domain.IsPrimary, domain.ValidationType)
return o.client.QueryRow(ctx, builder.String(), builder.Args()...).Scan(&domain.CreatedAt, &domain.UpdatedAt) return o.client.QueryRow(ctx, builder.String(), builder.Args()...).Scan(&domain.CreatedAt, &domain.UpdatedAt)
} }
@@ -95,17 +95,17 @@ func (o *orgDomain) Remove(ctx context.Context, condition database.Condition) (i
// SetPrimary implements [domain.OrganizationDomainRepository]. // SetPrimary implements [domain.OrganizationDomainRepository].
func (o orgDomain) SetPrimary() database.Change { func (o orgDomain) SetPrimary() database.Change {
return database.NewChange(o.IsPrimaryColumn(), true) return database.NewChange(o.IsPrimaryColumn(false), true)
} }
// SetVerificationType implements [domain.OrganizationDomainRepository]. // SetValidationType implements [domain.OrganizationDomainRepository].
func (o orgDomain) SetVerificationType(verificationType domain.DomainVerificationType) database.Change { func (o orgDomain) SetValidationType(verificationType domain.DomainValidationType) database.Change {
return database.NewChange(o.VerificationTypeColumn(), verificationType) return database.NewChange(o.ValidationTypeColumn(false), verificationType)
} }
// SetVerified implements [domain.OrganizationDomainRepository]. // SetVerified implements [domain.OrganizationDomainRepository].
func (o orgDomain) SetVerified() database.Change { func (o orgDomain) SetVerified() database.Change {
return database.NewChange(o.IsVerifiedColumn(), true) return database.NewChange(o.IsVerifiedColumn(false), true)
} }
// ------------------------------------------------------------- // -------------------------------------------------------------
@@ -114,28 +114,28 @@ func (o orgDomain) SetVerified() database.Change {
// DomainCondition implements [domain.OrganizationDomainRepository]. // DomainCondition implements [domain.OrganizationDomainRepository].
func (o orgDomain) DomainCondition(op database.TextOperation, domain string) database.Condition { func (o orgDomain) DomainCondition(op database.TextOperation, domain string) database.Condition {
return database.NewTextCondition(o.DomainColumn(), op, domain) return database.NewTextCondition(o.DomainColumn(true), op, domain)
} }
// InstanceIDCondition implements [domain.OrganizationDomainRepository]. // InstanceIDCondition implements [domain.OrganizationDomainRepository].
// Subtle: this method shadows the method ([domain.OrganizationRepository]).InstanceIDCondition of orgDomain.org. // Subtle: this method shadows the method ([domain.OrganizationRepository]).InstanceIDCondition of orgDomain.org.
func (o orgDomain) InstanceIDCondition(instanceID string) database.Condition { func (o orgDomain) InstanceIDCondition(instanceID string) database.Condition {
return database.NewTextCondition(o.InstanceIDColumn(), database.TextOperationEqual, instanceID) return database.NewTextCondition(o.InstanceIDColumn(true), database.TextOperationEqual, instanceID)
} }
// IsPrimaryCondition implements [domain.OrganizationDomainRepository]. // IsPrimaryCondition implements [domain.OrganizationDomainRepository].
func (o orgDomain) IsPrimaryCondition(isPrimary bool) database.Condition { func (o orgDomain) IsPrimaryCondition(isPrimary bool) database.Condition {
return database.NewBooleanCondition(o.IsPrimaryColumn(), isPrimary) return database.NewBooleanCondition(o.IsPrimaryColumn(true), isPrimary)
} }
// IsVerifiedCondition implements [domain.OrganizationDomainRepository]. // IsVerifiedCondition implements [domain.OrganizationDomainRepository].
func (o orgDomain) IsVerifiedCondition(isVerified bool) database.Condition { func (o orgDomain) IsVerifiedCondition(isVerified bool) database.Condition {
return database.NewBooleanCondition(o.IsVerifiedColumn(), isVerified) return database.NewBooleanCondition(o.IsVerifiedColumn(true), isVerified)
} }
// OrgIDCondition implements [domain.OrganizationDomainRepository]. // OrgIDCondition implements [domain.OrganizationDomainRepository].
func (o orgDomain) OrgIDCondition(orgID string) database.Condition { func (o orgDomain) OrgIDCondition(orgID string) database.Condition {
return database.NewTextCondition(o.OrgIDColumn(), database.TextOperationEqual, orgID) return database.NewTextCondition(o.OrgIDColumn(true), database.TextOperationEqual, orgID)
} }
// ------------------------------------------------------------- // -------------------------------------------------------------
@@ -144,45 +144,69 @@ func (o orgDomain) OrgIDCondition(orgID string) database.Condition {
// CreatedAtColumn implements [domain.OrganizationDomainRepository]. // CreatedAtColumn implements [domain.OrganizationDomainRepository].
// Subtle: this method shadows the method ([domain.OrganizationRepository]).CreatedAtColumn of orgDomain.org. // Subtle: this method shadows the method ([domain.OrganizationRepository]).CreatedAtColumn of orgDomain.org.
func (orgDomain) CreatedAtColumn() database.Column { func (orgDomain) CreatedAtColumn(qualified bool) database.Column {
if qualified {
return database.NewColumn("org_domains.created_at")
}
return database.NewColumn("created_at") return database.NewColumn("created_at")
} }
// DomainColumn implements [domain.OrganizationDomainRepository]. // DomainColumn implements [domain.OrganizationDomainRepository].
func (orgDomain) DomainColumn() database.Column { func (orgDomain) DomainColumn(qualified bool) database.Column {
if qualified {
return database.NewColumn("org_domains.domain")
}
return database.NewColumn("domain") return database.NewColumn("domain")
} }
// InstanceIDColumn implements [domain.OrganizationDomainRepository]. // InstanceIDColumn implements [domain.OrganizationDomainRepository].
// Subtle: this method shadows the method ([domain.OrganizationRepository]).InstanceIDColumn of orgDomain.org. // Subtle: this method shadows the method ([domain.OrganizationRepository]).InstanceIDColumn of orgDomain.org.
func (orgDomain) InstanceIDColumn() database.Column { func (orgDomain) InstanceIDColumn(qualified bool) database.Column {
if qualified {
return database.NewColumn("org_domains.instance_id")
}
return database.NewColumn("instance_id") return database.NewColumn("instance_id")
} }
// IsPrimaryColumn implements [domain.OrganizationDomainRepository]. // IsPrimaryColumn implements [domain.OrganizationDomainRepository].
func (orgDomain) IsPrimaryColumn() database.Column { func (orgDomain) IsPrimaryColumn(qualified bool) database.Column {
if qualified {
return database.NewColumn("org_domains.is_primary")
}
return database.NewColumn("is_primary") return database.NewColumn("is_primary")
} }
// IsVerifiedColumn implements [domain.OrganizationDomainRepository]. // IsVerifiedColumn implements [domain.OrganizationDomainRepository].
func (orgDomain) IsVerifiedColumn() database.Column { func (orgDomain) IsVerifiedColumn(qualified bool) database.Column {
if qualified {
return database.NewColumn("org_domains.is_verified")
}
return database.NewColumn("is_verified") return database.NewColumn("is_verified")
} }
// OrgIDColumn implements [domain.OrganizationDomainRepository]. // OrgIDColumn implements [domain.OrganizationDomainRepository].
func (orgDomain) OrgIDColumn() database.Column { func (orgDomain) OrgIDColumn(qualified bool) database.Column {
if qualified {
return database.NewColumn("org_domains.org_id")
}
return database.NewColumn("org_id") return database.NewColumn("org_id")
} }
// UpdatedAtColumn implements [domain.OrganizationDomainRepository]. // UpdatedAtColumn implements [domain.OrganizationDomainRepository].
// Subtle: this method shadows the method ([domain.OrganizationRepository]).UpdatedAtColumn of orgDomain.org. // Subtle: this method shadows the method ([domain.OrganizationRepository]).UpdatedAtColumn of orgDomain.org.
func (orgDomain) UpdatedAtColumn() database.Column { func (orgDomain) UpdatedAtColumn(qualified bool) database.Column {
if qualified {
return database.NewColumn("org_domains.updated_at")
}
return database.NewColumn("updated_at") return database.NewColumn("updated_at")
} }
// VerificationTypeColumn implements [domain.OrganizationDomainRepository]. // ValidationTypeColumn implements [domain.OrganizationDomainRepository].
func (orgDomain) VerificationTypeColumn() database.Column { func (orgDomain) ValidationTypeColumn(qualified bool) database.Column {
return database.NewColumn("verification_type") if qualified {
return database.NewColumn("org_domains.validation_type")
}
return database.NewColumn("validation_type")
} }
// ------------------------------------------------------------- // -------------------------------------------------------------

View File

@@ -235,8 +235,12 @@ func TestCreateOrganization(t *testing.T) {
// check organization values // check organization values
organization, err = organizationRepo.Get(ctx, organization, err = organizationRepo.Get(ctx,
organizationRepo.IDCondition(organization.ID), database.WithCondition(
organization.InstanceID, database.And(
organizationRepo.IDCondition(organization.ID),
organizationRepo.InstanceIDCondition(organization.InstanceID),
),
),
) )
require.NoError(t, err) require.NoError(t, err)
@@ -389,8 +393,12 @@ func TestUpdateOrganization(t *testing.T) {
// check organization values // check organization values
organization, err := organizationRepo.Get(ctx, organization, err := organizationRepo.Get(ctx,
organizationRepo.IDCondition(createdOrg.ID), database.WithCondition(
createdOrg.InstanceID, database.And(
organizationRepo.IDCondition(createdOrg.ID),
organizationRepo.InstanceIDCondition(createdOrg.InstanceID),
),
),
) )
require.NoError(t, err) require.NoError(t, err)
@@ -511,13 +519,18 @@ func TestGetOrganization(t *testing.T) {
// get org values // get org values
returnedOrg, err := orgRepo.Get(ctx, returnedOrg, err := orgRepo.Get(ctx,
tt.orgIdentifierCondition, database.WithCondition(
org.InstanceID, database.And(
tt.orgIdentifierCondition,
orgRepo.InstanceIDCondition(org.InstanceID),
),
),
) )
if tt.err != nil { if tt.err != nil {
require.ErrorIs(t, tt.err, err) require.ErrorIs(t, tt.err, err)
return return
} }
require.NoError(t, err)
if org.Name == "non existent org" { if org.Name == "non existent org" {
assert.Nil(t, returnedOrg) assert.Nil(t, returnedOrg)
@@ -764,9 +777,15 @@ func TestListOrganization(t *testing.T) {
organizations := tt.testFunc(ctx, t) organizations := tt.testFunc(ctx, t)
var condition database.Condition
if len(tt.conditionClauses) > 0 {
condition = database.And(tt.conditionClauses...)
}
// check organization values // check organization values
returnedOrgs, err := organizationRepo.List(ctx, returnedOrgs, err := organizationRepo.List(ctx,
tt.conditionClauses..., database.WithCondition(condition),
database.WithOrderBy(organizationRepo.CreatedAtColumn(true)),
) )
require.NoError(t, err) require.NoError(t, err)
if tt.noOrganizationReturned { if tt.noOrganizationReturned {
@@ -929,8 +948,12 @@ func TestDeleteOrganization(t *testing.T) {
// check organization was deleted // check organization was deleted
organization, err := organizationRepo.Get(ctx, organization, err := organizationRepo.Get(ctx,
tt.orgIdentifierCondition, database.WithCondition(
instanceId, database.And(
tt.orgIdentifierCondition,
organizationRepo.InstanceIDCondition(instanceId),
),
),
) )
require.ErrorIs(t, err, new(database.NoRowFoundError)) require.ErrorIs(t, err, new(database.NoRowFoundError))
assert.Nil(t, organization) assert.Nil(t, organization)