fix(db): always use begin tx (#7142)

* fix(db): always use begin tx

* fix(handler): timeout for begin
This commit is contained in:
Silvan
2024-01-04 17:12:20 +01:00
committed by GitHub
parent c0cef4983a
commit b7d027e2fd
19 changed files with 59 additions and 56 deletions

View File

@@ -1,17 +1,13 @@
package view
import (
"context"
"github.com/jinzhu/gorm"
"github.com/zitadel/zitadel/internal/api/call"
"github.com/zitadel/zitadel/internal/database"
)
type View struct {
Db *gorm.DB
client *database.DB
Db *gorm.DB
}
func StartView(sqlClient *database.DB) (*View, error) {
@@ -20,15 +16,10 @@ func StartView(sqlClient *database.DB) (*View, error) {
return nil, err
}
return &View{
Db: gorm,
client: sqlClient,
Db: gorm,
}, nil
}
func (v *View) Health() (err error) {
return v.Db.DB().Ping()
}
func (v *View) TimeTravel(ctx context.Context, tableName string) string {
return tableName + v.client.Timetravel(call.Took(ctx))
}

View File

@@ -1,6 +1,7 @@
package database
import (
"context"
"database/sql"
sq "github.com/Masterminds/squirrel"
@@ -98,7 +99,7 @@ func (d *database) ReadKey(id string) (_ *crypto.Key, err error) {
}, nil
}
func (d *database) CreateKeys(keys ...*crypto.Key) error {
func (d *database) CreateKeys(ctx context.Context, keys ...*crypto.Key) error {
insert := sq.Insert(EncryptionKeysTable).
Columns(encryptionKeysIDCol, encryptionKeysKeyCol).PlaceholderFormat(sq.Dollar)
for _, key := range keys {
@@ -112,7 +113,7 @@ func (d *database) CreateKeys(keys ...*crypto.Key) error {
if err != nil {
return zerrors.ThrowInternal(err, "", "unable to insert new keys")
}
tx, err := d.client.Begin()
tx, err := d.client.BeginTx(ctx, nil)
if err != nil {
return zerrors.ThrowInternal(err, "", "unable to insert new keys")
}

View File

@@ -1,6 +1,7 @@
package database
import (
"context"
"database/sql"
"database/sql/driver"
"errors"
@@ -395,7 +396,7 @@ func Test_database_CreateKeys(t *testing.T) {
masterKey: tt.fields.masterKey,
encrypt: tt.fields.encrypt,
}
err := d.CreateKeys(tt.args.keys...)
err := d.CreateKeys(context.Background(), tt.args.keys...)
if tt.res.err == nil {
assert.NoError(t, err)
} else if tt.res.err != nil && !tt.res.err(err) {

View File

@@ -1,7 +1,9 @@
package crypto
import "context"
type KeyStorage interface {
ReadKeys() (Keys, error)
ReadKey(id string) (*Key, error)
CreateKeys(...*Key) error
CreateKeys(context.Context, ...*Key) error
}

View File

@@ -316,13 +316,17 @@ func (h *Handler) processEvents(ctx context.Context, config *triggerConfig) (add
}
}()
txCtx := ctx
if h.txDuration > 0 {
var cancel func()
ctx, cancel = context.WithTimeout(ctx, h.txDuration)
defer cancel()
// add 100ms to store current state if iteration takes too long
txCtx, cancel = context.WithTimeout(ctx, h.txDuration+100*time.Millisecond)
defer cancel()
}
tx, err := h.client.Begin()
tx, err := h.client.BeginTx(txCtx, nil)
if err != nil {
return false, err
}

View File

@@ -121,7 +121,7 @@ func TestHandler_lockState(t *testing.T) {
projection: tt.fields.projection,
}
tx, err := tt.fields.mock.DB.Begin()
tx, err := tt.fields.mock.DB.BeginTx(context.Background(), nil)
if err != nil {
t.Fatalf("unable to begin transaction: %v", err)
}
@@ -244,7 +244,7 @@ func TestHandler_updateLastUpdated(t *testing.T) {
}
}
t.Run(tt.name, func(t *testing.T) {
tx, err := tt.fields.mock.DB.Begin()
tx, err := tt.fields.mock.DB.BeginTx(context.Background(), nil)
if err != nil {
t.Fatalf("unable to begin transaction: %v", err)
}
@@ -433,7 +433,7 @@ func TestHandler_currentState(t *testing.T) {
projection: tt.fields.projection,
}
tx, err := tt.fields.mock.DB.Begin()
tx, err := tt.fields.mock.DB.BeginTx(context.Background(), nil)
if err != nil {
t.Fatalf("unable to begin transaction: %v", err)
}

View File

@@ -88,7 +88,7 @@ func initDB(db *database.DB) error {
if err != nil {
return err
}
err = initialise.VerifyZitadel(db, *config)
err = initialise.VerifyZitadel(context.Background(), db, *config)
if err != nil {
return err
}

View File

@@ -1,6 +1,7 @@
package sql
import (
"context"
"database/sql"
"os"
"testing"
@@ -60,7 +61,7 @@ func initDB(db *database.DB) error {
return err
}
err = initialise.VerifyZitadel(db, *config)
err = initialise.VerifyZitadel(context.Background(), db, *config)
if err != nil {
return err
}

View File

@@ -19,7 +19,7 @@ import (
)
func (es *Eventstore) Push(ctx context.Context, commands ...eventstore.Command) (events []eventstore.Event, err error) {
tx, err := es.client.Begin()
tx, err := es.client.BeginTx(ctx, nil)
if err != nil {
return nil, err
}

View File

@@ -109,7 +109,7 @@ func (q *Queries) latestState(ctx context.Context, projections ...table) (state
}
func (q *Queries) ClearCurrentSequence(ctx context.Context, projectionName string) (err error) {
tx, err := q.client.Begin()
tx, err := q.client.BeginTx(ctx, nil)
if err != nil {
return zerrors.ThrowInternal(err, "QUERY-9iOpr", "Errors.RemoveFailed")
}