feat(idp_table_relational): adding inital idp tables for relational repository

This commit is contained in:
Iraq Jaber
2025-07-27 12:13:35 +01:00
parent 13b772aa8c
commit 9fd4f6f2b5
19 changed files with 3154 additions and 52 deletions

View File

@@ -16,11 +16,16 @@ func (a *and) Write(builder *StatementBuilder) {
builder.WriteString("(")
defer builder.WriteString(")")
}
for i, condition := range a.conditions {
if i > 0 {
firstCondition := true
for _, condition := range a.conditions {
if condition == nil {
continue
}
if !firstCondition {
builder.WriteString(" AND ")
}
condition.Write(builder)
firstCondition = false
}
}

View File

@@ -0,0 +1,16 @@
package migration
import (
_ "embed"
)
var (
//go:embed 003_identity_providers_table/up.sql
up003IdentityProvidersTable string
//go:embed 003_identity_providers_table/down.sql
down003IdentityProvidersTable string
)
func init() {
registerSQLMigration(3, up003IdentityProvidersTable, down003IdentityProvidersTable)
}

View File

@@ -0,0 +1,3 @@
DROP TABLE zitadel.identity_providers;
DROP Type zitadel.idp_state;
DROP Type zitadel.idp_type;

View File

@@ -0,0 +1,55 @@
CREATE TYPE zitadel.idp_state AS ENUM (
'active',
'inactive'
);
CREATE TYPE zitadel.idp_type AS ENUM (
'oidc',
'jwt',
'oauth',
'saml',
'ldap',
'github',
'google',
'microsoft',
'apple'
);
CREATE TABLE zitadel.identity_providers (
instance_id TEXT NOT NULL
, org_id TEXT
, id TEXT NOT NULL CHECK (id <> '')
, state zitadel.idp_state NOT NULL DEFAULT 'active'
, name TEXT NOT NULL CHECK (name <> '')
, type zitadel.idp_type -- NOT NULL
, auto_register BOOLEAN NOT NULL DEFAULT TRUE
, allow_creation BOOLEAN NOT NULL DEFAULT TRUE
, allow_auto_creation BOOLEAN NOT NULL DEFAULT TRUE
, allow_auto_update BOOLEAN NOT NULL DEFAULT TRUE
, allow_linking BOOLEAN NOT NULL DEFAULT TRUE
, allow_auto_linking BOOLEAN NOT NULL DEFAULT TRUE
, styling_type SMALLINT
, payload JSONB
, created_at TIMESTAMPTZ NOT NULL DEFAULT now()
, updated_at TIMESTAMPTZ NOT NULL DEFAULT now()
, PRIMARY KEY (instance_id, id)
, CONSTRAINT identity_providers_id_unique UNIQUE NULLS NOT DISTINCT (instance_id, org_id, id)
, CONSTRAINT identity_providers_name_unique UNIQUE NULLS NOT DISTINCT (instance_id, org_id, name)
, FOREIGN KEY (instance_id) REFERENCES zitadel.instances(id)
, FOREIGN KEY (instance_id, org_id) REFERENCES zitadel.organizations(instance_id, id)
);
-- CREATE INDEX idx_identity_providers_org_id ON identity_providers(instance_id, org_id) WHERE org_id IS NOT NULL;
CREATE INDEX idx_identity_providers_state ON zitadel.identity_providers(instance_id, state);
CREATE INDEX idx_identity_providers_type ON zitadel.identity_providers(instance_id, type);
-- CREATE INDEX idx_identity_providers_created_at ON identity_providers(created_at);
-- CREATE INDEX idx_identity_providers_deleted_at ON identity_providers(deleted_at) WHERE deleted_at IS NOT NULL;
CREATE TRIGGER trigger_set_updated_at
BEFORE UPDATE ON zitadel.identity_providers
FOR EACH ROW
WHEN (OLD.updated_at IS NOT DISTINCT FROM NEW.updated_at)
EXECUTE FUNCTION zitadel.set_updated_at();

View File

@@ -0,0 +1,485 @@
//go:build integration
package events_test
import (
"testing"
"time"
"github.com/brianvoe/gofakeit/v6"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/zitadel/zitadel/backend/v3/domain"
"github.com/zitadel/zitadel/backend/v3/storage/database"
"github.com/zitadel/zitadel/backend/v3/storage/database/repository"
"github.com/zitadel/zitadel/internal/integration"
"github.com/zitadel/zitadel/pkg/grpc/admin"
"github.com/zitadel/zitadel/pkg/grpc/idp"
idp_grpc "github.com/zitadel/zitadel/pkg/grpc/idp"
)
func TestServer_TestIDProviderReduces(t *testing.T) {
instanceID := Instance.ID()
t.Run("test idp add reduces", func(t *testing.T) {
name := gofakeit.Name()
beforeCreate := time.Now()
addOIDC, err := AdminClient.AddOIDCIDP(CTX, &admin.AddOIDCIDPRequest{
Name: name,
StylingType: idp_grpc.IDPStylingType_STYLING_TYPE_GOOGLE,
ClientId: "clientID",
ClientSecret: "clientSecret",
Issuer: "issuer",
Scopes: []string{"scope"},
DisplayNameMapping: idp.OIDCMappingField_OIDC_MAPPING_FIELD_EMAIL,
UsernameMapping: idp.OIDCMappingField_OIDC_MAPPING_FIELD_EMAIL,
AutoRegister: true,
})
require.NoError(t, err)
afterCreate := time.Now()
idpRepo := repository.IDProviderRepository(pool)
retryDuration, tick := integration.WaitForAndTickWithMaxDuration(CTX, time.Second*5)
assert.EventuallyWithT(t, func(t *assert.CollectT) {
idp, err := idpRepo.Get(CTX,
idpRepo.NameCondition(name),
instanceID,
nil,
)
require.NoError(t, err)
// event iam.idp.config.added
assert.Equal(t, addOIDC.IdpId, idp.ID)
assert.Equal(t, name, idp.Name)
assert.Equal(t, instanceID, idp.InstanceID)
assert.Equal(t, domain.IDPStateActive.String(), idp.State)
assert.Equal(t, true, idp.AutoRegister)
assert.Equal(t, int16(idp_grpc.IDPStylingType_STYLING_TYPE_GOOGLE), idp.StylingType)
assert.WithinRange(t, idp.UpdatedAt, beforeCreate, afterCreate)
assert.WithinRange(t, idp.CreatedAt, beforeCreate, afterCreate)
}, retryDuration, tick)
})
t.Run("test idp update reduces", func(t *testing.T) {
name := gofakeit.Name()
addOIDC, err := AdminClient.AddOIDCIDP(CTX, &admin.AddOIDCIDPRequest{
Name: name,
StylingType: idp_grpc.IDPStylingType_STYLING_TYPE_GOOGLE,
ClientId: "clientID",
ClientSecret: "clientSecret",
Issuer: "issuer",
Scopes: []string{"scope"},
DisplayNameMapping: idp.OIDCMappingField_OIDC_MAPPING_FIELD_EMAIL,
UsernameMapping: idp.OIDCMappingField_OIDC_MAPPING_FIELD_EMAIL,
AutoRegister: true,
})
require.NoError(t, err)
name = "new_" + name
beforeCreate := time.Now()
_, err = AdminClient.UpdateIDP(CTX, &admin.UpdateIDPRequest{
IdpId: addOIDC.IdpId,
Name: name,
StylingType: idp_grpc.IDPStylingType_STYLING_TYPE_UNSPECIFIED,
AutoRegister: false,
})
afterCreate := time.Now()
require.NoError(t, err)
idpRepo := repository.IDProviderRepository(pool)
retryDuration, tick := integration.WaitForAndTickWithMaxDuration(CTX, time.Second*5)
assert.EventuallyWithT(t, func(t *assert.CollectT) {
idp, err := idpRepo.Get(CTX,
idpRepo.NameCondition(name),
instanceID,
nil,
)
require.NoError(t, err)
// event iam.idp.config.changed
assert.Equal(t, addOIDC.IdpId, idp.ID)
assert.Equal(t, name, idp.Name)
assert.Equal(t, false, idp.AutoRegister)
assert.Equal(t, int16(idp_grpc.IDPStylingType_STYLING_TYPE_UNSPECIFIED), idp.StylingType)
assert.WithinRange(t, idp.UpdatedAt, beforeCreate, afterCreate)
}, retryDuration, tick)
})
t.Run("test idp deactivate reduces", func(t *testing.T) {
name := gofakeit.Name()
addOIDC, err := AdminClient.AddOIDCIDP(CTX, &admin.AddOIDCIDPRequest{
Name: name,
StylingType: idp_grpc.IDPStylingType_STYLING_TYPE_GOOGLE,
ClientId: "clientID",
ClientSecret: "clientSecret",
Issuer: "issuer",
Scopes: []string{"scope"},
DisplayNameMapping: idp.OIDCMappingField_OIDC_MAPPING_FIELD_EMAIL,
UsernameMapping: idp.OIDCMappingField_OIDC_MAPPING_FIELD_EMAIL,
AutoRegister: true,
})
require.NoError(t, err)
// deactivate idp
beforeCreate := time.Now()
_, err = AdminClient.DeactivateIDP(CTX, &admin.DeactivateIDPRequest{
IdpId: addOIDC.IdpId,
})
afterCreate := time.Now()
require.NoError(t, err)
idpRepo := repository.IDProviderRepository(pool)
retryDuration, tick := integration.WaitForAndTickWithMaxDuration(CTX, time.Minute)
assert.EventuallyWithT(t, func(t *assert.CollectT) {
idp, err := idpRepo.Get(CTX,
idpRepo.IDCondition(addOIDC.IdpId),
instanceID,
nil,
)
require.NoError(t, err)
// event iam.idp.config.deactivated
assert.Equal(t, addOIDC.IdpId, idp.ID)
assert.Equal(t, domain.IDPStateInactive.String(), idp.State)
assert.WithinRange(t, idp.UpdatedAt, beforeCreate, afterCreate)
}, retryDuration, tick)
})
t.Run("test idp reactivate reduces", func(t *testing.T) {
name := gofakeit.Name()
addOIDC, err := AdminClient.AddOIDCIDP(CTX, &admin.AddOIDCIDPRequest{
Name: name,
StylingType: idp_grpc.IDPStylingType_STYLING_TYPE_GOOGLE,
ClientId: "clientID",
ClientSecret: "clientSecret",
Issuer: "issuer",
Scopes: []string{"scope"},
DisplayNameMapping: idp.OIDCMappingField_OIDC_MAPPING_FIELD_EMAIL,
UsernameMapping: idp.OIDCMappingField_OIDC_MAPPING_FIELD_EMAIL,
AutoRegister: true,
})
require.NoError(t, err)
idpRepo := repository.IDProviderRepository(pool)
// deactivate idp
_, err = AdminClient.DeactivateIDP(CTX, &admin.DeactivateIDPRequest{
IdpId: addOIDC.IdpId,
})
require.NoError(t, err)
// wait for idp to be deactivated
retryDuration, tick := integration.WaitForAndTickWithMaxDuration(CTX, time.Minute)
assert.EventuallyWithT(t, func(t *assert.CollectT) {
idp, err := idpRepo.Get(CTX,
idpRepo.IDCondition(addOIDC.IdpId),
instanceID,
nil,
)
require.NoError(t, err)
assert.Equal(t, addOIDC.IdpId, idp.ID)
assert.Equal(t, domain.IDPStateInactive.String(), idp.State)
}, retryDuration, tick)
// reactivate idp
beforeCreate := time.Now()
_, err = AdminClient.ReactivateIDP(CTX, &admin.ReactivateIDPRequest{
IdpId: addOIDC.IdpId,
})
afterCreate := time.Now()
require.NoError(t, err)
retryDuration, tick = integration.WaitForAndTickWithMaxDuration(CTX, time.Minute)
assert.EventuallyWithT(t, func(t *assert.CollectT) {
idp, err := idpRepo.Get(CTX,
idpRepo.IDCondition(addOIDC.IdpId),
instanceID,
nil,
)
require.NoError(t, err)
// event iam.idp.config.reactivated
assert.Equal(t, addOIDC.IdpId, idp.ID)
assert.Equal(t, domain.IDPStateActive.String(), idp.State)
assert.WithinRange(t, idp.UpdatedAt, beforeCreate, afterCreate)
}, retryDuration, tick)
})
t.Run("test idp remove reduces", func(t *testing.T) {
name := gofakeit.Name()
// add idp
addOIDC, err := AdminClient.AddOIDCIDP(CTX, &admin.AddOIDCIDPRequest{
Name: name,
StylingType: idp_grpc.IDPStylingType_STYLING_TYPE_GOOGLE,
ClientId: "clientID",
ClientSecret: "clientSecret",
Issuer: "issuer",
Scopes: []string{"scope"},
DisplayNameMapping: idp.OIDCMappingField_OIDC_MAPPING_FIELD_EMAIL,
UsernameMapping: idp.OIDCMappingField_OIDC_MAPPING_FIELD_EMAIL,
AutoRegister: true,
})
require.NoError(t, err)
idpRepo := repository.IDProviderRepository(pool)
// remove idp
_, err = AdminClient.RemoveIDP(CTX, &admin.RemoveIDPRequest{
IdpId: addOIDC.IdpId,
})
require.NoError(t, err)
retryDuration, tick := integration.WaitForAndTickWithMaxDuration(CTX, time.Second*20)
assert.EventuallyWithT(t, func(t *assert.CollectT) {
_, err := idpRepo.Get(CTX,
idpRepo.IDCondition(addOIDC.IdpId),
instanceID,
nil,
)
// event iam.idp.config.remove
require.ErrorIs(t, &database.NoRowFoundError{}, err)
}, retryDuration, tick)
})
t.Run("test idp oidc addded reduces", func(t *testing.T) {
name := gofakeit.Name()
// add oidc
addOIDC, err := AdminClient.AddOIDCIDP(CTX, &admin.AddOIDCIDPRequest{
Name: name,
StylingType: idp_grpc.IDPStylingType_STYLING_TYPE_GOOGLE,
ClientId: "clientID",
ClientSecret: "clientSecret",
Issuer: "issuer",
Scopes: []string{"scope"},
DisplayNameMapping: idp.OIDCMappingField_OIDC_MAPPING_FIELD_EMAIL,
UsernameMapping: idp.OIDCMappingField_OIDC_MAPPING_FIELD_EMAIL,
AutoRegister: true,
})
require.NoError(t, err)
idpRepo := repository.IDProviderRepository(pool)
retryDuration, tick := integration.WaitForAndTickWithMaxDuration(CTX, time.Minute)
assert.EventuallyWithT(t, func(t *assert.CollectT) {
oidc, err := idpRepo.GetOIDC(CTX,
idpRepo.IDCondition(addOIDC.IdpId),
instanceID,
nil,
)
require.NoError(t, err)
// event org.idp.oidc.config.added
// idp
assert.Equal(t, addOIDC.IdpId, oidc.ID)
assert.Equal(t, domain.IDPTypeOIDC.String(), oidc.Type)
// oidc
assert.Equal(t, addOIDC.IdpId, oidc.IDPConfigID)
assert.Equal(t, "issuer", oidc.Issuer)
assert.Equal(t, "clientID", oidc.ClientID)
assert.Equal(t, []string{"scope"}, oidc.Scopes)
assert.Equal(t, domain.OIDCMappingField(idp.OIDCMappingField_OIDC_MAPPING_FIELD_EMAIL), oidc.IDPDisplayNameMapping)
assert.Equal(t, domain.OIDCMappingField(idp.OIDCMappingField_OIDC_MAPPING_FIELD_EMAIL), oidc.UserNameMapping)
}, retryDuration, tick)
})
t.Run("test idp oidc changed reduces", func(t *testing.T) {
name := gofakeit.Name()
// add oidc
addOIDC, err := AdminClient.AddOIDCIDP(CTX, &admin.AddOIDCIDPRequest{
Name: name,
StylingType: idp_grpc.IDPStylingType_STYLING_TYPE_GOOGLE,
ClientId: "clientID",
ClientSecret: "clientSecret",
Issuer: "issuer",
Scopes: []string{"scope"},
DisplayNameMapping: idp.OIDCMappingField_OIDC_MAPPING_FIELD_EMAIL,
UsernameMapping: idp.OIDCMappingField_OIDC_MAPPING_FIELD_EMAIL,
AutoRegister: true,
})
require.NoError(t, err)
idpRepo := repository.IDProviderRepository(pool)
// check original values for OCID
var oidc *domain.IDPOIDC
retryDuration, tick := integration.WaitForAndTickWithMaxDuration(CTX, time.Minute)
assert.EventuallyWithT(t, func(t *assert.CollectT) {
oidc, err = idpRepo.GetOIDC(CTX, idpRepo.IDCondition(addOIDC.IdpId), instanceID, nil)
require.NoError(t, err)
}, retryDuration, tick)
// idp
assert.Equal(t, addOIDC.IdpId, oidc.ID)
assert.Equal(t, domain.IDPTypeOIDC.String(), oidc.Type)
// oidc
assert.Equal(t, addOIDC.IdpId, oidc.IDPConfigID)
assert.Equal(t, "issuer", oidc.Issuer)
assert.Equal(t, "clientID", oidc.ClientID)
assert.Equal(t, []string{"scope"}, oidc.Scopes)
assert.Equal(t, domain.OIDCMappingField(idp.OIDCMappingField_OIDC_MAPPING_FIELD_EMAIL), oidc.IDPDisplayNameMapping)
assert.Equal(t, domain.OIDCMappingField(idp.OIDCMappingField_OIDC_MAPPING_FIELD_EMAIL), oidc.UserNameMapping)
beforeCreate := time.Now()
_, err = AdminClient.UpdateIDPOIDCConfig(CTX, &admin.UpdateIDPOIDCConfigRequest{
IdpId: addOIDC.IdpId,
ClientId: "new_clientID",
ClientSecret: "new_clientSecret",
Issuer: "new_issuer",
Scopes: []string{"new_scope"},
DisplayNameMapping: idp.OIDCMappingField_OIDC_MAPPING_FIELD_PREFERRED_USERNAME,
UsernameMapping: idp.OIDCMappingField_OIDC_MAPPING_FIELD_PREFERRED_USERNAME,
})
afterCreate := time.Now()
require.NoError(t, err)
retryDuration, tick = integration.WaitForAndTickWithMaxDuration(CTX, time.Second*5)
assert.EventuallyWithT(t, func(t *assert.CollectT) {
updateOIDC, err := idpRepo.GetOIDC(CTX,
idpRepo.IDCondition(addOIDC.IdpId),
instanceID,
nil,
)
require.NoError(t, err)
// event org.idp.oidc.config.changed
// idp
assert.Equal(t, addOIDC.IdpId, updateOIDC.ID)
assert.Equal(t, domain.IDPTypeOIDC.String(), updateOIDC.Type)
assert.WithinRange(t, updateOIDC.UpdatedAt, beforeCreate, afterCreate)
// oidc
assert.Equal(t, addOIDC.IdpId, updateOIDC.IDPConfigID)
assert.Equal(t, "new_issuer", updateOIDC.Issuer)
assert.Equal(t, "new_clientID", updateOIDC.ClientID)
assert.NotEqual(t, oidc.ClientSecret, updateOIDC.ClientSecret)
assert.Equal(t, []string{"new_scope"}, updateOIDC.Scopes)
assert.Equal(t, domain.OIDCMappingField(idp.OIDCMappingField_OIDC_MAPPING_FIELD_PREFERRED_USERNAME), updateOIDC.IDPDisplayNameMapping)
assert.Equal(t, domain.OIDCMappingField(idp.OIDCMappingField_OIDC_MAPPING_FIELD_PREFERRED_USERNAME), updateOIDC.UserNameMapping)
}, retryDuration, tick)
})
t.Run("test idp jwt addded reduces", func(t *testing.T) {
name := gofakeit.Name()
// add jwt
addJWT, err := AdminClient.AddJWTIDP(CTX, &admin.AddJWTIDPRequest{
Name: name,
StylingType: idp_grpc.IDPStylingType_STYLING_TYPE_GOOGLE,
JwtEndpoint: "jwtEndpoint",
Issuer: "issuer",
KeysEndpoint: "keyEndpoint",
HeaderName: "headerName",
AutoRegister: true,
})
require.NoError(t, err)
idpRepo := repository.IDProviderRepository(pool)
retryDuration, tick := integration.WaitForAndTickWithMaxDuration(CTX, time.Second*5)
assert.EventuallyWithT(t, func(t *assert.CollectT) {
jwt, err := idpRepo.GetJWT(CTX,
idpRepo.IDCondition(addJWT.IdpId),
instanceID,
nil,
)
require.NoError(t, err)
// event org.idp.jwt.config.added
// idp
assert.Equal(t, addJWT.IdpId, jwt.ID)
assert.Equal(t, domain.IDPTypeJWT.String(), jwt.Type)
// jwt
assert.Equal(t, addJWT.IdpId, jwt.IDPConfigID)
assert.Equal(t, "jwtEndpoint", jwt.JWTEndpoint)
assert.Equal(t, "issuer", jwt.Issuer)
assert.Equal(t, "keyEndpoint", jwt.KeysEndpoint)
assert.Equal(t, "headerName", jwt.HeaderName)
}, retryDuration, tick)
})
t.Run("test idp jwt changed reduces", func(t *testing.T) {
name := gofakeit.Name()
// add jwt
addJWT, err := AdminClient.AddJWTIDP(CTX, &admin.AddJWTIDPRequest{
Name: name,
StylingType: idp_grpc.IDPStylingType_STYLING_TYPE_GOOGLE,
JwtEndpoint: "jwtEndpoint",
Issuer: "issuer",
KeysEndpoint: "keyEndpoint",
HeaderName: "headerName",
AutoRegister: true,
})
require.NoError(t, err)
idpRepo := repository.IDProviderRepository(pool)
// check original values for jwt
var jwt *domain.IDPJWT
retryDuration, tick := integration.WaitForAndTickWithMaxDuration(CTX, time.Minute)
assert.EventuallyWithT(t, func(t *assert.CollectT) {
jwt, err = idpRepo.GetJWT(CTX, idpRepo.IDCondition(addJWT.IdpId), instanceID, nil)
require.NoError(t, err)
}, retryDuration, tick)
// idp
assert.Equal(t, addJWT.IdpId, jwt.ID)
assert.Equal(t, domain.IDPTypeJWT.String(), jwt.Type)
// jwt
assert.Equal(t, addJWT.IdpId, jwt.IDPConfigID)
assert.Equal(t, "jwtEndpoint", jwt.JWTEndpoint)
assert.Equal(t, "issuer", jwt.Issuer)
assert.Equal(t, "keyEndpoint", jwt.KeysEndpoint)
assert.Equal(t, "headerName", jwt.HeaderName)
beforeCreate := time.Now()
_, err = AdminClient.UpdateIDPJWTConfig(CTX, &admin.UpdateIDPJWTConfigRequest{
IdpId: addJWT.IdpId,
JwtEndpoint: "new_jwtEndpoint",
Issuer: "new_issuer",
KeysEndpoint: "new_keyEndpoint",
HeaderName: "new_headerName",
})
afterCreate := time.Now()
require.NoError(t, err)
retryDuration, tick = integration.WaitForAndTickWithMaxDuration(CTX, time.Second*5)
assert.EventuallyWithT(t, func(t *assert.CollectT) {
updateJWT, err := idpRepo.GetJWT(CTX,
idpRepo.IDCondition(addJWT.IdpId),
instanceID,
nil,
)
require.NoError(t, err)
// event org.idp.jwt.config.changed
// idp
assert.Equal(t, addJWT.IdpId, updateJWT.ID)
assert.Equal(t, domain.IDPTypeJWT.String(), updateJWT.Type)
assert.WithinRange(t, updateJWT.UpdatedAt, beforeCreate, afterCreate)
// jwt
assert.Equal(t, addJWT.IdpId, updateJWT.IDPConfigID)
assert.Equal(t, "new_jwtEndpoint", updateJWT.JWTEndpoint)
assert.Equal(t, "new_issuer", updateJWT.Issuer)
assert.Equal(t, "new_keyEndpoint", updateJWT.KeysEndpoint)
}, retryDuration, tick)
})
}

View File

@@ -137,6 +137,6 @@ const (
func writeBooleanOperation[T Boolean](builder *StatementBuilder, col Column, value T) {
col.Write(builder)
builder.WriteString(" IS ")
builder.WriteString(" = ")
builder.WriteArg(value)
}

View File

@@ -0,0 +1,312 @@
package repository
import (
"context"
"errors"
"github.com/zitadel/zitadel/backend/v3/domain"
"github.com/zitadel/zitadel/backend/v3/storage/database"
)
var _ domain.IDProviderRepository = (*idProvider)(nil)
type idProvider struct {
repository
}
func IDProviderRepository(client database.QueryExecutor) domain.IDProviderRepository {
return &idProvider{
repository: repository{
client: client,
},
}
}
const queryIDProviderStmt = `SELECT instance_id, org_id, id, state, name, type, allow_creation, allow_auto_creation,` +
` allow_auto_update, allow_linking, styling_type, payload, created_at, updated_at` +
` FROM zitadel.identity_providers`
func (i *idProvider) Get(ctx context.Context, id domain.IDPIdentifierCondition, instanceID string, orgID *string) (*domain.IdentityProvider, error) {
builder := database.StatementBuilder{}
builder.WriteString(queryIDProviderStmt)
conditions := []database.Condition{id, i.InstanceIDCondition(instanceID), i.OrgIDCondition(orgID)}
writeCondition(&builder, database.And(conditions...))
return scanIDProvider(ctx, i.client, &builder)
}
func (i *idProvider) List(ctx context.Context, conditions ...database.Condition) ([]*domain.IdentityProvider, error) {
builder := database.StatementBuilder{}
builder.WriteString(queryIDProviderStmt)
if conditions != nil {
writeCondition(&builder, database.And(conditions...))
}
orderBy := database.OrderBy(i.CreatedAtColumn())
orderBy.Write(&builder)
return scanIDProviders(ctx, i.client, &builder)
}
const createIDProviderStmt = `INSERT INTO zitadel.identity_providers` +
` (instance_id, org_id, id, state, name, type, allow_creation, allow_auto_creation,` +
` allow_auto_update, allow_linking, styling_type, payload)` +
` VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12)` +
` RETURNING created_at, updated_at`
func (i *idProvider) Create(ctx context.Context, idp *domain.IdentityProvider) error {
builder := database.StatementBuilder{}
builder.AppendArgs(
idp.InstanceID,
idp.OrgID,
idp.ID,
idp.State,
idp.Name,
idp.Type,
idp.AllowCreation,
idp.AllowAutoCreation,
idp.AllowAutoUpdate,
idp.AllowLinking,
idp.StylingType,
idp.Payload)
builder.WriteString(createIDProviderStmt)
err := i.client.QueryRow(ctx, builder.String(), builder.Args()...).Scan(&idp.CreatedAt, &idp.UpdatedAt)
if err != nil {
return checkCreateOrgErr(err)
}
return nil
}
func (i *idProvider) Update(ctx context.Context, id domain.IDPIdentifierCondition, instnaceID string, orgID *string, changes ...database.Change) (int64, error) {
if changes == nil {
return 0, errors.New("Update must contain at least one change")
}
builder := database.StatementBuilder{}
builder.WriteString(`UPDATE zitadel.identity_providers SET `)
conditions := []database.Condition{
id,
i.InstanceIDCondition(instnaceID),
i.OrgIDCondition(orgID),
}
database.Changes(changes).Write(&builder)
writeCondition(&builder, database.And(conditions...))
stmt := builder.String()
return i.client.Exec(ctx, stmt, builder.Args()...)
}
func (i *idProvider) Delete(ctx context.Context, id domain.IDPIdentifierCondition, instnaceID string, orgID *string) (int64, error) {
builder := database.StatementBuilder{}
builder.WriteString(`DELETE FROM zitadel.identity_providers`)
conditions := []database.Condition{
id,
i.InstanceIDCondition(instnaceID),
i.OrgIDCondition(orgID),
}
writeCondition(&builder, database.And(conditions...))
return i.client.Exec(ctx, builder.String(), builder.Args()...)
}
// -------------------------------------------------------------
// columns
// -------------------------------------------------------------
func (idProvider) InstanceIDColumn() database.Column {
return database.NewColumn("instance_id")
}
func (idProvider) OrgIDColumn() database.Column {
return database.NewColumn("org_id")
}
func (idProvider) IDColumn() database.Column {
return database.NewColumn("id")
}
func (idProvider) StateColumn() database.Column {
return database.NewColumn("state")
}
func (idProvider) NameColumn() database.Column {
return database.NewColumn("name")
}
func (idProvider) TypeColumn() database.Column {
return database.NewColumn("type")
}
func (idProvider) AutoRegisterColumn() database.Column {
return database.NewColumn("auto_register")
}
func (idProvider) AllowCreationColumn() database.Column {
return database.NewColumn("allow_creation")
}
func (idProvider) AllowAutoCreationColumn() database.Column {
return database.NewColumn("allow_auto_creation")
}
func (idProvider) AllowAutoUpdateColumn() database.Column {
return database.NewColumn("allow_auto_update")
}
func (idProvider) AllowLinkingColumn() database.Column {
return database.NewColumn("allow_linking")
}
func (idProvider) AllowAutoLinkingColumn() database.Column {
return database.NewColumn("allow_auto_linking")
}
func (idProvider) StylingTypeColumn() database.Column {
return database.NewColumn("styling_type")
}
func (idProvider) PayloadColumn() database.Column {
return database.NewColumn("payload")
}
func (idProvider) CreatedAtColumn() database.Column {
return database.NewColumn("created_at")
}
func (idProvider) UpdatedAtColumn() database.Column {
return database.NewColumn("updated_at")
}
// -------------------------------------------------------------
// conditions
// -------------------------------------------------------------
func (i idProvider) InstanceIDCondition(id string) database.Condition {
return database.NewTextCondition(i.InstanceIDColumn(), database.TextOperationEqual, id)
}
func (i idProvider) OrgIDCondition(id *string) database.Condition {
if id == nil {
return nil
}
return database.NewTextCondition(i.OrgIDColumn(), database.TextOperationEqual, *id)
}
func (i idProvider) IDCondition(id string) domain.IDPIdentifierCondition {
return database.NewTextCondition(i.IDColumn(), database.TextOperationEqual, id)
}
func (i idProvider) StateCondition(state domain.IDPState) database.Condition {
return database.NewTextCondition(i.StateColumn(), database.TextOperationEqual, state.String())
}
func (i idProvider) NameCondition(name string) domain.IDPIdentifierCondition {
return database.NewTextCondition(i.NameColumn(), database.TextOperationEqual, name)
}
func (i idProvider) TypeCondition(typee domain.IDPType) database.Condition {
return database.NewTextCondition(i.TypeColumn(), database.TextOperationEqual, typee.String())
}
func (i idProvider) AutoRegisterCondition(allow bool) database.Condition {
return database.NewBooleanCondition(i.AutoRegisterColumn(), allow)
}
func (i idProvider) AllowCreationCondition(allow bool) database.Condition {
return database.NewBooleanCondition(i.AllowCreationColumn(), allow)
}
func (i idProvider) AllowAutoCreationCondition(allow bool) database.Condition {
return database.NewBooleanCondition(i.AllowAutoCreationColumn(), allow)
}
func (i idProvider) AllowAutoUpdateCondition(allow bool) database.Condition {
return database.NewBooleanCondition(i.AllowAutoUpdateColumn(), allow)
}
func (i idProvider) AllowLinkingCondition(allow bool) database.Condition {
return database.NewBooleanCondition(i.AllowLinkingColumn(), allow)
}
func (i idProvider) AllowAutoLinkingCondition(allow bool) database.Condition {
return database.NewBooleanCondition(i.AllowAutoLinkingColumn(), allow)
}
func (i idProvider) StylingTypeCondition(style int16) database.Condition {
return database.NewNumberCondition(i.StylingTypeColumn(), database.NumberOperationEqual, style)
}
func (i idProvider) PayloadCondition(payload string) database.Condition {
return database.NewTextCondition(i.PayloadColumn(), database.TextOperationEqual, payload)
}
// -------------------------------------------------------------
// changes
// -------------------------------------------------------------
func (i idProvider) SetName(name string) database.Change {
return database.NewChange(i.NameColumn(), name)
}
func (i idProvider) SetState(state domain.IDPState) database.Change {
return database.NewChange(i.StateColumn(), state)
}
func (i idProvider) SetAllowCreation(allow bool) database.Change {
return database.NewChange(i.AllowCreationColumn(), allow)
}
func (i idProvider) SetAutoRegister(allow bool) database.Change {
return database.NewChange(i.AutoRegisterColumn(), allow)
}
func (i idProvider) SetAllowAutoCreation(allow bool) database.Change {
return database.NewChange(i.AllowAutoCreationColumn(), allow)
}
func (i idProvider) SetAllowAutoUpdate(allow bool) database.Change {
return database.NewChange(i.AllowAutoUpdateColumn(), allow)
}
func (i idProvider) SetAllowLinking(allow bool) database.Change {
return database.NewChange(i.AllowLinkingColumn(), allow)
}
func (i idProvider) SetAutoAllowLinking(allow bool) database.Change {
return database.NewChange(i.AllowAutoLinkingColumn(), allow)
}
func (i idProvider) SetStylingType(stylingType int16) database.Change {
return database.NewChange(i.StylingTypeColumn(), stylingType)
}
func (i idProvider) SetPayload(payload string) database.Change {
return database.NewChange(i.PayloadColumn(), payload)
}
func scanIDProvider(ctx context.Context, querier database.Querier, builder *database.StatementBuilder) (*domain.IdentityProvider, error) {
idp := &domain.IdentityProvider{}
err := scanRow(ctx, querier, builder, idp)
if err != nil {
return nil, err
}
return idp, err
}
func scanIDProviders(ctx context.Context, querier database.Querier, builder *database.StatementBuilder) ([]*domain.IdentityProvider, error) {
idps := []*domain.IdentityProvider{}
err := scanRows(ctx, querier, builder, &idps)
if err != nil {
return nil, err
}
return idps, nil
}

File diff suppressed because it is too large Load Diff

View File

@@ -172,28 +172,18 @@ func (instance) UpdatedAtColumn() database.Column {
}
func scanInstance(ctx context.Context, querier database.Querier, builder *database.StatementBuilder) (*domain.Instance, error) {
rows, err := querier.Query(ctx, builder.String(), builder.Args()...)
instance := new(domain.Instance)
err := scanRow(ctx, querier, builder, instance)
if err != nil {
return nil, err
}
instance := new(domain.Instance)
if err := rows.(database.CollectableRows).CollectExactlyOneRow(instance); err != nil {
return nil, err
}
return instance, nil
return instance, err
}
func scanInstances(ctx context.Context, querier database.Querier, builder *database.StatementBuilder) (instances []*domain.Instance, err error) {
rows, err := querier.Query(ctx, builder.String(), builder.Args()...)
err = scanRows(ctx, querier, builder, &instances)
if err != nil {
return nil, err
}
if err := rows.(database.CollectableRows).Collect(&instances); err != nil {
return nil, err
}
return instances, nil
}

View File

@@ -217,27 +217,19 @@ func (org) UpdatedAtColumn() database.Column {
}
func scanOrganization(ctx context.Context, querier database.Querier, builder *database.StatementBuilder) (*domain.Organization, error) {
rows, err := querier.Query(ctx, builder.String(), builder.Args()...)
organization := &domain.Organization{}
err := scanRow(ctx, querier, builder, organization)
if err != nil {
return nil, err
}
organization := &domain.Organization{}
if err := rows.(database.CollectableRows).CollectExactlyOneRow(organization); err != nil {
return nil, err
}
return organization, nil
return organization, err
}
func scanOrganizations(ctx context.Context, querier database.Querier, builder *database.StatementBuilder) ([]*domain.Organization, error) {
rows, err := querier.Query(ctx, builder.String(), builder.Args()...)
if err != nil {
return nil, err
}
organizations := []*domain.Organization{}
if err := rows.(database.CollectableRows).Collect(&organizations); err != nil {
err := scanRows(ctx, querier, builder, &organizations)
if err != nil {
return nil, err
}
return organizations, nil

View File

@@ -519,11 +519,6 @@ func TestGetOrganization(t *testing.T) {
return
}
if org.Name == "non existent org" {
assert.Nil(t, returnedOrg)
return
}
assert.Equal(t, returnedOrg.ID, org.ID)
assert.Equal(t, returnedOrg.Name, org.Name)
assert.Equal(t, returnedOrg.InstanceID, org.InstanceID)
@@ -815,8 +810,7 @@ func TestDeleteOrganization(t *testing.T) {
return test{
name: "happy path delete organization filter id",
testFunc: func(ctx context.Context, t *testing.T) {
organizations := make([]*domain.Organization, noOfOrganizations)
for i := range noOfOrganizations {
for range noOfOrganizations {
org := domain.Organization{
ID: organizationId,
@@ -829,7 +823,6 @@ func TestDeleteOrganization(t *testing.T) {
err := organizationRepo.Create(ctx, &org)
require.NoError(t, err)
organizations[i] = &org
}
},
orgIdentifierCondition: organizationRepo.IDCondition(organizationId),
@@ -843,8 +836,7 @@ func TestDeleteOrganization(t *testing.T) {
return test{
name: "happy path delete organization filter name",
testFunc: func(ctx context.Context, t *testing.T) {
organizations := make([]*domain.Organization, noOfOrganizations)
for i := range noOfOrganizations {
for range noOfOrganizations {
org := domain.Organization{
ID: gofakeit.Name(),
@@ -857,7 +849,6 @@ func TestDeleteOrganization(t *testing.T) {
err := organizationRepo.Create(ctx, &org)
require.NoError(t, err)
organizations[i] = &org
}
},
orgIdentifierCondition: organizationRepo.NameCondition(organizationName),
@@ -879,8 +870,7 @@ func TestDeleteOrganization(t *testing.T) {
name: "deleted already deleted organization",
testFunc: func(ctx context.Context, t *testing.T) {
noOfOrganizations := 1
organizations := make([]*domain.Organization, noOfOrganizations)
for i := range noOfOrganizations {
for range noOfOrganizations {
org := domain.Organization{
ID: gofakeit.Name(),
@@ -893,13 +883,12 @@ func TestDeleteOrganization(t *testing.T) {
err := organizationRepo.Create(ctx, &org)
require.NoError(t, err)
organizations[i] = &org
}
// delete organization
affectedRows, err := organizationRepo.Delete(ctx,
organizationRepo.NameCondition(organizationName),
organizations[0].InstanceID,
instanceId,
)
assert.Equal(t, int64(1), affectedRows)
require.NoError(t, err)

View File

@@ -1,6 +1,8 @@
package repository
import (
"context"
"github.com/zitadel/zitadel/backend/v3/storage/database"
)
@@ -18,3 +20,21 @@ func writeCondition(
builder.WriteString(" WHERE ")
condition.Write(builder)
}
func scanRow(ctx context.Context, querier database.Querier, builder *database.StatementBuilder, res any) error {
rows, err := querier.Query(ctx, builder.String(), builder.Args()...)
if err != nil {
return err
}
return rows.(database.CollectableRows).CollectExactlyOneRow(res)
}
func scanRows(ctx context.Context, querier database.Querier, builder *database.StatementBuilder, res any) error {
rows, err := querier.Query(ctx, builder.String(), builder.Args()...)
if err != nil {
return err
}
return rows.(database.CollectableRows).Collect(res)
}

View File

@@ -29,9 +29,7 @@ func (b *StatementBuilder) AppendArg(arg any) (placeholder string) {
if b.existingArgs == nil {
b.existingArgs = make(map[any]string)
}
if placeholder, ok := b.existingArgs[arg]; ok {
return placeholder
}
if instruction, ok := arg.(Instruction); ok {
return string(instruction)
}