feat(idp_table_relational): adding inital idp tables for relational repository

This commit is contained in:
Iraq Jaber
2025-07-27 12:13:35 +01:00
parent 13b772aa8c
commit 9fd4f6f2b5
19 changed files with 3154 additions and 52 deletions

View File

@@ -0,0 +1,164 @@
package domain
import (
"context"
"time"
"github.com/zitadel/zitadel/backend/v3/storage/database"
"github.com/zitadel/zitadel/internal/crypto"
)
//go:generate enumer -type IDPType -transform lower -trimprefix IDPType
type IDPType uint8
const (
IDPTypeUnspecified IDPType = iota
IDPTypeOIDC
IDPTypeJWT
IDPTypeOAuth
IDPTypeLDAP
IDPTypeAzureAD
IDPTypeGitHub
IDPTypeGitHubEnterprise
IDPTypeGitLab
IDPTypeGitLabSelfHosted
IDPTypeGoogle
IDPTypeApple
IDPTypeSAML
)
//go:generate enumer -type IDPState -transform lower -trimprefix IDPState
type IDPState uint8
const (
IDPStateActive IDPState = iota
IDPStateInactive
)
type OIDCMappingField int8
const (
OIDCMappingFieldUnspecified OIDCMappingField = iota
OIDCMappingFieldPreferredLoginName
OIDCMappingFieldEmail
// count is for validation purposes
oidcMappingFieldCount
)
type IdentityProvider struct {
InstanceID string `json:"instanceId,omitempty" db:"instance_id"`
OrgID *string `json:"orgId,omitempty" db:"org_id"`
ID string `json:"id,omitempty" db:"id"`
State string `json:"state,omitempty" db:"state"`
Name string `json:"name,omitempty" db:"name"`
Type string `json:"type,omitempty" db:"type"`
AllowCreation bool `json:"allowCreation,omitempty" db:"allow_creation"`
AutoRegister bool `json:"autoRegister,omitempty" db:"auto_register"`
AllowAutoCreation bool `json:"allowAutoCreation,omitempty" db:"allow_auto_creation"`
AllowAutoUpdate bool `json:"allowAutoUpdate,omitempty" db:"allow_auto_update"`
AllowLinking bool `json:"allowLinking,omitempty" db:"allow_linking"`
AllowAutoLinking bool `json:"allowAutoLinking,omitempty" db:"allow_auto_linking"`
StylingType int16 `json:"stylingType,omitempty" db:"styling_type"`
Payload *string `json:"payload,omitempty" db:"payload"`
CreatedAt time.Time `json:"createdAt,omitempty" db:"created_at"`
UpdatedAt time.Time `json:"updatedAt,omitempty" db:"updated_at"`
}
type OIDC struct {
IDPConfigID string `json:"idpConfigId"`
ClientID string `json:"clientId,omitempty"`
ClientSecret crypto.CryptoValue `json:"clientSecret,omitempty"`
Issuer string `json:"issuer,omitempty"`
AuthorizationEndpoint string `json:"authorizationEndpoint,omitempty"`
TokenEndpoint string `json:"tokenEndpoint,omitempty"`
Scopes []string `json:"scopes,omitempty"`
IDPDisplayNameMapping OIDCMappingField `json:"IDPDisplayNameMapping,omitempty"`
UserNameMapping OIDCMappingField `json:"usernameMapping,omitempty"`
}
type IDPOIDC struct {
*IdentityProvider
OIDC
}
type JWT struct {
IDPConfigID string `json:"idpConfigId"`
JWTEndpoint string `json:"jwtEndpoint,omitempty"`
Issuer string `json:"issuer,omitempty"`
KeysEndpoint string `json:"keysEndpoint,omitempty"`
HeaderName string `json:"headerName,omitempty"`
}
type IDPJWT struct {
*IdentityProvider
JWT
}
// IDPIdentifierCondition is used to help specify a single identity_provider,
// it will either be used as the identity_provider ID or identity_provider name,
// as identity_provider can be identified either using (instanceID + OrgID + ID) OR (instanceID + OrgID + name)
type IDPIdentifierCondition interface {
database.Condition
}
type idProviderColumns interface {
InstanceIDColumn() database.Column
OrgIDColumn() database.Column
IDColumn() database.Column
StateColumn() database.Column
NameColumn() database.Column
TypeColumn() database.Column
AllowCreationColumn() database.Column
AutoRegisterColumn() database.Column
AllowAutoCreationColumn() database.Column
AllowAutoUpdateColumn() database.Column
AllowLinkingColumn() database.Column
AllowAutoLinkingColumn() database.Column
StylingTypeColumn() database.Column
PayloadColumn() database.Column
CreatedAtColumn() database.Column
UpdatedAtColumn() database.Column
}
type idProviderConditions interface {
InstanceIDCondition(id string) database.Condition
OrgIDCondition(id *string) database.Condition
IDCondition(id string) IDPIdentifierCondition
StateCondition(state IDPState) database.Condition
NameCondition(name string) IDPIdentifierCondition
TypeCondition(typee IDPType) database.Condition
AutoRegisterCondition(allow bool) database.Condition
AllowCreationCondition(allow bool) database.Condition
AllowAutoCreationCondition(allow bool) database.Condition
AllowAutoUpdateCondition(allow bool) database.Condition
AllowLinkingCondition(allow bool) database.Condition
AllowAutoLinkingCondition(allow bool) database.Condition
StylingTypeCondition(style int16) database.Condition
PayloadCondition(payload string) database.Condition
}
type idProviderChanges interface {
SetName(name string) database.Change
SetState(state IDPState) database.Change
SetAllowCreation(allow bool) database.Change
SetAutoRegister(allow bool) database.Change
SetAllowAutoCreation(allow bool) database.Change
SetAllowAutoUpdate(allow bool) database.Change
SetAllowLinking(allow bool) database.Change
SetAutoAllowLinking(allow bool) database.Change
SetStylingType(stylingType int16) database.Change
SetPayload(payload string) database.Change
}
type IDProviderRepository interface {
idProviderColumns
idProviderConditions
idProviderChanges
Get(ctx context.Context, id IDPIdentifierCondition, instanceID string, orgID *string) (*IdentityProvider, error)
List(ctx context.Context, conditions ...database.Condition) ([]*IdentityProvider, error)
Create(ctx context.Context, idp *IdentityProvider) error
Update(ctx context.Context, id IDPIdentifierCondition, instanceID string, orgID *string, changes ...database.Change) (int64, error)
Delete(ctx context.Context, id IDPIdentifierCondition, instanceID string, orgID *string) (int64, error)
}

View File

