feat(db): adding org table to relational model (#10066)

# Which Problems Are Solved

As an outcome of [this
issue](https://github.com/zitadel/zitadel/issues/9599) we want to
implement relational tables in Zitadel. For that we use new tables as a
successor of the current tables used by Zitadel in `projections`, `auth`
and `admin` schemas. The new logic is based on [this
proposal](https://github.com/zitadel/zitadel/pull/9870). This issue does
not contain the switch from CQRS to the new tables. This is change will
be implemented in a later stage.

We focus on the most critical tables which is user authentication.

We need a table to manage organizations. 

### organization fields

The following fields must be managed in this table:

- `id`
- `instance_id`
- `name`
- `state` enum (active, inactive)
- `created_at`
- `updated_at`
- `deleted_at`

DISCUSS: should we add a `primary_domain` to this table so that we do
not have to join on domains to return a simple org?

We must ensure the unique constraints for this table matches the current
commands.

### organization repository

The repository must provide the following functions:

Manipulations:
- create
  - `instance_id`
  - `name`
- update
  - `name`
- delete

Queries:
- get returns single organization matching the criteria and pagination,
should return error if multiple were found
- list returns list of organizations matching the criteria, pagination

Criteria are the following:
- by id
- by name

pagination:
- by created_at
- by updated_at
- by name

### organization events

The following events must be applied on the table using a projection
(`internal/query/projection`)

- `org.added` results in create
- `org.changed` sets the `name` field
- `org.deactivated` sets the `state` field
- `org.reactivated` sets the `state` field
- `org.removed` sets the `deleted_at` field
- if answer is yes to discussion: `org.domain.primary.set` sets the
`primary_domain` field
- `instance.removed` sets the the `deleted_at` field if not already set

### acceptance criteria

- [x] migration is implemented and gets executed
- [x] domain interfaces are implemented and documented for service layer
- [x] repository is implemented and implements domain interface
- [x] testing
  - [x] the repository methods
  - [x] events get reduced correctly
  - [x] unique constraints
# Additional Context

Replace this example with links to related issues, discussions, discord
threads, or other sources with more context.
Use the Closing #issue syntax for issues that are resolved with this PR.
- Closes #https://github.com/zitadel/zitadel/issues/9936

---------

Co-authored-by: adlerhurst <27845747+adlerhurst@users.noreply.github.com>
This commit is contained in:
Iraq
2025-07-14 21:27:14 +02:00
committed by GitHub
parent 9595a1bcca
commit 8d020e56bb
28 changed files with 2238 additions and 590 deletions

View File

@@ -86,12 +86,12 @@ type InstanceRepository interface {
// Member returns the member repository which is a sub repository of the instance repository.
// Member() MemberRepository
Get(ctx context.Context, opts ...database.Condition) (*Instance, error)
Get(ctx context.Context, id string) (*Instance, error)
List(ctx context.Context, opts ...database.Condition) ([]*Instance, error)
Create(ctx context.Context, instance *Instance) error
Update(ctx context.Context, condition database.Condition, changes ...database.Change) (int64, error)
Delete(ctx context.Context, condition database.Condition) error
Update(ctx context.Context, id string, changes ...database.Change) (int64, error)
Delete(ctx context.Context, id string) (int64, error)
}
type CreateInstance struct {

View File

@@ -1,120 +0,0 @@
package domain
import (
"context"
"time"
"github.com/zitadel/zitadel/backend/v3/storage/cache"
"github.com/zitadel/zitadel/backend/v3/storage/database"
)
type OrgState uint8
const (
OrgStateActive OrgState = iota + 1
OrgStateInactive
)
// Org is used by all other packages to represent an organization.
type Org struct {
ID string `json:"id"`
Name string `json:"name"`
State OrgState `json:"state"`
CreatedAt time.Time `json:"createdAt"`
UpdatedAt time.Time `json:"updatedAt"`
}
type orgCacheIndex uint8
const (
orgCacheIndexUndefined orgCacheIndex = iota
orgCacheIndexID
)
// Keys implements [cache.Entry].
func (o *Org) Keys(index orgCacheIndex) (key []string) {
if index == orgCacheIndexID {
return []string{o.ID}
}
return nil
}
var _ cache.Entry[orgCacheIndex, string] = (*Org)(nil)
// orgColumns define all the columns of the org table.
type orgColumns interface {
// InstanceIDColumn returns the column for the instance id field.
InstanceIDColumn() database.Column
// IDColumn returns the column for the id field.
IDColumn() database.Column
// NameColumn returns the column for the name field.
NameColumn() database.Column
// StateColumn returns the column for the state field.
StateColumn() database.Column
// CreatedAtColumn returns the column for the created at field.
CreatedAtColumn() database.Column
// UpdatedAtColumn returns the column for the updated at field.
UpdatedAtColumn() database.Column
// DeletedAtColumn returns the column for the deleted at field.
DeletedAtColumn() database.Column
}
// orgConditions define all the conditions for the org table.
type orgConditions interface {
// InstanceIDCondition returns an equal filter on the instance id field.
InstanceIDCondition(instanceID string) database.Condition
// IDCondition returns an equal filter on the id field.
IDCondition(orgID string) database.Condition
// NameCondition returns a filter on the name field.
NameCondition(op database.TextOperation, name string) database.Condition
// StateCondition returns a filter on the state field.
StateCondition(op database.NumberOperation, state OrgState) database.Condition
}
// orgChanges define all the changes for the org table.
type orgChanges interface {
// SetName sets the name column.
SetName(name string) database.Change
// SetState sets the state column.
SetState(state OrgState) database.Change
}
// OrgRepository is the interface for the org repository.
// It is used to interact with the org table in the database.
type OrgRepository interface {
orgColumns
orgConditions
orgChanges
// Member returns the member repository.
Member() MemberRepository
// Domain returns the domain repository.
Domain() DomainRepository
// Get returns an org based on the given condition.
Get(ctx context.Context, opts ...database.QueryOption) (*Org, error)
// List returns a list of orgs based on the given condition.
List(ctx context.Context, opts ...database.QueryOption) ([]*Org, error)
// Create creates a new org.
Create(ctx context.Context, org *Org) error
// Delete removes orgs based on the given condition.
Delete(ctx context.Context, condition database.Condition) error
// Update executes the given changes based on the given condition.
Update(ctx context.Context, condition database.Condition, changes ...database.Change) error
}
// MemberRepository is a sub repository of the org repository and maybe the instance repository.
type MemberRepository interface {
AddMember(ctx context.Context, orgID, userID string, roles []string) error
SetMemberRoles(ctx context.Context, orgID, userID string, roles []string) error
RemoveMember(ctx context.Context, orgID, userID string) error
}
// DomainRepository is a sub repository of the org repository and maybe the instance repository.
type DomainRepository interface {
AddDomain(ctx context.Context, domain string) error
SetDomainVerified(ctx context.Context, domain string) error
RemoveDomain(ctx context.Context, domain string) error
}

View File

@@ -0,0 +1,103 @@
package domain
import (
"context"
"time"
"github.com/zitadel/zitadel/backend/v3/storage/database"
)
//go:generate enumer -type OrgState -transform lower -trimprefix OrgState
type OrgState uint8
const (
OrgStateActive OrgState = iota
OrgStateInactive
)
type Organization struct {
ID string `json:"id,omitempty" db:"id"`
Name string `json:"name,omitempty" db:"name"`
InstanceID string `json:"instanceId,omitempty" db:"instance_id"`
State string `json:"state,omitempty" db:"state"`
CreatedAt time.Time `json:"createdAt,omitempty" db:"created_at"`
UpdatedAt time.Time `json:"updatedAt,omitempty" db:"updated_at"`
DeletedAt *time.Time `json:"deletedAt,omitempty" db:"deleted_at"`
}
// OrgIdentifierCondition is used to help specify a single Organization,
// it will either be used as the organization ID or organization name,
// as organizations can be identified either using (instnaceID + ID) OR (instanceID + name)
type OrgIdentifierCondition interface {
database.Condition
}
// organizationColumns define all the columns of the instance table.
type organizationColumns interface {
// IDColumn returns the column for the id field.
IDColumn() database.Column
// NameColumn returns the column for the name field.
NameColumn() database.Column
// InstanceIDColumn returns the column for the default org id field
InstanceIDColumn() database.Column
// StateColumn returns the column for the name field.
StateColumn() database.Column
// CreatedAtColumn returns the column for the created at field.
CreatedAtColumn() database.Column
// UpdatedAtColumn returns the column for the updated at field.
UpdatedAtColumn() database.Column
// DeletedAtColumn returns the column for the deleted at field.
DeletedAtColumn() database.Column
}
// organizationConditions define all the conditions for the instance table.
type organizationConditions interface {
// IDCondition returns an equal filter on the id field.
IDCondition(instanceID string) OrgIdentifierCondition
// NameCondition returns a filter on the name field.
NameCondition(name string) OrgIdentifierCondition
// InstanceIDCondition returns a filter on the instance id field.
InstanceIDCondition(instanceID string) database.Condition
// StateCondition returns a filter on the name field.
StateCondition(state OrgState) database.Condition
}
// organizationChanges define all the changes for the instance table.
type organizationChanges interface {
// SetName sets the name column.
SetName(name string) database.Change
// SetState sets the name column.
SetState(state OrgState) database.Change
}
// OrganizationRepository is the interface for the instance repository.
type OrganizationRepository interface {
organizationColumns
organizationConditions
organizationChanges
Get(ctx context.Context, id OrgIdentifierCondition, instance_id string, opts ...database.Condition) (*Organization, error)
List(ctx context.Context, opts ...database.Condition) ([]*Organization, error)
Create(ctx context.Context, instance *Organization) error
Update(ctx context.Context, id OrgIdentifierCondition, instance_id string, changes ...database.Change) (int64, error)
Delete(ctx context.Context, id OrgIdentifierCondition, instance_id string) (int64, error)
}
type CreateOrganization struct {
Name string `json:"name"`
}
// MemberRepository is a sub repository of the org repository and maybe the instance repository.
type MemberRepository interface {
AddMember(ctx context.Context, orgID, userID string, roles []string) error
SetMemberRoles(ctx context.Context, orgID, userID string, roles []string) error
RemoveMember(ctx context.Context, orgID, userID string) error
}
// DomainRepository is a sub repository of the org repository and maybe the instance repository.
type DomainRepository interface {
AddDomain(ctx context.Context, domain string) error
SetDomainVerified(ctx context.Context, domain string) error
RemoveDomain(ctx context.Context, domain string) error
}

View File

@@ -0,0 +1,78 @@
// Code generated by "enumer -type OrgState -transform lower -trimprefix OrgState"; DO NOT EDIT.
package domain
import (
"fmt"
"strings"
)
const _OrgStateName = "activeinactive"
var _OrgStateIndex = [...]uint8{0, 6, 14}
const _OrgStateLowerName = "activeinactive"
func (i OrgState) String() string {
if i < 0 || i >= OrgState(len(_OrgStateIndex)-1) {
return fmt.Sprintf("OrgState(%d)", i)
}
return _OrgStateName[_OrgStateIndex[i]:_OrgStateIndex[i+1]]
}
// An "invalid array index" compiler error signifies that the constant values have changed.
// Re-run the stringer command to generate them again.
func _OrgStateNoOp() {
var x [1]struct{}
_ = x[OrgStateActive-(0)]
_ = x[OrgStateInactive-(1)]
}
var _OrgStateValues = []OrgState{OrgStateActive, OrgStateInactive}
var _OrgStateNameToValueMap = map[string]OrgState{
_OrgStateName[0:6]: OrgStateActive,
_OrgStateLowerName[0:6]: OrgStateActive,
_OrgStateName[6:14]: OrgStateInactive,
_OrgStateLowerName[6:14]: OrgStateInactive,
}
var _OrgStateNames = []string{
_OrgStateName[0:6],
_OrgStateName[6:14],
}
// OrgStateString retrieves an enum value from the enum constants string name.
// Throws an error if the param is not part of the enum.
func OrgStateString(s string) (OrgState, error) {
if val, ok := _OrgStateNameToValueMap[s]; ok {
return val, nil
}
if val, ok := _OrgStateNameToValueMap[strings.ToLower(s)]; ok {
return val, nil
}
return 0, fmt.Errorf("%s does not belong to OrgState values", s)
}
// OrgStateValues returns all values of the enum
func OrgStateValues() []OrgState {
return _OrgStateValues
}
// OrgStateStrings returns a slice of all String values of the enum
func OrgStateStrings() []string {
strs := make([]string, len(_OrgStateNames))
copy(strs, _OrgStateNames)
return strs
}
// IsAOrgState returns "true" if the value is listed in the enum definition. "false" otherwise
func (i OrgState) IsAOrgState() bool {
for _, v := range _OrgStateValues {
if i == v {
return true
}
}
return false
}

View File

@@ -59,8 +59,25 @@ type Row interface {
// Rows is an abstraction of sql.Rows.
type Rows interface {
Row
Scanner
Next() bool
Close() error
Err() error
}
type CollectableRows interface {
// Collect collects all rows and scans them into dest.
// dest must be a pointer to a slice of pointer to structs
// e.g. *[]*MyStruct
// Rows are closed after this call.
Collect(dest any) error
// CollectFirst collects the first row and scans it into dest.
// dest must be a pointer to a struct
// e.g. *MyStruct{}
// Rows are closed after this call.
CollectFirst(dest any) error
// CollectExactlyOneRow collects exactly one row and scans it into dest.
// e.g. *MyStruct{}
// Rows are closed after this call.
CollectExactlyOneRow(dest any) error
}

View File

@@ -6,8 +6,8 @@ CREATE TABLE IF NOT EXISTS zitadel.instances(
console_client_id TEXT, -- NOT NULL,
console_app_id TEXT, -- NOT NULL,
default_language TEXT, -- NOT NULL,
created_at TIMESTAMPTZ DEFAULT NOW(),
updated_at TIMESTAMPTZ DEFAULT NOW(),
created_at TIMESTAMPTZ DEFAULT NOW() NOT NULL,
updated_at TIMESTAMPTZ DEFAULT NOW() NOT NULL,
deleted_at TIMESTAMPTZ DEFAULT NULL
);

View File

@@ -0,0 +1,16 @@
package migration
import (
_ "embed"
)
var (
//go:embed 002_organization_table/up.sql
up002OrganizationTable string
//go:embed 002_organization_table/down.sql
down002OrganizationTable string
)
func init() {
registerSQLMigration(2, up002OrganizationTable, down002OrganizationTable)
}

View File

@@ -0,0 +1,2 @@
DROP TABLE zitadel.organizations;
DROP Type zitadel.organization_state;

View File

@@ -0,0 +1,33 @@
CREATE TYPE zitadel.organization_state AS ENUM (
'active',
'inactive'
);
CREATE TABLE zitadel.organizations(
id TEXT NOT NULL CHECK (id <> ''),
name TEXT NOT NULL CHECK (name <> ''),
instance_id TEXT NOT NULL REFERENCES zitadel.instances (id),
state zitadel.organization_state NOT NULL,
created_at TIMESTAMPTZ DEFAULT NOW() NOT NULL,
updated_at TIMESTAMPTZ DEFAULT NOW() NOT NULL,
deleted_at TIMESTAMPTZ DEFAULT NULL,
PRIMARY KEY (instance_id, id)
);
CREATE UNIQUE INDEX org_unique_instance_id_name_idx
ON zitadel.organizations (instance_id, name)
WHERE deleted_at IS NULL;
-- users are able to set the id for organizations
CREATE INDEX org_id_not_deleted_idx ON zitadel.organizations (id)
WHERE deleted_at IS NULL;
CREATE INDEX org_name_not_deleted_idx ON zitadel.organizations (name)
WHERE deleted_at IS NULL;
CREATE TRIGGER trigger_set_updated_at
BEFORE UPDATE ON zitadel.organizations
FOR EACH ROW
WHEN (OLD.updated_at IS NOT DISTINCT FROM NEW.updated_at)
EXECUTE FUNCTION zitadel.set_updated_at();

View File

@@ -1,15 +1,55 @@
package postgres
import (
"github.com/georgysavva/scany/v2/pgxscan"
"github.com/jackc/pgx/v5"
"github.com/zitadel/zitadel/backend/v3/storage/database"
)
var _ database.Rows = (*Rows)(nil)
var (
_ database.Rows = (*Rows)(nil)
_ database.CollectableRows = (*Rows)(nil)
)
type Rows struct{ pgx.Rows }
// Collect implements [database.CollectableRows].
// See [this page](https://github.com/georgysavva/scany/blob/master/dbscan/doc.go#L8) for additional details.
func (r *Rows) Collect(dest any) (err error) {
defer func() {
closeErr := r.Close()
if err == nil {
err = closeErr
}
}()
return pgxscan.ScanAll(dest, r.Rows)
}
// CollectFirst implements [database.CollectableRows].
// See [this page](https://github.com/georgysavva/scany/blob/master/dbscan/doc.go#L8) for additional details.
func (r *Rows) CollectFirst(dest any) (err error) {
defer func() {
closeErr := r.Close()
if err == nil {
err = closeErr
}
}()
return pgxscan.ScanRow(dest, r.Rows)
}
// CollectExactlyOneRow implements [database.CollectableRows].
// See [this page](https://github.com/georgysavva/scany/blob/master/dbscan/doc.go#L8) for additional details.
func (r *Rows) CollectExactlyOneRow(dest any) (err error) {
defer func() {
closeErr := r.Close()
if err == nil {
err = closeErr
}
}()
return pgxscan.ScanOne(dest, r.Rows)
}
// Close implements [database.Rows].
// Subtle: this method shadows the method (Rows).Close of Rows.Rows.
func (r *Rows) Close() error {

View File

@@ -0,0 +1,67 @@
//go:build integration
package events_test
import (
"context"
"os"
"testing"
"time"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgxpool"
"github.com/zitadel/zitadel/backend/v3/storage/database"
"github.com/zitadel/zitadel/backend/v3/storage/database/dialect/postgres"
"github.com/zitadel/zitadel/internal/integration"
v2beta_org "github.com/zitadel/zitadel/pkg/grpc/org/v2beta"
"github.com/zitadel/zitadel/pkg/grpc/system"
)
const ConnString = "host=localhost port=5432 user=zitadel dbname=zitadel sslmode=disable"
var (
dbPool *pgxpool.Pool
CTX context.Context
Instance *integration.Instance
SystemClient system.SystemServiceClient
OrgClient v2beta_org.OrganizationServiceClient
)
var pool database.Pool
func TestMain(m *testing.M) {
os.Exit(func() int {
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Minute)
defer cancel()
Instance = integration.NewInstance(ctx)
CTX = Instance.WithAuthorization(ctx, integration.UserTypeIAMOwner)
SystemClient = integration.SystemClient()
OrgClient = Instance.Client.OrgV2beta
var err error
dbConfig, err := pgxpool.ParseConfig(ConnString)
if err != nil {
panic(err)
}
dbConfig.AfterConnect = func(ctx context.Context, conn *pgx.Conn) error {
orgState, err := conn.LoadType(ctx, "zitadel.organization_state")
if err != nil {
return err
}
conn.TypeMap().RegisterType(orgState)
return nil
}
dbPool, err = pgxpool.NewWithConfig(context.Background(), dbConfig)
if err != nil {
panic(err)
}
pool = postgres.PGxPool(dbPool)
return m.Run()
}())
}

View File

@@ -1,171 +1,138 @@
//go:build integration
package instance_test
package events_test
import (
"context"
"os"
"testing"
"time"
"github.com/brianvoe/gofakeit/v6"
"github.com/jackc/pgx/v5/pgxpool"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/zitadel/zitadel/backend/v3/storage/database"
"github.com/zitadel/zitadel/backend/v3/storage/database/dialect/postgres"
"github.com/zitadel/zitadel/backend/v3/storage/database/repository"
"github.com/zitadel/zitadel/internal/integration"
"github.com/zitadel/zitadel/pkg/grpc/system"
)
const ConnString = "host=localhost port=5432 user=zitadel dbname=zitadel sslmode=disable"
var (
dbPool *pgxpool.Pool
CTX context.Context
Instance *integration.Instance
SystemClient system.SystemServiceClient
)
var pool database.Pool
func TestMain(m *testing.M) {
os.Exit(func() int {
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Minute)
defer cancel()
Instance = integration.NewInstance(ctx)
CTX = Instance.WithAuthorization(ctx, integration.UserTypeIAMOwner)
SystemClient = integration.SystemClient()
var err error
dbPool, err = pgxpool.New(context.Background(), ConnString)
if err != nil {
panic(err)
}
pool = postgres.PGxPool(dbPool)
return m.Run()
}())
}
func TestServer_TestInstanceAddReduces(t *testing.T) {
instanceName := gofakeit.Name()
beforeCreate := time.Now()
_, err := SystemClient.CreateInstance(CTX, &system.CreateInstanceRequest{
InstanceName: instanceName,
Owner: &system.CreateInstanceRequest_Machine_{
Machine: &system.CreateInstanceRequest_Machine{
UserName: "owner",
Name: "owner",
PersonalAccessToken: &system.CreateInstanceRequest_PersonalAccessToken{},
func TestServer_TestInstanceReduces(t *testing.T) {
t.Run("test instance add reduces", func(t *testing.T) {
instanceName := gofakeit.Name()
beforeCreate := time.Now()
instance, err := SystemClient.CreateInstance(CTX, &system.CreateInstanceRequest{
InstanceName: instanceName,
Owner: &system.CreateInstanceRequest_Machine_{
Machine: &system.CreateInstanceRequest_Machine{
UserName: "owner",
Name: "owner",
PersonalAccessToken: &system.CreateInstanceRequest_PersonalAccessToken{},
},
},
},
})
afterCreate := time.Now()
require.NoError(t, err)
instanceRepo := repository.InstanceRepository(pool)
retryDuration, tick := integration.WaitForAndTickWithMaxDuration(CTX, time.Minute)
assert.EventuallyWithT(t, func(ttt *assert.CollectT) {
instance, err := instanceRepo.Get(CTX,
instance.GetInstanceId(),
)
require.NoError(ttt, err)
// event instance.added
assert.Equal(ttt, instanceName, instance.Name)
// event instance.default.org.set
assert.NotNil(t, instance.DefaultOrgID)
// event instance.iam.project.set
assert.NotNil(t, instance.IAMProjectID)
// event instance.iam.console.set
assert.NotNil(t, instance.ConsoleAppID)
// event instance.default.language.set
assert.NotNil(t, instance.DefaultLanguage)
// event instance.added
assert.WithinRange(t, instance.CreatedAt, beforeCreate, afterCreate)
// event instance.added
assert.WithinRange(t, instance.UpdatedAt, beforeCreate, afterCreate)
assert.Nil(t, instance.DeletedAt)
}, retryDuration, tick)
})
afterCreate := time.Now()
require.NoError(t, err)
instanceRepo := repository.InstanceRepository(pool)
retryDuration, tick := integration.WaitForAndTickWithMaxDuration(CTX, time.Minute)
assert.EventuallyWithT(t, func(ttt *assert.CollectT) {
instance, err := instanceRepo.Get(CTX,
instanceRepo.NameCondition(database.TextOperationEqual, instanceName),
)
require.NoError(ttt, err)
// event instance.added
require.Equal(ttt, instanceName, instance.Name)
// event instance.default.org.set
require.NotNil(t, instance.DefaultOrgID)
// event instance.iam.project.set
require.NotNil(t, instance.IAMProjectID)
// event instance.iam.console.set
require.NotNil(t, instance.ConsoleAppID)
// event instance.default.language.set
require.NotNil(t, instance.DefaultLanguage)
// event instance.added
assert.WithinRange(t, instance.CreatedAt, beforeCreate, afterCreate)
// event instance.added
assert.WithinRange(t, instance.UpdatedAt, beforeCreate, afterCreate)
require.Nil(t, instance.DeletedAt)
}, retryDuration, tick)
}
func TestServer_TestInstanceUpdateNameReduces(t *testing.T) {
instanceName := gofakeit.Name()
res, err := SystemClient.CreateInstance(CTX, &system.CreateInstanceRequest{
InstanceName: instanceName,
Owner: &system.CreateInstanceRequest_Machine_{
Machine: &system.CreateInstanceRequest_Machine{
UserName: "owner",
Name: "owner",
PersonalAccessToken: &system.CreateInstanceRequest_PersonalAccessToken{},
t.Run("test instance update reduces", func(t *testing.T) {
instanceName := gofakeit.Name()
res, err := SystemClient.CreateInstance(CTX, &system.CreateInstanceRequest{
InstanceName: instanceName,
Owner: &system.CreateInstanceRequest_Machine_{
Machine: &system.CreateInstanceRequest_Machine{
UserName: "owner",
Name: "owner",
PersonalAccessToken: &system.CreateInstanceRequest_PersonalAccessToken{},
},
},
},
})
require.NoError(t, err)
instanceName += "new"
beforeUpdate := time.Now()
_, err = SystemClient.UpdateInstance(CTX, &system.UpdateInstanceRequest{
InstanceId: res.InstanceId,
InstanceName: instanceName,
})
require.NoError(t, err)
afterUpdate := time.Now()
instanceRepo := repository.InstanceRepository(pool)
retryDuration, tick := integration.WaitForAndTickWithMaxDuration(CTX, time.Minute)
assert.EventuallyWithT(t, func(ttt *assert.CollectT) {
instance, err := instanceRepo.Get(CTX,
res.InstanceId,
)
require.NoError(ttt, err)
// event instance.changed
assert.Equal(ttt, instanceName, instance.Name)
assert.WithinRange(t, instance.UpdatedAt, beforeUpdate, afterUpdate)
}, retryDuration, tick)
})
require.NoError(t, err)
instanceName += "new"
_, err = SystemClient.UpdateInstance(CTX, &system.UpdateInstanceRequest{
InstanceId: res.InstanceId,
InstanceName: instanceName,
})
require.NoError(t, err)
instanceRepo := repository.InstanceRepository(pool)
retryDuration, tick := integration.WaitForAndTickWithMaxDuration(CTX, time.Minute)
assert.EventuallyWithT(t, func(ttt *assert.CollectT) {
instance, err := instanceRepo.Get(CTX,
instanceRepo.NameCondition(database.TextOperationEqual, instanceName),
)
require.NoError(ttt, err)
// event instance.changed
require.Equal(ttt, instanceName, instance.Name)
}, retryDuration, tick)
}
func TestServer_TestInstanceDeleteReduces(t *testing.T) {
instanceName := gofakeit.Name()
res, err := SystemClient.CreateInstance(CTX, &system.CreateInstanceRequest{
InstanceName: instanceName,
Owner: &system.CreateInstanceRequest_Machine_{
Machine: &system.CreateInstanceRequest_Machine{
UserName: "owner",
Name: "owner",
PersonalAccessToken: &system.CreateInstanceRequest_PersonalAccessToken{},
t.Run("test instance delete reduces", func(t *testing.T) {
instanceName := gofakeit.Name()
res, err := SystemClient.CreateInstance(CTX, &system.CreateInstanceRequest{
InstanceName: instanceName,
Owner: &system.CreateInstanceRequest_Machine_{
Machine: &system.CreateInstanceRequest_Machine{
UserName: "owner",
Name: "owner",
PersonalAccessToken: &system.CreateInstanceRequest_PersonalAccessToken{},
},
},
},
})
require.NoError(t, err)
instanceRepo := repository.InstanceRepository(pool)
// check instance exists
retryDuration, tick := integration.WaitForAndTickWithMaxDuration(CTX, time.Minute)
assert.EventuallyWithT(t, func(ttt *assert.CollectT) {
instance, err := instanceRepo.Get(CTX,
res.InstanceId,
)
require.NoError(ttt, err)
assert.Equal(ttt, instanceName, instance.Name)
}, retryDuration, tick)
_, err = SystemClient.RemoveInstance(CTX, &system.RemoveInstanceRequest{
InstanceId: res.InstanceId,
})
require.NoError(t, err)
retryDuration, tick = integration.WaitForAndTickWithMaxDuration(CTX, time.Minute)
assert.EventuallyWithT(t, func(ttt *assert.CollectT) {
instance, err := instanceRepo.Get(CTX,
res.InstanceId,
)
// event instance.removed
assert.Nil(t, instance)
require.NoError(t, err)
}, retryDuration, tick)
})
require.NoError(t, err)
instanceRepo := repository.InstanceRepository(pool)
// check instance exists
retryDuration, tick := integration.WaitForAndTickWithMaxDuration(CTX, time.Minute)
assert.EventuallyWithT(t, func(ttt *assert.CollectT) {
instance, err := instanceRepo.Get(CTX,
instanceRepo.NameCondition(database.TextOperationEqual, instanceName),
)
require.NoError(ttt, err)
require.Equal(ttt, instanceName, instance.Name)
}, retryDuration, tick)
_, err = SystemClient.RemoveInstance(CTX, &system.RemoveInstanceRequest{
InstanceId: res.InstanceId,
})
require.NoError(t, err)
retryDuration, tick = integration.WaitForAndTickWithMaxDuration(CTX, time.Minute)
assert.EventuallyWithT(t, func(ttt *assert.CollectT) {
instance, err := instanceRepo.Get(CTX,
instanceRepo.NameCondition(database.TextOperationEqual, instanceName),
)
// event instance.removed
require.Nil(t, instance)
require.NoError(ttt, err)
}, retryDuration, tick)
}

View File

@@ -0,0 +1,218 @@
//go:build integration
package events_test
import (
"testing"
"time"
"github.com/brianvoe/gofakeit/v6"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/zitadel/zitadel/backend/v3/domain"
"github.com/zitadel/zitadel/backend/v3/storage/database/repository"
"github.com/zitadel/zitadel/internal/integration"
v2beta_org "github.com/zitadel/zitadel/pkg/grpc/org/v2beta"
)
func TestServer_TestOrganizationReduces(t *testing.T) {
instanceID := Instance.ID()
t.Run("test org add reduces", func(t *testing.T) {
beforeCreate := time.Now()
orgName := gofakeit.Name()
_, err := OrgClient.CreateOrganization(CTX, &v2beta_org.CreateOrganizationRequest{
Name: orgName,
})
require.NoError(t, err)
afterCreate := time.Now()
orgRepo := repository.OrganizationRepository(pool)
retryDuration, tick := integration.WaitForAndTickWithMaxDuration(CTX, time.Minute)
assert.EventuallyWithT(t, func(tt *assert.CollectT) {
organization, err := orgRepo.Get(CTX,
orgRepo.NameCondition(orgName),
instanceID,
)
require.NoError(tt, err)
// event org.added
assert.NotNil(t, organization.ID)
assert.Equal(t, orgName, organization.Name)
assert.NotNil(t, organization.InstanceID)
assert.Equal(t, domain.OrgStateActive.String(), organization.State)
assert.WithinRange(t, organization.CreatedAt, beforeCreate, afterCreate)
assert.WithinRange(t, organization.UpdatedAt, beforeCreate, afterCreate)
assert.Nil(t, organization.DeletedAt)
}, retryDuration, tick)
})
t.Run("test org change reduces", func(t *testing.T) {
orgName := gofakeit.Name()
// 1. create org
organization, err := OrgClient.CreateOrganization(CTX, &v2beta_org.CreateOrganizationRequest{
Name: orgName,
})
require.NoError(t, err)
// 2. update org name
beforeUpdate := time.Now()
orgName = orgName + "_new"
_, err = OrgClient.UpdateOrganization(CTX, &v2beta_org.UpdateOrganizationRequest{
Id: organization.Id,
Name: orgName,
})
require.NoError(t, err)
afterUpdate := time.Now()
orgRepo := repository.OrganizationRepository(pool)
retryDuration, tick := integration.WaitForAndTickWithMaxDuration(CTX, time.Minute)
assert.EventuallyWithT(t, func(t *assert.CollectT) {
organization, err := orgRepo.Get(CTX,
orgRepo.NameCondition(orgName),
instanceID,
)
require.NoError(t, err)
// event org.changed
assert.Equal(t, orgName, organization.Name)
assert.WithinRange(t, organization.UpdatedAt, beforeUpdate, afterUpdate)
}, retryDuration, tick)
})
t.Run("test org deactivate reduces", func(t *testing.T) {
orgName := gofakeit.Name()
// 1. create org
organization, err := OrgClient.CreateOrganization(CTX, &v2beta_org.CreateOrganizationRequest{
Name: orgName,
})
require.NoError(t, err)
// 2. deactivate org name
beforeDeactivate := time.Now()
_, err = OrgClient.DeactivateOrganization(CTX, &v2beta_org.DeactivateOrganizationRequest{
Id: organization.Id,
})
require.NoError(t, err)
afterDeactivate := time.Now()
orgRepo := repository.OrganizationRepository(pool)
retryDuration, tick := integration.WaitForAndTickWithMaxDuration(CTX, time.Minute)
assert.EventuallyWithT(t, func(t *assert.CollectT) {
organization, err := orgRepo.Get(CTX,
orgRepo.NameCondition(orgName),
instanceID,
)
require.NoError(t, err)
// event org.deactivate
assert.Equal(t, orgName, organization.Name)
assert.Equal(t, domain.OrgStateInactive.String(), organization.State)
assert.WithinRange(t, organization.UpdatedAt, beforeDeactivate, afterDeactivate)
}, retryDuration, tick)
})
t.Run("test org activate reduces", func(t *testing.T) {
orgName := gofakeit.Name()
// 1. create org
organization, err := OrgClient.CreateOrganization(CTX, &v2beta_org.CreateOrganizationRequest{
Name: orgName,
})
require.NoError(t, err)
// 2. deactivate org name
_, err = OrgClient.DeactivateOrganization(CTX, &v2beta_org.DeactivateOrganizationRequest{
Id: organization.Id,
})
require.NoError(t, err)
orgRepo := repository.OrganizationRepository(pool)
// 3. check org deactivated
retryDuration, tick := integration.WaitForAndTickWithMaxDuration(CTX, time.Minute)
assert.EventuallyWithT(t, func(t *assert.CollectT) {
organization, err := orgRepo.Get(CTX,
orgRepo.NameCondition(orgName),
instanceID,
)
require.NoError(t, err)
assert.Equal(t, orgName, organization.Name)
assert.Equal(t, domain.OrgStateInactive.String(), organization.State)
}, retryDuration, tick)
// 4. activate org name
beforeActivate := time.Now()
_, err = OrgClient.ActivateOrganization(CTX, &v2beta_org.ActivateOrganizationRequest{
Id: organization.Id,
})
require.NoError(t, err)
afterActivate := time.Now()
retryDuration, tick = integration.WaitForAndTickWithMaxDuration(CTX, time.Minute)
assert.EventuallyWithT(t, func(t *assert.CollectT) {
organization, err := orgRepo.Get(CTX,
orgRepo.NameCondition(orgName),
instanceID,
)
require.NoError(t, err)
// event org.reactivate
assert.Equal(t, orgName, organization.Name)
assert.Equal(t, domain.OrgStateActive.String(), organization.State)
assert.WithinRange(t, organization.UpdatedAt, beforeActivate, afterActivate)
}, retryDuration, tick)
})
t.Run("test org remove reduces", func(t *testing.T) {
orgName := gofakeit.Name()
// 1. create org
organization, err := OrgClient.CreateOrganization(CTX, &v2beta_org.CreateOrganizationRequest{
Name: orgName,
})
require.NoError(t, err)
// 2. check org retrivable
orgRepo := repository.OrganizationRepository(pool)
retryDuration, tick := integration.WaitForAndTickWithMaxDuration(CTX, time.Minute)
assert.EventuallyWithT(t, func(t *assert.CollectT) {
organization, err := orgRepo.Get(CTX,
orgRepo.NameCondition(orgName),
instanceID,
)
require.NoError(t, err)
if organization == nil {
assert.Fail(t, "this error is here because of a race condition")
}
assert.Equal(t, orgName, organization.Name)
}, retryDuration, tick)
// 3. delete org
_, err = OrgClient.DeleteOrganization(CTX, &v2beta_org.DeleteOrganizationRequest{
Id: organization.Id,
})
require.NoError(t, err)
retryDuration, tick = integration.WaitForAndTickWithMaxDuration(CTX, time.Minute)
assert.EventuallyWithT(t, func(t *assert.CollectT) {
organization, err := orgRepo.Get(CTX,
orgRepo.NameCondition(orgName),
instanceID,
)
require.NoError(t, err)
// event org.remove
assert.Nil(t, organization)
}, retryDuration, tick)
})
}

View File

@@ -0,0 +1,21 @@
package database
// Order represents a SQL condition.
// Its written after the ORDER BY keyword in a SQL statement.
type Order interface {
Write(builder *StatementBuilder)
}
type orderBy struct {
column Column
}
func OrderBy(column Column) Order {
return &orderBy{column: column}
}
// Write implements [Order].
func (o *orderBy) Write(builder *StatementBuilder) {
builder.WriteString(" ORDER BY ")
o.column.Write(builder)
}

View File

@@ -33,16 +33,17 @@ const queryInstanceStmt = `SELECT id, name, default_org_id, iam_project_id, cons
` FROM zitadel.instances`
// Get implements [domain.InstanceRepository].
func (i *instance) Get(ctx context.Context, opts ...database.Condition) (*domain.Instance, error) {
func (i *instance) Get(ctx context.Context, id string) (*domain.Instance, error) {
var builder database.StatementBuilder
builder.WriteString(queryInstanceStmt)
idCondition := i.IDCondition(id)
// return only non deleted instances
opts = append(opts, database.IsNull(i.DeletedAtColumn()))
i.writeCondition(&builder, database.And(opts...))
conditions := []database.Condition{idCondition, database.IsNull(i.DeletedAtColumn())}
writeCondition(&builder, database.And(conditions...))
return scanInstance(i.client.QueryRow(ctx, builder.String(), builder.Args()...))
return scanInstance(ctx, i.client, &builder)
}
// List implements [domain.InstanceRepository].
@@ -54,15 +55,9 @@ func (i *instance) List(ctx context.Context, opts ...database.Condition) ([]*dom
// return only non deleted instances
opts = append(opts, database.IsNull(i.DeletedAtColumn()))
notDeletedCondition := database.And(opts...)
i.writeCondition(&builder, notDeletedCondition)
writeCondition(&builder, notDeletedCondition)
rows, err := i.client.Query(ctx, builder.String(), builder.Args()...)
if err != nil {
return nil, err
}
defer rows.Close()
return scanInstances(rows)
return scanInstances(ctx, i.client, &builder)
}
const createInstanceStmt = `INSERT INTO zitadel.instances (id, name, default_org_id, iam_project_id, console_client_id, console_app_id, default_language)` +
@@ -101,15 +96,20 @@ func (i *instance) Create(ctx context.Context, instance *domain.Instance) error
}
// Update implements [domain.InstanceRepository].
func (i instance) Update(ctx context.Context, condition database.Condition, changes ...database.Change) (int64, error) {
func (i instance) Update(ctx context.Context, id string, changes ...database.Change) (int64, error) {
if changes == nil {
return 0, errors.New("Update must contain a change")
}
var builder database.StatementBuilder
builder.WriteString(`UPDATE zitadel.instances SET `)
// don't update deleted instances
conditions := []database.Condition{condition, database.IsNull(i.DeletedAtColumn())}
database.Changes(changes).Write(&builder)
i.writeCondition(&builder, database.And(conditions...))
idCondition := i.IDCondition(id)
// don't update deleted instances
conditions := []database.Condition{idCondition, database.IsNull(i.DeletedAtColumn())}
writeCondition(&builder, database.And(conditions...))
stmt := builder.String()
@@ -118,18 +118,18 @@ func (i instance) Update(ctx context.Context, condition database.Condition, chan
}
// Delete implements [domain.InstanceRepository].
func (i instance) Delete(ctx context.Context, condition database.Condition) error {
if condition == nil {
return errors.New("Delete must contain a condition") // (otherwise ALL instances will be deleted)
}
func (i instance) Delete(ctx context.Context, id string) (int64, error) {
var builder database.StatementBuilder
builder.WriteString(`UPDATE zitadel.instances SET deleted_at = $1`)
builder.AppendArgs(time.Now())
i.writeCondition(&builder, condition)
_, err := i.client.Exec(ctx, builder.String(), builder.Args()...)
return err
// don't delete already deleted instance
idCondition := i.IDCondition(id)
conditions := []database.Condition{idCondition, database.IsNull(i.DeletedAtColumn())}
writeCondition(&builder, database.And(conditions...))
return i.client.Exec(ctx, builder.String(), builder.Args()...)
}
// -------------------------------------------------------------
@@ -209,32 +209,30 @@ func (instance) DeletedAtColumn() database.Column {
return database.NewColumn("deleted_at")
}
func (i *instance) writeCondition(
builder *database.StatementBuilder,
condition database.Condition,
) {
if condition == nil {
return
func scanInstance(ctx context.Context, querier database.Querier, builder *database.StatementBuilder) (*domain.Instance, error) {
rows, err := querier.Query(ctx, builder.String(), builder.Args()...)
if err != nil {
return nil, err
}
builder.WriteString(" WHERE ")
condition.Write(builder)
instance := new(domain.Instance)
if err := rows.(database.CollectableRows).CollectExactlyOneRow(instance); err != nil {
if err.Error() == "no rows in result set" {
return nil, ErrResourceDoesNotExist
}
return nil, err
}
return instance, nil
}
func scanInstance(scanner database.Scanner) (*domain.Instance, error) {
var instance domain.Instance
err := scanner.Scan(
&instance.ID,
&instance.Name,
&instance.DefaultOrgID,
&instance.IAMProjectID,
&instance.ConsoleClientID,
&instance.ConsoleAppID,
&instance.DefaultLanguage,
&instance.CreatedAt,
&instance.UpdatedAt,
&instance.DeletedAt,
)
func scanInstances(ctx context.Context, querier database.Querier, builder *database.StatementBuilder) (instances []*domain.Instance, err error) {
rows, err := querier.Query(ctx, builder.String(), builder.Args()...)
if err != nil {
return nil, err
}
if err := rows.(database.CollectableRows).Collect(&instances); err != nil {
// if no results returned, this is not a error
// it just means the instance was not found
// the caller should check if the returned instance is nil
@@ -244,32 +242,5 @@ func scanInstance(scanner database.Scanner) (*domain.Instance, error) {
return nil, err
}
return &instance, nil
}
func scanInstances(rows database.Rows) ([]*domain.Instance, error) {
instances := make([]*domain.Instance, 0)
for rows.Next() {
var instance domain.Instance
err := rows.Scan(
&instance.ID,
&instance.Name,
&instance.DefaultOrgID,
&instance.IAMProjectID,
&instance.ConsoleClientID,
&instance.ConsoleAppID,
&instance.DefaultLanguage,
&instance.CreatedAt,
&instance.UpdatedAt,
&instance.DeletedAt,
)
if err != nil {
return nil, err
}
instances = append(instances, &instance)
}
return instances, nil
}

View File

@@ -7,6 +7,7 @@ import (
"time"
"github.com/brianvoe/gofakeit/v6"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/zitadel/zitadel/backend/v3/domain"
@@ -74,11 +75,61 @@ func TestCreateInstance(t *testing.T) {
}
err := instanceRepo.Create(ctx, &inst)
// change the name to make sure same only the id clashes
inst.Name = gofakeit.Name()
require.NoError(t, err)
return &inst
},
err: errors.New("instance id already exists"),
},
func() struct {
name string
testFunc func(ctx context.Context, t *testing.T) *domain.Instance
instance domain.Instance
err error
} {
instanceId := gofakeit.Name()
instanceName := gofakeit.Name()
return struct {
name string
testFunc func(ctx context.Context, t *testing.T) *domain.Instance
instance domain.Instance
err error
}{
name: "adding instance with same name twice",
testFunc: func(ctx context.Context, t *testing.T) *domain.Instance {
instanceRepo := repository.InstanceRepository(pool)
inst := domain.Instance{
ID: gofakeit.Name(),
Name: instanceName,
DefaultOrgID: "defaultOrgId",
IAMProjectID: "iamProject",
ConsoleClientID: "consoleCLient",
ConsoleAppID: "consoleApp",
DefaultLanguage: "defaultLanguage",
}
err := instanceRepo.Create(ctx, &inst)
require.NoError(t, err)
// change the id
inst.ID = instanceId
return &inst
},
instance: domain.Instance{
ID: instanceId,
Name: instanceName,
DefaultOrgID: "defaultOrgId",
IAMProjectID: "iamProject",
ConsoleClientID: "consoleCLient",
ConsoleAppID: "consoleApp",
DefaultLanguage: "defaultLanguage",
},
// two instances can have the sane name
err: nil,
}
}(),
{
name: "adding instance with no id",
instance: func() domain.Instance {
@@ -113,7 +164,7 @@ func TestCreateInstance(t *testing.T) {
// create instance
beforeCreate := time.Now()
err := instanceRepo.Create(ctx, instance)
require.Equal(t, tt.err, err)
assert.Equal(t, tt.err, err)
if err != nil {
return
}
@@ -121,20 +172,20 @@ func TestCreateInstance(t *testing.T) {
// check instance values
instance, err = instanceRepo.Get(ctx,
instanceRepo.NameCondition(database.TextOperationEqual, instance.Name),
instance.ID,
)
require.NoError(t, err)
require.Equal(t, tt.instance.ID, instance.ID)
require.Equal(t, tt.instance.Name, instance.Name)
require.Equal(t, tt.instance.DefaultOrgID, instance.DefaultOrgID)
require.Equal(t, tt.instance.IAMProjectID, instance.IAMProjectID)
require.Equal(t, tt.instance.ConsoleClientID, instance.ConsoleClientID)
require.Equal(t, tt.instance.ConsoleAppID, instance.ConsoleAppID)
require.Equal(t, tt.instance.DefaultLanguage, instance.DefaultLanguage)
require.WithinRange(t, instance.CreatedAt, beforeCreate, afterCreate)
require.WithinRange(t, instance.UpdatedAt, beforeCreate, afterCreate)
require.Nil(t, instance.DeletedAt)
assert.Equal(t, tt.instance.ID, instance.ID)
assert.Equal(t, tt.instance.Name, instance.Name)
assert.Equal(t, tt.instance.DefaultOrgID, instance.DefaultOrgID)
assert.Equal(t, tt.instance.IAMProjectID, instance.IAMProjectID)
assert.Equal(t, tt.instance.ConsoleClientID, instance.ConsoleClientID)
assert.Equal(t, tt.instance.ConsoleAppID, instance.ConsoleAppID)
assert.Equal(t, tt.instance.DefaultLanguage, instance.DefaultLanguage)
assert.WithinRange(t, instance.CreatedAt, beforeCreate, afterCreate)
assert.WithinRange(t, instance.UpdatedAt, beforeCreate, afterCreate)
assert.Nil(t, instance.DeletedAt)
})
}
}
@@ -144,6 +195,7 @@ func TestUpdateInstance(t *testing.T) {
name string
testFunc func(ctx context.Context, t *testing.T) *domain.Instance
rowsAffected int64
getErr error
}{
{
name: "happy path",
@@ -191,10 +243,11 @@ func TestUpdateInstance(t *testing.T) {
require.NoError(t, err)
// delete instance
err = instanceRepo.Delete(ctx,
instanceRepo.IDCondition(inst.ID),
affectedRows, err := instanceRepo.Delete(ctx,
inst.ID,
)
require.NoError(t, err)
assert.Equal(t, int64(1), affectedRows)
return &inst
},
@@ -211,6 +264,7 @@ func TestUpdateInstance(t *testing.T) {
return &inst
},
rowsAffected: 0,
getErr: repository.ErrResourceDoesNotExist,
},
}
for _, tt := range tests {
@@ -224,13 +278,13 @@ func TestUpdateInstance(t *testing.T) {
// update name
newName := "new_" + instance.Name
rowsAffected, err := instanceRepo.Update(ctx,
instanceRepo.IDCondition(instance.ID),
instance.ID,
instanceRepo.SetName(newName),
)
afterUpdate := time.Now()
require.NoError(t, err)
require.Equal(t, tt.rowsAffected, rowsAffected)
assert.Equal(t, tt.rowsAffected, rowsAffected)
if rowsAffected == 0 {
return
@@ -238,13 +292,13 @@ func TestUpdateInstance(t *testing.T) {
// check instance values
instance, err = instanceRepo.Get(ctx,
instanceRepo.IDCondition(instance.ID),
instance.ID,
)
require.NoError(t, err)
require.Equal(t, tt.getErr, err)
require.Equal(t, newName, instance.Name)
require.WithinRange(t, instance.UpdatedAt, beforeUpdate, afterUpdate)
require.Nil(t, instance.DeletedAt)
assert.Equal(t, newName, instance.Name)
assert.WithinRange(t, instance.UpdatedAt, beforeUpdate, afterUpdate)
assert.Nil(t, instance.DeletedAt)
})
}
}
@@ -252,9 +306,9 @@ func TestUpdateInstance(t *testing.T) {
func TestGetInstance(t *testing.T) {
instanceRepo := repository.InstanceRepository(pool)
type test struct {
name string
testFunc func(ctx context.Context, t *testing.T) *domain.Instance
conditionClauses []database.Condition
name string
testFunc func(ctx context.Context, t *testing.T) *domain.Instance
err error
}
tests := []test{
@@ -280,45 +334,17 @@ func TestGetInstance(t *testing.T) {
require.NoError(t, err)
return &inst
},
conditionClauses: []database.Condition{instanceRepo.IDCondition(instanceId)},
}
}(),
func() test {
instanceName := gofakeit.Name()
return test{
name: "happy path get using name",
testFunc: func(ctx context.Context, t *testing.T) *domain.Instance {
instanceId := gofakeit.Name()
inst := domain.Instance{
ID: instanceId,
Name: instanceName,
DefaultOrgID: "defaultOrgId",
IAMProjectID: "iamProject",
ConsoleClientID: "consoleCLient",
ConsoleAppID: "consoleApp",
DefaultLanguage: "defaultLanguage",
}
// create instance
err := instanceRepo.Create(ctx, &inst)
require.NoError(t, err)
return &inst
},
conditionClauses: []database.Condition{instanceRepo.NameCondition(database.TextOperationEqual, instanceName)},
}
}(),
{
name: "get non existent instance",
testFunc: func(ctx context.Context, t *testing.T) *domain.Instance {
instanceId := gofakeit.Name()
_ = domain.Instance{
ID: instanceId,
inst := domain.Instance{
ID: "get non existent instance",
}
return nil
return &inst
},
conditionClauses: []database.Condition{instanceRepo.NameCondition(database.TextOperationEqual, "non-existent-instance-name")},
err: repository.ErrResourceDoesNotExist,
},
}
for _, tt := range tests {
@@ -333,22 +359,25 @@ func TestGetInstance(t *testing.T) {
// check instance values
returnedInstance, err := instanceRepo.Get(ctx,
tt.conditionClauses...,
instance.ID,
)
require.NoError(t, err)
if instance == nil {
require.Nil(t, instance, returnedInstance)
if tt.err != nil {
require.Equal(t, tt.err, err)
return
}
require.NoError(t, err)
require.Equal(t, returnedInstance.ID, instance.ID)
require.Equal(t, returnedInstance.Name, instance.Name)
require.Equal(t, returnedInstance.DefaultOrgID, instance.DefaultOrgID)
require.Equal(t, returnedInstance.IAMProjectID, instance.IAMProjectID)
require.Equal(t, returnedInstance.ConsoleClientID, instance.ConsoleClientID)
require.Equal(t, returnedInstance.ConsoleAppID, instance.ConsoleAppID)
require.Equal(t, returnedInstance.DefaultLanguage, instance.DefaultLanguage)
if instance.ID == "get non existent instance" {
assert.Nil(t, returnedInstance)
return
}
assert.Equal(t, returnedInstance.ID, instance.ID)
assert.Equal(t, returnedInstance.Name, instance.Name)
assert.Equal(t, returnedInstance.DefaultOrgID, instance.DefaultOrgID)
assert.Equal(t, returnedInstance.IAMProjectID, instance.IAMProjectID)
assert.Equal(t, returnedInstance.ConsoleClientID, instance.ConsoleClientID)
assert.Equal(t, returnedInstance.ConsoleAppID, instance.ConsoleAppID)
assert.Equal(t, returnedInstance.DefaultLanguage, instance.DefaultLanguage)
})
}
}
@@ -504,19 +533,19 @@ func TestListInstance(t *testing.T) {
)
require.NoError(t, err)
if tt.noInstanceReturned {
require.Nil(t, returnedInstances)
assert.Nil(t, returnedInstances)
return
}
require.Equal(t, len(instances), len(returnedInstances))
assert.Equal(t, len(instances), len(returnedInstances))
for i, instance := range instances {
require.Equal(t, returnedInstances[i].ID, instance.ID)
require.Equal(t, returnedInstances[i].Name, instance.Name)
require.Equal(t, returnedInstances[i].DefaultOrgID, instance.DefaultOrgID)
require.Equal(t, returnedInstances[i].IAMProjectID, instance.IAMProjectID)
require.Equal(t, returnedInstances[i].ConsoleClientID, instance.ConsoleClientID)
require.Equal(t, returnedInstances[i].ConsoleAppID, instance.ConsoleAppID)
require.Equal(t, returnedInstances[i].DefaultLanguage, instance.DefaultLanguage)
assert.Equal(t, returnedInstances[i].ID, instance.ID)
assert.Equal(t, returnedInstances[i].Name, instance.Name)
assert.Equal(t, returnedInstances[i].DefaultOrgID, instance.DefaultOrgID)
assert.Equal(t, returnedInstances[i].IAMProjectID, instance.IAMProjectID)
assert.Equal(t, returnedInstances[i].ConsoleClientID, instance.ConsoleClientID)
assert.Equal(t, returnedInstances[i].ConsoleAppID, instance.ConsoleAppID)
assert.Equal(t, returnedInstances[i].DefaultLanguage, instance.DefaultLanguage)
}
})
}
@@ -524,18 +553,19 @@ func TestListInstance(t *testing.T) {
func TestDeleteInstance(t *testing.T) {
type test struct {
name string
testFunc func(ctx context.Context, t *testing.T)
conditionClauses database.Condition
name string
testFunc func(ctx context.Context, t *testing.T)
instanceID string
noOfDeletedRows int64
}
tests := []test{
func() test {
instanceRepo := repository.InstanceRepository(pool)
instanceId := gofakeit.Name()
var noOfInstances int64 = 1
return test{
name: "happy path delete single instance filter id",
testFunc: func(ctx context.Context, t *testing.T) {
noOfInstances := 1
instances := make([]*domain.Instance, noOfInstances)
for i := range noOfInstances {
@@ -556,75 +586,15 @@ func TestDeleteInstance(t *testing.T) {
instances[i] = &inst
}
},
conditionClauses: instanceRepo.IDCondition(instanceId),
instanceID: instanceId,
noOfDeletedRows: noOfInstances,
}
}(),
func() test {
instanceRepo := repository.InstanceRepository(pool)
instanceName := gofakeit.Name()
return test{
name: "happy path delete single instance filter name",
testFunc: func(ctx context.Context, t *testing.T) {
noOfInstances := 1
instances := make([]*domain.Instance, noOfInstances)
for i := range noOfInstances {
inst := domain.Instance{
ID: gofakeit.Name(),
Name: instanceName,
DefaultOrgID: "defaultOrgId",
IAMProjectID: "iamProject",
ConsoleClientID: "consoleCLient",
ConsoleAppID: "consoleApp",
DefaultLanguage: "defaultLanguage",
}
// create instance
err := instanceRepo.Create(ctx, &inst)
require.NoError(t, err)
instances[i] = &inst
}
},
conditionClauses: instanceRepo.NameCondition(database.TextOperationEqual, instanceName),
}
}(),
func() test {
instanceRepo := repository.InstanceRepository(pool)
non_existent_instance_name := gofakeit.Name()
return test{
name: "delete non existent instance",
conditionClauses: instanceRepo.NameCondition(database.TextOperationEqual, non_existent_instance_name),
}
}(),
func() test {
instanceRepo := repository.InstanceRepository(pool)
instanceName := gofakeit.Name()
return test{
name: "multiple instance filter on name",
testFunc: func(ctx context.Context, t *testing.T) {
noOfInstances := 5
instances := make([]*domain.Instance, noOfInstances)
for i := range noOfInstances {
inst := domain.Instance{
ID: gofakeit.Name(),
Name: instanceName,
DefaultOrgID: "defaultOrgId",
IAMProjectID: "iamProject",
ConsoleClientID: "consoleCLient",
ConsoleAppID: "consoleApp",
DefaultLanguage: "defaultLanguage",
}
// create instance
err := instanceRepo.Create(ctx, &inst)
require.NoError(t, err)
instances[i] = &inst
}
},
conditionClauses: instanceRepo.NameCondition(database.TextOperationEqual, instanceName),
name: "delete non existent instance",
instanceID: non_existent_instance_name,
}
}(),
func() test {
@@ -655,12 +625,15 @@ func TestDeleteInstance(t *testing.T) {
}
// delete instance
err := instanceRepo.Delete(ctx,
instanceRepo.NameCondition(database.TextOperationEqual, instanceName),
affectedRows, err := instanceRepo.Delete(ctx,
instances[0].ID,
)
require.NoError(t, err)
assert.Equal(t, int64(1), affectedRows)
},
conditionClauses: instanceRepo.NameCondition(database.TextOperationEqual, instanceName),
instanceID: instanceName,
// this test should return 0 affected rows as the instance was already deleted
noOfDeletedRows: 0,
}
}(),
}
@@ -674,17 +647,18 @@ func TestDeleteInstance(t *testing.T) {
}
// delete instance
err := instanceRepo.Delete(ctx,
tt.conditionClauses,
noOfDeletedRows, err := instanceRepo.Delete(ctx,
tt.instanceID,
)
require.NoError(t, err)
assert.Equal(t, noOfDeletedRows, tt.noOfDeletedRows)
// check instance was deleted
instance, err := instanceRepo.Get(ctx,
tt.conditionClauses,
tt.instanceID,
)
require.NoError(t, err)
require.Nil(t, instance)
require.Equal(t, err, repository.ErrResourceDoesNotExist)
assert.Nil(t, instance)
})
}
}

View File

@@ -2,8 +2,11 @@ package repository
import (
"context"
"errors"
"time"
"github.com/jackc/pgx/v5/pgconn"
"github.com/zitadel/zitadel/backend/v3/domain"
"github.com/zitadel/zitadel/backend/v3/storage/database"
)
@@ -12,11 +15,13 @@ import (
// repository
// -------------------------------------------------------------
var _ domain.OrganizationRepository = (*org)(nil)
type org struct {
repository
}
func OrgRepository(client database.QueryExecutor) domain.OrgRepository {
func OrganizationRepository(client database.QueryExecutor) domain.OrganizationRepository {
return &org{
repository: repository{
client: client,
@@ -24,52 +29,140 @@ func OrgRepository(client database.QueryExecutor) domain.OrgRepository {
}
}
// Create implements [domain.OrgRepository].
func (o *org) Create(ctx context.Context, org *domain.Org) error {
org.CreatedAt = time.Now()
org.UpdatedAt = org.CreatedAt
const queryOrganizationStmt = `SELECT id, name, instance_id, state, created_at, updated_at, deleted_at` +
` FROM zitadel.organizations`
// Get implements [domain.OrganizationRepository].
func (o *org) Get(ctx context.Context, id domain.OrgIdentifierCondition, instanceID string, conditions ...database.Condition) (*domain.Organization, error) {
builder := database.StatementBuilder{}
builder.WriteString(queryOrganizationStmt)
instanceIDCondition := o.InstanceIDCondition(instanceID)
// don't update deleted organizations
nonDeletedOrgs := database.IsNull(o.DeletedAtColumn())
conditions = append(conditions, id, instanceIDCondition, nonDeletedOrgs)
writeCondition(&builder, database.And(conditions...))
return scanOrganization(ctx, o.client, &builder)
}
// List implements [domain.OrganizationRepository].
func (o *org) List(ctx context.Context, opts ...database.Condition) ([]*domain.Organization, error) {
builder := database.StatementBuilder{}
builder.WriteString(queryOrganizationStmt)
// return only non deleted organizations
opts = append(opts, database.IsNull(o.DeletedAtColumn()))
writeCondition(&builder, database.And(opts...))
orderBy := database.OrderBy(o.CreatedAtColumn())
orderBy.Write(&builder)
return scanOrganizations(ctx, o.client, &builder)
}
const createOrganizationStmt = `INSERT INTO zitadel.organizations (id, name, instance_id, state)` +
` VALUES ($1, $2, $3, $4)` +
` RETURNING created_at, updated_at`
// Create implements [domain.OrganizationRepository].
func (o *org) Create(ctx context.Context, organization *domain.Organization) error {
builder := database.StatementBuilder{}
builder.AppendArgs(organization.ID, organization.Name, organization.InstanceID, organization.State)
builder.WriteString(createOrganizationStmt)
err := o.client.QueryRow(ctx, builder.String(), builder.Args()...).Scan(&organization.CreatedAt, &organization.UpdatedAt)
if err != nil {
return checkCreateOrgErr(err)
}
return nil
}
// Delete implements [domain.OrgRepository].
func (o *org) Delete(ctx context.Context, condition database.Condition) error {
return nil
func checkCreateOrgErr(err error) error {
var pgErr *pgconn.PgError
if !errors.As(err, &pgErr) {
return err
}
// constraint violation
if pgErr.Code == "23514" {
if pgErr.ConstraintName == "organizations_name_check" {
return errors.New("organization name not provided")
}
if pgErr.ConstraintName == "organizations_id_check" {
return errors.New("organization id not provided")
}
}
// duplicate
if pgErr.Code == "23505" {
if pgErr.ConstraintName == "organizations_pkey" {
return errors.New("organization id already exists")
}
if pgErr.ConstraintName == "org_unique_instance_id_name_idx" {
return errors.New("organization name already exists for instance")
}
}
// invalid instance id
if pgErr.Code == "23503" {
if pgErr.ConstraintName == "organizations_instance_id_fkey" {
return errors.New("invalid instance id")
}
}
return err
}
// Get implements [domain.OrgRepository].
func (o *org) Get(ctx context.Context, opts ...database.QueryOption) (*domain.Org, error) {
panic("unimplemented")
// Update implements [domain.OrganizationRepository].
func (o org) Update(ctx context.Context, id domain.OrgIdentifierCondition, instanceID string, changes ...database.Change) (int64, error) {
if changes == nil {
return 0, errors.New("Update must contain a condition") // (otherwise ALL organizations will be updated)
}
builder := database.StatementBuilder{}
builder.WriteString(`UPDATE zitadel.organizations SET `)
instanceIDCondition := o.InstanceIDCondition(instanceID)
// don't update deleted organizations
nonDeletedOrgs := database.IsNull(o.DeletedAtColumn())
conditions := []database.Condition{id, instanceIDCondition, nonDeletedOrgs}
database.Changes(changes).Write(&builder)
writeCondition(&builder, database.And(conditions...))
stmt := builder.String()
rowsAffected, err := o.client.Exec(ctx, stmt, builder.Args()...)
return rowsAffected, err
}
// List implements [domain.OrgRepository].
func (o *org) List(ctx context.Context, opts ...database.QueryOption) ([]*domain.Org, error) {
panic("unimplemented")
}
// Delete implements [domain.OrganizationRepository].
func (o org) Delete(ctx context.Context, id domain.OrgIdentifierCondition, instanceID string) (int64, error) {
builder := database.StatementBuilder{}
// Update implements [domain.OrgRepository].
func (o *org) Update(ctx context.Context, condition database.Condition, changes ...database.Change) error {
panic("unimplemented")
}
builder.WriteString(`UPDATE zitadel.organizations SET deleted_at = $1`)
builder.AppendArgs(time.Now())
func (o *org) Member() domain.MemberRepository {
return &orgMember{o}
}
instanceIDCondition := o.InstanceIDCondition(instanceID)
// don't update deleted organizations
nonDeletedOrgs := database.IsNull(o.DeletedAtColumn())
func (o *org) Domain() domain.DomainRepository {
return &orgDomain{o}
conditions := []database.Condition{id, instanceIDCondition, nonDeletedOrgs}
writeCondition(&builder, database.And(conditions...))
return o.client.Exec(ctx, builder.String(), builder.Args()...)
}
// -------------------------------------------------------------
// changes
// -------------------------------------------------------------
// SetName implements [domain.orgChanges].
func (o *org) SetName(name string) database.Change {
// SetName implements [domain.organizationChanges].
func (o org) SetName(name string) database.Change {
return database.NewChange(o.NameColumn(), name)
}
// SetState implements [domain.orgChanges].
func (o *org) SetState(state domain.OrgState) database.Change {
// SetState implements [domain.organizationChanges].
func (o org) SetState(state domain.OrgState) database.Change {
return database.NewChange(o.StateColumn(), state)
}
@@ -77,63 +170,97 @@ func (o *org) SetState(state domain.OrgState) database.Change {
// conditions
// -------------------------------------------------------------
// IDCondition implements [domain.orgConditions].
func (o *org) IDCondition(orgID string) database.Condition {
return database.NewTextCondition(o.IDColumn(), database.TextOperationEqual, orgID)
// IDCondition implements [domain.organizationConditions].
func (o org) IDCondition(id string) domain.OrgIdentifierCondition {
return database.NewTextCondition(o.IDColumn(), database.TextOperationEqual, id)
}
// InstanceIDCondition implements [domain.orgConditions].
func (o *org) InstanceIDCondition(instanceID string) database.Condition {
// NameCondition implements [domain.organizationConditions].
func (o org) NameCondition(name string) domain.OrgIdentifierCondition {
return database.NewTextCondition(o.NameColumn(), database.TextOperationEqual, name)
}
// InstanceIDCondition implements [domain.organizationConditions].
func (o org) InstanceIDCondition(instanceID string) database.Condition {
return database.NewTextCondition(o.InstanceIDColumn(), database.TextOperationEqual, instanceID)
}
// NameCondition implements [domain.orgConditions].
func (o *org) NameCondition(op database.TextOperation, name string) database.Condition {
return database.NewTextCondition(o.NameColumn(), op, name)
}
// StateCondition implements [domain.orgConditions].
func (o *org) StateCondition(op database.NumberOperation, state domain.OrgState) database.Condition {
return database.NewNumberCondition(o.StateColumn(), op, state)
// StateCondition implements [domain.organizationConditions].
func (o org) StateCondition(state domain.OrgState) database.Condition {
return database.NewTextCondition(o.StateColumn(), database.TextOperationEqual, state.String())
}
// -------------------------------------------------------------
// columns
// -------------------------------------------------------------
// CreatedAtColumn implements [domain.orgColumns].
func (o *org) CreatedAtColumn() database.Column {
return database.NewColumn("created_at")
}
// DeletedAtColumn implements [domain.orgColumns].
func (o *org) DeletedAtColumn() database.Column {
return database.NewColumn("deleted_at")
}
// IDColumn implements [domain.orgColumns].
func (o *org) IDColumn() database.Column {
// IDColumn implements [domain.organizationColumns].
func (org) IDColumn() database.Column {
return database.NewColumn("id")
}
// InstanceIDColumn implements [domain.orgColumns].
func (o *org) InstanceIDColumn() database.Column {
return database.NewColumn("instance_id")
}
// NameColumn implements [domain.orgColumns].
func (o *org) NameColumn() database.Column {
// NameColumn implements [domain.organizationColumns].
func (org) NameColumn() database.Column {
return database.NewColumn("name")
}
// StateColumn implements [domain.orgColumns].
func (o *org) StateColumn() database.Column {
// InstanceIDColumn implements [domain.organizationColumns].
func (org) InstanceIDColumn() database.Column {
return database.NewColumn("instance_id")
}
// StateColumn implements [domain.organizationColumns].
func (org) StateColumn() database.Column {
return database.NewColumn("state")
}
// UpdatedAtColumn implements [domain.orgColumns].
func (o *org) UpdatedAtColumn() database.Column {
// CreatedAtColumn implements [domain.organizationColumns].
func (org) CreatedAtColumn() database.Column {
return database.NewColumn("created_at")
}
// UpdatedAtColumn implements [domain.organizationColumns].
func (org) UpdatedAtColumn() database.Column {
return database.NewColumn("updated_at")
}
var _ domain.OrgRepository = (*org)(nil)
// DeletedAtColumn implements [domain.organizationColumns].
func (org) DeletedAtColumn() database.Column {
return database.NewColumn("deleted_at")
}
func scanOrganization(ctx context.Context, querier database.Querier, builder *database.StatementBuilder) (*domain.Organization, error) {
rows, err := querier.Query(ctx, builder.String(), builder.Args()...)
if err != nil {
return nil, err
}
organization := &domain.Organization{}
if err := rows.(database.CollectableRows).CollectExactlyOneRow(organization); err != nil {
if err.Error() == "no rows in result set" {
return nil, ErrResourceDoesNotExist
}
return nil, err
}
return organization, nil
}
func scanOrganizations(ctx context.Context, querier database.Querier, builder *database.StatementBuilder) ([]*domain.Organization, error) {
rows, err := querier.Query(ctx, builder.String(), builder.Args()...)
if err != nil {
return nil, err
}
organizations := []*domain.Organization{}
if err := rows.(database.CollectableRows).Collect(&organizations); err != nil {
// if no results returned, this is not a error
// it just means the organization was not found
// the caller should check if the returned organization is nil
if err.Error() == "no rows in result set" {
return nil, nil
}
return nil, err
}
return organizations, nil
}

View File

@@ -1,10 +1,942 @@
package repository_test
// iraq: I had to comment out so that the UTs would pass
// TestBla is an example and can be removed later
// func TestBla(t *testing.T) {
// var count int
// err := pool.QueryRow(context.Background(), "select count(*) from zitadel.instances").Scan(&count)
// assert.NoError(t, err)
// assert.Equal(t, 0, count)
// }
import (
"context"
"errors"
"testing"
"time"
"github.com/brianvoe/gofakeit/v6"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/zitadel/zitadel/backend/v3/domain"
"github.com/zitadel/zitadel/backend/v3/storage/database"
"github.com/zitadel/zitadel/backend/v3/storage/database/repository"
)
func TestCreateOrganization(t *testing.T) {
// create instance
instanceId := gofakeit.Name()
instance := domain.Instance{
ID: instanceId,
Name: gofakeit.Name(),
DefaultOrgID: "defaultOrgId",
IAMProjectID: "iamProject",
ConsoleClientID: "consoleCLient",
ConsoleAppID: "consoleApp",
DefaultLanguage: "defaultLanguage",
}
instanceRepo := repository.InstanceRepository(pool)
err := instanceRepo.Create(t.Context(), &instance)
assert.Nil(t, err)
tests := []struct {
name string
testFunc func(ctx context.Context, t *testing.T) *domain.Organization
organization domain.Organization
err error
}{
{
name: "happy path",
organization: func() domain.Organization {
organizationId := gofakeit.Name()
organizationName := gofakeit.Name()
organization := domain.Organization{
ID: organizationId,
Name: organizationName,
InstanceID: instanceId,
State: domain.OrgStateActive.String(),
}
return organization
}(),
},
{
name: "create organization without name",
organization: func() domain.Organization {
organizationId := gofakeit.Name()
// organizationName := gofakeit.Name()
organization := domain.Organization{
ID: organizationId,
Name: "",
InstanceID: instanceId,
State: domain.OrgStateActive.String(),
}
return organization
}(),
err: errors.New("organization name not provided"),
},
{
name: "adding org with same id twice",
testFunc: func(ctx context.Context, t *testing.T) *domain.Organization {
organizationRepo := repository.OrganizationRepository(pool)
organizationId := gofakeit.Name()
organizationName := gofakeit.Name()
org := domain.Organization{
ID: organizationId,
Name: organizationName,
InstanceID: instanceId,
State: domain.OrgStateActive.String(),
}
err := organizationRepo.Create(ctx, &org)
require.NoError(t, err)
// change the name to make sure same only the id clashes
org.Name = gofakeit.Name()
return &org
},
err: errors.New("organization id already exists"),
},
{
name: "adding org with same name twice",
testFunc: func(ctx context.Context, t *testing.T) *domain.Organization {
organizationRepo := repository.OrganizationRepository(pool)
organizationId := gofakeit.Name()
organizationName := gofakeit.Name()
org := domain.Organization{
ID: organizationId,
Name: organizationName,
InstanceID: instanceId,
State: domain.OrgStateActive.String(),
}
err := organizationRepo.Create(ctx, &org)
require.NoError(t, err)
// change the id to make sure same name+instance causes an error
org.ID = gofakeit.Name()
return &org
},
err: errors.New("organization name already exists for instance"),
},
func() struct {
name string
testFunc func(ctx context.Context, t *testing.T) *domain.Organization
organization domain.Organization
err error
} {
orgID := gofakeit.Name()
organizationName := gofakeit.Name()
return struct {
name string
testFunc func(ctx context.Context, t *testing.T) *domain.Organization
organization domain.Organization
err error
}{
name: "adding org with same name, different instance",
testFunc: func(ctx context.Context, t *testing.T) *domain.Organization {
// create instance
instId := gofakeit.Name()
instance := domain.Instance{
ID: instId,
Name: gofakeit.Name(),
DefaultOrgID: "defaultOrgId",
IAMProjectID: "iamProject",
ConsoleClientID: "consoleCLient",
ConsoleAppID: "consoleApp",
DefaultLanguage: "defaultLanguage",
}
instanceRepo := repository.InstanceRepository(pool)
err := instanceRepo.Create(ctx, &instance)
assert.Nil(t, err)
organizationRepo := repository.OrganizationRepository(pool)
org := domain.Organization{
ID: gofakeit.Name(),
Name: organizationName,
InstanceID: instId,
State: domain.OrgStateActive.String(),
}
err = organizationRepo.Create(ctx, &org)
require.NoError(t, err)
// change the id to make it unique
org.ID = orgID
// change the instanceID to a different instance
org.InstanceID = instanceId
return &org
},
organization: domain.Organization{
ID: orgID,
Name: organizationName,
InstanceID: instanceId,
State: domain.OrgStateActive.String(),
},
}
}(),
{
name: "adding organization with no id",
organization: func() domain.Organization {
// organizationId := gofakeit.Name()
organizationName := gofakeit.Name()
organization := domain.Organization{
// ID: organizationId,
Name: organizationName,
InstanceID: instanceId,
State: domain.OrgStateActive.String(),
}
return organization
}(),
err: errors.New("organization id not provided"),
},
{
name: "adding organization with no instance id",
organization: func() domain.Organization {
organizationId := gofakeit.Name()
organizationName := gofakeit.Name()
organization := domain.Organization{
ID: organizationId,
Name: organizationName,
State: domain.OrgStateActive.String(),
}
return organization
}(),
err: errors.New("invalid instance id"),
},
{
name: "adding organization with non existent instance id",
organization: func() domain.Organization {
organizationId := gofakeit.Name()
organizationName := gofakeit.Name()
organization := domain.Organization{
ID: organizationId,
Name: organizationName,
InstanceID: gofakeit.Name(),
State: domain.OrgStateActive.String(),
}
return organization
}(),
err: errors.New("invalid instance id"),
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ctx := context.Background()
var organization *domain.Organization
if tt.testFunc != nil {
organization = tt.testFunc(ctx, t)
} else {
organization = &tt.organization
}
organizationRepo := repository.OrganizationRepository(pool)
// create organization
beforeCreate := time.Now()
err = organizationRepo.Create(ctx, organization)
assert.Equal(t, tt.err, err)
if err != nil {
return
}
afterCreate := time.Now()
// check organization values
organization, err = organizationRepo.Get(ctx,
organizationRepo.IDCondition(organization.ID),
organization.InstanceID,
)
require.NoError(t, err)
assert.Equal(t, tt.organization.ID, organization.ID)
assert.Equal(t, tt.organization.Name, organization.Name)
assert.Equal(t, tt.organization.InstanceID, organization.InstanceID)
assert.Equal(t, tt.organization.State, organization.State)
assert.WithinRange(t, organization.CreatedAt, beforeCreate, afterCreate)
assert.WithinRange(t, organization.UpdatedAt, beforeCreate, afterCreate)
assert.Nil(t, organization.DeletedAt)
})
}
}
func TestUpdateOrganization(t *testing.T) {
// create instance
instanceId := gofakeit.Name()
instance := domain.Instance{
ID: instanceId,
Name: gofakeit.Name(),
DefaultOrgID: "defaultOrgId",
IAMProjectID: "iamProject",
ConsoleClientID: "consoleCLient",
ConsoleAppID: "consoleApp",
DefaultLanguage: "defaultLanguage",
}
instanceRepo := repository.InstanceRepository(pool)
err := instanceRepo.Create(t.Context(), &instance)
assert.Nil(t, err)
organizationRepo := repository.OrganizationRepository(pool)
tests := []struct {
name string
testFunc func(ctx context.Context, t *testing.T) *domain.Organization
update []database.Change
rowsAffected int64
}{
{
name: "happy path update name",
testFunc: func(ctx context.Context, t *testing.T) *domain.Organization {
organizationId := gofakeit.Name()
organizationName := gofakeit.Name()
org := domain.Organization{
ID: organizationId,
Name: organizationName,
InstanceID: instanceId,
State: domain.OrgStateActive.String(),
}
// create organization
err := organizationRepo.Create(ctx, &org)
require.NoError(t, err)
// update with updated value
org.Name = "new_name"
return &org
},
update: []database.Change{organizationRepo.SetName("new_name")},
rowsAffected: 1,
},
{
name: "update deleted organization",
testFunc: func(ctx context.Context, t *testing.T) *domain.Organization {
organizationId := gofakeit.Name()
organizationName := gofakeit.Name()
org := domain.Organization{
ID: organizationId,
Name: organizationName,
InstanceID: instanceId,
State: domain.OrgStateActive.String(),
}
// create organization
err := organizationRepo.Create(ctx, &org)
require.NoError(t, err)
// delete instance
_, err = organizationRepo.Delete(ctx,
organizationRepo.IDCondition(org.ID),
org.InstanceID,
)
require.NoError(t, err)
return &org
},
update: []database.Change{organizationRepo.SetName("new_name")},
rowsAffected: 0,
},
{
name: "happy path change state",
testFunc: func(ctx context.Context, t *testing.T) *domain.Organization {
organizationId := gofakeit.Name()
organizationName := gofakeit.Name()
org := domain.Organization{
ID: organizationId,
Name: organizationName,
InstanceID: instanceId,
State: domain.OrgStateActive.String(),
}
// create organization
err := organizationRepo.Create(ctx, &org)
require.NoError(t, err)
// update with updated value
org.State = domain.OrgStateInactive.String()
return &org
},
update: []database.Change{organizationRepo.SetState(domain.OrgStateInactive)},
rowsAffected: 1,
},
{
name: "update non existent organization",
testFunc: func(ctx context.Context, t *testing.T) *domain.Organization {
organizationId := gofakeit.Name()
org := domain.Organization{
ID: organizationId,
}
return &org
},
update: []database.Change{organizationRepo.SetName("new_name")},
rowsAffected: 0,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ctx := context.Background()
organizationRepo := repository.OrganizationRepository(pool)
createdOrg := tt.testFunc(ctx, t)
// update org
beforeUpdate := time.Now()
rowsAffected, err := organizationRepo.Update(ctx,
organizationRepo.IDCondition(createdOrg.ID),
createdOrg.InstanceID,
tt.update...,
)
afterUpdate := time.Now()
require.NoError(t, err)
assert.Equal(t, tt.rowsAffected, rowsAffected)
if rowsAffected == 0 {
return
}
// check organization values
organization, err := organizationRepo.Get(ctx,
organizationRepo.IDCondition(createdOrg.ID),
createdOrg.InstanceID,
)
require.NoError(t, err)
assert.Equal(t, createdOrg.ID, organization.ID)
assert.Equal(t, createdOrg.Name, organization.Name)
assert.Equal(t, createdOrg.State, organization.State)
assert.WithinRange(t, organization.UpdatedAt, beforeUpdate, afterUpdate)
assert.Nil(t, organization.DeletedAt)
})
}
}
func TestGetOrganization(t *testing.T) {
// create instance
instanceId := gofakeit.Name()
instance := domain.Instance{
ID: instanceId,
Name: gofakeit.Name(),
DefaultOrgID: "defaultOrgId",
IAMProjectID: "iamProject",
ConsoleClientID: "consoleCLient",
ConsoleAppID: "consoleApp",
DefaultLanguage: "defaultLanguage",
}
instanceRepo := repository.InstanceRepository(pool)
err := instanceRepo.Create(t.Context(), &instance)
assert.Nil(t, err)
orgRepo := repository.OrganizationRepository(pool)
// create organization
// this org is created as an additional org which should NOT
// be returned in the results of the tests
org := domain.Organization{
ID: gofakeit.Name(),
Name: gofakeit.Name(),
InstanceID: instanceId,
State: domain.OrgStateActive.String(),
}
err = orgRepo.Create(t.Context(), &org)
require.NoError(t, err)
type test struct {
name string
testFunc func(ctx context.Context, t *testing.T) *domain.Organization
orgIdentifierCondition domain.OrgIdentifierCondition
err error
}
tests := []test{
func() test {
organizationId := gofakeit.Name()
return test{
name: "happy path get using id",
testFunc: func(ctx context.Context, t *testing.T) *domain.Organization {
organizationName := gofakeit.Name()
org := domain.Organization{
ID: organizationId,
Name: organizationName,
InstanceID: instanceId,
State: domain.OrgStateActive.String(),
}
// create organization
err := orgRepo.Create(ctx, &org)
require.NoError(t, err)
return &org
},
orgIdentifierCondition: orgRepo.IDCondition(organizationId),
}
}(),
func() test {
organizationName := gofakeit.Name()
return test{
name: "happy path get using name",
testFunc: func(ctx context.Context, t *testing.T) *domain.Organization {
organizationId := gofakeit.Name()
org := domain.Organization{
ID: organizationId,
Name: organizationName,
InstanceID: instanceId,
State: domain.OrgStateActive.String(),
}
// create organization
err := orgRepo.Create(ctx, &org)
require.NoError(t, err)
return &org
},
orgIdentifierCondition: orgRepo.NameCondition(organizationName),
}
}(),
{
name: "get non existent organization",
testFunc: func(ctx context.Context, t *testing.T) *domain.Organization {
org := domain.Organization{
ID: "non existent org",
Name: "non existent org",
}
return &org
},
orgIdentifierCondition: orgRepo.NameCondition("non-existent-instance-name"),
err: repository.ErrResourceDoesNotExist,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ctx := context.Background()
orgRepo := repository.OrganizationRepository(pool)
var org *domain.Organization
if tt.testFunc != nil {
org = tt.testFunc(ctx, t)
}
// get org values
returnedOrg, err := orgRepo.Get(ctx,
tt.orgIdentifierCondition,
org.InstanceID,
)
if tt.err != nil {
require.Equal(t, tt.err, err)
return
}
if org.Name == "non existent org" {
assert.Nil(t, returnedOrg)
return
}
assert.Equal(t, returnedOrg.ID, org.ID)
assert.Equal(t, returnedOrg.Name, org.Name)
assert.Equal(t, returnedOrg.InstanceID, org.InstanceID)
assert.Equal(t, returnedOrg.State, org.State)
})
}
}
func TestListOrganization(t *testing.T) {
ctx := t.Context()
pool, stop, err := newEmbeddedDB(ctx)
require.NoError(t, err)
defer stop()
organizationRepo := repository.OrganizationRepository(pool)
// create instance
instanceId := gofakeit.Name()
instance := domain.Instance{
ID: instanceId,
Name: gofakeit.Name(),
DefaultOrgID: "defaultOrgId",
IAMProjectID: "iamProject",
ConsoleClientID: "consoleCLient",
ConsoleAppID: "consoleApp",
DefaultLanguage: "defaultLanguage",
}
instanceRepo := repository.InstanceRepository(pool)
err = instanceRepo.Create(ctx, &instance)
assert.Nil(t, err)
type test struct {
name string
testFunc func(ctx context.Context, t *testing.T) []*domain.Organization
conditionClauses []database.Condition
noOrganizationReturned bool
}
tests := []test{
{
name: "happy path single organization no filter",
testFunc: func(ctx context.Context, t *testing.T) []*domain.Organization {
noOfOrganizations := 1
organizations := make([]*domain.Organization, noOfOrganizations)
for i := range noOfOrganizations {
org := domain.Organization{
ID: gofakeit.Name(),
Name: gofakeit.Name(),
InstanceID: instanceId,
State: domain.OrgStateActive.String(),
}
// create organization
err := organizationRepo.Create(ctx, &org)
require.NoError(t, err)
organizations[i] = &org
}
return organizations
},
},
{
name: "happy path multiple organization no filter",
testFunc: func(ctx context.Context, t *testing.T) []*domain.Organization {
noOfOrganizations := 5
organizations := make([]*domain.Organization, noOfOrganizations)
for i := range noOfOrganizations {
org := domain.Organization{
ID: gofakeit.Name(),
Name: gofakeit.Name(),
InstanceID: instanceId,
State: domain.OrgStateActive.String(),
}
// create organization
err := organizationRepo.Create(ctx, &org)
require.NoError(t, err)
organizations[i] = &org
}
return organizations
},
},
func() test {
organizationId := gofakeit.Name()
return test{
name: "organization filter on id",
testFunc: func(ctx context.Context, t *testing.T) []*domain.Organization {
// create organization
// this org is created as an additional org which should NOT
// be returned in the results of this test case
org := domain.Organization{
ID: gofakeit.Name(),
Name: gofakeit.Name(),
InstanceID: instanceId,
State: domain.OrgStateActive.String(),
}
err = organizationRepo.Create(ctx, &org)
require.NoError(t, err)
noOfOrganizations := 1
organizations := make([]*domain.Organization, noOfOrganizations)
for i := range noOfOrganizations {
org := domain.Organization{
ID: organizationId,
Name: gofakeit.Name(),
InstanceID: instanceId,
State: domain.OrgStateActive.String(),
}
// create organization
err := organizationRepo.Create(ctx, &org)
require.NoError(t, err)
organizations[i] = &org
}
return organizations
},
conditionClauses: []database.Condition{organizationRepo.IDCondition(organizationId)},
}
}(),
{
name: "multiple organization filter on state",
testFunc: func(ctx context.Context, t *testing.T) []*domain.Organization {
// create organization
// this org is created as an additional org which should NOT
// be returned in the results of this test case
org := domain.Organization{
ID: gofakeit.Name(),
Name: gofakeit.Name(),
InstanceID: instanceId,
State: domain.OrgStateActive.String(),
}
err = organizationRepo.Create(ctx, &org)
require.NoError(t, err)
noOfOrganizations := 5
organizations := make([]*domain.Organization, noOfOrganizations)
for i := range noOfOrganizations {
org := domain.Organization{
ID: gofakeit.Name(),
Name: gofakeit.Name(),
InstanceID: instanceId,
State: domain.OrgStateInactive.String(),
}
// create organization
err := organizationRepo.Create(ctx, &org)
require.NoError(t, err)
organizations[i] = &org
}
return organizations
},
conditionClauses: []database.Condition{organizationRepo.StateCondition(domain.OrgStateInactive)},
},
func() test {
instanceId_2 := gofakeit.Name()
return test{
name: "multiple organization filter on instance",
testFunc: func(ctx context.Context, t *testing.T) []*domain.Organization {
// create instance 1
instanceId_1 := gofakeit.Name()
instance := domain.Instance{
ID: instanceId_1,
Name: gofakeit.Name(),
DefaultOrgID: "defaultOrgId",
IAMProjectID: "iamProject",
ConsoleClientID: "consoleCLient",
ConsoleAppID: "consoleApp",
DefaultLanguage: "defaultLanguage",
}
instanceRepo := repository.InstanceRepository(pool)
err = instanceRepo.Create(ctx, &instance)
assert.Nil(t, err)
// create organization
// this org is created as an additional org which should NOT
// be returned in the results of this test case
org := domain.Organization{
ID: gofakeit.Name(),
Name: gofakeit.Name(),
InstanceID: instanceId_1,
State: domain.OrgStateActive.String(),
}
err = organizationRepo.Create(ctx, &org)
require.NoError(t, err)
// create instance 2
instance_2 := domain.Instance{
ID: instanceId_2,
Name: gofakeit.Name(),
DefaultOrgID: "defaultOrgId",
IAMProjectID: "iamProject",
ConsoleClientID: "consoleCLient",
ConsoleAppID: "consoleApp",
DefaultLanguage: "defaultLanguage",
}
err = instanceRepo.Create(ctx, &instance_2)
assert.Nil(t, err)
noOfOrganizations := 5
organizations := make([]*domain.Organization, noOfOrganizations)
for i := range noOfOrganizations {
org := domain.Organization{
ID: gofakeit.Name(),
Name: gofakeit.Name(),
InstanceID: instanceId_2,
State: domain.OrgStateActive.String(),
}
// create organization
err := organizationRepo.Create(ctx, &org)
require.NoError(t, err)
organizations[i] = &org
}
return organizations
},
conditionClauses: []database.Condition{organizationRepo.InstanceIDCondition(instanceId_2)},
}
}(),
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Cleanup(func() {
_, err := pool.Exec(ctx, "DELETE FROM zitadel.organizations")
require.NoError(t, err)
})
organizations := tt.testFunc(ctx, t)
// check organization values
returnedOrgs, err := organizationRepo.List(ctx,
tt.conditionClauses...,
)
require.NoError(t, err)
if tt.noOrganizationReturned {
assert.Nil(t, returnedOrgs)
return
}
assert.Equal(t, len(organizations), len(returnedOrgs))
for i, org := range organizations {
assert.Equal(t, returnedOrgs[i].ID, org.ID)
assert.Equal(t, returnedOrgs[i].Name, org.Name)
assert.Equal(t, returnedOrgs[i].InstanceID, org.InstanceID)
assert.Equal(t, returnedOrgs[i].State, org.State)
}
})
}
}
func TestDeleteOrganization(t *testing.T) {
// create instance
instanceId := gofakeit.Name()
instance := domain.Instance{
ID: instanceId,
Name: gofakeit.Name(),
DefaultOrgID: "defaultOrgId",
IAMProjectID: "iamProject",
ConsoleClientID: "consoleCLient",
ConsoleAppID: "consoleApp",
DefaultLanguage: "defaultLanguage",
}
instanceRepo := repository.InstanceRepository(pool)
err := instanceRepo.Create(t.Context(), &instance)
assert.Nil(t, err)
type test struct {
name string
testFunc func(ctx context.Context, t *testing.T)
orgIdentifierCondition domain.OrgIdentifierCondition
noOfDeletedRows int64
}
tests := []test{
func() test {
organizationRepo := repository.OrganizationRepository(pool)
organizationId := gofakeit.Name()
var noOfOrganizations int64 = 1
return test{
name: "happy path delete organization filter id",
testFunc: func(ctx context.Context, t *testing.T) {
organizations := make([]*domain.Organization, noOfOrganizations)
for i := range noOfOrganizations {
org := domain.Organization{
ID: organizationId,
Name: gofakeit.Name(),
InstanceID: instanceId,
State: domain.OrgStateActive.String(),
}
// create organization
err := organizationRepo.Create(ctx, &org)
require.NoError(t, err)
organizations[i] = &org
}
},
orgIdentifierCondition: organizationRepo.IDCondition(organizationId),
noOfDeletedRows: noOfOrganizations,
}
}(),
func() test {
organizationRepo := repository.OrganizationRepository(pool)
organizationName := gofakeit.Name()
var noOfOrganizations int64 = 1
return test{
name: "happy path delete organization filter name",
testFunc: func(ctx context.Context, t *testing.T) {
organizations := make([]*domain.Organization, noOfOrganizations)
for i := range noOfOrganizations {
org := domain.Organization{
ID: gofakeit.Name(),
Name: organizationName,
InstanceID: instanceId,
State: domain.OrgStateActive.String(),
}
// create organization
err := organizationRepo.Create(ctx, &org)
require.NoError(t, err)
organizations[i] = &org
}
},
orgIdentifierCondition: organizationRepo.NameCondition(organizationName),
noOfDeletedRows: noOfOrganizations,
}
}(),
func() test {
organizationRepo := repository.OrganizationRepository(pool)
non_existent_organization_name := gofakeit.Name()
return test{
name: "delete non existent organization",
orgIdentifierCondition: organizationRepo.NameCondition(non_existent_organization_name),
}
}(),
func() test {
organizationRepo := repository.OrganizationRepository(pool)
organizationName := gofakeit.Name()
return test{
name: "deleted already deleted organization",
testFunc: func(ctx context.Context, t *testing.T) {
noOfOrganizations := 1
organizations := make([]*domain.Organization, noOfOrganizations)
for i := range noOfOrganizations {
org := domain.Organization{
ID: gofakeit.Name(),
Name: organizationName,
InstanceID: instanceId,
State: domain.OrgStateActive.String(),
}
// create organization
err := organizationRepo.Create(ctx, &org)
require.NoError(t, err)
organizations[i] = &org
}
// delete organization
affectedRows, err := organizationRepo.Delete(ctx,
organizationRepo.NameCondition(organizationName),
organizations[0].InstanceID,
)
assert.Equal(t, int64(1), affectedRows)
require.NoError(t, err)
},
orgIdentifierCondition: organizationRepo.NameCondition(organizationName),
// this test should return 0 affected rows as the org was already deleted
noOfDeletedRows: 0,
}
}(),
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ctx := context.Background()
organizationRepo := repository.OrganizationRepository(pool)
if tt.testFunc != nil {
tt.testFunc(ctx, t)
}
// delete organization
noOfDeletedRows, err := organizationRepo.Delete(ctx,
tt.orgIdentifierCondition,
instanceId,
)
require.NoError(t, err)
assert.Equal(t, noOfDeletedRows, tt.noOfDeletedRows)
// check organization was deleted
organization, err := organizationRepo.Get(ctx,
tt.orgIdentifierCondition,
instanceId,
)
require.Equal(t, err, repository.ErrResourceDoesNotExist)
assert.Nil(t, organization)
})
}
}

View File

@@ -1,7 +1,24 @@
package repository
import "github.com/zitadel/zitadel/backend/v3/storage/database"
import (
"errors"
"github.com/zitadel/zitadel/backend/v3/storage/database"
)
var ErrResourceDoesNotExist = errors.New("resource does not exist")
type repository struct {
client database.QueryExecutor
}
func writeCondition(
builder *database.StatementBuilder,
condition database.Condition,
) {
if condition == nil {
return
}
builder.WriteString(" WHERE ")
condition.Write(builder)
}

View File

@@ -123,7 +123,7 @@ func (u *user) Create(ctx context.Context, user *domain.User) error {
func (u *user) Delete(ctx context.Context, condition database.Condition) error {
builder := database.StatementBuilder{}
builder.WriteString("DELETE FROM users")
u.writeCondition(builder, condition)
writeCondition(&builder, condition)
_, err := u.client.Exec(ctx, builder.String(), builder.Args()...)
return err
}
@@ -223,17 +223,6 @@ func (user) DeletedAtColumn() database.Column {
return database.NewColumn("deleted_at")
}
func (u *user) writeCondition(
builder database.StatementBuilder,
condition database.Condition,
) {
if condition == nil {
return
}
builder.WriteString(" WHERE ")
condition.Write(&builder)
}
func (u user) columns() database.Columns {
return database.Columns{
u.InstanceIDColumn(),

View File

@@ -26,7 +26,7 @@ func (u *userHuman) GetEmail(ctx context.Context, condition database.Condition)
builder := database.StatementBuilder{}
builder.WriteString(userEmailQuery)
u.writeCondition(builder, condition)
writeCondition(&builder, condition)
err := u.client.QueryRow(ctx, builder.String(), builder.Args()...).Scan(
&email.Address,
@@ -43,7 +43,7 @@ func (h userHuman) Update(ctx context.Context, condition database.Condition, cha
builder := database.StatementBuilder{}
builder.WriteString(`UPDATE human_users SET `)
database.Changes(changes).Write(&builder)
h.writeCondition(builder, condition)
writeCondition(&builder, condition)
stmt := builder.String()

View File

@@ -22,7 +22,7 @@ func (m userMachine) Update(ctx context.Context, condition database.Condition, c
builder := database.StatementBuilder{}
builder.WriteString("UPDATE user_machines SET ")
database.Changes(changes).Write(&builder)
m.writeCondition(builder, condition)
writeCondition(&builder, condition)
m.writeReturning()
_, err := m.client.Exec(ctx, builder.String(), builder.Args()...)