implementation done

This commit is contained in:
adlerhurst
2025-07-22 19:09:56 +02:00
parent bb2d0aff3f
commit 9c348c0429
14 changed files with 536 additions and 179 deletions

View File

@@ -2,28 +2,35 @@ package domain
import "github.com/zitadel/zitadel/backend/v3/storage/database"
type DomainVerificationType string
type DomainValidationType string
const (
DomainVerificationTypeDNS DomainVerificationType = "dns"
DomainVerificationTypeHTTP DomainVerificationType = "http"
DomainValidationTypeDNS DomainValidationType = "dns"
DomainValidationTypeHTTP DomainValidationType = "http"
)
type domainColumns interface {
// InstanceIDColumn returns the column for the instance id field.
InstanceIDColumn() database.Column
// `qualified` indicates if the column should be qualified with the table name.
InstanceIDColumn(qualified bool) database.Column
// DomainColumn returns the column for the domain field.
DomainColumn() database.Column
// `qualified` indicates if the column should be qualified with the table name.
DomainColumn(qualified bool) database.Column
// IsVerifiedColumn returns the column for the is verified field.
IsVerifiedColumn() database.Column
// `qualified` indicates if the column should be qualified with the table name.
IsVerifiedColumn(qualified bool) database.Column
// IsPrimaryColumn returns the column for the is primary field.
IsPrimaryColumn() database.Column
// VerificationTypeColumn returns the column for the verification type field.
VerificationTypeColumn() database.Column
// `qualified` indicates if the column should be qualified with the table name.
IsPrimaryColumn(qualified bool) database.Column
// ValidationTypeColumn returns the column for the verification type field.
// `qualified` indicates if the column should be qualified with the table name.
ValidationTypeColumn(qualified bool) database.Column
// CreatedAtColumn returns the column for the created at field.
CreatedAtColumn() database.Column
// `qualified` indicates if the column should be qualified with the table name.
CreatedAtColumn(qualified bool) database.Column
// UpdatedAtColumn returns the column for the updated at field.
UpdatedAtColumn() database.Column
// `qualified` indicates if the column should be qualified with the table name.
UpdatedAtColumn(qualified bool) database.Column
}
type domainConditions interface {
@@ -51,9 +58,9 @@ type domainChanges interface {
// - The domain is already primary.
// - No domain matches the condition.
SetPrimary() database.Change
// SetVerificationType sets the verification type column.
// SetValidationType sets the verification type column.
// If the domain is already verified, this is a no-op.
SetVerificationType(verificationType DomainVerificationType) database.Change
SetValidationType(verificationType DomainValidationType) database.Change
}
// import (

View File

@@ -18,6 +18,8 @@ type Instance struct {
DefaultLanguage string `json:"defaultLanguage,omitempty" db:"default_language"`
CreatedAt time.Time `json:"createdAt" db:"created_at"`
UpdatedAt time.Time `json:"updatedAt" db:"updated_at"`
Domains []*InstanceDomain `json:"domains,omitempty" db:"-"`
}
type instanceCacheIndex uint8
@@ -40,23 +42,32 @@ var _ cache.Entry[instanceCacheIndex, string] = (*Instance)(nil)
// instanceColumns define all the columns of the instance table.
type instanceColumns interface {
// IDColumn returns the column for the id field.
IDColumn() database.Column
// `qualified` indicates if the column should be qualified with the table name.
IDColumn(qualified bool) database.Column
// NameColumn returns the column for the name field.
NameColumn() database.Column
// `qualified` indicates if the column should be qualified with the table name.
NameColumn(qualified bool) database.Column
// DefaultOrgIDColumn returns the column for the default org id field
DefaultOrgIDColumn() database.Column
// `qualified` indicates if the column should be qualified with the table name.
DefaultOrgIDColumn(qualified bool) database.Column
// IAMProjectIDColumn returns the column for the default IAM org id field
IAMProjectIDColumn() database.Column
// `qualified` indicates if the column should be qualified with the table name.
IAMProjectIDColumn(qualified bool) database.Column
// ConsoleClientIDColumn returns the column for the default IAM org id field
ConsoleClientIDColumn() database.Column
// `qualified` indicates if the column should be qualified with the table name.
ConsoleClientIDColumn(qualified bool) database.Column
// ConsoleAppIDColumn returns the column for the console client id field
ConsoleAppIDColumn() database.Column
// `qualified` indicates if the column should be qualified with the table name.
ConsoleAppIDColumn(qualified bool) database.Column
// DefaultLanguageColumn returns the column for the default language field
DefaultLanguageColumn() database.Column
// `qualified` indicates if the column should be qualified with the table name.
DefaultLanguageColumn(qualified bool) database.Column
// CreatedAtColumn returns the column for the created at field.
CreatedAtColumn() database.Column
// `qualified` indicates if the column should be qualified with the table name.
CreatedAtColumn(qualified bool) database.Column
// UpdatedAtColumn returns the column for the updated at field.
UpdatedAtColumn() database.Column
// `qualified` indicates if the column should be qualified with the table name.
UpdatedAtColumn(qualified bool) database.Column
}
// instanceConditions define all the conditions for the instance table.
@@ -83,16 +94,27 @@ type InstanceRepository interface {
// Member returns the member repository which is a sub repository of the instance repository.
// Member() MemberRepository
Get(ctx context.Context, id string) (*Instance, error)
List(ctx context.Context, opts ...database.Condition) ([]*Instance, error)
Get(ctx context.Context, opts ...database.QueryOption) (*Instance, error)
List(ctx context.Context, opts ...database.QueryOption) ([]*Instance, error)
Create(ctx context.Context, instance *Instance) error
Update(ctx context.Context, id string, changes ...database.Change) (int64, error)
Delete(ctx context.Context, id string) (int64, error)
Domains() InstanceDomainRepository
// Domains returns the domain sub repository for the instance.
// If shouldLoad is true, the domains will be loaded from the database and written to the [Instance].Domains field.
// If shouldLoad is set to true once, the Domains field will be set event if shouldLoad is false in the future.
Domains(shouldLoad bool) InstanceDomainRepository
}
type CreateInstance struct {
Name string `json:"name"`
}
type InstanceQueryOption func(*InstanceQueryOpts)
type InstanceQueryOpts struct {
database.QueryOpts
JoinDomains bool
}

View File

@@ -2,17 +2,23 @@ package domain
import (
"context"
"encoding/json"
"time"
"github.com/zitadel/zitadel/backend/v3/storage/database"
)
type InstanceDomains struct {
domains []*InstanceDomain
Raw json.RawMessage
}
type InstanceDomain struct {
InstanceID string `json:"instanceId,omitempty" db:"instance_id"`
Domain string `json:"domain,omitempty" db:"domain"`
IsVerified bool `json:"isVerified,omitempty" db:"is_verified"`
IsPrimary bool `json:"isPrimary,omitempty" db:"is_primary"`
VerificationType DomainVerificationType `json:"verificationType,omitempty" db:"verification_type"`
ValidationType DomainValidationType `json:"validationType,omitempty" db:"validation_type"`
CreatedAt string `json:"createdAt,omitempty" db:"created_at"`
UpdatedAt string `json:"updatedAt,omitempty" db:"updated_at"`
@@ -23,7 +29,7 @@ type AddInstanceDomain struct {
Domain string `json:"domain,omitempty" db:"domain"`
IsVerified bool `json:"isVerified,omitempty" db:"is_verified"`
IsPrimary bool `json:"isPrimary,omitempty" db:"is_primary"`
VerificationType DomainVerificationType `json:"verificationType,omitempty" db:"verification_type"`
VerificationType DomainValidationType `json:"validationType,omitempty" db:"validation_type"`
// CreatedAt is the time when the domain was added.
// It is set by the repository and should not be set by the caller.
@@ -36,7 +42,8 @@ type AddInstanceDomain struct {
type instanceDomainColumns interface {
domainColumns
// IsGeneratedColumn returns the column for the is generated field.
IsGeneratedColumn() database.Column
// `qualified` indicates if the column should be qualified with the table name.
IsGeneratedColumn(qualified bool) database.Column
}
type instanceDomainConditions interface {

View File

@@ -19,8 +19,10 @@ type Organization struct {
Name string `json:"name,omitempty" db:"name"`
InstanceID string `json:"instanceId,omitempty" db:"instance_id"`
State OrgState `json:"state,omitempty" db:"state"`
CreatedAt time.Time `json:"createdAt,omitempty" db:"created_at"`
UpdatedAt time.Time `json:"updatedAt,omitempty" db:"updated_at"`
CreatedAt time.Time `json:"createdAt,omitzero" db:"created_at"`
UpdatedAt time.Time `json:"updatedAt,omitzero" db:"updated_at"`
Domains []*OrganizationDomain `json:"domains,omitempty" db:"-"`
}
// OrgIdentifierCondition is used to help specify a single Organization,
@@ -33,17 +35,23 @@ type OrgIdentifierCondition interface {
// organizationColumns define all the columns of the instance table.
type organizationColumns interface {
// IDColumn returns the column for the id field.
IDColumn() database.Column
// `qualified` indicates if the column should be qualified with the table name.
IDColumn(qualified bool) database.Column
// NameColumn returns the column for the name field.
NameColumn() database.Column
// `qualified` indicates if the column should be qualified with the table name.
NameColumn(qualified bool) database.Column
// InstanceIDColumn returns the column for the default org id field
InstanceIDColumn() database.Column
// `qualified` indicates if the column should be qualified with the table name.
InstanceIDColumn(qualified bool) database.Column
// StateColumn returns the column for the name field.
StateColumn() database.Column
// `qualified` indicates if the column should be qualified with the table name.
StateColumn(qualified bool) database.Column
// CreatedAtColumn returns the column for the created at field.
CreatedAtColumn() database.Column
// `qualified` indicates if the column should be qualified with the table name.
CreatedAtColumn(qualified bool) database.Column
// UpdatedAtColumn returns the column for the updated at field.
UpdatedAtColumn() database.Column
// `qualified` indicates if the column should be qualified with the table name.
UpdatedAtColumn(qualified bool) database.Column
}
// organizationConditions define all the conditions for the instance table.
@@ -72,15 +80,17 @@ type OrganizationRepository interface {
organizationConditions
organizationChanges
Get(ctx context.Context, id OrgIdentifierCondition, instance_id string, opts ...database.Condition) (*Organization, error)
List(ctx context.Context, conditions ...database.Condition) ([]*Organization, error)
Get(ctx context.Context, opts ...database.QueryOption) (*Organization, error)
List(ctx context.Context, opts ...database.QueryOption) ([]*Organization, error)
Create(ctx context.Context, instance *Organization) error
Update(ctx context.Context, id OrgIdentifierCondition, instance_id string, changes ...database.Change) (int64, error)
Delete(ctx context.Context, id OrgIdentifierCondition, instance_id string) (int64, error)
// Domains returns the domain sub repository for the organization.
Domains() OrganizationDomainRepository
// If shouldLoad is true, the domains will be loaded from the database and written to the [Instance].Domains field.
// If shouldLoad is set to true once, the Domains field will be set event if shouldLoad is false in the future.
Domains(shouldLoad bool) OrganizationDomainRepository
}
type CreateOrganization struct {

View File

@@ -13,7 +13,7 @@ type OrganizationDomain struct {
Domain string `json:"domain,omitempty" db:"domain"`
IsVerified bool `json:"isVerified,omitempty" db:"is_verified"`
IsPrimary bool `json:"isPrimary,omitempty" db:"is_primary"`
VerificationType DomainVerificationType `json:"verificationType,omitempty" db:"verification_type"`
ValidationType DomainValidationType `json:"validationType,omitempty" db:"validation_type"`
CreatedAt string `json:"createdAt,omitempty" db:"created_at"`
UpdatedAt string `json:"updatedAt,omitempty" db:"updated_at"`
@@ -25,7 +25,7 @@ type AddOrganizationDomain struct {
Domain string `json:"domain,omitempty" db:"domain"`
IsVerified bool `json:"isVerified,omitempty" db:"is_verified"`
IsPrimary bool `json:"isPrimary,omitempty" db:"is_primary"`
VerificationType DomainVerificationType `json:"verificationType,omitempty" db:"verification_type"`
ValidationType DomainValidationType `json:"validationType,omitempty" db:"validation_type"`
// CreatedAt is the time when the domain was added.
// It is set by the repository and should not be set by the caller.
@@ -38,7 +38,7 @@ type AddOrganizationDomain struct {
type organizationDomainColumns interface {
domainColumns
// OrgIDColumn returns the column for the org id field.
OrgIDColumn() database.Column
OrgIDColumn(qualified bool) database.Column
}
type organizationDomainConditions interface {

View File

@@ -113,6 +113,15 @@ func NewBooleanCondition[V Boolean](col Column, value V) Condition {
})
}
// NewColumnCondition creates a condition that compares two columns on equality.
func NewColumnCondition(col1, col2 Column) Condition {
return valueCondition(func(builder *StatementBuilder) {
col1.Write(builder)
builder.WriteString(" = ")
col2.Write(builder)
})
}
// Write implements [Condition].
func (c valueCondition) Write(builder *StatementBuilder) {
c(builder)

View File

@@ -1,9 +1,12 @@
package database
import (
"errors"
"fmt"
)
var NoChangesError = errors.New("Update must contain a change")
// NoRowFoundError is returned when QueryRow does not find any row.
// It wraps the dialect specific original error to provide more context.
type NoRowFoundError struct {

View File

@@ -30,6 +30,36 @@ func WithOffset(offset uint32) QueryOption {
}
}
// WithGroupBy sets the columns to group the results by.
func WithGroupBy(groupBy ...Column) QueryOption {
return func(opts *QueryOpts) {
opts.GroupBy = groupBy
}
}
// WithLeftJoin adds a LEFT JOIN to the query.
func WithLeftJoin(table string, columns Condition) QueryOption {
return func(opts *QueryOpts) {
opts.Joins = append(opts.Joins, join{
table: table,
typ: JoinTypeLeft,
columns: columns,
})
}
}
type joinType string
const (
JoinTypeLeft joinType = "LEFT"
)
type join struct {
table string
typ joinType
columns Condition
}
// QueryOpts holds the options for a query.
// It is used to build the SQL SELECT statement.
type QueryOpts struct {
@@ -45,10 +75,19 @@ type QueryOpts struct {
// Offset is the number of results to skip before returning the results.
// It is used to build the OFFSET clause of the SQL statement.
Offset uint32
// GroupBy is the columns to group the results by.
// It is used to build the GROUP BY clause of the SQL statement.
GroupBy Columns
// Joins is a list of joins to be applied to the query.
// It is used to build the JOIN clauses of the SQL statement.
Joins []join
}
func (opts *QueryOpts) Write(builder *StatementBuilder) {
opts.WriteLeftJoins(builder)
opts.WriteCondition(builder)
opts.WriteGroupBy(builder)
opts.WriteOrderBy(builder)
opts.WriteLimit(builder)
opts.WriteOffset(builder)
@@ -85,3 +124,25 @@ func (opts *QueryOpts) WriteOffset(builder *StatementBuilder) {
builder.WriteString(" OFFSET ")
builder.WriteArg(opts.Offset)
}
func (opts *QueryOpts) WriteGroupBy(builder *StatementBuilder) {
if len(opts.GroupBy) == 0 {
return
}
builder.WriteString(" GROUP BY ")
opts.GroupBy.Write(builder)
}
func (opts *QueryOpts) WriteLeftJoins(builder *StatementBuilder) {
if len(opts.Joins) == 0 {
return
}
for _, join := range opts.Joins {
builder.WriteString(" ")
builder.WriteString(string(join.typ))
builder.WriteString(" JOIN ")
builder.WriteString(join.table)
builder.WriteString(" ON ")
join.columns.Write(builder)
}
}

View File

@@ -2,7 +2,7 @@ package repository
import (
"context"
"errors"
"encoding/json"
"github.com/zitadel/zitadel/backend/v3/domain"
"github.com/zitadel/zitadel/backend/v3/storage/database"
@@ -12,8 +12,8 @@ var _ domain.InstanceRepository = (*instance)(nil)
type instance struct {
repository
shouldJoinDomains bool
domainRepo domain.InstanceDomainRepository
shouldLoadDomains bool
domainRepo *instanceDomain
}
func InstanceRepository(client database.QueryExecutor) domain.InstanceRepository {
@@ -24,38 +24,71 @@ func InstanceRepository(client database.QueryExecutor) domain.InstanceRepository
}
}
// -------------------------------------------------------------
// repository
// -------------------------------------------------------------
const queryInstanceStmt = `SELECT id, name, default_org_id, iam_project_id, console_client_id, console_app_id, default_language, created_at, updated_at` +
` FROM zitadel.instances`
const (
queryInstanceStmt = `SELECT instances.id, instances.name, instances.default_org_id, instances.iam_project_id, instances.console_client_id, instances.console_app_id, instances.default_language, instances.created_at, instances.updated_at` +
` , CASE WHEN count(instance_domains.domain) > 0 THEN jsonb_agg(json_build_object('domain', instance_domains.domain, 'isVerified', instance_domains.is_verified, 'isPrimary', instance_domains.is_primary, 'isGenerated', instance_domains.is_generated, 'validationType', instance_domains.validation_type, 'createdAt', instance_domains.created_at, 'updatedAt', instance_domains.updated_at)) ELSE NULL::JSONB END domains` +
` FROM zitadel.instances`
)
// Get implements [domain.InstanceRepository].
func (i *instance) Get(ctx context.Context, id string) (*domain.Instance, error) {
func (i *instance) Get(ctx context.Context, opts ...database.QueryOption) (*domain.Instance, error) {
opts = append(opts,
i.joinDomains(),
database.WithGroupBy(i.IDColumn(true)),
)
options := new(database.QueryOpts)
for _, opt := range opts {
opt(options)
}
var builder database.StatementBuilder
builder.WriteString(queryInstanceStmt)
idCondition := i.IDCondition(id)
writeCondition(&builder, idCondition)
options.Write(&builder)
return scanInstance(ctx, i.client, &builder)
}
// List implements [domain.InstanceRepository].
func (i *instance) List(ctx context.Context, conditions ...database.Condition) ([]*domain.Instance, error) {
var builder database.StatementBuilder
builder.WriteString(queryInstanceStmt)
if conditions != nil {
writeCondition(&builder, database.And(conditions...))
func (i *instance) List(ctx context.Context, opts ...database.QueryOption) ([]*domain.Instance, error) {
opts = append(opts,
i.joinDomains(),
database.WithGroupBy(i.IDColumn(true)),
)
options := new(database.QueryOpts)
for _, opt := range opts {
opt(options)
}
var builder database.StatementBuilder
builder.WriteString(queryInstanceStmt)
options.Write(&builder)
return scanInstances(ctx, i.client, &builder)
}
func (i *instance) joinDomains() database.QueryOption {
columns := make([]database.Condition, 0, 2)
columns = append(columns, database.NewColumnCondition(i.IDColumn(true), i.Domains(false).InstanceIDColumn(true)))
// If domains should not be joined, we make sure to return null for the domain columns
// the query optimizer of the dialect should optimize this away if no domains are requested
if !i.shouldLoadDomains {
columns = append(columns, database.IsNull(i.Domains(false).InstanceIDColumn(true)))
}
return database.WithLeftJoin(
"zitadel.instance_domains",
database.And(columns...),
)
}
const createInstanceStmt = `INSERT INTO zitadel.instances (id, name, default_org_id, iam_project_id, console_client_id, console_app_id, default_language)` +
` VALUES ($1, $2, $3, $4, $5, $6, $7)` +
` RETURNING created_at, updated_at`
@@ -72,8 +105,8 @@ func (i *instance) Create(ctx context.Context, instance *domain.Instance) error
// Update implements [domain.InstanceRepository].
func (i instance) Update(ctx context.Context, id string, changes ...database.Change) (int64, error) {
if changes == nil {
return 0, errors.New("Update must contain a change")
if len(changes) == 0 {
return 0, database.NoChangesError
}
var builder database.StatementBuilder
@@ -107,7 +140,7 @@ func (i instance) Delete(ctx context.Context, id string) (int64, error) {
// SetName implements [domain.instanceChanges].
func (i instance) SetName(name string) database.Change {
return database.NewChange(i.NameColumn(), name)
return database.NewChange(i.NameColumn(false), name)
}
// -------------------------------------------------------------
@@ -116,12 +149,12 @@ func (i instance) SetName(name string) database.Change {
// IDCondition implements [domain.instanceConditions].
func (i instance) IDCondition(id string) database.Condition {
return database.NewTextCondition(i.IDColumn(), database.TextOperationEqual, id)
return database.NewTextCondition(i.IDColumn(true), database.TextOperationEqual, id)
}
// NameCondition implements [domain.instanceConditions].
func (i instance) NameCondition(op database.TextOperation, name string) database.Condition {
return database.NewTextCondition(i.NameColumn(), op, name)
return database.NewTextCondition(i.NameColumn(true), op, name)
}
// -------------------------------------------------------------
@@ -129,74 +162,122 @@ func (i instance) NameCondition(op database.TextOperation, name string) database
// -------------------------------------------------------------
// IDColumn implements [domain.instanceColumns].
func (instance) IDColumn() database.Column {
func (instance) IDColumn(qualified bool) database.Column {
if qualified {
return database.NewColumn("instances.id")
}
return database.NewColumn("id")
}
// NameColumn implements [domain.instanceColumns].
func (instance) NameColumn() database.Column {
func (instance) NameColumn(qualified bool) database.Column {
if qualified {
return database.NewColumn("instances.name")
}
return database.NewColumn("name")
}
// CreatedAtColumn implements [domain.instanceColumns].
func (instance) CreatedAtColumn() database.Column {
func (instance) CreatedAtColumn(qualified bool) database.Column {
if qualified {
return database.NewColumn("instances.created_at")
}
return database.NewColumn("created_at")
}
// DefaultOrgIdColumn implements [domain.instanceColumns].
func (instance) DefaultOrgIDColumn() database.Column {
func (instance) DefaultOrgIDColumn(qualified bool) database.Column {
if qualified {
return database.NewColumn("instances.default_org_id")
}
return database.NewColumn("default_org_id")
}
// IAMProjectIDColumn implements [domain.instanceColumns].
func (instance) IAMProjectIDColumn() database.Column {
func (instance) IAMProjectIDColumn(qualified bool) database.Column {
if qualified {
return database.NewColumn("instances.iam_project_id")
}
return database.NewColumn("iam_project_id")
}
// ConsoleClientIDColumn implements [domain.instanceColumns].
func (instance) ConsoleClientIDColumn() database.Column {
func (instance) ConsoleClientIDColumn(qualified bool) database.Column {
if qualified {
return database.NewColumn("instances.console_client_id")
}
return database.NewColumn("console_client_id")
}
// ConsoleAppIDColumn implements [domain.instanceColumns].
func (instance) ConsoleAppIDColumn() database.Column {
func (instance) ConsoleAppIDColumn(qualified bool) database.Column {
if qualified {
return database.NewColumn("instances.console_app_id")
}
return database.NewColumn("console_app_id")
}
// DefaultLanguageColumn implements [domain.instanceColumns].
func (instance) DefaultLanguageColumn() database.Column {
func (instance) DefaultLanguageColumn(qualified bool) database.Column {
if qualified {
return database.NewColumn("instances.default_language")
}
return database.NewColumn("default_language")
}
// UpdatedAtColumn implements [domain.instanceColumns].
func (instance) UpdatedAtColumn() database.Column {
func (instance) UpdatedAtColumn(qualified bool) database.Column {
if qualified {
return database.NewColumn("instances.updated_at")
}
return database.NewColumn("updated_at")
}
type rawInstance struct {
*domain.Instance
RawDomains json.RawMessage `json:"domains,omitempty" db:"domains"`
}
func scanInstance(ctx context.Context, querier database.Querier, builder *database.StatementBuilder) (*domain.Instance, error) {
rows, err := querier.Query(ctx, builder.String(), builder.Args()...)
if err != nil {
return nil, err
}
instance := new(domain.Instance)
if err := rows.(database.CollectableRows).CollectExactlyOneRow(instance); err != nil {
var instance rawInstance
if err := rows.(database.CollectableRows).CollectExactlyOneRow(&instance); err != nil {
return nil, err
}
return instance, nil
if len(instance.RawDomains) > 0 {
if err := json.Unmarshal(instance.RawDomains, &instance.Domains); err != nil {
return nil, err
}
}
return instance.Instance, nil
}
func scanInstances(ctx context.Context, querier database.Querier, builder *database.StatementBuilder) (instances []*domain.Instance, err error) {
func scanInstances(ctx context.Context, querier database.Querier, builder *database.StatementBuilder) ([]*domain.Instance, error) {
rows, err := querier.Query(ctx, builder.String(), builder.Args()...)
if err != nil {
return nil, err
}
if err := rows.(database.CollectableRows).Collect(&instances); err != nil {
var rawInstances []*rawInstance
if err := rows.(database.CollectableRows).Collect(&rawInstances); err != nil {
return nil, err
}
instances := make([]*domain.Instance, len(rawInstances))
for i, instance := range rawInstances {
if len(instance.RawDomains) > 0 {
if err := json.Unmarshal(instance.RawDomains, &instance.Domains); err != nil {
return nil, err
}
}
instances[i] = instance.Instance
}
return instances, nil
}
@@ -205,8 +286,10 @@ func scanInstances(ctx context.Context, querier database.Querier, builder *datab
// -------------------------------------------------------------
// Domains implements [domain.InstanceRepository].
func (i *instance) Domains() domain.InstanceDomainRepository {
i.shouldJoinDomains = true
func (i *instance) Domains(shouldLoad bool) domain.InstanceDomainRepository {
if !i.shouldLoadDomains {
i.shouldLoadDomains = shouldLoad
}
if i.domainRepo != nil {
return i.domainRepo

View File

@@ -18,8 +18,8 @@ type instanceDomain struct {
// repository
// -------------------------------------------------------------
const queryInstanceDomainStmt = `SELECT instance_id, domain, is_verified, is_primary, verification_type, created_at, updated_at ` +
`FROM zitadel.instance_domains`
const queryInstanceDomainStmt = `SELECT instance_domains.instance_id, instance_domains.domain, instance_domains.is_verified, instance_domains.is_primary, instance_domains.validation_type, instance_domains.created_at, instance_domains.updated_at ` +
`FROM zitadel.instance_domains id`
// Get implements [domain.InstanceDomainRepository].
// Subtle: this method shadows the method ([domain.InstanceRepository]).Get of instanceDomain.instance.
@@ -55,7 +55,7 @@ func (i *instanceDomain) List(ctx context.Context, opts ...database.QueryOption)
func (i *instanceDomain) Add(ctx context.Context, domain *domain.AddInstanceDomain) error {
var builder database.StatementBuilder
builder.WriteString(`INSERT INTO zitadel.instance_domains (instance_id, domain, is_verified, is_primary, verification_type) ` +
builder.WriteString(`INSERT INTO zitadel.instance_domains (instance_id, domain, is_verified, is_primary, validation_type) ` +
`VALUES ($1, $2, $3, $4, $5)` +
` RETURNING created_at, updated_at`)
@@ -91,19 +91,19 @@ func (i *instanceDomain) Update(ctx context.Context, condition database.Conditio
// changes
// -------------------------------------------------------------
// SetVerificationType implements [domain.InstanceDomainRepository].
func (i instanceDomain) SetVerificationType(verificationType domain.DomainVerificationType) database.Change {
return database.NewChange(i.VerificationTypeColumn(), verificationType)
// SetValidationType implements [domain.InstanceDomainRepository].
func (i instanceDomain) SetValidationType(verificationType domain.DomainValidationType) database.Change {
return database.NewChange(i.ValidationTypeColumn(false), verificationType)
}
// SetPrimary implements [domain.InstanceDomainRepository].
func (i instanceDomain) SetPrimary() database.Change {
return database.NewChange(i.IsPrimaryColumn(), true)
return database.NewChange(i.IsPrimaryColumn(false), true)
}
// SetVerified implements [domain.InstanceDomainRepository].
func (i instanceDomain) SetVerified() database.Change {
return database.NewChange(i.IsVerifiedColumn(), true)
return database.NewChange(i.IsVerifiedColumn(false), true)
}
// -------------------------------------------------------------
@@ -112,22 +112,22 @@ func (i instanceDomain) SetVerified() database.Change {
// DomainCondition implements [domain.InstanceDomainRepository].
func (i instanceDomain) DomainCondition(op database.TextOperation, domain string) database.Condition {
return database.NewTextCondition(i.DomainColumn(), op, domain)
return database.NewTextCondition(i.DomainColumn(true), op, domain)
}
// InstanceIDCondition implements [domain.InstanceDomainRepository].
func (i instanceDomain) InstanceIDCondition(instanceID string) database.Condition {
return database.NewTextCondition(i.InstanceIDColumn(), database.TextOperationEqual, instanceID)
return database.NewTextCondition(i.InstanceIDColumn(true), database.TextOperationEqual, instanceID)
}
// IsPrimaryCondition implements [domain.InstanceDomainRepository].
func (i instanceDomain) IsPrimaryCondition(isPrimary bool) database.Condition {
return database.NewBooleanCondition(i.IsPrimaryColumn(), isPrimary)
return database.NewBooleanCondition(i.IsPrimaryColumn(true), isPrimary)
}
// IsVerifiedCondition implements [domain.InstanceDomainRepository].
func (i instanceDomain) IsVerifiedCondition(isVerified bool) database.Condition {
return database.NewBooleanCondition(i.IsVerifiedColumn(), isVerified)
return database.NewBooleanCondition(i.IsVerifiedColumn(true), isVerified)
}
// -------------------------------------------------------------
@@ -136,43 +136,67 @@ func (i instanceDomain) IsVerifiedCondition(isVerified bool) database.Condition
// CreatedAtColumn implements [domain.InstanceDomainRepository].
// Subtle: this method shadows the method ([domain.InstanceRepository]).CreatedAtColumn of instanceDomain.instance.
func (instanceDomain) CreatedAtColumn() database.Column {
func (instanceDomain) CreatedAtColumn(qualified bool) database.Column {
if qualified {
return database.NewColumn("instance_domains.created_at")
}
return database.NewColumn("created_at")
}
// DomainColumn implements [domain.InstanceDomainRepository].
func (instanceDomain) DomainColumn() database.Column {
func (instanceDomain) DomainColumn(qualified bool) database.Column {
if qualified {
return database.NewColumn("instance_domains.domain")
}
return database.NewColumn("domain")
}
// InstanceIDColumn implements [domain.InstanceDomainRepository].
func (instanceDomain) InstanceIDColumn() database.Column {
func (instanceDomain) InstanceIDColumn(qualified bool) database.Column {
if qualified {
return database.NewColumn("instance_domains.instance_id")
}
return database.NewColumn("instance_id")
}
// IsPrimaryColumn implements [domain.InstanceDomainRepository].
func (instanceDomain) IsPrimaryColumn() database.Column {
func (instanceDomain) IsPrimaryColumn(qualified bool) database.Column {
if qualified {
return database.NewColumn("instance_domains.is_primary")
}
return database.NewColumn("is_primary")
}
// IsVerifiedColumn implements [domain.InstanceDomainRepository].
func (instanceDomain) IsVerifiedColumn() database.Column {
func (instanceDomain) IsVerifiedColumn(qualified bool) database.Column {
if qualified {
return database.NewColumn("instance_domains.is_verified")
}
return database.NewColumn("is_verified")
}
// UpdatedAtColumn implements [domain.InstanceDomainRepository].
// Subtle: this method shadows the method ([domain.InstanceRepository]).UpdatedAtColumn of instanceDomain.instance.
func (instanceDomain) UpdatedAtColumn() database.Column {
func (instanceDomain) UpdatedAtColumn(qualified bool) database.Column {
if qualified {
return database.NewColumn("instance_domains.updated_at")
}
return database.NewColumn("updated_at")
}
// VerificationTypeColumn implements [domain.InstanceDomainRepository].
func (instanceDomain) VerificationTypeColumn() database.Column {
return database.NewColumn("verification_type")
// ValidationTypeColumn implements [domain.InstanceDomainRepository].
func (instanceDomain) ValidationTypeColumn(qualified bool) database.Column {
if qualified {
return database.NewColumn("instance_domains.validation_type")
}
return database.NewColumn("validation_type")
}
// IsGeneratedColumn implements [domain.InstanceDomainRepository].
func (instanceDomain) IsGeneratedColumn() database.Column {
func (instanceDomain) IsGeneratedColumn(qualified bool) database.Column {
if qualified {
return database.NewColumn("instance_domains.is_generated")
}
return database.NewColumn("is_generated")
}

View File

@@ -171,7 +171,9 @@ func TestCreateInstance(t *testing.T) {
// check instance values
instance, err = instanceRepo.Get(ctx,
instance.ID,
database.WithCondition(
instanceRepo.IDCondition(instance.ID),
),
)
require.NoError(t, err)
@@ -290,7 +292,9 @@ func TestUpdateInstance(t *testing.T) {
// check instance values
instance, err = instanceRepo.Get(ctx,
instance.ID,
database.WithCondition(
instanceRepo.IDCondition(instance.ID),
),
)
require.Equal(t, tt.getErr, err)
@@ -356,7 +360,9 @@ func TestGetInstance(t *testing.T) {
// check instance values
returnedInstance, err := instanceRepo.Get(ctx,
instance.ID,
database.WithCondition(
instanceRepo.IDCondition(instance.ID),
),
)
if tt.err != nil {
require.ErrorIs(t, err, tt.err)
@@ -524,9 +530,15 @@ func TestListInstance(t *testing.T) {
instanceRepo := repository.InstanceRepository(pool)
var condition database.Condition
if len(tt.conditionClauses) > 0 {
condition = database.And(tt.conditionClauses...)
}
// check instance values
returnedInstances, err := instanceRepo.List(ctx,
tt.conditionClauses...,
database.WithCondition(condition),
database.WithOrderBy(instanceRepo.CreatedAtColumn(true)),
)
require.NoError(t, err)
if tt.noInstanceReturned {
@@ -652,7 +664,9 @@ func TestDeleteInstance(t *testing.T) {
// check instance was deleted
instance, err := instanceRepo.Get(ctx,
tt.instanceID,
database.WithCondition(
instanceRepo.IDCondition(tt.instanceID),
),
)
require.ErrorIs(t, err, new(database.NoRowFoundError))
assert.Nil(t, instance)

View File

@@ -2,7 +2,7 @@ package repository
import (
"context"
"errors"
"encoding/json"
"github.com/zitadel/zitadel/backend/v3/domain"
"github.com/zitadel/zitadel/backend/v3/storage/database"
@@ -15,9 +15,9 @@ import (
var _ domain.OrganizationRepository = (*org)(nil)
type org struct {
shouldJoinDomains bool
repository
domainRepo domain.OrganizationDomainRepository
shouldLoadDomains bool
domainRepo domain.OrganizationDomainRepository
}
func OrganizationRepository(client database.QueryExecutor) domain.OrganizationRepository {
@@ -28,39 +28,68 @@ func OrganizationRepository(client database.QueryExecutor) domain.OrganizationRe
}
}
const queryOrganizationStmt = `SELECT id, name, instance_id, state, created_at, updated_at` +
const queryOrganizationStmt = `SELECT organizations.id, organizations.name, organizations.instance_id, organizations.state, organizations.created_at, organizations.updated_at` +
` , CASE WHEN count(org_domains.domain) > 0 THEN jsonb_agg(json_build_object('domain', org_domains.domain, 'isVerified', org_domains.is_verified, 'isPrimary', org_domains.is_primary, 'validationType', org_domains.validation_type, 'createdAt', org_domains.created_at, 'updatedAt', org_domains.updated_at)) ELSE NULL::JSONB END domains` +
` FROM zitadel.organizations`
// Get implements [domain.OrganizationRepository].
func (o *org) Get(ctx context.Context, id domain.OrgIdentifierCondition, instanceID string, conditions ...database.Condition) (*domain.Organization, error) {
builder := database.StatementBuilder{}
func (o *org) Get(ctx context.Context, opts ...database.QueryOption) (*domain.Organization, error) {
opts = append(opts,
o.joinDomains(),
database.WithGroupBy(o.InstanceIDColumn(true), o.IDColumn(true)),
)
options := new(database.QueryOpts)
for _, opt := range opts {
opt(options)
}
var builder database.StatementBuilder
builder.WriteString(queryOrganizationStmt)
instanceIDCondition := o.InstanceIDCondition(instanceID)
conditions = append(conditions, id, instanceIDCondition)
writeCondition(&builder, database.And(conditions...))
options.Write(&builder)
return scanOrganization(ctx, o.client, &builder)
}
// List implements [domain.OrganizationRepository].
func (o *org) List(ctx context.Context, conditions ...database.Condition) ([]*domain.Organization, error) {
builder := database.StatementBuilder{}
func (o *org) List(ctx context.Context, opts ...database.QueryOption) ([]*domain.Organization, error) {
opts = append(opts,
o.joinDomains(),
database.WithGroupBy(o.InstanceIDColumn(true), o.IDColumn(true)),
)
builder.WriteString(queryOrganizationStmt)
if conditions != nil {
writeCondition(&builder, database.And(conditions...))
options := new(database.QueryOpts)
for _, opt := range opts {
opt(options)
}
orderBy := database.OrderBy(o.CreatedAtColumn())
orderBy.Write(&builder)
var builder database.StatementBuilder
builder.WriteString(queryOrganizationStmt)
options.Write(&builder)
return scanOrganizations(ctx, o.client, &builder)
}
func (o *org) joinDomains() database.QueryOption {
columns := make([]database.Condition, 0, 3)
columns = append(columns,
database.NewColumnCondition(o.InstanceIDColumn(true), o.Domains(false).InstanceIDColumn(true)),
database.NewColumnCondition(o.IDColumn(true), o.Domains(false).OrgIDColumn(true)),
)
// If domains should not be joined, we make sure to return null for the domain columns
// the query optimizer of the dialect should optimize this away if no domains are requested
if !o.shouldLoadDomains {
columns = append(columns, database.IsNull(o.domainRepo.OrgIDColumn(true)))
}
return database.WithLeftJoin(
"zitadel.org_domains",
database.And(columns...),
)
}
const createOrganizationStmt = `INSERT INTO zitadel.organizations (id, name, instance_id, state)` +
` VALUES ($1, $2, $3, $4)` +
` RETURNING created_at, updated_at`
@@ -77,8 +106,8 @@ func (o *org) Create(ctx context.Context, organization *domain.Organization) err
// Update implements [domain.OrganizationRepository].
func (o *org) Update(ctx context.Context, id domain.OrgIdentifierCondition, instanceID string, changes ...database.Change) (int64, error) {
if changes == nil {
return 0, errors.New("Update must contain a condition") // (otherwise ALL organizations will be updated)
if len(changes) == 0 {
return 0, database.NoChangesError
}
builder := database.StatementBuilder{}
builder.WriteString(`UPDATE zitadel.organizations SET `)
@@ -115,12 +144,12 @@ func (o *org) Delete(ctx context.Context, id domain.OrgIdentifierCondition, inst
// SetName implements [domain.organizationChanges].
func (o org) SetName(name string) database.Change {
return database.NewChange(o.NameColumn(), name)
return database.NewChange(o.NameColumn(false), name)
}
// SetState implements [domain.organizationChanges].
func (o org) SetState(state domain.OrgState) database.Change {
return database.NewChange(o.StateColumn(), state)
return database.NewChange(o.StateColumn(false), state)
}
// -------------------------------------------------------------
@@ -129,22 +158,22 @@ func (o org) SetState(state domain.OrgState) database.Change {
// IDCondition implements [domain.organizationConditions].
func (o org) IDCondition(id string) domain.OrgIdentifierCondition {
return database.NewTextCondition(o.IDColumn(), database.TextOperationEqual, id)
return database.NewTextCondition(o.IDColumn(true), database.TextOperationEqual, id)
}
// NameCondition implements [domain.organizationConditions].
func (o org) NameCondition(name string) domain.OrgIdentifierCondition {
return database.NewTextCondition(o.NameColumn(), database.TextOperationEqual, name)
return database.NewTextCondition(o.NameColumn(true), database.TextOperationEqual, name)
}
// InstanceIDCondition implements [domain.organizationConditions].
func (o org) InstanceIDCondition(instanceID string) database.Condition {
return database.NewTextCondition(o.InstanceIDColumn(), database.TextOperationEqual, instanceID)
return database.NewTextCondition(o.InstanceIDColumn(true), database.TextOperationEqual, instanceID)
}
// StateCondition implements [domain.organizationConditions].
func (o org) StateCondition(state domain.OrgState) database.Condition {
return database.NewTextCondition(o.StateColumn(), database.TextOperationEqual, state)
return database.NewTextCondition(o.StateColumn(true), database.TextOperationEqual, state)
}
// -------------------------------------------------------------
@@ -152,32 +181,50 @@ func (o org) StateCondition(state domain.OrgState) database.Condition {
// -------------------------------------------------------------
// IDColumn implements [domain.organizationColumns].
func (org) IDColumn() database.Column {
func (org) IDColumn(qualified bool) database.Column {
if qualified {
return database.NewColumn("organizations.id")
}
return database.NewColumn("id")
}
// NameColumn implements [domain.organizationColumns].
func (org) NameColumn() database.Column {
func (org) NameColumn(qualified bool) database.Column {
if qualified {
return database.NewColumn("organizations.name")
}
return database.NewColumn("name")
}
// InstanceIDColumn implements [domain.organizationColumns].
func (org) InstanceIDColumn() database.Column {
func (org) InstanceIDColumn(qualified bool) database.Column {
if qualified {
return database.NewColumn("organizations.instance_id")
}
return database.NewColumn("instance_id")
}
// StateColumn implements [domain.organizationColumns].
func (org) StateColumn() database.Column {
func (org) StateColumn(qualified bool) database.Column {
if qualified {
return database.NewColumn("organizations.state")
}
return database.NewColumn("state")
}
// CreatedAtColumn implements [domain.organizationColumns].
func (org) CreatedAtColumn() database.Column {
func (org) CreatedAtColumn(qualified bool) database.Column {
if qualified {
return database.NewColumn("organizations.created_at")
}
return database.NewColumn("created_at")
}
// UpdatedAtColumn implements [domain.organizationColumns].
func (org) UpdatedAtColumn() database.Column {
func (org) UpdatedAtColumn(qualified bool) database.Column {
if qualified {
return database.NewColumn("organizations.updated_at")
}
return database.NewColumn("updated_at")
}
@@ -185,18 +232,28 @@ func (org) UpdatedAtColumn() database.Column {
// scanners
// -------------------------------------------------------------
type rawOrganization struct {
*domain.Organization
RawDomains json.RawMessage `json:"domains,omitempty" db:"domains"`
}
func scanOrganization(ctx context.Context, querier database.Querier, builder *database.StatementBuilder) (*domain.Organization, error) {
rows, err := querier.Query(ctx, builder.String(), builder.Args()...)
if err != nil {
return nil, err
}
organization := &domain.Organization{}
if err := rows.(database.CollectableRows).CollectExactlyOneRow(organization); err != nil {
var org rawOrganization
if err := rows.(database.CollectableRows).CollectExactlyOneRow(&org); err != nil {
return nil, err
}
if len(org.RawDomains) > 0 {
if err := json.Unmarshal(org.RawDomains, &org.Domains); err != nil {
return nil, err
}
}
return organization, nil
return org.Organization, nil
}
func scanOrganizations(ctx context.Context, querier database.Querier, builder *database.StatementBuilder) ([]*domain.Organization, error) {
@@ -205,10 +262,20 @@ func scanOrganizations(ctx context.Context, querier database.Querier, builder *d
return nil, err
}
organizations := []*domain.Organization{}
if err := rows.(database.CollectableRows).Collect(&organizations); err != nil {
var rawOrgs []*rawOrganization
if err := rows.(database.CollectableRows).Collect(&rawOrgs); err != nil {
return nil, err
}
organizations := make([]*domain.Organization, len(rawOrgs))
for i, org := range rawOrgs {
if len(org.RawDomains) > 0 {
if err := json.Unmarshal(org.RawDomains, &org.Domains); err != nil {
return nil, err
}
}
organizations[i] = org.Organization
}
return organizations, nil
}
@@ -216,8 +283,11 @@ func scanOrganizations(ctx context.Context, querier database.Querier, builder *d
// sub repositories
// -------------------------------------------------------------
func (o *org) Domains() domain.OrganizationDomainRepository {
o.shouldJoinDomains = true
// Domains implements [domain.OrganizationRepository].
func (o *org) Domains(shouldLoad bool) domain.OrganizationDomainRepository {
if !o.shouldLoadDomains {
o.shouldLoadDomains = shouldLoad
}
if o.domainRepo != nil {
return o.domainRepo

View File

@@ -61,7 +61,7 @@ func (o *orgDomain) Add(ctx context.Context, domain *domain.AddOrganizationDomai
`VALUES ($1, $2, $3, $4, $5, $6)` +
` RETURNING created_at, updated_at`)
builder.AppendArgs(domain.InstanceID, domain.OrgID, domain.Domain, domain.IsVerified, domain.IsPrimary, domain.VerificationType)
builder.AppendArgs(domain.InstanceID, domain.OrgID, domain.Domain, domain.IsVerified, domain.IsPrimary, domain.ValidationType)
return o.client.QueryRow(ctx, builder.String(), builder.Args()...).Scan(&domain.CreatedAt, &domain.UpdatedAt)
}
@@ -95,17 +95,17 @@ func (o *orgDomain) Remove(ctx context.Context, condition database.Condition) (i
// SetPrimary implements [domain.OrganizationDomainRepository].
func (o orgDomain) SetPrimary() database.Change {
return database.NewChange(o.IsPrimaryColumn(), true)
return database.NewChange(o.IsPrimaryColumn(false), true)
}
// SetVerificationType implements [domain.OrganizationDomainRepository].
func (o orgDomain) SetVerificationType(verificationType domain.DomainVerificationType) database.Change {
return database.NewChange(o.VerificationTypeColumn(), verificationType)
// SetValidationType implements [domain.OrganizationDomainRepository].
func (o orgDomain) SetValidationType(verificationType domain.DomainValidationType) database.Change {
return database.NewChange(o.ValidationTypeColumn(false), verificationType)
}
// SetVerified implements [domain.OrganizationDomainRepository].
func (o orgDomain) SetVerified() database.Change {
return database.NewChange(o.IsVerifiedColumn(), true)
return database.NewChange(o.IsVerifiedColumn(false), true)
}
// -------------------------------------------------------------
@@ -114,28 +114,28 @@ func (o orgDomain) SetVerified() database.Change {
// DomainCondition implements [domain.OrganizationDomainRepository].
func (o orgDomain) DomainCondition(op database.TextOperation, domain string) database.Condition {
return database.NewTextCondition(o.DomainColumn(), op, domain)
return database.NewTextCondition(o.DomainColumn(true), op, domain)
}
// InstanceIDCondition implements [domain.OrganizationDomainRepository].
// Subtle: this method shadows the method ([domain.OrganizationRepository]).InstanceIDCondition of orgDomain.org.
func (o orgDomain) InstanceIDCondition(instanceID string) database.Condition {
return database.NewTextCondition(o.InstanceIDColumn(), database.TextOperationEqual, instanceID)
return database.NewTextCondition(o.InstanceIDColumn(true), database.TextOperationEqual, instanceID)
}
// IsPrimaryCondition implements [domain.OrganizationDomainRepository].
func (o orgDomain) IsPrimaryCondition(isPrimary bool) database.Condition {
return database.NewBooleanCondition(o.IsPrimaryColumn(), isPrimary)
return database.NewBooleanCondition(o.IsPrimaryColumn(true), isPrimary)
}
// IsVerifiedCondition implements [domain.OrganizationDomainRepository].
func (o orgDomain) IsVerifiedCondition(isVerified bool) database.Condition {
return database.NewBooleanCondition(o.IsVerifiedColumn(), isVerified)
return database.NewBooleanCondition(o.IsVerifiedColumn(true), isVerified)
}
// OrgIDCondition implements [domain.OrganizationDomainRepository].
func (o orgDomain) OrgIDCondition(orgID string) database.Condition {
return database.NewTextCondition(o.OrgIDColumn(), database.TextOperationEqual, orgID)
return database.NewTextCondition(o.OrgIDColumn(true), database.TextOperationEqual, orgID)
}
// -------------------------------------------------------------
@@ -144,45 +144,69 @@ func (o orgDomain) OrgIDCondition(orgID string) database.Condition {
// CreatedAtColumn implements [domain.OrganizationDomainRepository].
// Subtle: this method shadows the method ([domain.OrganizationRepository]).CreatedAtColumn of orgDomain.org.
func (orgDomain) CreatedAtColumn() database.Column {
func (orgDomain) CreatedAtColumn(qualified bool) database.Column {
if qualified {
return database.NewColumn("org_domains.created_at")
}
return database.NewColumn("created_at")
}
// DomainColumn implements [domain.OrganizationDomainRepository].
func (orgDomain) DomainColumn() database.Column {
func (orgDomain) DomainColumn(qualified bool) database.Column {
if qualified {
return database.NewColumn("org_domains.domain")
}
return database.NewColumn("domain")
}
// InstanceIDColumn implements [domain.OrganizationDomainRepository].
// Subtle: this method shadows the method ([domain.OrganizationRepository]).InstanceIDColumn of orgDomain.org.
func (orgDomain) InstanceIDColumn() database.Column {
func (orgDomain) InstanceIDColumn(qualified bool) database.Column {
if qualified {
return database.NewColumn("org_domains.instance_id")
}
return database.NewColumn("instance_id")
}
// IsPrimaryColumn implements [domain.OrganizationDomainRepository].
func (orgDomain) IsPrimaryColumn() database.Column {
func (orgDomain) IsPrimaryColumn(qualified bool) database.Column {
if qualified {
return database.NewColumn("org_domains.is_primary")
}
return database.NewColumn("is_primary")
}
// IsVerifiedColumn implements [domain.OrganizationDomainRepository].
func (orgDomain) IsVerifiedColumn() database.Column {
func (orgDomain) IsVerifiedColumn(qualified bool) database.Column {
if qualified {
return database.NewColumn("org_domains.is_verified")
}
return database.NewColumn("is_verified")
}
// OrgIDColumn implements [domain.OrganizationDomainRepository].
func (orgDomain) OrgIDColumn() database.Column {
func (orgDomain) OrgIDColumn(qualified bool) database.Column {
if qualified {
return database.NewColumn("org_domains.org_id")
}
return database.NewColumn("org_id")
}
// UpdatedAtColumn implements [domain.OrganizationDomainRepository].
// Subtle: this method shadows the method ([domain.OrganizationRepository]).UpdatedAtColumn of orgDomain.org.
func (orgDomain) UpdatedAtColumn() database.Column {
func (orgDomain) UpdatedAtColumn(qualified bool) database.Column {
if qualified {
return database.NewColumn("org_domains.updated_at")
}
return database.NewColumn("updated_at")
}
// VerificationTypeColumn implements [domain.OrganizationDomainRepository].
func (orgDomain) VerificationTypeColumn() database.Column {
return database.NewColumn("verification_type")
// ValidationTypeColumn implements [domain.OrganizationDomainRepository].
func (orgDomain) ValidationTypeColumn(qualified bool) database.Column {
if qualified {
return database.NewColumn("org_domains.validation_type")
}
return database.NewColumn("validation_type")
}
// -------------------------------------------------------------

View File

@@ -235,8 +235,12 @@ func TestCreateOrganization(t *testing.T) {
// check organization values
organization, err = organizationRepo.Get(ctx,
organizationRepo.IDCondition(organization.ID),
organization.InstanceID,
database.WithCondition(
database.And(
organizationRepo.IDCondition(organization.ID),
organizationRepo.InstanceIDCondition(organization.InstanceID),
),
),
)
require.NoError(t, err)
@@ -389,8 +393,12 @@ func TestUpdateOrganization(t *testing.T) {
// check organization values
organization, err := organizationRepo.Get(ctx,
organizationRepo.IDCondition(createdOrg.ID),
createdOrg.InstanceID,
database.WithCondition(
database.And(
organizationRepo.IDCondition(createdOrg.ID),
organizationRepo.InstanceIDCondition(createdOrg.InstanceID),
),
),
)
require.NoError(t, err)
@@ -511,13 +519,18 @@ func TestGetOrganization(t *testing.T) {
// get org values
returnedOrg, err := orgRepo.Get(ctx,
tt.orgIdentifierCondition,
org.InstanceID,
database.WithCondition(
database.And(
tt.orgIdentifierCondition,
orgRepo.InstanceIDCondition(org.InstanceID),
),
),
)
if tt.err != nil {
require.ErrorIs(t, tt.err, err)
return
}
require.NoError(t, err)
if org.Name == "non existent org" {
assert.Nil(t, returnedOrg)
@@ -764,9 +777,15 @@ func TestListOrganization(t *testing.T) {
organizations := tt.testFunc(ctx, t)
var condition database.Condition
if len(tt.conditionClauses) > 0 {
condition = database.And(tt.conditionClauses...)
}
// check organization values
returnedOrgs, err := organizationRepo.List(ctx,
tt.conditionClauses...,
database.WithCondition(condition),
database.WithOrderBy(organizationRepo.CreatedAtColumn(true)),
)
require.NoError(t, err)
if tt.noOrganizationReturned {
@@ -929,8 +948,12 @@ func TestDeleteOrganization(t *testing.T) {
// check organization was deleted
organization, err := organizationRepo.Get(ctx,
tt.orgIdentifierCondition,
instanceId,
database.WithCondition(
database.And(
tt.orgIdentifierCondition,
organizationRepo.InstanceIDCondition(instanceId),
),
),
)
require.ErrorIs(t, err, new(database.NoRowFoundError))
assert.Nil(t, organization)