@@ -0,0 +1,78 @@
// Code generated by "enumer -type IDPState -transform lower -trimprefix IDPState"; DO NOT EDIT.
package domain
import (
"fmt"
"strings"
)
const _IDPStateName = "activeinactive"
var _IDPStateIndex = [...]uint8{0, 6, 14}
const _IDPStateLowerName = "activeinactive"
func (i IDPState) String() string {
if i >= IDPState(len(_IDPStateIndex)-1) {
return fmt.Sprintf("IDPState(%d)", i)
}
return _IDPStateName[_IDPStateIndex[i]:_IDPStateIndex[i+1]]
}
// An "invalid array index" compiler error signifies that the constant values have changed.
// Re-run the stringer command to generate them again.
func _IDPStateNoOp() {
var x [1]struct{}
_ = x[IDPStateActive-(0)]
_ = x[IDPStateInactive-(1)]
}
var _IDPStateValues = []IDPState{IDPStateActive, IDPStateInactive}
var _IDPStateNameToValueMap = map[string]IDPState{
_IDPStateName[0:6]: IDPStateActive,
_IDPStateLowerName[0:6]: IDPStateActive,
_IDPStateName[6:14]: IDPStateInactive,
_IDPStateLowerName[6:14]: IDPStateInactive,
}
var _IDPStateNames = []string{
_IDPStateName[0:6],
_IDPStateName[6:14],
}
// IDPStateString retrieves an enum value from the enum constants string name.
// Throws an error if the param is not part of the enum.
func IDPStateString(s string) (IDPState, error) {
if val, ok := _IDPStateNameToValueMap[s]; ok {
return val, nil
}
if val, ok := _IDPStateNameToValueMap[strings.ToLower(s)]; ok {
return val, nil
}
return 0, fmt.Errorf("%s does not belong to IDPState values", s)
}
// IDPStateValues returns all values of the enum
func IDPStateValues() []IDPState {
return _IDPStateValues
}
// IDPStateStrings returns a slice of all String values of the enum
func IDPStateStrings() []string {
strs := make([]string, len(_IDPStateNames))
copy(strs, _IDPStateNames)
return strs
}
// IsAIDPState returns "true" if the value is listed in the enum definition. "false" otherwise
func (i IDPState) IsAIDPState() bool {
for _, v := range _IDPStateValues {
if i == v {
return true
}
}
return false
}

View File

@@ -0,0 +1,122 @@
// Code generated by "enumer -type IDPType -transform lower -trimprefix IDPType"; DO NOT EDIT.
package domain
import (
"fmt"
"strings"
)
const _IDPTypeName = "unspecifiedoidcjwtoauthldapazureadgithubgithubenterprisegitlabgitlabselfhostedgoogleapplesaml"
var _IDPTypeIndex = [...]uint8{0, 11, 15, 18, 23, 27, 34, 40, 56, 62, 78, 84, 89, 93}
const _IDPTypeLowerName = "unspecifiedoidcjwtoauthldapazureadgithubgithubenterprisegitlabgitlabselfhostedgoogleapplesaml"
func (i IDPType) String() string {
if i >= IDPType(len(_IDPTypeIndex)-1) {
return fmt.Sprintf("IDPType(%d)", i)
}
return _IDPTypeName[_IDPTypeIndex[i]:_IDPTypeIndex[i+1]]
}
// An "invalid array index" compiler error signifies that the constant values have changed.
// Re-run the stringer command to generate them again.
func _IDPTypeNoOp() {
var x [1]struct{}
_ = x[IDPTypeUnspecified-(0)]
_ = x[IDPTypeOIDC-(1)]
_ = x[IDPTypeJWT-(2)]
_ = x[IDPTypeOAuth-(3)]
_ = x[IDPTypeLDAP-(4)]
_ = x[IDPTypeAzureAD-(5)]
_ = x[IDPTypeGitHub-(6)]
_ = x[IDPTypeGitHubEnterprise-(7)]
_ = x[IDPTypeGitLab-(8)]
_ = x[IDPTypeGitLabSelfHosted-(9)]
_ = x[IDPTypeGoogle-(10)]
_ = x[IDPTypeApple-(11)]
_ = x[IDPTypeSAML-(12)]
}
var _IDPTypeValues = []IDPType{IDPTypeUnspecified, IDPTypeOIDC, IDPTypeJWT, IDPTypeOAuth, IDPTypeLDAP, IDPTypeAzureAD, IDPTypeGitHub, IDPTypeGitHubEnterprise, IDPTypeGitLab, IDPTypeGitLabSelfHosted, IDPTypeGoogle, IDPTypeApple, IDPTypeSAML}
var _IDPTypeNameToValueMap = map[string]IDPType{
_IDPTypeName[0:11]: IDPTypeUnspecified,
_IDPTypeLowerName[0:11]: IDPTypeUnspecified,
_IDPTypeName[11:15]: IDPTypeOIDC,
_IDPTypeLowerName[11:15]: IDPTypeOIDC,
_IDPTypeName[15:18]: IDPTypeJWT,
_IDPTypeLowerName[15:18]: IDPTypeJWT,
_IDPTypeName[18:23]: IDPTypeOAuth,
_IDPTypeLowerName[18:23]: IDPTypeOAuth,
_IDPTypeName[23:27]: IDPTypeLDAP,
_IDPTypeLowerName[23:27]: IDPTypeLDAP,
_IDPTypeName[27:34]: IDPTypeAzureAD,
_IDPTypeLowerName[27:34]: IDPTypeAzureAD,
_IDPTypeName[34:40]: IDPTypeGitHub,
_IDPTypeLowerName[34:40]: IDPTypeGitHub,
_IDPTypeName[40:56]: IDPTypeGitHubEnterprise,
_IDPTypeLowerName[40:56]: IDPTypeGitHubEnterprise,
_IDPTypeName[56:62]: IDPTypeGitLab,
_IDPTypeLowerName[56:62]: IDPTypeGitLab,
_IDPTypeName[62:78]: IDPTypeGitLabSelfHosted,
_IDPTypeLowerName[62:78]: IDPTypeGitLabSelfHosted,
_IDPTypeName[78:84]: IDPTypeGoogle,
_IDPTypeLowerName[78:84]: IDPTypeGoogle,
_IDPTypeName[84:89]: IDPTypeApple,
_IDPTypeLowerName[84:89]: IDPTypeApple,
_IDPTypeName[89:93]: IDPTypeSAML,
_IDPTypeLowerName[89:93]: IDPTypeSAML,
}
var _IDPTypeNames = []string{
_IDPTypeName[0:11],
_IDPTypeName[11:15],
_IDPTypeName[15:18],
_IDPTypeName[18:23],
_IDPTypeName[23:27],
_IDPTypeName[27:34],
_IDPTypeName[34:40],
_IDPTypeName[40:56],
_IDPTypeName[56:62],
_IDPTypeName[62:78],
_IDPTypeName[78:84],
_IDPTypeName[84:89],
_IDPTypeName[89:93],
}
// IDPTypeString retrieves an enum value from the enum constants string name.
// Throws an error if the param is not part of the enum.
func IDPTypeString(s string) (IDPType, error) {
if val, ok := _IDPTypeNameToValueMap[s]; ok {
return val, nil
}
if val, ok := _IDPTypeNameToValueMap[strings.ToLower(s)]; ok {
return val, nil
}
return 0, fmt.Errorf("%s does not belong to IDPType values", s)
}
// IDPTypeValues returns all values of the enum
func IDPTypeValues() []IDPType {
return _IDPTypeValues
}
// IDPTypeStrings returns a slice of all String values of the enum
func IDPTypeStrings() []string {
strs := make([]string, len(_IDPTypeNames))
copy(strs, _IDPTypeNames)
return strs
}
// IsAIDPType returns "true" if the value is listed in the enum definition. "false" otherwise
func (i IDPType) IsAIDPType() bool {
for _, v := range _IDPTypeValues {
if i == v {
return true
}
}
return false
}

View File

