mirror of
https://github.com/zitadel/zitadel.git
synced 2025-08-12 16:27:32 +00:00
implementation done
This commit is contained in:
@@ -2,7 +2,7 @@ package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"encoding/json"
|
||||
|
||||
"github.com/zitadel/zitadel/backend/v3/domain"
|
||||
"github.com/zitadel/zitadel/backend/v3/storage/database"
|
||||
@@ -15,9 +15,9 @@ import (
|
||||
var _ domain.OrganizationRepository = (*org)(nil)
|
||||
|
||||
type org struct {
|
||||
shouldJoinDomains bool
|
||||
repository
|
||||
domainRepo domain.OrganizationDomainRepository
|
||||
shouldLoadDomains bool
|
||||
domainRepo domain.OrganizationDomainRepository
|
||||
}
|
||||
|
||||
func OrganizationRepository(client database.QueryExecutor) domain.OrganizationRepository {
|
||||
@@ -28,39 +28,68 @@ func OrganizationRepository(client database.QueryExecutor) domain.OrganizationRe
|
||||
}
|
||||
}
|
||||
|
||||
const queryOrganizationStmt = `SELECT id, name, instance_id, state, created_at, updated_at` +
|
||||
const queryOrganizationStmt = `SELECT organizations.id, organizations.name, organizations.instance_id, organizations.state, organizations.created_at, organizations.updated_at` +
|
||||
` , CASE WHEN count(org_domains.domain) > 0 THEN jsonb_agg(json_build_object('domain', org_domains.domain, 'isVerified', org_domains.is_verified, 'isPrimary', org_domains.is_primary, 'validationType', org_domains.validation_type, 'createdAt', org_domains.created_at, 'updatedAt', org_domains.updated_at)) ELSE NULL::JSONB END domains` +
|
||||
` FROM zitadel.organizations`
|
||||
|
||||
// Get implements [domain.OrganizationRepository].
|
||||
func (o *org) Get(ctx context.Context, id domain.OrgIdentifierCondition, instanceID string, conditions ...database.Condition) (*domain.Organization, error) {
|
||||
builder := database.StatementBuilder{}
|
||||
func (o *org) Get(ctx context.Context, opts ...database.QueryOption) (*domain.Organization, error) {
|
||||
opts = append(opts,
|
||||
o.joinDomains(),
|
||||
database.WithGroupBy(o.InstanceIDColumn(true), o.IDColumn(true)),
|
||||
)
|
||||
|
||||
options := new(database.QueryOpts)
|
||||
for _, opt := range opts {
|
||||
opt(options)
|
||||
}
|
||||
|
||||
var builder database.StatementBuilder
|
||||
builder.WriteString(queryOrganizationStmt)
|
||||
|
||||
instanceIDCondition := o.InstanceIDCondition(instanceID)
|
||||
|
||||
conditions = append(conditions, id, instanceIDCondition)
|
||||
writeCondition(&builder, database.And(conditions...))
|
||||
options.Write(&builder)
|
||||
|
||||
return scanOrganization(ctx, o.client, &builder)
|
||||
}
|
||||
|
||||
// List implements [domain.OrganizationRepository].
|
||||
func (o *org) List(ctx context.Context, conditions ...database.Condition) ([]*domain.Organization, error) {
|
||||
builder := database.StatementBuilder{}
|
||||
func (o *org) List(ctx context.Context, opts ...database.QueryOption) ([]*domain.Organization, error) {
|
||||
opts = append(opts,
|
||||
o.joinDomains(),
|
||||
database.WithGroupBy(o.InstanceIDColumn(true), o.IDColumn(true)),
|
||||
)
|
||||
|
||||
builder.WriteString(queryOrganizationStmt)
|
||||
|
||||
if conditions != nil {
|
||||
writeCondition(&builder, database.And(conditions...))
|
||||
options := new(database.QueryOpts)
|
||||
for _, opt := range opts {
|
||||
opt(options)
|
||||
}
|
||||
|
||||
orderBy := database.OrderBy(o.CreatedAtColumn())
|
||||
orderBy.Write(&builder)
|
||||
|
||||
var builder database.StatementBuilder
|
||||
builder.WriteString(queryOrganizationStmt)
|
||||
options.Write(&builder)
|
||||
|
||||
return scanOrganizations(ctx, o.client, &builder)
|
||||
}
|
||||
|
||||
func (o *org) joinDomains() database.QueryOption {
|
||||
columns := make([]database.Condition, 0, 3)
|
||||
columns = append(columns,
|
||||
database.NewColumnCondition(o.InstanceIDColumn(true), o.Domains(false).InstanceIDColumn(true)),
|
||||
database.NewColumnCondition(o.IDColumn(true), o.Domains(false).OrgIDColumn(true)),
|
||||
)
|
||||
|
||||
// If domains should not be joined, we make sure to return null for the domain columns
|
||||
// the query optimizer of the dialect should optimize this away if no domains are requested
|
||||
if !o.shouldLoadDomains {
|
||||
columns = append(columns, database.IsNull(o.domainRepo.OrgIDColumn(true)))
|
||||
}
|
||||
|
||||
return database.WithLeftJoin(
|
||||
"zitadel.org_domains",
|
||||
database.And(columns...),
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
const createOrganizationStmt = `INSERT INTO zitadel.organizations (id, name, instance_id, state)` +
|
||||
` VALUES ($1, $2, $3, $4)` +
|
||||
` RETURNING created_at, updated_at`
|
||||
@@ -77,8 +106,8 @@ func (o *org) Create(ctx context.Context, organization *domain.Organization) err
|
||||
|
||||
// Update implements [domain.OrganizationRepository].
|
||||
func (o *org) Update(ctx context.Context, id domain.OrgIdentifierCondition, instanceID string, changes ...database.Change) (int64, error) {
|
||||
if changes == nil {
|
||||
return 0, errors.New("Update must contain a condition") // (otherwise ALL organizations will be updated)
|
||||
if len(changes) == 0 {
|
||||
return 0, database.NoChangesError
|
||||
}
|
||||
builder := database.StatementBuilder{}
|
||||
builder.WriteString(`UPDATE zitadel.organizations SET `)
|
||||
@@ -115,12 +144,12 @@ func (o *org) Delete(ctx context.Context, id domain.OrgIdentifierCondition, inst
|
||||
|
||||
// SetName implements [domain.organizationChanges].
|
||||
func (o org) SetName(name string) database.Change {
|
||||
return database.NewChange(o.NameColumn(), name)
|
||||
return database.NewChange(o.NameColumn(false), name)
|
||||
}
|
||||
|
||||
// SetState implements [domain.organizationChanges].
|
||||
func (o org) SetState(state domain.OrgState) database.Change {
|
||||
return database.NewChange(o.StateColumn(), state)
|
||||
return database.NewChange(o.StateColumn(false), state)
|
||||
}
|
||||
|
||||
// -------------------------------------------------------------
|
||||
@@ -129,22 +158,22 @@ func (o org) SetState(state domain.OrgState) database.Change {
|
||||
|
||||
// IDCondition implements [domain.organizationConditions].
|
||||
func (o org) IDCondition(id string) domain.OrgIdentifierCondition {
|
||||
return database.NewTextCondition(o.IDColumn(), database.TextOperationEqual, id)
|
||||
return database.NewTextCondition(o.IDColumn(true), database.TextOperationEqual, id)
|
||||
}
|
||||
|
||||
// NameCondition implements [domain.organizationConditions].
|
||||
func (o org) NameCondition(name string) domain.OrgIdentifierCondition {
|
||||
return database.NewTextCondition(o.NameColumn(), database.TextOperationEqual, name)
|
||||
return database.NewTextCondition(o.NameColumn(true), database.TextOperationEqual, name)
|
||||
}
|
||||
|
||||
// InstanceIDCondition implements [domain.organizationConditions].
|
||||
func (o org) InstanceIDCondition(instanceID string) database.Condition {
|
||||
return database.NewTextCondition(o.InstanceIDColumn(), database.TextOperationEqual, instanceID)
|
||||
return database.NewTextCondition(o.InstanceIDColumn(true), database.TextOperationEqual, instanceID)
|
||||
}
|
||||
|
||||
// StateCondition implements [domain.organizationConditions].
|
||||
func (o org) StateCondition(state domain.OrgState) database.Condition {
|
||||
return database.NewTextCondition(o.StateColumn(), database.TextOperationEqual, state)
|
||||
return database.NewTextCondition(o.StateColumn(true), database.TextOperationEqual, state)
|
||||
}
|
||||
|
||||
// -------------------------------------------------------------
|
||||
@@ -152,32 +181,50 @@ func (o org) StateCondition(state domain.OrgState) database.Condition {
|
||||
// -------------------------------------------------------------
|
||||
|
||||
// IDColumn implements [domain.organizationColumns].
|
||||
func (org) IDColumn() database.Column {
|
||||
func (org) IDColumn(qualified bool) database.Column {
|
||||
if qualified {
|
||||
return database.NewColumn("organizations.id")
|
||||
}
|
||||
return database.NewColumn("id")
|
||||
}
|
||||
|
||||
// NameColumn implements [domain.organizationColumns].
|
||||
func (org) NameColumn() database.Column {
|
||||
func (org) NameColumn(qualified bool) database.Column {
|
||||
if qualified {
|
||||
return database.NewColumn("organizations.name")
|
||||
}
|
||||
return database.NewColumn("name")
|
||||
}
|
||||
|
||||
// InstanceIDColumn implements [domain.organizationColumns].
|
||||
func (org) InstanceIDColumn() database.Column {
|
||||
func (org) InstanceIDColumn(qualified bool) database.Column {
|
||||
if qualified {
|
||||
return database.NewColumn("organizations.instance_id")
|
||||
}
|
||||
return database.NewColumn("instance_id")
|
||||
}
|
||||
|
||||
// StateColumn implements [domain.organizationColumns].
|
||||
func (org) StateColumn() database.Column {
|
||||
func (org) StateColumn(qualified bool) database.Column {
|
||||
if qualified {
|
||||
return database.NewColumn("organizations.state")
|
||||
}
|
||||
return database.NewColumn("state")
|
||||
}
|
||||
|
||||
// CreatedAtColumn implements [domain.organizationColumns].
|
||||
func (org) CreatedAtColumn() database.Column {
|
||||
func (org) CreatedAtColumn(qualified bool) database.Column {
|
||||
if qualified {
|
||||
return database.NewColumn("organizations.created_at")
|
||||
}
|
||||
return database.NewColumn("created_at")
|
||||
}
|
||||
|
||||
// UpdatedAtColumn implements [domain.organizationColumns].
|
||||
func (org) UpdatedAtColumn() database.Column {
|
||||
func (org) UpdatedAtColumn(qualified bool) database.Column {
|
||||
if qualified {
|
||||
return database.NewColumn("organizations.updated_at")
|
||||
}
|
||||
return database.NewColumn("updated_at")
|
||||
}
|
||||
|
||||
@@ -185,18 +232,28 @@ func (org) UpdatedAtColumn() database.Column {
|
||||
// scanners
|
||||
// -------------------------------------------------------------
|
||||
|
||||
type rawOrganization struct {
|
||||
*domain.Organization
|
||||
RawDomains json.RawMessage `json:"domains,omitempty" db:"domains"`
|
||||
}
|
||||
|
||||
func scanOrganization(ctx context.Context, querier database.Querier, builder *database.StatementBuilder) (*domain.Organization, error) {
|
||||
rows, err := querier.Query(ctx, builder.String(), builder.Args()...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
organization := &domain.Organization{}
|
||||
if err := rows.(database.CollectableRows).CollectExactlyOneRow(organization); err != nil {
|
||||
var org rawOrganization
|
||||
if err := rows.(database.CollectableRows).CollectExactlyOneRow(&org); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(org.RawDomains) > 0 {
|
||||
if err := json.Unmarshal(org.RawDomains, &org.Domains); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return organization, nil
|
||||
return org.Organization, nil
|
||||
}
|
||||
|
||||
func scanOrganizations(ctx context.Context, querier database.Querier, builder *database.StatementBuilder) ([]*domain.Organization, error) {
|
||||
@@ -205,10 +262,20 @@ func scanOrganizations(ctx context.Context, querier database.Querier, builder *d
|
||||
return nil, err
|
||||
}
|
||||
|
||||
organizations := []*domain.Organization{}
|
||||
if err := rows.(database.CollectableRows).Collect(&organizations); err != nil {
|
||||
var rawOrgs []*rawOrganization
|
||||
if err := rows.(database.CollectableRows).Collect(&rawOrgs); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
organizations := make([]*domain.Organization, len(rawOrgs))
|
||||
for i, org := range rawOrgs {
|
||||
if len(org.RawDomains) > 0 {
|
||||
if err := json.Unmarshal(org.RawDomains, &org.Domains); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
organizations[i] = org.Organization
|
||||
}
|
||||
return organizations, nil
|
||||
}
|
||||
|
||||
@@ -216,8 +283,11 @@ func scanOrganizations(ctx context.Context, querier database.Querier, builder *d
|
||||
// sub repositories
|
||||
// -------------------------------------------------------------
|
||||
|
||||
func (o *org) Domains() domain.OrganizationDomainRepository {
|
||||
o.shouldJoinDomains = true
|
||||
// Domains implements [domain.OrganizationRepository].
|
||||
func (o *org) Domains(shouldLoad bool) domain.OrganizationDomainRepository {
|
||||
if !o.shouldLoadDomains {
|
||||
o.shouldLoadDomains = shouldLoad
|
||||
}
|
||||
|
||||
if o.domainRepo != nil {
|
||||
return o.domainRepo
|
||||
|
Reference in New Issue
Block a user