mirror of
https://github.com/zitadel/zitadel.git
synced 2025-12-23 23:46:42 +00:00
This pull request introduces a significant refactoring of the database interaction layer, focusing on improving explicitness, transactional control, and error handling. The core change is the removal of the stateful `QueryExecutor` from repository instances. Instead, it is now passed as an argument to each method that interacts with the database. This change makes transaction management more explicit and flexible, as the same repository instance can be used with a database pool or a specific transaction without needing to be re-instantiated. ### Key Changes - **Explicit `QueryExecutor` Passing:** - All repository methods (`Get`, `List`, `Create`, `Update`, `Delete`, etc.) in `InstanceRepository`, `OrganizationRepository`, `UserRepository`, and their sub-repositories now require a `database.QueryExecutor` (e.g., a `*pgxpool.Pool` or `pgx.Tx`) as the first argument. - Repository constructors no longer accept a `QueryExecutor`. For example, `repository.InstanceRepository(pool)` is now `repository.InstanceRepository()`. - **Enhanced Error Handling:** - A new `database.MissingConditionError` is introduced to enforce required query conditions, such as ensuring an `instance_id` is always present in `UPDATE` and `DELETE` operations. - The database error wrapper in the `postgres` package now correctly identifies and wraps `pgx.ErrTooManyRows` and similar errors from the `scany` library into a `database.MultipleRowsFoundError`. - **Improved Database Conditions:** - The `database.Condition` interface now includes a `ContainsColumn(Column) bool` method. This allows for runtime checks to ensure that critical filters (like `instance_id`) are included in a query, preventing accidental cross-tenant data modification. - A new `database.Exists()` condition has been added to support `EXISTS` subqueries, enabling more complex filtering logic, such as finding an organization that has a specific domain. - **Repository and Interface Refactoring:** - The method for loading related entities (e.g., domains for an organization) has been changed from a boolean flag (`Domains(true)`) to a more explicit, chainable method (`LoadDomains()`). This returns a new repository instance configured to load the sub-resource, promoting immutability. - The custom `OrgIdentifierCondition` has been removed in favor of using the standard `database.Condition` interface, simplifying the API. - **Code Cleanup and Test Updates:** - Unnecessary struct embeddings and metadata have been removed. - All integration and repository tests have been updated to reflect the new method signatures, passing the database pool or transaction object explicitly. - New tests have been added to cover the new `ExistsDomain` functionality and other enhancements. These changes make the data access layer more robust, predictable, and easier to work with, especially in the context of database transactions.
321 lines
11 KiB
Go
321 lines
11 KiB
Go
package repository
|
|
|
|
import (
|
|
"context"
|
|
"time"
|
|
|
|
"golang.org/x/text/language"
|
|
|
|
"github.com/zitadel/zitadel/backend/v3/domain"
|
|
"github.com/zitadel/zitadel/backend/v3/storage/database"
|
|
)
|
|
|
|
var _ domain.InstanceRepository = (*instance)(nil)
|
|
|
|
type instance struct {
|
|
shouldLoadDomains bool
|
|
domainRepo instanceDomain
|
|
}
|
|
|
|
func InstanceRepository() domain.InstanceRepository {
|
|
return new(instance)
|
|
}
|
|
|
|
func (instance) qualifiedTableName() string {
|
|
return "zitadel.instances"
|
|
}
|
|
|
|
func (instance) unqualifiedTableName() string {
|
|
return "instances"
|
|
}
|
|
|
|
// -------------------------------------------------------------
|
|
// repository
|
|
// -------------------------------------------------------------
|
|
|
|
const (
|
|
queryInstanceStmt = `SELECT instances.id, instances.name, instances.default_org_id, instances.iam_project_id, instances.console_client_id, instances.console_app_id, instances.default_language, instances.created_at, instances.updated_at` +
|
|
` , jsonb_agg(json_build_object('domain', instance_domains.domain, 'isPrimary', instance_domains.is_primary, 'isGenerated', instance_domains.is_generated, 'createdAt', instance_domains.created_at, 'updatedAt', instance_domains.updated_at)) FILTER (WHERE instance_domains.instance_id IS NOT NULL) AS domains` +
|
|
` FROM zitadel.instances`
|
|
)
|
|
|
|
// Get implements [domain.InstanceRepository].
|
|
func (i instance) Get(ctx context.Context, client database.QueryExecutor, opts ...database.QueryOption) (*domain.Instance, error) {
|
|
opts = append(opts,
|
|
i.joinDomains(),
|
|
database.WithGroupBy(i.IDColumn()),
|
|
)
|
|
|
|
options := new(database.QueryOpts)
|
|
for _, opt := range opts {
|
|
opt(options)
|
|
}
|
|
|
|
var builder database.StatementBuilder
|
|
builder.WriteString(queryInstanceStmt)
|
|
options.Write(&builder)
|
|
|
|
return scanInstance(ctx, client, &builder)
|
|
}
|
|
|
|
// List implements [domain.InstanceRepository].
|
|
func (i instance) List(ctx context.Context, client database.QueryExecutor, opts ...database.QueryOption) ([]*domain.Instance, error) {
|
|
opts = append(opts,
|
|
i.joinDomains(),
|
|
database.WithGroupBy(i.IDColumn()),
|
|
)
|
|
|
|
options := new(database.QueryOpts)
|
|
for _, opt := range opts {
|
|
opt(options)
|
|
}
|
|
|
|
var builder database.StatementBuilder
|
|
builder.WriteString(queryInstanceStmt)
|
|
options.Write(&builder)
|
|
|
|
return scanInstances(ctx, client, &builder)
|
|
}
|
|
|
|
// Create implements [domain.InstanceRepository].
|
|
func (i instance) Create(ctx context.Context, client database.QueryExecutor, instance *domain.Instance) error {
|
|
var (
|
|
builder database.StatementBuilder
|
|
createdAt, updatedAt any = database.DefaultInstruction, database.DefaultInstruction
|
|
)
|
|
if !instance.CreatedAt.IsZero() {
|
|
createdAt = instance.CreatedAt
|
|
}
|
|
if !instance.UpdatedAt.IsZero() {
|
|
updatedAt = instance.UpdatedAt
|
|
}
|
|
|
|
builder.WriteString(`INSERT INTO `)
|
|
builder.WriteString(i.qualifiedTableName())
|
|
builder.WriteString(` (id, name, default_org_id, iam_project_id, console_client_id, console_app_id, default_language, created_at, updated_at) VALUES (`)
|
|
builder.WriteArgs(instance.ID, instance.Name, instance.DefaultOrgID, instance.IAMProjectID, instance.ConsoleClientID, instance.ConsoleAppID, instance.DefaultLanguage, createdAt, updatedAt)
|
|
builder.WriteString(`) RETURNING created_at, updated_at`)
|
|
|
|
return client.QueryRow(ctx, builder.String(), builder.Args()...).Scan(&instance.CreatedAt, &instance.UpdatedAt)
|
|
}
|
|
|
|
// Update implements [domain.InstanceRepository].
|
|
func (i instance) Update(ctx context.Context, client database.QueryExecutor, id string, changes ...database.Change) (int64, error) {
|
|
if len(changes) == 0 {
|
|
return 0, database.ErrNoChanges
|
|
}
|
|
if !database.Changes(changes).IsOnColumn(i.UpdatedAtColumn()) {
|
|
changes = append(changes, database.NewChange(i.UpdatedAtColumn(), database.NullInstruction))
|
|
}
|
|
|
|
var builder database.StatementBuilder
|
|
builder.WriteString(`UPDATE zitadel.instances SET `)
|
|
database.Changes(changes).Write(&builder)
|
|
idCondition := i.IDCondition(id)
|
|
writeCondition(&builder, idCondition)
|
|
|
|
stmt := builder.String()
|
|
|
|
return client.Exec(ctx, stmt, builder.Args()...)
|
|
}
|
|
|
|
// Delete implements [domain.InstanceRepository].
|
|
func (i instance) Delete(ctx context.Context, client database.QueryExecutor, id string) (int64, error) {
|
|
var builder database.StatementBuilder
|
|
|
|
builder.WriteString(`DELETE FROM `)
|
|
builder.WriteString(i.qualifiedTableName())
|
|
|
|
idCondition := i.IDCondition(id)
|
|
writeCondition(&builder, idCondition)
|
|
|
|
return client.Exec(ctx, builder.String(), builder.Args()...)
|
|
}
|
|
|
|
// -------------------------------------------------------------
|
|
// changes
|
|
// -------------------------------------------------------------
|
|
|
|
// SetName implements [domain.instanceChanges].
|
|
func (i instance) SetName(name string) database.Change {
|
|
return database.NewChange(i.NameColumn(), name)
|
|
}
|
|
|
|
// SetUpdatedAt implements [domain.instanceChanges].
|
|
func (i instance) SetUpdatedAt(time time.Time) database.Change {
|
|
return database.NewChange(i.UpdatedAtColumn(), time)
|
|
}
|
|
|
|
func (i instance) SetIAMProject(id string) database.Change {
|
|
return database.NewChange(i.IAMProjectIDColumn(), id)
|
|
}
|
|
func (i instance) SetDefaultOrg(id string) database.Change {
|
|
return database.NewChange(i.DefaultOrgIDColumn(), id)
|
|
}
|
|
func (i instance) SetDefaultLanguage(lang language.Tag) database.Change {
|
|
return database.NewChange(i.DefaultLanguageColumn(), lang.String())
|
|
}
|
|
func (i instance) SetConsoleClientID(id string) database.Change {
|
|
return database.NewChange(i.ConsoleClientIDColumn(), id)
|
|
}
|
|
func (i instance) SetConsoleAppID(id string) database.Change {
|
|
return database.NewChange(i.ConsoleAppIDColumn(), id)
|
|
}
|
|
|
|
// -------------------------------------------------------------
|
|
// conditions
|
|
// -------------------------------------------------------------
|
|
|
|
// IDCondition implements [domain.instanceConditions].
|
|
func (i instance) IDCondition(id string) database.Condition {
|
|
return database.NewTextCondition(i.IDColumn(), database.TextOperationEqual, id)
|
|
}
|
|
|
|
// NameCondition implements [domain.instanceConditions].
|
|
func (i instance) NameCondition(op database.TextOperation, name string) database.Condition {
|
|
return database.NewTextCondition(i.NameColumn(), op, name)
|
|
}
|
|
|
|
// ExistsDomain creates a correlated [database.Exists] condition on instance_domains.
|
|
// Use this filter to make sure the Instance returned contains a specific domain.
|
|
// of the instance in the aggregated result.
|
|
// Example usage:
|
|
//
|
|
// domainRepo := instanceRepo.Domains(true) // ensure domains are loaded/aggregated
|
|
// instance, _ := instanceRepo.Get(ctx,
|
|
// database.WithCondition(
|
|
// database.And(
|
|
// instanceRepo.InstanceIDCondition(instanceID),
|
|
// instanceRepo.ExistsDomain(domainRepo.DomainCondition(database.TextOperationEqual, "example.com")),
|
|
// ),
|
|
// ),
|
|
// )
|
|
func (i instance) ExistsDomain(cond database.Condition) database.Condition {
|
|
return database.Exists(
|
|
i.domainRepo.qualifiedTableName(),
|
|
database.And(
|
|
database.NewColumnCondition(i.IDColumn(), i.domainRepo.InstanceIDColumn()),
|
|
cond,
|
|
),
|
|
)
|
|
}
|
|
|
|
// -------------------------------------------------------------
|
|
// columns
|
|
// -------------------------------------------------------------
|
|
|
|
// IDColumn implements [domain.instanceColumns].
|
|
func (i instance) IDColumn() database.Column {
|
|
return database.NewColumn(i.unqualifiedTableName(), "id")
|
|
}
|
|
|
|
// NameColumn implements [domain.instanceColumns].
|
|
func (i instance) NameColumn() database.Column {
|
|
return database.NewColumn(i.unqualifiedTableName(), "name")
|
|
}
|
|
|
|
// CreatedAtColumn implements [domain.instanceColumns].
|
|
func (i instance) CreatedAtColumn() database.Column {
|
|
return database.NewColumn(i.unqualifiedTableName(), "created_at")
|
|
}
|
|
|
|
// DefaultOrgIdColumn implements [domain.instanceColumns].
|
|
func (i instance) DefaultOrgIDColumn() database.Column {
|
|
return database.NewColumn(i.unqualifiedTableName(), "default_org_id")
|
|
}
|
|
|
|
// IAMProjectIDColumn implements [domain.instanceColumns].
|
|
func (i instance) IAMProjectIDColumn() database.Column {
|
|
return database.NewColumn(i.unqualifiedTableName(), "iam_project_id")
|
|
}
|
|
|
|
// ConsoleClientIDColumn implements [domain.instanceColumns].
|
|
func (i instance) ConsoleClientIDColumn() database.Column {
|
|
return database.NewColumn(i.unqualifiedTableName(), "console_client_id")
|
|
}
|
|
|
|
// ConsoleAppIDColumn implements [domain.instanceColumns].
|
|
func (i instance) ConsoleAppIDColumn() database.Column {
|
|
return database.NewColumn(i.unqualifiedTableName(), "console_app_id")
|
|
}
|
|
|
|
// DefaultLanguageColumn implements [domain.instanceColumns].
|
|
func (i instance) DefaultLanguageColumn() database.Column {
|
|
return database.NewColumn(i.unqualifiedTableName(), "default_language")
|
|
}
|
|
|
|
// UpdatedAtColumn implements [domain.instanceColumns].
|
|
func (i instance) UpdatedAtColumn() database.Column {
|
|
return database.NewColumn(i.unqualifiedTableName(), "updated_at")
|
|
}
|
|
|
|
// -------------------------------------------------------------
|
|
// scanners
|
|
// -------------------------------------------------------------
|
|
|
|
type rawInstance struct {
|
|
*domain.Instance
|
|
Domains JSONArray[domain.InstanceDomain] `json:"domains,omitempty" db:"domains"`
|
|
}
|
|
|
|
func scanInstance(ctx context.Context, querier database.Querier, builder *database.StatementBuilder) (*domain.Instance, error) {
|
|
rows, err := querier.Query(ctx, builder.String(), builder.Args()...)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
var instance rawInstance
|
|
if err := rows.(database.CollectableRows).CollectExactlyOneRow(&instance); err != nil {
|
|
return nil, err
|
|
}
|
|
instance.Instance.Domains = instance.Domains
|
|
|
|
return instance.Instance, nil
|
|
}
|
|
|
|
func scanInstances(ctx context.Context, querier database.Querier, builder *database.StatementBuilder) ([]*domain.Instance, error) {
|
|
rows, err := querier.Query(ctx, builder.String(), builder.Args()...)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
var instances []*rawInstance
|
|
if err := rows.(database.CollectableRows).Collect(&instances); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
result := make([]*domain.Instance, len(instances))
|
|
for i, inst := range instances {
|
|
result[i] = inst.Instance
|
|
result[i].Domains = inst.Domains
|
|
}
|
|
|
|
return result, nil
|
|
}
|
|
|
|
// -------------------------------------------------------------
|
|
// sub repositories
|
|
// -------------------------------------------------------------
|
|
|
|
func (i *instance) LoadDomains() domain.InstanceRepository {
|
|
return &instance{
|
|
shouldLoadDomains: true,
|
|
}
|
|
}
|
|
|
|
func (i *instance) joinDomains() database.QueryOption {
|
|
columns := make([]database.Condition, 0, 2)
|
|
columns = append(columns, database.NewColumnCondition(i.IDColumn(), i.domainRepo.InstanceIDColumn()))
|
|
|
|
// If domains should not be joined, we make sure to return null for the domain columns
|
|
// the query optimizer of the dialect should optimize this away if no domains are requested
|
|
if !i.shouldLoadDomains {
|
|
columns = append(columns, database.IsNull(i.domainRepo.InstanceIDColumn()))
|
|
}
|
|
|
|
return database.WithLeftJoin(
|
|
i.domainRepo.qualifiedTableName(),
|
|
database.And(columns...),
|
|
)
|
|
}
|