@@ -84,7 +84,7 @@ type InstanceRepository interface {
// Member() MemberRepository // Member() MemberRepository
Get(ctx context.Context, id string) (*Instance, error) Get(ctx context.Context, id string) (*Instance, error)
List(ctx context.Context, opts ...database.Condition) ([]*Instance, error) List(ctx context.Context, conditions ...database.Condition) ([]*Instance, error)
Create(ctx context.Context, instance *Instance) error Create(ctx context.Context, instance *Instance) error
Update(ctx context.Context, id string, changes ...database.Change) (int64, error) Update(ctx context.Context, id string, changes ...database.Change) (int64, error)

View File

@@ -37,7 +37,7 @@ type organizationColumns interface {
IDColumn() database.Column IDColumn() database.Column
// NameColumn returns the column for the name field. // NameColumn returns the column for the name field.
NameColumn() database.Column NameColumn() database.Column
// InstanceIDColumn returns the column for the default org id field // InstanceIDColumn returns the column for the instance id field
InstanceIDColumn() database.Column InstanceIDColumn() database.Column
// StateColumn returns the column for the name field. // StateColumn returns the column for the name field.
StateColumn() database.Column StateColumn() database.Column

View File

@@ -14,7 +14,7 @@ var _OrgStateIndex = [...]uint8{0, 6, 14}
const _OrgStateLowerName = "activeinactive" const _OrgStateLowerName = "activeinactive"
func (i OrgState) String() string { func (i OrgState) String() string {
if i < 0 || i >= OrgState(len(_OrgStateIndex)-1) { if i >= OrgState(len(_OrgStateIndex)-1) {
return fmt.Sprintf("OrgState(%d)", i) return fmt.Sprintf("OrgState(%d)", i)
} }
return _OrgStateName[_OrgStateIndex[i]:_OrgStateIndex[i+1]] return _OrgStateName[_OrgStateIndex[i]:_OrgStateIndex[i+1]]

View File

@@ -16,11 +16,16 @@ func (a *and) Write(builder *StatementBuilder) {
builder.WriteString("(") builder.WriteString("(")
defer builder.WriteString(")") defer builder.WriteString(")")
} }
for i, condition := range a.conditions { firstCondition := true
if i > 0 { for _, condition := range a.conditions {
if condition == nil {
continue
}
if !firstCondition {
builder.WriteString(" AND ") builder.WriteString(" AND ")
} }
condition.Write(builder) condition.Write(builder)
firstCondition = false
} }
} }

View File

@@ -0,0 +1,16 @@
package migration
import (
_ "embed"
)
var (
//go:embed 003_identity_providers_table/up.sql
up003IdentityProvidersTable string
//go:embed 003_identity_providers_table/down.sql
down003IdentityProvidersTable string
)
func init() {
registerSQLMigration(3, up003IdentityProvidersTable, down003IdentityProvidersTable)
}

View File

@@ -0,0 +1,3 @@
DROP TABLE zitadel.identity_providers;
DROP Type zitadel.idp_state;
DROP Type zitadel.idp_type;

View File

@@ -0,0 +1,55 @@
CREATE TYPE zitadel.idp_state AS ENUM (
'active',
'inactive'
);
CREATE TYPE zitadel.idp_type AS ENUM (
'oidc',
'jwt',
'oauth',
'saml',
'ldap',
'github',
'google',
'microsoft',
'apple'
);
CREATE TABLE zitadel.identity_providers (
instance_id TEXT NOT NULL
, org_id TEXT
, id TEXT NOT NULL CHECK (id <> '')
, state zitadel.idp_state NOT NULL DEFAULT 'active'
, name TEXT NOT NULL CHECK (name <> '')
, type zitadel.idp_type -- NOT NULL
, auto_register BOOLEAN NOT NULL DEFAULT TRUE
, allow_creation BOOLEAN NOT NULL DEFAULT TRUE
, allow_auto_creation BOOLEAN NOT NULL DEFAULT TRUE
, allow_auto_update BOOLEAN NOT NULL DEFAULT TRUE
, allow_linking BOOLEAN NOT NULL DEFAULT TRUE
, allow_auto_linking BOOLEAN NOT NULL DEFAULT TRUE
, styling_type SMALLINT
, payload JSONB
, created_at TIMESTAMPTZ NOT NULL DEFAULT now()
, updated_at TIMESTAMPTZ NOT NULL DEFAULT now()
, PRIMARY KEY (instance_id, id)
, CONSTRAINT identity_providers_id_unique UNIQUE NULLS NOT DISTINCT (instance_id, org_id, id)
, CONSTRAINT identity_providers_name_unique UNIQUE NULLS NOT DISTINCT (instance_id, org_id, name)
, FOREIGN KEY (instance_id) REFERENCES zitadel.instances(id)
, FOREIGN KEY (instance_id, org_id) REFERENCES zitadel.organizations(instance_id, id)
);
-- CREATE INDEX idx_identity_providers_org_id ON identity_providers(instance_id, org_id) WHERE org_id IS NOT NULL;
CREATE INDEX idx_identity_providers_state ON zitadel.identity_providers(instance_id, state);
CREATE INDEX idx_identity_providers_type ON zitadel.identity_providers(instance_id, type);
-- CREATE INDEX idx_identity_providers_created_at ON identity_providers(created_at);
-- CREATE INDEX idx_identity_providers_deleted_at ON identity_providers(deleted_at) WHERE deleted_at IS NOT NULL;
CREATE TRIGGER trigger_set_updated_at
BEFORE UPDATE ON zitadel.identity_providers
FOR EACH ROW
WHEN (OLD.updated_at IS NOT DISTINCT FROM NEW.updated_at)
EXECUTE FUNCTION zitadel.set_updated_at();

View File

@@ -0,0 +1,485 @@
//go:build integration
package events_test
import (
"testing"
"time"
"github.com/brianvoe/gofakeit/v6"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/zitadel/zitadel/backend/v3/domain"
"github.com/zitadel/zitadel/backend/v3/storage/database"
"github.com/zitadel/zitadel/backend/v3/storage/database/repository"
"github.com/zitadel/zitadel/internal/integration"
"github.com/zitadel/zitadel/pkg/grpc/admin"
"github.com/zitadel/zitadel/pkg/grpc/idp"
idp_grpc "github.com/zitadel/zitadel/pkg/grpc/idp"
)
func TestServer_TestIDProviderReduces(t *testing.T) {
instanceID := Instance.ID()
t.Run("test idp add reduces", func(t *testing.T) {
name := gofakeit.Name()
beforeCreate := time.Now()
addOIDC, err := AdminClient.AddOIDCIDP(CTX, &admin.AddOIDCIDPRequest{
Name: name,
StylingType: idp_grpc.IDPStylingType_STYLING_TYPE_GOOGLE,
ClientId: "clientID",
ClientSecret: "clientSecret",
Issuer: "issuer",
Scopes: []string{"scope"},
DisplayNameMapping: idp.OIDCMappingField_OIDC_MAPPING_FIELD_EMAIL,
UsernameMapping: idp.OIDCMappingField_OIDC_MAPPING_FIELD_EMAIL,
AutoRegister: true,
})
require.NoError(t, err)
afterCreate := time.Now()
idpRepo := repository.IDProviderRepository(pool)
retryDuration, tick := integration.WaitForAndTickWithMaxDuration(CTX, time.Second*5)
assert.EventuallyWithT(t, func(t *assert.CollectT) {
idp, err := idpRepo.Get(CTX,
idpRepo.NameCondition(name),
instanceID,
nil,
)
require.NoError(t, err)
// event iam.idp.config.added
assert.Equal(t, addOIDC.IdpId, idp.ID)
assert.Equal(t, name, idp.Name)
assert.Equal(t, instanceID, idp.InstanceID)
assert.Equal(t, domain.IDPStateActive.String(), idp.State)
assert.Equal(t, true, idp.AutoRegister)
assert.Equal(t, int16(idp_grpc.IDPStylingType_STYLING_TYPE_GOOGLE), idp.StylingType)
assert.WithinRange(t, idp.UpdatedAt, beforeCreate, afterCreate)
assert.WithinRange(t, idp.CreatedAt, beforeCreate, afterCreate)
}, retryDuration, tick)
})
t.Run("test idp update reduces", func(t *testing.T) {
name := gofakeit.Name()
addOIDC, err := AdminClient.AddOIDCIDP(CTX, &admin.AddOIDCIDPRequest{
Name: name,
StylingType: idp_grpc.IDPStylingType_STYLING_TYPE_GOOGLE,
ClientId: "clientID",
ClientSecret: "clientSecret",
Issuer: "issuer",
Scopes: []string{"scope"},
DisplayNameMapping: idp.OIDCMappingField_OIDC_MAPPING_FIELD_EMAIL,
UsernameMapping: idp.OIDCMappingField_OIDC_MAPPING_FIELD_EMAIL,
AutoRegister: true,
})
require.NoError(t, err)
name = "new_" + name
beforeCreate := time.Now()
_, err = AdminClient.UpdateIDP(CTX, &admin.UpdateIDPRequest{
IdpId: addOIDC.IdpId,
Name: name,
StylingType: idp_grpc.IDPStylingType_STYLING_TYPE_UNSPECIFIED,
AutoRegister: false,
})
afterCreate := time.Now()
require.NoError(t, err)
idpRepo := repository.IDProviderRepository(pool)
retryDuration, tick := integration.WaitForAndTickWithMaxDuration(CTX, time.Second*5)
assert.EventuallyWithT(t, func(t *assert.CollectT) {
idp, err := idpRepo.Get(CTX,
idpRepo.NameCondition(name),
instanceID,
nil,
)
require.NoError(t, err)
// event iam.idp.config.changed
assert.Equal(t, addOIDC.IdpId, idp.ID)
assert.Equal(t, name, idp.Name)
assert.Equal(t, false, idp.AutoRegister)
assert.Equal(t, int16(idp_grpc.IDPStylingType_STYLING_TYPE_UNSPECIFIED), idp.StylingType)
assert.WithinRange(t, idp.UpdatedAt, beforeCreate, afterCreate)
}, retryDuration, tick)
})
t.Run("test idp deactivate reduces", func(t *testing.T) {
name := gofakeit.Name()
addOIDC, err := AdminClient.AddOIDCIDP(CTX, &admin.AddOIDCIDPRequest{
Name: name,
StylingType: idp_grpc.IDPStylingType_STYLING_TYPE_GOOGLE,
ClientId: "clientID",
ClientSecret: "clientSecret",
Issuer: "issuer",
Scopes: []string{"scope"},
DisplayNameMapping: idp.OIDCMappingField_OIDC_MAPPING_FIELD_EMAIL,
UsernameMapping: idp.OIDCMappingField_OIDC_MAPPING_FIELD_EMAIL,
AutoRegister: true,
})
require.NoError(t, err)
// deactivate idp
beforeCreate := time.Now()
_, err = AdminClient.DeactivateIDP(CTX, &admin.DeactivateIDPRequest{
IdpId: addOIDC.IdpId,
})
afterCreate := time.Now()
require.NoError(t, err)
idpRepo := repository.IDProviderRepository(pool)
retryDuration, tick := integration.WaitForAndTickWithMaxDuration(CTX, time.Minute)
assert.EventuallyWithT(t, func(t *assert.CollectT) {
idp, err := idpRepo.Get(CTX,
idpRepo.IDCondition(addOIDC.IdpId),
instanceID,
nil,
)
require.NoError(t, err)
// event iam.idp.config.deactivated
assert.Equal(t, addOIDC.IdpId, idp.ID)
assert.Equal(t, domain.IDPStateInactive.String(), idp.State)
assert.WithinRange(t, idp.UpdatedAt, beforeCreate, afterCreate)
}, retryDuration, tick)
})
t.Run("test idp reactivate reduces", func(t *testing.T) {
name := gofakeit.Name()
addOIDC, err := AdminClient.AddOIDCIDP(CTX, &admin.AddOIDCIDPRequest{
Name: name,
StylingType: idp_grpc.IDPStylingType_STYLING_TYPE_GOOGLE,
ClientId: "clientID",
ClientSecret: "clientSecret",
Issuer: "issuer",
Scopes: []string{"scope"},
DisplayNameMapping: idp.OIDCMappingField_OIDC_MAPPING_FIELD_EMAIL,
UsernameMapping: idp.OIDCMappingField_OIDC_MAPPING_FIELD_EMAIL,
AutoRegister: true,
})
require.NoError(t, err)
idpRepo := repository.IDProviderRepository(pool)
// deactivate idp
_, err = AdminClient.DeactivateIDP(CTX, &admin.DeactivateIDPRequest{
IdpId: addOIDC.IdpId,
})
require.NoError(t, err)
// wait for idp to be deactivated
retryDuration, tick := integration.WaitForAndTickWithMaxDuration(CTX, time.Minute)
assert.EventuallyWithT(t, func(t *assert.CollectT) {
idp, err := idpRepo.Get(CTX,
idpRepo.IDCondition(addOIDC.IdpId),
instanceID,
nil,
)
require.NoError(t, err)
assert.Equal(t, addOIDC.IdpId, idp.ID)
assert.Equal(t, domain.IDPStateInactive.String(), idp.State)
}, retryDuration, tick)
// reactivate idp
beforeCreate := time.Now()
_, err = AdminClient.ReactivateIDP(CTX, &admin.ReactivateIDPRequest{
IdpId: addOIDC.IdpId,
})
afterCreate := time.Now()
require.NoError(t, err)
retryDuration, tick = integration.WaitForAndTickWithMaxDuration(CTX, time.Minute)
assert.EventuallyWithT(t, func(t *assert.CollectT) {
idp, err := idpRepo.Get(CTX,
idpRepo.IDCondition(addOIDC.IdpId),
instanceID,
nil,
)
require.NoError(t, err)
// event iam.idp.config.reactivated
assert.Equal(t, addOIDC.IdpId, idp.ID)
assert.Equal(t, domain.IDPStateActive.String(), idp.State)
assert.WithinRange(t, idp.UpdatedAt, beforeCreate, afterCreate)
}, retryDuration, tick)
})
t.Run("test idp remove reduces", func(t *testing.T) {
name := gofakeit.Name()
// add idp
addOIDC, err := AdminClient.AddOIDCIDP(CTX, &admin.AddOIDCIDPRequest{
Name: name,
StylingType: idp_grpc.IDPStylingType_STYLING_TYPE_GOOGLE,
ClientId: "clientID",
ClientSecret: "clientSecret",
Issuer: "issuer",
Scopes: []string{"scope"},
DisplayNameMapping: idp.OIDCMappingField_OIDC_MAPPING_FIELD_EMAIL,
UsernameMapping: idp.OIDCMappingField_OIDC_MAPPING_FIELD_EMAIL,
AutoRegister: true,
})
require.NoError(t, err)
idpRepo := repository.IDProviderRepository(pool)
// remove idp
_, err = AdminClient.RemoveIDP(CTX, &admin.RemoveIDPRequest{
IdpId: addOIDC.IdpId,
})
require.NoError(t, err)
retryDuration, tick := integration.WaitForAndTickWithMaxDuration(CTX, time.Second*20)
assert.EventuallyWithT(t, func(t *assert.CollectT) {
_, err := idpRepo.Get(CTX,
idpRepo.IDCondition(addOIDC.IdpId),
instanceID,
nil,
)
// event iam.idp.config.remove
require.ErrorIs(t, &database.NoRowFoundError{}, err)
}, retryDuration, tick)
})
t.Run("test idp oidc addded reduces", func(t *testing.T) {
name := gofakeit.Name()
// add oidc
addOIDC, err := AdminClient.AddOIDCIDP(CTX, &admin.AddOIDCIDPRequest{
Name: name,
StylingType: idp_grpc.IDPStylingType_STYLING_TYPE_GOOGLE,
ClientId: "clientID",
ClientSecret: "clientSecret",
Issuer: "issuer",
Scopes: []string{"scope"},
DisplayNameMapping: idp.OIDCMappingField_OIDC_MAPPING_FIELD_EMAIL,
UsernameMapping: idp.OIDCMappingField_OIDC_MAPPING_FIELD_EMAIL,
AutoRegister: true,
})
require.NoError(t, err)
idpRepo := repository.IDProviderRepository(pool)
retryDuration, tick := integration.WaitForAndTickWithMaxDuration(CTX, time.Minute)
assert.EventuallyWithT(t, func(t *assert.CollectT) {
oidc, err := idpRepo.GetOIDC(CTX,
idpRepo.IDCondition(addOIDC.IdpId),
instanceID,
nil,
)
require.NoError(t, err)
// event org.idp.oidc.config.added
// idp
assert.Equal(t, addOIDC.IdpId, oidc.ID)
assert.Equal(t, domain.IDPTypeOIDC.String(), oidc.Type)
// oidc
assert.Equal(t, addOIDC.IdpId, oidc.IDPConfigID)
assert.Equal(t, "issuer", oidc.Issuer)
assert.Equal(t, "clientID", oidc.ClientID)
assert.Equal(t, []string{"scope"}, oidc.Scopes)
assert.Equal(t, domain.OIDCMappingField(idp.OIDCMappingField_OIDC_MAPPING_FIELD_EMAIL), oidc.IDPDisplayNameMapping)
assert.Equal(t, domain.OIDCMappingField(idp.OIDCMappingField_OIDC_MAPPING_FIELD_EMAIL), oidc.UserNameMapping)
}, retryDuration, tick)
})
t.Run("test idp oidc changed reduces", func(t *testing.T) {
name := gofakeit.Name()
// add oidc
addOIDC, err := AdminClient.AddOIDCIDP(CTX, &admin.AddOIDCIDPRequest{
Name: name,
StylingType: idp_grpc.IDPStylingType_STYLING_TYPE_GOOGLE,
ClientId: "clientID",
ClientSecret: "clientSecret",
Issuer: "issuer",
Scopes: []string{"scope"},
DisplayNameMapping: idp.OIDCMappingField_OIDC_MAPPING_FIELD_EMAIL,
UsernameMapping: idp.OIDCMappingField_OIDC_MAPPING_FIELD_EMAIL,
AutoRegister: true,
})
require.NoError(t, err)
idpRepo := repository.IDProviderRepository(pool)
// check original values for OCID
var oidc *domain.IDPOIDC
retryDuration, tick := integration.WaitForAndTickWithMaxDuration(CTX, time.Minute)
assert.EventuallyWithT(t, func(t *assert.CollectT) {
oidc, err = idpRepo.GetOIDC(CTX, idpRepo.IDCondition(addOIDC.IdpId), instanceID, nil)
require.NoError(t, err)
}, retryDuration, tick)
// idp
assert.Equal(t, addOIDC.IdpId, oidc.ID)
assert.Equal(t, domain.IDPTypeOIDC.String(), oidc.Type)
// oidc
assert.Equal(t, addOIDC.IdpId, oidc.IDPConfigID)
assert.Equal(t, "issuer", oidc.Issuer)
assert.Equal(t, "clientID", oidc.ClientID)
assert.Equal(t, []string{"scope"}, oidc.Scopes)
assert.Equal(t, domain.OIDCMappingField(idp.OIDCMappingField_OIDC_MAPPING_FIELD_EMAIL), oidc.IDPDisplayNameMapping)
assert.Equal(t, domain.OIDCMappingField(idp.OIDCMappingField_OIDC_MAPPING_FIELD_EMAIL), oidc.UserNameMapping)
beforeCreate := time.Now()
_, err = AdminClient.UpdateIDPOIDCConfig(CTX, &admin.UpdateIDPOIDCConfigRequest{
IdpId: addOIDC.IdpId,
ClientId: "new_clientID",
ClientSecret: "new_clientSecret",
Issuer: "new_issuer",
Scopes: []string{"new_scope"},
DisplayNameMapping: idp.OIDCMappingField_OIDC_MAPPING_FIELD_PREFERRED_USERNAME,
UsernameMapping: idp.OIDCMappingField_OIDC_MAPPING_FIELD_PREFERRED_USERNAME,
})
afterCreate := time.Now()
require.NoError(t, err)
retryDuration, tick = integration.WaitForAndTickWithMaxDuration(CTX, time.Second*5)
assert.EventuallyWithT(t, func(t *assert.CollectT) {
updateOIDC, err := idpRepo.GetOIDC(CTX,
idpRepo.IDCondition(addOIDC.IdpId),
instanceID,
nil,
)
require.NoError(t, err)
// event org.idp.oidc.config.changed
// idp
assert.Equal(t, addOIDC.IdpId, updateOIDC.ID)
assert.Equal(t, domain.IDPTypeOIDC.String(), updateOIDC.Type)
assert.WithinRange(t, updateOIDC.UpdatedAt, beforeCreate, afterCreate)
// oidc
assert.Equal(t, addOIDC.IdpId, updateOIDC.IDPConfigID)
assert.Equal(t, "new_issuer", updateOIDC.Issuer)
assert.Equal(t, "new_clientID", updateOIDC.ClientID)
assert.NotEqual(t, oidc.ClientSecret, updateOIDC.ClientSecret)
assert.Equal(t, []string{"new_scope"}, updateOIDC.Scopes)
assert.Equal(t, domain.OIDCMappingField(idp.OIDCMappingField_OIDC_MAPPING_FIELD_PREFERRED_USERNAME), updateOIDC.IDPDisplayNameMapping)
assert.Equal(t, domain.OIDCMappingField(idp.OIDCMappingField_OIDC_MAPPING_FIELD_PREFERRED_USERNAME), updateOIDC.UserNameMapping)
}, retryDuration, tick)
})
t.Run("test idp jwt addded reduces", func(t *testing.T) {
name := gofakeit.Name()
// add jwt
addJWT, err := AdminClient.AddJWTIDP(CTX, &admin.AddJWTIDPRequest{
Name: name,
StylingType: idp_grpc.IDPStylingType_STYLING_TYPE_GOOGLE,
JwtEndpoint: "jwtEndpoint",
Issuer: "issuer",
KeysEndpoint: "keyEndpoint",
HeaderName: "headerName",
AutoRegister: true,
})
require.NoError(t, err)
idpRepo := repository.IDProviderRepository(pool)
retryDuration, tick := integration.WaitForAndTickWithMaxDuration(CTX, time.Second*5)
assert.EventuallyWithT(t, func(t *assert.CollectT) {
jwt, err := idpRepo.GetJWT(CTX,
idpRepo.IDCondition(addJWT.IdpId),
instanceID,
nil,
)
require.NoError(t, err)
// event org.idp.jwt.config.added
// idp
assert.Equal(t, addJWT.IdpId, jwt.ID)
assert.Equal(t, domain.IDPTypeJWT.String(), jwt.Type)
// jwt
assert.Equal(t, addJWT.IdpId, jwt.IDPConfigID)
assert.Equal(t, "jwtEndpoint", jwt.JWTEndpoint)
assert.Equal(t, "issuer", jwt.Issuer)
assert.Equal(t, "keyEndpoint", jwt.KeysEndpoint)
assert.Equal(t, "headerName", jwt.HeaderName)
}, retryDuration, tick)
})
t.Run("test idp jwt changed reduces", func(t *testing.T) {
name := gofakeit.Name()
// add jwt
addJWT, err := AdminClient.AddJWTIDP(CTX, &admin.AddJWTIDPRequest{
Name: name,
StylingType: idp_grpc.IDPStylingType_STYLING_TYPE_GOOGLE,
JwtEndpoint: "jwtEndpoint",
Issuer: "issuer",
KeysEndpoint: "keyEndpoint",
HeaderName: "headerName",
AutoRegister: true,
})
require.NoError(t, err)
idpRepo := repository.IDProviderRepository(pool)
// check original values for jwt
var jwt *domain.IDPJWT
retryDuration, tick := integration.WaitForAndTickWithMaxDuration(CTX, time.Minute)
assert.EventuallyWithT(t, func(t *assert.CollectT) {
jwt, err = idpRepo.GetJWT(CTX, idpRepo.IDCondition(addJWT.IdpId), instanceID, nil)
require.NoError(t, err)
}, retryDuration, tick)
// idp
assert.Equal(t, addJWT.IdpId, jwt.ID)
assert.Equal(t, domain.IDPTypeJWT.String(), jwt.Type)
// jwt
assert.Equal(t, addJWT.IdpId, jwt.IDPConfigID)
assert.Equal(t, "jwtEndpoint", jwt.JWTEndpoint)
assert.Equal(t, "issuer", jwt.Issuer)
assert.Equal(t, "keyEndpoint", jwt.KeysEndpoint)
assert.Equal(t, "headerName", jwt.HeaderName)
beforeCreate := time.Now()
_, err = AdminClient.UpdateIDPJWTConfig(CTX, &admin.UpdateIDPJWTConfigRequest{
IdpId: addJWT.IdpId,
JwtEndpoint: "new_jwtEndpoint",
Issuer: "new_issuer",
KeysEndpoint: "new_keyEndpoint",
HeaderName: "new_headerName",
})
afterCreate := time.Now()
require.NoError(t, err)
retryDuration, tick = integration.WaitForAndTickWithMaxDuration(CTX, time.Second*5)
assert.EventuallyWithT(t, func(t *assert.CollectT) {
updateJWT, err := idpRepo.GetJWT(CTX,
idpRepo.IDCondition(addJWT.IdpId),
instanceID,
nil,
)
require.NoError(t, err)
// event org.idp.jwt.config.changed
// idp
assert.Equal(t, addJWT.IdpId, updateJWT.ID)
assert.Equal(t, domain.IDPTypeJWT.String(), updateJWT.Type)
assert.WithinRange(t, updateJWT.UpdatedAt, beforeCreate, afterCreate)
// jwt
assert.Equal(t, addJWT.IdpId, updateJWT.IDPConfigID)
assert.Equal(t, "new_jwtEndpoint", updateJWT.JWTEndpoint)
assert.Equal(t, "new_issuer", updateJWT.Issuer)
assert.Equal(t, "new_keyEndpoint", updateJWT.KeysEndpoint)
}, retryDuration, tick)
})
}

View File

@@ -137,6 +137,6 @@ const (
func writeBooleanOperation[T Boolean](builder *StatementBuilder, col Column, value T) { func writeBooleanOperation[T Boolean](builder *StatementBuilder, col Column, value T) {
col.Write(builder) col.Write(builder)
builder.WriteString(" IS ") builder.WriteString(" = ")
builder.WriteArg(value) builder.WriteArg(value)
} }

View File

@@ -0,0 +1,312 @@
package repository
import (
"context"
"errors"
"github.com/zitadel/zitadel/backend/v3/domain"
"github.com/zitadel/zitadel/backend/v3/storage/database"
)
var _ domain.IDProviderRepository = (*idProvider)(nil)
type idProvider struct {
repository
}
func IDProviderRepository(client database.QueryExecutor) domain.IDProviderRepository {
return &idProvider{
repository: repository{
client: client,
},
}
}
const queryIDProviderStmt = `SELECT instance_id, org_id, id, state, name, type, allow_creation, allow_auto_creation,` +
` allow_auto_update, allow_linking, styling_type, payload, created_at, updated_at` +
` FROM zitadel.identity_providers`
func (i *idProvider) Get(ctx context.Context, id domain.IDPIdentifierCondition, instanceID string, orgID *string) (*domain.IdentityProvider, error) {
builder := database.StatementBuilder{}
builder.WriteString(queryIDProviderStmt)
conditions := []database.Condition{id, i.InstanceIDCondition(instanceID), i.OrgIDCondition(orgID)}
writeCondition(&builder, database.And(conditions...))
return scanIDProvider(ctx, i.client, &builder)
}
func (i *idProvider) List(ctx context.Context, conditions ...database.Condition) ([]*domain.IdentityProvider, error) {
builder := database.StatementBuilder{}
builder.WriteString(queryIDProviderStmt)
if conditions != nil {
writeCondition(&builder, database.And(conditions...))
}
orderBy := database.OrderBy(i.CreatedAtColumn())
orderBy.Write(&builder)
return scanIDProviders(ctx, i.client, &builder)
}
const createIDProviderStmt = `INSERT INTO zitadel.identity_providers` +
` (instance_id, org_id, id, state, name, type, allow_creation, allow_auto_creation,` +
` allow_auto_update, allow_linking, styling_type, payload)` +
` VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12)` +
` RETURNING created_at, updated_at`
func (i *idProvider) Create(ctx context.Context, idp *domain.IdentityProvider) error {
builder := database.StatementBuilder{}
builder.AppendArgs(
idp.InstanceID,
idp.OrgID,
idp.ID,
idp.State,
idp.Name,
idp.Type,
idp.AllowCreation,
idp.AllowAutoCreation,
idp.AllowAutoUpdate,
idp.AllowLinking,
idp.StylingType,
idp.Payload)
builder.WriteString(createIDProviderStmt)
err := i.client.QueryRow(ctx, builder.String(), builder.Args()...).Scan(&idp.CreatedAt, &idp.UpdatedAt)
if err != nil {
return checkCreateOrgErr(err)
}
return nil
}
func (i *idProvider) Update(ctx context.Context, id domain.IDPIdentifierCondition, instnaceID string, orgID *string, changes ...database.Change) (int64, error) {
if changes == nil {
return 0, errors.New("Update must contain at least one change")
}
builder := database.StatementBuilder{}
builder.WriteString(`UPDATE zitadel.identity_providers SET `)
conditions := []database.Condition{
id,
i.InstanceIDCondition(instnaceID),
i.OrgIDCondition(orgID),
}
database.Changes(changes).Write(&builder)
writeCondition(&builder, database.And(conditions...))
stmt := builder.String()
return i.client.Exec(ctx, stmt, builder.Args()...)
}
func (i *idProvider) Delete(ctx context.Context, id domain.IDPIdentifierCondition, instnaceID string, orgID *string) (int64, error) {
builder := database.StatementBuilder{}
builder.WriteString(`DELETE FROM zitadel.identity_providers`)
conditions := []database.Condition{
id,
i.InstanceIDCondition(instnaceID),
i.OrgIDCondition(orgID),
}
writeCondition(&builder, database.And(conditions...))
return i.client.Exec(ctx, builder.String(), builder.Args()...)
}
// -------------------------------------------------------------
// columns
// -------------------------------------------------------------
func (idProvider) InstanceIDColumn() database.Column {
return database.NewColumn("instance_id")
}
func (idProvider) OrgIDColumn() database.Column {
return database.NewColumn("org_id")
}
func (idProvider) IDColumn() database.Column {
return database.NewColumn("id")
}
func (idProvider) StateColumn() database.Column {
return database.NewColumn("state")
}
func (idProvider) NameColumn() database.Column {
return database.NewColumn("name")
}
func (idProvider) TypeColumn() database.Column {
return database.NewColumn("type")
}
func (idProvider) AutoRegisterColumn() database.Column {
return database.NewColumn("auto_register")
}
func (idProvider) AllowCreationColumn() database.Column {
return database.NewColumn("allow_creation")
}
func (idProvider) AllowAutoCreationColumn() database.Column {
return database.NewColumn("allow_auto_creation")
}
func (idProvider) AllowAutoUpdateColumn() database.Column {
return database.NewColumn("allow_auto_update")
}
func (idProvider) AllowLinkingColumn() database.Column {
return database.NewColumn("allow_linking")
}
func (idProvider) AllowAutoLinkingColumn() database.Column {
return database.NewColumn("allow_auto_linking")
}
func (idProvider) StylingTypeColumn() database.Column {
return database.NewColumn("styling_type")
}
func (idProvider) PayloadColumn() database.Column {
return database.NewColumn("payload")
}
func (idProvider) CreatedAtColumn() database.Column {
return database.NewColumn("created_at")
}
func (idProvider) UpdatedAtColumn() database.Column {
return database.NewColumn("updated_at")
}
// -------------------------------------------------------------
// conditions
// -------------------------------------------------------------
func (i idProvider) InstanceIDCondition(id string) database.Condition {
return database.NewTextCondition(i.InstanceIDColumn(), database.TextOperationEqual, id)
}
func (i idProvider) OrgIDCondition(id *string) database.Condition {
if id == nil {
return nil
}
return database.NewTextCondition(i.OrgIDColumn(), database.TextOperationEqual, *id)
}
func (i idProvider) IDCondition(id string) domain.IDPIdentifierCondition {
return database.NewTextCondition(i.IDColumn(), database.TextOperationEqual, id)
}
func (i idProvider) StateCondition(state domain.IDPState) database.Condition {
return database.NewTextCondition(i.StateColumn(), database.TextOperationEqual, state.String())
}
func (i idProvider) NameCondition(name string) domain.IDPIdentifierCondition {
return database.NewTextCondition(i.NameColumn(), database.TextOperationEqual, name)
}
func (i idProvider) TypeCondition(typee domain.IDPType) database.Condition {
return database.NewTextCondition(i.TypeColumn(), database.TextOperationEqual, typee.String())
}
func (i idProvider) AutoRegisterCondition(allow bool) database.Condition {
return database.NewBooleanCondition(i.AutoRegisterColumn(), allow)
}
func (i idProvider) AllowCreationCondition(allow bool) database.Condition {
return database.NewBooleanCondition(i.AllowCreationColumn(), allow)
}
func (i idProvider) AllowAutoCreationCondition(allow bool) database.Condition {
return database.NewBooleanCondition(i.AllowAutoCreationColumn(), allow)
}
func (i idProvider) AllowAutoUpdateCondition(allow bool) database.Condition {
return database.NewBooleanCondition(i.AllowAutoUpdateColumn(), allow)
}
func (i idProvider) AllowLinkingCondition(allow bool) database.Condition {
return database.NewBooleanCondition(i.AllowLinkingColumn(), allow)
}
func (i idProvider) AllowAutoLinkingCondition(allow bool) database.Condition {
return database.NewBooleanCondition(i.AllowAutoLinkingColumn(), allow)
}
func (i idProvider) StylingTypeCondition(style int16) database.Condition {
return database.NewNumberCondition(i.StylingTypeColumn(), database.NumberOperationEqual, style)
}
func (i idProvider) PayloadCondition(payload string) database.Condition {
return database.NewTextCondition(i.PayloadColumn(), database.TextOperationEqual, payload)
}
// -------------------------------------------------------------
// changes
// -------------------------------------------------------------
func (i idProvider) SetName(name string) database.Change {
return database.NewChange(i.NameColumn(), name)
}
func (i idProvider) SetState(state domain.IDPState) database.Change {
return database.NewChange(i.StateColumn(), state)
}
func (i idProvider) SetAllowCreation(allow bool) database.Change {
return database.NewChange(i.AllowCreationColumn(), allow)
}
func (i idProvider) SetAutoRegister(allow bool) database.Change {
return database.NewChange(i.AutoRegisterColumn(), allow)
}
func (i idProvider) SetAllowAutoCreation(allow bool) database.Change {
return database.NewChange(i.AllowAutoCreationColumn(), allow)
}
func (i idProvider) SetAllowAutoUpdate(allow bool) database.Change {
return database.NewChange(i.AllowAutoUpdateColumn(), allow)
}
func (i idProvider) SetAllowLinking(allow bool) database.Change {
return database.NewChange(i.AllowLinkingColumn(), allow)
}
func (i idProvider) SetAutoAllowLinking(allow bool) database.Change {
return database.NewChange(i.AllowAutoLinkingColumn(), allow)
}
func (i idProvider) SetStylingType(stylingType int16) database.Change {
return database.NewChange(i.StylingTypeColumn(), stylingType)
}
func (i idProvider) SetPayload(payload string) database.Change {
return database.NewChange(i.PayloadColumn(), payload)
}
func scanIDProvider(ctx context.Context, querier database.Querier, builder *database.StatementBuilder) (*domain.IdentityProvider, error) {
idp := &domain.IdentityProvider{}
err := scanRow(ctx, querier, builder, idp)
if err != nil {
return nil, err
}
return idp, err
}
func scanIDProviders(ctx context.Context, querier database.Querier, builder *database.StatementBuilder) ([]*domain.IdentityProvider, error) {
idps := []*domain.IdentityProvider{}
err := scanRows(ctx, querier, builder, &idps)
if err != nil {
return nil, err
}
return idps, nil
}

File diff suppressed because it is too large Load Diff

View File

@@ -172,28 +172,18 @@ func (instance) UpdatedAtColumn() database.Column {
} }
func scanInstance(ctx context.Context, querier database.Querier, builder *database.StatementBuilder) (*domain.Instance, error) { func scanInstance(ctx context.Context, querier database.Querier, builder *database.StatementBuilder) (*domain.Instance, error) {
rows, err := querier.Query(ctx, builder.String(), builder.Args()...) instance := new(domain.Instance)
err := scanRow(ctx, querier, builder, instance)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return instance, err
instance := new(domain.Instance)
if err := rows.(database.CollectableRows).CollectExactlyOneRow(instance); err != nil {
return nil, err
}
return instance, nil
} }
func scanInstances(ctx context.Context, querier database.Querier, builder *database.StatementBuilder) (instances []*domain.Instance, err error) { func scanInstances(ctx context.Context, querier database.Querier, builder *database.StatementBuilder) (instances []*domain.Instance, err error) {
rows, err := querier.Query(ctx, builder.String(), builder.Args()...) err = scanRows(ctx, querier, builder, &instances)
if err != nil { if err != nil {
return nil, err return nil, err
} }
if err := rows.(database.CollectableRows).Collect(&instances); err != nil {
return nil, err
}
return instances, nil return instances, nil
} }

View File

@@ -217,27 +217,19 @@ func (org) UpdatedAtColumn() database.Column {
} }
func scanOrganization(ctx context.Context, querier database.Querier, builder *database.StatementBuilder) (*domain.Organization, error) { func scanOrganization(ctx context.Context, querier database.Querier, builder *database.StatementBuilder) (*domain.Organization, error) {
rows, err := querier.Query(ctx, builder.String(), builder.Args()...) organization := &domain.Organization{}
err := scanRow(ctx, querier, builder, organization)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return organization, err
organization := &domain.Organization{}
if err := rows.(database.CollectableRows).CollectExactlyOneRow(organization); err != nil {
return nil, err
}
return organization, nil
} }
func scanOrganizations(ctx context.Context, querier database.Querier, builder *database.StatementBuilder) ([]*domain.Organization, error) { func scanOrganizations(ctx context.Context, querier database.Querier, builder *database.StatementBuilder) ([]*domain.Organization, error) {
rows, err := querier.Query(ctx, builder.String(), builder.Args()...)
if err != nil {
return nil, err
}
organizations := []*domain.Organization{} organizations := []*domain.Organization{}
if err := rows.(database.CollectableRows).Collect(&organizations); err != nil {
err := scanRows(ctx, querier, builder, &organizations)
if err != nil {
return nil, err return nil, err
} }
return organizations, nil return organizations, nil

View File

@@ -519,11 +519,6 @@ func TestGetOrganization(t *testing.T) {
return return
} }
if org.Name == "non existent org" {
assert.Nil(t, returnedOrg)
return
}
assert.Equal(t, returnedOrg.ID, org.ID) assert.Equal(t, returnedOrg.ID, org.ID)
assert.Equal(t, returnedOrg.Name, org.Name) assert.Equal(t, returnedOrg.Name, org.Name)
assert.Equal(t, returnedOrg.InstanceID, org.InstanceID) assert.Equal(t, returnedOrg.InstanceID, org.InstanceID)
@@ -815,8 +810,7 @@ func TestDeleteOrganization(t *testing.T) {
return test{ return test{
name: "happy path delete organization filter id", name: "happy path delete organization filter id",
testFunc: func(ctx context.Context, t *testing.T) { testFunc: func(ctx context.Context, t *testing.T) {
organizations := make([]*domain.Organization, noOfOrganizations) for range noOfOrganizations {
for i := range noOfOrganizations {
org := domain.Organization{ org := domain.Organization{
ID: organizationId, ID: organizationId,
@@ -829,7 +823,6 @@ func TestDeleteOrganization(t *testing.T) {
err := organizationRepo.Create(ctx, &org) err := organizationRepo.Create(ctx, &org)
require.NoError(t, err) require.NoError(t, err)
organizations[i] = &org
} }
}, },
orgIdentifierCondition: organizationRepo.IDCondition(organizationId), orgIdentifierCondition: organizationRepo.IDCondition(organizationId),
@@ -843,8 +836,7 @@ func TestDeleteOrganization(t *testing.T) {
return test{ return test{
name: "happy path delete organization filter name", name: "happy path delete organization filter name",
testFunc: func(ctx context.Context, t *testing.T) { testFunc: func(ctx context.Context, t *testing.T) {
organizations := make([]*domain.Organization, noOfOrganizations) for range noOfOrganizations {
for i := range noOfOrganizations {
org := domain.Organization{ org := domain.Organization{
ID: gofakeit.Name(), ID: gofakeit.Name(),
@@ -857,7 +849,6 @@ func TestDeleteOrganization(t *testing.T) {
err := organizationRepo.Create(ctx, &org) err := organizationRepo.Create(ctx, &org)
require.NoError(t, err) require.NoError(t, err)
organizations[i] = &org
} }
}, },
orgIdentifierCondition: organizationRepo.NameCondition(organizationName), orgIdentifierCondition: organizationRepo.NameCondition(organizationName),
@@ -879,8 +870,7 @@ func TestDeleteOrganization(t *testing.T) {
name: "deleted already deleted organization", name: "deleted already deleted organization",
testFunc: func(ctx context.Context, t *testing.T) { testFunc: func(ctx context.Context, t *testing.T) {
noOfOrganizations := 1 noOfOrganizations := 1
organizations := make([]*domain.Organization, noOfOrganizations) for range noOfOrganizations {
for i := range noOfOrganizations {
org := domain.Organization{ org := domain.Organization{
ID: gofakeit.Name(), ID: gofakeit.Name(),
@@ -893,13 +883,12 @@ func TestDeleteOrganization(t *testing.T) {
err := organizationRepo.Create(ctx, &org) err := organizationRepo.Create(ctx, &org)
require.NoError(t, err) require.NoError(t, err)
organizations[i] = &org
} }
// delete organization // delete organization
affectedRows, err := organizationRepo.Delete(ctx, affectedRows, err := organizationRepo.Delete(ctx,
organizationRepo.NameCondition(organizationName), organizationRepo.NameCondition(organizationName),
organizations[0].InstanceID, instanceId,
) )
assert.Equal(t, int64(1), affectedRows) assert.Equal(t, int64(1), affectedRows)
require.NoError(t, err) require.NoError(t, err)

View File

@@ -1,6 +1,8 @@
package repository package repository
import ( import (
"context"
"github.com/zitadel/zitadel/backend/v3/storage/database" "github.com/zitadel/zitadel/backend/v3/storage/database"
) )
@@ -18,3 +20,21 @@ func writeCondition(
builder.WriteString(" WHERE ") builder.WriteString(" WHERE ")
condition.Write(builder) condition.Write(builder)
} }
func scanRow(ctx context.Context, querier database.Querier, builder *database.StatementBuilder, res any) error {
rows, err := querier.Query(ctx, builder.String(), builder.Args()...)
if err != nil {
return err
}
return rows.(database.CollectableRows).CollectExactlyOneRow(res)
}
func scanRows(ctx context.Context, querier database.Querier, builder *database.StatementBuilder, res any) error {
rows, err := querier.Query(ctx, builder.String(), builder.Args()...)
if err != nil {
return err
}
return rows.(database.CollectableRows).Collect(res)
}

View File

@@ -29,9 +29,7 @@ func (b *StatementBuilder) AppendArg(arg any) (placeholder string) {
if b.existingArgs == nil { if b.existingArgs == nil {
b.existingArgs = make(map[any]string) b.existingArgs = make(map[any]string)
} }
if placeholder, ok := b.existingArgs[arg]; ok {
return placeholder
}
if instruction, ok := arg.(Instruction); ok { if instruction, ok := arg.(Instruction); ok {
return string(instruction) return string(instruction)
} }