refactor: use 1 db pool instead of 3

This commit is contained in:
adlerhurst 2024-12-22 11:40:55 +01:00
parent bcf416d4cf
commit 8f5bdff131
28 changed files with 478 additions and 535 deletions

View File

@ -9,7 +9,6 @@ import (
"github.com/zitadel/logging" "github.com/zitadel/logging"
"github.com/zitadel/zitadel/internal/database" "github.com/zitadel/zitadel/internal/database"
"github.com/zitadel/zitadel/internal/database/dialect"
) )
var ( var (
@ -79,7 +78,7 @@ func initialise(ctx context.Context, config database.Config, steps ...func(conte
return err return err
} }
db, err := database.Connect(config, true, dialect.DBPurposeQuery) db, err := database.Connect(config, true)
if err != nil { if err != nil {
return err return err
} }

View File

@ -11,7 +11,6 @@ import (
"github.com/zitadel/logging" "github.com/zitadel/logging"
"github.com/zitadel/zitadel/internal/database" "github.com/zitadel/zitadel/internal/database"
"github.com/zitadel/zitadel/internal/database/dialect"
es_v3 "github.com/zitadel/zitadel/internal/eventstore/v3" es_v3 "github.com/zitadel/zitadel/internal/eventstore/v3"
) )
@ -85,7 +84,7 @@ func VerifyZitadel(ctx context.Context, db *database.DB, config database.Config)
func verifyZitadel(ctx context.Context, config database.Config) error { func verifyZitadel(ctx context.Context, config database.Config) error {
logging.WithFields("database", config.DatabaseName()).Info("verify zitadel") logging.WithFields("database", config.DatabaseName()).Info("verify zitadel")
db, err := database.Connect(config, false, dialect.DBPurposeQuery) db, err := database.Connect(config, false)
if err != nil { if err != nil {
return err return err
} }

View File

@ -12,7 +12,6 @@ import (
"github.com/zitadel/zitadel/internal/crypto" "github.com/zitadel/zitadel/internal/crypto"
cryptoDB "github.com/zitadel/zitadel/internal/crypto/database" cryptoDB "github.com/zitadel/zitadel/internal/crypto/database"
"github.com/zitadel/zitadel/internal/database" "github.com/zitadel/zitadel/internal/database"
"github.com/zitadel/zitadel/internal/database/dialect"
"github.com/zitadel/zitadel/internal/zerrors" "github.com/zitadel/zitadel/internal/zerrors"
) )
@ -124,7 +123,7 @@ func openFile(fileName string) (io.Reader, error) {
} }
func keyStorage(config database.Config, masterKey string) (crypto.KeyStorage, error) { func keyStorage(config database.Config, masterKey string) (crypto.KeyStorage, error) {
db, err := database.Connect(config, false, dialect.DBPurposeQuery) db, err := database.Connect(config, false)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -12,7 +12,6 @@ import (
"github.com/zitadel/logging" "github.com/zitadel/logging"
"github.com/zitadel/zitadel/internal/database" "github.com/zitadel/zitadel/internal/database"
"github.com/zitadel/zitadel/internal/database/dialect"
) )
func authCmd() *cobra.Command { func authCmd() *cobra.Command {
@ -34,11 +33,11 @@ Only auth requests are mirrored`,
} }
func copyAuth(ctx context.Context, config *Migration) { func copyAuth(ctx context.Context, config *Migration) {
sourceClient, err := database.Connect(config.Source, false, dialect.DBPurposeQuery) sourceClient, err := database.Connect(config.Source, false)
logging.OnError(err).Fatal("unable to connect to source database") logging.OnError(err).Fatal("unable to connect to source database")
defer sourceClient.Close() defer sourceClient.Close()
destClient, err := database.Connect(config.Destination, false, dialect.DBPurposeEventPusher) destClient, err := database.Connect(config.Destination, false)
logging.OnError(err).Fatal("unable to connect to destination database") logging.OnError(err).Fatal("unable to connect to destination database")
defer destClient.Close() defer destClient.Close()

View File

@ -14,7 +14,6 @@ import (
"github.com/zitadel/logging" "github.com/zitadel/logging"
db "github.com/zitadel/zitadel/internal/database" db "github.com/zitadel/zitadel/internal/database"
"github.com/zitadel/zitadel/internal/database/dialect"
"github.com/zitadel/zitadel/internal/id" "github.com/zitadel/zitadel/internal/id"
"github.com/zitadel/zitadel/internal/v2/database" "github.com/zitadel/zitadel/internal/v2/database"
"github.com/zitadel/zitadel/internal/v2/eventstore" "github.com/zitadel/zitadel/internal/v2/eventstore"
@ -44,11 +43,11 @@ Migrate only copies events2 and unique constraints`,
} }
func copyEventstore(ctx context.Context, config *Migration) { func copyEventstore(ctx context.Context, config *Migration) {
sourceClient, err := db.Connect(config.Source, false, dialect.DBPurposeEventPusher) sourceClient, err := db.Connect(config.Source, false)
logging.OnError(err).Fatal("unable to connect to source database") logging.OnError(err).Fatal("unable to connect to source database")
defer sourceClient.Close() defer sourceClient.Close()
destClient, err := db.Connect(config.Destination, false, dialect.DBPurposeEventPusher) destClient, err := db.Connect(config.Destination, false)
logging.OnError(err).Fatal("unable to connect to destination database") logging.OnError(err).Fatal("unable to connect to destination database")
defer destClient.Close() defer destClient.Close()

View File

@ -30,7 +30,6 @@ import (
"github.com/zitadel/zitadel/internal/config/systemdefaults" "github.com/zitadel/zitadel/internal/config/systemdefaults"
crypto_db "github.com/zitadel/zitadel/internal/crypto/database" crypto_db "github.com/zitadel/zitadel/internal/crypto/database"
"github.com/zitadel/zitadel/internal/database" "github.com/zitadel/zitadel/internal/database"
"github.com/zitadel/zitadel/internal/database/dialect"
"github.com/zitadel/zitadel/internal/domain" "github.com/zitadel/zitadel/internal/domain"
"github.com/zitadel/zitadel/internal/eventstore" "github.com/zitadel/zitadel/internal/eventstore"
old_es "github.com/zitadel/zitadel/internal/eventstore/repository/sql" old_es "github.com/zitadel/zitadel/internal/eventstore/repository/sql"
@ -106,7 +105,7 @@ func projections(
) { ) {
start := time.Now() start := time.Now()
client, err := database.Connect(config.Destination, false, dialect.DBPurposeQuery) client, err := database.Connect(config.Destination, false)
logging.OnError(err).Fatal("unable to connect to database") logging.OnError(err).Fatal("unable to connect to database")
keyStorage, err := crypto_db.NewKeyStorage(client, masterKey) keyStorage, err := crypto_db.NewKeyStorage(client, masterKey)
@ -119,9 +118,7 @@ func projections(
logging.OnError(err).Fatal("unable create static storage") logging.OnError(err).Fatal("unable create static storage")
config.Eventstore.Querier = old_es.NewCRDB(client) config.Eventstore.Querier = old_es.NewCRDB(client)
esPusherDBClient, err := database.Connect(config.Destination, false, dialect.DBPurposeEventPusher) config.Eventstore.Pusher = new_es.NewEventstore(client)
logging.OnError(err).Fatal("unable to connect eventstore push client")
config.Eventstore.Pusher = new_es.NewEventstore(esPusherDBClient)
es := eventstore.NewEventstore(config.Eventstore) es := eventstore.NewEventstore(config.Eventstore)
esV4 := es_v4.NewEventstoreFromOne(es_v4_pg.New(client, &es_v4_pg.Config{ esV4 := es_v4.NewEventstoreFromOne(es_v4_pg.New(client, &es_v4_pg.Config{
MaxRetries: config.Eventstore.MaxRetries, MaxRetries: config.Eventstore.MaxRetries,

View File

@ -12,7 +12,6 @@ import (
"github.com/zitadel/logging" "github.com/zitadel/logging"
"github.com/zitadel/zitadel/internal/database" "github.com/zitadel/zitadel/internal/database"
"github.com/zitadel/zitadel/internal/database/dialect"
) )
func systemCmd() *cobra.Command { func systemCmd() *cobra.Command {
@ -34,11 +33,11 @@ Only keys and assets are mirrored`,
} }
func copySystem(ctx context.Context, config *Migration) { func copySystem(ctx context.Context, config *Migration) {
sourceClient, err := database.Connect(config.Source, false, dialect.DBPurposeQuery) sourceClient, err := database.Connect(config.Source, false)
logging.OnError(err).Fatal("unable to connect to source database") logging.OnError(err).Fatal("unable to connect to source database")
defer sourceClient.Close() defer sourceClient.Close()
destClient, err := database.Connect(config.Destination, false, dialect.DBPurposeEventPusher) destClient, err := database.Connect(config.Destination, false)
logging.OnError(err).Fatal("unable to connect to destination database") logging.OnError(err).Fatal("unable to connect to destination database")
defer destClient.Close() defer destClient.Close()

View File

@ -13,7 +13,6 @@ import (
cryptoDatabase "github.com/zitadel/zitadel/internal/crypto/database" cryptoDatabase "github.com/zitadel/zitadel/internal/crypto/database"
"github.com/zitadel/zitadel/internal/database" "github.com/zitadel/zitadel/internal/database"
"github.com/zitadel/zitadel/internal/database/dialect"
"github.com/zitadel/zitadel/internal/query/projection" "github.com/zitadel/zitadel/internal/query/projection"
) )
@ -37,11 +36,11 @@ var schemas = []string{
} }
func verifyMigration(ctx context.Context, config *Migration) { func verifyMigration(ctx context.Context, config *Migration) {
sourceClient, err := database.Connect(config.Source, false, dialect.DBPurposeQuery) sourceClient, err := database.Connect(config.Source, false)
logging.OnError(err).Fatal("unable to connect to source database") logging.OnError(err).Fatal("unable to connect to source database")
defer sourceClient.Close() defer sourceClient.Close()
destClient, err := database.Connect(config.Destination, false, dialect.DBPurposeEventPusher) destClient, err := database.Connect(config.Destination, false)
logging.OnError(err).Fatal("unable to connect to destination database") logging.OnError(err).Fatal("unable to connect to destination database")
defer destClient.Close() defer destClient.Close()

View File

@ -8,7 +8,6 @@ import (
"github.com/zitadel/logging" "github.com/zitadel/logging"
"github.com/zitadel/zitadel/internal/database" "github.com/zitadel/zitadel/internal/database"
"github.com/zitadel/zitadel/internal/database/dialect"
"github.com/zitadel/zitadel/internal/eventstore" "github.com/zitadel/zitadel/internal/eventstore"
old_es "github.com/zitadel/zitadel/internal/eventstore/repository/sql" old_es "github.com/zitadel/zitadel/internal/eventstore/repository/sql"
new_es "github.com/zitadel/zitadel/internal/eventstore/v3" new_es "github.com/zitadel/zitadel/internal/eventstore/v3"
@ -32,13 +31,11 @@ func Cleanup(config *Config) {
logging.Info("cleanup started") logging.Info("cleanup started")
queryDBClient, err := database.Connect(config.Database, false, dialect.DBPurposeQuery) dbClient, err := database.Connect(config.Database, false)
logging.OnError(err).Fatal("unable to connect to database")
esPusherDBClient, err := database.Connect(config.Database, false, dialect.DBPurposeEventPusher)
logging.OnError(err).Fatal("unable to connect to database") logging.OnError(err).Fatal("unable to connect to database")
config.Eventstore.Pusher = new_es.NewEventstore(esPusherDBClient) config.Eventstore.Pusher = new_es.NewEventstore(dbClient)
config.Eventstore.Querier = old_es.NewCRDB(queryDBClient) config.Eventstore.Querier = old_es.NewCRDB(dbClient)
es := eventstore.NewEventstore(config.Eventstore) es := eventstore.NewEventstore(config.Eventstore)
step, err := migration.LastStuckStep(ctx, es) step, err := migration.LastStuckStep(ctx, es)

View File

@ -26,7 +26,6 @@ import (
"github.com/zitadel/zitadel/internal/command" "github.com/zitadel/zitadel/internal/command"
cryptoDB "github.com/zitadel/zitadel/internal/crypto/database" cryptoDB "github.com/zitadel/zitadel/internal/crypto/database"
"github.com/zitadel/zitadel/internal/database" "github.com/zitadel/zitadel/internal/database"
"github.com/zitadel/zitadel/internal/database/dialect"
"github.com/zitadel/zitadel/internal/domain" "github.com/zitadel/zitadel/internal/domain"
"github.com/zitadel/zitadel/internal/eventstore" "github.com/zitadel/zitadel/internal/eventstore"
old_es "github.com/zitadel/zitadel/internal/eventstore/repository/sql" old_es "github.com/zitadel/zitadel/internal/eventstore/repository/sql"
@ -102,26 +101,22 @@ func Setup(ctx context.Context, config *Config, steps *Steps, masterKey string)
i18n.MustLoadSupportedLanguagesFromDir() i18n.MustLoadSupportedLanguagesFromDir()
queryDBClient, err := database.Connect(config.Database, false, dialect.DBPurposeQuery) dbClient, err := database.Connect(config.Database, false)
logging.OnError(err).Fatal("unable to connect to database")
esPusherDBClient, err := database.Connect(config.Database, false, dialect.DBPurposeEventPusher)
logging.OnError(err).Fatal("unable to connect to database")
projectionDBClient, err := database.Connect(config.Database, false, dialect.DBPurposeProjectionSpooler)
logging.OnError(err).Fatal("unable to connect to database") logging.OnError(err).Fatal("unable to connect to database")
config.Eventstore.Querier = old_es.NewCRDB(queryDBClient) config.Eventstore.Querier = old_es.NewCRDB(dbClient)
esV3 := new_es.NewEventstore(esPusherDBClient) esV3 := new_es.NewEventstore(dbClient)
config.Eventstore.Pusher = esV3 config.Eventstore.Pusher = esV3
config.Eventstore.Searcher = esV3 config.Eventstore.Searcher = esV3
eventstoreClient := eventstore.NewEventstore(config.Eventstore) eventstoreClient := eventstore.NewEventstore(config.Eventstore)
logging.OnError(err).Fatal("unable to start eventstore") logging.OnError(err).Fatal("unable to start eventstore")
eventstoreV4 := es_v4.NewEventstoreFromOne(es_v4_pg.New(queryDBClient, &es_v4_pg.Config{ eventstoreV4 := es_v4.NewEventstoreFromOne(es_v4_pg.New(dbClient, &es_v4_pg.Config{
MaxRetries: config.Eventstore.MaxRetries, MaxRetries: config.Eventstore.MaxRetries,
})) }))
steps.s1ProjectionTable = &ProjectionTable{dbClient: queryDBClient.DB} steps.s1ProjectionTable = &ProjectionTable{dbClient: dbClient.DB}
steps.s2AssetsTable = &AssetTable{dbClient: queryDBClient.DB} steps.s2AssetsTable = &AssetTable{dbClient: dbClient.DB}
steps.FirstInstance.Skip = config.ForMirror || steps.FirstInstance.Skip steps.FirstInstance.Skip = config.ForMirror || steps.FirstInstance.Skip
steps.FirstInstance.instanceSetup = config.DefaultInstance steps.FirstInstance.instanceSetup = config.DefaultInstance
@ -129,7 +124,7 @@ func Setup(ctx context.Context, config *Config, steps *Steps, masterKey string)
steps.FirstInstance.smtpEncryptionKey = config.EncryptionKeys.SMTP steps.FirstInstance.smtpEncryptionKey = config.EncryptionKeys.SMTP
steps.FirstInstance.oidcEncryptionKey = config.EncryptionKeys.OIDC steps.FirstInstance.oidcEncryptionKey = config.EncryptionKeys.OIDC
steps.FirstInstance.masterKey = masterKey steps.FirstInstance.masterKey = masterKey
steps.FirstInstance.db = queryDBClient steps.FirstInstance.db = dbClient
steps.FirstInstance.es = eventstoreClient steps.FirstInstance.es = eventstoreClient
steps.FirstInstance.defaults = config.SystemDefaults steps.FirstInstance.defaults = config.SystemDefaults
steps.FirstInstance.zitadelRoles = config.InternalAuthZ.RolePermissionMappings steps.FirstInstance.zitadelRoles = config.InternalAuthZ.RolePermissionMappings
@ -137,42 +132,42 @@ func Setup(ctx context.Context, config *Config, steps *Steps, masterKey string)
steps.FirstInstance.externalSecure = config.ExternalSecure steps.FirstInstance.externalSecure = config.ExternalSecure
steps.FirstInstance.externalPort = config.ExternalPort steps.FirstInstance.externalPort = config.ExternalPort
steps.s5LastFailed = &LastFailed{dbClient: queryDBClient.DB} steps.s5LastFailed = &LastFailed{dbClient: dbClient.DB}
steps.s6OwnerRemoveColumns = &OwnerRemoveColumns{dbClient: queryDBClient.DB} steps.s6OwnerRemoveColumns = &OwnerRemoveColumns{dbClient: dbClient.DB}
steps.s7LogstoreTables = &LogstoreTables{dbClient: queryDBClient.DB, username: config.Database.Username(), dbType: config.Database.Type()} steps.s7LogstoreTables = &LogstoreTables{dbClient: dbClient.DB, username: config.Database.Username(), dbType: config.Database.Type()}
steps.s8AuthTokens = &AuthTokenIndexes{dbClient: queryDBClient} steps.s8AuthTokens = &AuthTokenIndexes{dbClient: dbClient}
steps.CorrectCreationDate.dbClient = esPusherDBClient steps.CorrectCreationDate.dbClient = dbClient
steps.s12AddOTPColumns = &AddOTPColumns{dbClient: queryDBClient} steps.s12AddOTPColumns = &AddOTPColumns{dbClient: dbClient}
steps.s13FixQuotaProjection = &FixQuotaConstraints{dbClient: queryDBClient} steps.s13FixQuotaProjection = &FixQuotaConstraints{dbClient: dbClient}
steps.s14NewEventsTable = &NewEventsTable{dbClient: esPusherDBClient} steps.s14NewEventsTable = &NewEventsTable{dbClient: dbClient}
steps.s15CurrentStates = &CurrentProjectionState{dbClient: queryDBClient} steps.s15CurrentStates = &CurrentProjectionState{dbClient: dbClient}
steps.s16UniqueConstraintsLower = &UniqueConstraintToLower{dbClient: queryDBClient} steps.s16UniqueConstraintsLower = &UniqueConstraintToLower{dbClient: dbClient}
steps.s17AddOffsetToUniqueConstraints = &AddOffsetToCurrentStates{dbClient: queryDBClient} steps.s17AddOffsetToUniqueConstraints = &AddOffsetToCurrentStates{dbClient: dbClient}
steps.s18AddLowerFieldsToLoginNames = &AddLowerFieldsToLoginNames{dbClient: queryDBClient} steps.s18AddLowerFieldsToLoginNames = &AddLowerFieldsToLoginNames{dbClient: dbClient}
steps.s19AddCurrentStatesIndex = &AddCurrentSequencesIndex{dbClient: queryDBClient} steps.s19AddCurrentStatesIndex = &AddCurrentSequencesIndex{dbClient: dbClient}
steps.s20AddByUserSessionIndex = &AddByUserIndexToSession{dbClient: queryDBClient} steps.s20AddByUserSessionIndex = &AddByUserIndexToSession{dbClient: dbClient}
steps.s21AddBlockFieldToLimits = &AddBlockFieldToLimits{dbClient: queryDBClient} steps.s21AddBlockFieldToLimits = &AddBlockFieldToLimits{dbClient: dbClient}
steps.s22ActiveInstancesIndex = &ActiveInstanceEvents{dbClient: queryDBClient} steps.s22ActiveInstancesIndex = &ActiveInstanceEvents{dbClient: dbClient}
steps.s23CorrectGlobalUniqueConstraints = &CorrectGlobalUniqueConstraints{dbClient: esPusherDBClient} steps.s23CorrectGlobalUniqueConstraints = &CorrectGlobalUniqueConstraints{dbClient: dbClient}
steps.s24AddActorToAuthTokens = &AddActorToAuthTokens{dbClient: queryDBClient} steps.s24AddActorToAuthTokens = &AddActorToAuthTokens{dbClient: dbClient}
steps.s25User11AddLowerFieldsToVerifiedEmail = &User11AddLowerFieldsToVerifiedEmail{dbClient: esPusherDBClient} steps.s25User11AddLowerFieldsToVerifiedEmail = &User11AddLowerFieldsToVerifiedEmail{dbClient: dbClient}
steps.s26AuthUsers3 = &AuthUsers3{dbClient: esPusherDBClient} steps.s26AuthUsers3 = &AuthUsers3{dbClient: dbClient}
steps.s27IDPTemplate6SAMLNameIDFormat = &IDPTemplate6SAMLNameIDFormat{dbClient: esPusherDBClient} steps.s27IDPTemplate6SAMLNameIDFormat = &IDPTemplate6SAMLNameIDFormat{dbClient: dbClient}
steps.s28AddFieldTable = &AddFieldTable{dbClient: esPusherDBClient} steps.s28AddFieldTable = &AddFieldTable{dbClient: dbClient}
steps.s29FillFieldsForProjectGrant = &FillFieldsForProjectGrant{eventstore: eventstoreClient} steps.s29FillFieldsForProjectGrant = &FillFieldsForProjectGrant{eventstore: eventstoreClient}
steps.s30FillFieldsForOrgDomainVerified = &FillFieldsForOrgDomainVerified{eventstore: eventstoreClient} steps.s30FillFieldsForOrgDomainVerified = &FillFieldsForOrgDomainVerified{eventstore: eventstoreClient}
steps.s31AddAggregateIndexToFields = &AddAggregateIndexToFields{dbClient: esPusherDBClient} steps.s31AddAggregateIndexToFields = &AddAggregateIndexToFields{dbClient: dbClient}
steps.s32AddAuthSessionID = &AddAuthSessionID{dbClient: esPusherDBClient} steps.s32AddAuthSessionID = &AddAuthSessionID{dbClient: dbClient}
steps.s33SMSConfigs3TwilioAddVerifyServiceSid = &SMSConfigs3TwilioAddVerifyServiceSid{dbClient: esPusherDBClient} steps.s33SMSConfigs3TwilioAddVerifyServiceSid = &SMSConfigs3TwilioAddVerifyServiceSid{dbClient: dbClient}
steps.s34AddCacheSchema = &AddCacheSchema{dbClient: queryDBClient} steps.s34AddCacheSchema = &AddCacheSchema{dbClient: dbClient}
steps.s35AddPositionToIndexEsWm = &AddPositionToIndexEsWm{dbClient: esPusherDBClient} steps.s35AddPositionToIndexEsWm = &AddPositionToIndexEsWm{dbClient: dbClient}
steps.s36FillV2Milestones = &FillV3Milestones{dbClient: queryDBClient, eventstore: eventstoreClient} steps.s36FillV2Milestones = &FillV3Milestones{dbClient: dbClient, eventstore: eventstoreClient}
steps.s37Apps7OIDConfigsBackChannelLogoutURI = &Apps7OIDConfigsBackChannelLogoutURI{dbClient: esPusherDBClient} steps.s37Apps7OIDConfigsBackChannelLogoutURI = &Apps7OIDConfigsBackChannelLogoutURI{dbClient: dbClient}
steps.s38BackChannelLogoutNotificationStart = &BackChannelLogoutNotificationStart{dbClient: esPusherDBClient, esClient: eventstoreClient} steps.s38BackChannelLogoutNotificationStart = &BackChannelLogoutNotificationStart{dbClient: dbClient, esClient: eventstoreClient}
steps.s40InitPushFunc = &InitPushFunc{dbClient: esPusherDBClient} steps.s40InitPushFunc = &InitPushFunc{dbClient: dbClient}
steps.s42Apps7OIDCConfigsLoginVersion = &Apps7OIDCConfigsLoginVersion{dbClient: esPusherDBClient} steps.s42Apps7OIDCConfigsLoginVersion = &Apps7OIDCConfigsLoginVersion{dbClient: dbClient}
err = projection.Create(ctx, projectionDBClient, eventstoreClient, config.Projections, nil, nil, nil) err = projection.Create(ctx, dbClient, eventstoreClient, config.Projections, nil, nil, nil)
logging.OnError(err).Fatal("unable to start projections") logging.OnError(err).Fatal("unable to start projections")
repeatableSteps := []migration.RepeatableMigration{ repeatableSteps := []migration.RepeatableMigration{
@ -252,8 +247,8 @@ func Setup(ctx context.Context, config *Config, steps *Steps, masterKey string)
ctx, ctx,
eventstoreClient, eventstoreClient,
eventstoreV4, eventstoreV4,
queryDBClient, dbClient,
projectionDBClient, dbClient,
masterKey, masterKey,
config, config,
) )

View File

@ -75,7 +75,6 @@ import (
"github.com/zitadel/zitadel/internal/crypto" "github.com/zitadel/zitadel/internal/crypto"
cryptoDB "github.com/zitadel/zitadel/internal/crypto/database" cryptoDB "github.com/zitadel/zitadel/internal/crypto/database"
"github.com/zitadel/zitadel/internal/database" "github.com/zitadel/zitadel/internal/database"
"github.com/zitadel/zitadel/internal/database/dialect"
"github.com/zitadel/zitadel/internal/domain" "github.com/zitadel/zitadel/internal/domain"
"github.com/zitadel/zitadel/internal/eventstore" "github.com/zitadel/zitadel/internal/eventstore"
old_es "github.com/zitadel/zitadel/internal/eventstore/repository/sql" old_es "github.com/zitadel/zitadel/internal/eventstore/repository/sql"
@ -148,20 +147,12 @@ func startZitadel(ctx context.Context, config *Config, masterKey string, server
i18n.MustLoadSupportedLanguagesFromDir() i18n.MustLoadSupportedLanguagesFromDir()
queryDBClient, err := database.Connect(config.Database, false, dialect.DBPurposeQuery) dbClient, err := database.Connect(config.Database, false)
if err != nil { if err != nil {
return fmt.Errorf("cannot start DB client for queries: %w", err) return fmt.Errorf("cannot start DB client for queries: %w", err)
} }
esPusherDBClient, err := database.Connect(config.Database, false, dialect.DBPurposeEventPusher)
if err != nil {
return fmt.Errorf("cannot start client for event store pusher: %w", err)
}
projectionDBClient, err := database.Connect(config.Database, false, dialect.DBPurposeProjectionSpooler)
if err != nil {
return fmt.Errorf("cannot start client for projection spooler: %w", err)
}
keyStorage, err := cryptoDB.NewKeyStorage(queryDBClient, masterKey) keyStorage, err := cryptoDB.NewKeyStorage(dbClient, masterKey)
if err != nil { if err != nil {
return fmt.Errorf("cannot start key storage: %w", err) return fmt.Errorf("cannot start key storage: %w", err)
} }
@ -170,16 +161,16 @@ func startZitadel(ctx context.Context, config *Config, masterKey string, server
return err return err
} }
config.Eventstore.Pusher = new_es.NewEventstore(esPusherDBClient) config.Eventstore.Pusher = new_es.NewEventstore(dbClient)
config.Eventstore.Searcher = new_es.NewEventstore(queryDBClient) config.Eventstore.Searcher = new_es.NewEventstore(dbClient)
config.Eventstore.Querier = old_es.NewCRDB(queryDBClient) config.Eventstore.Querier = old_es.NewCRDB(dbClient)
eventstoreClient := eventstore.NewEventstore(config.Eventstore) eventstoreClient := eventstore.NewEventstore(config.Eventstore)
eventstoreV4 := es_v4.NewEventstoreFromOne(es_v4_pg.New(queryDBClient, &es_v4_pg.Config{ eventstoreV4 := es_v4.NewEventstoreFromOne(es_v4_pg.New(dbClient, &es_v4_pg.Config{
MaxRetries: config.Eventstore.MaxRetries, MaxRetries: config.Eventstore.MaxRetries,
})) }))
sessionTokenVerifier := internal_authz.SessionTokenVerifier(keys.OIDC) sessionTokenVerifier := internal_authz.SessionTokenVerifier(keys.OIDC)
cacheConnectors, err := connector.StartConnectors(config.Caches, queryDBClient) cacheConnectors, err := connector.StartConnectors(config.Caches, dbClient)
if err != nil { if err != nil {
return fmt.Errorf("unable to start caches: %w", err) return fmt.Errorf("unable to start caches: %w", err)
} }
@ -188,8 +179,8 @@ func startZitadel(ctx context.Context, config *Config, masterKey string, server
ctx, ctx,
eventstoreClient, eventstoreClient,
eventstoreV4.Querier, eventstoreV4.Querier,
queryDBClient, dbClient,
projectionDBClient, dbClient,
cacheConnectors, cacheConnectors,
config.Projections, config.Projections,
config.SystemDefaults, config.SystemDefaults,
@ -213,7 +204,7 @@ func startZitadel(ctx context.Context, config *Config, masterKey string, server
return fmt.Errorf("cannot start queries: %w", err) return fmt.Errorf("cannot start queries: %w", err)
} }
authZRepo, err := authz.Start(queries, eventstoreClient, queryDBClient, keys.OIDC, config.ExternalSecure) authZRepo, err := authz.Start(queries, eventstoreClient, dbClient, keys.OIDC, config.ExternalSecure)
if err != nil { if err != nil {
return fmt.Errorf("error starting authz repo: %w", err) return fmt.Errorf("error starting authz repo: %w", err)
} }
@ -221,7 +212,7 @@ func startZitadel(ctx context.Context, config *Config, masterKey string, server
return internal_authz.CheckPermission(ctx, authZRepo, config.InternalAuthZ.RolePermissionMappings, permission, orgID, resourceID) return internal_authz.CheckPermission(ctx, authZRepo, config.InternalAuthZ.RolePermissionMappings, permission, orgID, resourceID)
} }
storage, err := config.AssetStorage.NewStorage(queryDBClient.DB) storage, err := config.AssetStorage.NewStorage(dbClient.DB)
if err != nil { if err != nil {
return fmt.Errorf("cannot start asset storage client: %w", err) return fmt.Errorf("cannot start asset storage client: %w", err)
} }
@ -266,7 +257,7 @@ func startZitadel(ctx context.Context, config *Config, masterKey string, server
if err != nil { if err != nil {
return err return err
} }
actionsExecutionDBEmitter, err := logstore.NewEmitter[*record.ExecutionLog](ctx, clock, config.Quotas.Execution, execution.NewDatabaseLogStorage(queryDBClient, commands, queries)) actionsExecutionDBEmitter, err := logstore.NewEmitter[*record.ExecutionLog](ctx, clock, config.Quotas.Execution, execution.NewDatabaseLogStorage(dbClient, commands, queries))
if err != nil { if err != nil {
return err return err
} }
@ -295,7 +286,7 @@ func startZitadel(ctx context.Context, config *Config, masterKey string, server
keys.SMS, keys.SMS,
keys.OIDC, keys.OIDC,
config.OIDC.DefaultBackChannelLogoutLifetime, config.OIDC.DefaultBackChannelLogoutLifetime,
queryDBClient, dbClient,
) )
notification.Start(ctx) notification.Start(ctx)
@ -311,7 +302,7 @@ func startZitadel(ctx context.Context, config *Config, masterKey string, server
commands, commands,
queries, queries,
eventstoreClient, eventstoreClient,
queryDBClient, dbClient,
config, config,
storage, storage,
authZRepo, authZRepo,
@ -330,7 +321,7 @@ func startZitadel(ctx context.Context, config *Config, masterKey string, server
if server != nil { if server != nil {
server <- &Server{ server <- &Server{
Config: config, Config: config,
DB: queryDBClient, DB: dbClient,
KeyStorage: keyStorage, KeyStorage: keyStorage,
Keys: keys, Keys: keys,
Eventstore: eventstoreClient, Eventstore: eventstoreClient,

View File

@ -3,7 +3,6 @@ package cockroach
import ( import (
"context" "context"
"database/sql" "database/sql"
"fmt"
"strconv" "strconv"
"strings" "strings"
"time" "time"
@ -14,7 +13,6 @@ import (
"github.com/mitchellh/mapstructure" "github.com/mitchellh/mapstructure"
"github.com/zitadel/logging" "github.com/zitadel/logging"
"github.com/zitadel/zitadel/internal/api/authz"
"github.com/zitadel/zitadel/internal/database/dialect" "github.com/zitadel/zitadel/internal/database/dialect"
) )
@ -74,19 +72,19 @@ func (_ *Config) Decode(configs []interface{}) (dialect.Connector, error) {
return connector, nil return connector, nil
} }
func (c *Config) Connect(useAdmin bool, pusherRatio, spoolerRatio float64, purpose dialect.DBPurpose) (*sql.DB, *pgxpool.Pool, error) { func (c *Config) Connect(useAdmin bool) (*sql.DB, *pgxpool.Pool, error) {
dialect.RegisterAfterConnect(func(ctx context.Context, c *pgx.Conn) error { dialect.RegisterAfterConnect(func(ctx context.Context, c *pgx.Conn) error {
// CockroachDB by default does not allow multiple modifications of the same table using ON CONFLICT // CockroachDB by default does not allow multiple modifications of the same table using ON CONFLICT
// This is needed to fill the fields table of the eventstore during eventstore.Push. // This is needed to fill the fields table of the eventstore during eventstore.Push.
_, err := c.Exec(ctx, "SET enable_multiple_modifications_of_table = on") _, err := c.Exec(ctx, "SET enable_multiple_modifications_of_table = on")
return err return err
}) })
connConfig, err := dialect.NewConnectionConfig(c.MaxOpenConns, c.MaxIdleConns, pusherRatio, spoolerRatio, purpose) connConfig, err := dialect.NewConnectionConfig(c.MaxOpenConns, c.MaxIdleConns)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
config, err := pgxpool.ParseConfig(c.String(useAdmin, purpose.AppName())) config, err := pgxpool.ParseConfig(c.String(useAdmin))
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
@ -102,18 +100,6 @@ func (c *Config) Connect(useAdmin bool, pusherRatio, spoolerRatio float64, purpo
} }
} }
// For the pusher we set the app name with the instance ID
if purpose == dialect.DBPurposeEventPusher {
config.BeforeAcquire = func(ctx context.Context, conn *pgx.Conn) bool {
return setAppNameWithID(ctx, conn, purpose, authz.GetInstance(ctx).InstanceID())
}
config.AfterRelease = func(conn *pgx.Conn) bool {
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
return setAppNameWithID(ctx, conn, purpose, "IDLE")
}
}
if connConfig.MaxOpenConns != 0 { if connConfig.MaxOpenConns != 0 {
config.MaxConns = int32(connConfig.MaxOpenConns) config.MaxConns = int32(connConfig.MaxOpenConns)
} }
@ -195,7 +181,7 @@ func (c *Config) checkSSL(user User) {
} }
} }
func (c Config) String(useAdmin bool, appName string) string { func (c Config) String(useAdmin bool) string {
user := c.User user := c.User
if useAdmin { if useAdmin {
user = c.Admin.User user = c.Admin.User
@ -206,7 +192,7 @@ func (c Config) String(useAdmin bool, appName string) string {
"port=" + strconv.Itoa(int(c.Port)), "port=" + strconv.Itoa(int(c.Port)),
"user=" + user.Username, "user=" + user.Username,
"dbname=" + c.Database, "dbname=" + c.Database,
"application_name=" + appName, "application_name=zitadel",
"sslmode=" + user.SSL.Mode, "sslmode=" + user.SSL.Mode,
} }
if c.Options != "" { if c.Options != "" {
@ -232,11 +218,3 @@ func (c Config) String(useAdmin bool, appName string) string {
return strings.Join(fields, " ") return strings.Join(fields, " ")
} }
func setAppNameWithID(ctx context.Context, conn *pgx.Conn, purpose dialect.DBPurpose, id string) bool {
// needs to be set like this because psql complains about parameters in the SET statement
query := fmt.Sprintf("SET application_name = '%s_%s'", purpose.AppName(), id)
_, err := conn.Exec(ctx, query)
logging.OnError(err).Warn("failed to set application name")
return err == nil
}

View File

@ -134,8 +134,8 @@ func QueryJSONObject[T any](ctx context.Context, db *DB, query string, args ...a
return obj, nil return obj, nil
} }
func Connect(config Config, useAdmin bool, purpose dialect.DBPurpose) (*DB, error) { func Connect(config Config, useAdmin bool) (*DB, error) {
client, pool, err := config.connector.Connect(useAdmin, config.EventPushConnRatio, config.ProjectionSpoolerConnRatio, purpose) client, pool, err := config.connector.Connect(useAdmin)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -26,36 +26,36 @@ type Matcher interface {
} }
const ( const (
QueryAppName = "zitadel_queries" // QueryAppName = "zitadel_queries"
EventstorePusherAppName = "zitadel_es_pusher" // EventstorePusherAppName = "zitadel_es_pusher"
ProjectionSpoolerAppName = "zitadel_projection_spooler" // ProjectionSpoolerAppName = "zitadel_projection_spooler"
defaultAppName = "zitadel" defaultAppName = "zitadel"
) )
// DBPurpose is what the resulting connection pool is used for. // DBPurpose is what the resulting connection pool is used for.
type DBPurpose int // type DBPurpose int
const ( // const (
DBPurposeQuery DBPurpose = iota // DBPurposeQuery DBPurpose = iota
DBPurposeEventPusher // DBPurposeEventPusher
DBPurposeProjectionSpooler // DBPurposeProjectionSpooler
) // )
func (p DBPurpose) AppName() string { // func (p DBPurpose) AppName() string {
switch p { // switch p {
case DBPurposeQuery: // case DBPurposeQuery:
return QueryAppName // return QueryAppName
case DBPurposeEventPusher: // case DBPurposeEventPusher:
return EventstorePusherAppName // return EventstorePusherAppName
case DBPurposeProjectionSpooler: // case DBPurposeProjectionSpooler:
return ProjectionSpoolerAppName // return ProjectionSpoolerAppName
default: // default:
return defaultAppName // return defaultAppName
} // }
} // }
type Connector interface { type Connector interface {
Connect(useAdmin bool, pusherRatio, spoolerRatio float64, purpose DBPurpose) (*sql.DB, *pgxpool.Pool, error) Connect(useAdmin bool) (*sql.DB, *pgxpool.Pool, error)
Password() string Password() string
Database Database
} }

View File

@ -1,36 +0,0 @@
package dialect
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestDBPurpose_AppName(t *testing.T) {
tests := []struct {
p DBPurpose
want string
}{
{
p: DBPurposeQuery,
want: QueryAppName,
},
{
p: DBPurposeEventPusher,
want: EventstorePusherAppName,
},
{
p: DBPurposeProjectionSpooler,
want: ProjectionSpoolerAppName,
},
{
p: 99,
want: defaultAppName,
},
}
for _, tt := range tests {
t.Run(tt.want, func(t *testing.T) {
assert.Equal(t, tt.want, tt.p.AppName())
})
}
}

View File

@ -3,7 +3,6 @@ package dialect
import ( import (
"context" "context"
"errors" "errors"
"fmt"
"reflect" "reflect"
"github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5"
@ -11,11 +10,11 @@ import (
) )
var ( var (
ErrNegativeRatio = errors.New("ratio cannot be negative") // ErrNegativeRatio = errors.New("ratio cannot be negative")
ErrHighSumRatio = errors.New("sum of pusher and projection ratios must be < 1") // ErrHighSumRatio = errors.New("sum of pusher and projection ratios must be < 1")
ErrIllegalMaxOpenConns = errors.New("MaxOpenConns of the database must be higher than 3 or 0 for unlimited") ErrIllegalMaxOpenConns = errors.New("MaxOpenConns of the database must be higher than 3 or 0 for unlimited")
ErrIllegalMaxIdleConns = errors.New("MaxIdleConns of the database must be higher than 3 or 0 for unlimited") ErrIllegalMaxIdleConns = errors.New("MaxIdleConns of the database must be higher than 3 or 0 for unlimited")
ErrInvalidPurpose = errors.New("DBPurpose out of range") // ErrInvalidPurpose = errors.New("DBPurpose out of range")
) )
// ConnectionConfig defines the Max Open and Idle connections for a DB connection pool. // ConnectionConfig defines the Max Open and Idle connections for a DB connection pool.
@ -25,28 +24,6 @@ type ConnectionConfig struct {
AfterConnect []func(ctx context.Context, c *pgx.Conn) error AfterConnect []func(ctx context.Context, c *pgx.Conn) error
} }
// takeRatio of MaxOpenConns and MaxIdleConns from config and returns
// a new ConnectionConfig with the resulting values.
func (c *ConnectionConfig) takeRatio(ratio float64) (*ConnectionConfig, error) {
if ratio < 0 {
return nil, ErrNegativeRatio
}
out := &ConnectionConfig{
MaxOpenConns: uint32(ratio * float64(c.MaxOpenConns)),
MaxIdleConns: uint32(ratio * float64(c.MaxIdleConns)),
AfterConnect: c.AfterConnect,
}
if c.MaxOpenConns != 0 && out.MaxOpenConns < 1 && ratio > 0 {
out.MaxOpenConns = 1
}
if c.MaxIdleConns != 0 && out.MaxIdleConns < 1 && ratio > 0 {
out.MaxIdleConns = 1
}
return out, nil
}
var afterConnectFuncs []func(ctx context.Context, c *pgx.Conn) error var afterConnectFuncs []func(ctx context.Context, c *pgx.Conn) error
func RegisterAfterConnect(f func(ctx context.Context, c *pgx.Conn) error) { func RegisterAfterConnect(f func(ctx context.Context, c *pgx.Conn) error) {
@ -82,48 +59,10 @@ func RegisterDefaultPgTypeVariants[T any](m *pgtype.Map, name, arrayName string)
// //
// openConns and idleConns must be at least 3 or 0, which means no limit. // openConns and idleConns must be at least 3 or 0, which means no limit.
// The pusherRatio and spoolerRatio must be between 0 and 1. // The pusherRatio and spoolerRatio must be between 0 and 1.
func NewConnectionConfig(openConns, idleConns uint32, pusherRatio, projectionRatio float64, purpose DBPurpose) (*ConnectionConfig, error) { func NewConnectionConfig(openConns, idleConns uint32) (*ConnectionConfig, error) {
if openConns != 0 && openConns < 3 { return &ConnectionConfig{
return nil, ErrIllegalMaxOpenConns
}
if idleConns != 0 && idleConns < 3 {
return nil, ErrIllegalMaxIdleConns
}
if pusherRatio+projectionRatio >= 1 {
return nil, ErrHighSumRatio
}
queryConfig := &ConnectionConfig{
MaxOpenConns: openConns, MaxOpenConns: openConns,
MaxIdleConns: idleConns, MaxIdleConns: idleConns,
AfterConnect: afterConnectFuncs, AfterConnect: afterConnectFuncs,
} }, nil
pusherConfig, err := queryConfig.takeRatio(pusherRatio)
if err != nil {
return nil, fmt.Errorf("event pusher: %w", err)
}
spoolerConfig, err := queryConfig.takeRatio(projectionRatio)
if err != nil {
return nil, fmt.Errorf("projection spooler: %w", err)
}
// subtract the claimed amount
if queryConfig.MaxOpenConns > 0 {
queryConfig.MaxOpenConns -= pusherConfig.MaxOpenConns + spoolerConfig.MaxOpenConns
}
if queryConfig.MaxIdleConns > 0 {
queryConfig.MaxIdleConns -= pusherConfig.MaxIdleConns + spoolerConfig.MaxIdleConns
}
switch purpose {
case DBPurposeQuery:
return queryConfig, nil
case DBPurposeEventPusher:
return pusherConfig, nil
case DBPurposeProjectionSpooler:
return spoolerConfig, nil
default:
return nil, fmt.Errorf("%w: %v", ErrInvalidPurpose, purpose)
}
} }

View File

@ -1,252 +0,0 @@
package dialect
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestConnectionConfig_takeRatio(t *testing.T) {
type fields struct {
MaxOpenConns uint32
MaxIdleConns uint32
}
tests := []struct {
name string
fields fields
ratio float64
wantOut *ConnectionConfig
wantErr error
}{
{
name: "ratio less than 0 error",
ratio: -0.1,
wantErr: ErrNegativeRatio,
},
{
name: "zero values",
fields: fields{
MaxOpenConns: 0,
MaxIdleConns: 0,
},
ratio: 0,
wantOut: &ConnectionConfig{
MaxOpenConns: 0,
MaxIdleConns: 0,
},
},
{
name: "max conns, ratio 0",
fields: fields{
MaxOpenConns: 10,
MaxIdleConns: 5,
},
ratio: 0,
wantOut: &ConnectionConfig{
MaxOpenConns: 0,
MaxIdleConns: 0,
},
},
{
name: "half ratio",
fields: fields{
MaxOpenConns: 10,
MaxIdleConns: 5,
},
ratio: 0.5,
wantOut: &ConnectionConfig{
MaxOpenConns: 5,
MaxIdleConns: 2,
},
},
{
name: "minimal 1",
fields: fields{
MaxOpenConns: 2,
MaxIdleConns: 2,
},
ratio: 0.1,
wantOut: &ConnectionConfig{
MaxOpenConns: 1,
MaxIdleConns: 1,
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
in := &ConnectionConfig{
MaxOpenConns: tt.fields.MaxOpenConns,
MaxIdleConns: tt.fields.MaxIdleConns,
}
got, err := in.takeRatio(tt.ratio)
require.ErrorIs(t, err, tt.wantErr)
assert.Equal(t, tt.wantOut, got)
})
}
}
func TestNewConnectionConfig(t *testing.T) {
type args struct {
openConns uint32
idleConns uint32
pusherRatio float64
projectionRatio float64
purpose DBPurpose
}
tests := []struct {
name string
args args
want *ConnectionConfig
wantErr error
}{
{
name: "illegal open conns error",
args: args{
openConns: 2,
idleConns: 3,
},
wantErr: ErrIllegalMaxOpenConns,
},
{
name: "illegal idle conns error",
args: args{
openConns: 3,
idleConns: 2,
},
wantErr: ErrIllegalMaxIdleConns,
},
{
name: "high ration sum error",
args: args{
openConns: 3,
idleConns: 3,
pusherRatio: 0.5,
projectionRatio: 0.5,
},
wantErr: ErrHighSumRatio,
},
{
name: "illegal pusher ratio error",
args: args{
openConns: 3,
idleConns: 3,
pusherRatio: -0.1,
projectionRatio: 0.5,
},
wantErr: ErrNegativeRatio,
},
{
name: "illegal projection ratio error",
args: args{
openConns: 3,
idleConns: 3,
pusherRatio: 0.5,
projectionRatio: -0.1,
},
wantErr: ErrNegativeRatio,
},
{
name: "invalid purpose error",
args: args{
openConns: 3,
idleConns: 3,
pusherRatio: 0.4,
projectionRatio: 0.4,
purpose: 99,
},
wantErr: ErrInvalidPurpose,
},
{
name: "min values, query purpose",
args: args{
openConns: 3,
idleConns: 3,
pusherRatio: 0.2,
projectionRatio: 0.2,
purpose: DBPurposeQuery,
},
want: &ConnectionConfig{
MaxOpenConns: 1,
MaxIdleConns: 1,
},
},
{
name: "min values, pusher purpose",
args: args{
openConns: 3,
idleConns: 3,
pusherRatio: 0.2,
projectionRatio: 0.2,
purpose: DBPurposeEventPusher,
},
want: &ConnectionConfig{
MaxOpenConns: 1,
MaxIdleConns: 1,
},
},
{
name: "min values, projection purpose",
args: args{
openConns: 3,
idleConns: 3,
pusherRatio: 0.2,
projectionRatio: 0.2,
purpose: DBPurposeProjectionSpooler,
},
want: &ConnectionConfig{
MaxOpenConns: 1,
MaxIdleConns: 1,
},
},
{
name: "high values, query purpose",
args: args{
openConns: 10,
idleConns: 5,
pusherRatio: 0.2,
projectionRatio: 0.2,
purpose: DBPurposeQuery,
},
want: &ConnectionConfig{
MaxOpenConns: 6,
MaxIdleConns: 3,
},
},
{
name: "high values, pusher purpose",
args: args{
openConns: 10,
idleConns: 5,
pusherRatio: 0.2,
projectionRatio: 0.2,
purpose: DBPurposeEventPusher,
},
want: &ConnectionConfig{
MaxOpenConns: 2,
MaxIdleConns: 1,
},
},
{
name: "high values, projection purpose",
args: args{
openConns: 10,
idleConns: 5,
pusherRatio: 0.2,
projectionRatio: 0.2,
purpose: DBPurposeProjectionSpooler,
},
want: &ConnectionConfig{
MaxOpenConns: 2,
MaxIdleConns: 1,
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := NewConnectionConfig(tt.args.openConns, tt.args.idleConns, tt.args.pusherRatio, tt.args.projectionRatio, tt.args.purpose)
require.ErrorIs(t, err, tt.wantErr)
assert.Equal(t, tt.want, got)
})
}
}

View File

@ -3,7 +3,6 @@ package postgres
import ( import (
"context" "context"
"database/sql" "database/sql"
"fmt"
"strconv" "strconv"
"strings" "strings"
"time" "time"
@ -14,7 +13,6 @@ import (
"github.com/mitchellh/mapstructure" "github.com/mitchellh/mapstructure"
"github.com/zitadel/logging" "github.com/zitadel/logging"
"github.com/zitadel/zitadel/internal/api/authz"
"github.com/zitadel/zitadel/internal/database/dialect" "github.com/zitadel/zitadel/internal/database/dialect"
) )
@ -75,13 +73,13 @@ func (_ *Config) Decode(configs []interface{}) (dialect.Connector, error) {
return connector, nil return connector, nil
} }
func (c *Config) Connect(useAdmin bool, pusherRatio, spoolerRatio float64, purpose dialect.DBPurpose) (*sql.DB, *pgxpool.Pool, error) { func (c *Config) Connect(useAdmin bool) (*sql.DB, *pgxpool.Pool, error) {
connConfig, err := dialect.NewConnectionConfig(c.MaxOpenConns, c.MaxIdleConns, pusherRatio, spoolerRatio, purpose) connConfig, err := dialect.NewConnectionConfig(c.MaxOpenConns, c.MaxIdleConns)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
config, err := pgxpool.ParseConfig(c.String(useAdmin, purpose.AppName())) config, err := pgxpool.ParseConfig(c.String(useAdmin))
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
@ -95,18 +93,6 @@ func (c *Config) Connect(useAdmin bool, pusherRatio, spoolerRatio float64, purpo
return nil return nil
} }
// For the pusher we set the app name with the instance ID
if purpose == dialect.DBPurposeEventPusher {
config.BeforeAcquire = func(ctx context.Context, conn *pgx.Conn) bool {
return setAppNameWithID(ctx, conn, purpose, authz.GetInstance(ctx).InstanceID())
}
config.AfterRelease = func(conn *pgx.Conn) bool {
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
return setAppNameWithID(ctx, conn, purpose, "IDLE")
}
}
if connConfig.MaxOpenConns != 0 { if connConfig.MaxOpenConns != 0 {
config.MaxConns = int32(connConfig.MaxOpenConns) config.MaxConns = int32(connConfig.MaxOpenConns)
} }
@ -191,7 +177,7 @@ func (s *Config) checkSSL(user User) {
} }
} }
func (c Config) String(useAdmin bool, appName string) string { func (c Config) String(useAdmin bool) string {
user := c.User user := c.User
if useAdmin { if useAdmin {
user = c.Admin.User user = c.Admin.User
@ -201,7 +187,7 @@ func (c Config) String(useAdmin bool, appName string) string {
"host=" + c.Host, "host=" + c.Host,
"port=" + strconv.Itoa(int(c.Port)), "port=" + strconv.Itoa(int(c.Port)),
"user=" + user.Username, "user=" + user.Username,
"application_name=" + appName, "application_name=zitadel",
"sslmode=" + user.SSL.Mode, "sslmode=" + user.SSL.Mode,
} }
if c.Options != "" { if c.Options != "" {
@ -233,11 +219,3 @@ func (c Config) String(useAdmin bool, appName string) string {
return strings.Join(fields, " ") return strings.Join(fields, " ")
} }
func setAppNameWithID(ctx context.Context, conn *pgx.Conn, purpose dialect.DBPurpose, id string) bool {
// needs to be set like this because psql complains about parameters in the SET statement
query := fmt.Sprintf("SET application_name = '%s_%s'", purpose.AppName(), id)
_, err := conn.Exec(ctx, query)
logging.OnError(err).Warn("failed to set application name")
return err == nil
}

View File

@ -309,7 +309,7 @@ func prepareConditions(criteria querier, query *repository.SearchQuery, useV1 bo
} }
for i := range instanceIDs { for i := range instanceIDs {
instanceIDs[i] = dialect.DBPurposeEventPusher.AppName() + "_" + instanceIDs[i] instanceIDs[i] = "zitadel_es_pusher_" + instanceIDs[i]
} }
clauses += awaitOpenTransactions(useV1) clauses += awaitOpenTransactions(useV1)

View File

@ -4,9 +4,11 @@ import (
"context" "context"
"database/sql" "database/sql"
_ "embed" _ "embed"
"fmt"
"github.com/zitadel/logging" "github.com/zitadel/logging"
"github.com/zitadel/zitadel/internal/api/authz"
"github.com/zitadel/zitadel/internal/database" "github.com/zitadel/zitadel/internal/database"
"github.com/zitadel/zitadel/internal/eventstore" "github.com/zitadel/zitadel/internal/eventstore"
"github.com/zitadel/zitadel/internal/telemetry/tracing" "github.com/zitadel/zitadel/internal/telemetry/tracing"
@ -55,6 +57,11 @@ func (es *Eventstore) writeCommands(ctx context.Context, client database.Context
}() }()
} }
_, err = tx.ExecContext(ctx, "SET LOCAL application_name = $1", fmt.Sprintf("zitadel_es_pusher_%s", authz.GetInstance(ctx).InstanceID()))
if err != nil {
return nil, err
}
events, err := writeEvents(ctx, tx, commands) events, err := writeEvents(ctx, tx, commands)
if err != nil { if err != nil {
return nil, err return nil, err

View File

@ -0,0 +1,51 @@
package instance
import (
"context"
"time"
"google.golang.org/protobuf/types/known/timestamppb"
"github.com/zitadel/zitadel/pkg/grpc/object"
system_pb "github.com/zitadel/zitadel/pkg/grpc/system"
)
type AddInstanceRequest struct {
system_pb.AddInstanceRequest
ID string
CreatedAt time.Time
}
type AddInstanceResponse struct {
system_pb.AddInstanceResponse
}
func (bl *BusinessLogic) AddInstance(ctx context.Context, request *AddInstanceRequest) (_ *AddInstanceResponse, err error) {
tx, err := bl.client.Begin(ctx)
if err != nil {
return nil, err
}
defer func() {
err = tx.End(ctx, err)
}()
request.ID = time.Since(time.Time{}).String()
err = bl.storage.WriteInstanceAdded(ctx, tx, request)
if err != nil {
return nil, err
}
return &AddInstanceResponse{
AddInstanceResponse: system_pb.AddInstanceResponse{
InstanceId: request.ID,
Details: &object.ObjectDetails{
Sequence: 1,
CreationDate: timestamppb.New(request.CreatedAt),
ChangeDate: timestamppb.New(request.CreatedAt),
ResourceOwner: request.ID,
},
},
}, nil
}

View File

@ -0,0 +1,62 @@
package instance_test
import (
"context"
"log/slog"
"testing"
"github.com/zitadel/zitadel/internal/v3/instance"
repo_log "github.com/zitadel/zitadel/internal/v3/repository/log"
repo_mem "github.com/zitadel/zitadel/internal/v3/repository/memory"
"github.com/zitadel/zitadel/internal/v3/storage"
"github.com/zitadel/zitadel/internal/v3/storage/memory"
system_pb "github.com/zitadel/zitadel/pkg/grpc/system"
)
func TestBusinessLogic_AddInstance(t *testing.T) {
type fields struct {
client storage.Client
stores []instance.InstanceStorage
}
type args struct {
ctx context.Context
request *instance.AddInstanceRequest
}
tests := []struct {
name string
fields fields
args args
wantErr bool
}{
{
name: "ok",
fields: fields{
client: &memory.Client{},
stores: []instance.InstanceStorage{
repo_mem.NewInstanceMemory(),
repo_log.NewInstanceLogger(slog.Default()),
},
},
args: args{
ctx: context.Background(),
request: &instance.AddInstanceRequest{
AddInstanceRequest: system_pb.AddInstanceRequest{
InstanceName: "test",
CustomDomain: "test",
DefaultLanguage: "en",
},
},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
bl := instance.NewBusinessLogic(tt.fields.client, tt.fields.stores...)
_, err := bl.AddInstance(tt.args.ctx, tt.args.request)
if (err != nil) != tt.wantErr {
t.Errorf("BusinessLogic.AddInstance() error = %v, wantErr %v", err, tt.wantErr)
return
}
})
}
}

View File

@ -0,0 +1,34 @@
package instance
import (
"context"
"github.com/zitadel/zitadel/internal/v3/storage"
)
type InstanceStorage interface {
WriteInstanceAdded(ctx context.Context, tx storage.Transaction, instance *AddInstanceRequest) error
}
type BusinessLogic struct {
client storage.Client
storage InstanceStorage
}
func NewBusinessLogic(client storage.Client, stores ...InstanceStorage) *BusinessLogic {
return &BusinessLogic{
client: client,
storage: chainedStorage[InstanceStorage](stores),
}
}
type chainedStorage[S InstanceStorage] []S
func (cs chainedStorage[S]) WriteInstanceAdded(ctx context.Context, tx storage.Transaction, instance *AddInstanceRequest) error {
for _, store := range cs {
if err := store.WriteInstanceAdded(ctx, tx, instance); err != nil {
return err
}
}
return nil
}

View File

@ -0,0 +1,29 @@
// Logs the operation of an instance
package log
import (
"context"
"log/slog"
"github.com/zitadel/zitadel/internal/v3/instance"
"github.com/zitadel/zitadel/internal/v3/storage"
)
var _ instance.InstanceStorage = (*InstanceLogger)(nil)
type InstanceLogger struct {
*Logger
}
func NewInstanceLogger(logger *slog.Logger) *InstanceLogger {
return &InstanceLogger{Logger: NewLogger(logger)}
}
// WriteInstanceAdded implements instance.InstanceStorage.
func (l *InstanceLogger) WriteInstanceAdded(ctx context.Context, tx storage.Transaction, instance *instance.AddInstanceRequest) error {
tx.OnCommit(func(ctx context.Context) error {
l.InfoContext(ctx, "Instance added", slog.Any("instance", instance))
return nil
})
return nil
}

View File

@ -0,0 +1,14 @@
// Logs the operation of an instance
package log
import (
"log/slog"
)
type Logger struct {
*slog.Logger
}
func NewLogger(logger *slog.Logger) *Logger {
return &Logger{logger}
}

View File

@ -0,0 +1,71 @@
package memory
import (
"context"
"errors"
"sync"
"time"
"golang.org/x/text/language"
"github.com/zitadel/zitadel/internal/v3/instance"
"github.com/zitadel/zitadel/internal/v3/storage"
)
var _ instance.InstanceStorage = (*InstanceMemory)(nil)
type InstanceMemory struct {
instances map[string]*memoryInstance
mu sync.RWMutex
}
func NewInstanceMemory() *InstanceMemory {
return &InstanceMemory{
instances: make(map[string]*memoryInstance),
}
}
// WriteInstanceAdded implements instance.InstanceStorage.
func (i *InstanceMemory) WriteInstanceAdded(ctx context.Context, tx storage.Transaction, instance *instance.AddInstanceRequest) error {
defaultLanguage, err := language.Parse(instance.DefaultLanguage)
if err != nil {
return err
}
if instance.CreatedAt.IsZero() {
instance.CreatedAt = time.Now()
}
i.mu.Lock()
if i.instances[instance.ID] != nil {
return errors.New("instance already exists")
}
i.instances[instance.ID] = &memoryInstance{
id: instance.ID,
name: instance.InstanceName,
customDomain: instance.CustomDomain,
defaultLanguage: defaultLanguage,
}
tx.OnCommit(func(ctx context.Context) error {
i.mu.Unlock()
return nil
})
tx.OnRollback(func(ctx context.Context) error {
delete(i.instances, instance.ID)
i.mu.Unlock()
return nil
})
return nil
}
type memoryInstance struct {
id string
name string
customDomain string
defaultLanguage language.Tag
}

View File

@ -0,0 +1,64 @@
package memory
import (
"context"
"log/slog"
"github.com/zitadel/zitadel/internal/v3/storage"
)
var _ storage.Client = (*Client)(nil)
type Client struct{}
func (c *Client) Begin(ctx context.Context) (storage.Transaction, error) {
return new(Transaction), nil
}
var _ storage.Transaction = (*Transaction)(nil)
type Transaction struct {
commitHooks []func(ctx context.Context) error
rollbackHooks []func(ctx context.Context) error
}
// Commit implements storage.Transaction.
func (t *Transaction) Commit(ctx context.Context) error {
for _, hook := range t.commitHooks {
if err := hook(ctx); err != nil {
return err
}
}
return nil
}
// End implements storage.Transaction.
func (t *Transaction) End(ctx context.Context, err error) error {
if err != nil {
rollbackErr := t.Rollback(ctx)
slog.WarnContext(ctx, "Rollback failed", slog.Any("cause", rollbackErr))
return err
}
return t.Commit(ctx)
}
// OnCommit implements storage.Transaction.
func (t *Transaction) OnCommit(hook func(ctx context.Context) error) {
t.commitHooks = append(t.commitHooks, hook)
}
// OnRollback implements storage.Transaction.
func (t *Transaction) OnRollback(hook func(ctx context.Context) error) {
t.rollbackHooks = append(t.rollbackHooks, hook)
}
// Rollback implements storage.Transaction.
func (t *Transaction) Rollback(ctx context.Context) error {
for _, hook := range t.rollbackHooks {
if err := hook(ctx); err != nil {
return err
}
}
return nil
}

View File

@ -0,0 +1,31 @@
package storage
import "context"
type Client interface {
Begin(ctx context.Context) (Transaction, error)
}
// type Command interface {
// }
// type Query[R any] interface {
// Result() R
// }
// type Executor interface {
// Execute(ctx context.Context, command Command) error
// }
// type Querier interface {
// Query[R](ctx context.Context, query Query[R]) error
// }
type Transaction interface {
Commit(ctx context.Context) error
Rollback(ctx context.Context) error
// End the transaction based on err. If err is nil the transaction is committed, otherwise it is rolled back.
End(ctx context.Context, err error) error
OnCommit(hook func(ctx context.Context) error)
OnRollback(hook func(ctx context.Context) error)
}