mirror of
https://github.com/zitadel/zitadel.git
synced 2025-12-24 02:48:09 +00:00
fix(repo): correct mapping for domains (#10653)
This pull request fixes an issue where the repository would fail to scan organization or instance structs if the `domains` column was `NULL`. ## Which problems are solved If the `domains` column of `orgs` or `instances` was `NULL`, the repository failed scanning into the structs. This happened because the scanning mechanism did not correctly handle `NULL` JSONB columns. ## How the problems are solved A new generic type `JSONArray[T]` is introduced, which implements the `sql.Scanner` interface. This type can correctly scan JSON arrays from the database, including handling `NULL` values gracefully. The repositories for instances and organizations have been updated to use this new type for the domains field. The SQL queries have also been improved to use `FILTER` with `jsonb_agg` for better readability and performance when aggregating domains. ## Additional changes * An unnecessary cleanup step in the organization domain tests for already removed domains has been removed. * The `pgxscan` library has been replaced with `sqlscan` for scanning `database/sql`.Rows. * Minor cleanups in integration tests.
This commit is contained in:
3
.github/workflows/core.yml
vendored
3
.github/workflows/core.yml
vendored
@@ -17,6 +17,7 @@ on:
|
||||
|
||||
env:
|
||||
cache_path: |
|
||||
backend
|
||||
internal/statik/statik.go
|
||||
internal/notification/statik/statik.go
|
||||
internal/api/ui/login/static/resources/themes/zitadel/css/zitadel.css*
|
||||
@@ -43,7 +44,7 @@ jobs:
|
||||
continue-on-error: true
|
||||
id: cache
|
||||
with:
|
||||
key: core-${{ hashFiles( 'go.*', 'openapi', 'cmd', 'pkg/grpc/**/*.go', 'proto', 'internal') }}
|
||||
key: core-${{ hashFiles( 'go.*', 'openapi', 'cmd', 'pkg/grpc/**/*.go', 'proto', 'internal', 'backend/**') }}
|
||||
restore-keys: |
|
||||
core-
|
||||
path: ${{ env.cache_path }}
|
||||
|
||||
@@ -3,7 +3,7 @@ package sql
|
||||
import (
|
||||
"database/sql"
|
||||
|
||||
pgxscan "github.com/georgysavva/scany/v2/dbscan"
|
||||
"github.com/georgysavva/scany/v2/sqlscan"
|
||||
"github.com/jackc/pgx/v5"
|
||||
|
||||
"github.com/zitadel/zitadel/backend/v3/storage/database"
|
||||
@@ -44,7 +44,7 @@ func (r *Rows) Collect(dest any) (err error) {
|
||||
err = closeErr
|
||||
}
|
||||
}()
|
||||
return wrapError(pgxscan.ScanAll(dest, r.Rows))
|
||||
return wrapError(sqlscan.ScanAll(dest, r.Rows))
|
||||
}
|
||||
|
||||
// CollectFirst implements [database.CollectableRows].
|
||||
@@ -56,7 +56,7 @@ func (r *Rows) CollectFirst(dest any) (err error) {
|
||||
err = closeErr
|
||||
}
|
||||
}()
|
||||
return wrapError(pgxscan.ScanRow(dest, r.Rows))
|
||||
return wrapError(sqlscan.ScanRow(dest, r.Rows))
|
||||
}
|
||||
|
||||
// CollectExactlyOneRow implements [database.CollectableRows].
|
||||
@@ -68,7 +68,7 @@ func (r *Rows) CollectExactlyOneRow(dest any) (err error) {
|
||||
err = closeErr
|
||||
}
|
||||
}()
|
||||
return wrapError(pgxscan.ScanOne(dest, r.Rows))
|
||||
return wrapError(sqlscan.ScanOne(dest, r.Rows))
|
||||
}
|
||||
|
||||
// Close implements [database.Rows].
|
||||
|
||||
@@ -35,11 +35,11 @@ func TestServer_TestInstanceDomainReduces(t *testing.T) {
|
||||
|
||||
// Wait for instance to be created
|
||||
retryDuration, tick := integration.WaitForAndTickWithMaxDuration(CTX, time.Minute)
|
||||
assert.EventuallyWithT(t, func(ttt *assert.CollectT) {
|
||||
assert.EventuallyWithT(t, func(t *assert.CollectT) {
|
||||
_, err := instanceRepo.Get(CTX,
|
||||
database.WithCondition(instanceRepo.IDCondition(instance.Instance.Id)),
|
||||
database.WithCondition(instanceRepo.IDCondition(instance.ID())),
|
||||
)
|
||||
assert.NoError(ttt, err)
|
||||
assert.NoError(t, err)
|
||||
}, retryDuration, tick)
|
||||
|
||||
t.Run("test instance custom domain add reduces", func(t *testing.T) {
|
||||
@@ -65,7 +65,7 @@ func TestServer_TestInstanceDomainReduces(t *testing.T) {
|
||||
|
||||
// Test that domain add reduces
|
||||
retryDuration, tick = integration.WaitForAndTickWithMaxDuration(CTX, time.Minute)
|
||||
assert.EventuallyWithT(t, func(ttt *assert.CollectT) {
|
||||
assert.EventuallyWithT(t, func(t *assert.CollectT) {
|
||||
domain, err := instanceDomainRepo.Get(CTX,
|
||||
database.WithCondition(
|
||||
database.And(
|
||||
@@ -75,13 +75,13 @@ func TestServer_TestInstanceDomainReduces(t *testing.T) {
|
||||
),
|
||||
),
|
||||
)
|
||||
require.NoError(ttt, err)
|
||||
require.NoError(t, err)
|
||||
// event instance.domain.added
|
||||
assert.Equal(ttt, domainName, domain.Domain)
|
||||
assert.Equal(ttt, instance.Instance.Id, domain.InstanceID)
|
||||
assert.False(ttt, *domain.IsPrimary)
|
||||
assert.WithinRange(ttt, domain.CreatedAt, beforeAdd, afterAdd)
|
||||
assert.WithinRange(ttt, domain.UpdatedAt, beforeAdd, afterAdd)
|
||||
assert.Equal(t, domainName, domain.Domain)
|
||||
assert.Equal(t, instance.Instance.Id, domain.InstanceID)
|
||||
assert.False(t, *domain.IsPrimary)
|
||||
assert.WithinRange(t, domain.CreatedAt, beforeAdd, afterAdd)
|
||||
assert.WithinRange(t, domain.UpdatedAt, beforeAdd, afterAdd)
|
||||
}, retryDuration, tick)
|
||||
})
|
||||
|
||||
@@ -124,7 +124,7 @@ func TestServer_TestInstanceDomainReduces(t *testing.T) {
|
||||
|
||||
// Wait for domain to be created
|
||||
retryDuration, tick = integration.WaitForAndTickWithMaxDuration(CTX, time.Minute)
|
||||
assert.EventuallyWithT(t, func(ttt *assert.CollectT) {
|
||||
assert.EventuallyWithT(t, func(t *assert.CollectT) {
|
||||
domain, err := instanceDomainRepo.Get(CTX,
|
||||
database.WithCondition(
|
||||
database.And(
|
||||
@@ -134,9 +134,9 @@ func TestServer_TestInstanceDomainReduces(t *testing.T) {
|
||||
),
|
||||
),
|
||||
)
|
||||
require.NoError(ttt, err)
|
||||
require.False(ttt, *domain.IsPrimary)
|
||||
assert.Equal(ttt, domainName, domain.Domain)
|
||||
require.NoError(t, err)
|
||||
require.False(t, *domain.IsPrimary)
|
||||
assert.Equal(t, domainName, domain.Domain)
|
||||
}, retryDuration, tick)
|
||||
|
||||
// Set domain as primary
|
||||
@@ -150,7 +150,7 @@ func TestServer_TestInstanceDomainReduces(t *testing.T) {
|
||||
|
||||
// Test that set primary reduces
|
||||
retryDuration, tick = integration.WaitForAndTickWithMaxDuration(CTX, time.Minute)
|
||||
assert.EventuallyWithT(t, func(ttt *assert.CollectT) {
|
||||
assert.EventuallyWithT(t, func(t *assert.CollectT) {
|
||||
domain, err := instanceDomainRepo.Get(CTX,
|
||||
database.WithCondition(
|
||||
database.And(
|
||||
@@ -160,11 +160,11 @@ func TestServer_TestInstanceDomainReduces(t *testing.T) {
|
||||
),
|
||||
),
|
||||
)
|
||||
require.NoError(ttt, err)
|
||||
require.NoError(t, err)
|
||||
// event instance.domain.primary.set
|
||||
assert.Equal(ttt, domainName, domain.Domain)
|
||||
assert.True(ttt, *domain.IsPrimary)
|
||||
assert.WithinRange(ttt, domain.UpdatedAt, beforeSetPrimary, afterSetPrimary)
|
||||
assert.Equal(t, domainName, domain.Domain)
|
||||
assert.True(t, *domain.IsPrimary)
|
||||
assert.WithinRange(t, domain.UpdatedAt, beforeSetPrimary, afterSetPrimary)
|
||||
}, retryDuration, tick)
|
||||
})
|
||||
|
||||
@@ -179,7 +179,7 @@ func TestServer_TestInstanceDomainReduces(t *testing.T) {
|
||||
|
||||
// Wait for domain to be created and verify it exists
|
||||
retryDuration, tick = integration.WaitForAndTickWithMaxDuration(CTX, time.Minute)
|
||||
assert.EventuallyWithT(t, func(ttt *assert.CollectT) {
|
||||
assert.EventuallyWithT(t, func(t *assert.CollectT) {
|
||||
_, err := instanceDomainRepo.Get(CTX,
|
||||
database.WithCondition(
|
||||
database.And(
|
||||
@@ -189,7 +189,7 @@ func TestServer_TestInstanceDomainReduces(t *testing.T) {
|
||||
),
|
||||
),
|
||||
)
|
||||
require.NoError(ttt, err)
|
||||
require.NoError(t, err)
|
||||
}, retryDuration, tick)
|
||||
|
||||
// Remove the domain
|
||||
@@ -201,7 +201,7 @@ func TestServer_TestInstanceDomainReduces(t *testing.T) {
|
||||
|
||||
// Test that domain remove reduces
|
||||
retryDuration, tick = integration.WaitForAndTickWithMaxDuration(CTX, time.Minute)
|
||||
assert.EventuallyWithT(t, func(ttt *assert.CollectT) {
|
||||
assert.EventuallyWithT(t, func(t *assert.CollectT) {
|
||||
domain, err := instanceDomainRepo.Get(CTX,
|
||||
database.WithCondition(
|
||||
database.And(
|
||||
@@ -212,8 +212,8 @@ func TestServer_TestInstanceDomainReduces(t *testing.T) {
|
||||
),
|
||||
)
|
||||
// event instance.domain.removed
|
||||
assert.Nil(ttt, domain)
|
||||
require.ErrorIs(ttt, err, new(database.NoRowFoundError))
|
||||
assert.Nil(t, domain)
|
||||
require.ErrorIs(t, err, new(database.NoRowFoundError))
|
||||
}, retryDuration, tick)
|
||||
})
|
||||
|
||||
@@ -240,7 +240,7 @@ func TestServer_TestInstanceDomainReduces(t *testing.T) {
|
||||
|
||||
// Test that domain add reduces
|
||||
retryDuration, tick = integration.WaitForAndTickWithMaxDuration(CTX, time.Minute)
|
||||
assert.EventuallyWithT(t, func(ttt *assert.CollectT) {
|
||||
assert.EventuallyWithT(t, func(t *assert.CollectT) {
|
||||
domain, err := instanceDomainRepo.Get(CTX,
|
||||
database.WithCondition(
|
||||
database.And(
|
||||
@@ -250,12 +250,12 @@ func TestServer_TestInstanceDomainReduces(t *testing.T) {
|
||||
),
|
||||
),
|
||||
)
|
||||
require.NoError(ttt, err)
|
||||
require.NoError(t, err)
|
||||
// event instance.domain.added
|
||||
assert.Equal(ttt, domainName, domain.Domain)
|
||||
assert.Equal(ttt, instance.Instance.Id, domain.InstanceID)
|
||||
assert.WithinRange(ttt, domain.CreatedAt, beforeAdd, afterAdd)
|
||||
assert.WithinRange(ttt, domain.UpdatedAt, beforeAdd, afterAdd)
|
||||
assert.Equal(t, domainName, domain.Domain)
|
||||
assert.Equal(t, instance.Instance.Id, domain.InstanceID)
|
||||
assert.WithinRange(t, domain.CreatedAt, beforeAdd, afterAdd)
|
||||
assert.WithinRange(t, domain.UpdatedAt, beforeAdd, afterAdd)
|
||||
}, retryDuration, tick)
|
||||
})
|
||||
|
||||
@@ -270,7 +270,7 @@ func TestServer_TestInstanceDomainReduces(t *testing.T) {
|
||||
|
||||
// Wait for domain to be created and verify it exists
|
||||
retryDuration, tick = integration.WaitForAndTickWithMaxDuration(CTX, time.Minute)
|
||||
assert.EventuallyWithT(t, func(ttt *assert.CollectT) {
|
||||
assert.EventuallyWithT(t, func(t *assert.CollectT) {
|
||||
_, err := instanceDomainRepo.Get(CTX,
|
||||
database.WithCondition(
|
||||
database.And(
|
||||
@@ -280,7 +280,7 @@ func TestServer_TestInstanceDomainReduces(t *testing.T) {
|
||||
),
|
||||
),
|
||||
)
|
||||
require.NoError(ttt, err)
|
||||
require.NoError(t, err)
|
||||
}, retryDuration, tick)
|
||||
|
||||
// Remove the domain
|
||||
@@ -292,7 +292,7 @@ func TestServer_TestInstanceDomainReduces(t *testing.T) {
|
||||
|
||||
// Test that domain remove reduces
|
||||
retryDuration, tick = integration.WaitForAndTickWithMaxDuration(CTX, time.Minute)
|
||||
assert.EventuallyWithT(t, func(ttt *assert.CollectT) {
|
||||
assert.EventuallyWithT(t, func(t *assert.CollectT) {
|
||||
domain, err := instanceDomainRepo.Get(CTX,
|
||||
database.WithCondition(
|
||||
database.And(
|
||||
@@ -303,8 +303,8 @@ func TestServer_TestInstanceDomainReduces(t *testing.T) {
|
||||
),
|
||||
)
|
||||
// event instance.domain.removed
|
||||
assert.Nil(ttt, domain)
|
||||
require.ErrorIs(ttt, err, new(database.NoRowFoundError))
|
||||
assert.Nil(t, domain)
|
||||
require.ErrorIs(t, err, new(database.NoRowFoundError))
|
||||
}, retryDuration, tick)
|
||||
})
|
||||
}
|
||||
|
||||
@@ -45,13 +45,13 @@ func TestServer_TestInstanceReduces(t *testing.T) {
|
||||
})
|
||||
|
||||
retryDuration, tick := integration.WaitForAndTickWithMaxDuration(CTX, time.Minute)
|
||||
assert.EventuallyWithT(t, func(ttt *assert.CollectT) {
|
||||
assert.EventuallyWithT(t, func(t *assert.CollectT) {
|
||||
instance, err := instanceRepo.Get(CTX,
|
||||
database.WithCondition(instanceRepo.IDCondition(instance.GetInstanceId())),
|
||||
)
|
||||
require.NoError(ttt, err)
|
||||
require.NoError(t, err)
|
||||
// event instance.added
|
||||
assert.Equal(ttt, instanceName, instance.Name)
|
||||
assert.Equal(t, instanceName, instance.Name)
|
||||
// event instance.default.org.set
|
||||
assert.NotNil(t, instance.DefaultOrgID)
|
||||
// event instance.iam.project.set
|
||||
|
||||
@@ -36,11 +36,11 @@ func TestServer_TestOrgDomainReduces(t *testing.T) {
|
||||
|
||||
// Wait for org to be created
|
||||
retryDuration, tick := integration.WaitForAndTickWithMaxDuration(CTX, time.Minute)
|
||||
assert.EventuallyWithT(t, func(ttt *assert.CollectT) {
|
||||
assert.EventuallyWithT(t, func(t *assert.CollectT) {
|
||||
_, err := orgRepo.Get(CTX,
|
||||
database.WithCondition(orgRepo.IDCondition(org.GetId())),
|
||||
)
|
||||
assert.NoError(ttt, err)
|
||||
assert.NoError(t, err)
|
||||
}, retryDuration, tick)
|
||||
|
||||
// The API call also sets the domain as primary, so we don't do a separate test for that.
|
||||
@@ -67,7 +67,7 @@ func TestServer_TestOrgDomainReduces(t *testing.T) {
|
||||
|
||||
// Test that domain add reduces
|
||||
retryDuration, tick = integration.WaitForAndTickWithMaxDuration(CTX, time.Minute)
|
||||
assert.EventuallyWithT(t, func(ttt *assert.CollectT) {
|
||||
assert.EventuallyWithT(t, func(t *assert.CollectT) {
|
||||
gottenDomain, err := orgDomainRepo.Get(CTX,
|
||||
database.WithCondition(
|
||||
database.And(
|
||||
@@ -77,7 +77,7 @@ func TestServer_TestOrgDomainReduces(t *testing.T) {
|
||||
),
|
||||
),
|
||||
)
|
||||
require.NoError(ttt, err)
|
||||
require.NoError(t, err)
|
||||
// event org.domain.added
|
||||
assert.Equal(t, domainName, gottenDomain.Domain)
|
||||
assert.Equal(t, Instance.Instance.Id, gottenDomain.InstanceID)
|
||||
@@ -97,16 +97,6 @@ func TestServer_TestOrgDomainReduces(t *testing.T) {
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
t.Cleanup(func() {
|
||||
_, err := OrgClient.DeleteOrganizationDomain(CTX, &v2beta.DeleteOrganizationDomainRequest{
|
||||
OrganizationId: org.GetId(),
|
||||
Domain: domainName,
|
||||
})
|
||||
if err != nil {
|
||||
t.Logf("Failed to delete domain on cleanup: %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
// Remove the domain
|
||||
_, err = OrgClient.DeleteOrganizationDomain(CTX, &v2beta.DeleteOrganizationDomainRequest{
|
||||
OrganizationId: org.GetId(),
|
||||
@@ -116,7 +106,7 @@ func TestServer_TestOrgDomainReduces(t *testing.T) {
|
||||
|
||||
// Test that domain remove reduces
|
||||
retryDuration, tick = integration.WaitForAndTickWithMaxDuration(CTX, time.Minute)
|
||||
assert.EventuallyWithT(t, func(ttt *assert.CollectT) {
|
||||
assert.EventuallyWithT(t, func(t *assert.CollectT) {
|
||||
domain, err := orgDomainRepo.Get(CTX,
|
||||
database.WithCondition(
|
||||
database.And(
|
||||
@@ -126,8 +116,8 @@ func TestServer_TestOrgDomainReduces(t *testing.T) {
|
||||
),
|
||||
)
|
||||
// event instance.domain.removed
|
||||
assert.Nil(ttt, domain)
|
||||
require.ErrorIs(ttt, err, new(database.NoRowFoundError))
|
||||
assert.Nil(t, domain)
|
||||
require.ErrorIs(t, err, new(database.NoRowFoundError))
|
||||
}, retryDuration, tick)
|
||||
})
|
||||
}
|
||||
|
||||
21
backend/v3/storage/database/repository/array.go
Normal file
21
backend/v3/storage/database/repository/array.go
Normal file
@@ -0,0 +1,21 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
)
|
||||
|
||||
type JSONArray[T any] []*T
|
||||
|
||||
func (a JSONArray[T]) Scan(src any) error {
|
||||
switch s := src.(type) {
|
||||
case string:
|
||||
return json.Unmarshal([]byte(s), &a)
|
||||
case []byte:
|
||||
return json.Unmarshal(s, &a)
|
||||
case nil:
|
||||
return nil
|
||||
default:
|
||||
return errors.New("unsupported scan source")
|
||||
}
|
||||
}
|
||||
@@ -2,8 +2,6 @@ package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"time"
|
||||
|
||||
"golang.org/x/text/language"
|
||||
@@ -34,7 +32,7 @@ func InstanceRepository(client database.QueryExecutor) domain.InstanceRepository
|
||||
|
||||
const (
|
||||
queryInstanceStmt = `SELECT instances.id, instances.name, instances.default_org_id, instances.iam_project_id, instances.console_client_id, instances.console_app_id, instances.default_language, instances.created_at, instances.updated_at` +
|
||||
` , CASE WHEN count(instance_domains.domain) > 0 THEN jsonb_agg(json_build_object('domain', instance_domains.domain, 'isPrimary', instance_domains.is_primary, 'isGenerated', instance_domains.is_generated, 'createdAt', instance_domains.created_at, 'updatedAt', instance_domains.updated_at)) ELSE NULL::JSONB END domains` +
|
||||
` , jsonb_agg(json_build_object('domain', instance_domains.domain, 'isPrimary', instance_domains.is_primary, 'isGenerated', instance_domains.is_generated, 'createdAt', instance_domains.created_at, 'updatedAt', instance_domains.updated_at)) FILTER (WHERE instance_domains.instance_id IS NOT NULL) AS domains` +
|
||||
` FROM zitadel.instances`
|
||||
)
|
||||
|
||||
@@ -236,9 +234,13 @@ func (instance) UpdatedAtColumn() database.Column {
|
||||
return database.NewColumn("instances", "updated_at")
|
||||
}
|
||||
|
||||
// -------------------------------------------------------------
|
||||
// scanners
|
||||
// -------------------------------------------------------------
|
||||
|
||||
type rawInstance struct {
|
||||
*domain.Instance
|
||||
RawDomains sql.Null[json.RawMessage] `json:"domains,omitzero" db:"domains"`
|
||||
Domains JSONArray[domain.InstanceDomain] `json:"domains,omitempty" db:"domains"`
|
||||
}
|
||||
|
||||
func scanInstance(ctx context.Context, querier database.Querier, builder *database.StatementBuilder) (*domain.Instance, error) {
|
||||
@@ -251,12 +253,7 @@ func scanInstance(ctx context.Context, querier database.Querier, builder *databa
|
||||
if err := rows.(database.CollectableRows).CollectExactlyOneRow(&instance); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if instance.RawDomains.Valid {
|
||||
if err := json.Unmarshal(instance.RawDomains.V, &instance.Domains); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
instance.Instance.Domains = instance.Domains
|
||||
|
||||
return instance.Instance, nil
|
||||
}
|
||||
@@ -267,21 +264,18 @@ func scanInstances(ctx context.Context, querier database.Querier, builder *datab
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var rawInstances []*rawInstance
|
||||
if err := rows.(database.CollectableRows).Collect(&rawInstances); err != nil {
|
||||
var instances []*rawInstance
|
||||
if err := rows.(database.CollectableRows).Collect(&instances); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
instances := make([]*domain.Instance, len(rawInstances))
|
||||
for i, instance := range rawInstances {
|
||||
if instance.RawDomains.Valid {
|
||||
if err := json.Unmarshal(instance.RawDomains.V, &instance.Domains); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
instances[i] = instance.Instance
|
||||
result := make([]*domain.Instance, len(instances))
|
||||
for i, inst := range instances {
|
||||
result[i] = inst.Instance
|
||||
result[i].Domains = inst.Domains
|
||||
}
|
||||
return instances, nil
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// -------------------------------------------------------------
|
||||
|
||||
@@ -2,7 +2,6 @@ package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
|
||||
"github.com/zitadel/zitadel/backend/v3/domain"
|
||||
"github.com/zitadel/zitadel/backend/v3/storage/database"
|
||||
@@ -29,7 +28,7 @@ func OrganizationRepository(client database.QueryExecutor) domain.OrganizationRe
|
||||
}
|
||||
|
||||
const queryOrganizationStmt = `SELECT organizations.id, organizations.name, organizations.instance_id, organizations.state, organizations.created_at, organizations.updated_at` +
|
||||
` , CASE WHEN count(org_domains.domain) > 0 THEN jsonb_agg(json_build_object('domain', org_domains.domain, 'isVerified', org_domains.is_verified, 'isPrimary', org_domains.is_primary, 'validationType', org_domains.validation_type, 'createdAt', org_domains.created_at, 'updatedAt', org_domains.updated_at)) ELSE NULL::JSONB END domains` +
|
||||
` , jsonb_agg(json_build_object('domain', org_domains.domain, 'isVerified', org_domains.is_verified, 'isPrimary', org_domains.is_primary, 'validationType', org_domains.validation_type, 'createdAt', org_domains.created_at, 'updatedAt', org_domains.updated_at)) FILTER (WHERE org_domains.org_id IS NOT NULL) AS domains` +
|
||||
` FROM zitadel.organizations`
|
||||
|
||||
// Get implements [domain.OrganizationRepository].
|
||||
@@ -212,9 +211,9 @@ func (org) UpdatedAtColumn() database.Column {
|
||||
// scanners
|
||||
// -------------------------------------------------------------
|
||||
|
||||
type rawOrganization struct {
|
||||
type rawOrg struct {
|
||||
*domain.Organization
|
||||
RawDomains json.RawMessage `json:"domains,omitempty" db:"domains"`
|
||||
Domains JSONArray[domain.OrganizationDomain] `json:"domains,omitempty" db:"domains"`
|
||||
}
|
||||
|
||||
func scanOrganization(ctx context.Context, querier database.Querier, builder *database.StatementBuilder) (*domain.Organization, error) {
|
||||
@@ -223,15 +222,11 @@ func scanOrganization(ctx context.Context, querier database.Querier, builder *da
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var org rawOrganization
|
||||
var org rawOrg
|
||||
if err := rows.(database.CollectableRows).CollectExactlyOneRow(&org); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(org.RawDomains) > 0 {
|
||||
if err := json.Unmarshal(org.RawDomains, &org.Domains); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
org.Organization.Domains = org.Domains
|
||||
|
||||
return org.Organization, nil
|
||||
}
|
||||
@@ -242,21 +237,18 @@ func scanOrganizations(ctx context.Context, querier database.Querier, builder *d
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var rawOrgs []*rawOrganization
|
||||
if err := rows.(database.CollectableRows).Collect(&rawOrgs); err != nil {
|
||||
var orgs []*rawOrg
|
||||
if err := rows.(database.CollectableRows).Collect(&orgs); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
organizations := make([]*domain.Organization, len(rawOrgs))
|
||||
for i, org := range rawOrgs {
|
||||
if len(org.RawDomains) > 0 {
|
||||
if err := json.Unmarshal(org.RawDomains, &org.Domains); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
organizations[i] = org.Organization
|
||||
result := make([]*domain.Organization, len(orgs))
|
||||
for i, org := range orgs {
|
||||
result[i] = org.Organization
|
||||
result[i].Domains = org.Domains
|
||||
}
|
||||
return organizations, nil
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// -------------------------------------------------------------
|
||||
|
||||
@@ -646,7 +646,7 @@ func TestServer_ListOrganizations(t *testing.T) {
|
||||
Filter: tt.query,
|
||||
})
|
||||
if tt.err != nil {
|
||||
require.ErrorContains(t, err, tt.err.Error())
|
||||
require.ErrorContains(ttt, err, tt.err.Error())
|
||||
return
|
||||
}
|
||||
require.NoError(ttt, err)
|
||||
@@ -828,7 +828,9 @@ func TestServer_ActivateOrganization(t *testing.T) {
|
||||
},
|
||||
})
|
||||
require.NoError(ttt, err)
|
||||
require.Equal(ttt, v2beta_org.OrgState_ORG_STATE_INACTIVE, listOrgRes.Organizations[0].State)
|
||||
if assert.GreaterOrEqual(ttt, len(listOrgRes.Organizations), 1) {
|
||||
require.Equal(ttt, v2beta_org.OrgState_ORG_STATE_INACTIVE, listOrgRes.Organizations[0].State)
|
||||
}
|
||||
}, retryDuration, tick, "timeout waiting for expected organizations being created")
|
||||
|
||||
return orgId
|
||||
@@ -1048,14 +1050,14 @@ func TestServer_AddOrganizationDomain(t *testing.T) {
|
||||
queryRes, err := Client.ListOrganizationDomains(CTX, &v2beta_org.ListOrganizationDomainsRequest{
|
||||
OrganizationId: orgId,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NoError(ttt, err)
|
||||
found := false
|
||||
for _, res := range queryRes.Domains {
|
||||
if res.DomainName == domain {
|
||||
found = true
|
||||
}
|
||||
}
|
||||
require.True(t, found, "unable to find added domain")
|
||||
require.True(ttt, found, "unable to find added domain")
|
||||
}, retryDuration, tick, "timeout waiting for expected organizations being created")
|
||||
|
||||
return orgId
|
||||
@@ -1209,20 +1211,20 @@ func TestServer_ListOrganizationDomains(t *testing.T) {
|
||||
queryRes, err = Client.ListOrganizationDomains(CTX, &v2beta_org.ListOrganizationDomainsRequest{
|
||||
OrganizationId: orgId,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NoError(ttt, err)
|
||||
found := false
|
||||
for _, res := range queryRes.Domains {
|
||||
if res.DomainName == tt.domain {
|
||||
found = true
|
||||
}
|
||||
}
|
||||
require.True(t, found, "unable to find added domain")
|
||||
require.True(ttt, found, "unable to find added domain")
|
||||
}, retryDuration, tick, "timeout waiting for adding domain")
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
func TestServer_DeleteOerganizationDomain(t *testing.T) {
|
||||
func TestServer_DeleteOrganizationDomain(t *testing.T) {
|
||||
domain := gofakeit.URL()
|
||||
tests := []struct {
|
||||
name string
|
||||
@@ -1260,14 +1262,14 @@ func TestServer_DeleteOerganizationDomain(t *testing.T) {
|
||||
queryRes, err := Client.ListOrganizationDomains(CTX, &v2beta_org.ListOrganizationDomainsRequest{
|
||||
OrganizationId: orgId,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NoError(ttt, err)
|
||||
found := false
|
||||
for _, res := range queryRes.Domains {
|
||||
if res.DomainName == domain {
|
||||
found = true
|
||||
}
|
||||
}
|
||||
require.True(t, found, "unable to find added domain")
|
||||
require.True(ttt, found, "unable to find added domain")
|
||||
}, retryDuration, tick, "timeout waiting for expected organizations being created")
|
||||
|
||||
return orgId
|
||||
@@ -1303,14 +1305,14 @@ func TestServer_DeleteOerganizationDomain(t *testing.T) {
|
||||
queryRes, err := Client.ListOrganizationDomains(CTX, &v2beta_org.ListOrganizationDomainsRequest{
|
||||
OrganizationId: orgId,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NoError(ttt, err)
|
||||
found := false
|
||||
for _, res := range queryRes.Domains {
|
||||
if res.DomainName == domain {
|
||||
found = true
|
||||
}
|
||||
}
|
||||
require.True(t, found, "unable to find added domain")
|
||||
require.True(ttt, found, "unable to find added domain")
|
||||
}, retryDuration, tick, "timeout waiting for expected organizations being created")
|
||||
|
||||
_, err = Client.DeleteOrganizationDomain(CTX, &v2beta_org.DeleteOrganizationDomainRequest{
|
||||
@@ -1691,7 +1693,7 @@ func TestServer_SetOrganizationMetadata(t *testing.T) {
|
||||
listMetadataRes, err := Client.ListOrganizationMetadata(tt.ctx, &v2beta_org.ListOrganizationMetadataRequest{
|
||||
OrganizationId: orgId,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NoError(ttt, err)
|
||||
foundMetadata := false
|
||||
foundMetadataKeyCount := 0
|
||||
for _, res := range listMetadataRes.Metadata {
|
||||
@@ -1719,11 +1721,11 @@ func TestServer_ListOrganizationMetadata(t *testing.T) {
|
||||
orgId := orgs[0].Id
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
ctx context.Context
|
||||
setupFunc func()
|
||||
orgId string
|
||||
keyValuPars []struct {
|
||||
name string
|
||||
ctx context.Context
|
||||
setupFunc func()
|
||||
orgId string
|
||||
keyValuePairs []struct {
|
||||
key string
|
||||
value string
|
||||
}
|
||||
@@ -1744,7 +1746,7 @@ func TestServer_ListOrganizationMetadata(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
},
|
||||
orgId: orgId,
|
||||
keyValuPars: []struct{ key, value string }{
|
||||
keyValuePairs: []struct{ key, value string }{
|
||||
{
|
||||
key: "key1",
|
||||
value: "value1",
|
||||
@@ -1775,7 +1777,7 @@ func TestServer_ListOrganizationMetadata(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
},
|
||||
orgId: orgId,
|
||||
keyValuPars: []struct{ key, value string }{
|
||||
keyValuePairs: []struct{ key, value string }{
|
||||
{
|
||||
key: "key2",
|
||||
value: "value2",
|
||||
@@ -1791,10 +1793,10 @@ func TestServer_ListOrganizationMetadata(t *testing.T) {
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "list org metadata for non existent org",
|
||||
ctx: Instance.WithAuthorization(CTX, integration.UserTypeIAMOwner),
|
||||
orgId: "non existent orgid",
|
||||
keyValuPars: []struct{ key, value string }{},
|
||||
name: "list org metadata for non existent org",
|
||||
ctx: Instance.WithAuthorization(CTX, integration.UserTypeIAMOwner),
|
||||
orgId: "non existent orgid",
|
||||
keyValuePairs: []struct{ key, value string }{},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
@@ -1808,10 +1810,10 @@ func TestServer_ListOrganizationMetadata(t *testing.T) {
|
||||
got, err := Client.ListOrganizationMetadata(tt.ctx, &v2beta_org.ListOrganizationMetadataRequest{
|
||||
OrganizationId: tt.orgId,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NoError(ttt, err)
|
||||
|
||||
foundMetadataCount := 0
|
||||
for _, kv := range tt.keyValuPars {
|
||||
for _, kv := range tt.keyValuePairs {
|
||||
for _, res := range got.Metadata {
|
||||
if res.Key == kv.key &&
|
||||
string(res.Value) == kv.value {
|
||||
@@ -1819,7 +1821,7 @@ func TestServer_ListOrganizationMetadata(t *testing.T) {
|
||||
}
|
||||
}
|
||||
}
|
||||
require.Equal(t, len(tt.keyValuPars), foundMetadataCount)
|
||||
require.Equal(ttt, len(tt.keyValuePairs), foundMetadataCount)
|
||||
}, retryDuration, tick, "timeout waiting for expected organizations being created")
|
||||
})
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user