diff --git a/backend/v3/domain/id_provider.go b/backend/v3/domain/id_provider.go index 68ff8d242a..c5777f14dd 100644 --- a/backend/v3/domain/id_provider.go +++ b/backend/v3/domain/id_provider.go @@ -31,21 +31,28 @@ const ( type IdentityProvider struct { InstanceID string `json:"instanceId,omitempty" db:"instance_id"` - OrgID string `json:"org_id,omitempty" db:"org_id"` + OrgID string `json:"orgId,omitempty" db:"org_id"` ID string `json:"id,omitempty" db:"id"` State string `json:"state,omitempty" db:"state"` Name string `json:"name,omitempty" db:"name"` Type string `json:"type,omitempty" db:"type"` - AllowCreation bool `json:"allow_creation,omitempty" db:"allow_creation"` - AllowAutoCreation bool `json:"allow_auto_creation,omitempty" db:"allow_auto_creation"` - AllowAutoUpdate bool `json:"allow_auto_update,omitempty" db:"allow_auto_update"` - AllowLinking bool `json:"allow_linking,omitempty" db:"allow_linking"` - StylingType int16 `json:"styling_type,omitempty" db:"styling_type"` + AllowCreation bool `json:"allowCreation,omitempty" db:"allow_creation"` + AllowAutoCreation bool `json:"allowAutoCreation,omitempty" db:"allow_auto_creation"` + AllowAutoUpdate bool `json:"allowAutoUpdate,omitempty" db:"allow_auto_update"` + AllowLinking bool `json:"allowLinking,omitempty" db:"allow_linking"` + StylingType int16 `json:"stylingType,omitempty" db:"styling_type"` Payload string `json:"payload,omitempty" db:"payload"` CreatedAt time.Time `json:"createdAt,omitempty" db:"created_at"` UpdatedAt time.Time `json:"updatedAt,omitempty" db:"updated_at"` } +// IDPIdentifierCondition is used to help specify a single identity_provider, +// it will either be used as the identity_provider ID or identity_provider name, +// as identity_provider can be identified either using (instnaceID + OrgID + ID) OR (instanceID + OrgID + name) +type IDPIdentifierCondition interface { + database.Condition +} + type idProviderColumns interface { InstanceIDColumn() database.Column OrgIDColumn() database.Column @@ -66,9 +73,9 @@ type idProviderColumns interface { type idProviderConditions interface { InstanceIDCondition(id string) database.Condition OrgIDCondition(id string) database.Condition - IDCondition(id string) database.Condition + IDCondition(id string) IDPIdentifierCondition StateCondition(state IDPState) database.Condition - NameCondition(name string) database.Condition + NameCondition(name string) IDPIdentifierCondition TypeCondition(typee IDPType) database.Condition AllowCreationCondition(allow bool) database.Condition AllowAutoCreationCondition(allow bool) database.Condition @@ -94,10 +101,10 @@ type IDProviderRepository interface { idProviderConditions idProviderChanges - Get(ctx context.Context, id string) (*IdentityProvider, error) + Get(ctx context.Context, id IDPIdentifierCondition, instnaceID string, orgID string) (*IdentityProvider, error) List(ctx context.Context, conditions ...database.Condition) ([]*IdentityProvider, error) Create(ctx context.Context, idp *IdentityProvider) error - Update(ctx context.Context, id string, changes ...database.Change) (int64, error) - Delete(ctx context.Context, id string) (int64, error) + Update(ctx context.Context, id IDPIdentifierCondition, changes ...database.Change) (int64, error) + Delete(ctx context.Context, id IDPIdentifierCondition) (int64, error) } diff --git a/backend/v3/storage/database/dialect/postgres/migration/003_identity_providers_table.go b/backend/v3/storage/database/dialect/postgres/migration/003_identity_providers_table.go new file mode 100644 index 0000000000..56a14ffcd5 --- /dev/null +++ b/backend/v3/storage/database/dialect/postgres/migration/003_identity_providers_table.go @@ -0,0 +1,16 @@ +package migration + +import ( + _ "embed" +) + +var ( + //go:embed 003_identity_providers_table/up.sql + up003IdentityProvidersTable string + //go:embed 003_identity_providers_table/down.sql + down003IdentityProvidersTable string +) + +func init() { + registerSQLMigration(3, up003IdentityProvidersTable, down003IdentityProvidersTable) +} diff --git a/backend/v3/storage/database/dialect/postgres/migration/003_id_providers_table/down.sql b/backend/v3/storage/database/dialect/postgres/migration/003_identity_providers_table/down.sql similarity index 100% rename from backend/v3/storage/database/dialect/postgres/migration/003_id_providers_table/down.sql rename to backend/v3/storage/database/dialect/postgres/migration/003_identity_providers_table/down.sql diff --git a/backend/v3/storage/database/dialect/postgres/migration/003_id_providers_table/up.sql b/backend/v3/storage/database/dialect/postgres/migration/003_identity_providers_table/up.sql similarity index 65% rename from backend/v3/storage/database/dialect/postgres/migration/003_id_providers_table/up.sql rename to backend/v3/storage/database/dialect/postgres/migration/003_identity_providers_table/up.sql index ee872b72fa..c1bddccb64 100644 --- a/backend/v3/storage/database/dialect/postgres/migration/003_id_providers_table/up.sql +++ b/backend/v3/storage/database/dialect/postgres/migration/003_identity_providers_table/up.sql @@ -14,12 +14,12 @@ CREATE TYPE idp_type AS ENUM ( 'apple' ); -CREATE TABLE identity_providers ( +CREATE TABLE zitadel.identity_providers ( instance_id TEXT NOT NULL , org_id TEXT - , id TEXT NOT NULL + , id TEXT NOT NULL CHECK (id <> '') , state idp_state NOT NULL DEFAULT 'active' - , name TEXT + , name TEXT NOT NULL CHECK (name <> '') , type idp_type NOT NULL , allow_creation BOOLEAN NOT NULL DEFAULT TRUE , allow_auto_creation BOOLEAN NOT NULL DEFAULT TRUE @@ -32,14 +32,15 @@ CREATE TABLE identity_providers ( , updated_at TIMESTAMPTZ NOT NULL DEFAULT now() , PRIMARY KEY (instance_id, id) - , CONSTRAINT identity_providers_unique UNIQUE NULLS NOT DISTINCT (instance_id, org_id, id) - , FOREIGN KEY (instance_id) REFERENCES instances(id) - , FOREIGN KEY (instance_id, org_id) REFERENCES organizations(instance_id, id) + , CONSTRAINT identity_providers_id_unique UNIQUE NULLS NOT DISTINCT (instance_id, org_id, id) + , CONSTRAINT identity_providers_name_unique UNIQUE NULLS NOT DISTINCT (instance_id, org_id, name) + , FOREIGN KEY (instance_id) REFERENCES zitadel.instances(id) + , FOREIGN KEY (instance_id, org_id) REFERENCES zitadel.organizations(instance_id, id) ); -- CREATE INDEX idx_identity_providers_org_id ON identity_providers(instance_id, org_id) WHERE org_id IS NOT NULL; -CREATE INDEX idx_identity_providers_state ON identity_providers(instance_id, state); -CREATE INDEX idx_identity_providers_type ON identity_providers(instance_id, type); +CREATE INDEX idx_identity_providers_state ON zitadel.identity_providers(instance_id, state); +CREATE INDEX idx_identity_providers_type ON zitadel.identity_providers(instance_id, type); -- CREATE INDEX idx_identity_providers_created_at ON identity_providers(created_at); -- CREATE INDEX idx_identity_providers_deleted_at ON identity_providers(deleted_at) WHERE deleted_at IS NOT NULL; diff --git a/backend/v3/storage/database/repository/id_provider.go b/backend/v3/storage/database/repository/id_provider.go index b8629a9a75..78096d7275 100644 --- a/backend/v3/storage/database/repository/id_provider.go +++ b/backend/v3/storage/database/repository/id_provider.go @@ -22,18 +22,18 @@ func IDProviderRepository(client database.QueryExecutor) domain.IDProviderReposi } } -const queryIDProviderStmt = `instance_id, org_id, id, state, name, type, allow_creation, allow_auto_creation,` + +const queryIDProviderStmt = `SELECT instance_id, org_id, id, state, name, type, allow_creation, allow_auto_creation,` + ` allow_auto_update, allow_linking, styling_type, payload, created_at, updated_at` + ` FROM zitadel.identity_providers` -func (i *idProvider) Get(ctx context.Context, id string) (*domain.IdentityProvider, error) { +func (i *idProvider) Get(ctx context.Context, id domain.IDPIdentifierCondition, instnaceID string, orgID string) (*domain.IdentityProvider, error) { builder := database.StatementBuilder{} builder.WriteString(queryIDProviderStmt) - idCondition := i.IDCondition(id) + conditions := []database.Condition{id, i.InstanceIDCondition(instnaceID), i.OrgIDCondition(orgID)} - writeCondition(&builder, idCondition) + writeCondition(&builder, database.And(conditions...)) return scanIDProvider(ctx, i.client, &builder) } @@ -56,13 +56,13 @@ func (i *idProvider) List(ctx context.Context, conditions ...database.Condition) const createIDProviderStmt = `INSERT INTO zitadel.identity_providers` + ` (instance_id, org_id, id, state, name, type, allow_creation, allow_auto_creation,` + ` allow_auto_update, allow_linking, styling_type, payload)` + - ` VALUES ($1, $2, $3, $4, $5, $6, $,7, $8, $9, $10, $11, $12)` + + ` VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12)` + ` RETURNING created_at, updated_at` func (i *idProvider) Create(ctx context.Context, idp *domain.IdentityProvider) error { builder := database.StatementBuilder{} builder.AppendArgs(idp.InstanceID, idp.OrgID, idp.ID, idp.State, idp.Name, idp.Type, idp.AllowCreation, - idp.AllowAutoCreation, idp.AllowLinking, idp.StylingType, idp.Payload) + idp.AllowAutoCreation, idp.AllowAutoUpdate, idp.AllowLinking, idp.StylingType, idp.Payload) builder.WriteString(createIDProviderStmt) err := i.client.QueryRow(ctx, builder.String(), builder.Args()...).Scan(&idp.CreatedAt, &idp.UpdatedAt) @@ -72,14 +72,15 @@ func (i *idProvider) Create(ctx context.Context, idp *domain.IdentityProvider) e return nil } -func (i *idProvider) Update(ctx context.Context, id string, changes ...database.Change) (int64, error) { +func (i *idProvider) Update(ctx context.Context, id domain.IDPIdentifierCondition, changes ...database.Change) (int64, error) { if changes == nil { return 0, errors.New("Update must contain a condition") // (otherwise ALL identity_providers will be updated) } builder := database.StatementBuilder{} builder.WriteString(`UPDATE zitadel.identity_provider SET `) - conditions := []database.Condition{i.IDCondition(id)} + // conditions := []database.Condition{i.IDCondition(id)} + conditions := []database.Condition{id} database.Changes(changes).Write(&builder) writeCondition(&builder, database.And(conditions...)) @@ -89,13 +90,13 @@ func (i *idProvider) Update(ctx context.Context, id string, changes ...database. return rowsAffected, err } -func (i *idProvider) Delete(ctx context.Context, id string) (int64, error) { +func (i *idProvider) Delete(ctx context.Context, id domain.IDPIdentifierCondition) (int64, error) { builder := database.StatementBuilder{} builder.WriteString(`DELETE FROM zitadel.identity_providers`) - conditions := []database.Condition{i.IDCondition(id)} - writeCondition(&builder, database.And(conditions...)) + // conditions := []database.Condition{i.IDCondition(id)} + // writeCondition(&builder, database.And(conditions...)) return i.client.Exec(ctx, builder.String(), builder.Args()...) } @@ -172,7 +173,7 @@ func (i idProvider) OrgIDCondition(id string) database.Condition { return database.NewTextCondition(i.OrgIDColumn(), database.TextOperationEqual, id) } -func (i idProvider) IDCondition(id string) database.Condition { +func (i idProvider) IDCondition(id string) domain.IDPIdentifierCondition { return database.NewTextCondition(i.IDColumn(), database.TextOperationEqual, id) } @@ -180,7 +181,7 @@ func (i idProvider) StateCondition(state domain.IDPState) database.Condition { return database.NewTextCondition(i.OrgIDColumn(), database.TextOperationEqual, state.String()) } -func (i idProvider) NameCondition(name string) database.Condition { +func (i idProvider) NameCondition(name string) domain.IDPIdentifierCondition { return database.NewTextCondition(i.NameColumn(), database.TextOperationEqual, name) } diff --git a/backend/v3/storage/database/repository/id_provider_test.go b/backend/v3/storage/database/repository/id_provider_test.go new file mode 100644 index 0000000000..ef63cbece9 --- /dev/null +++ b/backend/v3/storage/database/repository/id_provider_test.go @@ -0,0 +1,345 @@ +package repository_test + +import ( + "context" + "testing" + "time" + + "github.com/brianvoe/gofakeit/v6" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/zitadel/zitadel/backend/v3/domain" + "github.com/zitadel/zitadel/backend/v3/storage/database" + "github.com/zitadel/zitadel/backend/v3/storage/database/repository" +) + +func TestCreateIDProvider(t *testing.T) { + // create instance + instanceId := gofakeit.Name() + instance := domain.Instance{ + ID: instanceId, + Name: gofakeit.Name(), + DefaultOrgID: "defaultOrgId", + IAMProjectID: "iamProject", + ConsoleClientID: "consoleCLient", + ConsoleAppID: "consoleApp", + DefaultLanguage: "defaultLanguage", + } + instanceRepo := repository.InstanceRepository(pool) + err := instanceRepo.Create(t.Context(), &instance) + require.NoError(t, err) + + // create org + orgId := gofakeit.Name() + org := domain.Organization{ + ID: orgId, + Name: gofakeit.Name(), + InstanceID: instanceId, + State: domain.OrgStateActive.String(), + } + organizationRepo := repository.OrganizationRepository(pool) + err = organizationRepo.Create(t.Context(), &org) + require.NoError(t, err) + + type test struct { + name string + testFunc func(ctx context.Context, t *testing.T) *domain.IdentityProvider + idp domain.IdentityProvider + err error + } + + // TESTS + tests := []test{ + { + name: "happy path", + idp: domain.IdentityProvider{ + InstanceID: instanceId, + OrgID: orgId, + ID: gofakeit.Name(), + State: domain.IDPStateActive.String(), + Name: gofakeit.Name(), + Type: domain.IDPTypeOIDC.String(), + AllowCreation: true, + AllowAutoCreation: true, + AllowAutoUpdate: true, + AllowLinking: true, + StylingType: 1, + Payload: "{}", + }, + }, + { + name: "create organization without name", + idp: domain.IdentityProvider{ + InstanceID: instanceId, + OrgID: orgId, + ID: gofakeit.Name(), + State: domain.IDPStateActive.String(), + // Name: gofakeit.Name(), + Type: domain.IDPTypeOIDC.String(), + AllowCreation: true, + AllowAutoCreation: true, + AllowAutoUpdate: true, + AllowLinking: true, + StylingType: 1, + Payload: "{}", + }, + err: new(database.CheckError), + }, + { + name: "adding org with same id twice", + testFunc: func(ctx context.Context, t *testing.T) *domain.IdentityProvider { + idpRepo := repository.IDProviderRepository(pool) + + idp := domain.IdentityProvider{ + InstanceID: instanceId, + OrgID: orgId, + ID: gofakeit.Name(), + State: domain.IDPStateActive.String(), + Name: gofakeit.Name(), + Type: domain.IDPTypeOIDC.String(), + AllowCreation: true, + AllowAutoCreation: true, + AllowAutoUpdate: true, + AllowLinking: true, + StylingType: 1, + Payload: "{}", + } + + err := idpRepo.Create(ctx, &idp) + require.NoError(t, err) + // change the name to make sure same only the id clashes + org.Name = gofakeit.Name() + return &idp + }, + err: new(database.UniqueError), + }, + { + name: "adding org with same name twice", + testFunc: func(ctx context.Context, t *testing.T) *domain.IdentityProvider { + idpRepo := repository.IDProviderRepository(pool) + + idp := domain.IdentityProvider{ + InstanceID: instanceId, + OrgID: orgId, + ID: gofakeit.Name(), + State: domain.IDPStateActive.String(), + Name: gofakeit.Name(), + Type: domain.IDPTypeOIDC.String(), + AllowCreation: true, + AllowAutoCreation: true, + AllowAutoUpdate: true, + AllowLinking: true, + StylingType: 1, + Payload: "{}", + } + + err := idpRepo.Create(ctx, &idp) + require.NoError(t, err) + // change the id to make sure same name causes an error + idp.ID = gofakeit.Name() + return &idp + }, + err: new(database.UniqueError), + }, + func() test { + id := gofakeit.Name() + name := gofakeit.Name() + return test{ + name: "adding org with same name, id, different instance", + testFunc: func(ctx context.Context, t *testing.T) *domain.IdentityProvider { + // create instance + newInstId := gofakeit.Name() + instance := domain.Instance{ + ID: newInstId, + Name: gofakeit.Name(), + DefaultOrgID: "defaultOrgId", + IAMProjectID: "iamProject", + ConsoleClientID: "consoleCLient", + ConsoleAppID: "consoleApp", + DefaultLanguage: "defaultLanguage", + } + instanceRepo := repository.InstanceRepository(pool) + err := instanceRepo.Create(ctx, &instance) + assert.Nil(t, err) + + // create org + newOrgId := gofakeit.Name() + org := domain.Organization{ + ID: newOrgId, + Name: gofakeit.Name(), + InstanceID: newInstId, + State: domain.OrgStateActive.String(), + } + organizationRepo := repository.OrganizationRepository(pool) + err = organizationRepo.Create(t.Context(), &org) + require.NoError(t, err) + + idpRepo := repository.IDProviderRepository(pool) + idp := domain.IdentityProvider{ + InstanceID: newInstId, + OrgID: newOrgId, + ID: id, + State: domain.IDPStateActive.String(), + Name: name, + Type: domain.IDPTypeOIDC.String(), + AllowCreation: true, + AllowAutoCreation: true, + AllowAutoUpdate: true, + AllowLinking: true, + StylingType: 1, + Payload: "{}", + } + + err = idpRepo.Create(ctx, &idp) + require.NoError(t, err) + + // change the instanceID to a different instance + idp.InstanceID = instanceId + // change the OrgId to a different organization + idp.OrgID = orgId + return &idp + }, + idp: domain.IdentityProvider{ + InstanceID: instanceId, + OrgID: orgId, + ID: id, + State: domain.IDPStateActive.String(), + Name: name, + Type: domain.IDPTypeOIDC.String(), + AllowCreation: true, + AllowAutoCreation: true, + AllowAutoUpdate: true, + AllowLinking: true, + StylingType: 1, + Payload: "{}", + }, + } + }(), + { + name: "adding idp with no id", + idp: domain.IdentityProvider{ + InstanceID: instanceId, + OrgID: orgId, + // ID: gofakeit.Name(), + State: domain.IDPStateActive.String(), + Name: gofakeit.Name(), + Type: domain.IDPTypeOIDC.String(), + AllowCreation: true, + AllowAutoCreation: true, + AllowAutoUpdate: true, + AllowLinking: true, + StylingType: 1, + Payload: "{}", + }, + err: new(database.CheckError), + }, + { + name: "adding idp with no name", + idp: domain.IdentityProvider{ + InstanceID: instanceId, + OrgID: orgId, + ID: gofakeit.Name(), + State: domain.IDPStateActive.String(), + // Name: gofakeit.Name(), + Type: domain.IDPTypeOIDC.String(), + AllowCreation: true, + AllowAutoCreation: true, + AllowAutoUpdate: true, + AllowLinking: true, + StylingType: 1, + Payload: "{}", + }, + err: new(database.CheckError), + }, + { + name: "adding idp with no instance id", + idp: domain.IdentityProvider{ + // InstanceID: instanceId, + OrgID: orgId, + State: domain.IDPStateActive.String(), + Name: gofakeit.Name(), + Type: domain.IDPTypeOIDC.String(), + AllowCreation: true, + AllowAutoCreation: true, + AllowAutoUpdate: true, + AllowLinking: true, + StylingType: 1, + Payload: "{}", + }, + err: new(database.IntegrityViolationError), + }, + { + name: "adding organization with non existent instance id", + idp: domain.IdentityProvider{ + InstanceID: gofakeit.Name(), + OrgID: orgId, + State: domain.IDPStateActive.String(), + ID: gofakeit.Name(), + Name: gofakeit.Name(), + Type: domain.IDPTypeOIDC.String(), + AllowCreation: true, + AllowAutoCreation: true, + AllowAutoUpdate: true, + AllowLinking: true, + StylingType: 1, + Payload: "{}", + }, + err: new(database.ForeignKeyError), + }, + { + name: "adding organization with non existent org id", + idp: domain.IdentityProvider{ + InstanceID: instanceId, + OrgID: gofakeit.Name(), + State: domain.IDPStateActive.String(), + ID: gofakeit.Name(), + Name: gofakeit.Name(), + Type: domain.IDPTypeOIDC.String(), + AllowCreation: true, + AllowAutoCreation: true, + AllowAutoUpdate: true, + AllowLinking: true, + StylingType: 1, + Payload: "{}", + }, + err: new(database.ForeignKeyError), + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctx := context.Background() + + var idp *domain.IdentityProvider + if tt.testFunc != nil { + idp = tt.testFunc(ctx, t) + } else { + idp = &tt.idp + } + idpRepo := repository.IDProviderRepository(pool) + + // create idp + beforeCreate := time.Now() + err = idpRepo.Create(ctx, idp) + assert.ErrorIs(t, err, tt.err) + if err != nil { + return + } + afterCreate := time.Now() + + // check organization values + idp, err = idpRepo.Get(ctx, + idpRepo.IDCondition(idp.ID), + idp.InstanceID, + idp.OrgID, + ) + require.NoError(t, err) + + assert.Equal(t, tt.idp.ID, idp.ID) + assert.Equal(t, tt.idp.Name, idp.Name) + assert.Equal(t, tt.idp.InstanceID, idp.InstanceID) + assert.Equal(t, tt.idp.State, idp.State) + assert.WithinRange(t, idp.CreatedAt, beforeCreate, afterCreate) + assert.WithinRange(t, idp.UpdatedAt, beforeCreate, afterCreate) + }) + } +} diff --git a/backend/v3/storage/database/statement.go b/backend/v3/storage/database/statement.go index 7d779fe360..041a6d87b6 100644 --- a/backend/v3/storage/database/statement.go +++ b/backend/v3/storage/database/statement.go @@ -29,9 +29,7 @@ func (b *StatementBuilder) AppendArg(arg any) (placeholder string) { if b.existingArgs == nil { b.existingArgs = make(map[any]string) } - if placeholder, ok := b.existingArgs[arg]; ok { - return placeholder - } + if instruction, ok := arg.(Instruction); ok { return string(instruction) }