fix(handler): optimise snapshot hanlding (#8652)

# Which Problems Are Solved

There are cases where not all statements of multiExec are succeed. This
leads to inconsistent states. One example is [LDAP
IDPs](https://github.com/zitadel/zitadel/issues/7959).

If statements get executed only partially this can lead to inconsistent
states or even break projections for objects which might not were
correctly created in a sub table.

This behaviour is possible because we use
[`SAVEPOINTS`](https://www.postgresql.org/docs/current/sql-savepoint.html)
during each statement of a multiExec.

# How the Problems Are Solved

SAVEPOINTS are only created at the beginning of an exec function not
during every execution like before. Additionally `RELEASE` or `ROLLBACK`
of `SAVEPOINTS` are only used when needed.

# Additional Changes

- refactor some unused parameters

# Additional Context

- closes https://github.com/zitadel/zitadel/issues/7959

(cherry picked from commit ddeeeed30375a888b314c5a5bc9c2182d33916c9)
This commit is contained in:
Silvan 2024-10-02 17:34:19 +02:00 committed by Livio Spring
parent c05bc0f6cf
commit 250c61529b
No known key found for this signature in database
GPG Key ID: 26BB1C2FA5952CF0
4 changed files with 26 additions and 30 deletions

View File

@ -513,7 +513,7 @@ func (h *Handler) processEvents(ctx context.Context, config *triggerConfig) (add
return additionalIteration, err return additionalIteration, err
} }
lastProcessedIndex, err := h.executeStatements(ctx, tx, currentState, statements) lastProcessedIndex, err := h.executeStatements(ctx, tx, statements)
h.log().OnError(err).WithField("lastProcessedIndex", lastProcessedIndex).Debug("execution of statements failed") h.log().OnError(err).WithField("lastProcessedIndex", lastProcessedIndex).Debug("execution of statements failed")
if lastProcessedIndex < 0 { if lastProcessedIndex < 0 {
return false, err return false, err
@ -585,7 +585,7 @@ func skipPreviouslyReduced(statements []*Statement, currentState *state) int {
return -1 return -1
} }
func (h *Handler) executeStatements(ctx context.Context, tx *sql.Tx, currentState *state, statements []*Statement) (lastProcessedIndex int, err error) { func (h *Handler) executeStatements(ctx context.Context, tx *sql.Tx, statements []*Statement) (lastProcessedIndex int, err error) {
lastProcessedIndex = -1 lastProcessedIndex = -1
for i, statement := range statements { for i, statement := range statements {
@ -593,7 +593,7 @@ func (h *Handler) executeStatements(ctx context.Context, tx *sql.Tx, currentStat
case <-ctx.Done(): case <-ctx.Done():
break break
default: default:
err := h.executeStatement(ctx, tx, currentState, statement) err := h.executeStatement(ctx, tx, statement)
if err != nil { if err != nil {
return lastProcessedIndex, err return lastProcessedIndex, err
} }
@ -603,28 +603,24 @@ func (h *Handler) executeStatements(ctx context.Context, tx *sql.Tx, currentStat
return lastProcessedIndex, nil return lastProcessedIndex, nil
} }
func (h *Handler) executeStatement(ctx context.Context, tx *sql.Tx, currentState *state, statement *Statement) (err error) { func (h *Handler) executeStatement(ctx context.Context, tx *sql.Tx, statement *Statement) (err error) {
if statement.Execute == nil { if statement.Execute == nil {
return nil return nil
} }
_, err = tx.Exec("SAVEPOINT exec") _, err = tx.ExecContext(ctx, "SAVEPOINT exec_stmt")
if err != nil { if err != nil {
h.log().WithError(err).Debug("create savepoint failed") h.log().WithError(err).Debug("create savepoint failed")
return err return err
} }
var shouldContinue bool
defer func() {
_, errSave := tx.Exec("RELEASE SAVEPOINT exec")
if err == nil {
err = errSave
}
}()
if err = statement.Execute(tx, h.projection.Name()); err != nil { if err = statement.Execute(tx, h.projection.Name()); err != nil {
h.log().WithError(err).Error("statement execution failed") h.log().WithError(err).Error("statement execution failed")
shouldContinue = h.handleFailedStmt(tx, failureFromStatement(statement, err)) _, rollbackErr := tx.ExecContext(ctx, "ROLLBACK TO SAVEPOINT exec_stmt")
h.log().OnError(rollbackErr).Error("rollback to savepoint failed")
shouldContinue := h.handleFailedStmt(tx, failureFromStatement(statement, err))
if shouldContinue { if shouldContinue {
return nil return nil
} }

View File

@ -264,11 +264,23 @@ func NewViewCheck(selectStmt string, secondaryTables ...*SuffixedTable) *handler
} }
func execNextIfExists(config execConfig, q query, opts []execOption, executeNext bool) func(handler.Executer, string) (bool, error) { func execNextIfExists(config execConfig, q query, opts []execOption, executeNext bool) func(handler.Executer, string) (bool, error) {
return func(handler handler.Executer, name string) (bool, error) { return func(handler handler.Executer, name string) (shouldExecuteNext bool, err error) {
err := exec(config, q, opts)(handler, name) _, err = handler.Exec("SAVEPOINT exec_stmt")
if isErrAlreadyExists(err) { if err != nil {
return executeNext, nil return false, zerrors.ThrowInternal(err, "V2-U1wlz", "create savepoint failed")
} }
defer func() {
if err == nil {
return
}
if isErrAlreadyExists(err) {
_, err = handler.Exec("ROLLBACK TO SAVEPOINT exec_stmt")
shouldExecuteNext = executeNext
return
}
}()
err = exec(config, q, opts)(handler, name)
return false, err return false, err
} }
} }

View File

@ -601,18 +601,6 @@ func exec(config execConfig, q query, opts []execOption) Exec {
opt(&config) opt(&config)
} }
_, err = ex.Exec("SAVEPOINT stmt_exec")
if err != nil {
return zerrors.ThrowInternal(err, "CRDB-YdOXD", "create savepoint failed")
}
defer func() {
if err != nil {
_, rollbackErr := ex.Exec("ROLLBACK TO SAVEPOINT stmt_exec")
logging.OnError(rollbackErr).Debug("rollback failed")
return
}
_, err = ex.Exec("RELEASE SAVEPOINT stmt_exec")
}()
_, err = ex.Exec(q(config), config.args...) _, err = ex.Exec(q(config), config.args...)
if err != nil { if err != nil {
return zerrors.ThrowInternal(err, "CRDB-pKtsr", "exec failed") return zerrors.ThrowInternal(err, "CRDB-pKtsr", "exec failed")

View File

@ -25,7 +25,7 @@ type execution struct {
type anyArg struct{} type anyArg struct{}
func (e *testExecuter) Exec(stmt string, args ...interface{}) (sql.Result, error) { func (e *testExecuter) Exec(stmt string, args ...interface{}) (sql.Result, error) {
if stmt == "SAVEPOINT stmt_exec" || stmt == "RELEASE SAVEPOINT stmt_exec" { if stmt == "SAVEPOINT exec_stmt" {
return nil, nil return nil, nil
} }