diff --git a/backend/v3/domain/instance.go b/backend/v3/domain/instance.go index b6dd5f21df..2efe6734b3 100644 --- a/backend/v3/domain/instance.go +++ b/backend/v3/domain/instance.go @@ -86,12 +86,12 @@ type InstanceRepository interface { // Member returns the member repository which is a sub repository of the instance repository. // Member() MemberRepository - Get(ctx context.Context, opts ...database.Condition) (*Instance, error) + Get(ctx context.Context, id string) (*Instance, error) List(ctx context.Context, opts ...database.Condition) ([]*Instance, error) Create(ctx context.Context, instance *Instance) error - Update(ctx context.Context, condition database.Condition, changes ...database.Change) (int64, error) - Delete(ctx context.Context, condition database.Condition) error + Update(ctx context.Context, id string, changes ...database.Change) (int64, error) + Delete(ctx context.Context, id string) (int64, error) } type CreateInstance struct { diff --git a/backend/v3/domain/org.go b/backend/v3/domain/org.go deleted file mode 100644 index 0f80e58a73..0000000000 --- a/backend/v3/domain/org.go +++ /dev/null @@ -1,120 +0,0 @@ -package domain - -import ( - "context" - "time" - - "github.com/zitadel/zitadel/backend/v3/storage/cache" - "github.com/zitadel/zitadel/backend/v3/storage/database" -) - -type OrgState uint8 - -const ( - OrgStateActive OrgState = iota + 1 - OrgStateInactive -) - -// Org is used by all other packages to represent an organization. -type Org struct { - ID string `json:"id"` - Name string `json:"name"` - - State OrgState `json:"state"` - - CreatedAt time.Time `json:"createdAt"` - UpdatedAt time.Time `json:"updatedAt"` -} - -type orgCacheIndex uint8 - -const ( - orgCacheIndexUndefined orgCacheIndex = iota - orgCacheIndexID -) - -// Keys implements [cache.Entry]. -func (o *Org) Keys(index orgCacheIndex) (key []string) { - if index == orgCacheIndexID { - return []string{o.ID} - } - return nil -} - -var _ cache.Entry[orgCacheIndex, string] = (*Org)(nil) - -// orgColumns define all the columns of the org table. -type orgColumns interface { - // InstanceIDColumn returns the column for the instance id field. - InstanceIDColumn() database.Column - // IDColumn returns the column for the id field. - IDColumn() database.Column - // NameColumn returns the column for the name field. - NameColumn() database.Column - // StateColumn returns the column for the state field. - StateColumn() database.Column - // CreatedAtColumn returns the column for the created at field. - CreatedAtColumn() database.Column - // UpdatedAtColumn returns the column for the updated at field. - UpdatedAtColumn() database.Column - // DeletedAtColumn returns the column for the deleted at field. - DeletedAtColumn() database.Column -} - -// orgConditions define all the conditions for the org table. -type orgConditions interface { - // InstanceIDCondition returns an equal filter on the instance id field. - InstanceIDCondition(instanceID string) database.Condition - // IDCondition returns an equal filter on the id field. - IDCondition(orgID string) database.Condition - // NameCondition returns a filter on the name field. - NameCondition(op database.TextOperation, name string) database.Condition - // StateCondition returns a filter on the state field. - StateCondition(op database.NumberOperation, state OrgState) database.Condition -} - -// orgChanges define all the changes for the org table. -type orgChanges interface { - // SetName sets the name column. - SetName(name string) database.Change - // SetState sets the state column. - SetState(state OrgState) database.Change -} - -// OrgRepository is the interface for the org repository. -// It is used to interact with the org table in the database. -type OrgRepository interface { - orgColumns - orgConditions - orgChanges - - // Member returns the member repository. - Member() MemberRepository - // Domain returns the domain repository. - Domain() DomainRepository - - // Get returns an org based on the given condition. - Get(ctx context.Context, opts ...database.QueryOption) (*Org, error) - // List returns a list of orgs based on the given condition. - List(ctx context.Context, opts ...database.QueryOption) ([]*Org, error) - // Create creates a new org. - Create(ctx context.Context, org *Org) error - // Delete removes orgs based on the given condition. - Delete(ctx context.Context, condition database.Condition) error - // Update executes the given changes based on the given condition. - Update(ctx context.Context, condition database.Condition, changes ...database.Change) error -} - -// MemberRepository is a sub repository of the org repository and maybe the instance repository. -type MemberRepository interface { - AddMember(ctx context.Context, orgID, userID string, roles []string) error - SetMemberRoles(ctx context.Context, orgID, userID string, roles []string) error - RemoveMember(ctx context.Context, orgID, userID string) error -} - -// DomainRepository is a sub repository of the org repository and maybe the instance repository. -type DomainRepository interface { - AddDomain(ctx context.Context, domain string) error - SetDomainVerified(ctx context.Context, domain string) error - RemoveDomain(ctx context.Context, domain string) error -} diff --git a/backend/v3/domain/organization.go b/backend/v3/domain/organization.go new file mode 100644 index 0000000000..94dd80d72f --- /dev/null +++ b/backend/v3/domain/organization.go @@ -0,0 +1,103 @@ +package domain + +import ( + "context" + "time" + + "github.com/zitadel/zitadel/backend/v3/storage/database" +) + +//go:generate enumer -type OrgState -transform lower -trimprefix OrgState +type OrgState uint8 + +const ( + OrgStateActive OrgState = iota + OrgStateInactive +) + +type Organization struct { + ID string `json:"id,omitempty" db:"id"` + Name string `json:"name,omitempty" db:"name"` + InstanceID string `json:"instanceId,omitempty" db:"instance_id"` + State string `json:"state,omitempty" db:"state"` + CreatedAt time.Time `json:"createdAt,omitempty" db:"created_at"` + UpdatedAt time.Time `json:"updatedAt,omitempty" db:"updated_at"` + DeletedAt *time.Time `json:"deletedAt,omitempty" db:"deleted_at"` +} + +// OrgIdentifierCondition is used to help specify a single Organization, +// it will either be used as the organization ID or organization name, +// as organizations can be identified either using (instnaceID + ID) OR (instanceID + name) +type OrgIdentifierCondition interface { + database.Condition +} + +// organizationColumns define all the columns of the instance table. +type organizationColumns interface { + // IDColumn returns the column for the id field. + IDColumn() database.Column + // NameColumn returns the column for the name field. + NameColumn() database.Column + // InstanceIDColumn returns the column for the default org id field + InstanceIDColumn() database.Column + // StateColumn returns the column for the name field. + StateColumn() database.Column + // CreatedAtColumn returns the column for the created at field. + CreatedAtColumn() database.Column + // UpdatedAtColumn returns the column for the updated at field. + UpdatedAtColumn() database.Column + // DeletedAtColumn returns the column for the deleted at field. + DeletedAtColumn() database.Column +} + +// organizationConditions define all the conditions for the instance table. +type organizationConditions interface { + // IDCondition returns an equal filter on the id field. + IDCondition(instanceID string) OrgIdentifierCondition + // NameCondition returns a filter on the name field. + NameCondition(name string) OrgIdentifierCondition + // InstanceIDCondition returns a filter on the instance id field. + InstanceIDCondition(instanceID string) database.Condition + // StateCondition returns a filter on the name field. + StateCondition(state OrgState) database.Condition +} + +// organizationChanges define all the changes for the instance table. +type organizationChanges interface { + // SetName sets the name column. + SetName(name string) database.Change + // SetState sets the name column. + SetState(state OrgState) database.Change +} + +// OrganizationRepository is the interface for the instance repository. +type OrganizationRepository interface { + organizationColumns + organizationConditions + organizationChanges + + Get(ctx context.Context, id OrgIdentifierCondition, instance_id string, opts ...database.Condition) (*Organization, error) + List(ctx context.Context, opts ...database.Condition) ([]*Organization, error) + + Create(ctx context.Context, instance *Organization) error + Update(ctx context.Context, id OrgIdentifierCondition, instance_id string, changes ...database.Change) (int64, error) + Delete(ctx context.Context, id OrgIdentifierCondition, instance_id string) (int64, error) +} + +type CreateOrganization struct { + Name string `json:"name"` +} + +// MemberRepository is a sub repository of the org repository and maybe the instance repository. +type MemberRepository interface { + AddMember(ctx context.Context, orgID, userID string, roles []string) error + SetMemberRoles(ctx context.Context, orgID, userID string, roles []string) error + RemoveMember(ctx context.Context, orgID, userID string) error +} + +// DomainRepository is a sub repository of the org repository and maybe the instance repository. +type DomainRepository interface { + AddDomain(ctx context.Context, domain string) error + SetDomainVerified(ctx context.Context, domain string) error + RemoveDomain(ctx context.Context, domain string) error +} diff --git a/backend/v3/domain/orgstate_enumer.go b/backend/v3/domain/orgstate_enumer.go new file mode 100644 index 0000000000..a5a1d0ca57 --- /dev/null +++ b/backend/v3/domain/orgstate_enumer.go @@ -0,0 +1,78 @@ +// Code generated by "enumer -type OrgState -transform lower -trimprefix OrgState"; DO NOT EDIT. + +package domain + +import ( + "fmt" + "strings" +) + +const _OrgStateName = "activeinactive" + +var _OrgStateIndex = [...]uint8{0, 6, 14} + +const _OrgStateLowerName = "activeinactive" + +func (i OrgState) String() string { + if i < 0 || i >= OrgState(len(_OrgStateIndex)-1) { + return fmt.Sprintf("OrgState(%d)", i) + } + return _OrgStateName[_OrgStateIndex[i]:_OrgStateIndex[i+1]] +} + +// An "invalid array index" compiler error signifies that the constant values have changed. +// Re-run the stringer command to generate them again. +func _OrgStateNoOp() { + var x [1]struct{} + _ = x[OrgStateActive-(0)] + _ = x[OrgStateInactive-(1)] +} + +var _OrgStateValues = []OrgState{OrgStateActive, OrgStateInactive} + +var _OrgStateNameToValueMap = map[string]OrgState{ + _OrgStateName[0:6]: OrgStateActive, + _OrgStateLowerName[0:6]: OrgStateActive, + _OrgStateName[6:14]: OrgStateInactive, + _OrgStateLowerName[6:14]: OrgStateInactive, +} + +var _OrgStateNames = []string{ + _OrgStateName[0:6], + _OrgStateName[6:14], +} + +// OrgStateString retrieves an enum value from the enum constants string name. +// Throws an error if the param is not part of the enum. +func OrgStateString(s string) (OrgState, error) { + if val, ok := _OrgStateNameToValueMap[s]; ok { + return val, nil + } + + if val, ok := _OrgStateNameToValueMap[strings.ToLower(s)]; ok { + return val, nil + } + return 0, fmt.Errorf("%s does not belong to OrgState values", s) +} + +// OrgStateValues returns all values of the enum +func OrgStateValues() []OrgState { + return _OrgStateValues +} + +// OrgStateStrings returns a slice of all String values of the enum +func OrgStateStrings() []string { + strs := make([]string, len(_OrgStateNames)) + copy(strs, _OrgStateNames) + return strs +} + +// IsAOrgState returns "true" if the value is listed in the enum definition. "false" otherwise +func (i OrgState) IsAOrgState() bool { + for _, v := range _OrgStateValues { + if i == v { + return true + } + } + return false +} diff --git a/backend/v3/storage/database/database.go b/backend/v3/storage/database/database.go index 7dd8d490ba..00a852e7a8 100644 --- a/backend/v3/storage/database/database.go +++ b/backend/v3/storage/database/database.go @@ -59,8 +59,25 @@ type Row interface { // Rows is an abstraction of sql.Rows. type Rows interface { - Row + Scanner Next() bool Close() error Err() error } + +type CollectableRows interface { + // Collect collects all rows and scans them into dest. + // dest must be a pointer to a slice of pointer to structs + // e.g. *[]*MyStruct + // Rows are closed after this call. + Collect(dest any) error + // CollectFirst collects the first row and scans it into dest. + // dest must be a pointer to a struct + // e.g. *MyStruct{} + // Rows are closed after this call. + CollectFirst(dest any) error + // CollectExactlyOneRow collects exactly one row and scans it into dest. + // e.g. *MyStruct{} + // Rows are closed after this call. + CollectExactlyOneRow(dest any) error +} diff --git a/backend/v3/storage/database/dialect/postgres/migration/001_instance_table/up.sql b/backend/v3/storage/database/dialect/postgres/migration/001_instance_table/up.sql index 111719632c..75ddcfb6cc 100644 --- a/backend/v3/storage/database/dialect/postgres/migration/001_instance_table/up.sql +++ b/backend/v3/storage/database/dialect/postgres/migration/001_instance_table/up.sql @@ -6,8 +6,8 @@ CREATE TABLE IF NOT EXISTS zitadel.instances( console_client_id TEXT, -- NOT NULL, console_app_id TEXT, -- NOT NULL, default_language TEXT, -- NOT NULL, - created_at TIMESTAMPTZ DEFAULT NOW(), - updated_at TIMESTAMPTZ DEFAULT NOW(), + created_at TIMESTAMPTZ DEFAULT NOW() NOT NULL, + updated_at TIMESTAMPTZ DEFAULT NOW() NOT NULL, deleted_at TIMESTAMPTZ DEFAULT NULL ); diff --git a/backend/v3/storage/database/dialect/postgres/migration/002_organization_table.go b/backend/v3/storage/database/dialect/postgres/migration/002_organization_table.go new file mode 100644 index 0000000000..2f0a04eee0 --- /dev/null +++ b/backend/v3/storage/database/dialect/postgres/migration/002_organization_table.go @@ -0,0 +1,16 @@ +package migration + +import ( + _ "embed" +) + +var ( + //go:embed 002_organization_table/up.sql + up002OrganizationTable string + //go:embed 002_organization_table/down.sql + down002OrganizationTable string +) + +func init() { + registerSQLMigration(2, up002OrganizationTable, down002OrganizationTable) +} diff --git a/backend/v3/storage/database/dialect/postgres/migration/002_organization_table/down.sql b/backend/v3/storage/database/dialect/postgres/migration/002_organization_table/down.sql new file mode 100644 index 0000000000..654858cdac --- /dev/null +++ b/backend/v3/storage/database/dialect/postgres/migration/002_organization_table/down.sql @@ -0,0 +1,2 @@ +DROP TABLE zitadel.organizations; +DROP Type zitadel.organization_state; diff --git a/backend/v3/storage/database/dialect/postgres/migration/002_organization_table/up.sql b/backend/v3/storage/database/dialect/postgres/migration/002_organization_table/up.sql new file mode 100644 index 0000000000..c09757c003 --- /dev/null +++ b/backend/v3/storage/database/dialect/postgres/migration/002_organization_table/up.sql @@ -0,0 +1,33 @@ +CREATE TYPE zitadel.organization_state AS ENUM ( + 'active', + 'inactive' +); + +CREATE TABLE zitadel.organizations( + id TEXT NOT NULL CHECK (id <> ''), + name TEXT NOT NULL CHECK (name <> ''), + instance_id TEXT NOT NULL REFERENCES zitadel.instances (id), + state zitadel.organization_state NOT NULL, + created_at TIMESTAMPTZ DEFAULT NOW() NOT NULL, + updated_at TIMESTAMPTZ DEFAULT NOW() NOT NULL, + deleted_at TIMESTAMPTZ DEFAULT NULL, + + PRIMARY KEY (instance_id, id) +); + +CREATE UNIQUE INDEX org_unique_instance_id_name_idx + ON zitadel.organizations (instance_id, name) + WHERE deleted_at IS NULL; + +-- users are able to set the id for organizations +CREATE INDEX org_id_not_deleted_idx ON zitadel.organizations (id) + WHERE deleted_at IS NULL; + +CREATE INDEX org_name_not_deleted_idx ON zitadel.organizations (name) + WHERE deleted_at IS NULL; + +CREATE TRIGGER trigger_set_updated_at +BEFORE UPDATE ON zitadel.organizations +FOR EACH ROW +WHEN (OLD.updated_at IS NOT DISTINCT FROM NEW.updated_at) +EXECUTE FUNCTION zitadel.set_updated_at(); diff --git a/backend/v3/storage/database/dialect/postgres/rows.go b/backend/v3/storage/database/dialect/postgres/rows.go index 891a2a3f46..8dafc88f4f 100644 --- a/backend/v3/storage/database/dialect/postgres/rows.go +++ b/backend/v3/storage/database/dialect/postgres/rows.go @@ -1,15 +1,55 @@ package postgres import ( + "github.com/georgysavva/scany/v2/pgxscan" "github.com/jackc/pgx/v5" "github.com/zitadel/zitadel/backend/v3/storage/database" ) -var _ database.Rows = (*Rows)(nil) +var ( + _ database.Rows = (*Rows)(nil) + _ database.CollectableRows = (*Rows)(nil) +) type Rows struct{ pgx.Rows } +// Collect implements [database.CollectableRows]. +// See [this page](https://github.com/georgysavva/scany/blob/master/dbscan/doc.go#L8) for additional details. +func (r *Rows) Collect(dest any) (err error) { + defer func() { + closeErr := r.Close() + if err == nil { + err = closeErr + } + }() + return pgxscan.ScanAll(dest, r.Rows) +} + +// CollectFirst implements [database.CollectableRows]. +// See [this page](https://github.com/georgysavva/scany/blob/master/dbscan/doc.go#L8) for additional details. +func (r *Rows) CollectFirst(dest any) (err error) { + defer func() { + closeErr := r.Close() + if err == nil { + err = closeErr + } + }() + return pgxscan.ScanRow(dest, r.Rows) +} + +// CollectExactlyOneRow implements [database.CollectableRows]. +// See [this page](https://github.com/georgysavva/scany/blob/master/dbscan/doc.go#L8) for additional details. +func (r *Rows) CollectExactlyOneRow(dest any) (err error) { + defer func() { + closeErr := r.Close() + if err == nil { + err = closeErr + } + }() + return pgxscan.ScanOne(dest, r.Rows) +} + // Close implements [database.Rows]. // Subtle: this method shadows the method (Rows).Close of Rows.Rows. func (r *Rows) Close() error { diff --git a/backend/v3/storage/database/events_testing/events_test.go b/backend/v3/storage/database/events_testing/events_test.go new file mode 100644 index 0000000000..e1bc903bea --- /dev/null +++ b/backend/v3/storage/database/events_testing/events_test.go @@ -0,0 +1,67 @@ +//go:build integration + +package events_test + +import ( + "context" + "os" + "testing" + "time" + + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgxpool" + + "github.com/zitadel/zitadel/backend/v3/storage/database" + "github.com/zitadel/zitadel/backend/v3/storage/database/dialect/postgres" + "github.com/zitadel/zitadel/internal/integration" + v2beta_org "github.com/zitadel/zitadel/pkg/grpc/org/v2beta" + "github.com/zitadel/zitadel/pkg/grpc/system" +) + +const ConnString = "host=localhost port=5432 user=zitadel dbname=zitadel sslmode=disable" + +var ( + dbPool *pgxpool.Pool + CTX context.Context + Instance *integration.Instance + SystemClient system.SystemServiceClient + OrgClient v2beta_org.OrganizationServiceClient +) + +var pool database.Pool + +func TestMain(m *testing.M) { + os.Exit(func() int { + ctx, cancel := context.WithTimeout(context.Background(), 15*time.Minute) + defer cancel() + + Instance = integration.NewInstance(ctx) + + CTX = Instance.WithAuthorization(ctx, integration.UserTypeIAMOwner) + SystemClient = integration.SystemClient() + OrgClient = Instance.Client.OrgV2beta + + var err error + dbConfig, err := pgxpool.ParseConfig(ConnString) + if err != nil { + panic(err) + } + dbConfig.AfterConnect = func(ctx context.Context, conn *pgx.Conn) error { + orgState, err := conn.LoadType(ctx, "zitadel.organization_state") + if err != nil { + return err + } + conn.TypeMap().RegisterType(orgState) + return nil + } + + dbPool, err = pgxpool.NewWithConfig(context.Background(), dbConfig) + if err != nil { + panic(err) + } + + pool = postgres.PGxPool(dbPool) + + return m.Run() + }()) +} diff --git a/backend/v3/storage/database/events_testing/instance_test.go b/backend/v3/storage/database/events_testing/instance_test.go index dba35b52ea..dd289b5822 100644 --- a/backend/v3/storage/database/events_testing/instance_test.go +++ b/backend/v3/storage/database/events_testing/instance_test.go @@ -1,171 +1,138 @@ //go:build integration -package instance_test +package events_test import ( - "context" - "os" "testing" "time" "github.com/brianvoe/gofakeit/v6" - "github.com/jackc/pgx/v5/pgxpool" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "github.com/zitadel/zitadel/backend/v3/storage/database" - "github.com/zitadel/zitadel/backend/v3/storage/database/dialect/postgres" "github.com/zitadel/zitadel/backend/v3/storage/database/repository" "github.com/zitadel/zitadel/internal/integration" "github.com/zitadel/zitadel/pkg/grpc/system" ) -const ConnString = "host=localhost port=5432 user=zitadel dbname=zitadel sslmode=disable" - -var ( - dbPool *pgxpool.Pool - CTX context.Context - Instance *integration.Instance - SystemClient system.SystemServiceClient -) - -var pool database.Pool - -func TestMain(m *testing.M) { - os.Exit(func() int { - ctx, cancel := context.WithTimeout(context.Background(), 15*time.Minute) - defer cancel() - - Instance = integration.NewInstance(ctx) - - CTX = Instance.WithAuthorization(ctx, integration.UserTypeIAMOwner) - SystemClient = integration.SystemClient() - - var err error - dbPool, err = pgxpool.New(context.Background(), ConnString) - if err != nil { - panic(err) - } - - pool = postgres.PGxPool(dbPool) - - return m.Run() - }()) -} - -func TestServer_TestInstanceAddReduces(t *testing.T) { - instanceName := gofakeit.Name() - beforeCreate := time.Now() - _, err := SystemClient.CreateInstance(CTX, &system.CreateInstanceRequest{ - InstanceName: instanceName, - Owner: &system.CreateInstanceRequest_Machine_{ - Machine: &system.CreateInstanceRequest_Machine{ - UserName: "owner", - Name: "owner", - PersonalAccessToken: &system.CreateInstanceRequest_PersonalAccessToken{}, +func TestServer_TestInstanceReduces(t *testing.T) { + t.Run("test instance add reduces", func(t *testing.T) { + instanceName := gofakeit.Name() + beforeCreate := time.Now() + instance, err := SystemClient.CreateInstance(CTX, &system.CreateInstanceRequest{ + InstanceName: instanceName, + Owner: &system.CreateInstanceRequest_Machine_{ + Machine: &system.CreateInstanceRequest_Machine{ + UserName: "owner", + Name: "owner", + PersonalAccessToken: &system.CreateInstanceRequest_PersonalAccessToken{}, + }, }, - }, + }) + afterCreate := time.Now() + + require.NoError(t, err) + + instanceRepo := repository.InstanceRepository(pool) + retryDuration, tick := integration.WaitForAndTickWithMaxDuration(CTX, time.Minute) + assert.EventuallyWithT(t, func(ttt *assert.CollectT) { + instance, err := instanceRepo.Get(CTX, + instance.GetInstanceId(), + ) + require.NoError(ttt, err) + // event instance.added + assert.Equal(ttt, instanceName, instance.Name) + // event instance.default.org.set + assert.NotNil(t, instance.DefaultOrgID) + // event instance.iam.project.set + assert.NotNil(t, instance.IAMProjectID) + // event instance.iam.console.set + assert.NotNil(t, instance.ConsoleAppID) + // event instance.default.language.set + assert.NotNil(t, instance.DefaultLanguage) + // event instance.added + assert.WithinRange(t, instance.CreatedAt, beforeCreate, afterCreate) + // event instance.added + assert.WithinRange(t, instance.UpdatedAt, beforeCreate, afterCreate) + assert.Nil(t, instance.DeletedAt) + }, retryDuration, tick) }) - afterCreate := time.Now() - require.NoError(t, err) - - instanceRepo := repository.InstanceRepository(pool) - retryDuration, tick := integration.WaitForAndTickWithMaxDuration(CTX, time.Minute) - assert.EventuallyWithT(t, func(ttt *assert.CollectT) { - instance, err := instanceRepo.Get(CTX, - instanceRepo.NameCondition(database.TextOperationEqual, instanceName), - ) - require.NoError(ttt, err) - // event instance.added - require.Equal(ttt, instanceName, instance.Name) - // event instance.default.org.set - require.NotNil(t, instance.DefaultOrgID) - // event instance.iam.project.set - require.NotNil(t, instance.IAMProjectID) - // event instance.iam.console.set - require.NotNil(t, instance.ConsoleAppID) - // event instance.default.language.set - require.NotNil(t, instance.DefaultLanguage) - // event instance.added - assert.WithinRange(t, instance.CreatedAt, beforeCreate, afterCreate) - // event instance.added - assert.WithinRange(t, instance.UpdatedAt, beforeCreate, afterCreate) - require.Nil(t, instance.DeletedAt) - }, retryDuration, tick) -} - -func TestServer_TestInstanceUpdateNameReduces(t *testing.T) { - instanceName := gofakeit.Name() - res, err := SystemClient.CreateInstance(CTX, &system.CreateInstanceRequest{ - InstanceName: instanceName, - Owner: &system.CreateInstanceRequest_Machine_{ - Machine: &system.CreateInstanceRequest_Machine{ - UserName: "owner", - Name: "owner", - PersonalAccessToken: &system.CreateInstanceRequest_PersonalAccessToken{}, + t.Run("test instance update reduces", func(t *testing.T) { + instanceName := gofakeit.Name() + res, err := SystemClient.CreateInstance(CTX, &system.CreateInstanceRequest{ + InstanceName: instanceName, + Owner: &system.CreateInstanceRequest_Machine_{ + Machine: &system.CreateInstanceRequest_Machine{ + UserName: "owner", + Name: "owner", + PersonalAccessToken: &system.CreateInstanceRequest_PersonalAccessToken{}, + }, }, - }, + }) + require.NoError(t, err) + + instanceName += "new" + beforeUpdate := time.Now() + _, err = SystemClient.UpdateInstance(CTX, &system.UpdateInstanceRequest{ + InstanceId: res.InstanceId, + InstanceName: instanceName, + }) + require.NoError(t, err) + afterUpdate := time.Now() + + instanceRepo := repository.InstanceRepository(pool) + retryDuration, tick := integration.WaitForAndTickWithMaxDuration(CTX, time.Minute) + assert.EventuallyWithT(t, func(ttt *assert.CollectT) { + instance, err := instanceRepo.Get(CTX, + res.InstanceId, + ) + require.NoError(ttt, err) + // event instance.changed + assert.Equal(ttt, instanceName, instance.Name) + assert.WithinRange(t, instance.UpdatedAt, beforeUpdate, afterUpdate) + }, retryDuration, tick) }) - require.NoError(t, err) - instanceName += "new" - _, err = SystemClient.UpdateInstance(CTX, &system.UpdateInstanceRequest{ - InstanceId: res.InstanceId, - InstanceName: instanceName, - }) - require.NoError(t, err) - - instanceRepo := repository.InstanceRepository(pool) - retryDuration, tick := integration.WaitForAndTickWithMaxDuration(CTX, time.Minute) - assert.EventuallyWithT(t, func(ttt *assert.CollectT) { - instance, err := instanceRepo.Get(CTX, - instanceRepo.NameCondition(database.TextOperationEqual, instanceName), - ) - require.NoError(ttt, err) - // event instance.changed - require.Equal(ttt, instanceName, instance.Name) - }, retryDuration, tick) -} - -func TestServer_TestInstanceDeleteReduces(t *testing.T) { - instanceName := gofakeit.Name() - res, err := SystemClient.CreateInstance(CTX, &system.CreateInstanceRequest{ - InstanceName: instanceName, - Owner: &system.CreateInstanceRequest_Machine_{ - Machine: &system.CreateInstanceRequest_Machine{ - UserName: "owner", - Name: "owner", - PersonalAccessToken: &system.CreateInstanceRequest_PersonalAccessToken{}, + t.Run("test instance delete reduces", func(t *testing.T) { + instanceName := gofakeit.Name() + res, err := SystemClient.CreateInstance(CTX, &system.CreateInstanceRequest{ + InstanceName: instanceName, + Owner: &system.CreateInstanceRequest_Machine_{ + Machine: &system.CreateInstanceRequest_Machine{ + UserName: "owner", + Name: "owner", + PersonalAccessToken: &system.CreateInstanceRequest_PersonalAccessToken{}, + }, }, - }, + }) + require.NoError(t, err) + + instanceRepo := repository.InstanceRepository(pool) + + // check instance exists + retryDuration, tick := integration.WaitForAndTickWithMaxDuration(CTX, time.Minute) + assert.EventuallyWithT(t, func(ttt *assert.CollectT) { + instance, err := instanceRepo.Get(CTX, + res.InstanceId, + ) + require.NoError(ttt, err) + assert.Equal(ttt, instanceName, instance.Name) + }, retryDuration, tick) + + _, err = SystemClient.RemoveInstance(CTX, &system.RemoveInstanceRequest{ + InstanceId: res.InstanceId, + }) + require.NoError(t, err) + + retryDuration, tick = integration.WaitForAndTickWithMaxDuration(CTX, time.Minute) + assert.EventuallyWithT(t, func(ttt *assert.CollectT) { + instance, err := instanceRepo.Get(CTX, + res.InstanceId, + ) + // event instance.removed + assert.Nil(t, instance) + require.NoError(t, err) + }, retryDuration, tick) }) - require.NoError(t, err) - - instanceRepo := repository.InstanceRepository(pool) - - // check instance exists - retryDuration, tick := integration.WaitForAndTickWithMaxDuration(CTX, time.Minute) - assert.EventuallyWithT(t, func(ttt *assert.CollectT) { - instance, err := instanceRepo.Get(CTX, - instanceRepo.NameCondition(database.TextOperationEqual, instanceName), - ) - require.NoError(ttt, err) - require.Equal(ttt, instanceName, instance.Name) - }, retryDuration, tick) - - _, err = SystemClient.RemoveInstance(CTX, &system.RemoveInstanceRequest{ - InstanceId: res.InstanceId, - }) - require.NoError(t, err) - - retryDuration, tick = integration.WaitForAndTickWithMaxDuration(CTX, time.Minute) - assert.EventuallyWithT(t, func(ttt *assert.CollectT) { - instance, err := instanceRepo.Get(CTX, - instanceRepo.NameCondition(database.TextOperationEqual, instanceName), - ) - // event instance.removed - require.Nil(t, instance) - require.NoError(ttt, err) - }, retryDuration, tick) } diff --git a/backend/v3/storage/database/events_testing/organization_test.go b/backend/v3/storage/database/events_testing/organization_test.go new file mode 100644 index 0000000000..a97c15a8c9 --- /dev/null +++ b/backend/v3/storage/database/events_testing/organization_test.go @@ -0,0 +1,218 @@ +//go:build integration + +package events_test + +import ( + "testing" + "time" + + "github.com/brianvoe/gofakeit/v6" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/zitadel/zitadel/backend/v3/domain" + "github.com/zitadel/zitadel/backend/v3/storage/database/repository" + "github.com/zitadel/zitadel/internal/integration" + v2beta_org "github.com/zitadel/zitadel/pkg/grpc/org/v2beta" +) + +func TestServer_TestOrganizationReduces(t *testing.T) { + instanceID := Instance.ID() + + t.Run("test org add reduces", func(t *testing.T) { + beforeCreate := time.Now() + orgName := gofakeit.Name() + + _, err := OrgClient.CreateOrganization(CTX, &v2beta_org.CreateOrganizationRequest{ + Name: orgName, + }) + require.NoError(t, err) + afterCreate := time.Now() + + orgRepo := repository.OrganizationRepository(pool) + + retryDuration, tick := integration.WaitForAndTickWithMaxDuration(CTX, time.Minute) + assert.EventuallyWithT(t, func(tt *assert.CollectT) { + organization, err := orgRepo.Get(CTX, + orgRepo.NameCondition(orgName), + instanceID, + ) + require.NoError(tt, err) + + // event org.added + assert.NotNil(t, organization.ID) + assert.Equal(t, orgName, organization.Name) + assert.NotNil(t, organization.InstanceID) + assert.Equal(t, domain.OrgStateActive.String(), organization.State) + assert.WithinRange(t, organization.CreatedAt, beforeCreate, afterCreate) + assert.WithinRange(t, organization.UpdatedAt, beforeCreate, afterCreate) + assert.Nil(t, organization.DeletedAt) + }, retryDuration, tick) + }) + + t.Run("test org change reduces", func(t *testing.T) { + orgName := gofakeit.Name() + + // 1. create org + organization, err := OrgClient.CreateOrganization(CTX, &v2beta_org.CreateOrganizationRequest{ + Name: orgName, + }) + require.NoError(t, err) + + // 2. update org name + beforeUpdate := time.Now() + orgName = orgName + "_new" + _, err = OrgClient.UpdateOrganization(CTX, &v2beta_org.UpdateOrganizationRequest{ + Id: organization.Id, + Name: orgName, + }) + require.NoError(t, err) + afterUpdate := time.Now() + + orgRepo := repository.OrganizationRepository(pool) + + retryDuration, tick := integration.WaitForAndTickWithMaxDuration(CTX, time.Minute) + assert.EventuallyWithT(t, func(t *assert.CollectT) { + organization, err := orgRepo.Get(CTX, + orgRepo.NameCondition(orgName), + instanceID, + ) + require.NoError(t, err) + + // event org.changed + assert.Equal(t, orgName, organization.Name) + assert.WithinRange(t, organization.UpdatedAt, beforeUpdate, afterUpdate) + }, retryDuration, tick) + }) + + t.Run("test org deactivate reduces", func(t *testing.T) { + orgName := gofakeit.Name() + + // 1. create org + organization, err := OrgClient.CreateOrganization(CTX, &v2beta_org.CreateOrganizationRequest{ + Name: orgName, + }) + require.NoError(t, err) + + // 2. deactivate org name + beforeDeactivate := time.Now() + _, err = OrgClient.DeactivateOrganization(CTX, &v2beta_org.DeactivateOrganizationRequest{ + Id: organization.Id, + }) + + require.NoError(t, err) + afterDeactivate := time.Now() + + orgRepo := repository.OrganizationRepository(pool) + + retryDuration, tick := integration.WaitForAndTickWithMaxDuration(CTX, time.Minute) + assert.EventuallyWithT(t, func(t *assert.CollectT) { + organization, err := orgRepo.Get(CTX, + orgRepo.NameCondition(orgName), + instanceID, + ) + require.NoError(t, err) + + // event org.deactivate + assert.Equal(t, orgName, organization.Name) + assert.Equal(t, domain.OrgStateInactive.String(), organization.State) + assert.WithinRange(t, organization.UpdatedAt, beforeDeactivate, afterDeactivate) + }, retryDuration, tick) + }) + + t.Run("test org activate reduces", func(t *testing.T) { + orgName := gofakeit.Name() + + // 1. create org + organization, err := OrgClient.CreateOrganization(CTX, &v2beta_org.CreateOrganizationRequest{ + Name: orgName, + }) + require.NoError(t, err) + + // 2. deactivate org name + _, err = OrgClient.DeactivateOrganization(CTX, &v2beta_org.DeactivateOrganizationRequest{ + Id: organization.Id, + }) + require.NoError(t, err) + + orgRepo := repository.OrganizationRepository(pool) + // 3. check org deactivated + retryDuration, tick := integration.WaitForAndTickWithMaxDuration(CTX, time.Minute) + assert.EventuallyWithT(t, func(t *assert.CollectT) { + organization, err := orgRepo.Get(CTX, + orgRepo.NameCondition(orgName), + instanceID, + ) + require.NoError(t, err) + + assert.Equal(t, orgName, organization.Name) + assert.Equal(t, domain.OrgStateInactive.String(), organization.State) + }, retryDuration, tick) + + // 4. activate org name + beforeActivate := time.Now() + _, err = OrgClient.ActivateOrganization(CTX, &v2beta_org.ActivateOrganizationRequest{ + Id: organization.Id, + }) + require.NoError(t, err) + afterActivate := time.Now() + + retryDuration, tick = integration.WaitForAndTickWithMaxDuration(CTX, time.Minute) + assert.EventuallyWithT(t, func(t *assert.CollectT) { + organization, err := orgRepo.Get(CTX, + orgRepo.NameCondition(orgName), + instanceID, + ) + require.NoError(t, err) + + // event org.reactivate + assert.Equal(t, orgName, organization.Name) + assert.Equal(t, domain.OrgStateActive.String(), organization.State) + assert.WithinRange(t, organization.UpdatedAt, beforeActivate, afterActivate) + }, retryDuration, tick) + }) + + t.Run("test org remove reduces", func(t *testing.T) { + orgName := gofakeit.Name() + + // 1. create org + organization, err := OrgClient.CreateOrganization(CTX, &v2beta_org.CreateOrganizationRequest{ + Name: orgName, + }) + require.NoError(t, err) + + // 2. check org retrivable + orgRepo := repository.OrganizationRepository(pool) + retryDuration, tick := integration.WaitForAndTickWithMaxDuration(CTX, time.Minute) + assert.EventuallyWithT(t, func(t *assert.CollectT) { + organization, err := orgRepo.Get(CTX, + orgRepo.NameCondition(orgName), + instanceID, + ) + require.NoError(t, err) + + if organization == nil { + assert.Fail(t, "this error is here because of a race condition") + } + assert.Equal(t, orgName, organization.Name) + }, retryDuration, tick) + + // 3. delete org + _, err = OrgClient.DeleteOrganization(CTX, &v2beta_org.DeleteOrganizationRequest{ + Id: organization.Id, + }) + require.NoError(t, err) + + retryDuration, tick = integration.WaitForAndTickWithMaxDuration(CTX, time.Minute) + assert.EventuallyWithT(t, func(t *assert.CollectT) { + organization, err := orgRepo.Get(CTX, + orgRepo.NameCondition(orgName), + instanceID, + ) + require.NoError(t, err) + + // event org.remove + assert.Nil(t, organization) + }, retryDuration, tick) + }) +} diff --git a/backend/v3/storage/database/order.go b/backend/v3/storage/database/order.go new file mode 100644 index 0000000000..72995c3178 --- /dev/null +++ b/backend/v3/storage/database/order.go @@ -0,0 +1,21 @@ +package database + +// Order represents a SQL condition. +// Its written after the ORDER BY keyword in a SQL statement. +type Order interface { + Write(builder *StatementBuilder) +} + +type orderBy struct { + column Column +} + +func OrderBy(column Column) Order { + return &orderBy{column: column} +} + +// Write implements [Order]. +func (o *orderBy) Write(builder *StatementBuilder) { + builder.WriteString(" ORDER BY ") + o.column.Write(builder) +} diff --git a/backend/v3/storage/database/repository/instance.go b/backend/v3/storage/database/repository/instance.go index 4d81b8711e..80e852785f 100644 --- a/backend/v3/storage/database/repository/instance.go +++ b/backend/v3/storage/database/repository/instance.go @@ -33,16 +33,17 @@ const queryInstanceStmt = `SELECT id, name, default_org_id, iam_project_id, cons ` FROM zitadel.instances` // Get implements [domain.InstanceRepository]. -func (i *instance) Get(ctx context.Context, opts ...database.Condition) (*domain.Instance, error) { +func (i *instance) Get(ctx context.Context, id string) (*domain.Instance, error) { var builder database.StatementBuilder builder.WriteString(queryInstanceStmt) + idCondition := i.IDCondition(id) // return only non deleted instances - opts = append(opts, database.IsNull(i.DeletedAtColumn())) - i.writeCondition(&builder, database.And(opts...)) + conditions := []database.Condition{idCondition, database.IsNull(i.DeletedAtColumn())} + writeCondition(&builder, database.And(conditions...)) - return scanInstance(i.client.QueryRow(ctx, builder.String(), builder.Args()...)) + return scanInstance(ctx, i.client, &builder) } // List implements [domain.InstanceRepository]. @@ -54,15 +55,9 @@ func (i *instance) List(ctx context.Context, opts ...database.Condition) ([]*dom // return only non deleted instances opts = append(opts, database.IsNull(i.DeletedAtColumn())) notDeletedCondition := database.And(opts...) - i.writeCondition(&builder, notDeletedCondition) + writeCondition(&builder, notDeletedCondition) - rows, err := i.client.Query(ctx, builder.String(), builder.Args()...) - if err != nil { - return nil, err - } - defer rows.Close() - - return scanInstances(rows) + return scanInstances(ctx, i.client, &builder) } const createInstanceStmt = `INSERT INTO zitadel.instances (id, name, default_org_id, iam_project_id, console_client_id, console_app_id, default_language)` + @@ -101,15 +96,20 @@ func (i *instance) Create(ctx context.Context, instance *domain.Instance) error } // Update implements [domain.InstanceRepository]. -func (i instance) Update(ctx context.Context, condition database.Condition, changes ...database.Change) (int64, error) { +func (i instance) Update(ctx context.Context, id string, changes ...database.Change) (int64, error) { + if changes == nil { + return 0, errors.New("Update must contain a change") + } var builder database.StatementBuilder builder.WriteString(`UPDATE zitadel.instances SET `) - // don't update deleted instances - conditions := []database.Condition{condition, database.IsNull(i.DeletedAtColumn())} database.Changes(changes).Write(&builder) - i.writeCondition(&builder, database.And(conditions...)) + + idCondition := i.IDCondition(id) + // don't update deleted instances + conditions := []database.Condition{idCondition, database.IsNull(i.DeletedAtColumn())} + writeCondition(&builder, database.And(conditions...)) stmt := builder.String() @@ -118,18 +118,18 @@ func (i instance) Update(ctx context.Context, condition database.Condition, chan } // Delete implements [domain.InstanceRepository]. -func (i instance) Delete(ctx context.Context, condition database.Condition) error { - if condition == nil { - return errors.New("Delete must contain a condition") // (otherwise ALL instances will be deleted) - } +func (i instance) Delete(ctx context.Context, id string) (int64, error) { var builder database.StatementBuilder builder.WriteString(`UPDATE zitadel.instances SET deleted_at = $1`) builder.AppendArgs(time.Now()) - i.writeCondition(&builder, condition) - _, err := i.client.Exec(ctx, builder.String(), builder.Args()...) - return err + // don't delete already deleted instance + idCondition := i.IDCondition(id) + conditions := []database.Condition{idCondition, database.IsNull(i.DeletedAtColumn())} + writeCondition(&builder, database.And(conditions...)) + + return i.client.Exec(ctx, builder.String(), builder.Args()...) } // ------------------------------------------------------------- @@ -209,32 +209,30 @@ func (instance) DeletedAtColumn() database.Column { return database.NewColumn("deleted_at") } -func (i *instance) writeCondition( - builder *database.StatementBuilder, - condition database.Condition, -) { - if condition == nil { - return +func scanInstance(ctx context.Context, querier database.Querier, builder *database.StatementBuilder) (*domain.Instance, error) { + rows, err := querier.Query(ctx, builder.String(), builder.Args()...) + if err != nil { + return nil, err } - builder.WriteString(" WHERE ") - condition.Write(builder) + + instance := new(domain.Instance) + if err := rows.(database.CollectableRows).CollectExactlyOneRow(instance); err != nil { + if err.Error() == "no rows in result set" { + return nil, ErrResourceDoesNotExist + } + return nil, err + } + + return instance, nil } -func scanInstance(scanner database.Scanner) (*domain.Instance, error) { - var instance domain.Instance - err := scanner.Scan( - &instance.ID, - &instance.Name, - &instance.DefaultOrgID, - &instance.IAMProjectID, - &instance.ConsoleClientID, - &instance.ConsoleAppID, - &instance.DefaultLanguage, - &instance.CreatedAt, - &instance.UpdatedAt, - &instance.DeletedAt, - ) +func scanInstances(ctx context.Context, querier database.Querier, builder *database.StatementBuilder) (instances []*domain.Instance, err error) { + rows, err := querier.Query(ctx, builder.String(), builder.Args()...) if err != nil { + return nil, err + } + + if err := rows.(database.CollectableRows).Collect(&instances); err != nil { // if no results returned, this is not a error // it just means the instance was not found // the caller should check if the returned instance is nil @@ -244,32 +242,5 @@ func scanInstance(scanner database.Scanner) (*domain.Instance, error) { return nil, err } - return &instance, nil -} - -func scanInstances(rows database.Rows) ([]*domain.Instance, error) { - instances := make([]*domain.Instance, 0) - for rows.Next() { - - var instance domain.Instance - err := rows.Scan( - &instance.ID, - &instance.Name, - &instance.DefaultOrgID, - &instance.IAMProjectID, - &instance.ConsoleClientID, - &instance.ConsoleAppID, - &instance.DefaultLanguage, - &instance.CreatedAt, - &instance.UpdatedAt, - &instance.DeletedAt, - ) - if err != nil { - return nil, err - } - - instances = append(instances, &instance) - - } return instances, nil } diff --git a/backend/v3/storage/database/repository/instance_test.go b/backend/v3/storage/database/repository/instance_test.go index a7f4c9add1..e6fb6964fe 100644 --- a/backend/v3/storage/database/repository/instance_test.go +++ b/backend/v3/storage/database/repository/instance_test.go @@ -7,6 +7,7 @@ import ( "time" "github.com/brianvoe/gofakeit/v6" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/zitadel/zitadel/backend/v3/domain" @@ -74,11 +75,61 @@ func TestCreateInstance(t *testing.T) { } err := instanceRepo.Create(ctx, &inst) + // change the name to make sure same only the id clashes + inst.Name = gofakeit.Name() require.NoError(t, err) return &inst }, err: errors.New("instance id already exists"), }, + func() struct { + name string + testFunc func(ctx context.Context, t *testing.T) *domain.Instance + instance domain.Instance + err error + } { + instanceId := gofakeit.Name() + instanceName := gofakeit.Name() + return struct { + name string + testFunc func(ctx context.Context, t *testing.T) *domain.Instance + instance domain.Instance + err error + }{ + name: "adding instance with same name twice", + testFunc: func(ctx context.Context, t *testing.T) *domain.Instance { + instanceRepo := repository.InstanceRepository(pool) + + inst := domain.Instance{ + ID: gofakeit.Name(), + Name: instanceName, + DefaultOrgID: "defaultOrgId", + IAMProjectID: "iamProject", + ConsoleClientID: "consoleCLient", + ConsoleAppID: "consoleApp", + DefaultLanguage: "defaultLanguage", + } + + err := instanceRepo.Create(ctx, &inst) + require.NoError(t, err) + + // change the id + inst.ID = instanceId + return &inst + }, + instance: domain.Instance{ + ID: instanceId, + Name: instanceName, + DefaultOrgID: "defaultOrgId", + IAMProjectID: "iamProject", + ConsoleClientID: "consoleCLient", + ConsoleAppID: "consoleApp", + DefaultLanguage: "defaultLanguage", + }, + // two instances can have the sane name + err: nil, + } + }(), { name: "adding instance with no id", instance: func() domain.Instance { @@ -113,7 +164,7 @@ func TestCreateInstance(t *testing.T) { // create instance beforeCreate := time.Now() err := instanceRepo.Create(ctx, instance) - require.Equal(t, tt.err, err) + assert.Equal(t, tt.err, err) if err != nil { return } @@ -121,20 +172,20 @@ func TestCreateInstance(t *testing.T) { // check instance values instance, err = instanceRepo.Get(ctx, - instanceRepo.NameCondition(database.TextOperationEqual, instance.Name), + instance.ID, ) require.NoError(t, err) - require.Equal(t, tt.instance.ID, instance.ID) - require.Equal(t, tt.instance.Name, instance.Name) - require.Equal(t, tt.instance.DefaultOrgID, instance.DefaultOrgID) - require.Equal(t, tt.instance.IAMProjectID, instance.IAMProjectID) - require.Equal(t, tt.instance.ConsoleClientID, instance.ConsoleClientID) - require.Equal(t, tt.instance.ConsoleAppID, instance.ConsoleAppID) - require.Equal(t, tt.instance.DefaultLanguage, instance.DefaultLanguage) - require.WithinRange(t, instance.CreatedAt, beforeCreate, afterCreate) - require.WithinRange(t, instance.UpdatedAt, beforeCreate, afterCreate) - require.Nil(t, instance.DeletedAt) + assert.Equal(t, tt.instance.ID, instance.ID) + assert.Equal(t, tt.instance.Name, instance.Name) + assert.Equal(t, tt.instance.DefaultOrgID, instance.DefaultOrgID) + assert.Equal(t, tt.instance.IAMProjectID, instance.IAMProjectID) + assert.Equal(t, tt.instance.ConsoleClientID, instance.ConsoleClientID) + assert.Equal(t, tt.instance.ConsoleAppID, instance.ConsoleAppID) + assert.Equal(t, tt.instance.DefaultLanguage, instance.DefaultLanguage) + assert.WithinRange(t, instance.CreatedAt, beforeCreate, afterCreate) + assert.WithinRange(t, instance.UpdatedAt, beforeCreate, afterCreate) + assert.Nil(t, instance.DeletedAt) }) } } @@ -144,6 +195,7 @@ func TestUpdateInstance(t *testing.T) { name string testFunc func(ctx context.Context, t *testing.T) *domain.Instance rowsAffected int64 + getErr error }{ { name: "happy path", @@ -191,10 +243,11 @@ func TestUpdateInstance(t *testing.T) { require.NoError(t, err) // delete instance - err = instanceRepo.Delete(ctx, - instanceRepo.IDCondition(inst.ID), + affectedRows, err := instanceRepo.Delete(ctx, + inst.ID, ) require.NoError(t, err) + assert.Equal(t, int64(1), affectedRows) return &inst }, @@ -211,6 +264,7 @@ func TestUpdateInstance(t *testing.T) { return &inst }, rowsAffected: 0, + getErr: repository.ErrResourceDoesNotExist, }, } for _, tt := range tests { @@ -224,13 +278,13 @@ func TestUpdateInstance(t *testing.T) { // update name newName := "new_" + instance.Name rowsAffected, err := instanceRepo.Update(ctx, - instanceRepo.IDCondition(instance.ID), + instance.ID, instanceRepo.SetName(newName), ) afterUpdate := time.Now() require.NoError(t, err) - require.Equal(t, tt.rowsAffected, rowsAffected) + assert.Equal(t, tt.rowsAffected, rowsAffected) if rowsAffected == 0 { return @@ -238,13 +292,13 @@ func TestUpdateInstance(t *testing.T) { // check instance values instance, err = instanceRepo.Get(ctx, - instanceRepo.IDCondition(instance.ID), + instance.ID, ) - require.NoError(t, err) + require.Equal(t, tt.getErr, err) - require.Equal(t, newName, instance.Name) - require.WithinRange(t, instance.UpdatedAt, beforeUpdate, afterUpdate) - require.Nil(t, instance.DeletedAt) + assert.Equal(t, newName, instance.Name) + assert.WithinRange(t, instance.UpdatedAt, beforeUpdate, afterUpdate) + assert.Nil(t, instance.DeletedAt) }) } } @@ -252,9 +306,9 @@ func TestUpdateInstance(t *testing.T) { func TestGetInstance(t *testing.T) { instanceRepo := repository.InstanceRepository(pool) type test struct { - name string - testFunc func(ctx context.Context, t *testing.T) *domain.Instance - conditionClauses []database.Condition + name string + testFunc func(ctx context.Context, t *testing.T) *domain.Instance + err error } tests := []test{ @@ -280,45 +334,17 @@ func TestGetInstance(t *testing.T) { require.NoError(t, err) return &inst }, - conditionClauses: []database.Condition{instanceRepo.IDCondition(instanceId)}, - } - }(), - func() test { - instanceName := gofakeit.Name() - return test{ - name: "happy path get using name", - testFunc: func(ctx context.Context, t *testing.T) *domain.Instance { - instanceId := gofakeit.Name() - - inst := domain.Instance{ - ID: instanceId, - Name: instanceName, - DefaultOrgID: "defaultOrgId", - IAMProjectID: "iamProject", - ConsoleClientID: "consoleCLient", - ConsoleAppID: "consoleApp", - DefaultLanguage: "defaultLanguage", - } - - // create instance - err := instanceRepo.Create(ctx, &inst) - require.NoError(t, err) - return &inst - }, - conditionClauses: []database.Condition{instanceRepo.NameCondition(database.TextOperationEqual, instanceName)}, } }(), { name: "get non existent instance", testFunc: func(ctx context.Context, t *testing.T) *domain.Instance { - instanceId := gofakeit.Name() - - _ = domain.Instance{ - ID: instanceId, + inst := domain.Instance{ + ID: "get non existent instance", } - return nil + return &inst }, - conditionClauses: []database.Condition{instanceRepo.NameCondition(database.TextOperationEqual, "non-existent-instance-name")}, + err: repository.ErrResourceDoesNotExist, }, } for _, tt := range tests { @@ -333,22 +359,25 @@ func TestGetInstance(t *testing.T) { // check instance values returnedInstance, err := instanceRepo.Get(ctx, - tt.conditionClauses..., + instance.ID, ) - require.NoError(t, err) - if instance == nil { - require.Nil(t, instance, returnedInstance) + if tt.err != nil { + require.Equal(t, tt.err, err) return } - require.NoError(t, err) - require.Equal(t, returnedInstance.ID, instance.ID) - require.Equal(t, returnedInstance.Name, instance.Name) - require.Equal(t, returnedInstance.DefaultOrgID, instance.DefaultOrgID) - require.Equal(t, returnedInstance.IAMProjectID, instance.IAMProjectID) - require.Equal(t, returnedInstance.ConsoleClientID, instance.ConsoleClientID) - require.Equal(t, returnedInstance.ConsoleAppID, instance.ConsoleAppID) - require.Equal(t, returnedInstance.DefaultLanguage, instance.DefaultLanguage) + if instance.ID == "get non existent instance" { + assert.Nil(t, returnedInstance) + return + } + + assert.Equal(t, returnedInstance.ID, instance.ID) + assert.Equal(t, returnedInstance.Name, instance.Name) + assert.Equal(t, returnedInstance.DefaultOrgID, instance.DefaultOrgID) + assert.Equal(t, returnedInstance.IAMProjectID, instance.IAMProjectID) + assert.Equal(t, returnedInstance.ConsoleClientID, instance.ConsoleClientID) + assert.Equal(t, returnedInstance.ConsoleAppID, instance.ConsoleAppID) + assert.Equal(t, returnedInstance.DefaultLanguage, instance.DefaultLanguage) }) } } @@ -504,19 +533,19 @@ func TestListInstance(t *testing.T) { ) require.NoError(t, err) if tt.noInstanceReturned { - require.Nil(t, returnedInstances) + assert.Nil(t, returnedInstances) return } - require.Equal(t, len(instances), len(returnedInstances)) + assert.Equal(t, len(instances), len(returnedInstances)) for i, instance := range instances { - require.Equal(t, returnedInstances[i].ID, instance.ID) - require.Equal(t, returnedInstances[i].Name, instance.Name) - require.Equal(t, returnedInstances[i].DefaultOrgID, instance.DefaultOrgID) - require.Equal(t, returnedInstances[i].IAMProjectID, instance.IAMProjectID) - require.Equal(t, returnedInstances[i].ConsoleClientID, instance.ConsoleClientID) - require.Equal(t, returnedInstances[i].ConsoleAppID, instance.ConsoleAppID) - require.Equal(t, returnedInstances[i].DefaultLanguage, instance.DefaultLanguage) + assert.Equal(t, returnedInstances[i].ID, instance.ID) + assert.Equal(t, returnedInstances[i].Name, instance.Name) + assert.Equal(t, returnedInstances[i].DefaultOrgID, instance.DefaultOrgID) + assert.Equal(t, returnedInstances[i].IAMProjectID, instance.IAMProjectID) + assert.Equal(t, returnedInstances[i].ConsoleClientID, instance.ConsoleClientID) + assert.Equal(t, returnedInstances[i].ConsoleAppID, instance.ConsoleAppID) + assert.Equal(t, returnedInstances[i].DefaultLanguage, instance.DefaultLanguage) } }) } @@ -524,18 +553,19 @@ func TestListInstance(t *testing.T) { func TestDeleteInstance(t *testing.T) { type test struct { - name string - testFunc func(ctx context.Context, t *testing.T) - conditionClauses database.Condition + name string + testFunc func(ctx context.Context, t *testing.T) + instanceID string + noOfDeletedRows int64 } tests := []test{ func() test { instanceRepo := repository.InstanceRepository(pool) instanceId := gofakeit.Name() + var noOfInstances int64 = 1 return test{ name: "happy path delete single instance filter id", testFunc: func(ctx context.Context, t *testing.T) { - noOfInstances := 1 instances := make([]*domain.Instance, noOfInstances) for i := range noOfInstances { @@ -556,75 +586,15 @@ func TestDeleteInstance(t *testing.T) { instances[i] = &inst } }, - conditionClauses: instanceRepo.IDCondition(instanceId), + instanceID: instanceId, + noOfDeletedRows: noOfInstances, } }(), func() test { - instanceRepo := repository.InstanceRepository(pool) - instanceName := gofakeit.Name() - return test{ - name: "happy path delete single instance filter name", - testFunc: func(ctx context.Context, t *testing.T) { - noOfInstances := 1 - instances := make([]*domain.Instance, noOfInstances) - for i := range noOfInstances { - - inst := domain.Instance{ - ID: gofakeit.Name(), - Name: instanceName, - DefaultOrgID: "defaultOrgId", - IAMProjectID: "iamProject", - ConsoleClientID: "consoleCLient", - ConsoleAppID: "consoleApp", - DefaultLanguage: "defaultLanguage", - } - - // create instance - err := instanceRepo.Create(ctx, &inst) - require.NoError(t, err) - - instances[i] = &inst - } - }, - conditionClauses: instanceRepo.NameCondition(database.TextOperationEqual, instanceName), - } - }(), - func() test { - instanceRepo := repository.InstanceRepository(pool) non_existent_instance_name := gofakeit.Name() return test{ - name: "delete non existent instance", - conditionClauses: instanceRepo.NameCondition(database.TextOperationEqual, non_existent_instance_name), - } - }(), - func() test { - instanceRepo := repository.InstanceRepository(pool) - instanceName := gofakeit.Name() - return test{ - name: "multiple instance filter on name", - testFunc: func(ctx context.Context, t *testing.T) { - noOfInstances := 5 - instances := make([]*domain.Instance, noOfInstances) - for i := range noOfInstances { - - inst := domain.Instance{ - ID: gofakeit.Name(), - Name: instanceName, - DefaultOrgID: "defaultOrgId", - IAMProjectID: "iamProject", - ConsoleClientID: "consoleCLient", - ConsoleAppID: "consoleApp", - DefaultLanguage: "defaultLanguage", - } - - // create instance - err := instanceRepo.Create(ctx, &inst) - require.NoError(t, err) - - instances[i] = &inst - } - }, - conditionClauses: instanceRepo.NameCondition(database.TextOperationEqual, instanceName), + name: "delete non existent instance", + instanceID: non_existent_instance_name, } }(), func() test { @@ -655,12 +625,15 @@ func TestDeleteInstance(t *testing.T) { } // delete instance - err := instanceRepo.Delete(ctx, - instanceRepo.NameCondition(database.TextOperationEqual, instanceName), + affectedRows, err := instanceRepo.Delete(ctx, + instances[0].ID, ) require.NoError(t, err) + assert.Equal(t, int64(1), affectedRows) }, - conditionClauses: instanceRepo.NameCondition(database.TextOperationEqual, instanceName), + instanceID: instanceName, + // this test should return 0 affected rows as the instance was already deleted + noOfDeletedRows: 0, } }(), } @@ -674,17 +647,18 @@ func TestDeleteInstance(t *testing.T) { } // delete instance - err := instanceRepo.Delete(ctx, - tt.conditionClauses, + noOfDeletedRows, err := instanceRepo.Delete(ctx, + tt.instanceID, ) require.NoError(t, err) + assert.Equal(t, noOfDeletedRows, tt.noOfDeletedRows) // check instance was deleted instance, err := instanceRepo.Get(ctx, - tt.conditionClauses, + tt.instanceID, ) - require.NoError(t, err) - require.Nil(t, instance) + require.Equal(t, err, repository.ErrResourceDoesNotExist) + assert.Nil(t, instance) }) } } diff --git a/backend/v3/storage/database/repository/org.go b/backend/v3/storage/database/repository/org.go index 2dea176730..90670500fa 100644 --- a/backend/v3/storage/database/repository/org.go +++ b/backend/v3/storage/database/repository/org.go @@ -2,8 +2,11 @@ package repository import ( "context" + "errors" "time" + "github.com/jackc/pgx/v5/pgconn" + "github.com/zitadel/zitadel/backend/v3/domain" "github.com/zitadel/zitadel/backend/v3/storage/database" ) @@ -12,11 +15,13 @@ import ( // repository // ------------------------------------------------------------- +var _ domain.OrganizationRepository = (*org)(nil) + type org struct { repository } -func OrgRepository(client database.QueryExecutor) domain.OrgRepository { +func OrganizationRepository(client database.QueryExecutor) domain.OrganizationRepository { return &org{ repository: repository{ client: client, @@ -24,52 +29,140 @@ func OrgRepository(client database.QueryExecutor) domain.OrgRepository { } } -// Create implements [domain.OrgRepository]. -func (o *org) Create(ctx context.Context, org *domain.Org) error { - org.CreatedAt = time.Now() - org.UpdatedAt = org.CreatedAt +const queryOrganizationStmt = `SELECT id, name, instance_id, state, created_at, updated_at, deleted_at` + + ` FROM zitadel.organizations` + +// Get implements [domain.OrganizationRepository]. +func (o *org) Get(ctx context.Context, id domain.OrgIdentifierCondition, instanceID string, conditions ...database.Condition) (*domain.Organization, error) { + builder := database.StatementBuilder{} + + builder.WriteString(queryOrganizationStmt) + + instanceIDCondition := o.InstanceIDCondition(instanceID) + // don't update deleted organizations + nonDeletedOrgs := database.IsNull(o.DeletedAtColumn()) + + conditions = append(conditions, id, instanceIDCondition, nonDeletedOrgs) + writeCondition(&builder, database.And(conditions...)) + + return scanOrganization(ctx, o.client, &builder) +} + +// List implements [domain.OrganizationRepository]. +func (o *org) List(ctx context.Context, opts ...database.Condition) ([]*domain.Organization, error) { + builder := database.StatementBuilder{} + + builder.WriteString(queryOrganizationStmt) + + // return only non deleted organizations + opts = append(opts, database.IsNull(o.DeletedAtColumn())) + writeCondition(&builder, database.And(opts...)) + + orderBy := database.OrderBy(o.CreatedAtColumn()) + orderBy.Write(&builder) + + return scanOrganizations(ctx, o.client, &builder) +} + +const createOrganizationStmt = `INSERT INTO zitadel.organizations (id, name, instance_id, state)` + + ` VALUES ($1, $2, $3, $4)` + + ` RETURNING created_at, updated_at` + +// Create implements [domain.OrganizationRepository]. +func (o *org) Create(ctx context.Context, organization *domain.Organization) error { + builder := database.StatementBuilder{} + builder.AppendArgs(organization.ID, organization.Name, organization.InstanceID, organization.State) + builder.WriteString(createOrganizationStmt) + + err := o.client.QueryRow(ctx, builder.String(), builder.Args()...).Scan(&organization.CreatedAt, &organization.UpdatedAt) + if err != nil { + return checkCreateOrgErr(err) + } return nil } -// Delete implements [domain.OrgRepository]. -func (o *org) Delete(ctx context.Context, condition database.Condition) error { - return nil +func checkCreateOrgErr(err error) error { + var pgErr *pgconn.PgError + if !errors.As(err, &pgErr) { + return err + } + // constraint violation + if pgErr.Code == "23514" { + if pgErr.ConstraintName == "organizations_name_check" { + return errors.New("organization name not provided") + } + if pgErr.ConstraintName == "organizations_id_check" { + return errors.New("organization id not provided") + } + } + // duplicate + if pgErr.Code == "23505" { + if pgErr.ConstraintName == "organizations_pkey" { + return errors.New("organization id already exists") + } + if pgErr.ConstraintName == "org_unique_instance_id_name_idx" { + return errors.New("organization name already exists for instance") + } + } + // invalid instance id + if pgErr.Code == "23503" { + if pgErr.ConstraintName == "organizations_instance_id_fkey" { + return errors.New("invalid instance id") + } + } + return err } -// Get implements [domain.OrgRepository]. -func (o *org) Get(ctx context.Context, opts ...database.QueryOption) (*domain.Org, error) { - panic("unimplemented") +// Update implements [domain.OrganizationRepository]. +func (o org) Update(ctx context.Context, id domain.OrgIdentifierCondition, instanceID string, changes ...database.Change) (int64, error) { + if changes == nil { + return 0, errors.New("Update must contain a condition") // (otherwise ALL organizations will be updated) + } + builder := database.StatementBuilder{} + builder.WriteString(`UPDATE zitadel.organizations SET `) + + instanceIDCondition := o.InstanceIDCondition(instanceID) + // don't update deleted organizations + nonDeletedOrgs := database.IsNull(o.DeletedAtColumn()) + + conditions := []database.Condition{id, instanceIDCondition, nonDeletedOrgs} + database.Changes(changes).Write(&builder) + writeCondition(&builder, database.And(conditions...)) + + stmt := builder.String() + + rowsAffected, err := o.client.Exec(ctx, stmt, builder.Args()...) + return rowsAffected, err } -// List implements [domain.OrgRepository]. -func (o *org) List(ctx context.Context, opts ...database.QueryOption) ([]*domain.Org, error) { - panic("unimplemented") -} +// Delete implements [domain.OrganizationRepository]. +func (o org) Delete(ctx context.Context, id domain.OrgIdentifierCondition, instanceID string) (int64, error) { + builder := database.StatementBuilder{} -// Update implements [domain.OrgRepository]. -func (o *org) Update(ctx context.Context, condition database.Condition, changes ...database.Change) error { - panic("unimplemented") -} + builder.WriteString(`UPDATE zitadel.organizations SET deleted_at = $1`) + builder.AppendArgs(time.Now()) -func (o *org) Member() domain.MemberRepository { - return &orgMember{o} -} + instanceIDCondition := o.InstanceIDCondition(instanceID) + // don't update deleted organizations + nonDeletedOrgs := database.IsNull(o.DeletedAtColumn()) -func (o *org) Domain() domain.DomainRepository { - return &orgDomain{o} + conditions := []database.Condition{id, instanceIDCondition, nonDeletedOrgs} + writeCondition(&builder, database.And(conditions...)) + + return o.client.Exec(ctx, builder.String(), builder.Args()...) } // ------------------------------------------------------------- // changes // ------------------------------------------------------------- -// SetName implements [domain.orgChanges]. -func (o *org) SetName(name string) database.Change { +// SetName implements [domain.organizationChanges]. +func (o org) SetName(name string) database.Change { return database.NewChange(o.NameColumn(), name) } -// SetState implements [domain.orgChanges]. -func (o *org) SetState(state domain.OrgState) database.Change { +// SetState implements [domain.organizationChanges]. +func (o org) SetState(state domain.OrgState) database.Change { return database.NewChange(o.StateColumn(), state) } @@ -77,63 +170,97 @@ func (o *org) SetState(state domain.OrgState) database.Change { // conditions // ------------------------------------------------------------- -// IDCondition implements [domain.orgConditions]. -func (o *org) IDCondition(orgID string) database.Condition { - return database.NewTextCondition(o.IDColumn(), database.TextOperationEqual, orgID) +// IDCondition implements [domain.organizationConditions]. +func (o org) IDCondition(id string) domain.OrgIdentifierCondition { + return database.NewTextCondition(o.IDColumn(), database.TextOperationEqual, id) } -// InstanceIDCondition implements [domain.orgConditions]. -func (o *org) InstanceIDCondition(instanceID string) database.Condition { +// NameCondition implements [domain.organizationConditions]. +func (o org) NameCondition(name string) domain.OrgIdentifierCondition { + return database.NewTextCondition(o.NameColumn(), database.TextOperationEqual, name) +} + +// InstanceIDCondition implements [domain.organizationConditions]. +func (o org) InstanceIDCondition(instanceID string) database.Condition { return database.NewTextCondition(o.InstanceIDColumn(), database.TextOperationEqual, instanceID) } -// NameCondition implements [domain.orgConditions]. -func (o *org) NameCondition(op database.TextOperation, name string) database.Condition { - return database.NewTextCondition(o.NameColumn(), op, name) -} - -// StateCondition implements [domain.orgConditions]. -func (o *org) StateCondition(op database.NumberOperation, state domain.OrgState) database.Condition { - return database.NewNumberCondition(o.StateColumn(), op, state) +// StateCondition implements [domain.organizationConditions]. +func (o org) StateCondition(state domain.OrgState) database.Condition { + return database.NewTextCondition(o.StateColumn(), database.TextOperationEqual, state.String()) } // ------------------------------------------------------------- // columns // ------------------------------------------------------------- -// CreatedAtColumn implements [domain.orgColumns]. -func (o *org) CreatedAtColumn() database.Column { - return database.NewColumn("created_at") -} - -// DeletedAtColumn implements [domain.orgColumns]. -func (o *org) DeletedAtColumn() database.Column { - return database.NewColumn("deleted_at") -} - -// IDColumn implements [domain.orgColumns]. -func (o *org) IDColumn() database.Column { +// IDColumn implements [domain.organizationColumns]. +func (org) IDColumn() database.Column { return database.NewColumn("id") } -// InstanceIDColumn implements [domain.orgColumns]. -func (o *org) InstanceIDColumn() database.Column { - return database.NewColumn("instance_id") -} - -// NameColumn implements [domain.orgColumns]. -func (o *org) NameColumn() database.Column { +// NameColumn implements [domain.organizationColumns]. +func (org) NameColumn() database.Column { return database.NewColumn("name") } -// StateColumn implements [domain.orgColumns]. -func (o *org) StateColumn() database.Column { +// InstanceIDColumn implements [domain.organizationColumns]. +func (org) InstanceIDColumn() database.Column { + return database.NewColumn("instance_id") +} + +// StateColumn implements [domain.organizationColumns]. +func (org) StateColumn() database.Column { return database.NewColumn("state") } -// UpdatedAtColumn implements [domain.orgColumns]. -func (o *org) UpdatedAtColumn() database.Column { +// CreatedAtColumn implements [domain.organizationColumns]. +func (org) CreatedAtColumn() database.Column { + return database.NewColumn("created_at") +} + +// UpdatedAtColumn implements [domain.organizationColumns]. +func (org) UpdatedAtColumn() database.Column { return database.NewColumn("updated_at") } -var _ domain.OrgRepository = (*org)(nil) +// DeletedAtColumn implements [domain.organizationColumns]. +func (org) DeletedAtColumn() database.Column { + return database.NewColumn("deleted_at") +} + +func scanOrganization(ctx context.Context, querier database.Querier, builder *database.StatementBuilder) (*domain.Organization, error) { + rows, err := querier.Query(ctx, builder.String(), builder.Args()...) + if err != nil { + return nil, err + } + + organization := &domain.Organization{} + if err := rows.(database.CollectableRows).CollectExactlyOneRow(organization); err != nil { + if err.Error() == "no rows in result set" { + return nil, ErrResourceDoesNotExist + } + return nil, err + } + + return organization, nil +} + +func scanOrganizations(ctx context.Context, querier database.Querier, builder *database.StatementBuilder) ([]*domain.Organization, error) { + rows, err := querier.Query(ctx, builder.String(), builder.Args()...) + if err != nil { + return nil, err + } + + organizations := []*domain.Organization{} + if err := rows.(database.CollectableRows).Collect(&organizations); err != nil { + // if no results returned, this is not a error + // it just means the organization was not found + // the caller should check if the returned organization is nil + if err.Error() == "no rows in result set" { + return nil, nil + } + return nil, err + } + return organizations, nil +} diff --git a/backend/v3/storage/database/repository/org_test.go b/backend/v3/storage/database/repository/org_test.go index 1877a29458..243ac4f0b4 100644 --- a/backend/v3/storage/database/repository/org_test.go +++ b/backend/v3/storage/database/repository/org_test.go @@ -1,10 +1,942 @@ package repository_test -// iraq: I had to comment out so that the UTs would pass -// TestBla is an example and can be removed later -// func TestBla(t *testing.T) { -// var count int -// err := pool.QueryRow(context.Background(), "select count(*) from zitadel.instances").Scan(&count) -// assert.NoError(t, err) -// assert.Equal(t, 0, count) -// } +import ( + "context" + "errors" + "testing" + "time" + + "github.com/brianvoe/gofakeit/v6" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/zitadel/zitadel/backend/v3/domain" + "github.com/zitadel/zitadel/backend/v3/storage/database" + "github.com/zitadel/zitadel/backend/v3/storage/database/repository" +) + +func TestCreateOrganization(t *testing.T) { + // create instance + instanceId := gofakeit.Name() + instance := domain.Instance{ + ID: instanceId, + Name: gofakeit.Name(), + DefaultOrgID: "defaultOrgId", + IAMProjectID: "iamProject", + ConsoleClientID: "consoleCLient", + ConsoleAppID: "consoleApp", + DefaultLanguage: "defaultLanguage", + } + instanceRepo := repository.InstanceRepository(pool) + err := instanceRepo.Create(t.Context(), &instance) + assert.Nil(t, err) + + tests := []struct { + name string + testFunc func(ctx context.Context, t *testing.T) *domain.Organization + organization domain.Organization + err error + }{ + { + name: "happy path", + organization: func() domain.Organization { + organizationId := gofakeit.Name() + organizationName := gofakeit.Name() + organization := domain.Organization{ + ID: organizationId, + Name: organizationName, + InstanceID: instanceId, + State: domain.OrgStateActive.String(), + } + return organization + }(), + }, + { + name: "create organization without name", + organization: func() domain.Organization { + organizationId := gofakeit.Name() + // organizationName := gofakeit.Name() + organization := domain.Organization{ + ID: organizationId, + Name: "", + InstanceID: instanceId, + State: domain.OrgStateActive.String(), + } + return organization + }(), + err: errors.New("organization name not provided"), + }, + { + name: "adding org with same id twice", + testFunc: func(ctx context.Context, t *testing.T) *domain.Organization { + organizationRepo := repository.OrganizationRepository(pool) + organizationId := gofakeit.Name() + organizationName := gofakeit.Name() + + org := domain.Organization{ + ID: organizationId, + Name: organizationName, + InstanceID: instanceId, + State: domain.OrgStateActive.String(), + } + + err := organizationRepo.Create(ctx, &org) + require.NoError(t, err) + // change the name to make sure same only the id clashes + org.Name = gofakeit.Name() + return &org + }, + err: errors.New("organization id already exists"), + }, + { + name: "adding org with same name twice", + testFunc: func(ctx context.Context, t *testing.T) *domain.Organization { + organizationRepo := repository.OrganizationRepository(pool) + organizationId := gofakeit.Name() + organizationName := gofakeit.Name() + + org := domain.Organization{ + ID: organizationId, + Name: organizationName, + InstanceID: instanceId, + State: domain.OrgStateActive.String(), + } + + err := organizationRepo.Create(ctx, &org) + require.NoError(t, err) + // change the id to make sure same name+instance causes an error + org.ID = gofakeit.Name() + return &org + }, + err: errors.New("organization name already exists for instance"), + }, + func() struct { + name string + testFunc func(ctx context.Context, t *testing.T) *domain.Organization + organization domain.Organization + err error + } { + orgID := gofakeit.Name() + organizationName := gofakeit.Name() + + return struct { + name string + testFunc func(ctx context.Context, t *testing.T) *domain.Organization + organization domain.Organization + err error + }{ + name: "adding org with same name, different instance", + testFunc: func(ctx context.Context, t *testing.T) *domain.Organization { + // create instance + instId := gofakeit.Name() + instance := domain.Instance{ + ID: instId, + Name: gofakeit.Name(), + DefaultOrgID: "defaultOrgId", + IAMProjectID: "iamProject", + ConsoleClientID: "consoleCLient", + ConsoleAppID: "consoleApp", + DefaultLanguage: "defaultLanguage", + } + instanceRepo := repository.InstanceRepository(pool) + err := instanceRepo.Create(ctx, &instance) + assert.Nil(t, err) + + organizationRepo := repository.OrganizationRepository(pool) + + org := domain.Organization{ + ID: gofakeit.Name(), + Name: organizationName, + InstanceID: instId, + State: domain.OrgStateActive.String(), + } + + err = organizationRepo.Create(ctx, &org) + require.NoError(t, err) + + // change the id to make it unique + org.ID = orgID + // change the instanceID to a different instance + org.InstanceID = instanceId + return &org + }, + organization: domain.Organization{ + ID: orgID, + Name: organizationName, + InstanceID: instanceId, + State: domain.OrgStateActive.String(), + }, + } + }(), + { + name: "adding organization with no id", + organization: func() domain.Organization { + // organizationId := gofakeit.Name() + organizationName := gofakeit.Name() + organization := domain.Organization{ + // ID: organizationId, + Name: organizationName, + InstanceID: instanceId, + State: domain.OrgStateActive.String(), + } + return organization + }(), + err: errors.New("organization id not provided"), + }, + { + name: "adding organization with no instance id", + organization: func() domain.Organization { + organizationId := gofakeit.Name() + organizationName := gofakeit.Name() + organization := domain.Organization{ + ID: organizationId, + Name: organizationName, + State: domain.OrgStateActive.String(), + } + return organization + }(), + err: errors.New("invalid instance id"), + }, + { + name: "adding organization with non existent instance id", + organization: func() domain.Organization { + organizationId := gofakeit.Name() + organizationName := gofakeit.Name() + organization := domain.Organization{ + ID: organizationId, + Name: organizationName, + InstanceID: gofakeit.Name(), + State: domain.OrgStateActive.String(), + } + return organization + }(), + err: errors.New("invalid instance id"), + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctx := context.Background() + + var organization *domain.Organization + if tt.testFunc != nil { + organization = tt.testFunc(ctx, t) + } else { + organization = &tt.organization + } + organizationRepo := repository.OrganizationRepository(pool) + + // create organization + beforeCreate := time.Now() + err = organizationRepo.Create(ctx, organization) + assert.Equal(t, tt.err, err) + if err != nil { + return + } + afterCreate := time.Now() + + // check organization values + organization, err = organizationRepo.Get(ctx, + organizationRepo.IDCondition(organization.ID), + organization.InstanceID, + ) + require.NoError(t, err) + + assert.Equal(t, tt.organization.ID, organization.ID) + assert.Equal(t, tt.organization.Name, organization.Name) + assert.Equal(t, tt.organization.InstanceID, organization.InstanceID) + assert.Equal(t, tt.organization.State, organization.State) + assert.WithinRange(t, organization.CreatedAt, beforeCreate, afterCreate) + assert.WithinRange(t, organization.UpdatedAt, beforeCreate, afterCreate) + assert.Nil(t, organization.DeletedAt) + }) + } +} + +func TestUpdateOrganization(t *testing.T) { + // create instance + instanceId := gofakeit.Name() + instance := domain.Instance{ + ID: instanceId, + Name: gofakeit.Name(), + DefaultOrgID: "defaultOrgId", + IAMProjectID: "iamProject", + ConsoleClientID: "consoleCLient", + ConsoleAppID: "consoleApp", + DefaultLanguage: "defaultLanguage", + } + instanceRepo := repository.InstanceRepository(pool) + err := instanceRepo.Create(t.Context(), &instance) + assert.Nil(t, err) + organizationRepo := repository.OrganizationRepository(pool) + + tests := []struct { + name string + testFunc func(ctx context.Context, t *testing.T) *domain.Organization + update []database.Change + rowsAffected int64 + }{ + { + name: "happy path update name", + testFunc: func(ctx context.Context, t *testing.T) *domain.Organization { + organizationId := gofakeit.Name() + organizationName := gofakeit.Name() + + org := domain.Organization{ + ID: organizationId, + Name: organizationName, + InstanceID: instanceId, + State: domain.OrgStateActive.String(), + } + + // create organization + err := organizationRepo.Create(ctx, &org) + require.NoError(t, err) + + // update with updated value + org.Name = "new_name" + return &org + }, + update: []database.Change{organizationRepo.SetName("new_name")}, + rowsAffected: 1, + }, + { + name: "update deleted organization", + testFunc: func(ctx context.Context, t *testing.T) *domain.Organization { + organizationId := gofakeit.Name() + organizationName := gofakeit.Name() + + org := domain.Organization{ + ID: organizationId, + Name: organizationName, + InstanceID: instanceId, + State: domain.OrgStateActive.String(), + } + + // create organization + err := organizationRepo.Create(ctx, &org) + require.NoError(t, err) + + // delete instance + _, err = organizationRepo.Delete(ctx, + organizationRepo.IDCondition(org.ID), + org.InstanceID, + ) + require.NoError(t, err) + + return &org + }, + update: []database.Change{organizationRepo.SetName("new_name")}, + rowsAffected: 0, + }, + { + name: "happy path change state", + testFunc: func(ctx context.Context, t *testing.T) *domain.Organization { + organizationId := gofakeit.Name() + organizationName := gofakeit.Name() + + org := domain.Organization{ + ID: organizationId, + Name: organizationName, + InstanceID: instanceId, + State: domain.OrgStateActive.String(), + } + + // create organization + err := organizationRepo.Create(ctx, &org) + require.NoError(t, err) + + // update with updated value + org.State = domain.OrgStateInactive.String() + return &org + }, + update: []database.Change{organizationRepo.SetState(domain.OrgStateInactive)}, + rowsAffected: 1, + }, + { + name: "update non existent organization", + testFunc: func(ctx context.Context, t *testing.T) *domain.Organization { + organizationId := gofakeit.Name() + + org := domain.Organization{ + ID: organizationId, + } + return &org + }, + update: []database.Change{organizationRepo.SetName("new_name")}, + rowsAffected: 0, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctx := context.Background() + organizationRepo := repository.OrganizationRepository(pool) + + createdOrg := tt.testFunc(ctx, t) + + // update org + beforeUpdate := time.Now() + rowsAffected, err := organizationRepo.Update(ctx, + organizationRepo.IDCondition(createdOrg.ID), + createdOrg.InstanceID, + tt.update..., + ) + afterUpdate := time.Now() + require.NoError(t, err) + + assert.Equal(t, tt.rowsAffected, rowsAffected) + + if rowsAffected == 0 { + return + } + + // check organization values + organization, err := organizationRepo.Get(ctx, + organizationRepo.IDCondition(createdOrg.ID), + createdOrg.InstanceID, + ) + require.NoError(t, err) + + assert.Equal(t, createdOrg.ID, organization.ID) + assert.Equal(t, createdOrg.Name, organization.Name) + assert.Equal(t, createdOrg.State, organization.State) + assert.WithinRange(t, organization.UpdatedAt, beforeUpdate, afterUpdate) + assert.Nil(t, organization.DeletedAt) + }) + } +} + +func TestGetOrganization(t *testing.T) { + // create instance + instanceId := gofakeit.Name() + instance := domain.Instance{ + ID: instanceId, + Name: gofakeit.Name(), + DefaultOrgID: "defaultOrgId", + IAMProjectID: "iamProject", + ConsoleClientID: "consoleCLient", + ConsoleAppID: "consoleApp", + DefaultLanguage: "defaultLanguage", + } + instanceRepo := repository.InstanceRepository(pool) + err := instanceRepo.Create(t.Context(), &instance) + assert.Nil(t, err) + + orgRepo := repository.OrganizationRepository(pool) + + // create organization + // this org is created as an additional org which should NOT + // be returned in the results of the tests + org := domain.Organization{ + ID: gofakeit.Name(), + Name: gofakeit.Name(), + InstanceID: instanceId, + State: domain.OrgStateActive.String(), + } + err = orgRepo.Create(t.Context(), &org) + require.NoError(t, err) + + type test struct { + name string + testFunc func(ctx context.Context, t *testing.T) *domain.Organization + orgIdentifierCondition domain.OrgIdentifierCondition + err error + } + + tests := []test{ + func() test { + organizationId := gofakeit.Name() + return test{ + name: "happy path get using id", + testFunc: func(ctx context.Context, t *testing.T) *domain.Organization { + organizationName := gofakeit.Name() + + org := domain.Organization{ + ID: organizationId, + Name: organizationName, + InstanceID: instanceId, + State: domain.OrgStateActive.String(), + } + + // create organization + err := orgRepo.Create(ctx, &org) + require.NoError(t, err) + + return &org + }, + orgIdentifierCondition: orgRepo.IDCondition(organizationId), + } + }(), + func() test { + organizationName := gofakeit.Name() + return test{ + name: "happy path get using name", + testFunc: func(ctx context.Context, t *testing.T) *domain.Organization { + organizationId := gofakeit.Name() + + org := domain.Organization{ + ID: organizationId, + Name: organizationName, + InstanceID: instanceId, + State: domain.OrgStateActive.String(), + } + + // create organization + err := orgRepo.Create(ctx, &org) + require.NoError(t, err) + + return &org + }, + orgIdentifierCondition: orgRepo.NameCondition(organizationName), + } + }(), + { + name: "get non existent organization", + testFunc: func(ctx context.Context, t *testing.T) *domain.Organization { + org := domain.Organization{ + ID: "non existent org", + Name: "non existent org", + } + return &org + }, + orgIdentifierCondition: orgRepo.NameCondition("non-existent-instance-name"), + err: repository.ErrResourceDoesNotExist, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctx := context.Background() + orgRepo := repository.OrganizationRepository(pool) + + var org *domain.Organization + if tt.testFunc != nil { + org = tt.testFunc(ctx, t) + } + + // get org values + returnedOrg, err := orgRepo.Get(ctx, + tt.orgIdentifierCondition, + org.InstanceID, + ) + if tt.err != nil { + require.Equal(t, tt.err, err) + return + } + + if org.Name == "non existent org" { + assert.Nil(t, returnedOrg) + return + } + + assert.Equal(t, returnedOrg.ID, org.ID) + assert.Equal(t, returnedOrg.Name, org.Name) + assert.Equal(t, returnedOrg.InstanceID, org.InstanceID) + assert.Equal(t, returnedOrg.State, org.State) + }) + } +} + +func TestListOrganization(t *testing.T) { + ctx := t.Context() + pool, stop, err := newEmbeddedDB(ctx) + require.NoError(t, err) + defer stop() + organizationRepo := repository.OrganizationRepository(pool) + + // create instance + instanceId := gofakeit.Name() + instance := domain.Instance{ + ID: instanceId, + Name: gofakeit.Name(), + DefaultOrgID: "defaultOrgId", + IAMProjectID: "iamProject", + ConsoleClientID: "consoleCLient", + ConsoleAppID: "consoleApp", + DefaultLanguage: "defaultLanguage", + } + instanceRepo := repository.InstanceRepository(pool) + err = instanceRepo.Create(ctx, &instance) + assert.Nil(t, err) + + type test struct { + name string + testFunc func(ctx context.Context, t *testing.T) []*domain.Organization + conditionClauses []database.Condition + noOrganizationReturned bool + } + tests := []test{ + { + name: "happy path single organization no filter", + testFunc: func(ctx context.Context, t *testing.T) []*domain.Organization { + noOfOrganizations := 1 + organizations := make([]*domain.Organization, noOfOrganizations) + for i := range noOfOrganizations { + + org := domain.Organization{ + ID: gofakeit.Name(), + Name: gofakeit.Name(), + InstanceID: instanceId, + State: domain.OrgStateActive.String(), + } + + // create organization + err := organizationRepo.Create(ctx, &org) + require.NoError(t, err) + + organizations[i] = &org + } + + return organizations + }, + }, + { + name: "happy path multiple organization no filter", + testFunc: func(ctx context.Context, t *testing.T) []*domain.Organization { + noOfOrganizations := 5 + organizations := make([]*domain.Organization, noOfOrganizations) + for i := range noOfOrganizations { + + org := domain.Organization{ + ID: gofakeit.Name(), + Name: gofakeit.Name(), + InstanceID: instanceId, + State: domain.OrgStateActive.String(), + } + + // create organization + err := organizationRepo.Create(ctx, &org) + require.NoError(t, err) + + organizations[i] = &org + } + + return organizations + }, + }, + func() test { + organizationId := gofakeit.Name() + return test{ + name: "organization filter on id", + testFunc: func(ctx context.Context, t *testing.T) []*domain.Organization { + // create organization + // this org is created as an additional org which should NOT + // be returned in the results of this test case + org := domain.Organization{ + ID: gofakeit.Name(), + Name: gofakeit.Name(), + InstanceID: instanceId, + State: domain.OrgStateActive.String(), + } + err = organizationRepo.Create(ctx, &org) + require.NoError(t, err) + + noOfOrganizations := 1 + organizations := make([]*domain.Organization, noOfOrganizations) + for i := range noOfOrganizations { + + org := domain.Organization{ + ID: organizationId, + Name: gofakeit.Name(), + InstanceID: instanceId, + State: domain.OrgStateActive.String(), + } + + // create organization + err := organizationRepo.Create(ctx, &org) + require.NoError(t, err) + + organizations[i] = &org + } + + return organizations + }, + conditionClauses: []database.Condition{organizationRepo.IDCondition(organizationId)}, + } + }(), + { + name: "multiple organization filter on state", + testFunc: func(ctx context.Context, t *testing.T) []*domain.Organization { + // create organization + // this org is created as an additional org which should NOT + // be returned in the results of this test case + org := domain.Organization{ + ID: gofakeit.Name(), + Name: gofakeit.Name(), + InstanceID: instanceId, + State: domain.OrgStateActive.String(), + } + err = organizationRepo.Create(ctx, &org) + require.NoError(t, err) + + noOfOrganizations := 5 + organizations := make([]*domain.Organization, noOfOrganizations) + for i := range noOfOrganizations { + + org := domain.Organization{ + ID: gofakeit.Name(), + Name: gofakeit.Name(), + InstanceID: instanceId, + State: domain.OrgStateInactive.String(), + } + + // create organization + err := organizationRepo.Create(ctx, &org) + require.NoError(t, err) + + organizations[i] = &org + } + + return organizations + }, + conditionClauses: []database.Condition{organizationRepo.StateCondition(domain.OrgStateInactive)}, + }, + func() test { + instanceId_2 := gofakeit.Name() + return test{ + name: "multiple organization filter on instance", + testFunc: func(ctx context.Context, t *testing.T) []*domain.Organization { + // create instance 1 + instanceId_1 := gofakeit.Name() + instance := domain.Instance{ + ID: instanceId_1, + Name: gofakeit.Name(), + DefaultOrgID: "defaultOrgId", + IAMProjectID: "iamProject", + ConsoleClientID: "consoleCLient", + ConsoleAppID: "consoleApp", + DefaultLanguage: "defaultLanguage", + } + instanceRepo := repository.InstanceRepository(pool) + err = instanceRepo.Create(ctx, &instance) + assert.Nil(t, err) + + // create organization + // this org is created as an additional org which should NOT + // be returned in the results of this test case + org := domain.Organization{ + ID: gofakeit.Name(), + Name: gofakeit.Name(), + InstanceID: instanceId_1, + State: domain.OrgStateActive.String(), + } + err = organizationRepo.Create(ctx, &org) + require.NoError(t, err) + + // create instance 2 + instance_2 := domain.Instance{ + ID: instanceId_2, + Name: gofakeit.Name(), + DefaultOrgID: "defaultOrgId", + IAMProjectID: "iamProject", + ConsoleClientID: "consoleCLient", + ConsoleAppID: "consoleApp", + DefaultLanguage: "defaultLanguage", + } + err = instanceRepo.Create(ctx, &instance_2) + assert.Nil(t, err) + + noOfOrganizations := 5 + organizations := make([]*domain.Organization, noOfOrganizations) + for i := range noOfOrganizations { + + org := domain.Organization{ + ID: gofakeit.Name(), + Name: gofakeit.Name(), + InstanceID: instanceId_2, + State: domain.OrgStateActive.String(), + } + + // create organization + err := organizationRepo.Create(ctx, &org) + require.NoError(t, err) + + organizations[i] = &org + } + + return organizations + }, + conditionClauses: []database.Condition{organizationRepo.InstanceIDCondition(instanceId_2)}, + } + }(), + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Cleanup(func() { + _, err := pool.Exec(ctx, "DELETE FROM zitadel.organizations") + require.NoError(t, err) + }) + + organizations := tt.testFunc(ctx, t) + + // check organization values + returnedOrgs, err := organizationRepo.List(ctx, + tt.conditionClauses..., + ) + require.NoError(t, err) + if tt.noOrganizationReturned { + assert.Nil(t, returnedOrgs) + return + } + + assert.Equal(t, len(organizations), len(returnedOrgs)) + for i, org := range organizations { + assert.Equal(t, returnedOrgs[i].ID, org.ID) + assert.Equal(t, returnedOrgs[i].Name, org.Name) + assert.Equal(t, returnedOrgs[i].InstanceID, org.InstanceID) + assert.Equal(t, returnedOrgs[i].State, org.State) + } + }) + } +} + +func TestDeleteOrganization(t *testing.T) { + // create instance + instanceId := gofakeit.Name() + instance := domain.Instance{ + ID: instanceId, + Name: gofakeit.Name(), + DefaultOrgID: "defaultOrgId", + IAMProjectID: "iamProject", + ConsoleClientID: "consoleCLient", + ConsoleAppID: "consoleApp", + DefaultLanguage: "defaultLanguage", + } + instanceRepo := repository.InstanceRepository(pool) + err := instanceRepo.Create(t.Context(), &instance) + assert.Nil(t, err) + + type test struct { + name string + testFunc func(ctx context.Context, t *testing.T) + orgIdentifierCondition domain.OrgIdentifierCondition + noOfDeletedRows int64 + } + tests := []test{ + func() test { + organizationRepo := repository.OrganizationRepository(pool) + organizationId := gofakeit.Name() + var noOfOrganizations int64 = 1 + return test{ + name: "happy path delete organization filter id", + testFunc: func(ctx context.Context, t *testing.T) { + organizations := make([]*domain.Organization, noOfOrganizations) + for i := range noOfOrganizations { + + org := domain.Organization{ + ID: organizationId, + Name: gofakeit.Name(), + InstanceID: instanceId, + State: domain.OrgStateActive.String(), + } + + // create organization + err := organizationRepo.Create(ctx, &org) + require.NoError(t, err) + + organizations[i] = &org + } + }, + orgIdentifierCondition: organizationRepo.IDCondition(organizationId), + noOfDeletedRows: noOfOrganizations, + } + }(), + func() test { + organizationRepo := repository.OrganizationRepository(pool) + organizationName := gofakeit.Name() + var noOfOrganizations int64 = 1 + return test{ + name: "happy path delete organization filter name", + testFunc: func(ctx context.Context, t *testing.T) { + organizations := make([]*domain.Organization, noOfOrganizations) + for i := range noOfOrganizations { + + org := domain.Organization{ + ID: gofakeit.Name(), + Name: organizationName, + InstanceID: instanceId, + State: domain.OrgStateActive.String(), + } + + // create organization + err := organizationRepo.Create(ctx, &org) + require.NoError(t, err) + + organizations[i] = &org + } + }, + orgIdentifierCondition: organizationRepo.NameCondition(organizationName), + noOfDeletedRows: noOfOrganizations, + } + }(), + func() test { + organizationRepo := repository.OrganizationRepository(pool) + non_existent_organization_name := gofakeit.Name() + return test{ + name: "delete non existent organization", + orgIdentifierCondition: organizationRepo.NameCondition(non_existent_organization_name), + } + }(), + func() test { + organizationRepo := repository.OrganizationRepository(pool) + organizationName := gofakeit.Name() + return test{ + name: "deleted already deleted organization", + testFunc: func(ctx context.Context, t *testing.T) { + noOfOrganizations := 1 + organizations := make([]*domain.Organization, noOfOrganizations) + for i := range noOfOrganizations { + + org := domain.Organization{ + ID: gofakeit.Name(), + Name: organizationName, + InstanceID: instanceId, + State: domain.OrgStateActive.String(), + } + + // create organization + err := organizationRepo.Create(ctx, &org) + require.NoError(t, err) + + organizations[i] = &org + } + + // delete organization + affectedRows, err := organizationRepo.Delete(ctx, + organizationRepo.NameCondition(organizationName), + organizations[0].InstanceID, + ) + assert.Equal(t, int64(1), affectedRows) + require.NoError(t, err) + }, + orgIdentifierCondition: organizationRepo.NameCondition(organizationName), + // this test should return 0 affected rows as the org was already deleted + noOfDeletedRows: 0, + } + }(), + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctx := context.Background() + organizationRepo := repository.OrganizationRepository(pool) + + if tt.testFunc != nil { + tt.testFunc(ctx, t) + } + + // delete organization + noOfDeletedRows, err := organizationRepo.Delete(ctx, + tt.orgIdentifierCondition, + instanceId, + ) + require.NoError(t, err) + assert.Equal(t, noOfDeletedRows, tt.noOfDeletedRows) + + // check organization was deleted + organization, err := organizationRepo.Get(ctx, + tt.orgIdentifierCondition, + instanceId, + ) + require.Equal(t, err, repository.ErrResourceDoesNotExist) + assert.Nil(t, organization) + }) + } +} diff --git a/backend/v3/storage/database/repository/repository.go b/backend/v3/storage/database/repository/repository.go index 8181883bdd..9abf656ccc 100644 --- a/backend/v3/storage/database/repository/repository.go +++ b/backend/v3/storage/database/repository/repository.go @@ -1,7 +1,24 @@ package repository -import "github.com/zitadel/zitadel/backend/v3/storage/database" +import ( + "errors" + + "github.com/zitadel/zitadel/backend/v3/storage/database" +) + +var ErrResourceDoesNotExist = errors.New("resource does not exist") type repository struct { client database.QueryExecutor } + +func writeCondition( + builder *database.StatementBuilder, + condition database.Condition, +) { + if condition == nil { + return + } + builder.WriteString(" WHERE ") + condition.Write(builder) +} diff --git a/backend/v3/storage/database/repository/user.go b/backend/v3/storage/database/repository/user.go index 1adc22c3d6..737b845a10 100644 --- a/backend/v3/storage/database/repository/user.go +++ b/backend/v3/storage/database/repository/user.go @@ -123,7 +123,7 @@ func (u *user) Create(ctx context.Context, user *domain.User) error { func (u *user) Delete(ctx context.Context, condition database.Condition) error { builder := database.StatementBuilder{} builder.WriteString("DELETE FROM users") - u.writeCondition(builder, condition) + writeCondition(&builder, condition) _, err := u.client.Exec(ctx, builder.String(), builder.Args()...) return err } @@ -223,17 +223,6 @@ func (user) DeletedAtColumn() database.Column { return database.NewColumn("deleted_at") } -func (u *user) writeCondition( - builder database.StatementBuilder, - condition database.Condition, -) { - if condition == nil { - return - } - builder.WriteString(" WHERE ") - condition.Write(&builder) -} - func (u user) columns() database.Columns { return database.Columns{ u.InstanceIDColumn(), diff --git a/backend/v3/storage/database/repository/user_human.go b/backend/v3/storage/database/repository/user_human.go index 1bef85cfee..0dbe31fc47 100644 --- a/backend/v3/storage/database/repository/user_human.go +++ b/backend/v3/storage/database/repository/user_human.go @@ -26,7 +26,7 @@ func (u *userHuman) GetEmail(ctx context.Context, condition database.Condition) builder := database.StatementBuilder{} builder.WriteString(userEmailQuery) - u.writeCondition(builder, condition) + writeCondition(&builder, condition) err := u.client.QueryRow(ctx, builder.String(), builder.Args()...).Scan( &email.Address, @@ -43,7 +43,7 @@ func (h userHuman) Update(ctx context.Context, condition database.Condition, cha builder := database.StatementBuilder{} builder.WriteString(`UPDATE human_users SET `) database.Changes(changes).Write(&builder) - h.writeCondition(builder, condition) + writeCondition(&builder, condition) stmt := builder.String() diff --git a/backend/v3/storage/database/repository/user_machine.go b/backend/v3/storage/database/repository/user_machine.go index 766f76a46d..9233942375 100644 --- a/backend/v3/storage/database/repository/user_machine.go +++ b/backend/v3/storage/database/repository/user_machine.go @@ -22,7 +22,7 @@ func (m userMachine) Update(ctx context.Context, condition database.Condition, c builder := database.StatementBuilder{} builder.WriteString("UPDATE user_machines SET ") database.Changes(changes).Write(&builder) - m.writeCondition(builder, condition) + writeCondition(&builder, condition) m.writeReturning() _, err := m.client.Exec(ctx, builder.String(), builder.Args()...) diff --git a/go.mod b/go.mod index 09dd32e822..a5917e1f72 100644 --- a/go.mod +++ b/go.mod @@ -28,6 +28,7 @@ require ( github.com/fatih/color v1.18.0 github.com/fergusstrange/embedded-postgres v1.30.0 github.com/gabriel-vasile/mimetype v1.4.9 + github.com/georgysavva/scany/v2 v2.1.4 github.com/go-chi/chi/v5 v5.2.1 github.com/go-jose/go-jose/v4 v4.1.0 github.com/go-ldap/ldap/v3 v3.4.11 diff --git a/go.sum b/go.sum index d511119796..0b4cbef44e 100644 --- a/go.sum +++ b/go.sum @@ -233,6 +233,8 @@ github.com/fxamacker/cbor/v2 v2.7.0 h1:iM5WgngdRBanHcxugY4JySA0nk1wZorNOpTgCMedv github.com/fxamacker/cbor/v2 v2.7.0/go.mod h1:pxXPTn3joSm21Gbwsv0w9OSA2y1HFR9qXEeXQVeNoDQ= github.com/gabriel-vasile/mimetype v1.4.9 h1:5k+WDwEsD9eTLL8Tz3L0VnmVh9QxGjRmjBvAG7U/oYY= github.com/gabriel-vasile/mimetype v1.4.9/go.mod h1:WnSQhFKJuBlRyLiKohA/2DtIlPFAbguNaG7QCHcyGok= +github.com/georgysavva/scany/v2 v2.1.4 h1:nrzHEJ4oQVRoiKmocRqA1IyGOmM/GQOEsg9UjMR5Ip4= +github.com/georgysavva/scany/v2 v2.1.4/go.mod h1:fqp9yHZzM/PFVa3/rYEC57VmDx+KDch0LoqrJzkvtos= github.com/ghodss/yaml v1.0.0/go.mod h1:4dBDuWmgqj2HViK6kFavaiC9ZROes6MMH2rRYeMEF04= github.com/gin-contrib/sse v0.1.0/go.mod h1:RHrZQHXnP2xjPF+u1gW/2HnVO7nvIa9PG3Gm+fLHvGI= github.com/gin-gonic/gin v1.6.3/go.mod h1:75u5sXoLsGZoRN5Sgbi1eraJ4GU3++wFwWzhwvtwp4M= @@ -308,8 +310,9 @@ github.com/golang-jwt/jwt/v4 v4.5.2 h1:YtQM7lnr8iZ+j5q71MGKkNw9Mn7AjHM68uc9g5fXe github.com/golang-jwt/jwt/v4 v4.5.2/go.mod h1:m21LjoU+eqJr34lmDMbreY2eSTRJ1cv77w39/MY0Ch0= github.com/golang-jwt/jwt/v5 v5.2.2 h1:Rl4B7itRWVtYIHFrSNd7vhTiz9UpLdi6gZhZ3wEeDy8= github.com/golang-jwt/jwt/v5 v5.2.2/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk= -github.com/golang-sql/civil v0.0.0-20190719163853-cb61b32ac6fe h1:lXe2qZdvpiX5WZkZR4hgp4KJVfY3nMkvmwbVkpv1rVY= github.com/golang-sql/civil v0.0.0-20190719163853-cb61b32ac6fe/go.mod h1:8vg3r2VgvsThLBIFL93Qb5yWzgyZWhEmBwUJWevAkK0= +github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9 h1:au07oEsX2xN0ktxqI+Sida1w446QrXBRJ0nee3SNZlA= +github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9/go.mod h1:8vg3r2VgvsThLBIFL93Qb5yWzgyZWhEmBwUJWevAkK0= github.com/golang/freetype v0.0.0-20170609003504-e2365dfdc4a0/go.mod h1:E/TSTwGwJL78qG/PmXZO1EjYhfJinVAhrmmHX6Z8B9k= github.com/golang/geo v0.0.0-20190916061304-5b978397cfec/go.mod h1:QZ0nwyI2jOfgRAoBvP+ab5aRr7c9x7lhGEJrKvBwjWI= github.com/golang/geo v0.0.0-20200319012246-673a6f80352d/go.mod h1:QZ0nwyI2jOfgRAoBvP+ab5aRr7c9x7lhGEJrKvBwjWI= @@ -764,6 +767,8 @@ github.com/streadway/amqp v0.0.0-20190827072141-edfb9018d271/go.mod h1:AZpEONHx3 github.com/streadway/handy v0.0.0-20190108123426-d5acb3125c2a/go.mod h1:qNTQ5P5JnDBl6z3cMAg/SywNDC5ABu5ApDIw6lUbRmI= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/objx v0.5.2 h1:xuMeJ0Sdp5ZMRXx/aWO6RZxdr3beISkG5/G/aIRr3pY= +github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA= github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= diff --git a/internal/query/projection/org_relational.go b/internal/query/projection/org_relational.go new file mode 100644 index 0000000000..2250ab9351 --- /dev/null +++ b/internal/query/projection/org_relational.go @@ -0,0 +1,185 @@ +package projection + +import ( + "context" + + repoDomain "github.com/zitadel/zitadel/backend/v3/domain" + "github.com/zitadel/zitadel/internal/eventstore" + "github.com/zitadel/zitadel/internal/eventstore/handler/v2" + "github.com/zitadel/zitadel/internal/repository/instance" + "github.com/zitadel/zitadel/internal/repository/org" + "github.com/zitadel/zitadel/internal/zerrors" +) + +const ( + OrgRelationProjectionTable = "zitadel.organizations" +) + +type orgRelationalProjection struct{} + +func (*orgRelationalProjection) Name() string { + return OrgRelationProjectionTable +} + +func newOrgRelationalProjection(ctx context.Context, config handler.Config) *handler.Handler { + return handler.NewHandler(ctx, &config, new(orgRelationalProjection)) +} + +func (p *orgRelationalProjection) Reducers() []handler.AggregateReducer { + return []handler.AggregateReducer{ + { + Aggregate: org.AggregateType, + EventReducers: []handler.EventReducer{ + { + Event: org.OrgAddedEventType, + Reduce: p.reduceOrgRelationalAdded, + }, + { + Event: org.OrgChangedEventType, + Reduce: p.reduceOrgRelationalChanged, + }, + { + Event: org.OrgDeactivatedEventType, + Reduce: p.reduceOrgRelationalDeactivated, + }, + { + Event: org.OrgReactivatedEventType, + Reduce: p.reduceOrgRelationalReactivated, + }, + { + Event: org.OrgRemovedEventType, + Reduce: p.reduceOrgRelationalRemoved, + }, + // TODO + // { + // Event: org.OrgDomainPrimarySetEventType, + // Reduce: p.reducePrimaryDomainSetRelational, + // }, + }, + }, + { + Aggregate: instance.AggregateType, + EventReducers: []handler.EventReducer{ + { + Event: instance.InstanceRemovedEventType, + Reduce: reduceInstanceRemovedHelper(OrgColumnInstanceID), + }, + }, + }, + } +} + +func (p *orgRelationalProjection) reduceOrgRelationalAdded(event eventstore.Event) (*handler.Statement, error) { + e, ok := event.(*org.OrgAddedEvent) + if !ok { + return nil, zerrors.ThrowInvalidArgumentf(nil, "HANDL-uYq5R", "reduce.wrong.event.type %s", org.OrgAddedEventType) + } + + return handler.NewCreateStatement( + e, + []handler.Column{ + handler.NewCol(OrgColumnID, e.Aggregate().ID), + handler.NewCol(OrgColumnName, e.Name), + handler.NewCol(OrgColumnInstanceID, e.Aggregate().InstanceID), + handler.NewCol(State, repoDomain.OrgStateActive.String()), + handler.NewCol(CreatedAt, e.CreationDate()), + handler.NewCol(UpdatedAt, e.CreationDate()), + }, + ), nil +} + +func (p *orgRelationalProjection) reduceOrgRelationalChanged(event eventstore.Event) (*handler.Statement, error) { + e, ok := event.(*org.OrgChangedEvent) + if !ok { + return nil, zerrors.ThrowInvalidArgumentf(nil, "HANDL-Bg9om", "reduce.wrong.event.type %s", org.OrgChangedEventType) + } + if e.Name == "" { + return handler.NewNoOpStatement(e), nil + } + return handler.NewUpdateStatement( + e, + []handler.Column{ + handler.NewCol(OrgColumnName, e.Name), + handler.NewCol(UpdatedAt, e.CreationDate()), + }, + []handler.Condition{ + handler.NewCond(OrgColumnID, e.Aggregate().ID), + handler.NewCond(OrgColumnInstanceID, e.Aggregate().InstanceID), + }, + ), nil +} + +func (p *orgRelationalProjection) reduceOrgRelationalDeactivated(event eventstore.Event) (*handler.Statement, error) { + e, ok := event.(*org.OrgDeactivatedEvent) + if !ok { + return nil, zerrors.ThrowInvalidArgumentf(nil, "HANDL-BApK5", "reduce.wrong.event.type %s", org.OrgDeactivatedEventType) + } + + return handler.NewUpdateStatement( + e, + []handler.Column{ + handler.NewCol(State, repoDomain.OrgStateInactive.String()), + handler.NewCol(UpdatedAt, e.CreationDate()), + }, + []handler.Condition{ + handler.NewCond(OrgColumnID, e.Aggregate().ID), + handler.NewCond(OrgColumnInstanceID, e.Aggregate().InstanceID), + }, + ), nil +} + +func (p *orgRelationalProjection) reduceOrgRelationalReactivated(event eventstore.Event) (*handler.Statement, error) { + e, ok := event.(*org.OrgReactivatedEvent) + if !ok { + return nil, zerrors.ThrowInvalidArgumentf(nil, "HANDL-o38DE", "reduce.wrong.event.type %s", org.OrgReactivatedEventType) + } + return handler.NewUpdateStatement( + e, + []handler.Column{ + handler.NewCol(State, repoDomain.OrgStateActive.String()), + handler.NewCol(UpdatedAt, e.CreationDate()), + }, + []handler.Condition{ + handler.NewCond(OrgColumnID, e.Aggregate().ID), + handler.NewCond(OrgColumnInstanceID, e.Aggregate().InstanceID), + }, + ), nil +} + +// TODO +// func (p *orgRelationalProjection) reducePrimaryDomainSetRelational(event eventstore.Event) (*handler.Statement, error) { +// e, ok := event.(*org.DomainPrimarySetEvent) +// if !ok { +// return nil, zerrors.ThrowInvalidArgumentf(nil, "HANDL-3Tbkt", "reduce.wrong.event.type %s", org.OrgDomainPrimarySetEventType) +// } +// return handler.NewUpdateStatement( +// e, +// []handler.Column{ +// handler.NewCol(OrgColumnChangeDate, e.CreationDate()), +// handler.NewCol(OrgColumnSequence, e.Sequence()), +// handler.NewCol(OrgColumnDomain, e.Domain), +// }, +// []handler.Condition{ +// handler.NewCond(OrgColumnID, e.Aggregate().ID), +// handler.NewCond(OrgColumnInstanceID, e.Aggregate().InstanceID), +// }, +// ), nil +// } + +func (p *orgRelationalProjection) reduceOrgRelationalRemoved(event eventstore.Event) (*handler.Statement, error) { + e, ok := event.(*org.OrgRemovedEvent) + if !ok { + return nil, zerrors.ThrowInvalidArgumentf(nil, "PROJE-DGm9g", "reduce.wrong.event.type %s", org.OrgRemovedEventType) + } + return handler.NewUpdateStatement( + e, + []handler.Column{ + handler.NewCol(UpdatedAt, e.CreationDate()), + handler.NewCol(DeletedAt, e.CreationDate()), + }, + []handler.Condition{ + handler.NewCond(OrgColumnID, e.Aggregate().ID), + handler.NewCond(OrgColumnInstanceID, e.Aggregate().InstanceID), + }, + ), nil +} diff --git a/internal/query/projection/projection.go b/internal/query/projection/projection.go index f05864396d..c4d82d843f 100644 --- a/internal/query/projection/projection.go +++ b/internal/query/projection/projection.go @@ -61,6 +61,7 @@ var ( UserAuthMethodProjection *handler.Handler InstanceProjection *handler.Handler InstanceRelationalProjection *handler.Handler + OrganizationRelationalProjection *handler.Handler SecretGeneratorProjection *handler.Handler SMTPConfigProjection *handler.Handler SMSConfigProjection *handler.Handler @@ -157,7 +158,8 @@ func Create(ctx context.Context, sqlClient *database.DB, es handler.EventStore, UserMetadataProjection = newUserMetadataProjection(ctx, applyCustomConfig(projectionConfig, config.Customizations["user_metadata"])) UserAuthMethodProjection = newUserAuthMethodProjection(ctx, applyCustomConfig(projectionConfig, config.Customizations["user_auth_method"])) InstanceProjection = newInstanceProjection(ctx, applyCustomConfig(projectionConfig, config.Customizations["instances"])) - InstanceRelationalProjection = newInstanceRelationalProjection(ctx, applyCustomConfig(projectionConfig, config.Customizations["instances_relational"])) + InstanceRelationalProjection = newInstanceRelationalProjection(ctx, applyCustomConfig(projectionConfig, config.Customizations["organizations_relational"])) + OrganizationRelationalProjection = newOrgRelationalProjection(ctx, applyCustomConfig(projectionConfig, config.Customizations["instances_relational"])) SecretGeneratorProjection = newSecretGeneratorProjection(ctx, applyCustomConfig(projectionConfig, config.Customizations["secret_generators"])) SMTPConfigProjection = newSMTPConfigProjection(ctx, applyCustomConfig(projectionConfig, config.Customizations["smtp_configs"])) SMSConfigProjection = newSMSConfigProjection(ctx, applyCustomConfig(projectionConfig, config.Customizations["sms_config"])) @@ -337,6 +339,7 @@ func newProjectionsList() { UserAuthMethodProjection, InstanceProjection, InstanceRelationalProjection, + OrganizationRelationalProjection, SecretGeneratorProjection, SMTPConfigProjection, SMSConfigProjection, diff --git a/internal/query/projection/relational_common.go b/internal/query/projection/relational_common.go index 0140ea3559..7b0d957661 100644 --- a/internal/query/projection/relational_common.go +++ b/internal/query/projection/relational_common.go @@ -1,6 +1,7 @@ package projection const ( + State = "state" CreatedAt = "created_at" UpdatedAt = "updated_at" DeletedAt = "deleted_at" diff --git a/internal/repository/org/domain.go b/internal/repository/org/domain.go index 0b722b3ca0..85b8a939f6 100644 --- a/internal/repository/org/domain.go +++ b/internal/repository/org/domain.go @@ -117,7 +117,8 @@ func NewDomainVerificationAddedEvent( aggregate *eventstore.Aggregate, domain string, validationType domain.OrgDomainValidationType, - validationCode *crypto.CryptoValue) *DomainVerificationAddedEvent { + validationCode *crypto.CryptoValue, +) *DomainVerificationAddedEvent { return &DomainVerificationAddedEvent{ BaseEvent: *eventstore.NewBaseEventForPush( ctx,