mirror of
https://github.com/zitadel/zitadel.git
synced 2024-12-14 11:58:02 +00:00
ddeeeed303
# Which Problems Are Solved There are cases where not all statements of multiExec are succeed. This leads to inconsistent states. One example is [LDAP IDPs](https://github.com/zitadel/zitadel/issues/7959). If statements get executed only partially this can lead to inconsistent states or even break projections for objects which might not were correctly created in a sub table. This behaviour is possible because we use [`SAVEPOINTS`](https://www.postgresql.org/docs/current/sql-savepoint.html) during each statement of a multiExec. # How the Problems Are Solved SAVEPOINTS are only created at the beginning of an exec function not during every execution like before. Additionally `RELEASE` or `ROLLBACK` of `SAVEPOINTS` are only used when needed. # Additional Changes - refactor some unused parameters # Additional Context - closes https://github.com/zitadel/zitadel/issues/7959
63 lines
1.5 KiB
Go
63 lines
1.5 KiB
Go
package projection
|
|
|
|
import (
|
|
"database/sql"
|
|
"testing"
|
|
|
|
"github.com/stretchr/testify/assert"
|
|
|
|
"github.com/zitadel/zitadel/internal/zerrors"
|
|
)
|
|
|
|
type testExecuter struct {
|
|
execIdx int
|
|
executions []execution
|
|
}
|
|
|
|
type execution struct {
|
|
expectedStmt string
|
|
gottenStmt string
|
|
|
|
expectedArgs []interface{}
|
|
gottenArgs []interface{}
|
|
}
|
|
|
|
type anyArg struct{}
|
|
|
|
func (e *testExecuter) Exec(stmt string, args ...interface{}) (sql.Result, error) {
|
|
if stmt == "SAVEPOINT exec_stmt" {
|
|
return nil, nil
|
|
}
|
|
|
|
if e.execIdx >= len(e.executions) {
|
|
return nil, zerrors.ThrowInternal(nil, "PROJE-8TNoE", "too many executions")
|
|
}
|
|
e.executions[e.execIdx].gottenArgs = args
|
|
e.executions[e.execIdx].gottenStmt = stmt
|
|
e.execIdx++
|
|
return nil, nil
|
|
}
|
|
|
|
func (e *testExecuter) Validate(t *testing.T) {
|
|
t.Helper()
|
|
if e.execIdx != len(e.executions) {
|
|
t.Errorf("not all expected execs executed. got: %d, want: %d", e.execIdx, len(e.executions))
|
|
return
|
|
}
|
|
for _, execution := range e.executions {
|
|
if len(execution.gottenArgs) != len(execution.expectedArgs) {
|
|
t.Errorf("wrong arg len expected: %d got: %d", len(execution.expectedArgs), len(execution.gottenArgs))
|
|
} else {
|
|
for i := 0; i < len(execution.expectedArgs); i++ {
|
|
if _, ok := execution.expectedArgs[i].(anyArg); ok {
|
|
continue
|
|
}
|
|
assert.Equal(t, execution.expectedArgs[i], execution.gottenArgs[i], "wrong argument at index %d", i)
|
|
}
|
|
}
|
|
if execution.gottenStmt != execution.expectedStmt {
|
|
t.Errorf("wrong stmt want:\n%s\ngot:\n%s", execution.expectedStmt, execution.gottenStmt)
|
|
}
|
|
}
|
|
}
|