diff --git a/cmd/initialise/init.go b/cmd/initialise/init.go index 0af04925c9..28f123f851 100644 --- a/cmd/initialise/init.go +++ b/cmd/initialise/init.go @@ -1,6 +1,7 @@ package initialise import ( + "context" "embed" "github.com/spf13/cobra" @@ -48,7 +49,7 @@ The user provided by flags needs privileges to Run: func(cmd *cobra.Command, args []string) { config := MustNewConfig(viper.GetViper()) - InitAll(config) + InitAll(cmd.Context(), config) }, } @@ -56,7 +57,7 @@ The user provided by flags needs privileges to return cmd } -func InitAll(config *Config) { +func InitAll(ctx context.Context, config *Config) { err := initialise(config.Database, VerifyUser(config.Database.Username(), config.Database.Password()), VerifyDatabase(config.Database.DatabaseName()), @@ -64,7 +65,7 @@ func InitAll(config *Config) { ) logging.OnError(err).Fatal("unable to initialize the database") - err = verifyZitadel(config.Database) + err = verifyZitadel(ctx, config.Database) logging.OnError(err).Fatal("unable to initialize ZITADEL") } diff --git a/cmd/initialise/verify_zitadel.go b/cmd/initialise/verify_zitadel.go index 7307ad94d9..c0034141db 100644 --- a/cmd/initialise/verify_zitadel.go +++ b/cmd/initialise/verify_zitadel.go @@ -1,6 +1,7 @@ package initialise import ( + "context" _ "embed" "fmt" @@ -23,13 +24,13 @@ Prereqesits: `, Run: func(cmd *cobra.Command, args []string) { config := MustNewConfig(viper.GetViper()) - err := verifyZitadel(config.Database) + err := verifyZitadel(cmd.Context(), config.Database) logging.OnError(err).Fatal("unable to init zitadel") }, } } -func VerifyZitadel(db *database.DB, config database.Config) error { +func VerifyZitadel(ctx context.Context, db *database.DB, config database.Config) error { err := ReadStmts(config.Type()) if err != nil { return err @@ -41,7 +42,7 @@ func VerifyZitadel(db *database.DB, config database.Config) error { } logging.WithFields().Info("verify encryption keys") - if err := createEncryptionKeys(db); err != nil { + if err := createEncryptionKeys(ctx, db); err != nil { return err } @@ -56,7 +57,7 @@ func VerifyZitadel(db *database.DB, config database.Config) error { } logging.WithFields().Info("verify events tables") - if err := createEvents(db); err != nil { + if err := createEvents(ctx, db); err != nil { return err } @@ -73,7 +74,7 @@ func VerifyZitadel(db *database.DB, config database.Config) error { return nil } -func verifyZitadel(config database.Config) error { +func verifyZitadel(ctx context.Context, config database.Config) error { logging.WithFields("database", config.DatabaseName()).Info("verify zitadel") db, err := database.Connect(config, false, dialect.DBPurposeQuery) @@ -81,15 +82,15 @@ func verifyZitadel(config database.Config) error { return err } - if err := VerifyZitadel(db, config); err != nil { + if err := VerifyZitadel(ctx, db, config); err != nil { return err } return db.Close() } -func createEncryptionKeys(db *database.DB) error { - tx, err := db.Begin() +func createEncryptionKeys(ctx context.Context, db *database.DB) error { + tx, err := db.BeginTx(ctx, nil) if err != nil { return err } @@ -101,8 +102,8 @@ func createEncryptionKeys(db *database.DB) error { return tx.Commit() } -func createEvents(db *database.DB) (err error) { - tx, err := db.Begin() +func createEvents(ctx context.Context, db *database.DB) (err error) { + tx, err := db.BeginTx(ctx, nil) if err != nil { return err } diff --git a/cmd/initialise/verify_zitadel_test.go b/cmd/initialise/verify_zitadel_test.go index 4631396e29..64df01bdb1 100644 --- a/cmd/initialise/verify_zitadel_test.go +++ b/cmd/initialise/verify_zitadel_test.go @@ -1,6 +1,7 @@ package initialise import ( + "context" "database/sql" "database/sql/driver" "errors" @@ -107,7 +108,7 @@ func Test_verifyEvents(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - if err := createEvents(tt.args.db.db); !errors.Is(err, tt.targetErr) { + if err := createEvents(context.Background(), tt.args.db.db); !errors.Is(err, tt.targetErr) { t.Errorf("createEvents() error = %v, want: %v", err, tt.targetErr) } if err := tt.args.db.mock.ExpectationsWereMet(); err != nil { @@ -160,7 +161,7 @@ func Test_verifyEncryptionKeys(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - if err := createEncryptionKeys(tt.args.db.db); !errors.Is(err, tt.targetErr) { + if err := createEncryptionKeys(context.Background(), tt.args.db.db); !errors.Is(err, tt.targetErr) { t.Errorf("createEvents() error = %v, want: %v", err, tt.targetErr) } if err := tt.args.db.mock.ExpectationsWereMet(); err != nil { diff --git a/cmd/key/key.go b/cmd/key/key.go index b18bb867e8..2691932784 100644 --- a/cmd/key/key.go +++ b/cmd/key/key.go @@ -74,7 +74,7 @@ new -f keys.yaml key2=anotherkey`, if err != nil { return err } - return storage.CreateKeys(keys...) + return storage.CreateKeys(cmd.Context(), keys...) }, } cmd.PersistentFlags().StringP(flagKeyFile, "f", "", "path to keys file") diff --git a/cmd/setup/03.go b/cmd/setup/03.go index 4a93888f6b..89978497ec 100644 --- a/cmd/setup/03.go +++ b/cmd/setup/03.go @@ -46,7 +46,7 @@ func (mig *FirstInstance) Execute(ctx context.Context) error { if err != nil { return fmt.Errorf("cannot start key storage: %w", err) } - if err = verifyKey(mig.userEncryptionKey, keyStorage); err != nil { + if err = verifyKey(ctx, mig.userEncryptionKey, keyStorage); err != nil { return err } userAlg, err := crypto.NewAESCrypto(mig.userEncryptionKey, keyStorage) @@ -54,7 +54,7 @@ func (mig *FirstInstance) Execute(ctx context.Context) error { return err } - if err = verifyKey(mig.smtpEncryptionKey, keyStorage); err != nil { + if err = verifyKey(ctx, mig.smtpEncryptionKey, keyStorage); err != nil { return err } smtpEncryption, err := crypto.NewAESCrypto(mig.smtpEncryptionKey, keyStorage) @@ -62,7 +62,7 @@ func (mig *FirstInstance) Execute(ctx context.Context) error { return err } - if err = verifyKey(mig.oidcEncryptionKey, keyStorage); err != nil { + if err = verifyKey(ctx, mig.oidcEncryptionKey, keyStorage); err != nil { return err } oidcEncryption, err := crypto.NewAESCrypto(mig.oidcEncryptionKey, keyStorage) @@ -167,7 +167,7 @@ func (mig *FirstInstance) String() string { return "03_default_instance" } -func verifyKey(key *crypto.KeyConfig, storage crypto.KeyStorage) (err error) { +func verifyKey(ctx context.Context, key *crypto.KeyConfig, storage crypto.KeyStorage) (err error) { _, err = crypto.LoadKey(key.EncryptionKeyID, storage) if err == nil { return nil @@ -176,5 +176,5 @@ func verifyKey(key *crypto.KeyConfig, storage crypto.KeyStorage) (err error) { if err != nil { return err } - return storage.CreateKeys(k) + return storage.CreateKeys(ctx, k) } diff --git a/cmd/start/encryption_keys.go b/cmd/start/encryption_keys.go index b5943bf40b..42727b095d 100644 --- a/cmd/start/encryption_keys.go +++ b/cmd/start/encryption_keys.go @@ -1,6 +1,8 @@ package start import ( + "context" + "github.com/zitadel/zitadel/internal/crypto" "github.com/zitadel/zitadel/internal/zerrors" ) @@ -34,8 +36,8 @@ type encryptionKeys struct { OIDCKey []byte } -func ensureEncryptionKeys(keyConfig *encryptionKeyConfig, keyStorage crypto.KeyStorage) (keys *encryptionKeys, err error) { - if err := verifyDefaultKeys(keyStorage); err != nil { +func ensureEncryptionKeys(ctx context.Context, keyConfig *encryptionKeyConfig, keyStorage crypto.KeyStorage) (keys *encryptionKeys, err error) { + if err := verifyDefaultKeys(ctx, keyStorage); err != nil { return nil, err } keys = new(encryptionKeys) @@ -89,7 +91,7 @@ func ensureEncryptionKeys(keyConfig *encryptionKeyConfig, keyStorage crypto.KeyS return keys, nil } -func verifyDefaultKeys(keyStorage crypto.KeyStorage) (err error) { +func verifyDefaultKeys(ctx context.Context, keyStorage crypto.KeyStorage) (err error) { keys := make([]*crypto.Key, 0, len(defaultKeyIDs)) for _, keyID := range defaultKeyIDs { _, err := crypto.LoadKey(keyID, keyStorage) @@ -105,7 +107,7 @@ func verifyDefaultKeys(keyStorage crypto.KeyStorage) (err error) { if len(keys) == 0 { return nil } - if err := keyStorage.CreateKeys(keys...); err != nil { + if err := keyStorage.CreateKeys(ctx, keys...); err != nil { return zerrors.ThrowInternal(err, "START-aGBq2", "cannot create default keys") } return nil diff --git a/cmd/start/start.go b/cmd/start/start.go index f367d923fd..d75a3bb3d3 100644 --- a/cmd/start/start.go +++ b/cmd/start/start.go @@ -95,7 +95,7 @@ Requirements: if err != nil { return err } - return startZitadel(config, masterKey, server) + return startZitadel(cmd.Context(), config, masterKey, server) }, } @@ -119,11 +119,9 @@ type Server struct { Shutdown chan<- os.Signal } -func startZitadel(config *Config, masterKey string, server chan<- *Server) error { +func startZitadel(ctx context.Context, config *Config, masterKey string, server chan<- *Server) error { showBasicInformation(config) - ctx := context.Background() - i18n.MustLoadSupportedLanguagesFromDir() queryDBClient, err := database.Connect(config.Database, false, dialect.DBPurposeQuery) @@ -143,7 +141,7 @@ func startZitadel(config *Config, masterKey string, server chan<- *Server) error if err != nil { return fmt.Errorf("cannot start key storage: %w", err) } - keys, err := ensureEncryptionKeys(config.EncryptionKeys, keyStorage) + keys, err := ensureEncryptionKeys(ctx, config.EncryptionKeys, keyStorage) if err != nil { return err } diff --git a/cmd/start/start_from_init.go b/cmd/start/start_from_init.go index 940efb4e84..d07ff89cf4 100644 --- a/cmd/start/start_from_init.go +++ b/cmd/start/start_from_init.go @@ -29,7 +29,7 @@ Requirements: masterKey, err := key.MasterKey(cmd) logging.OnError(err).Panic("No master key provided") - initialise.InitAll(initialise.MustNewConfig(viper.GetViper())) + initialise.InitAll(cmd.Context(), initialise.MustNewConfig(viper.GetViper())) setupConfig := setup.MustNewConfig(viper.GetViper()) setupSteps := setup.MustNewSteps(viper.New()) @@ -37,7 +37,7 @@ Requirements: startConfig := MustNewConfig(viper.GetViper()) - err = startZitadel(startConfig, masterKey, server) + err = startZitadel(cmd.Context(), startConfig, masterKey, server) logging.OnError(err).Fatal("unable to start zitadel") }, } diff --git a/cmd/start/start_from_setup.go b/cmd/start/start_from_setup.go index 0be315fae9..ea1bd42a05 100644 --- a/cmd/start/start_from_setup.go +++ b/cmd/start/start_from_setup.go @@ -35,7 +35,7 @@ Requirements: startConfig := MustNewConfig(viper.GetViper()) - err = startZitadel(startConfig, masterKey, server) + err = startZitadel(cmd.Context(), startConfig, masterKey, server) logging.OnError(err).Fatal("unable to start zitadel") }, } diff --git a/internal/admin/repository/eventsourcing/view/view.go b/internal/admin/repository/eventsourcing/view/view.go index 9ede972813..2f26b8c81d 100644 --- a/internal/admin/repository/eventsourcing/view/view.go +++ b/internal/admin/repository/eventsourcing/view/view.go @@ -1,17 +1,13 @@ package view import ( - "context" - "github.com/jinzhu/gorm" - "github.com/zitadel/zitadel/internal/api/call" "github.com/zitadel/zitadel/internal/database" ) type View struct { - Db *gorm.DB - client *database.DB + Db *gorm.DB } func StartView(sqlClient *database.DB) (*View, error) { @@ -20,15 +16,10 @@ func StartView(sqlClient *database.DB) (*View, error) { return nil, err } return &View{ - Db: gorm, - client: sqlClient, + Db: gorm, }, nil } func (v *View) Health() (err error) { return v.Db.DB().Ping() } - -func (v *View) TimeTravel(ctx context.Context, tableName string) string { - return tableName + v.client.Timetravel(call.Took(ctx)) -} diff --git a/internal/crypto/database/database.go b/internal/crypto/database/database.go index 8fac222d9d..2d8666fd07 100644 --- a/internal/crypto/database/database.go +++ b/internal/crypto/database/database.go @@ -1,6 +1,7 @@ package database import ( + "context" "database/sql" sq "github.com/Masterminds/squirrel" @@ -98,7 +99,7 @@ func (d *database) ReadKey(id string) (_ *crypto.Key, err error) { }, nil } -func (d *database) CreateKeys(keys ...*crypto.Key) error { +func (d *database) CreateKeys(ctx context.Context, keys ...*crypto.Key) error { insert := sq.Insert(EncryptionKeysTable). Columns(encryptionKeysIDCol, encryptionKeysKeyCol).PlaceholderFormat(sq.Dollar) for _, key := range keys { @@ -112,7 +113,7 @@ func (d *database) CreateKeys(keys ...*crypto.Key) error { if err != nil { return zerrors.ThrowInternal(err, "", "unable to insert new keys") } - tx, err := d.client.Begin() + tx, err := d.client.BeginTx(ctx, nil) if err != nil { return zerrors.ThrowInternal(err, "", "unable to insert new keys") } diff --git a/internal/crypto/database/database_test.go b/internal/crypto/database/database_test.go index f7c8313355..80d39e0586 100644 --- a/internal/crypto/database/database_test.go +++ b/internal/crypto/database/database_test.go @@ -1,6 +1,7 @@ package database import ( + "context" "database/sql" "database/sql/driver" "errors" @@ -395,7 +396,7 @@ func Test_database_CreateKeys(t *testing.T) { masterKey: tt.fields.masterKey, encrypt: tt.fields.encrypt, } - err := d.CreateKeys(tt.args.keys...) + err := d.CreateKeys(context.Background(), tt.args.keys...) if tt.res.err == nil { assert.NoError(t, err) } else if tt.res.err != nil && !tt.res.err(err) { diff --git a/internal/crypto/key_storage.go b/internal/crypto/key_storage.go index 523818a0c0..df169c5898 100644 --- a/internal/crypto/key_storage.go +++ b/internal/crypto/key_storage.go @@ -1,7 +1,9 @@ package crypto +import "context" + type KeyStorage interface { ReadKeys() (Keys, error) ReadKey(id string) (*Key, error) - CreateKeys(...*Key) error + CreateKeys(context.Context, ...*Key) error } diff --git a/internal/eventstore/handler/v2/handler.go b/internal/eventstore/handler/v2/handler.go index 0bc2020efa..3c457aee28 100644 --- a/internal/eventstore/handler/v2/handler.go +++ b/internal/eventstore/handler/v2/handler.go @@ -316,13 +316,17 @@ func (h *Handler) processEvents(ctx context.Context, config *triggerConfig) (add } }() + txCtx := ctx if h.txDuration > 0 { var cancel func() ctx, cancel = context.WithTimeout(ctx, h.txDuration) defer cancel() + // add 100ms to store current state if iteration takes too long + txCtx, cancel = context.WithTimeout(ctx, h.txDuration+100*time.Millisecond) + defer cancel() } - tx, err := h.client.Begin() + tx, err := h.client.BeginTx(txCtx, nil) if err != nil { return false, err } diff --git a/internal/eventstore/handler/v2/state_test.go b/internal/eventstore/handler/v2/state_test.go index 5e20773947..908d8c0dfc 100644 --- a/internal/eventstore/handler/v2/state_test.go +++ b/internal/eventstore/handler/v2/state_test.go @@ -121,7 +121,7 @@ func TestHandler_lockState(t *testing.T) { projection: tt.fields.projection, } - tx, err := tt.fields.mock.DB.Begin() + tx, err := tt.fields.mock.DB.BeginTx(context.Background(), nil) if err != nil { t.Fatalf("unable to begin transaction: %v", err) } @@ -244,7 +244,7 @@ func TestHandler_updateLastUpdated(t *testing.T) { } } t.Run(tt.name, func(t *testing.T) { - tx, err := tt.fields.mock.DB.Begin() + tx, err := tt.fields.mock.DB.BeginTx(context.Background(), nil) if err != nil { t.Fatalf("unable to begin transaction: %v", err) } @@ -433,7 +433,7 @@ func TestHandler_currentState(t *testing.T) { projection: tt.fields.projection, } - tx, err := tt.fields.mock.DB.Begin() + tx, err := tt.fields.mock.DB.BeginTx(context.Background(), nil) if err != nil { t.Fatalf("unable to begin transaction: %v", err) } diff --git a/internal/eventstore/local_crdb_test.go b/internal/eventstore/local_crdb_test.go index 81ca937287..d77f6a3dd7 100644 --- a/internal/eventstore/local_crdb_test.go +++ b/internal/eventstore/local_crdb_test.go @@ -88,7 +88,7 @@ func initDB(db *database.DB) error { if err != nil { return err } - err = initialise.VerifyZitadel(db, *config) + err = initialise.VerifyZitadel(context.Background(), db, *config) if err != nil { return err } diff --git a/internal/eventstore/repository/sql/local_crdb_test.go b/internal/eventstore/repository/sql/local_crdb_test.go index b1aa485887..3eede324a4 100644 --- a/internal/eventstore/repository/sql/local_crdb_test.go +++ b/internal/eventstore/repository/sql/local_crdb_test.go @@ -1,6 +1,7 @@ package sql import ( + "context" "database/sql" "os" "testing" @@ -60,7 +61,7 @@ func initDB(db *database.DB) error { return err } - err = initialise.VerifyZitadel(db, *config) + err = initialise.VerifyZitadel(context.Background(), db, *config) if err != nil { return err } diff --git a/internal/eventstore/v3/push.go b/internal/eventstore/v3/push.go index a86e426132..3c7ae903dc 100644 --- a/internal/eventstore/v3/push.go +++ b/internal/eventstore/v3/push.go @@ -19,7 +19,7 @@ import ( ) func (es *Eventstore) Push(ctx context.Context, commands ...eventstore.Command) (events []eventstore.Event, err error) { - tx, err := es.client.Begin() + tx, err := es.client.BeginTx(ctx, nil) if err != nil { return nil, err } diff --git a/internal/query/current_state.go b/internal/query/current_state.go index 73c8eaeb9a..eff2f05398 100644 --- a/internal/query/current_state.go +++ b/internal/query/current_state.go @@ -109,7 +109,7 @@ func (q *Queries) latestState(ctx context.Context, projections ...table) (state } func (q *Queries) ClearCurrentSequence(ctx context.Context, projectionName string) (err error) { - tx, err := q.client.Begin() + tx, err := q.client.BeginTx(ctx, nil) if err != nil { return zerrors.ThrowInternal(err, "QUERY-9iOpr", "Errors.RemoveFailed") }