diff --git a/backend/v3/domain/id_provider.go b/backend/v3/domain/id_provider.go new file mode 100644 index 0000000000..3740ccad50 --- /dev/null +++ b/backend/v3/domain/id_provider.go @@ -0,0 +1,164 @@ +package domain + +import ( + "context" + "time" + + "github.com/zitadel/zitadel/backend/v3/storage/database" + "github.com/zitadel/zitadel/internal/crypto" +) + +//go:generate enumer -type IDPType -transform lower -trimprefix IDPType +type IDPType uint8 + +const ( + IDPTypeUnspecified IDPType = iota + IDPTypeOIDC + IDPTypeJWT + IDPTypeOAuth + IDPTypeLDAP + IDPTypeAzureAD + IDPTypeGitHub + IDPTypeGitHubEnterprise + IDPTypeGitLab + IDPTypeGitLabSelfHosted + IDPTypeGoogle + IDPTypeApple + IDPTypeSAML +) + +//go:generate enumer -type IDPState -transform lower -trimprefix IDPState +type IDPState uint8 + +const ( + IDPStateActive IDPState = iota + IDPStateInactive +) + +type OIDCMappingField int8 + +const ( + OIDCMappingFieldUnspecified OIDCMappingField = iota + OIDCMappingFieldPreferredLoginName + OIDCMappingFieldEmail + // count is for validation purposes + oidcMappingFieldCount +) + +type IdentityProvider struct { + InstanceID string `json:"instanceId,omitempty" db:"instance_id"` + OrgID *string `json:"orgId,omitempty" db:"org_id"` + ID string `json:"id,omitempty" db:"id"` + State string `json:"state,omitempty" db:"state"` + Name string `json:"name,omitempty" db:"name"` + Type string `json:"type,omitempty" db:"type"` + AllowCreation bool `json:"allowCreation,omitempty" db:"allow_creation"` + AutoRegister bool `json:"autoRegister,omitempty" db:"auto_register"` + AllowAutoCreation bool `json:"allowAutoCreation,omitempty" db:"allow_auto_creation"` + AllowAutoUpdate bool `json:"allowAutoUpdate,omitempty" db:"allow_auto_update"` + AllowLinking bool `json:"allowLinking,omitempty" db:"allow_linking"` + AllowAutoLinking bool `json:"allowAutoLinking,omitempty" db:"allow_auto_linking"` + StylingType int16 `json:"stylingType,omitempty" db:"styling_type"` + Payload *string `json:"payload,omitempty" db:"payload"` + CreatedAt time.Time `json:"createdAt,omitempty" db:"created_at"` + UpdatedAt time.Time `json:"updatedAt,omitempty" db:"updated_at"` +} + +type OIDC struct { + IDPConfigID string `json:"idpConfigId"` + ClientID string `json:"clientId,omitempty"` + ClientSecret crypto.CryptoValue `json:"clientSecret,omitempty"` + Issuer string `json:"issuer,omitempty"` + AuthorizationEndpoint string `json:"authorizationEndpoint,omitempty"` + TokenEndpoint string `json:"tokenEndpoint,omitempty"` + Scopes []string `json:"scopes,omitempty"` + IDPDisplayNameMapping OIDCMappingField `json:"IDPDisplayNameMapping,omitempty"` + UserNameMapping OIDCMappingField `json:"usernameMapping,omitempty"` +} + +type IDPOIDC struct { + *IdentityProvider + OIDC +} + +type JWT struct { + IDPConfigID string `json:"idpConfigId"` + JWTEndpoint string `json:"jwtEndpoint,omitempty"` + Issuer string `json:"issuer,omitempty"` + KeysEndpoint string `json:"keysEndpoint,omitempty"` + HeaderName string `json:"headerName,omitempty"` +} + +type IDPJWT struct { + *IdentityProvider + JWT +} + +// IDPIdentifierCondition is used to help specify a single identity_provider, +// it will either be used as the identity_provider ID or identity_provider name, +// as identity_provider can be identified either using (instanceID + OrgID + ID) OR (instanceID + OrgID + name) +type IDPIdentifierCondition interface { + database.Condition +} + +type idProviderColumns interface { + InstanceIDColumn() database.Column + OrgIDColumn() database.Column + IDColumn() database.Column + StateColumn() database.Column + NameColumn() database.Column + TypeColumn() database.Column + AllowCreationColumn() database.Column + AutoRegisterColumn() database.Column + AllowAutoCreationColumn() database.Column + AllowAutoUpdateColumn() database.Column + AllowLinkingColumn() database.Column + AllowAutoLinkingColumn() database.Column + StylingTypeColumn() database.Column + PayloadColumn() database.Column + CreatedAtColumn() database.Column + UpdatedAtColumn() database.Column +} + +type idProviderConditions interface { + InstanceIDCondition(id string) database.Condition + OrgIDCondition(id *string) database.Condition + IDCondition(id string) IDPIdentifierCondition + StateCondition(state IDPState) database.Condition + NameCondition(name string) IDPIdentifierCondition + TypeCondition(typee IDPType) database.Condition + AutoRegisterCondition(allow bool) database.Condition + AllowCreationCondition(allow bool) database.Condition + AllowAutoCreationCondition(allow bool) database.Condition + AllowAutoUpdateCondition(allow bool) database.Condition + AllowLinkingCondition(allow bool) database.Condition + AllowAutoLinkingCondition(allow bool) database.Condition + StylingTypeCondition(style int16) database.Condition + PayloadCondition(payload string) database.Condition +} + +type idProviderChanges interface { + SetName(name string) database.Change + SetState(state IDPState) database.Change + SetAllowCreation(allow bool) database.Change + SetAutoRegister(allow bool) database.Change + SetAllowAutoCreation(allow bool) database.Change + SetAllowAutoUpdate(allow bool) database.Change + SetAllowLinking(allow bool) database.Change + SetAutoAllowLinking(allow bool) database.Change + SetStylingType(stylingType int16) database.Change + SetPayload(payload string) database.Change +} + +type IDProviderRepository interface { + idProviderColumns + idProviderConditions + idProviderChanges + + Get(ctx context.Context, id IDPIdentifierCondition, instanceID string, orgID *string) (*IdentityProvider, error) + List(ctx context.Context, conditions ...database.Condition) ([]*IdentityProvider, error) + + Create(ctx context.Context, idp *IdentityProvider) error + Update(ctx context.Context, id IDPIdentifierCondition, instanceID string, orgID *string, changes ...database.Change) (int64, error) + Delete(ctx context.Context, id IDPIdentifierCondition, instanceID string, orgID *string) (int64, error) +} diff --git a/backend/v3/domain/idpstate_enumer.go b/backend/v3/domain/idpstate_enumer.go new file mode 100644 index 0000000000..899192ae2b --- /dev/null +++ b/backend/v3/domain/idpstate_enumer.go @@ -0,0 +1,78 @@ +// Code generated by "enumer -type IDPState -transform lower -trimprefix IDPState"; DO NOT EDIT. + +package domain + +import ( + "fmt" + "strings" +) + +const _IDPStateName = "activeinactive" + +var _IDPStateIndex = [...]uint8{0, 6, 14} + +const _IDPStateLowerName = "activeinactive" + +func (i IDPState) String() string { + if i >= IDPState(len(_IDPStateIndex)-1) { + return fmt.Sprintf("IDPState(%d)", i) + } + return _IDPStateName[_IDPStateIndex[i]:_IDPStateIndex[i+1]] +} + +// An "invalid array index" compiler error signifies that the constant values have changed. +// Re-run the stringer command to generate them again. +func _IDPStateNoOp() { + var x [1]struct{} + _ = x[IDPStateActive-(0)] + _ = x[IDPStateInactive-(1)] +} + +var _IDPStateValues = []IDPState{IDPStateActive, IDPStateInactive} + +var _IDPStateNameToValueMap = map[string]IDPState{ + _IDPStateName[0:6]: IDPStateActive, + _IDPStateLowerName[0:6]: IDPStateActive, + _IDPStateName[6:14]: IDPStateInactive, + _IDPStateLowerName[6:14]: IDPStateInactive, +} + +var _IDPStateNames = []string{ + _IDPStateName[0:6], + _IDPStateName[6:14], +} + +// IDPStateString retrieves an enum value from the enum constants string name. +// Throws an error if the param is not part of the enum. +func IDPStateString(s string) (IDPState, error) { + if val, ok := _IDPStateNameToValueMap[s]; ok { + return val, nil + } + + if val, ok := _IDPStateNameToValueMap[strings.ToLower(s)]; ok { + return val, nil + } + return 0, fmt.Errorf("%s does not belong to IDPState values", s) +} + +// IDPStateValues returns all values of the enum +func IDPStateValues() []IDPState { + return _IDPStateValues +} + +// IDPStateStrings returns a slice of all String values of the enum +func IDPStateStrings() []string { + strs := make([]string, len(_IDPStateNames)) + copy(strs, _IDPStateNames) + return strs +} + +// IsAIDPState returns "true" if the value is listed in the enum definition. "false" otherwise +func (i IDPState) IsAIDPState() bool { + for _, v := range _IDPStateValues { + if i == v { + return true + } + } + return false +} diff --git a/backend/v3/domain/idptype_enumer.go b/backend/v3/domain/idptype_enumer.go new file mode 100644 index 0000000000..881fd893f8 --- /dev/null +++ b/backend/v3/domain/idptype_enumer.go @@ -0,0 +1,122 @@ +// Code generated by "enumer -type IDPType -transform lower -trimprefix IDPType"; DO NOT EDIT. + +package domain + +import ( + "fmt" + "strings" +) + +const _IDPTypeName = "unspecifiedoidcjwtoauthldapazureadgithubgithubenterprisegitlabgitlabselfhostedgoogleapplesaml" + +var _IDPTypeIndex = [...]uint8{0, 11, 15, 18, 23, 27, 34, 40, 56, 62, 78, 84, 89, 93} + +const _IDPTypeLowerName = "unspecifiedoidcjwtoauthldapazureadgithubgithubenterprisegitlabgitlabselfhostedgoogleapplesaml" + +func (i IDPType) String() string { + if i >= IDPType(len(_IDPTypeIndex)-1) { + return fmt.Sprintf("IDPType(%d)", i) + } + return _IDPTypeName[_IDPTypeIndex[i]:_IDPTypeIndex[i+1]] +} + +// An "invalid array index" compiler error signifies that the constant values have changed. +// Re-run the stringer command to generate them again. +func _IDPTypeNoOp() { + var x [1]struct{} + _ = x[IDPTypeUnspecified-(0)] + _ = x[IDPTypeOIDC-(1)] + _ = x[IDPTypeJWT-(2)] + _ = x[IDPTypeOAuth-(3)] + _ = x[IDPTypeLDAP-(4)] + _ = x[IDPTypeAzureAD-(5)] + _ = x[IDPTypeGitHub-(6)] + _ = x[IDPTypeGitHubEnterprise-(7)] + _ = x[IDPTypeGitLab-(8)] + _ = x[IDPTypeGitLabSelfHosted-(9)] + _ = x[IDPTypeGoogle-(10)] + _ = x[IDPTypeApple-(11)] + _ = x[IDPTypeSAML-(12)] +} + +var _IDPTypeValues = []IDPType{IDPTypeUnspecified, IDPTypeOIDC, IDPTypeJWT, IDPTypeOAuth, IDPTypeLDAP, IDPTypeAzureAD, IDPTypeGitHub, IDPTypeGitHubEnterprise, IDPTypeGitLab, IDPTypeGitLabSelfHosted, IDPTypeGoogle, IDPTypeApple, IDPTypeSAML} + +var _IDPTypeNameToValueMap = map[string]IDPType{ + _IDPTypeName[0:11]: IDPTypeUnspecified, + _IDPTypeLowerName[0:11]: IDPTypeUnspecified, + _IDPTypeName[11:15]: IDPTypeOIDC, + _IDPTypeLowerName[11:15]: IDPTypeOIDC, + _IDPTypeName[15:18]: IDPTypeJWT, + _IDPTypeLowerName[15:18]: IDPTypeJWT, + _IDPTypeName[18:23]: IDPTypeOAuth, + _IDPTypeLowerName[18:23]: IDPTypeOAuth, + _IDPTypeName[23:27]: IDPTypeLDAP, + _IDPTypeLowerName[23:27]: IDPTypeLDAP, + _IDPTypeName[27:34]: IDPTypeAzureAD, + _IDPTypeLowerName[27:34]: IDPTypeAzureAD, + _IDPTypeName[34:40]: IDPTypeGitHub, + _IDPTypeLowerName[34:40]: IDPTypeGitHub, + _IDPTypeName[40:56]: IDPTypeGitHubEnterprise, + _IDPTypeLowerName[40:56]: IDPTypeGitHubEnterprise, + _IDPTypeName[56:62]: IDPTypeGitLab, + _IDPTypeLowerName[56:62]: IDPTypeGitLab, + _IDPTypeName[62:78]: IDPTypeGitLabSelfHosted, + _IDPTypeLowerName[62:78]: IDPTypeGitLabSelfHosted, + _IDPTypeName[78:84]: IDPTypeGoogle, + _IDPTypeLowerName[78:84]: IDPTypeGoogle, + _IDPTypeName[84:89]: IDPTypeApple, + _IDPTypeLowerName[84:89]: IDPTypeApple, + _IDPTypeName[89:93]: IDPTypeSAML, + _IDPTypeLowerName[89:93]: IDPTypeSAML, +} + +var _IDPTypeNames = []string{ + _IDPTypeName[0:11], + _IDPTypeName[11:15], + _IDPTypeName[15:18], + _IDPTypeName[18:23], + _IDPTypeName[23:27], + _IDPTypeName[27:34], + _IDPTypeName[34:40], + _IDPTypeName[40:56], + _IDPTypeName[56:62], + _IDPTypeName[62:78], + _IDPTypeName[78:84], + _IDPTypeName[84:89], + _IDPTypeName[89:93], +} + +// IDPTypeString retrieves an enum value from the enum constants string name. +// Throws an error if the param is not part of the enum. +func IDPTypeString(s string) (IDPType, error) { + if val, ok := _IDPTypeNameToValueMap[s]; ok { + return val, nil + } + + if val, ok := _IDPTypeNameToValueMap[strings.ToLower(s)]; ok { + return val, nil + } + return 0, fmt.Errorf("%s does not belong to IDPType values", s) +} + +// IDPTypeValues returns all values of the enum +func IDPTypeValues() []IDPType { + return _IDPTypeValues +} + +// IDPTypeStrings returns a slice of all String values of the enum +func IDPTypeStrings() []string { + strs := make([]string, len(_IDPTypeNames)) + copy(strs, _IDPTypeNames) + return strs +} + +// IsAIDPType returns "true" if the value is listed in the enum definition. "false" otherwise +func (i IDPType) IsAIDPType() bool { + for _, v := range _IDPTypeValues { + if i == v { + return true + } + } + return false +} diff --git a/backend/v3/domain/instance.go b/backend/v3/domain/instance.go index 10a3620788..9b95d111c0 100644 --- a/backend/v3/domain/instance.go +++ b/backend/v3/domain/instance.go @@ -84,7 +84,7 @@ type InstanceRepository interface { // Member() MemberRepository Get(ctx context.Context, id string) (*Instance, error) - List(ctx context.Context, opts ...database.Condition) ([]*Instance, error) + List(ctx context.Context, conditions ...database.Condition) ([]*Instance, error) Create(ctx context.Context, instance *Instance) error Update(ctx context.Context, id string, changes ...database.Change) (int64, error) diff --git a/backend/v3/domain/organization.go b/backend/v3/domain/organization.go index 870cc77838..5f4ab169d3 100644 --- a/backend/v3/domain/organization.go +++ b/backend/v3/domain/organization.go @@ -37,7 +37,7 @@ type organizationColumns interface { IDColumn() database.Column // NameColumn returns the column for the name field. NameColumn() database.Column - // InstanceIDColumn returns the column for the default org id field + // InstanceIDColumn returns the column for the instance id field InstanceIDColumn() database.Column // StateColumn returns the column for the name field. StateColumn() database.Column diff --git a/backend/v3/domain/orgstate_enumer.go b/backend/v3/domain/orgstate_enumer.go index a5a1d0ca57..a63258969b 100644 --- a/backend/v3/domain/orgstate_enumer.go +++ b/backend/v3/domain/orgstate_enumer.go @@ -14,7 +14,7 @@ var _OrgStateIndex = [...]uint8{0, 6, 14} const _OrgStateLowerName = "activeinactive" func (i OrgState) String() string { - if i < 0 || i >= OrgState(len(_OrgStateIndex)-1) { + if i >= OrgState(len(_OrgStateIndex)-1) { return fmt.Sprintf("OrgState(%d)", i) } return _OrgStateName[_OrgStateIndex[i]:_OrgStateIndex[i+1]] diff --git a/backend/v3/storage/database/condition.go b/backend/v3/storage/database/condition.go index 55f1e862e6..29aaa7c4ff 100644 --- a/backend/v3/storage/database/condition.go +++ b/backend/v3/storage/database/condition.go @@ -16,11 +16,16 @@ func (a *and) Write(builder *StatementBuilder) { builder.WriteString("(") defer builder.WriteString(")") } - for i, condition := range a.conditions { - if i > 0 { + firstCondition := true + for _, condition := range a.conditions { + if condition == nil { + continue + } + if !firstCondition { builder.WriteString(" AND ") } condition.Write(builder) + firstCondition = false } } diff --git a/backend/v3/storage/database/dialect/postgres/migration/003_identity_providers_table.go b/backend/v3/storage/database/dialect/postgres/migration/003_identity_providers_table.go new file mode 100644 index 0000000000..56a14ffcd5 --- /dev/null +++ b/backend/v3/storage/database/dialect/postgres/migration/003_identity_providers_table.go @@ -0,0 +1,16 @@ +package migration + +import ( + _ "embed" +) + +var ( + //go:embed 003_identity_providers_table/up.sql + up003IdentityProvidersTable string + //go:embed 003_identity_providers_table/down.sql + down003IdentityProvidersTable string +) + +func init() { + registerSQLMigration(3, up003IdentityProvidersTable, down003IdentityProvidersTable) +} diff --git a/backend/v3/storage/database/dialect/postgres/migration/003_identity_providers_table/down.sql b/backend/v3/storage/database/dialect/postgres/migration/003_identity_providers_table/down.sql new file mode 100644 index 0000000000..0831a6ca4f --- /dev/null +++ b/backend/v3/storage/database/dialect/postgres/migration/003_identity_providers_table/down.sql @@ -0,0 +1,3 @@ +DROP TABLE zitadel.identity_providers; +DROP Type zitadel.idp_state; +DROP Type zitadel.idp_type; diff --git a/backend/v3/storage/database/dialect/postgres/migration/003_identity_providers_table/up.sql b/backend/v3/storage/database/dialect/postgres/migration/003_identity_providers_table/up.sql new file mode 100644 index 0000000000..869abd97cd --- /dev/null +++ b/backend/v3/storage/database/dialect/postgres/migration/003_identity_providers_table/up.sql @@ -0,0 +1,55 @@ +CREATE TYPE zitadel.idp_state AS ENUM ( + 'active', + 'inactive' +); + +CREATE TYPE zitadel.idp_type AS ENUM ( + 'oidc', + 'jwt', + 'oauth', + 'saml', + 'ldap', + 'github', + 'google', + 'microsoft', + 'apple' +); + +CREATE TABLE zitadel.identity_providers ( + instance_id TEXT NOT NULL + , org_id TEXT + , id TEXT NOT NULL CHECK (id <> '') + , state zitadel.idp_state NOT NULL DEFAULT 'active' + , name TEXT NOT NULL CHECK (name <> '') + , type zitadel.idp_type -- NOT NULL + , auto_register BOOLEAN NOT NULL DEFAULT TRUE + , allow_creation BOOLEAN NOT NULL DEFAULT TRUE + , allow_auto_creation BOOLEAN NOT NULL DEFAULT TRUE + , allow_auto_update BOOLEAN NOT NULL DEFAULT TRUE + , allow_linking BOOLEAN NOT NULL DEFAULT TRUE + , allow_auto_linking BOOLEAN NOT NULL DEFAULT TRUE + , styling_type SMALLINT + , payload JSONB + + , created_at TIMESTAMPTZ NOT NULL DEFAULT now() + , updated_at TIMESTAMPTZ NOT NULL DEFAULT now() + + , PRIMARY KEY (instance_id, id) + , CONSTRAINT identity_providers_id_unique UNIQUE NULLS NOT DISTINCT (instance_id, org_id, id) + , CONSTRAINT identity_providers_name_unique UNIQUE NULLS NOT DISTINCT (instance_id, org_id, name) + , FOREIGN KEY (instance_id) REFERENCES zitadel.instances(id) + , FOREIGN KEY (instance_id, org_id) REFERENCES zitadel.organizations(instance_id, id) +); + +-- CREATE INDEX idx_identity_providers_org_id ON identity_providers(instance_id, org_id) WHERE org_id IS NOT NULL; +CREATE INDEX idx_identity_providers_state ON zitadel.identity_providers(instance_id, state); +CREATE INDEX idx_identity_providers_type ON zitadel.identity_providers(instance_id, type); +-- CREATE INDEX idx_identity_providers_created_at ON identity_providers(created_at); +-- CREATE INDEX idx_identity_providers_deleted_at ON identity_providers(deleted_at) WHERE deleted_at IS NOT NULL; + + +CREATE TRIGGER trigger_set_updated_at +BEFORE UPDATE ON zitadel.identity_providers +FOR EACH ROW +WHEN (OLD.updated_at IS NOT DISTINCT FROM NEW.updated_at) +EXECUTE FUNCTION zitadel.set_updated_at(); diff --git a/backend/v3/storage/database/events_testing/id_provider_test.go b/backend/v3/storage/database/events_testing/id_provider_test.go new file mode 100644 index 0000000000..8d9cdf0f1e --- /dev/null +++ b/backend/v3/storage/database/events_testing/id_provider_test.go @@ -0,0 +1,485 @@ +//go:build integration + +package events_test + +import ( + "testing" + "time" + + "github.com/brianvoe/gofakeit/v6" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/zitadel/zitadel/backend/v3/domain" + "github.com/zitadel/zitadel/backend/v3/storage/database" + "github.com/zitadel/zitadel/backend/v3/storage/database/repository" + "github.com/zitadel/zitadel/internal/integration" + "github.com/zitadel/zitadel/pkg/grpc/admin" + "github.com/zitadel/zitadel/pkg/grpc/idp" + idp_grpc "github.com/zitadel/zitadel/pkg/grpc/idp" +) + +func TestServer_TestIDProviderReduces(t *testing.T) { + instanceID := Instance.ID() + + t.Run("test idp add reduces", func(t *testing.T) { + name := gofakeit.Name() + + beforeCreate := time.Now() + addOIDC, err := AdminClient.AddOIDCIDP(CTX, &admin.AddOIDCIDPRequest{ + Name: name, + StylingType: idp_grpc.IDPStylingType_STYLING_TYPE_GOOGLE, + ClientId: "clientID", + ClientSecret: "clientSecret", + Issuer: "issuer", + Scopes: []string{"scope"}, + DisplayNameMapping: idp.OIDCMappingField_OIDC_MAPPING_FIELD_EMAIL, + UsernameMapping: idp.OIDCMappingField_OIDC_MAPPING_FIELD_EMAIL, + AutoRegister: true, + }) + require.NoError(t, err) + afterCreate := time.Now() + + idpRepo := repository.IDProviderRepository(pool) + + retryDuration, tick := integration.WaitForAndTickWithMaxDuration(CTX, time.Second*5) + assert.EventuallyWithT(t, func(t *assert.CollectT) { + idp, err := idpRepo.Get(CTX, + idpRepo.NameCondition(name), + instanceID, + nil, + ) + require.NoError(t, err) + + // event iam.idp.config.added + assert.Equal(t, addOIDC.IdpId, idp.ID) + assert.Equal(t, name, idp.Name) + assert.Equal(t, instanceID, idp.InstanceID) + assert.Equal(t, domain.IDPStateActive.String(), idp.State) + assert.Equal(t, true, idp.AutoRegister) + assert.Equal(t, int16(idp_grpc.IDPStylingType_STYLING_TYPE_GOOGLE), idp.StylingType) + assert.WithinRange(t, idp.UpdatedAt, beforeCreate, afterCreate) + assert.WithinRange(t, idp.CreatedAt, beforeCreate, afterCreate) + }, retryDuration, tick) + }) + + t.Run("test idp update reduces", func(t *testing.T) { + name := gofakeit.Name() + + addOIDC, err := AdminClient.AddOIDCIDP(CTX, &admin.AddOIDCIDPRequest{ + Name: name, + StylingType: idp_grpc.IDPStylingType_STYLING_TYPE_GOOGLE, + ClientId: "clientID", + ClientSecret: "clientSecret", + Issuer: "issuer", + Scopes: []string{"scope"}, + DisplayNameMapping: idp.OIDCMappingField_OIDC_MAPPING_FIELD_EMAIL, + UsernameMapping: idp.OIDCMappingField_OIDC_MAPPING_FIELD_EMAIL, + AutoRegister: true, + }) + require.NoError(t, err) + + name = "new_" + name + + beforeCreate := time.Now() + _, err = AdminClient.UpdateIDP(CTX, &admin.UpdateIDPRequest{ + IdpId: addOIDC.IdpId, + Name: name, + StylingType: idp_grpc.IDPStylingType_STYLING_TYPE_UNSPECIFIED, + AutoRegister: false, + }) + afterCreate := time.Now() + require.NoError(t, err) + + idpRepo := repository.IDProviderRepository(pool) + + retryDuration, tick := integration.WaitForAndTickWithMaxDuration(CTX, time.Second*5) + assert.EventuallyWithT(t, func(t *assert.CollectT) { + idp, err := idpRepo.Get(CTX, + idpRepo.NameCondition(name), + instanceID, + nil, + ) + require.NoError(t, err) + + // event iam.idp.config.changed + assert.Equal(t, addOIDC.IdpId, idp.ID) + assert.Equal(t, name, idp.Name) + assert.Equal(t, false, idp.AutoRegister) + assert.Equal(t, int16(idp_grpc.IDPStylingType_STYLING_TYPE_UNSPECIFIED), idp.StylingType) + assert.WithinRange(t, idp.UpdatedAt, beforeCreate, afterCreate) + }, retryDuration, tick) + }) + + t.Run("test idp deactivate reduces", func(t *testing.T) { + name := gofakeit.Name() + + addOIDC, err := AdminClient.AddOIDCIDP(CTX, &admin.AddOIDCIDPRequest{ + Name: name, + StylingType: idp_grpc.IDPStylingType_STYLING_TYPE_GOOGLE, + ClientId: "clientID", + ClientSecret: "clientSecret", + Issuer: "issuer", + Scopes: []string{"scope"}, + DisplayNameMapping: idp.OIDCMappingField_OIDC_MAPPING_FIELD_EMAIL, + UsernameMapping: idp.OIDCMappingField_OIDC_MAPPING_FIELD_EMAIL, + AutoRegister: true, + }) + require.NoError(t, err) + + // deactivate idp + beforeCreate := time.Now() + _, err = AdminClient.DeactivateIDP(CTX, &admin.DeactivateIDPRequest{ + IdpId: addOIDC.IdpId, + }) + afterCreate := time.Now() + require.NoError(t, err) + + idpRepo := repository.IDProviderRepository(pool) + + retryDuration, tick := integration.WaitForAndTickWithMaxDuration(CTX, time.Minute) + assert.EventuallyWithT(t, func(t *assert.CollectT) { + idp, err := idpRepo.Get(CTX, + idpRepo.IDCondition(addOIDC.IdpId), + instanceID, + nil, + ) + require.NoError(t, err) + + // event iam.idp.config.deactivated + assert.Equal(t, addOIDC.IdpId, idp.ID) + assert.Equal(t, domain.IDPStateInactive.String(), idp.State) + assert.WithinRange(t, idp.UpdatedAt, beforeCreate, afterCreate) + }, retryDuration, tick) + }) + + t.Run("test idp reactivate reduces", func(t *testing.T) { + name := gofakeit.Name() + + addOIDC, err := AdminClient.AddOIDCIDP(CTX, &admin.AddOIDCIDPRequest{ + Name: name, + StylingType: idp_grpc.IDPStylingType_STYLING_TYPE_GOOGLE, + ClientId: "clientID", + ClientSecret: "clientSecret", + Issuer: "issuer", + Scopes: []string{"scope"}, + DisplayNameMapping: idp.OIDCMappingField_OIDC_MAPPING_FIELD_EMAIL, + UsernameMapping: idp.OIDCMappingField_OIDC_MAPPING_FIELD_EMAIL, + AutoRegister: true, + }) + require.NoError(t, err) + + idpRepo := repository.IDProviderRepository(pool) + + // deactivate idp + _, err = AdminClient.DeactivateIDP(CTX, &admin.DeactivateIDPRequest{ + IdpId: addOIDC.IdpId, + }) + require.NoError(t, err) + // wait for idp to be deactivated + retryDuration, tick := integration.WaitForAndTickWithMaxDuration(CTX, time.Minute) + assert.EventuallyWithT(t, func(t *assert.CollectT) { + idp, err := idpRepo.Get(CTX, + idpRepo.IDCondition(addOIDC.IdpId), + instanceID, + nil, + ) + require.NoError(t, err) + + assert.Equal(t, addOIDC.IdpId, idp.ID) + assert.Equal(t, domain.IDPStateInactive.String(), idp.State) + }, retryDuration, tick) + + // reactivate idp + beforeCreate := time.Now() + _, err = AdminClient.ReactivateIDP(CTX, &admin.ReactivateIDPRequest{ + IdpId: addOIDC.IdpId, + }) + afterCreate := time.Now() + require.NoError(t, err) + + retryDuration, tick = integration.WaitForAndTickWithMaxDuration(CTX, time.Minute) + assert.EventuallyWithT(t, func(t *assert.CollectT) { + idp, err := idpRepo.Get(CTX, + idpRepo.IDCondition(addOIDC.IdpId), + instanceID, + nil, + ) + require.NoError(t, err) + + // event iam.idp.config.reactivated + assert.Equal(t, addOIDC.IdpId, idp.ID) + assert.Equal(t, domain.IDPStateActive.String(), idp.State) + assert.WithinRange(t, idp.UpdatedAt, beforeCreate, afterCreate) + }, retryDuration, tick) + }) + + t.Run("test idp remove reduces", func(t *testing.T) { + name := gofakeit.Name() + + // add idp + addOIDC, err := AdminClient.AddOIDCIDP(CTX, &admin.AddOIDCIDPRequest{ + Name: name, + StylingType: idp_grpc.IDPStylingType_STYLING_TYPE_GOOGLE, + ClientId: "clientID", + ClientSecret: "clientSecret", + Issuer: "issuer", + Scopes: []string{"scope"}, + DisplayNameMapping: idp.OIDCMappingField_OIDC_MAPPING_FIELD_EMAIL, + UsernameMapping: idp.OIDCMappingField_OIDC_MAPPING_FIELD_EMAIL, + AutoRegister: true, + }) + require.NoError(t, err) + + idpRepo := repository.IDProviderRepository(pool) + + // remove idp + _, err = AdminClient.RemoveIDP(CTX, &admin.RemoveIDPRequest{ + IdpId: addOIDC.IdpId, + }) + require.NoError(t, err) + + retryDuration, tick := integration.WaitForAndTickWithMaxDuration(CTX, time.Second*20) + assert.EventuallyWithT(t, func(t *assert.CollectT) { + _, err := idpRepo.Get(CTX, + idpRepo.IDCondition(addOIDC.IdpId), + instanceID, + nil, + ) + + // event iam.idp.config.remove + require.ErrorIs(t, &database.NoRowFoundError{}, err) + }, retryDuration, tick) + }) + + t.Run("test idp oidc addded reduces", func(t *testing.T) { + name := gofakeit.Name() + + // add oidc + addOIDC, err := AdminClient.AddOIDCIDP(CTX, &admin.AddOIDCIDPRequest{ + Name: name, + StylingType: idp_grpc.IDPStylingType_STYLING_TYPE_GOOGLE, + ClientId: "clientID", + ClientSecret: "clientSecret", + Issuer: "issuer", + Scopes: []string{"scope"}, + DisplayNameMapping: idp.OIDCMappingField_OIDC_MAPPING_FIELD_EMAIL, + UsernameMapping: idp.OIDCMappingField_OIDC_MAPPING_FIELD_EMAIL, + AutoRegister: true, + }) + require.NoError(t, err) + + idpRepo := repository.IDProviderRepository(pool) + + retryDuration, tick := integration.WaitForAndTickWithMaxDuration(CTX, time.Minute) + assert.EventuallyWithT(t, func(t *assert.CollectT) { + oidc, err := idpRepo.GetOIDC(CTX, + idpRepo.IDCondition(addOIDC.IdpId), + instanceID, + nil, + ) + require.NoError(t, err) + + // event org.idp.oidc.config.added + // idp + assert.Equal(t, addOIDC.IdpId, oidc.ID) + assert.Equal(t, domain.IDPTypeOIDC.String(), oidc.Type) + + // oidc + assert.Equal(t, addOIDC.IdpId, oidc.IDPConfigID) + assert.Equal(t, "issuer", oidc.Issuer) + assert.Equal(t, "clientID", oidc.ClientID) + assert.Equal(t, []string{"scope"}, oidc.Scopes) + assert.Equal(t, domain.OIDCMappingField(idp.OIDCMappingField_OIDC_MAPPING_FIELD_EMAIL), oidc.IDPDisplayNameMapping) + assert.Equal(t, domain.OIDCMappingField(idp.OIDCMappingField_OIDC_MAPPING_FIELD_EMAIL), oidc.UserNameMapping) + }, retryDuration, tick) + }) + + t.Run("test idp oidc changed reduces", func(t *testing.T) { + name := gofakeit.Name() + + // add oidc + addOIDC, err := AdminClient.AddOIDCIDP(CTX, &admin.AddOIDCIDPRequest{ + Name: name, + StylingType: idp_grpc.IDPStylingType_STYLING_TYPE_GOOGLE, + ClientId: "clientID", + ClientSecret: "clientSecret", + Issuer: "issuer", + Scopes: []string{"scope"}, + DisplayNameMapping: idp.OIDCMappingField_OIDC_MAPPING_FIELD_EMAIL, + UsernameMapping: idp.OIDCMappingField_OIDC_MAPPING_FIELD_EMAIL, + AutoRegister: true, + }) + require.NoError(t, err) + + idpRepo := repository.IDProviderRepository(pool) + + // check original values for OCID + var oidc *domain.IDPOIDC + retryDuration, tick := integration.WaitForAndTickWithMaxDuration(CTX, time.Minute) + assert.EventuallyWithT(t, func(t *assert.CollectT) { + oidc, err = idpRepo.GetOIDC(CTX, idpRepo.IDCondition(addOIDC.IdpId), instanceID, nil) + require.NoError(t, err) + }, retryDuration, tick) + + // idp + assert.Equal(t, addOIDC.IdpId, oidc.ID) + assert.Equal(t, domain.IDPTypeOIDC.String(), oidc.Type) + + // oidc + assert.Equal(t, addOIDC.IdpId, oidc.IDPConfigID) + assert.Equal(t, "issuer", oidc.Issuer) + assert.Equal(t, "clientID", oidc.ClientID) + assert.Equal(t, []string{"scope"}, oidc.Scopes) + assert.Equal(t, domain.OIDCMappingField(idp.OIDCMappingField_OIDC_MAPPING_FIELD_EMAIL), oidc.IDPDisplayNameMapping) + assert.Equal(t, domain.OIDCMappingField(idp.OIDCMappingField_OIDC_MAPPING_FIELD_EMAIL), oidc.UserNameMapping) + + beforeCreate := time.Now() + _, err = AdminClient.UpdateIDPOIDCConfig(CTX, &admin.UpdateIDPOIDCConfigRequest{ + IdpId: addOIDC.IdpId, + ClientId: "new_clientID", + ClientSecret: "new_clientSecret", + Issuer: "new_issuer", + Scopes: []string{"new_scope"}, + DisplayNameMapping: idp.OIDCMappingField_OIDC_MAPPING_FIELD_PREFERRED_USERNAME, + UsernameMapping: idp.OIDCMappingField_OIDC_MAPPING_FIELD_PREFERRED_USERNAME, + }) + afterCreate := time.Now() + require.NoError(t, err) + + retryDuration, tick = integration.WaitForAndTickWithMaxDuration(CTX, time.Second*5) + assert.EventuallyWithT(t, func(t *assert.CollectT) { + updateOIDC, err := idpRepo.GetOIDC(CTX, + idpRepo.IDCondition(addOIDC.IdpId), + instanceID, + nil, + ) + require.NoError(t, err) + + // event org.idp.oidc.config.changed + // idp + assert.Equal(t, addOIDC.IdpId, updateOIDC.ID) + assert.Equal(t, domain.IDPTypeOIDC.String(), updateOIDC.Type) + assert.WithinRange(t, updateOIDC.UpdatedAt, beforeCreate, afterCreate) + + // oidc + assert.Equal(t, addOIDC.IdpId, updateOIDC.IDPConfigID) + assert.Equal(t, "new_issuer", updateOIDC.Issuer) + assert.Equal(t, "new_clientID", updateOIDC.ClientID) + assert.NotEqual(t, oidc.ClientSecret, updateOIDC.ClientSecret) + assert.Equal(t, []string{"new_scope"}, updateOIDC.Scopes) + assert.Equal(t, domain.OIDCMappingField(idp.OIDCMappingField_OIDC_MAPPING_FIELD_PREFERRED_USERNAME), updateOIDC.IDPDisplayNameMapping) + assert.Equal(t, domain.OIDCMappingField(idp.OIDCMappingField_OIDC_MAPPING_FIELD_PREFERRED_USERNAME), updateOIDC.UserNameMapping) + }, retryDuration, tick) + }) + + t.Run("test idp jwt addded reduces", func(t *testing.T) { + name := gofakeit.Name() + + // add jwt + addJWT, err := AdminClient.AddJWTIDP(CTX, &admin.AddJWTIDPRequest{ + Name: name, + StylingType: idp_grpc.IDPStylingType_STYLING_TYPE_GOOGLE, + JwtEndpoint: "jwtEndpoint", + Issuer: "issuer", + KeysEndpoint: "keyEndpoint", + HeaderName: "headerName", + AutoRegister: true, + }) + require.NoError(t, err) + + idpRepo := repository.IDProviderRepository(pool) + + retryDuration, tick := integration.WaitForAndTickWithMaxDuration(CTX, time.Second*5) + assert.EventuallyWithT(t, func(t *assert.CollectT) { + jwt, err := idpRepo.GetJWT(CTX, + idpRepo.IDCondition(addJWT.IdpId), + instanceID, + nil, + ) + require.NoError(t, err) + + // event org.idp.jwt.config.added + // idp + assert.Equal(t, addJWT.IdpId, jwt.ID) + assert.Equal(t, domain.IDPTypeJWT.String(), jwt.Type) + + // jwt + assert.Equal(t, addJWT.IdpId, jwt.IDPConfigID) + assert.Equal(t, "jwtEndpoint", jwt.JWTEndpoint) + assert.Equal(t, "issuer", jwt.Issuer) + assert.Equal(t, "keyEndpoint", jwt.KeysEndpoint) + assert.Equal(t, "headerName", jwt.HeaderName) + }, retryDuration, tick) + }) + + t.Run("test idp jwt changed reduces", func(t *testing.T) { + name := gofakeit.Name() + + // add jwt + addJWT, err := AdminClient.AddJWTIDP(CTX, &admin.AddJWTIDPRequest{ + Name: name, + StylingType: idp_grpc.IDPStylingType_STYLING_TYPE_GOOGLE, + JwtEndpoint: "jwtEndpoint", + Issuer: "issuer", + KeysEndpoint: "keyEndpoint", + HeaderName: "headerName", + AutoRegister: true, + }) + require.NoError(t, err) + + idpRepo := repository.IDProviderRepository(pool) + + // check original values for jwt + var jwt *domain.IDPJWT + retryDuration, tick := integration.WaitForAndTickWithMaxDuration(CTX, time.Minute) + assert.EventuallyWithT(t, func(t *assert.CollectT) { + jwt, err = idpRepo.GetJWT(CTX, idpRepo.IDCondition(addJWT.IdpId), instanceID, nil) + require.NoError(t, err) + }, retryDuration, tick) + + // idp + assert.Equal(t, addJWT.IdpId, jwt.ID) + assert.Equal(t, domain.IDPTypeJWT.String(), jwt.Type) + + // jwt + assert.Equal(t, addJWT.IdpId, jwt.IDPConfigID) + assert.Equal(t, "jwtEndpoint", jwt.JWTEndpoint) + assert.Equal(t, "issuer", jwt.Issuer) + assert.Equal(t, "keyEndpoint", jwt.KeysEndpoint) + assert.Equal(t, "headerName", jwt.HeaderName) + + beforeCreate := time.Now() + _, err = AdminClient.UpdateIDPJWTConfig(CTX, &admin.UpdateIDPJWTConfigRequest{ + IdpId: addJWT.IdpId, + JwtEndpoint: "new_jwtEndpoint", + Issuer: "new_issuer", + KeysEndpoint: "new_keyEndpoint", + HeaderName: "new_headerName", + }) + afterCreate := time.Now() + require.NoError(t, err) + + retryDuration, tick = integration.WaitForAndTickWithMaxDuration(CTX, time.Second*5) + assert.EventuallyWithT(t, func(t *assert.CollectT) { + updateJWT, err := idpRepo.GetJWT(CTX, + idpRepo.IDCondition(addJWT.IdpId), + instanceID, + nil, + ) + require.NoError(t, err) + + // event org.idp.jwt.config.changed + // idp + assert.Equal(t, addJWT.IdpId, updateJWT.ID) + assert.Equal(t, domain.IDPTypeJWT.String(), updateJWT.Type) + assert.WithinRange(t, updateJWT.UpdatedAt, beforeCreate, afterCreate) + + // jwt + assert.Equal(t, addJWT.IdpId, updateJWT.IDPConfigID) + assert.Equal(t, "new_jwtEndpoint", updateJWT.JWTEndpoint) + assert.Equal(t, "new_issuer", updateJWT.Issuer) + assert.Equal(t, "new_keyEndpoint", updateJWT.KeysEndpoint) + }, retryDuration, tick) + }) +} diff --git a/backend/v3/storage/database/operators.go b/backend/v3/storage/database/operators.go index a2949220e9..9f3e18aeab 100644 --- a/backend/v3/storage/database/operators.go +++ b/backend/v3/storage/database/operators.go @@ -137,6 +137,6 @@ const ( func writeBooleanOperation[T Boolean](builder *StatementBuilder, col Column, value T) { col.Write(builder) - builder.WriteString(" IS ") + builder.WriteString(" = ") builder.WriteArg(value) } diff --git a/backend/v3/storage/database/repository/id_provider.go b/backend/v3/storage/database/repository/id_provider.go new file mode 100644 index 0000000000..55c981a8d1 --- /dev/null +++ b/backend/v3/storage/database/repository/id_provider.go @@ -0,0 +1,312 @@ +package repository + +import ( + "context" + "errors" + + "github.com/zitadel/zitadel/backend/v3/domain" + "github.com/zitadel/zitadel/backend/v3/storage/database" +) + +var _ domain.IDProviderRepository = (*idProvider)(nil) + +type idProvider struct { + repository +} + +func IDProviderRepository(client database.QueryExecutor) domain.IDProviderRepository { + return &idProvider{ + repository: repository{ + client: client, + }, + } +} + +const queryIDProviderStmt = `SELECT instance_id, org_id, id, state, name, type, allow_creation, allow_auto_creation,` + + ` allow_auto_update, allow_linking, styling_type, payload, created_at, updated_at` + + ` FROM zitadel.identity_providers` + +func (i *idProvider) Get(ctx context.Context, id domain.IDPIdentifierCondition, instanceID string, orgID *string) (*domain.IdentityProvider, error) { + builder := database.StatementBuilder{} + + builder.WriteString(queryIDProviderStmt) + + conditions := []database.Condition{id, i.InstanceIDCondition(instanceID), i.OrgIDCondition(orgID)} + + writeCondition(&builder, database.And(conditions...)) + + return scanIDProvider(ctx, i.client, &builder) +} + +func (i *idProvider) List(ctx context.Context, conditions ...database.Condition) ([]*domain.IdentityProvider, error) { + builder := database.StatementBuilder{} + + builder.WriteString(queryIDProviderStmt) + + if conditions != nil { + writeCondition(&builder, database.And(conditions...)) + } + + orderBy := database.OrderBy(i.CreatedAtColumn()) + orderBy.Write(&builder) + + return scanIDProviders(ctx, i.client, &builder) +} + +const createIDProviderStmt = `INSERT INTO zitadel.identity_providers` + + ` (instance_id, org_id, id, state, name, type, allow_creation, allow_auto_creation,` + + ` allow_auto_update, allow_linking, styling_type, payload)` + + ` VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12)` + + ` RETURNING created_at, updated_at` + +func (i *idProvider) Create(ctx context.Context, idp *domain.IdentityProvider) error { + builder := database.StatementBuilder{} + builder.AppendArgs( + idp.InstanceID, + idp.OrgID, + idp.ID, + idp.State, + idp.Name, + idp.Type, + idp.AllowCreation, + idp.AllowAutoCreation, + idp.AllowAutoUpdate, + idp.AllowLinking, + idp.StylingType, + idp.Payload) + builder.WriteString(createIDProviderStmt) + + err := i.client.QueryRow(ctx, builder.String(), builder.Args()...).Scan(&idp.CreatedAt, &idp.UpdatedAt) + if err != nil { + return checkCreateOrgErr(err) + } + return nil +} + +func (i *idProvider) Update(ctx context.Context, id domain.IDPIdentifierCondition, instnaceID string, orgID *string, changes ...database.Change) (int64, error) { + if changes == nil { + return 0, errors.New("Update must contain at least one change") + } + builder := database.StatementBuilder{} + builder.WriteString(`UPDATE zitadel.identity_providers SET `) + + conditions := []database.Condition{ + id, + i.InstanceIDCondition(instnaceID), + i.OrgIDCondition(orgID), + } + database.Changes(changes).Write(&builder) + writeCondition(&builder, database.And(conditions...)) + + stmt := builder.String() + + return i.client.Exec(ctx, stmt, builder.Args()...) +} + +func (i *idProvider) Delete(ctx context.Context, id domain.IDPIdentifierCondition, instnaceID string, orgID *string) (int64, error) { + builder := database.StatementBuilder{} + + builder.WriteString(`DELETE FROM zitadel.identity_providers`) + + conditions := []database.Condition{ + id, + i.InstanceIDCondition(instnaceID), + i.OrgIDCondition(orgID), + } + writeCondition(&builder, database.And(conditions...)) + + return i.client.Exec(ctx, builder.String(), builder.Args()...) +} + +// ------------------------------------------------------------- +// columns +// ------------------------------------------------------------- + +func (idProvider) InstanceIDColumn() database.Column { + return database.NewColumn("instance_id") +} + +func (idProvider) OrgIDColumn() database.Column { + return database.NewColumn("org_id") +} + +func (idProvider) IDColumn() database.Column { + return database.NewColumn("id") +} + +func (idProvider) StateColumn() database.Column { + return database.NewColumn("state") +} + +func (idProvider) NameColumn() database.Column { + return database.NewColumn("name") +} + +func (idProvider) TypeColumn() database.Column { + return database.NewColumn("type") +} + +func (idProvider) AutoRegisterColumn() database.Column { + return database.NewColumn("auto_register") +} + +func (idProvider) AllowCreationColumn() database.Column { + return database.NewColumn("allow_creation") +} + +func (idProvider) AllowAutoCreationColumn() database.Column { + return database.NewColumn("allow_auto_creation") +} + +func (idProvider) AllowAutoUpdateColumn() database.Column { + return database.NewColumn("allow_auto_update") +} + +func (idProvider) AllowLinkingColumn() database.Column { + return database.NewColumn("allow_linking") +} + +func (idProvider) AllowAutoLinkingColumn() database.Column { + return database.NewColumn("allow_auto_linking") +} + +func (idProvider) StylingTypeColumn() database.Column { + return database.NewColumn("styling_type") +} + +func (idProvider) PayloadColumn() database.Column { + return database.NewColumn("payload") +} + +func (idProvider) CreatedAtColumn() database.Column { + return database.NewColumn("created_at") +} + +func (idProvider) UpdatedAtColumn() database.Column { + return database.NewColumn("updated_at") +} + +// ------------------------------------------------------------- +// conditions +// ------------------------------------------------------------- + +func (i idProvider) InstanceIDCondition(id string) database.Condition { + return database.NewTextCondition(i.InstanceIDColumn(), database.TextOperationEqual, id) +} + +func (i idProvider) OrgIDCondition(id *string) database.Condition { + if id == nil { + return nil + } + return database.NewTextCondition(i.OrgIDColumn(), database.TextOperationEqual, *id) +} + +func (i idProvider) IDCondition(id string) domain.IDPIdentifierCondition { + return database.NewTextCondition(i.IDColumn(), database.TextOperationEqual, id) +} + +func (i idProvider) StateCondition(state domain.IDPState) database.Condition { + return database.NewTextCondition(i.StateColumn(), database.TextOperationEqual, state.String()) +} + +func (i idProvider) NameCondition(name string) domain.IDPIdentifierCondition { + return database.NewTextCondition(i.NameColumn(), database.TextOperationEqual, name) +} + +func (i idProvider) TypeCondition(typee domain.IDPType) database.Condition { + return database.NewTextCondition(i.TypeColumn(), database.TextOperationEqual, typee.String()) +} + +func (i idProvider) AutoRegisterCondition(allow bool) database.Condition { + return database.NewBooleanCondition(i.AutoRegisterColumn(), allow) +} + +func (i idProvider) AllowCreationCondition(allow bool) database.Condition { + return database.NewBooleanCondition(i.AllowCreationColumn(), allow) +} + +func (i idProvider) AllowAutoCreationCondition(allow bool) database.Condition { + return database.NewBooleanCondition(i.AllowAutoCreationColumn(), allow) +} + +func (i idProvider) AllowAutoUpdateCondition(allow bool) database.Condition { + return database.NewBooleanCondition(i.AllowAutoUpdateColumn(), allow) +} + +func (i idProvider) AllowLinkingCondition(allow bool) database.Condition { + return database.NewBooleanCondition(i.AllowLinkingColumn(), allow) +} + +func (i idProvider) AllowAutoLinkingCondition(allow bool) database.Condition { + return database.NewBooleanCondition(i.AllowAutoLinkingColumn(), allow) +} + +func (i idProvider) StylingTypeCondition(style int16) database.Condition { + return database.NewNumberCondition(i.StylingTypeColumn(), database.NumberOperationEqual, style) +} + +func (i idProvider) PayloadCondition(payload string) database.Condition { + return database.NewTextCondition(i.PayloadColumn(), database.TextOperationEqual, payload) +} + +// ------------------------------------------------------------- +// changes +// ------------------------------------------------------------- + +func (i idProvider) SetName(name string) database.Change { + return database.NewChange(i.NameColumn(), name) +} + +func (i idProvider) SetState(state domain.IDPState) database.Change { + return database.NewChange(i.StateColumn(), state) +} + +func (i idProvider) SetAllowCreation(allow bool) database.Change { + return database.NewChange(i.AllowCreationColumn(), allow) +} + +func (i idProvider) SetAutoRegister(allow bool) database.Change { + return database.NewChange(i.AutoRegisterColumn(), allow) +} + +func (i idProvider) SetAllowAutoCreation(allow bool) database.Change { + return database.NewChange(i.AllowAutoCreationColumn(), allow) +} + +func (i idProvider) SetAllowAutoUpdate(allow bool) database.Change { + return database.NewChange(i.AllowAutoUpdateColumn(), allow) +} + +func (i idProvider) SetAllowLinking(allow bool) database.Change { + return database.NewChange(i.AllowLinkingColumn(), allow) +} + +func (i idProvider) SetAutoAllowLinking(allow bool) database.Change { + return database.NewChange(i.AllowAutoLinkingColumn(), allow) +} + +func (i idProvider) SetStylingType(stylingType int16) database.Change { + return database.NewChange(i.StylingTypeColumn(), stylingType) +} + +func (i idProvider) SetPayload(payload string) database.Change { + return database.NewChange(i.PayloadColumn(), payload) +} + +func scanIDProvider(ctx context.Context, querier database.Querier, builder *database.StatementBuilder) (*domain.IdentityProvider, error) { + idp := &domain.IdentityProvider{} + err := scanRow(ctx, querier, builder, idp) + if err != nil { + return nil, err + } + return idp, err +} + +func scanIDProviders(ctx context.Context, querier database.Querier, builder *database.StatementBuilder) ([]*domain.IdentityProvider, error) { + idps := []*domain.IdentityProvider{} + err := scanRows(ctx, querier, builder, &idps) + if err != nil { + return nil, err + } + return idps, nil +} diff --git a/backend/v3/storage/database/repository/id_provider_test.go b/backend/v3/storage/database/repository/id_provider_test.go new file mode 100644 index 0000000000..e4371f091b --- /dev/null +++ b/backend/v3/storage/database/repository/id_provider_test.go @@ -0,0 +1,1873 @@ +package repository_test + +import ( + "context" + "testing" + "time" + + "github.com/muhlemmer/gu" + + "github.com/brianvoe/gofakeit/v6" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/zitadel/zitadel/backend/v3/domain" + "github.com/zitadel/zitadel/backend/v3/storage/database" + "github.com/zitadel/zitadel/backend/v3/storage/database/repository" +) + +func TestCreateIDProvider(t *testing.T) { + // create instance + instanceId := gofakeit.Name() + instance := domain.Instance{ + ID: instanceId, + Name: gofakeit.Name(), + DefaultOrgID: "defaultOrgId", + IAMProjectID: "iamProject", + ConsoleClientID: "consoleCLient", + ConsoleAppID: "consoleApp", + DefaultLanguage: "defaultLanguage", + } + instanceRepo := repository.InstanceRepository(pool) + err := instanceRepo.Create(t.Context(), &instance) + require.NoError(t, err) + + // create org + orgId := gofakeit.Name() + org := domain.Organization{ + ID: orgId, + Name: gofakeit.Name(), + InstanceID: instanceId, + State: domain.OrgStateActive.String(), + } + organizationRepo := repository.OrganizationRepository(pool) + err = organizationRepo.Create(t.Context(), &org) + require.NoError(t, err) + + type test struct { + name string + testFunc func(ctx context.Context, t *testing.T) *domain.IdentityProvider + idp domain.IdentityProvider + err error + } + + // TESTS + tests := []test{ + { + name: "happy path", + idp: domain.IdentityProvider{ + InstanceID: instanceId, + OrgID: &orgId, + ID: gofakeit.Name(), + State: domain.IDPStateActive.String(), + Name: gofakeit.Name(), + Type: domain.IDPTypeOIDC.String(), + AllowCreation: true, + AllowAutoCreation: true, + AllowAutoUpdate: true, + AllowLinking: true, + StylingType: 1, + Payload: gu.Ptr("{}"), + }, + }, + { + name: "create organization without name", + idp: domain.IdentityProvider{ + InstanceID: instanceId, + OrgID: &orgId, + ID: gofakeit.Name(), + State: domain.IDPStateActive.String(), + // Name: gofakeit.Name(), + Type: domain.IDPTypeOIDC.String(), + AllowCreation: true, + AllowAutoCreation: true, + AllowAutoUpdate: true, + AllowLinking: true, + StylingType: 1, + Payload: gu.Ptr("{}"), + }, + err: new(database.CheckError), + }, + { + name: "adding org with same id twice", + testFunc: func(ctx context.Context, t *testing.T) *domain.IdentityProvider { + idpRepo := repository.IDProviderRepository(pool) + + idp := domain.IdentityProvider{ + InstanceID: instanceId, + OrgID: &orgId, + ID: gofakeit.Name(), + State: domain.IDPStateActive.String(), + Name: gofakeit.Name(), + Type: domain.IDPTypeOIDC.String(), + AllowCreation: true, + AllowAutoCreation: true, + AllowAutoUpdate: true, + AllowLinking: true, + StylingType: 1, + Payload: gu.Ptr("{}"), + } + + err := idpRepo.Create(ctx, &idp) + require.NoError(t, err) + // change the name to make sure same only the id clashes + org.Name = gofakeit.Name() + return &idp + }, + err: new(database.UniqueError), + }, + { + name: "adding org with same name twice", + testFunc: func(ctx context.Context, t *testing.T) *domain.IdentityProvider { + idpRepo := repository.IDProviderRepository(pool) + + idp := domain.IdentityProvider{ + InstanceID: instanceId, + OrgID: &orgId, + ID: gofakeit.Name(), + State: domain.IDPStateActive.String(), + Name: gofakeit.Name(), + Type: domain.IDPTypeOIDC.String(), + AllowCreation: true, + AllowAutoCreation: true, + AllowAutoUpdate: true, + AllowLinking: true, + StylingType: 1, + Payload: gu.Ptr("{}"), + } + + err := idpRepo.Create(ctx, &idp) + require.NoError(t, err) + // change the id to make sure same name causes an error + idp.ID = gofakeit.Name() + return &idp + }, + err: new(database.UniqueError), + }, + func() test { + id := gofakeit.Name() + name := gofakeit.Name() + return test{ + name: "adding org with same name, id, different instance", + testFunc: func(ctx context.Context, t *testing.T) *domain.IdentityProvider { + // create instance + newInstId := gofakeit.Name() + instance := domain.Instance{ + ID: newInstId, + Name: gofakeit.Name(), + DefaultOrgID: "defaultOrgId", + IAMProjectID: "iamProject", + ConsoleClientID: "consoleCLient", + ConsoleAppID: "consoleApp", + DefaultLanguage: "defaultLanguage", + } + instanceRepo := repository.InstanceRepository(pool) + err := instanceRepo.Create(ctx, &instance) + assert.Nil(t, err) + + // create org + newOrgId := gofakeit.Name() + org := domain.Organization{ + ID: newOrgId, + Name: gofakeit.Name(), + InstanceID: newInstId, + State: domain.OrgStateActive.String(), + } + organizationRepo := repository.OrganizationRepository(pool) + err = organizationRepo.Create(t.Context(), &org) + require.NoError(t, err) + + idpRepo := repository.IDProviderRepository(pool) + idp := domain.IdentityProvider{ + InstanceID: newInstId, + OrgID: &newOrgId, + ID: id, + State: domain.IDPStateActive.String(), + Name: name, + Type: domain.IDPTypeOIDC.String(), + AllowCreation: true, + AllowAutoCreation: true, + AllowAutoUpdate: true, + AllowLinking: true, + StylingType: 1, + Payload: gu.Ptr("{}"), + } + + err = idpRepo.Create(ctx, &idp) + require.NoError(t, err) + + // change the instanceID to a different instance + idp.InstanceID = instanceId + // change the OrgId to a different organization + idp.OrgID = &orgId + return &idp + }, + idp: domain.IdentityProvider{ + InstanceID: instanceId, + OrgID: &orgId, + ID: id, + State: domain.IDPStateActive.String(), + Name: name, + Type: domain.IDPTypeOIDC.String(), + AllowCreation: true, + AllowAutoCreation: true, + AllowAutoUpdate: true, + AllowLinking: true, + StylingType: 1, + Payload: gu.Ptr("{}"), + }, + } + }(), + { + name: "adding idp with no id", + idp: domain.IdentityProvider{ + InstanceID: instanceId, + OrgID: &orgId, + // ID: gofakeit.Name(), + State: domain.IDPStateActive.String(), + Name: gofakeit.Name(), + Type: domain.IDPTypeOIDC.String(), + AllowCreation: true, + AllowAutoCreation: true, + AllowAutoUpdate: true, + AllowLinking: true, + StylingType: 1, + Payload: gu.Ptr("{}"), + }, + err: new(database.CheckError), + }, + { + name: "adding idp with no name", + idp: domain.IdentityProvider{ + InstanceID: instanceId, + OrgID: &orgId, + ID: gofakeit.Name(), + State: domain.IDPStateActive.String(), + // Name: gofakeit.Name(), + Type: domain.IDPTypeOIDC.String(), + AllowCreation: true, + AllowAutoCreation: true, + AllowAutoUpdate: true, + AllowLinking: true, + StylingType: 1, + Payload: gu.Ptr("{}"), + }, + err: new(database.CheckError), + }, + { + name: "adding idp with no instance id", + idp: domain.IdentityProvider{ + // InstanceID: instanceId, + OrgID: &orgId, + State: domain.IDPStateActive.String(), + Name: gofakeit.Name(), + Type: domain.IDPTypeOIDC.String(), + AllowCreation: true, + AllowAutoCreation: true, + AllowAutoUpdate: true, + AllowLinking: true, + StylingType: 1, + Payload: gu.Ptr("{}"), + }, + err: new(database.IntegrityViolationError), + }, + { + name: "adding organization with non existent instance id", + idp: domain.IdentityProvider{ + InstanceID: gofakeit.Name(), + OrgID: &orgId, + State: domain.IDPStateActive.String(), + ID: gofakeit.Name(), + Name: gofakeit.Name(), + Type: domain.IDPTypeOIDC.String(), + AllowCreation: true, + AllowAutoCreation: true, + AllowAutoUpdate: true, + AllowLinking: true, + StylingType: 1, + Payload: gu.Ptr("{}"), + }, + err: new(database.ForeignKeyError), + }, + { + name: "adding organization with non existent org id", + idp: domain.IdentityProvider{ + InstanceID: instanceId, + OrgID: gu.Ptr(gofakeit.Name()), + State: domain.IDPStateActive.String(), + ID: gofakeit.Name(), + Name: gofakeit.Name(), + Type: domain.IDPTypeOIDC.String(), + AllowCreation: true, + AllowAutoCreation: true, + AllowAutoUpdate: true, + AllowLinking: true, + StylingType: 1, + Payload: gu.Ptr("{}"), + }, + err: new(database.ForeignKeyError), + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctx := context.Background() + + var idp *domain.IdentityProvider + if tt.testFunc != nil { + idp = tt.testFunc(ctx, t) + } else { + idp = &tt.idp + } + idpRepo := repository.IDProviderRepository(pool) + + // create idp + beforeCreate := time.Now() + err = idpRepo.Create(ctx, idp) + assert.ErrorIs(t, err, tt.err) + if err != nil { + return + } + afterCreate := time.Now() + + // check organization values + idp, err = idpRepo.Get(ctx, + idpRepo.IDCondition(idp.ID), + idp.InstanceID, + idp.OrgID, + ) + require.NoError(t, err) + + assert.Equal(t, tt.idp.InstanceID, idp.InstanceID) + assert.Equal(t, tt.idp.OrgID, idp.OrgID) + assert.Equal(t, tt.idp.State, idp.State) + assert.Equal(t, tt.idp.ID, idp.ID) + assert.Equal(t, tt.idp.Name, idp.Name) + assert.Equal(t, tt.idp.Type, idp.Type) + assert.Equal(t, tt.idp.AllowCreation, idp.AllowCreation) + assert.Equal(t, tt.idp.AllowAutoCreation, idp.AllowAutoCreation) + assert.Equal(t, tt.idp.AllowAutoUpdate, idp.AllowAutoUpdate) + assert.Equal(t, tt.idp.AllowLinking, idp.AllowLinking) + assert.Equal(t, tt.idp.StylingType, idp.StylingType) + assert.Equal(t, tt.idp.Payload, idp.Payload) + assert.WithinRange(t, idp.CreatedAt, beforeCreate, afterCreate) + assert.WithinRange(t, idp.UpdatedAt, beforeCreate, afterCreate) + }) + } +} + +func TestUpdateIDProvider(t *testing.T) { + // create instance + instanceId := gofakeit.Name() + instance := domain.Instance{ + ID: instanceId, + Name: gofakeit.Name(), + DefaultOrgID: "defaultOrgId", + IAMProjectID: "iamProject", + ConsoleClientID: "consoleCLient", + ConsoleAppID: "consoleApp", + DefaultLanguage: "defaultLanguage", + } + instanceRepo := repository.InstanceRepository(pool) + err := instanceRepo.Create(t.Context(), &instance) + require.NoError(t, err) + + // create org + orgId := gofakeit.Name() + org := domain.Organization{ + ID: orgId, + Name: gofakeit.Name(), + InstanceID: instanceId, + State: domain.OrgStateActive.String(), + } + organizationRepo := repository.OrganizationRepository(pool) + err = organizationRepo.Create(t.Context(), &org) + require.NoError(t, err) + + idpRepo := repository.IDProviderRepository(pool) + + tests := []struct { + name string + testFunc func(ctx context.Context, t *testing.T) *domain.IdentityProvider + update []database.Change + rowsAffected int64 + }{ + { + name: "happy path update name", + testFunc: func(ctx context.Context, t *testing.T) *domain.IdentityProvider { + idp := domain.IdentityProvider{ + InstanceID: instanceId, + OrgID: &orgId, + ID: gofakeit.Name(), + State: domain.IDPStateActive.String(), + Name: gofakeit.Name(), + Type: domain.IDPTypeOIDC.String(), + AllowCreation: true, + AllowAutoCreation: true, + AllowAutoUpdate: true, + AllowLinking: true, + StylingType: 1, + Payload: gu.Ptr("{}"), + } + + err := idpRepo.Create(ctx, &idp) + require.NoError(t, err) + idp.Name = "new_name" + return &idp + }, + update: []database.Change{idpRepo.SetName("new_name")}, + rowsAffected: 1, + }, + { + name: "happy path update state", + testFunc: func(ctx context.Context, t *testing.T) *domain.IdentityProvider { + idp := domain.IdentityProvider{ + InstanceID: instanceId, + OrgID: &orgId, + ID: gofakeit.Name(), + State: domain.IDPStateActive.String(), + Name: gofakeit.Name(), + Type: domain.IDPTypeOIDC.String(), + AllowCreation: true, + AllowAutoCreation: true, + AllowAutoUpdate: true, + AllowLinking: true, + StylingType: 1, + Payload: gu.Ptr("{}"), + } + + err := idpRepo.Create(ctx, &idp) + require.NoError(t, err) + idp.State = domain.IDPStateInactive.String() + return &idp + }, + update: []database.Change{idpRepo.SetState(domain.IDPStateInactive)}, + rowsAffected: 1, + }, + { + name: "happy path update AllowCreation", + testFunc: func(ctx context.Context, t *testing.T) *domain.IdentityProvider { + idp := domain.IdentityProvider{ + InstanceID: instanceId, + OrgID: &orgId, + ID: gofakeit.Name(), + State: domain.IDPStateActive.String(), + Name: gofakeit.Name(), + Type: domain.IDPTypeOIDC.String(), + AllowCreation: true, + AllowAutoCreation: true, + AllowAutoUpdate: true, + AllowLinking: true, + StylingType: 1, + Payload: gu.Ptr("{}"), + } + + err := idpRepo.Create(ctx, &idp) + require.NoError(t, err) + idp.AllowCreation = false + return &idp + }, + update: []database.Change{idpRepo.SetAllowCreation(false)}, + rowsAffected: 1, + }, + { + name: "happy path update AllowAutoCreation", + testFunc: func(ctx context.Context, t *testing.T) *domain.IdentityProvider { + idp := domain.IdentityProvider{ + InstanceID: instanceId, + OrgID: &orgId, + ID: gofakeit.Name(), + State: domain.IDPStateActive.String(), + Name: gofakeit.Name(), + Type: domain.IDPTypeOIDC.String(), + AllowCreation: true, + AllowAutoCreation: true, + AllowAutoUpdate: true, + AllowLinking: true, + StylingType: 1, + Payload: gu.Ptr("{}"), + } + + err := idpRepo.Create(ctx, &idp) + require.NoError(t, err) + idp.AllowAutoCreation = false + return &idp + }, + update: []database.Change{idpRepo.SetAllowAutoCreation(false)}, + rowsAffected: 1, + }, + { + name: "happy path update AllowLinking", + testFunc: func(ctx context.Context, t *testing.T) *domain.IdentityProvider { + idp := domain.IdentityProvider{ + InstanceID: instanceId, + OrgID: &orgId, + ID: gofakeit.Name(), + State: domain.IDPStateActive.String(), + Name: gofakeit.Name(), + Type: domain.IDPTypeOIDC.String(), + AllowCreation: true, + AllowAutoCreation: true, + AllowAutoUpdate: true, + AllowLinking: true, + StylingType: 1, + Payload: gu.Ptr("{}"), + } + + err := idpRepo.Create(ctx, &idp) + require.NoError(t, err) + idp.AllowLinking = false + return &idp + }, + update: []database.Change{idpRepo.SetAllowLinking(false)}, + rowsAffected: 1, + }, + { + name: "happy path update StylingType", + testFunc: func(ctx context.Context, t *testing.T) *domain.IdentityProvider { + idp := domain.IdentityProvider{ + InstanceID: instanceId, + OrgID: &orgId, + ID: gofakeit.Name(), + State: domain.IDPStateActive.String(), + Name: gofakeit.Name(), + Type: domain.IDPTypeOIDC.String(), + AllowCreation: true, + AllowAutoCreation: true, + AllowAutoUpdate: true, + AllowLinking: true, + StylingType: 1, + Payload: gu.Ptr("{}"), + } + + err := idpRepo.Create(ctx, &idp) + require.NoError(t, err) + idp.StylingType = 2 + return &idp + }, + update: []database.Change{idpRepo.SetStylingType(2)}, + rowsAffected: 1, + }, + { + name: "happy path update Payload", + testFunc: func(ctx context.Context, t *testing.T) *domain.IdentityProvider { + idp := domain.IdentityProvider{ + InstanceID: instanceId, + OrgID: &orgId, + ID: gofakeit.Name(), + State: domain.IDPStateActive.String(), + Name: gofakeit.Name(), + Type: domain.IDPTypeOIDC.String(), + AllowCreation: true, + AllowAutoCreation: true, + AllowAutoUpdate: true, + AllowLinking: true, + StylingType: 1, + Payload: gu.Ptr("{}"), + } + + err := idpRepo.Create(ctx, &idp) + require.NoError(t, err) + idp.Payload = gu.Ptr(`{"json": {}}`) + return &idp + }, + // update: []database.Change{idpRepo.SetPayload("{{}}")}, + update: []database.Change{idpRepo.SetPayload(`{"json": {}}`)}, + rowsAffected: 1, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctx := t.Context() + organizationRepo := repository.OrganizationRepository(pool) + idpRepo := repository.IDProviderRepository(pool) + + createdIDP := tt.testFunc(ctx, t) + + // update idp + beforeUpdate := time.Now() + rowsAffected, err := idpRepo.Update(ctx, + idpRepo.IDCondition(createdIDP.ID), + createdIDP.InstanceID, + createdIDP.OrgID, + tt.update..., + ) + afterUpdate := time.Now() + require.NoError(t, err) + + assert.Equal(t, tt.rowsAffected, rowsAffected) + + if rowsAffected == 0 { + return + } + + // check idp values + idp, err := idpRepo.Get(ctx, + organizationRepo.IDCondition(createdIDP.ID), + createdIDP.InstanceID, + createdIDP.OrgID, + ) + require.NoError(t, err) + + assert.Equal(t, createdIDP.InstanceID, idp.InstanceID) + assert.Equal(t, createdIDP.OrgID, idp.OrgID) + assert.Equal(t, createdIDP.State, idp.State) + assert.Equal(t, createdIDP.ID, idp.ID) + assert.Equal(t, createdIDP.Name, idp.Name) + assert.Equal(t, createdIDP.Type, idp.Type) + assert.Equal(t, createdIDP.AllowCreation, idp.AllowCreation) + assert.Equal(t, createdIDP.AllowAutoCreation, idp.AllowAutoCreation) + assert.Equal(t, createdIDP.AllowAutoUpdate, idp.AllowAutoUpdate) + assert.Equal(t, createdIDP.AllowLinking, idp.AllowLinking) + assert.Equal(t, createdIDP.StylingType, idp.StylingType) + assert.Equal(t, createdIDP.Payload, idp.Payload) + assert.WithinRange(t, idp.UpdatedAt, beforeUpdate, afterUpdate) + }) + } +} + +func TestGetIDProvider(t *testing.T) { + // create instance + instanceId := gofakeit.Name() + instance := domain.Instance{ + ID: instanceId, + Name: gofakeit.Name(), + DefaultOrgID: "defaultOrgId", + IAMProjectID: "iamProject", + ConsoleClientID: "consoleCLient", + ConsoleAppID: "consoleApp", + DefaultLanguage: "defaultLanguage", + } + instanceRepo := repository.InstanceRepository(pool) + err := instanceRepo.Create(t.Context(), &instance) + require.NoError(t, err) + + // create org + orgId := gofakeit.Name() + org := domain.Organization{ + ID: orgId, + Name: gofakeit.Name(), + InstanceID: instanceId, + State: domain.OrgStateActive.String(), + } + organizationRepo := repository.OrganizationRepository(pool) + err = organizationRepo.Create(t.Context(), &org) + require.NoError(t, err) + + // create organization + // this org is created as an additional org which should NOT + // be returned in the results of the tests + preexistingOrg := domain.Organization{ + ID: gofakeit.Name(), + Name: gofakeit.Name(), + InstanceID: instanceId, + State: domain.OrgStateActive.String(), + } + err = organizationRepo.Create(t.Context(), &preexistingOrg) + require.NoError(t, err) + + idpRepo := repository.IDProviderRepository(pool) + type test struct { + name string + testFunc func(ctx context.Context, t *testing.T) *domain.IdentityProvider + idpIdentifierCondition domain.OrgIdentifierCondition + err error + } + + tests := []test{ + func() test { + id := gofakeit.Name() + return test{ + name: "happy path get using id", + testFunc: func(ctx context.Context, t *testing.T) *domain.IdentityProvider { + idp := domain.IdentityProvider{ + InstanceID: instanceId, + OrgID: &orgId, + ID: id, + State: domain.IDPStateActive.String(), + Name: gofakeit.Name(), + Type: domain.IDPTypeOIDC.String(), + AllowCreation: true, + AllowAutoCreation: true, + AllowAutoUpdate: true, + AllowLinking: true, + StylingType: 1, + Payload: gu.Ptr("{}"), + } + + err := idpRepo.Create(ctx, &idp) + require.NoError(t, err) + return &idp + }, + idpIdentifierCondition: idpRepo.IDCondition(id), + } + }(), + func() test { + name := gofakeit.Name() + return test{ + name: "happy path get using name", + testFunc: func(ctx context.Context, t *testing.T) *domain.IdentityProvider { + idp := domain.IdentityProvider{ + InstanceID: instanceId, + OrgID: &orgId, + ID: gofakeit.Name(), + State: domain.IDPStateActive.String(), + Name: name, + Type: domain.IDPTypeOIDC.String(), + AllowCreation: true, + AllowAutoCreation: true, + AllowAutoUpdate: true, + AllowLinking: true, + StylingType: 1, + Payload: gu.Ptr("{}"), + } + + err := idpRepo.Create(ctx, &idp) + require.NoError(t, err) + return &idp + }, + idpIdentifierCondition: idpRepo.NameCondition(name), + } + }(), + { + name: "get using non existent id", + testFunc: func(ctx context.Context, t *testing.T) *domain.IdentityProvider { + idp := domain.IdentityProvider{ + InstanceID: instanceId, + OrgID: &orgId, + ID: gofakeit.Name(), + State: domain.IDPStateActive.String(), + Name: gofakeit.Name(), + Type: domain.IDPTypeOIDC.String(), + AllowCreation: true, + AllowAutoCreation: true, + AllowAutoUpdate: true, + AllowLinking: true, + StylingType: 1, + Payload: gu.Ptr("{}"), + } + + err := idpRepo.Create(ctx, &idp) + require.NoError(t, err) + return &idp + }, + idpIdentifierCondition: idpRepo.IDCondition("non-existent-id"), + err: new(database.NoRowFoundError), + }, + { + name: "get using non existent name", + testFunc: func(ctx context.Context, t *testing.T) *domain.IdentityProvider { + idp := domain.IdentityProvider{ + InstanceID: instanceId, + OrgID: &orgId, + ID: gofakeit.Name(), + State: domain.IDPStateActive.String(), + Name: gofakeit.Name(), + Type: domain.IDPTypeOIDC.String(), + AllowCreation: true, + AllowAutoCreation: true, + AllowAutoUpdate: true, + AllowLinking: true, + StylingType: 1, + Payload: gu.Ptr("{}"), + } + + err := idpRepo.Create(ctx, &idp) + require.NoError(t, err) + return &idp + }, + idpIdentifierCondition: idpRepo.NameCondition("non-existent-name"), + err: new(database.NoRowFoundError), + }, + //////// + func() test { + id := gofakeit.Name() + return test{ + name: "non existent orgID", + testFunc: func(ctx context.Context, t *testing.T) *domain.IdentityProvider { + idp := domain.IdentityProvider{ + InstanceID: instanceId, + OrgID: &orgId, + ID: id, + State: domain.IDPStateActive.String(), + Name: gofakeit.Name(), + Type: domain.IDPTypeOIDC.String(), + AllowCreation: true, + AllowAutoCreation: true, + AllowAutoUpdate: true, + AllowLinking: true, + StylingType: 1, + Payload: gu.Ptr("{}"), + } + + err := idpRepo.Create(ctx, &idp) + require.NoError(t, err) + idp.OrgID = gu.Ptr("non-existent-orgID") + return &idp + }, + idpIdentifierCondition: idpRepo.IDCondition(id), + err: new(database.NoRowFoundError), + } + }(), + func() test { + name := gofakeit.Name() + return test{ + name: "non existent instanceID", + testFunc: func(ctx context.Context, t *testing.T) *domain.IdentityProvider { + idp := domain.IdentityProvider{ + InstanceID: instanceId, + OrgID: &orgId, + ID: gofakeit.Name(), + State: domain.IDPStateActive.String(), + Name: name, + Type: domain.IDPTypeOIDC.String(), + AllowCreation: true, + AllowAutoCreation: true, + AllowAutoUpdate: true, + AllowLinking: true, + StylingType: 1, + Payload: gu.Ptr("{}"), + } + + err := idpRepo.Create(ctx, &idp) + require.NoError(t, err) + idp.InstanceID = "non-existent-instnaceID" + return &idp + }, + idpIdentifierCondition: idpRepo.NameCondition(name), + err: new(database.NoRowFoundError), + } + }(), + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctx := t.Context() + + var idp *domain.IdentityProvider + if tt.testFunc != nil { + idp = tt.testFunc(ctx, t) + } + + // get idp + returnedIDP, err := idpRepo.Get(ctx, + tt.idpIdentifierCondition, + idp.InstanceID, + idp.OrgID, + ) + if err != nil { + require.ErrorIs(t, tt.err, err) + return + } + + assert.Equal(t, returnedIDP.InstanceID, idp.InstanceID) + assert.Equal(t, returnedIDP.OrgID, idp.OrgID) + assert.Equal(t, returnedIDP.State, idp.State) + assert.Equal(t, returnedIDP.ID, idp.ID) + assert.Equal(t, returnedIDP.Name, idp.Name) + assert.Equal(t, returnedIDP.Type, idp.Type) + assert.Equal(t, returnedIDP.AllowCreation, idp.AllowCreation) + assert.Equal(t, returnedIDP.AllowAutoCreation, idp.AllowAutoCreation) + assert.Equal(t, returnedIDP.AllowAutoUpdate, idp.AllowAutoUpdate) + assert.Equal(t, returnedIDP.AllowLinking, idp.AllowLinking) + assert.Equal(t, returnedIDP.StylingType, idp.StylingType) + assert.Equal(t, returnedIDP.Payload, idp.Payload) + }) + } +} + +func TestListIDProvider(t *testing.T) { + ctx := t.Context() + pool, stop, err := newEmbeddedDB(ctx) + require.NoError(t, err) + defer stop() + + // create instance + instanceId := gofakeit.Name() + instance := domain.Instance{ + ID: instanceId, + Name: gofakeit.Name(), + DefaultOrgID: "defaultOrgId", + IAMProjectID: "iamProject", + ConsoleClientID: "consoleCLient", + ConsoleAppID: "consoleApp", + DefaultLanguage: "defaultLanguage", + } + instanceRepo := repository.InstanceRepository(pool) + err = instanceRepo.Create(ctx, &instance) + require.NoError(t, err) + + // create org + orgId := gofakeit.Name() + org := domain.Organization{ + ID: orgId, + Name: gofakeit.Name(), + InstanceID: instanceId, + State: domain.OrgStateActive.String(), + } + organizationRepo := repository.OrganizationRepository(pool) + err = organizationRepo.Create(t.Context(), &org) + require.NoError(t, err) + + idpRepo := repository.IDProviderRepository(pool) + + type test struct { + name string + testFunc func(ctx context.Context, t *testing.T) []*domain.IdentityProvider + conditionClauses []database.Condition + noIDPsReturned bool + } + tests := []test{ + { + name: "multiple idps filter on instance", + testFunc: func(ctx context.Context, t *testing.T) []*domain.IdentityProvider { + // create instance + newInstanceId := gofakeit.Name() + instance := domain.Instance{ + ID: newInstanceId, + Name: gofakeit.Name(), + DefaultOrgID: "defaultOrgId", + IAMProjectID: "iamProject", + ConsoleClientID: "consoleCLient", + ConsoleAppID: "consoleApp", + DefaultLanguage: "defaultLanguage", + } + err = instanceRepo.Create(ctx, &instance) + require.NoError(t, err) + + // create org + newOrgId := gofakeit.Name() + org := domain.Organization{ + ID: newOrgId, + Name: gofakeit.Name(), + InstanceID: newInstanceId, + State: domain.OrgStateActive.String(), + } + organizationRepo := repository.OrganizationRepository(pool) + err = organizationRepo.Create(t.Context(), &org) + require.NoError(t, err) + + // create idp + // this idp is created as an additional idp which should NOT + // be returned in the results of this test case + idp := domain.IdentityProvider{ + InstanceID: newInstanceId, + OrgID: &newOrgId, + ID: gofakeit.Name(), + State: domain.IDPStateActive.String(), + Name: gofakeit.Name(), + Type: domain.IDPTypeOIDC.String(), + AllowCreation: true, + AllowAutoCreation: true, + AllowAutoUpdate: true, + AllowLinking: true, + StylingType: 1, + Payload: gu.Ptr("{}"), + } + err := idpRepo.Create(ctx, &idp) + require.NoError(t, err) + + noOfIDPs := 5 + idps := make([]*domain.IdentityProvider, noOfIDPs) + for i := range noOfIDPs { + + idp := domain.IdentityProvider{ + InstanceID: instanceId, + OrgID: &orgId, + ID: gofakeit.Name(), + State: domain.IDPStateActive.String(), + Name: gofakeit.Name(), + Type: domain.IDPTypeOIDC.String(), + AllowCreation: true, + AllowAutoCreation: true, + AllowAutoUpdate: true, + AllowLinking: true, + StylingType: 1, + Payload: gu.Ptr("{}"), + } + + err := idpRepo.Create(ctx, &idp) + require.NoError(t, err) + + idps[i] = &idp + } + + return idps + }, + conditionClauses: []database.Condition{idpRepo.InstanceIDCondition(instanceId)}, + }, + { + name: "multiple idps filter on org", + testFunc: func(ctx context.Context, t *testing.T) []*domain.IdentityProvider { + // create org + newOrgId := gofakeit.Name() + org := domain.Organization{ + ID: newOrgId, + Name: gofakeit.Name(), + InstanceID: instanceId, + State: domain.OrgStateActive.String(), + } + organizationRepo := repository.OrganizationRepository(pool) + err = organizationRepo.Create(t.Context(), &org) + require.NoError(t, err) + + // create idp + // this idp is created as an additional idp which should NOT + // be returned in the results of this test case + idp := domain.IdentityProvider{ + InstanceID: instanceId, + OrgID: &newOrgId, + ID: gofakeit.Name(), + State: domain.IDPStateActive.String(), + Name: gofakeit.Name(), + Type: domain.IDPTypeOIDC.String(), + AllowCreation: true, + AllowAutoCreation: true, + AllowAutoUpdate: true, + AllowLinking: true, + StylingType: 1, + Payload: gu.Ptr("{}"), + } + err := idpRepo.Create(ctx, &idp) + require.NoError(t, err) + + noOfIDPs := 5 + idps := make([]*domain.IdentityProvider, noOfIDPs) + for i := range noOfIDPs { + + idp := domain.IdentityProvider{ + InstanceID: instanceId, + OrgID: &orgId, + ID: gofakeit.Name(), + State: domain.IDPStateActive.String(), + Name: gofakeit.Name(), + Type: domain.IDPTypeOIDC.String(), + AllowCreation: true, + AllowAutoCreation: true, + AllowAutoUpdate: true, + AllowLinking: true, + StylingType: 1, + Payload: gu.Ptr("{}"), + } + + err := idpRepo.Create(ctx, &idp) + require.NoError(t, err) + + idps[i] = &idp + } + + return idps + }, + conditionClauses: []database.Condition{idpRepo.OrgIDCondition(&orgId)}, + }, + { + name: "happy path single idp no filter", + testFunc: func(ctx context.Context, t *testing.T) []*domain.IdentityProvider { + noOfIDPs := 1 + idps := make([]*domain.IdentityProvider, noOfIDPs) + for i := range noOfIDPs { + + idp := domain.IdentityProvider{ + InstanceID: instanceId, + OrgID: &orgId, + ID: gofakeit.Name(), + State: domain.IDPStateActive.String(), + Name: gofakeit.Name(), + Type: domain.IDPTypeOIDC.String(), + AllowCreation: true, + AllowAutoCreation: true, + AllowAutoUpdate: true, + AllowLinking: true, + StylingType: 1, + Payload: gu.Ptr("{}"), + } + + err := idpRepo.Create(ctx, &idp) + require.NoError(t, err) + + idps[i] = &idp + } + + return idps + }, + }, + { + name: "happy path multiple idps no filter", + testFunc: func(ctx context.Context, t *testing.T) []*domain.IdentityProvider { + noOfIDPs := 5 + idps := make([]*domain.IdentityProvider, noOfIDPs) + for i := range noOfIDPs { + + idp := domain.IdentityProvider{ + InstanceID: instanceId, + OrgID: &orgId, + ID: gofakeit.Name(), + State: domain.IDPStateActive.String(), + Name: gofakeit.Name(), + Type: domain.IDPTypeOIDC.String(), + AllowCreation: true, + AllowAutoCreation: true, + AllowAutoUpdate: true, + AllowLinking: true, + StylingType: 1, + Payload: gu.Ptr("{}"), + } + + err := idpRepo.Create(ctx, &idp) + require.NoError(t, err) + + idps[i] = &idp + } + + return idps + }, + }, + func() test { + id := gofakeit.Name() + return test{ + name: "idp filter on id", + testFunc: func(ctx context.Context, t *testing.T) []*domain.IdentityProvider { + // create idp + // this idp is created as an additional idp which should NOT + // be returned in the results of this test case + idp := domain.IdentityProvider{ + InstanceID: instanceId, + OrgID: &orgId, + ID: gofakeit.Name(), + State: domain.IDPStateActive.String(), + Name: gofakeit.Name(), + Type: domain.IDPTypeOIDC.String(), + AllowCreation: true, + AllowAutoCreation: true, + AllowAutoUpdate: true, + AllowLinking: true, + StylingType: 1, + Payload: gu.Ptr("{}"), + } + err := idpRepo.Create(ctx, &idp) + require.NoError(t, err) + + noOfIDPs := 1 + idps := make([]*domain.IdentityProvider, noOfIDPs) + for i := range noOfIDPs { + + idp := domain.IdentityProvider{ + InstanceID: instanceId, + OrgID: &orgId, + ID: id, + State: domain.IDPStateActive.String(), + Name: gofakeit.Name(), + Type: domain.IDPTypeOIDC.String(), + AllowCreation: true, + AllowAutoCreation: true, + AllowAutoUpdate: true, + AllowLinking: true, + StylingType: 1, + Payload: gu.Ptr("{}"), + } + + err := idpRepo.Create(ctx, &idp) + require.NoError(t, err) + + idps[i] = &idp + } + + return idps + }, + conditionClauses: []database.Condition{idpRepo.IDCondition(id)}, + } + }(), + { + name: "multiple idps filter on state", + testFunc: func(ctx context.Context, t *testing.T) []*domain.IdentityProvider { + // create idp + // this idp is created as an additional idp which should NOT + // be returned in the results of this test case + idp := domain.IdentityProvider{ + InstanceID: instanceId, + OrgID: &orgId, + ID: gofakeit.Name(), + // state inactive + State: domain.IDPStateInactive.String(), + Name: gofakeit.Name(), + Type: domain.IDPTypeOIDC.String(), + AllowCreation: true, + AllowAutoCreation: true, + AllowAutoUpdate: true, + AllowLinking: true, + StylingType: 1, + Payload: gu.Ptr("{}"), + } + err := idpRepo.Create(ctx, &idp) + require.NoError(t, err) + + noOfIDPs := 5 + idps := make([]*domain.IdentityProvider, noOfIDPs) + for i := range noOfIDPs { + + idp := domain.IdentityProvider{ + InstanceID: instanceId, + OrgID: &orgId, + ID: gofakeit.Name(), + // state active + State: domain.IDPStateActive.String(), + Name: gofakeit.Name(), + Type: domain.IDPTypeOIDC.String(), + AllowCreation: true, + AllowAutoCreation: true, + AllowAutoUpdate: true, + AllowLinking: true, + StylingType: 1, + Payload: gu.Ptr("{}"), + } + + err := idpRepo.Create(ctx, &idp) + require.NoError(t, err) + + idps[i] = &idp + } + + return idps + }, + conditionClauses: []database.Condition{idpRepo.StateCondition(domain.IDPStateActive)}, + }, + func() test { + name := gofakeit.Name() + return test{ + name: "multiple idps filter on name", + testFunc: func(ctx context.Context, t *testing.T) []*domain.IdentityProvider { + // create idp + // this idp is created as an additional idp which should NOT + // be returned in the results of this test case + idp := domain.IdentityProvider{ + InstanceID: instanceId, + OrgID: &orgId, + ID: gofakeit.Name(), + State: domain.IDPStateActive.String(), + Name: gofakeit.Name(), + Type: domain.IDPTypeOIDC.String(), + AllowCreation: true, + AllowAutoCreation: true, + AllowAutoUpdate: true, + AllowLinking: true, + StylingType: 1, + Payload: gu.Ptr("{}"), + } + err := idpRepo.Create(ctx, &idp) + require.NoError(t, err) + + noOfIDPs := 1 + idps := make([]*domain.IdentityProvider, noOfIDPs) + for i := range noOfIDPs { + + idp := domain.IdentityProvider{ + InstanceID: instanceId, + OrgID: &orgId, + ID: gofakeit.Name(), + State: domain.IDPStateActive.String(), + Name: name, + Type: domain.IDPTypeOIDC.String(), + AllowCreation: true, + AllowAutoCreation: true, + AllowAutoUpdate: true, + AllowLinking: true, + StylingType: 1, + Payload: gu.Ptr("{}"), + } + + err := idpRepo.Create(ctx, &idp) + require.NoError(t, err) + + idps[i] = &idp + } + + return idps + }, + conditionClauses: []database.Condition{idpRepo.NameCondition(name)}, + } + }(), + { + name: "multiple idps filter on type", + testFunc: func(ctx context.Context, t *testing.T) []*domain.IdentityProvider { + // create idp + // this idp is created as an additional idp which should NOT + // be returned in the results of this test case + idp := domain.IdentityProvider{ + InstanceID: instanceId, + OrgID: &orgId, + ID: gofakeit.Name(), + State: domain.IDPStateInactive.String(), + Name: gofakeit.Name(), + Type: domain.IDPTypeOIDC.String(), + AllowCreation: true, + AllowAutoCreation: true, + AllowAutoUpdate: true, + AllowLinking: true, + StylingType: 1, + Payload: gu.Ptr("{}"), + } + err := idpRepo.Create(ctx, &idp) + require.NoError(t, err) + + noOfIDPs := 5 + idps := make([]*domain.IdentityProvider, noOfIDPs) + for i := range noOfIDPs { + + idp := domain.IdentityProvider{ + InstanceID: instanceId, + OrgID: &orgId, + ID: gofakeit.Name(), + State: domain.IDPStateActive.String(), + Name: gofakeit.Name(), + // type LDAP + Type: domain.IDPTypeLDAP.String(), + AllowCreation: true, + AllowAutoCreation: true, + AllowAutoUpdate: true, + AllowLinking: true, + StylingType: 1, + Payload: gu.Ptr("{}"), + } + + err := idpRepo.Create(ctx, &idp) + require.NoError(t, err) + + idps[i] = &idp + } + + return idps + }, + conditionClauses: []database.Condition{idpRepo.TypeCondition(domain.IDPTypeLDAP)}, + }, + { + name: "multiple idps filter on AllowCreation", + testFunc: func(ctx context.Context, t *testing.T) []*domain.IdentityProvider { + // create idp + // this idp is created as an additional idp which should NOT + // be returned in the results of this test case + idp := domain.IdentityProvider{ + InstanceID: instanceId, + OrgID: &orgId, + ID: gofakeit.Name(), + State: domain.IDPStateInactive.String(), + Name: gofakeit.Name(), + Type: domain.IDPTypeOIDC.String(), + AllowCreation: true, + AllowAutoCreation: true, + AllowAutoUpdate: true, + AllowLinking: true, + StylingType: 1, + Payload: gu.Ptr("{}"), + } + err := idpRepo.Create(ctx, &idp) + require.NoError(t, err) + + noOfIDPs := 5 + idps := make([]*domain.IdentityProvider, noOfIDPs) + for i := range noOfIDPs { + + idp := domain.IdentityProvider{ + InstanceID: instanceId, + OrgID: &orgId, + ID: gofakeit.Name(), + State: domain.IDPStateActive.String(), + Name: gofakeit.Name(), + Type: domain.IDPTypeLDAP.String(), + // AllowCreation set to false + AllowCreation: false, + AllowAutoCreation: true, + AllowAutoUpdate: true, + AllowLinking: true, + StylingType: 1, + Payload: gu.Ptr("{}"), + } + + err := idpRepo.Create(ctx, &idp) + require.NoError(t, err) + + idps[i] = &idp + } + + return idps + }, + conditionClauses: []database.Condition{idpRepo.AllowCreationCondition(false)}, + }, + { + name: "multiple idps filter on AllowAutoCreation", + testFunc: func(ctx context.Context, t *testing.T) []*domain.IdentityProvider { + // create idp + // this idp is created as an additional idp which should NOT + // be returned in the results of this test case + idp := domain.IdentityProvider{ + InstanceID: instanceId, + OrgID: &orgId, + ID: gofakeit.Name(), + State: domain.IDPStateInactive.String(), + Name: gofakeit.Name(), + Type: domain.IDPTypeOIDC.String(), + AllowCreation: true, + AllowAutoCreation: true, + AllowAutoUpdate: true, + AllowLinking: true, + StylingType: 1, + Payload: gu.Ptr("{}"), + } + err := idpRepo.Create(ctx, &idp) + require.NoError(t, err) + + noOfIDPs := 5 + idps := make([]*domain.IdentityProvider, noOfIDPs) + for i := range noOfIDPs { + + idp := domain.IdentityProvider{ + InstanceID: instanceId, + OrgID: &orgId, + ID: gofakeit.Name(), + State: domain.IDPStateActive.String(), + Name: gofakeit.Name(), + Type: domain.IDPTypeLDAP.String(), + AllowCreation: true, + // AllowAutoCreation set to false + AllowAutoCreation: false, + AllowAutoUpdate: true, + AllowLinking: true, + StylingType: 1, + Payload: gu.Ptr("{}"), + } + + err := idpRepo.Create(ctx, &idp) + require.NoError(t, err) + + idps[i] = &idp + } + + return idps + }, + conditionClauses: []database.Condition{idpRepo.AllowAutoCreationCondition(false)}, + }, + { + name: "multiple idps filter on AllowAutoUpdate", + testFunc: func(ctx context.Context, t *testing.T) []*domain.IdentityProvider { + // create idp + // this idp is created as an additional idp which should NOT + // be returned in the results of this test case + idp := domain.IdentityProvider{ + InstanceID: instanceId, + OrgID: &orgId, + ID: gofakeit.Name(), + // state inactive + State: domain.IDPStateInactive.String(), + Name: gofakeit.Name(), + Type: domain.IDPTypeOIDC.String(), + AllowCreation: true, + AllowAutoCreation: true, + AllowAutoUpdate: true, + AllowLinking: true, + StylingType: 1, + Payload: gu.Ptr("{}"), + } + err := idpRepo.Create(ctx, &idp) + require.NoError(t, err) + + noOfIDPs := 5 + idps := make([]*domain.IdentityProvider, noOfIDPs) + for i := range noOfIDPs { + + idp := domain.IdentityProvider{ + InstanceID: instanceId, + OrgID: &orgId, + ID: gofakeit.Name(), + State: domain.IDPStateActive.String(), + Name: gofakeit.Name(), + Type: domain.IDPTypeLDAP.String(), + AllowCreation: true, + AllowAutoCreation: true, + // AllowAutoUpdate set to false + AllowAutoUpdate: false, + AllowLinking: true, + StylingType: 1, + Payload: gu.Ptr("{}"), + } + + err := idpRepo.Create(ctx, &idp) + require.NoError(t, err) + + idps[i] = &idp + } + + return idps + }, + conditionClauses: []database.Condition{idpRepo.AllowAutoUpdateCondition(false)}, + }, + { + name: "multiple idps filter on AllowLinking", + testFunc: func(ctx context.Context, t *testing.T) []*domain.IdentityProvider { + // create idp + // this idp is created as an additional idp which should NOT + // be returned in the results of this test case + idp := domain.IdentityProvider{ + InstanceID: instanceId, + OrgID: &orgId, + ID: gofakeit.Name(), + // state inactive + State: domain.IDPStateInactive.String(), + Name: gofakeit.Name(), + Type: domain.IDPTypeOIDC.String(), + AllowCreation: true, + AllowAutoCreation: true, + AllowAutoUpdate: true, + AllowLinking: true, + StylingType: 1, + Payload: gu.Ptr("{}"), + } + err := idpRepo.Create(ctx, &idp) + require.NoError(t, err) + + noOfIDPs := 5 + idps := make([]*domain.IdentityProvider, noOfIDPs) + for i := range noOfIDPs { + + idp := domain.IdentityProvider{ + InstanceID: instanceId, + OrgID: &orgId, + ID: gofakeit.Name(), + State: domain.IDPStateActive.String(), + Name: gofakeit.Name(), + Type: domain.IDPTypeLDAP.String(), + AllowCreation: true, + AllowAutoCreation: true, + AllowAutoUpdate: true, + // AllowLinking set to false + AllowLinking: false, + StylingType: 1, + Payload: gu.Ptr("{}"), + } + + err := idpRepo.Create(ctx, &idp) + require.NoError(t, err) + + idps[i] = &idp + } + + return idps + }, + conditionClauses: []database.Condition{idpRepo.AllowLinkingCondition(false)}, + }, + { + name: "multiple idps filter on StylingType", + testFunc: func(ctx context.Context, t *testing.T) []*domain.IdentityProvider { + // create idp + // this idp is created as an additional idp which should NOT + // be returned in the results of this test case + idp := domain.IdentityProvider{ + InstanceID: instanceId, + OrgID: &orgId, + ID: gofakeit.Name(), + State: domain.IDPStateActive.String(), + Name: gofakeit.Name(), + Type: domain.IDPTypeOIDC.String(), + AllowCreation: true, + AllowAutoCreation: true, + AllowAutoUpdate: true, + AllowLinking: true, + StylingType: 1, + Payload: gu.Ptr("{}"), + } + err := idpRepo.Create(ctx, &idp) + require.NoError(t, err) + + noOfIDPs := 1 + idps := make([]*domain.IdentityProvider, noOfIDPs) + for i := range noOfIDPs { + + idp := domain.IdentityProvider{ + InstanceID: instanceId, + OrgID: &orgId, + ID: gofakeit.Name(), + State: domain.IDPStateActive.String(), + Name: gofakeit.Name(), + Type: domain.IDPTypeOIDC.String(), + AllowCreation: true, + AllowAutoCreation: true, + AllowAutoUpdate: true, + AllowLinking: true, + // StylingType set to 4 + StylingType: 4, + Payload: gu.Ptr("{}"), + } + + err := idpRepo.Create(ctx, &idp) + require.NoError(t, err) + + idps[i] = &idp + } + + return idps + }, + conditionClauses: []database.Condition{idpRepo.StylingTypeCondition(4)}, + }, + func() test { + payload := `{"json": {}}` + return test{ + name: "multiple idps filter on Payload", + testFunc: func(ctx context.Context, t *testing.T) []*domain.IdentityProvider { + // create idp + // this idp is created as an additional idp which should NOT + // be returned in the results of this test case + idp := domain.IdentityProvider{ + InstanceID: instanceId, + OrgID: &orgId, + ID: gofakeit.Name(), + State: domain.IDPStateActive.String(), + Name: gofakeit.Name(), + Type: domain.IDPTypeOIDC.String(), + AllowCreation: true, + AllowAutoCreation: true, + AllowAutoUpdate: true, + AllowLinking: true, + StylingType: 1, + Payload: gu.Ptr("{}"), + } + err := idpRepo.Create(ctx, &idp) + require.NoError(t, err) + + noOfIDPs := 1 + idps := make([]*domain.IdentityProvider, noOfIDPs) + for i := range noOfIDPs { + + idp := domain.IdentityProvider{ + InstanceID: instanceId, + OrgID: &orgId, + ID: gofakeit.Name(), + State: domain.IDPStateActive.String(), + Name: gofakeit.Name(), + Type: domain.IDPTypeOIDC.String(), + AllowCreation: true, + AllowAutoCreation: true, + AllowAutoUpdate: true, + AllowLinking: true, + StylingType: 1, + Payload: &payload, + } + + err := idpRepo.Create(ctx, &idp) + require.NoError(t, err) + + idps[i] = &idp + } + + return idps + }, + conditionClauses: []database.Condition{idpRepo.PayloadCondition(payload)}, + } + }(), + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Cleanup(func() { + _, err := pool.Exec(ctx, "DELETE FROM zitadel.identity_providers") + require.NoError(t, err) + }) + + ctx := t.Context() + + idps := tt.testFunc(ctx, t) + + // check idp values + returnedIDPs, err := idpRepo.List(ctx, + tt.conditionClauses..., + ) + require.NoError(t, err) + if tt.noIDPsReturned { + assert.Nil(t, returnedIDPs) + return + } + + assert.Equal(t, len(idps), len(returnedIDPs)) + for i, idp := range idps { + + assert.Equal(t, returnedIDPs[i].InstanceID, idp.InstanceID) + assert.Equal(t, returnedIDPs[i].OrgID, idp.OrgID) + assert.Equal(t, returnedIDPs[i].State, idp.State) + assert.Equal(t, returnedIDPs[i].ID, idp.ID) + assert.Equal(t, returnedIDPs[i].Name, idp.Name) + assert.Equal(t, returnedIDPs[i].Type, idp.Type) + assert.Equal(t, returnedIDPs[i].AllowCreation, idp.AllowCreation) + assert.Equal(t, returnedIDPs[i].AllowAutoCreation, idp.AllowAutoCreation) + assert.Equal(t, returnedIDPs[i].AllowAutoUpdate, idp.AllowAutoUpdate) + assert.Equal(t, returnedIDPs[i].AllowLinking, idp.AllowLinking) + assert.Equal(t, returnedIDPs[i].StylingType, idp.StylingType) + assert.Equal(t, returnedIDPs[i].Payload, idp.Payload) + } + }) + } +} + +func TestDeleteIDProvider(t *testing.T) { + // create instance + instanceId := gofakeit.Name() + instance := domain.Instance{ + ID: instanceId, + Name: gofakeit.Name(), + DefaultOrgID: "defaultOrgId", + IAMProjectID: "iamProject", + ConsoleClientID: "consoleCLient", + ConsoleAppID: "consoleApp", + DefaultLanguage: "defaultLanguage", + } + instanceRepo := repository.InstanceRepository(pool) + err := instanceRepo.Create(t.Context(), &instance) + require.NoError(t, err) + + // create org + orgId := gofakeit.Name() + org := domain.Organization{ + ID: orgId, + Name: gofakeit.Name(), + InstanceID: instanceId, + State: domain.OrgStateActive.String(), + } + organizationRepo := repository.OrganizationRepository(pool) + err = organizationRepo.Create(t.Context(), &org) + require.NoError(t, err) + + idpRepo := repository.IDProviderRepository(pool) + + type test struct { + name string + testFunc func(ctx context.Context, t *testing.T) + idpIdentifierCondition domain.IDPIdentifierCondition + noOfDeletedRows int64 + } + tests := []test{ + func() test { + id := gofakeit.Name() + var noOfIDPs int64 = 1 + return test{ + name: "happy path delete idp filter id", + testFunc: func(ctx context.Context, t *testing.T) { + for range noOfIDPs { + idp := domain.IdentityProvider{ + InstanceID: instanceId, + OrgID: &orgId, + ID: id, + State: domain.IDPStateActive.String(), + Name: gofakeit.Name(), + Type: domain.IDPTypeOIDC.String(), + AllowCreation: true, + AllowAutoCreation: true, + AllowAutoUpdate: true, + AllowLinking: true, + StylingType: 1, + Payload: gu.Ptr("{}"), + } + + err := idpRepo.Create(ctx, &idp) + require.NoError(t, err) + } + }, + idpIdentifierCondition: idpRepo.IDCondition(id), + noOfDeletedRows: noOfIDPs, + } + }(), + func() test { + name := gofakeit.Name() + var noOfIDPs int64 = 1 + return test{ + name: "happy path delete idp filter name", + testFunc: func(ctx context.Context, t *testing.T) { + for range noOfIDPs { + idp := domain.IdentityProvider{ + InstanceID: instanceId, + OrgID: &orgId, + ID: gofakeit.Name(), + State: domain.IDPStateActive.String(), + Name: name, + Type: domain.IDPTypeOIDC.String(), + AllowCreation: true, + AllowAutoCreation: true, + AllowAutoUpdate: true, + AllowLinking: true, + StylingType: 1, + Payload: gu.Ptr("{}"), + } + + err := idpRepo.Create(ctx, &idp) + require.NoError(t, err) + + } + }, + idpIdentifierCondition: idpRepo.NameCondition(name), + noOfDeletedRows: noOfIDPs, + } + }(), + { + name: "delete non existent idp", + idpIdentifierCondition: idpRepo.NameCondition(gofakeit.Name()), + }, + func() test { + name := gofakeit.Name() + return test{ + name: "deleted already deleted idp", + testFunc: func(ctx context.Context, t *testing.T) { + noOfIDPs := 1 + for range noOfIDPs { + idp := domain.IdentityProvider{ + InstanceID: instanceId, + OrgID: &orgId, + ID: gofakeit.Name(), + State: domain.IDPStateActive.String(), + Name: name, + Type: domain.IDPTypeOIDC.String(), + AllowCreation: true, + AllowAutoCreation: true, + AllowAutoUpdate: true, + AllowLinking: true, + StylingType: 1, + Payload: gu.Ptr("{}"), + } + + err := idpRepo.Create(ctx, &idp) + require.NoError(t, err) + } + + // delete organization + affectedRows, err := idpRepo.Delete(ctx, + idpRepo.NameCondition(name), + instanceId, + &orgId, + ) + assert.Equal(t, int64(1), affectedRows) + require.NoError(t, err) + }, + idpIdentifierCondition: idpRepo.NameCondition(name), + // this test should return 0 affected rows as the idp was already deleted + noOfDeletedRows: 0, + } + }(), + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctx := t.Context() + + if tt.testFunc != nil { + tt.testFunc(ctx, t) + } + + // delete idp + noOfDeletedRows, err := idpRepo.Delete(ctx, + tt.idpIdentifierCondition, + instanceId, + &orgId, + ) + require.NoError(t, err) + assert.Equal(t, noOfDeletedRows, tt.noOfDeletedRows) + + // check idp was deleted + organization, err := idpRepo.Get(ctx, + tt.idpIdentifierCondition, + instanceId, + &orgId, + ) + require.ErrorIs(t, err, new(database.NoRowFoundError)) + assert.Nil(t, organization) + }) + } +} diff --git a/backend/v3/storage/database/repository/instance.go b/backend/v3/storage/database/repository/instance.go index 63f878574c..e34e9bc6e0 100644 --- a/backend/v3/storage/database/repository/instance.go +++ b/backend/v3/storage/database/repository/instance.go @@ -172,28 +172,18 @@ func (instance) UpdatedAtColumn() database.Column { } func scanInstance(ctx context.Context, querier database.Querier, builder *database.StatementBuilder) (*domain.Instance, error) { - rows, err := querier.Query(ctx, builder.String(), builder.Args()...) + instance := new(domain.Instance) + err := scanRow(ctx, querier, builder, instance) if err != nil { return nil, err } - - instance := new(domain.Instance) - if err := rows.(database.CollectableRows).CollectExactlyOneRow(instance); err != nil { - return nil, err - } - - return instance, nil + return instance, err } func scanInstances(ctx context.Context, querier database.Querier, builder *database.StatementBuilder) (instances []*domain.Instance, err error) { - rows, err := querier.Query(ctx, builder.String(), builder.Args()...) + err = scanRows(ctx, querier, builder, &instances) if err != nil { return nil, err } - - if err := rows.(database.CollectableRows).Collect(&instances); err != nil { - return nil, err - } - return instances, nil } diff --git a/backend/v3/storage/database/repository/org.go b/backend/v3/storage/database/repository/org.go index e8053aadd9..7268c50437 100644 --- a/backend/v3/storage/database/repository/org.go +++ b/backend/v3/storage/database/repository/org.go @@ -217,27 +217,19 @@ func (org) UpdatedAtColumn() database.Column { } func scanOrganization(ctx context.Context, querier database.Querier, builder *database.StatementBuilder) (*domain.Organization, error) { - rows, err := querier.Query(ctx, builder.String(), builder.Args()...) + organization := &domain.Organization{} + err := scanRow(ctx, querier, builder, organization) if err != nil { return nil, err } - - organization := &domain.Organization{} - if err := rows.(database.CollectableRows).CollectExactlyOneRow(organization); err != nil { - return nil, err - } - - return organization, nil + return organization, err } func scanOrganizations(ctx context.Context, querier database.Querier, builder *database.StatementBuilder) ([]*domain.Organization, error) { - rows, err := querier.Query(ctx, builder.String(), builder.Args()...) - if err != nil { - return nil, err - } - organizations := []*domain.Organization{} - if err := rows.(database.CollectableRows).Collect(&organizations); err != nil { + + err := scanRows(ctx, querier, builder, &organizations) + if err != nil { return nil, err } return organizations, nil diff --git a/backend/v3/storage/database/repository/org_test.go b/backend/v3/storage/database/repository/org_test.go index a6b5182f8c..d1d13d539e 100644 --- a/backend/v3/storage/database/repository/org_test.go +++ b/backend/v3/storage/database/repository/org_test.go @@ -519,11 +519,6 @@ func TestGetOrganization(t *testing.T) { return } - if org.Name == "non existent org" { - assert.Nil(t, returnedOrg) - return - } - assert.Equal(t, returnedOrg.ID, org.ID) assert.Equal(t, returnedOrg.Name, org.Name) assert.Equal(t, returnedOrg.InstanceID, org.InstanceID) @@ -815,8 +810,7 @@ func TestDeleteOrganization(t *testing.T) { return test{ name: "happy path delete organization filter id", testFunc: func(ctx context.Context, t *testing.T) { - organizations := make([]*domain.Organization, noOfOrganizations) - for i := range noOfOrganizations { + for range noOfOrganizations { org := domain.Organization{ ID: organizationId, @@ -829,7 +823,6 @@ func TestDeleteOrganization(t *testing.T) { err := organizationRepo.Create(ctx, &org) require.NoError(t, err) - organizations[i] = &org } }, orgIdentifierCondition: organizationRepo.IDCondition(organizationId), @@ -843,8 +836,7 @@ func TestDeleteOrganization(t *testing.T) { return test{ name: "happy path delete organization filter name", testFunc: func(ctx context.Context, t *testing.T) { - organizations := make([]*domain.Organization, noOfOrganizations) - for i := range noOfOrganizations { + for range noOfOrganizations { org := domain.Organization{ ID: gofakeit.Name(), @@ -857,7 +849,6 @@ func TestDeleteOrganization(t *testing.T) { err := organizationRepo.Create(ctx, &org) require.NoError(t, err) - organizations[i] = &org } }, orgIdentifierCondition: organizationRepo.NameCondition(organizationName), @@ -879,8 +870,7 @@ func TestDeleteOrganization(t *testing.T) { name: "deleted already deleted organization", testFunc: func(ctx context.Context, t *testing.T) { noOfOrganizations := 1 - organizations := make([]*domain.Organization, noOfOrganizations) - for i := range noOfOrganizations { + for range noOfOrganizations { org := domain.Organization{ ID: gofakeit.Name(), @@ -893,13 +883,12 @@ func TestDeleteOrganization(t *testing.T) { err := organizationRepo.Create(ctx, &org) require.NoError(t, err) - organizations[i] = &org } // delete organization affectedRows, err := organizationRepo.Delete(ctx, organizationRepo.NameCondition(organizationName), - organizations[0].InstanceID, + instanceId, ) assert.Equal(t, int64(1), affectedRows) require.NoError(t, err) diff --git a/backend/v3/storage/database/repository/repository.go b/backend/v3/storage/database/repository/repository.go index c5b9ff81f0..6713dc77ed 100644 --- a/backend/v3/storage/database/repository/repository.go +++ b/backend/v3/storage/database/repository/repository.go @@ -1,6 +1,8 @@ package repository import ( + "context" + "github.com/zitadel/zitadel/backend/v3/storage/database" ) @@ -18,3 +20,21 @@ func writeCondition( builder.WriteString(" WHERE ") condition.Write(builder) } + +func scanRow(ctx context.Context, querier database.Querier, builder *database.StatementBuilder, res any) error { + rows, err := querier.Query(ctx, builder.String(), builder.Args()...) + if err != nil { + return err + } + + return rows.(database.CollectableRows).CollectExactlyOneRow(res) +} + +func scanRows(ctx context.Context, querier database.Querier, builder *database.StatementBuilder, res any) error { + rows, err := querier.Query(ctx, builder.String(), builder.Args()...) + if err != nil { + return err + } + + return rows.(database.CollectableRows).Collect(res) +} diff --git a/backend/v3/storage/database/statement.go b/backend/v3/storage/database/statement.go index 7d779fe360..041a6d87b6 100644 --- a/backend/v3/storage/database/statement.go +++ b/backend/v3/storage/database/statement.go @@ -29,9 +29,7 @@ func (b *StatementBuilder) AppendArg(arg any) (placeholder string) { if b.existingArgs == nil { b.existingArgs = make(map[any]string) } - if placeholder, ok := b.existingArgs[arg]; ok { - return placeholder - } + if instruction, ok := arg.(Instruction); ok { return string(instruction) }