fix(db): always use begin tx (#7142)

* fix(db): always use begin tx

* fix(handler): timeout for begin
This commit is contained in:
Silvan
2024-01-04 17:12:20 +01:00
committed by GitHub
parent c0cef4983a
commit b7d027e2fd
19 changed files with 59 additions and 56 deletions

View File

@@ -1,6 +1,7 @@
package initialise
import (
"context"
"embed"
"github.com/spf13/cobra"
@@ -48,7 +49,7 @@ The user provided by flags needs privileges to
Run: func(cmd *cobra.Command, args []string) {
config := MustNewConfig(viper.GetViper())
InitAll(config)
InitAll(cmd.Context(), config)
},
}
@@ -56,7 +57,7 @@ The user provided by flags needs privileges to
return cmd
}
func InitAll(config *Config) {
func InitAll(ctx context.Context, config *Config) {
err := initialise(config.Database,
VerifyUser(config.Database.Username(), config.Database.Password()),
VerifyDatabase(config.Database.DatabaseName()),
@@ -64,7 +65,7 @@ func InitAll(config *Config) {
)
logging.OnError(err).Fatal("unable to initialize the database")
err = verifyZitadel(config.Database)
err = verifyZitadel(ctx, config.Database)
logging.OnError(err).Fatal("unable to initialize ZITADEL")
}

View File

@@ -1,6 +1,7 @@
package initialise
import (
"context"
_ "embed"
"fmt"
@@ -23,13 +24,13 @@ Prereqesits:
`,
Run: func(cmd *cobra.Command, args []string) {
config := MustNewConfig(viper.GetViper())
err := verifyZitadel(config.Database)
err := verifyZitadel(cmd.Context(), config.Database)
logging.OnError(err).Fatal("unable to init zitadel")
},
}
}
func VerifyZitadel(db *database.DB, config database.Config) error {
func VerifyZitadel(ctx context.Context, db *database.DB, config database.Config) error {
err := ReadStmts(config.Type())
if err != nil {
return err
@@ -41,7 +42,7 @@ func VerifyZitadel(db *database.DB, config database.Config) error {
}
logging.WithFields().Info("verify encryption keys")
if err := createEncryptionKeys(db); err != nil {
if err := createEncryptionKeys(ctx, db); err != nil {
return err
}
@@ -56,7 +57,7 @@ func VerifyZitadel(db *database.DB, config database.Config) error {
}
logging.WithFields().Info("verify events tables")
if err := createEvents(db); err != nil {
if err := createEvents(ctx, db); err != nil {
return err
}
@@ -73,7 +74,7 @@ func VerifyZitadel(db *database.DB, config database.Config) error {
return nil
}
func verifyZitadel(config database.Config) error {
func verifyZitadel(ctx context.Context, config database.Config) error {
logging.WithFields("database", config.DatabaseName()).Info("verify zitadel")
db, err := database.Connect(config, false, dialect.DBPurposeQuery)
@@ -81,15 +82,15 @@ func verifyZitadel(config database.Config) error {
return err
}
if err := VerifyZitadel(db, config); err != nil {
if err := VerifyZitadel(ctx, db, config); err != nil {
return err
}
return db.Close()
}
func createEncryptionKeys(db *database.DB) error {
tx, err := db.Begin()
func createEncryptionKeys(ctx context.Context, db *database.DB) error {
tx, err := db.BeginTx(ctx, nil)
if err != nil {
return err
}
@@ -101,8 +102,8 @@ func createEncryptionKeys(db *database.DB) error {
return tx.Commit()
}
func createEvents(db *database.DB) (err error) {
tx, err := db.Begin()
func createEvents(ctx context.Context, db *database.DB) (err error) {
tx, err := db.BeginTx(ctx, nil)
if err != nil {
return err
}

View File

@@ -1,6 +1,7 @@
package initialise
import (
"context"
"database/sql"
"database/sql/driver"
"errors"
@@ -107,7 +108,7 @@ func Test_verifyEvents(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if err := createEvents(tt.args.db.db); !errors.Is(err, tt.targetErr) {
if err := createEvents(context.Background(), tt.args.db.db); !errors.Is(err, tt.targetErr) {
t.Errorf("createEvents() error = %v, want: %v", err, tt.targetErr)
}
if err := tt.args.db.mock.ExpectationsWereMet(); err != nil {
@@ -160,7 +161,7 @@ func Test_verifyEncryptionKeys(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if err := createEncryptionKeys(tt.args.db.db); !errors.Is(err, tt.targetErr) {
if err := createEncryptionKeys(context.Background(), tt.args.db.db); !errors.Is(err, tt.targetErr) {
t.Errorf("createEvents() error = %v, want: %v", err, tt.targetErr)
}
if err := tt.args.db.mock.ExpectationsWereMet(); err != nil {

View File

@@ -74,7 +74,7 @@ new -f keys.yaml key2=anotherkey`,
if err != nil {
return err
}
return storage.CreateKeys(keys...)
return storage.CreateKeys(cmd.Context(), keys...)
},
}
cmd.PersistentFlags().StringP(flagKeyFile, "f", "", "path to keys file")

View File

@@ -46,7 +46,7 @@ func (mig *FirstInstance) Execute(ctx context.Context) error {
if err != nil {
return fmt.Errorf("cannot start key storage: %w", err)
}
if err = verifyKey(mig.userEncryptionKey, keyStorage); err != nil {
if err = verifyKey(ctx, mig.userEncryptionKey, keyStorage); err != nil {
return err
}
userAlg, err := crypto.NewAESCrypto(mig.userEncryptionKey, keyStorage)
@@ -54,7 +54,7 @@ func (mig *FirstInstance) Execute(ctx context.Context) error {
return err
}
if err = verifyKey(mig.smtpEncryptionKey, keyStorage); err != nil {
if err = verifyKey(ctx, mig.smtpEncryptionKey, keyStorage); err != nil {
return err
}
smtpEncryption, err := crypto.NewAESCrypto(mig.smtpEncryptionKey, keyStorage)
@@ -62,7 +62,7 @@ func (mig *FirstInstance) Execute(ctx context.Context) error {
return err
}
if err = verifyKey(mig.oidcEncryptionKey, keyStorage); err != nil {
if err = verifyKey(ctx, mig.oidcEncryptionKey, keyStorage); err != nil {
return err
}
oidcEncryption, err := crypto.NewAESCrypto(mig.oidcEncryptionKey, keyStorage)
@@ -167,7 +167,7 @@ func (mig *FirstInstance) String() string {
return "03_default_instance"
}
func verifyKey(key *crypto.KeyConfig, storage crypto.KeyStorage) (err error) {
func verifyKey(ctx context.Context, key *crypto.KeyConfig, storage crypto.KeyStorage) (err error) {
_, err = crypto.LoadKey(key.EncryptionKeyID, storage)
if err == nil {
return nil
@@ -176,5 +176,5 @@ func verifyKey(key *crypto.KeyConfig, storage crypto.KeyStorage) (err error) {
if err != nil {
return err
}
return storage.CreateKeys(k)
return storage.CreateKeys(ctx, k)
}

View File

@@ -1,6 +1,8 @@
package start
import (
"context"
"github.com/zitadel/zitadel/internal/crypto"
"github.com/zitadel/zitadel/internal/zerrors"
)
@@ -34,8 +36,8 @@ type encryptionKeys struct {
OIDCKey []byte
}
func ensureEncryptionKeys(keyConfig *encryptionKeyConfig, keyStorage crypto.KeyStorage) (keys *encryptionKeys, err error) {
if err := verifyDefaultKeys(keyStorage); err != nil {
func ensureEncryptionKeys(ctx context.Context, keyConfig *encryptionKeyConfig, keyStorage crypto.KeyStorage) (keys *encryptionKeys, err error) {
if err := verifyDefaultKeys(ctx, keyStorage); err != nil {
return nil, err
}
keys = new(encryptionKeys)
@@ -89,7 +91,7 @@ func ensureEncryptionKeys(keyConfig *encryptionKeyConfig, keyStorage crypto.KeyS
return keys, nil
}
func verifyDefaultKeys(keyStorage crypto.KeyStorage) (err error) {
func verifyDefaultKeys(ctx context.Context, keyStorage crypto.KeyStorage) (err error) {
keys := make([]*crypto.Key, 0, len(defaultKeyIDs))
for _, keyID := range defaultKeyIDs {
_, err := crypto.LoadKey(keyID, keyStorage)
@@ -105,7 +107,7 @@ func verifyDefaultKeys(keyStorage crypto.KeyStorage) (err error) {
if len(keys) == 0 {
return nil
}
if err := keyStorage.CreateKeys(keys...); err != nil {
if err := keyStorage.CreateKeys(ctx, keys...); err != nil {
return zerrors.ThrowInternal(err, "START-aGBq2", "cannot create default keys")
}
return nil

View File

@@ -95,7 +95,7 @@ Requirements:
if err != nil {
return err
}
return startZitadel(config, masterKey, server)
return startZitadel(cmd.Context(), config, masterKey, server)
},
}
@@ -119,11 +119,9 @@ type Server struct {
Shutdown chan<- os.Signal
}
func startZitadel(config *Config, masterKey string, server chan<- *Server) error {
func startZitadel(ctx context.Context, config *Config, masterKey string, server chan<- *Server) error {
showBasicInformation(config)
ctx := context.Background()
i18n.MustLoadSupportedLanguagesFromDir()
queryDBClient, err := database.Connect(config.Database, false, dialect.DBPurposeQuery)
@@ -143,7 +141,7 @@ func startZitadel(config *Config, masterKey string, server chan<- *Server) error
if err != nil {
return fmt.Errorf("cannot start key storage: %w", err)
}
keys, err := ensureEncryptionKeys(config.EncryptionKeys, keyStorage)
keys, err := ensureEncryptionKeys(ctx, config.EncryptionKeys, keyStorage)
if err != nil {
return err
}

View File

@@ -29,7 +29,7 @@ Requirements:
masterKey, err := key.MasterKey(cmd)
logging.OnError(err).Panic("No master key provided")
initialise.InitAll(initialise.MustNewConfig(viper.GetViper()))
initialise.InitAll(cmd.Context(), initialise.MustNewConfig(viper.GetViper()))
setupConfig := setup.MustNewConfig(viper.GetViper())
setupSteps := setup.MustNewSteps(viper.New())
@@ -37,7 +37,7 @@ Requirements:
startConfig := MustNewConfig(viper.GetViper())
err = startZitadel(startConfig, masterKey, server)
err = startZitadel(cmd.Context(), startConfig, masterKey, server)
logging.OnError(err).Fatal("unable to start zitadel")
},
}

View File

@@ -35,7 +35,7 @@ Requirements:
startConfig := MustNewConfig(viper.GetViper())
err = startZitadel(startConfig, masterKey, server)
err = startZitadel(cmd.Context(), startConfig, masterKey, server)
logging.OnError(err).Fatal("unable to start zitadel")
},
}