mirror of
https://github.com/zitadel/zitadel.git
synced 2025-08-11 17:57:33 +00:00
fix(db): always use begin tx (#7142)
* fix(db): always use begin tx * fix(handler): timeout for begin
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
package initialise
|
||||
|
||||
import (
|
||||
"context"
|
||||
"embed"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
@@ -48,7 +49,7 @@ The user provided by flags needs privileges to
|
||||
Run: func(cmd *cobra.Command, args []string) {
|
||||
config := MustNewConfig(viper.GetViper())
|
||||
|
||||
InitAll(config)
|
||||
InitAll(cmd.Context(), config)
|
||||
},
|
||||
}
|
||||
|
||||
@@ -56,7 +57,7 @@ The user provided by flags needs privileges to
|
||||
return cmd
|
||||
}
|
||||
|
||||
func InitAll(config *Config) {
|
||||
func InitAll(ctx context.Context, config *Config) {
|
||||
err := initialise(config.Database,
|
||||
VerifyUser(config.Database.Username(), config.Database.Password()),
|
||||
VerifyDatabase(config.Database.DatabaseName()),
|
||||
@@ -64,7 +65,7 @@ func InitAll(config *Config) {
|
||||
)
|
||||
logging.OnError(err).Fatal("unable to initialize the database")
|
||||
|
||||
err = verifyZitadel(config.Database)
|
||||
err = verifyZitadel(ctx, config.Database)
|
||||
logging.OnError(err).Fatal("unable to initialize ZITADEL")
|
||||
}
|
||||
|
||||
|
@@ -1,6 +1,7 @@
|
||||
package initialise
|
||||
|
||||
import (
|
||||
"context"
|
||||
_ "embed"
|
||||
"fmt"
|
||||
|
||||
@@ -23,13 +24,13 @@ Prereqesits:
|
||||
`,
|
||||
Run: func(cmd *cobra.Command, args []string) {
|
||||
config := MustNewConfig(viper.GetViper())
|
||||
err := verifyZitadel(config.Database)
|
||||
err := verifyZitadel(cmd.Context(), config.Database)
|
||||
logging.OnError(err).Fatal("unable to init zitadel")
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func VerifyZitadel(db *database.DB, config database.Config) error {
|
||||
func VerifyZitadel(ctx context.Context, db *database.DB, config database.Config) error {
|
||||
err := ReadStmts(config.Type())
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -41,7 +42,7 @@ func VerifyZitadel(db *database.DB, config database.Config) error {
|
||||
}
|
||||
|
||||
logging.WithFields().Info("verify encryption keys")
|
||||
if err := createEncryptionKeys(db); err != nil {
|
||||
if err := createEncryptionKeys(ctx, db); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -56,7 +57,7 @@ func VerifyZitadel(db *database.DB, config database.Config) error {
|
||||
}
|
||||
|
||||
logging.WithFields().Info("verify events tables")
|
||||
if err := createEvents(db); err != nil {
|
||||
if err := createEvents(ctx, db); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -73,7 +74,7 @@ func VerifyZitadel(db *database.DB, config database.Config) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func verifyZitadel(config database.Config) error {
|
||||
func verifyZitadel(ctx context.Context, config database.Config) error {
|
||||
logging.WithFields("database", config.DatabaseName()).Info("verify zitadel")
|
||||
|
||||
db, err := database.Connect(config, false, dialect.DBPurposeQuery)
|
||||
@@ -81,15 +82,15 @@ func verifyZitadel(config database.Config) error {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := VerifyZitadel(db, config); err != nil {
|
||||
if err := VerifyZitadel(ctx, db, config); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return db.Close()
|
||||
}
|
||||
|
||||
func createEncryptionKeys(db *database.DB) error {
|
||||
tx, err := db.Begin()
|
||||
func createEncryptionKeys(ctx context.Context, db *database.DB) error {
|
||||
tx, err := db.BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -101,8 +102,8 @@ func createEncryptionKeys(db *database.DB) error {
|
||||
return tx.Commit()
|
||||
}
|
||||
|
||||
func createEvents(db *database.DB) (err error) {
|
||||
tx, err := db.Begin()
|
||||
func createEvents(ctx context.Context, db *database.DB) (err error) {
|
||||
tx, err := db.BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@@ -1,6 +1,7 @@
|
||||
package initialise
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"database/sql/driver"
|
||||
"errors"
|
||||
@@ -107,7 +108,7 @@ func Test_verifyEvents(t *testing.T) {
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if err := createEvents(tt.args.db.db); !errors.Is(err, tt.targetErr) {
|
||||
if err := createEvents(context.Background(), tt.args.db.db); !errors.Is(err, tt.targetErr) {
|
||||
t.Errorf("createEvents() error = %v, want: %v", err, tt.targetErr)
|
||||
}
|
||||
if err := tt.args.db.mock.ExpectationsWereMet(); err != nil {
|
||||
@@ -160,7 +161,7 @@ func Test_verifyEncryptionKeys(t *testing.T) {
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if err := createEncryptionKeys(tt.args.db.db); !errors.Is(err, tt.targetErr) {
|
||||
if err := createEncryptionKeys(context.Background(), tt.args.db.db); !errors.Is(err, tt.targetErr) {
|
||||
t.Errorf("createEvents() error = %v, want: %v", err, tt.targetErr)
|
||||
}
|
||||
if err := tt.args.db.mock.ExpectationsWereMet(); err != nil {
|
||||
|
Reference in New Issue
Block a user