mirror of
https://github.com/zitadel/zitadel.git
synced 2026-01-05 01:34:31 +00:00
fix(repo): correct mapping for domains (#10653)
This pull request fixes an issue where the repository would fail to scan organization or instance structs if the `domains` column was `NULL`. ## Which problems are solved If the `domains` column of `orgs` or `instances` was `NULL`, the repository failed scanning into the structs. This happened because the scanning mechanism did not correctly handle `NULL` JSONB columns. ## How the problems are solved A new generic type `JSONArray[T]` is introduced, which implements the `sql.Scanner` interface. This type can correctly scan JSON arrays from the database, including handling `NULL` values gracefully. The repositories for instances and organizations have been updated to use this new type for the domains field. The SQL queries have also been improved to use `FILTER` with `jsonb_agg` for better readability and performance when aggregating domains. ## Additional changes * An unnecessary cleanup step in the organization domain tests for already removed domains has been removed. * The `pgxscan` library has been replaced with `sqlscan` for scanning `database/sql`.Rows. * Minor cleanups in integration tests.
This commit is contained in:
@@ -3,7 +3,7 @@ package sql
|
||||
import (
|
||||
"database/sql"
|
||||
|
||||
pgxscan "github.com/georgysavva/scany/v2/dbscan"
|
||||
"github.com/georgysavva/scany/v2/sqlscan"
|
||||
"github.com/jackc/pgx/v5"
|
||||
|
||||
"github.com/zitadel/zitadel/backend/v3/storage/database"
|
||||
@@ -44,7 +44,7 @@ func (r *Rows) Collect(dest any) (err error) {
|
||||
err = closeErr
|
||||
}
|
||||
}()
|
||||
return wrapError(pgxscan.ScanAll(dest, r.Rows))
|
||||
return wrapError(sqlscan.ScanAll(dest, r.Rows))
|
||||
}
|
||||
|
||||
// CollectFirst implements [database.CollectableRows].
|
||||
@@ -56,7 +56,7 @@ func (r *Rows) CollectFirst(dest any) (err error) {
|
||||
err = closeErr
|
||||
}
|
||||
}()
|
||||
return wrapError(pgxscan.ScanRow(dest, r.Rows))
|
||||
return wrapError(sqlscan.ScanRow(dest, r.Rows))
|
||||
}
|
||||
|
||||
// CollectExactlyOneRow implements [database.CollectableRows].
|
||||
@@ -68,7 +68,7 @@ func (r *Rows) CollectExactlyOneRow(dest any) (err error) {
|
||||
err = closeErr
|
||||
}
|
||||
}()
|
||||
return wrapError(pgxscan.ScanOne(dest, r.Rows))
|
||||
return wrapError(sqlscan.ScanOne(dest, r.Rows))
|
||||
}
|
||||
|
||||
// Close implements [database.Rows].
|
||||
|
||||
@@ -35,11 +35,11 @@ func TestServer_TestInstanceDomainReduces(t *testing.T) {
|
||||
|
||||
// Wait for instance to be created
|
||||
retryDuration, tick := integration.WaitForAndTickWithMaxDuration(CTX, time.Minute)
|
||||
assert.EventuallyWithT(t, func(ttt *assert.CollectT) {
|
||||
assert.EventuallyWithT(t, func(t *assert.CollectT) {
|
||||
_, err := instanceRepo.Get(CTX,
|
||||
database.WithCondition(instanceRepo.IDCondition(instance.Instance.Id)),
|
||||
database.WithCondition(instanceRepo.IDCondition(instance.ID())),
|
||||
)
|
||||
assert.NoError(ttt, err)
|
||||
assert.NoError(t, err)
|
||||
}, retryDuration, tick)
|
||||
|
||||
t.Run("test instance custom domain add reduces", func(t *testing.T) {
|
||||
@@ -65,7 +65,7 @@ func TestServer_TestInstanceDomainReduces(t *testing.T) {
|
||||
|
||||
// Test that domain add reduces
|
||||
retryDuration, tick = integration.WaitForAndTickWithMaxDuration(CTX, time.Minute)
|
||||
assert.EventuallyWithT(t, func(ttt *assert.CollectT) {
|
||||
assert.EventuallyWithT(t, func(t *assert.CollectT) {
|
||||
domain, err := instanceDomainRepo.Get(CTX,
|
||||
database.WithCondition(
|
||||
database.And(
|
||||
@@ -75,13 +75,13 @@ func TestServer_TestInstanceDomainReduces(t *testing.T) {
|
||||
),
|
||||
),
|
||||
)
|
||||
require.NoError(ttt, err)
|
||||
require.NoError(t, err)
|
||||
// event instance.domain.added
|
||||
assert.Equal(ttt, domainName, domain.Domain)
|
||||
assert.Equal(ttt, instance.Instance.Id, domain.InstanceID)
|
||||
assert.False(ttt, *domain.IsPrimary)
|
||||
assert.WithinRange(ttt, domain.CreatedAt, beforeAdd, afterAdd)
|
||||
assert.WithinRange(ttt, domain.UpdatedAt, beforeAdd, afterAdd)
|
||||
assert.Equal(t, domainName, domain.Domain)
|
||||
assert.Equal(t, instance.Instance.Id, domain.InstanceID)
|
||||
assert.False(t, *domain.IsPrimary)
|
||||
assert.WithinRange(t, domain.CreatedAt, beforeAdd, afterAdd)
|
||||
assert.WithinRange(t, domain.UpdatedAt, beforeAdd, afterAdd)
|
||||
}, retryDuration, tick)
|
||||
})
|
||||
|
||||
@@ -124,7 +124,7 @@ func TestServer_TestInstanceDomainReduces(t *testing.T) {
|
||||
|
||||
// Wait for domain to be created
|
||||
retryDuration, tick = integration.WaitForAndTickWithMaxDuration(CTX, time.Minute)
|
||||
assert.EventuallyWithT(t, func(ttt *assert.CollectT) {
|
||||
assert.EventuallyWithT(t, func(t *assert.CollectT) {
|
||||
domain, err := instanceDomainRepo.Get(CTX,
|
||||
database.WithCondition(
|
||||
database.And(
|
||||
@@ -134,9 +134,9 @@ func TestServer_TestInstanceDomainReduces(t *testing.T) {
|
||||
),
|
||||
),
|
||||
)
|
||||
require.NoError(ttt, err)
|
||||
require.False(ttt, *domain.IsPrimary)
|
||||
assert.Equal(ttt, domainName, domain.Domain)
|
||||
require.NoError(t, err)
|
||||
require.False(t, *domain.IsPrimary)
|
||||
assert.Equal(t, domainName, domain.Domain)
|
||||
}, retryDuration, tick)
|
||||
|
||||
// Set domain as primary
|
||||
@@ -150,7 +150,7 @@ func TestServer_TestInstanceDomainReduces(t *testing.T) {
|
||||
|
||||
// Test that set primary reduces
|
||||
retryDuration, tick = integration.WaitForAndTickWithMaxDuration(CTX, time.Minute)
|
||||
assert.EventuallyWithT(t, func(ttt *assert.CollectT) {
|
||||
assert.EventuallyWithT(t, func(t *assert.CollectT) {
|
||||
domain, err := instanceDomainRepo.Get(CTX,
|
||||
database.WithCondition(
|
||||
database.And(
|
||||
@@ -160,11 +160,11 @@ func TestServer_TestInstanceDomainReduces(t *testing.T) {
|
||||
),
|
||||
),
|
||||
)
|
||||
require.NoError(ttt, err)
|
||||
require.NoError(t, err)
|
||||
// event instance.domain.primary.set
|
||||
assert.Equal(ttt, domainName, domain.Domain)
|
||||
assert.True(ttt, *domain.IsPrimary)
|
||||
assert.WithinRange(ttt, domain.UpdatedAt, beforeSetPrimary, afterSetPrimary)
|
||||
assert.Equal(t, domainName, domain.Domain)
|
||||
assert.True(t, *domain.IsPrimary)
|
||||
assert.WithinRange(t, domain.UpdatedAt, beforeSetPrimary, afterSetPrimary)
|
||||
}, retryDuration, tick)
|
||||
})
|
||||
|
||||
@@ -179,7 +179,7 @@ func TestServer_TestInstanceDomainReduces(t *testing.T) {
|
||||
|
||||
// Wait for domain to be created and verify it exists
|
||||
retryDuration, tick = integration.WaitForAndTickWithMaxDuration(CTX, time.Minute)
|
||||
assert.EventuallyWithT(t, func(ttt *assert.CollectT) {
|
||||
assert.EventuallyWithT(t, func(t *assert.CollectT) {
|
||||
_, err := instanceDomainRepo.Get(CTX,
|
||||
database.WithCondition(
|
||||
database.And(
|
||||
@@ -189,7 +189,7 @@ func TestServer_TestInstanceDomainReduces(t *testing.T) {
|
||||
),
|
||||
),
|
||||
)
|
||||
require.NoError(ttt, err)
|
||||
require.NoError(t, err)
|
||||
}, retryDuration, tick)
|
||||
|
||||
// Remove the domain
|
||||
@@ -201,7 +201,7 @@ func TestServer_TestInstanceDomainReduces(t *testing.T) {
|
||||
|
||||
// Test that domain remove reduces
|
||||
retryDuration, tick = integration.WaitForAndTickWithMaxDuration(CTX, time.Minute)
|
||||
assert.EventuallyWithT(t, func(ttt *assert.CollectT) {
|
||||
assert.EventuallyWithT(t, func(t *assert.CollectT) {
|
||||
domain, err := instanceDomainRepo.Get(CTX,
|
||||
database.WithCondition(
|
||||
database.And(
|
||||
@@ -212,8 +212,8 @@ func TestServer_TestInstanceDomainReduces(t *testing.T) {
|
||||
),
|
||||
)
|
||||
// event instance.domain.removed
|
||||
assert.Nil(ttt, domain)
|
||||
require.ErrorIs(ttt, err, new(database.NoRowFoundError))
|
||||
assert.Nil(t, domain)
|
||||
require.ErrorIs(t, err, new(database.NoRowFoundError))
|
||||
}, retryDuration, tick)
|
||||
})
|
||||
|
||||
@@ -240,7 +240,7 @@ func TestServer_TestInstanceDomainReduces(t *testing.T) {
|
||||
|
||||
// Test that domain add reduces
|
||||
retryDuration, tick = integration.WaitForAndTickWithMaxDuration(CTX, time.Minute)
|
||||
assert.EventuallyWithT(t, func(ttt *assert.CollectT) {
|
||||
assert.EventuallyWithT(t, func(t *assert.CollectT) {
|
||||
domain, err := instanceDomainRepo.Get(CTX,
|
||||
database.WithCondition(
|
||||
database.And(
|
||||
@@ -250,12 +250,12 @@ func TestServer_TestInstanceDomainReduces(t *testing.T) {
|
||||
),
|
||||
),
|
||||
)
|
||||
require.NoError(ttt, err)
|
||||
require.NoError(t, err)
|
||||
// event instance.domain.added
|
||||
assert.Equal(ttt, domainName, domain.Domain)
|
||||
assert.Equal(ttt, instance.Instance.Id, domain.InstanceID)
|
||||
assert.WithinRange(ttt, domain.CreatedAt, beforeAdd, afterAdd)
|
||||
assert.WithinRange(ttt, domain.UpdatedAt, beforeAdd, afterAdd)
|
||||
assert.Equal(t, domainName, domain.Domain)
|
||||
assert.Equal(t, instance.Instance.Id, domain.InstanceID)
|
||||
assert.WithinRange(t, domain.CreatedAt, beforeAdd, afterAdd)
|
||||
assert.WithinRange(t, domain.UpdatedAt, beforeAdd, afterAdd)
|
||||
}, retryDuration, tick)
|
||||
})
|
||||
|
||||
@@ -270,7 +270,7 @@ func TestServer_TestInstanceDomainReduces(t *testing.T) {
|
||||
|
||||
// Wait for domain to be created and verify it exists
|
||||
retryDuration, tick = integration.WaitForAndTickWithMaxDuration(CTX, time.Minute)
|
||||
assert.EventuallyWithT(t, func(ttt *assert.CollectT) {
|
||||
assert.EventuallyWithT(t, func(t *assert.CollectT) {
|
||||
_, err := instanceDomainRepo.Get(CTX,
|
||||
database.WithCondition(
|
||||
database.And(
|
||||
@@ -280,7 +280,7 @@ func TestServer_TestInstanceDomainReduces(t *testing.T) {
|
||||
),
|
||||
),
|
||||
)
|
||||
require.NoError(ttt, err)
|
||||
require.NoError(t, err)
|
||||
}, retryDuration, tick)
|
||||
|
||||
// Remove the domain
|
||||
@@ -292,7 +292,7 @@ func TestServer_TestInstanceDomainReduces(t *testing.T) {
|
||||
|
||||
// Test that domain remove reduces
|
||||
retryDuration, tick = integration.WaitForAndTickWithMaxDuration(CTX, time.Minute)
|
||||
assert.EventuallyWithT(t, func(ttt *assert.CollectT) {
|
||||
assert.EventuallyWithT(t, func(t *assert.CollectT) {
|
||||
domain, err := instanceDomainRepo.Get(CTX,
|
||||
database.WithCondition(
|
||||
database.And(
|
||||
@@ -303,8 +303,8 @@ func TestServer_TestInstanceDomainReduces(t *testing.T) {
|
||||
),
|
||||
)
|
||||
// event instance.domain.removed
|
||||
assert.Nil(ttt, domain)
|
||||
require.ErrorIs(ttt, err, new(database.NoRowFoundError))
|
||||
assert.Nil(t, domain)
|
||||
require.ErrorIs(t, err, new(database.NoRowFoundError))
|
||||
}, retryDuration, tick)
|
||||
})
|
||||
}
|
||||
|
||||
@@ -45,13 +45,13 @@ func TestServer_TestInstanceReduces(t *testing.T) {
|
||||
})
|
||||
|
||||
retryDuration, tick := integration.WaitForAndTickWithMaxDuration(CTX, time.Minute)
|
||||
assert.EventuallyWithT(t, func(ttt *assert.CollectT) {
|
||||
assert.EventuallyWithT(t, func(t *assert.CollectT) {
|
||||
instance, err := instanceRepo.Get(CTX,
|
||||
database.WithCondition(instanceRepo.IDCondition(instance.GetInstanceId())),
|
||||
)
|
||||
require.NoError(ttt, err)
|
||||
require.NoError(t, err)
|
||||
// event instance.added
|
||||
assert.Equal(ttt, instanceName, instance.Name)
|
||||
assert.Equal(t, instanceName, instance.Name)
|
||||
// event instance.default.org.set
|
||||
assert.NotNil(t, instance.DefaultOrgID)
|
||||
// event instance.iam.project.set
|
||||
|
||||
@@ -36,11 +36,11 @@ func TestServer_TestOrgDomainReduces(t *testing.T) {
|
||||
|
||||
// Wait for org to be created
|
||||
retryDuration, tick := integration.WaitForAndTickWithMaxDuration(CTX, time.Minute)
|
||||
assert.EventuallyWithT(t, func(ttt *assert.CollectT) {
|
||||
assert.EventuallyWithT(t, func(t *assert.CollectT) {
|
||||
_, err := orgRepo.Get(CTX,
|
||||
database.WithCondition(orgRepo.IDCondition(org.GetId())),
|
||||
)
|
||||
assert.NoError(ttt, err)
|
||||
assert.NoError(t, err)
|
||||
}, retryDuration, tick)
|
||||
|
||||
// The API call also sets the domain as primary, so we don't do a separate test for that.
|
||||
@@ -67,7 +67,7 @@ func TestServer_TestOrgDomainReduces(t *testing.T) {
|
||||
|
||||
// Test that domain add reduces
|
||||
retryDuration, tick = integration.WaitForAndTickWithMaxDuration(CTX, time.Minute)
|
||||
assert.EventuallyWithT(t, func(ttt *assert.CollectT) {
|
||||
assert.EventuallyWithT(t, func(t *assert.CollectT) {
|
||||
gottenDomain, err := orgDomainRepo.Get(CTX,
|
||||
database.WithCondition(
|
||||
database.And(
|
||||
@@ -77,7 +77,7 @@ func TestServer_TestOrgDomainReduces(t *testing.T) {
|
||||
),
|
||||
),
|
||||
)
|
||||
require.NoError(ttt, err)
|
||||
require.NoError(t, err)
|
||||
// event org.domain.added
|
||||
assert.Equal(t, domainName, gottenDomain.Domain)
|
||||
assert.Equal(t, Instance.Instance.Id, gottenDomain.InstanceID)
|
||||
@@ -97,16 +97,6 @@ func TestServer_TestOrgDomainReduces(t *testing.T) {
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
t.Cleanup(func() {
|
||||
_, err := OrgClient.DeleteOrganizationDomain(CTX, &v2beta.DeleteOrganizationDomainRequest{
|
||||
OrganizationId: org.GetId(),
|
||||
Domain: domainName,
|
||||
})
|
||||
if err != nil {
|
||||
t.Logf("Failed to delete domain on cleanup: %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
// Remove the domain
|
||||
_, err = OrgClient.DeleteOrganizationDomain(CTX, &v2beta.DeleteOrganizationDomainRequest{
|
||||
OrganizationId: org.GetId(),
|
||||
@@ -116,7 +106,7 @@ func TestServer_TestOrgDomainReduces(t *testing.T) {
|
||||
|
||||
// Test that domain remove reduces
|
||||
retryDuration, tick = integration.WaitForAndTickWithMaxDuration(CTX, time.Minute)
|
||||
assert.EventuallyWithT(t, func(ttt *assert.CollectT) {
|
||||
assert.EventuallyWithT(t, func(t *assert.CollectT) {
|
||||
domain, err := orgDomainRepo.Get(CTX,
|
||||
database.WithCondition(
|
||||
database.And(
|
||||
@@ -126,8 +116,8 @@ func TestServer_TestOrgDomainReduces(t *testing.T) {
|
||||
),
|
||||
)
|
||||
// event instance.domain.removed
|
||||
assert.Nil(ttt, domain)
|
||||
require.ErrorIs(ttt, err, new(database.NoRowFoundError))
|
||||
assert.Nil(t, domain)
|
||||
require.ErrorIs(t, err, new(database.NoRowFoundError))
|
||||
}, retryDuration, tick)
|
||||
})
|
||||
}
|
||||
|
||||
21
backend/v3/storage/database/repository/array.go
Normal file
21
backend/v3/storage/database/repository/array.go
Normal file
@@ -0,0 +1,21 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
)
|
||||
|
||||
type JSONArray[T any] []*T
|
||||
|
||||
func (a JSONArray[T]) Scan(src any) error {
|
||||
switch s := src.(type) {
|
||||
case string:
|
||||
return json.Unmarshal([]byte(s), &a)
|
||||
case []byte:
|
||||
return json.Unmarshal(s, &a)
|
||||
case nil:
|
||||
return nil
|
||||
default:
|
||||
return errors.New("unsupported scan source")
|
||||
}
|
||||
}
|
||||
@@ -2,8 +2,6 @@ package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"time"
|
||||
|
||||
"golang.org/x/text/language"
|
||||
@@ -34,7 +32,7 @@ func InstanceRepository(client database.QueryExecutor) domain.InstanceRepository
|
||||
|
||||
const (
|
||||
queryInstanceStmt = `SELECT instances.id, instances.name, instances.default_org_id, instances.iam_project_id, instances.console_client_id, instances.console_app_id, instances.default_language, instances.created_at, instances.updated_at` +
|
||||
` , CASE WHEN count(instance_domains.domain) > 0 THEN jsonb_agg(json_build_object('domain', instance_domains.domain, 'isPrimary', instance_domains.is_primary, 'isGenerated', instance_domains.is_generated, 'createdAt', instance_domains.created_at, 'updatedAt', instance_domains.updated_at)) ELSE NULL::JSONB END domains` +
|
||||
` , jsonb_agg(json_build_object('domain', instance_domains.domain, 'isPrimary', instance_domains.is_primary, 'isGenerated', instance_domains.is_generated, 'createdAt', instance_domains.created_at, 'updatedAt', instance_domains.updated_at)) FILTER (WHERE instance_domains.instance_id IS NOT NULL) AS domains` +
|
||||
` FROM zitadel.instances`
|
||||
)
|
||||
|
||||
@@ -236,9 +234,13 @@ func (instance) UpdatedAtColumn() database.Column {
|
||||
return database.NewColumn("instances", "updated_at")
|
||||
}
|
||||
|
||||
// -------------------------------------------------------------
|
||||
// scanners
|
||||
// -------------------------------------------------------------
|
||||
|
||||
type rawInstance struct {
|
||||
*domain.Instance
|
||||
RawDomains sql.Null[json.RawMessage] `json:"domains,omitzero" db:"domains"`
|
||||
Domains JSONArray[domain.InstanceDomain] `json:"domains,omitempty" db:"domains"`
|
||||
}
|
||||
|
||||
func scanInstance(ctx context.Context, querier database.Querier, builder *database.StatementBuilder) (*domain.Instance, error) {
|
||||
@@ -251,12 +253,7 @@ func scanInstance(ctx context.Context, querier database.Querier, builder *databa
|
||||
if err := rows.(database.CollectableRows).CollectExactlyOneRow(&instance); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if instance.RawDomains.Valid {
|
||||
if err := json.Unmarshal(instance.RawDomains.V, &instance.Domains); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
instance.Instance.Domains = instance.Domains
|
||||
|
||||
return instance.Instance, nil
|
||||
}
|
||||
@@ -267,21 +264,18 @@ func scanInstances(ctx context.Context, querier database.Querier, builder *datab
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var rawInstances []*rawInstance
|
||||
if err := rows.(database.CollectableRows).Collect(&rawInstances); err != nil {
|
||||
var instances []*rawInstance
|
||||
if err := rows.(database.CollectableRows).Collect(&instances); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
instances := make([]*domain.Instance, len(rawInstances))
|
||||
for i, instance := range rawInstances {
|
||||
if instance.RawDomains.Valid {
|
||||
if err := json.Unmarshal(instance.RawDomains.V, &instance.Domains); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
instances[i] = instance.Instance
|
||||
result := make([]*domain.Instance, len(instances))
|
||||
for i, inst := range instances {
|
||||
result[i] = inst.Instance
|
||||
result[i].Domains = inst.Domains
|
||||
}
|
||||
return instances, nil
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// -------------------------------------------------------------
|
||||
|
||||
@@ -2,7 +2,6 @@ package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
|
||||
"github.com/zitadel/zitadel/backend/v3/domain"
|
||||
"github.com/zitadel/zitadel/backend/v3/storage/database"
|
||||
@@ -29,7 +28,7 @@ func OrganizationRepository(client database.QueryExecutor) domain.OrganizationRe
|
||||
}
|
||||
|
||||
const queryOrganizationStmt = `SELECT organizations.id, organizations.name, organizations.instance_id, organizations.state, organizations.created_at, organizations.updated_at` +
|
||||
` , CASE WHEN count(org_domains.domain) > 0 THEN jsonb_agg(json_build_object('domain', org_domains.domain, 'isVerified', org_domains.is_verified, 'isPrimary', org_domains.is_primary, 'validationType', org_domains.validation_type, 'createdAt', org_domains.created_at, 'updatedAt', org_domains.updated_at)) ELSE NULL::JSONB END domains` +
|
||||
` , jsonb_agg(json_build_object('domain', org_domains.domain, 'isVerified', org_domains.is_verified, 'isPrimary', org_domains.is_primary, 'validationType', org_domains.validation_type, 'createdAt', org_domains.created_at, 'updatedAt', org_domains.updated_at)) FILTER (WHERE org_domains.org_id IS NOT NULL) AS domains` +
|
||||
` FROM zitadel.organizations`
|
||||
|
||||
// Get implements [domain.OrganizationRepository].
|
||||
@@ -212,9 +211,9 @@ func (org) UpdatedAtColumn() database.Column {
|
||||
// scanners
|
||||
// -------------------------------------------------------------
|
||||
|
||||
type rawOrganization struct {
|
||||
type rawOrg struct {
|
||||
*domain.Organization
|
||||
RawDomains json.RawMessage `json:"domains,omitempty" db:"domains"`
|
||||
Domains JSONArray[domain.OrganizationDomain] `json:"domains,omitempty" db:"domains"`
|
||||
}
|
||||
|
||||
func scanOrganization(ctx context.Context, querier database.Querier, builder *database.StatementBuilder) (*domain.Organization, error) {
|
||||
@@ -223,15 +222,11 @@ func scanOrganization(ctx context.Context, querier database.Querier, builder *da
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var org rawOrganization
|
||||
var org rawOrg
|
||||
if err := rows.(database.CollectableRows).CollectExactlyOneRow(&org); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(org.RawDomains) > 0 {
|
||||
if err := json.Unmarshal(org.RawDomains, &org.Domains); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
org.Organization.Domains = org.Domains
|
||||
|
||||
return org.Organization, nil
|
||||
}
|
||||
@@ -242,21 +237,18 @@ func scanOrganizations(ctx context.Context, querier database.Querier, builder *d
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var rawOrgs []*rawOrganization
|
||||
if err := rows.(database.CollectableRows).Collect(&rawOrgs); err != nil {
|
||||
var orgs []*rawOrg
|
||||
if err := rows.(database.CollectableRows).Collect(&orgs); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
organizations := make([]*domain.Organization, len(rawOrgs))
|
||||
for i, org := range rawOrgs {
|
||||
if len(org.RawDomains) > 0 {
|
||||
if err := json.Unmarshal(org.RawDomains, &org.Domains); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
organizations[i] = org.Organization
|
||||
result := make([]*domain.Organization, len(orgs))
|
||||
for i, org := range orgs {
|
||||
result[i] = org.Organization
|
||||
result[i].Domains = org.Domains
|
||||
}
|
||||
return organizations, nil
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// -------------------------------------------------------------
|
||||
|
||||
Reference in New Issue
Block a user