mirror of
https://github.com/zitadel/zitadel.git
synced 2025-08-11 21:37:32 +00:00
feat(eventstore): increase parallel write capabilities (#5940)
This implementation increases parallel write capabilities of the eventstore. Please have a look at the technical advisories: [05](https://zitadel.com/docs/support/advisory/a10005) and [06](https://zitadel.com/docs/support/advisory/a10006). The implementation of eventstore.push is rewritten and stored events are migrated to a new table `eventstore.events2`. If you are using cockroach: make sure that the database user of ZITADEL has `VIEWACTIVITY` grant. This is used to query events.
This commit is contained in:
@@ -1,83 +0,0 @@
|
||||
package crdb
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/database"
|
||||
"github.com/zitadel/zitadel/internal/errors"
|
||||
"github.com/zitadel/zitadel/internal/eventstore"
|
||||
)
|
||||
|
||||
const (
|
||||
currentSequenceStmtFormat = `SELECT current_sequence, aggregate_type, instance_id FROM %s WHERE projection_name = $1 AND instance_id = ANY ($2) FOR UPDATE`
|
||||
currentSequenceStmtWithoutLockFormat = `SELECT current_sequence, aggregate_type, instance_id FROM %s WHERE projection_name = $1 AND instance_id = ANY ($2)`
|
||||
updateCurrentSequencesStmtFormat = `INSERT INTO %s (projection_name, aggregate_type, current_sequence, instance_id, timestamp) VALUES `
|
||||
updateCurrentSequencesConflictStmt = ` ON CONFLICT (projection_name, aggregate_type, instance_id) DO UPDATE SET current_sequence = EXCLUDED.current_sequence, timestamp = EXCLUDED.timestamp`
|
||||
)
|
||||
|
||||
type currentSequences map[eventstore.AggregateType][]*instanceSequence
|
||||
|
||||
type instanceSequence struct {
|
||||
instanceID string
|
||||
sequence uint64
|
||||
}
|
||||
|
||||
func (h *StatementHandler) currentSequences(ctx context.Context, isTx bool, query func(context.Context, func(*sql.Rows) error, string, ...interface{}) error, instanceIDs database.StringArray) (currentSequences, error) {
|
||||
stmt := h.currentSequenceStmt
|
||||
if !isTx {
|
||||
stmt = h.currentSequenceWithoutLockStmt
|
||||
}
|
||||
|
||||
sequences := make(currentSequences, len(h.aggregates))
|
||||
err := query(ctx,
|
||||
func(rows *sql.Rows) error {
|
||||
for rows.Next() {
|
||||
var (
|
||||
aggregateType eventstore.AggregateType
|
||||
sequence uint64
|
||||
instanceID string
|
||||
)
|
||||
|
||||
err := rows.Scan(&sequence, &aggregateType, &instanceID)
|
||||
if err != nil {
|
||||
return errors.ThrowInternal(err, "CRDB-dbatK", "scan failed")
|
||||
}
|
||||
|
||||
sequences[aggregateType] = append(sequences[aggregateType], &instanceSequence{
|
||||
sequence: sequence,
|
||||
instanceID: instanceID,
|
||||
})
|
||||
}
|
||||
return nil
|
||||
},
|
||||
stmt, h.ProjectionName, instanceIDs)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return sequences, nil
|
||||
}
|
||||
|
||||
func (h *StatementHandler) updateCurrentSequences(tx *sql.Tx, sequences currentSequences) error {
|
||||
valueQueries := make([]string, 0, len(sequences))
|
||||
valueCounter := 0
|
||||
values := make([]interface{}, 0, len(sequences)*3)
|
||||
for aggregate, instanceSequence := range sequences {
|
||||
for _, sequence := range instanceSequence {
|
||||
valueQueries = append(valueQueries, "($"+strconv.Itoa(valueCounter+1)+", $"+strconv.Itoa(valueCounter+2)+", $"+strconv.Itoa(valueCounter+3)+", $"+strconv.Itoa(valueCounter+4)+", NOW())")
|
||||
valueCounter += 4
|
||||
values = append(values, h.ProjectionName, aggregate, sequence.sequence, sequence.instanceID)
|
||||
}
|
||||
}
|
||||
|
||||
res, err := tx.Exec(h.updateSequencesBaseStmt+strings.Join(valueQueries, ", ")+updateCurrentSequencesConflictStmt, values...)
|
||||
if err != nil {
|
||||
return errors.ThrowInternal(err, "CRDB-TrH2Z", "unable to exec update sequence")
|
||||
}
|
||||
if rows, _ := res.RowsAffected(); rows < 1 {
|
||||
return errSeqNotUpdated
|
||||
}
|
||||
return nil
|
||||
}
|
@@ -1,301 +1,15 @@
|
||||
package crdb
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"database/sql/driver"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/DATA-DOG/go-sqlmock"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/database"
|
||||
"github.com/zitadel/zitadel/internal/eventstore"
|
||||
)
|
||||
|
||||
type mockExpectation func(sqlmock.Sqlmock)
|
||||
|
||||
func expectFailureCount(tableName string, projectionName, instanceID string, failedSeq, failureCount uint64) func(sqlmock.Sqlmock) {
|
||||
return func(m sqlmock.Sqlmock) {
|
||||
m.ExpectQuery(`WITH failures AS \(SELECT failure_count FROM `+tableName+` WHERE projection_name = \$1 AND failed_sequence = \$2 AND instance_id = \$3\) SELECT COALESCE\(\(SELECT failure_count FROM failures\), 0\) AS failure_count`).
|
||||
WithArgs(projectionName, failedSeq, instanceID).
|
||||
WillReturnRows(
|
||||
sqlmock.NewRows([]string{"failure_count"}).
|
||||
AddRow(failureCount),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
func expectUpdateFailureCount(tableName string, projectionName, instanceID string, seq, failureCount uint64) func(sqlmock.Sqlmock) {
|
||||
return func(m sqlmock.Sqlmock) {
|
||||
m.ExpectExec(`INSERT INTO `+tableName+` \(projection_name, failed_sequence, failure_count, error, instance_id, last_failed\) VALUES \(\$1, \$2, \$3, \$4\, \$5\, \$6\) ON CONFLICT \(projection_name, failed_sequence, instance_id\) DO UPDATE SET failure_count = EXCLUDED\.failure_count, error = EXCLUDED\.error, last_failed = EXCLUDED\.last_failed`).
|
||||
WithArgs(projectionName, seq, failureCount, sqlmock.AnyArg(), instanceID, "NOW()").WillReturnResult(sqlmock.NewResult(1, 1))
|
||||
}
|
||||
}
|
||||
|
||||
func expectCreate(projectionName string, columnNames, placeholders []string) func(sqlmock.Sqlmock) {
|
||||
return func(m sqlmock.Sqlmock) {
|
||||
args := make([]driver.Value, len(columnNames))
|
||||
for i := 0; i < len(columnNames); i++ {
|
||||
args[i] = sqlmock.AnyArg()
|
||||
placeholders[i] = `\` + placeholders[i]
|
||||
}
|
||||
m.ExpectExec("INSERT INTO " + projectionName + ` \(` + strings.Join(columnNames, ", ") + `\) VALUES \(` + strings.Join(placeholders, ", ") + `\)`).
|
||||
WithArgs(args...).
|
||||
WillReturnResult(sqlmock.NewResult(1, 1))
|
||||
}
|
||||
}
|
||||
|
||||
func expectCreateErr(projectionName string, columnNames, placeholders []string, err error) func(sqlmock.Sqlmock) {
|
||||
return func(m sqlmock.Sqlmock) {
|
||||
args := make([]driver.Value, len(columnNames))
|
||||
for i := 0; i < len(columnNames); i++ {
|
||||
args[i] = sqlmock.AnyArg()
|
||||
placeholders[i] = `\` + placeholders[i]
|
||||
}
|
||||
m.ExpectExec("INSERT INTO " + projectionName + ` \(` + strings.Join(columnNames, ", ") + `\) VALUES \(` + strings.Join(placeholders, ", ") + `\)`).
|
||||
WithArgs(args...).
|
||||
WillReturnError(err)
|
||||
}
|
||||
}
|
||||
|
||||
func expectBegin() func(sqlmock.Sqlmock) {
|
||||
return func(m sqlmock.Sqlmock) {
|
||||
m.ExpectBegin()
|
||||
}
|
||||
}
|
||||
|
||||
func expectBeginErr(err error) func(sqlmock.Sqlmock) {
|
||||
return func(m sqlmock.Sqlmock) {
|
||||
m.ExpectBegin().WillReturnError(err)
|
||||
}
|
||||
}
|
||||
|
||||
func expectCommit() func(sqlmock.Sqlmock) {
|
||||
return func(m sqlmock.Sqlmock) {
|
||||
m.ExpectCommit()
|
||||
}
|
||||
}
|
||||
|
||||
func expectCommitErr(err error) func(sqlmock.Sqlmock) {
|
||||
return func(m sqlmock.Sqlmock) {
|
||||
m.ExpectCommit().WillReturnError(err)
|
||||
}
|
||||
}
|
||||
|
||||
func expectRollback() func(sqlmock.Sqlmock) {
|
||||
return func(m sqlmock.Sqlmock) {
|
||||
m.ExpectRollback()
|
||||
}
|
||||
}
|
||||
|
||||
func expectSavePoint() func(sqlmock.Sqlmock) {
|
||||
return func(m sqlmock.Sqlmock) {
|
||||
m.ExpectExec("SAVEPOINT push_stmt").
|
||||
WillReturnResult(sqlmock.NewResult(1, 1))
|
||||
}
|
||||
}
|
||||
|
||||
func expectSavePointErr(err error) func(sqlmock.Sqlmock) {
|
||||
return func(m sqlmock.Sqlmock) {
|
||||
m.ExpectExec("SAVEPOINT push_stmt").
|
||||
WillReturnError(err)
|
||||
}
|
||||
}
|
||||
|
||||
func expectSavePointRollback() func(sqlmock.Sqlmock) {
|
||||
return func(m sqlmock.Sqlmock) {
|
||||
m.ExpectExec("ROLLBACK TO SAVEPOINT push_stmt").
|
||||
WillReturnResult(sqlmock.NewResult(1, 1))
|
||||
}
|
||||
}
|
||||
|
||||
func expectSavePointRollbackErr(err error) func(sqlmock.Sqlmock) {
|
||||
return func(m sqlmock.Sqlmock) {
|
||||
m.ExpectExec("ROLLBACK TO SAVEPOINT push_stmt").
|
||||
WillReturnError(err)
|
||||
}
|
||||
}
|
||||
|
||||
func expectSavePointRelease() func(sqlmock.Sqlmock) {
|
||||
return func(m sqlmock.Sqlmock) {
|
||||
m.ExpectExec("RELEASE push_stmt").
|
||||
WillReturnResult(sqlmock.NewResult(1, 1))
|
||||
}
|
||||
}
|
||||
|
||||
func expectCurrentSequence(isTx bool, tableName, projection string, seq uint64, aggregateType string, instanceIDs []string) func(sqlmock.Sqlmock) {
|
||||
rows := sqlmock.NewRows([]string{"current_sequence", "aggregate_type", "instance_id"})
|
||||
for _, instanceID := range instanceIDs {
|
||||
rows.AddRow(seq, aggregateType, instanceID)
|
||||
}
|
||||
return func(m sqlmock.Sqlmock) {
|
||||
stmt := `SELECT current_sequence, aggregate_type, instance_id FROM ` + tableName + ` WHERE projection_name = \$1 AND instance_id = ANY \(\$2\)`
|
||||
if isTx {
|
||||
stmt += " FOR UPDATE"
|
||||
}
|
||||
m.ExpectQuery(stmt).
|
||||
WithArgs(
|
||||
projection,
|
||||
database.StringArray(instanceIDs),
|
||||
).
|
||||
WillReturnRows(
|
||||
rows,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
func expectCurrentSequenceErr(isTx bool, tableName, projection string, instanceIDs []string, err error) func(sqlmock.Sqlmock) {
|
||||
return func(m sqlmock.Sqlmock) {
|
||||
stmt := `SELECT current_sequence, aggregate_type, instance_id FROM ` + tableName + ` WHERE projection_name = \$1 AND instance_id = ANY \(\$2\)`
|
||||
if isTx {
|
||||
stmt += " FOR UPDATE"
|
||||
}
|
||||
m.ExpectQuery(stmt).
|
||||
WithArgs(
|
||||
projection,
|
||||
database.StringArray(instanceIDs),
|
||||
).
|
||||
WillReturnError(err)
|
||||
}
|
||||
}
|
||||
|
||||
func expectCurrentSequenceNoRows(tableName, projection string, instanceIDs []string) func(sqlmock.Sqlmock) {
|
||||
return func(m sqlmock.Sqlmock) {
|
||||
m.ExpectQuery(`SELECT current_sequence, aggregate_type, instance_id FROM `+tableName+` WHERE projection_name = \$1 AND instance_id = ANY \(\$2\) FOR UPDATE`).
|
||||
WithArgs(
|
||||
projection,
|
||||
database.StringArray(instanceIDs),
|
||||
).
|
||||
WillReturnRows(
|
||||
sqlmock.NewRows([]string{"current_sequence", "aggregate_type", "instance_id"}),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
func expectCurrentSequenceScanErr(tableName, projection string, instanceIDs []string) func(sqlmock.Sqlmock) {
|
||||
return func(m sqlmock.Sqlmock) {
|
||||
m.ExpectQuery(`SELECT current_sequence, aggregate_type, instance_id FROM `+tableName+` WHERE projection_name = \$1 AND instance_id = ANY \(\$2\) FOR UPDATE`).
|
||||
WithArgs(
|
||||
projection,
|
||||
database.StringArray(instanceIDs),
|
||||
).
|
||||
WillReturnRows(
|
||||
sqlmock.NewRows([]string{"current_sequence", "aggregate_type", "instance_id"}).
|
||||
RowError(0, sql.ErrTxDone).
|
||||
AddRow(0, "agg", "instanceID"),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
func expectUpdateCurrentSequence(tableName, projection string, seq uint64, aggregateType, instanceID string) func(sqlmock.Sqlmock) {
|
||||
return func(m sqlmock.Sqlmock) {
|
||||
m.ExpectExec("INSERT INTO "+tableName+` \(projection_name, aggregate_type, current_sequence, instance_id, timestamp\) VALUES \(\$1, \$2, \$3, \$4, NOW\(\)\) ON CONFLICT \(projection_name, aggregate_type, instance_id\) DO UPDATE SET current_sequence = EXCLUDED.current_sequence, timestamp = EXCLUDED.timestamp`).
|
||||
WithArgs(
|
||||
projection,
|
||||
aggregateType,
|
||||
seq,
|
||||
instanceID,
|
||||
).
|
||||
WillReturnResult(
|
||||
sqlmock.NewResult(1, 1),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
func expectUpdateThreeCurrentSequence(t *testing.T, tableName, projection string, sequences currentSequences) func(sqlmock.Sqlmock) {
|
||||
args := make([][]interface{}, 0)
|
||||
for aggregateType, instanceSequences := range sequences {
|
||||
for _, sequence := range instanceSequences {
|
||||
args = append(args, []interface{}{
|
||||
projection,
|
||||
aggregateType,
|
||||
sequence.sequence,
|
||||
sequence.instanceID,
|
||||
})
|
||||
}
|
||||
}
|
||||
matcher := ¤tSequenceMatcher{t: t, seq: args}
|
||||
matchers := make([]driver.Value, len(args)*4)
|
||||
for i := 0; i < len(args)*4; i++ {
|
||||
matchers[i] = matcher
|
||||
}
|
||||
return func(m sqlmock.Sqlmock) {
|
||||
m.ExpectExec("INSERT INTO " + tableName + ` \(projection_name, aggregate_type, current_sequence, instance_id, timestamp\) VALUES \(\$1, \$2, \$3, \$4, NOW\(\)\), \(\$5, \$6, \$7, \$8, NOW\(\)\), \(\$9, \$10, \$11, \$12, NOW\(\)\) ON CONFLICT \(projection_name, aggregate_type, instance_id\) DO UPDATE SET current_sequence = EXCLUDED.current_sequence, timestamp = EXCLUDED.timestamp`).
|
||||
WithArgs(
|
||||
matchers...,
|
||||
).
|
||||
WillReturnResult(
|
||||
sqlmock.NewResult(1, 1),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
type currentSequenceMatcher struct {
|
||||
seq [][]interface{}
|
||||
i int
|
||||
t *testing.T
|
||||
}
|
||||
|
||||
func (m *currentSequenceMatcher) Match(value driver.Value) bool {
|
||||
if m.i%4 == 0 {
|
||||
m.i = 0
|
||||
}
|
||||
left := make([]interface{}, 0, len(m.seq))
|
||||
for _, seq := range m.seq {
|
||||
found := seq[m.i]
|
||||
if found == nil {
|
||||
continue
|
||||
}
|
||||
switch v := value.(type) {
|
||||
case string:
|
||||
if found == v || found == eventstore.AggregateType(v) {
|
||||
seq[m.i] = nil
|
||||
m.i++
|
||||
return true
|
||||
}
|
||||
case int64:
|
||||
if found == uint64(v) {
|
||||
seq[m.i] = nil
|
||||
m.i++
|
||||
return true
|
||||
}
|
||||
}
|
||||
left = append(left, found)
|
||||
}
|
||||
m.t.Errorf("expected: %v, possible left values: %v", value, left)
|
||||
m.t.FailNow()
|
||||
return false
|
||||
}
|
||||
|
||||
func expectUpdateCurrentSequenceErr(tableName, projection string, seq uint64, err error, aggregateType, instanceID string) func(sqlmock.Sqlmock) {
|
||||
return func(m sqlmock.Sqlmock) {
|
||||
m.ExpectExec("INSERT INTO "+tableName+` \(projection_name, aggregate_type, current_sequence, instance_id, timestamp\) VALUES \(\$1, \$2, \$3, \$4, NOW\(\)\) ON CONFLICT \(projection_name, aggregate_type, instance_id\) DO UPDATE SET current_sequence = EXCLUDED.current_sequence, timestamp = EXCLUDED.timestamp`).
|
||||
WithArgs(
|
||||
projection,
|
||||
aggregateType,
|
||||
seq,
|
||||
instanceID,
|
||||
).
|
||||
WillReturnError(err)
|
||||
}
|
||||
}
|
||||
|
||||
func expectUpdateCurrentSequenceNoRows(tableName, projection string, seq uint64, aggregateType, instanceID string) func(sqlmock.Sqlmock) {
|
||||
return func(m sqlmock.Sqlmock) {
|
||||
m.ExpectExec("INSERT INTO "+tableName+` \(projection_name, aggregate_type, current_sequence, instance_id, timestamp\) VALUES \(\$1, \$2, \$3, \$4, NOW\(\)\) ON CONFLICT \(projection_name, aggregate_type, instance_id\) DO UPDATE SET current_sequence = EXCLUDED.current_sequence, timestamp = EXCLUDED.timestamp`).
|
||||
WithArgs(
|
||||
projection,
|
||||
aggregateType,
|
||||
seq,
|
||||
instanceID,
|
||||
).
|
||||
WillReturnResult(
|
||||
sqlmock.NewResult(0, 0),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
func expectLock(lockTable, workerName string, d time.Duration, instanceID string) func(sqlmock.Sqlmock) {
|
||||
return func(m sqlmock.Sqlmock) {
|
||||
m.ExpectExec(`INSERT INTO `+lockTable+
|
||||
@@ -308,7 +22,7 @@ func expectLock(lockTable, workerName string, d time.Duration, instanceID string
|
||||
d,
|
||||
projectionName,
|
||||
instanceID,
|
||||
database.StringArray{instanceID},
|
||||
database.TextArray[string]{instanceID},
|
||||
).
|
||||
WillReturnResult(
|
||||
sqlmock.NewResult(1, 1),
|
||||
@@ -329,7 +43,7 @@ func expectLockMultipleInstances(lockTable, workerName string, d time.Duration,
|
||||
projectionName,
|
||||
instanceID1,
|
||||
instanceID2,
|
||||
database.StringArray{instanceID1, instanceID2},
|
||||
database.TextArray[string]{instanceID1, instanceID2},
|
||||
).
|
||||
WillReturnResult(
|
||||
sqlmock.NewResult(1, 1),
|
||||
@@ -349,7 +63,7 @@ func expectLockNoRows(lockTable, workerName string, d time.Duration, instanceID
|
||||
d,
|
||||
projectionName,
|
||||
instanceID,
|
||||
database.StringArray{instanceID},
|
||||
database.TextArray[string]{instanceID},
|
||||
).
|
||||
WillReturnResult(driver.ResultNoRows)
|
||||
}
|
||||
@@ -367,7 +81,7 @@ func expectLockErr(lockTable, workerName string, d time.Duration, instanceID str
|
||||
d,
|
||||
projectionName,
|
||||
instanceID,
|
||||
database.StringArray{instanceID},
|
||||
database.TextArray[string]{instanceID},
|
||||
).
|
||||
WillReturnError(err)
|
||||
}
|
||||
|
@@ -1,51 +0,0 @@
|
||||
package crdb
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
|
||||
"github.com/zitadel/logging"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/errors"
|
||||
"github.com/zitadel/zitadel/internal/eventstore/handler"
|
||||
)
|
||||
|
||||
const (
|
||||
setFailureCountStmtFormat = "INSERT INTO %s" +
|
||||
" (projection_name, failed_sequence, failure_count, error, instance_id, last_failed)" +
|
||||
" VALUES ($1, $2, $3, $4, $5, $6) ON CONFLICT (projection_name, failed_sequence, instance_id)" +
|
||||
" DO UPDATE SET failure_count = EXCLUDED.failure_count, error = EXCLUDED.error, last_failed = EXCLUDED.last_failed"
|
||||
failureCountStmtFormat = "WITH failures AS (SELECT failure_count FROM %s WHERE projection_name = $1 AND failed_sequence = $2 AND instance_id = $3)" +
|
||||
" SELECT COALESCE((SELECT failure_count FROM failures), 0) AS failure_count"
|
||||
)
|
||||
|
||||
func (h *StatementHandler) handleFailedStmt(tx *sql.Tx, stmt *handler.Statement, execErr error) (shouldContinue bool) {
|
||||
failureCount, err := h.failureCount(tx, stmt.Sequence, stmt.InstanceID)
|
||||
if err != nil {
|
||||
logging.WithFields("projection", h.ProjectionName, "sequence", stmt.Sequence).WithError(err).Warn("unable to get failure count")
|
||||
return false
|
||||
}
|
||||
failureCount += 1
|
||||
err = h.setFailureCount(tx, stmt.Sequence, failureCount, execErr, stmt.InstanceID)
|
||||
logging.WithFields("projection", h.ProjectionName, "sequence", stmt.Sequence).OnError(err).Warn("unable to update failure count")
|
||||
|
||||
return failureCount >= h.maxFailureCount
|
||||
}
|
||||
|
||||
func (h *StatementHandler) failureCount(tx *sql.Tx, seq uint64, instanceID string) (count uint, err error) {
|
||||
row := tx.QueryRow(h.failureCountStmt, h.ProjectionName, seq, instanceID)
|
||||
if err = row.Err(); err != nil {
|
||||
return 0, errors.ThrowInternal(err, "CRDB-Unnex", "unable to update failure count")
|
||||
}
|
||||
if err = row.Scan(&count); err != nil {
|
||||
return 0, errors.ThrowInternal(err, "CRDB-RwSMV", "unable to scan count")
|
||||
}
|
||||
return count, nil
|
||||
}
|
||||
|
||||
func (h *StatementHandler) setFailureCount(tx *sql.Tx, seq uint64, count uint, err error, instanceID string) error {
|
||||
_, dbErr := tx.Exec(h.setFailureCountStmt, h.ProjectionName, seq, count, err.Error(), instanceID, "NOW()")
|
||||
if dbErr != nil {
|
||||
return errors.ThrowInternal(dbErr, "CRDB-4Ht4x", "set failure count failed")
|
||||
}
|
||||
return nil
|
||||
}
|
@@ -1,347 +0,0 @@
|
||||
package crdb
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
|
||||
"github.com/zitadel/logging"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/database"
|
||||
"github.com/zitadel/zitadel/internal/errors"
|
||||
"github.com/zitadel/zitadel/internal/eventstore"
|
||||
"github.com/zitadel/zitadel/internal/eventstore/handler"
|
||||
"github.com/zitadel/zitadel/internal/repository/pseudo"
|
||||
)
|
||||
|
||||
var (
|
||||
errSeqNotUpdated = errors.ThrowInternal(nil, "CRDB-79GWt", "current sequence not updated")
|
||||
)
|
||||
|
||||
type StatementHandlerConfig struct {
|
||||
handler.ProjectionHandlerConfig
|
||||
|
||||
Client *database.DB
|
||||
SequenceTable string
|
||||
LockTable string
|
||||
FailedEventsTable string
|
||||
MaxFailureCount uint
|
||||
BulkLimit uint64
|
||||
|
||||
Reducers []handler.AggregateReducer
|
||||
InitCheck *handler.Check
|
||||
}
|
||||
|
||||
type StatementHandler struct {
|
||||
*handler.ProjectionHandler
|
||||
Locker
|
||||
|
||||
client *database.DB
|
||||
sequenceTable string
|
||||
currentSequenceStmt string
|
||||
currentSequenceWithoutLockStmt string
|
||||
updateSequencesBaseStmt string
|
||||
maxFailureCount uint
|
||||
failureCountStmt string
|
||||
setFailureCountStmt string
|
||||
|
||||
aggregates []eventstore.AggregateType
|
||||
reduces map[eventstore.EventType]handler.Reduce
|
||||
initCheck *handler.Check
|
||||
initialized chan bool
|
||||
|
||||
bulkLimit uint64
|
||||
|
||||
reduceScheduledPseudoEvent bool
|
||||
}
|
||||
|
||||
func NewStatementHandler(
|
||||
ctx context.Context,
|
||||
config StatementHandlerConfig,
|
||||
) StatementHandler {
|
||||
aggregateTypes := make([]eventstore.AggregateType, 0, len(config.Reducers))
|
||||
reduces := make(map[eventstore.EventType]handler.Reduce, len(config.Reducers))
|
||||
reduceScheduledPseudoEvent := false
|
||||
for _, aggReducer := range config.Reducers {
|
||||
aggregateTypes = append(aggregateTypes, aggReducer.Aggregate)
|
||||
if aggReducer.Aggregate == pseudo.AggregateType {
|
||||
reduceScheduledPseudoEvent = true
|
||||
if len(config.Reducers) != 1 ||
|
||||
len(aggReducer.EventRedusers) != 1 ||
|
||||
aggReducer.EventRedusers[0].Event != pseudo.ScheduledEventType {
|
||||
panic("if a pseudo.AggregateType is reduced, exactly one event reducer for pseudo.ScheduledEventType is supported and no other aggregate can be reduced")
|
||||
}
|
||||
}
|
||||
for _, eventReducer := range aggReducer.EventRedusers {
|
||||
reduces[eventReducer.Event] = eventReducer.Reduce
|
||||
}
|
||||
}
|
||||
|
||||
h := StatementHandler{
|
||||
client: config.Client,
|
||||
sequenceTable: config.SequenceTable,
|
||||
maxFailureCount: config.MaxFailureCount,
|
||||
currentSequenceStmt: fmt.Sprintf(currentSequenceStmtFormat, config.SequenceTable),
|
||||
currentSequenceWithoutLockStmt: fmt.Sprintf(currentSequenceStmtWithoutLockFormat, config.SequenceTable),
|
||||
updateSequencesBaseStmt: fmt.Sprintf(updateCurrentSequencesStmtFormat, config.SequenceTable),
|
||||
failureCountStmt: fmt.Sprintf(failureCountStmtFormat, config.FailedEventsTable),
|
||||
setFailureCountStmt: fmt.Sprintf(setFailureCountStmtFormat, config.FailedEventsTable),
|
||||
aggregates: aggregateTypes,
|
||||
reduces: reduces,
|
||||
bulkLimit: config.BulkLimit,
|
||||
Locker: NewLocker(config.Client.DB, config.LockTable, config.ProjectionName),
|
||||
initCheck: config.InitCheck,
|
||||
initialized: make(chan bool),
|
||||
reduceScheduledPseudoEvent: reduceScheduledPseudoEvent,
|
||||
}
|
||||
|
||||
h.ProjectionHandler = handler.NewProjectionHandler(ctx, config.ProjectionHandlerConfig, h.reduce, h.Update, h.searchQuery, h.Lock, h.Unlock, h.initialized, reduceScheduledPseudoEvent)
|
||||
|
||||
return h
|
||||
}
|
||||
|
||||
func (h *StatementHandler) Start() {
|
||||
h.initialized <- true
|
||||
close(h.initialized)
|
||||
if !h.reduceScheduledPseudoEvent {
|
||||
h.Subscribe(h.aggregates...)
|
||||
}
|
||||
}
|
||||
|
||||
func (h *StatementHandler) searchQuery(ctx context.Context, instanceIDs []string) (*eventstore.SearchQueryBuilder, uint64, error) {
|
||||
if h.reduceScheduledPseudoEvent {
|
||||
return nil, 1, nil
|
||||
}
|
||||
return h.dbSearchQuery(ctx, instanceIDs)
|
||||
}
|
||||
|
||||
func (h *StatementHandler) dbSearchQuery(ctx context.Context, instanceIDs []string) (*eventstore.SearchQueryBuilder, uint64, error) {
|
||||
sequences, err := h.currentSequences(ctx, false, h.client.QueryContext, instanceIDs)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
queryBuilder := eventstore.NewSearchQueryBuilder(eventstore.ColumnsEvent).Limit(h.bulkLimit).AllowTimeTravel()
|
||||
|
||||
for _, aggregateType := range h.aggregates {
|
||||
for _, instanceID := range instanceIDs {
|
||||
var seq uint64
|
||||
for _, sequence := range sequences[aggregateType] {
|
||||
if sequence.instanceID == instanceID {
|
||||
seq = sequence.sequence
|
||||
break
|
||||
}
|
||||
}
|
||||
queryBuilder.
|
||||
AddQuery().
|
||||
AggregateTypes(aggregateType).
|
||||
SequenceGreater(seq).
|
||||
InstanceID(instanceID)
|
||||
}
|
||||
}
|
||||
return queryBuilder, h.bulkLimit, nil
|
||||
}
|
||||
|
||||
type transaction struct {
|
||||
*sql.Tx
|
||||
}
|
||||
|
||||
func (t *transaction) QueryContext(ctx context.Context, scan func(rows *sql.Rows) error, query string, args ...any) error {
|
||||
rows, err := t.Tx.QueryContext(ctx, query, args...)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer func() {
|
||||
closeErr := rows.Close()
|
||||
logging.OnError(closeErr).Info("rows.Close failed")
|
||||
}()
|
||||
|
||||
if err = scan(rows); err != nil {
|
||||
return err
|
||||
}
|
||||
return rows.Err()
|
||||
}
|
||||
|
||||
// Update implements handler.Update
|
||||
func (h *StatementHandler) Update(ctx context.Context, stmts []*handler.Statement, reduce handler.Reduce) (index int, err error) {
|
||||
if len(stmts) == 0 {
|
||||
return -1, nil
|
||||
}
|
||||
instanceIDs := make([]string, 0, len(stmts))
|
||||
for _, stmt := range stmts {
|
||||
instanceIDs = appendToInstanceIDs(instanceIDs, stmt.InstanceID)
|
||||
}
|
||||
tx, err := h.client.BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
return -1, errors.ThrowInternal(err, "CRDB-e89Gq", "begin failed")
|
||||
}
|
||||
|
||||
sequences, err := h.currentSequences(ctx, true, (&transaction{Tx: tx}).QueryContext, instanceIDs)
|
||||
if err != nil {
|
||||
tx.Rollback()
|
||||
return -1, err
|
||||
}
|
||||
|
||||
//checks for events between create statement and current sequence
|
||||
// because there could be events between current sequence and a creation event
|
||||
// and we cannot check via stmt.PreviousSequence
|
||||
if stmts[0].PreviousSequence == 0 {
|
||||
previousStmts, err := h.fetchPreviousStmts(ctx, tx, stmts[0].Sequence, stmts[0].InstanceID, sequences, reduce)
|
||||
if err != nil {
|
||||
tx.Rollback()
|
||||
return -1, err
|
||||
}
|
||||
stmts = append(previousStmts, stmts...)
|
||||
}
|
||||
|
||||
lastSuccessfulIdx := h.executeStmts(tx, &stmts, sequences)
|
||||
|
||||
if lastSuccessfulIdx >= 0 {
|
||||
err = h.updateCurrentSequences(tx, sequences)
|
||||
if err != nil {
|
||||
tx.Rollback()
|
||||
return -1, err
|
||||
}
|
||||
}
|
||||
|
||||
if err = tx.Commit(); err != nil {
|
||||
return -1, err
|
||||
}
|
||||
|
||||
if lastSuccessfulIdx < len(stmts)-1 {
|
||||
return lastSuccessfulIdx, handler.ErrSomeStmtsFailed
|
||||
}
|
||||
|
||||
return lastSuccessfulIdx, nil
|
||||
}
|
||||
|
||||
func (h *StatementHandler) fetchPreviousStmts(ctx context.Context, tx *sql.Tx, stmtSeq uint64, instanceID string, sequences currentSequences, reduce handler.Reduce) (previousStmts []*handler.Statement, err error) {
|
||||
query := eventstore.NewSearchQueryBuilder(eventstore.ColumnsEvent).SetTx(tx)
|
||||
queriesAdded := false
|
||||
for _, aggregateType := range h.aggregates {
|
||||
for _, sequence := range sequences[aggregateType] {
|
||||
if stmtSeq <= sequence.sequence && instanceID == sequence.instanceID {
|
||||
continue
|
||||
}
|
||||
|
||||
query.
|
||||
AddQuery().
|
||||
AggregateTypes(aggregateType).
|
||||
SequenceGreater(sequence.sequence).
|
||||
SequenceLess(stmtSeq).
|
||||
InstanceID(sequence.instanceID)
|
||||
|
||||
queriesAdded = true
|
||||
}
|
||||
}
|
||||
|
||||
if !queriesAdded {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
events, err := h.Eventstore.Filter(ctx, query)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for _, event := range events {
|
||||
stmt, err := reduce(event)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
previousStmts = append(previousStmts, stmt)
|
||||
}
|
||||
return previousStmts, nil
|
||||
}
|
||||
|
||||
func (h *StatementHandler) executeStmts(
|
||||
tx *sql.Tx,
|
||||
stmts *[]*handler.Statement,
|
||||
sequences currentSequences,
|
||||
) int {
|
||||
|
||||
lastSuccessfulIdx := -1
|
||||
stmts:
|
||||
for i := 0; i < len(*stmts); i++ {
|
||||
stmt := (*stmts)[i]
|
||||
for _, sequence := range sequences[stmt.AggregateType] {
|
||||
if stmt.Sequence <= sequence.sequence && stmt.InstanceID == sequence.instanceID {
|
||||
logging.WithFields("statement", stmt, "currentSequence", sequence).Debug("statement dropped")
|
||||
if i < len(*stmts)-1 {
|
||||
copy((*stmts)[i:], (*stmts)[i+1:])
|
||||
}
|
||||
*stmts = (*stmts)[:len(*stmts)-1]
|
||||
i--
|
||||
continue stmts
|
||||
}
|
||||
if stmt.PreviousSequence > 0 && stmt.PreviousSequence != sequence.sequence && stmt.InstanceID == sequence.instanceID {
|
||||
logging.WithFields("projection", h.ProjectionName, "aggregateType", stmt.AggregateType, "sequence", stmt.Sequence, "prevSeq", stmt.PreviousSequence, "currentSeq", sequence.sequence).Warn("sequences do not match")
|
||||
break stmts
|
||||
}
|
||||
}
|
||||
err := h.executeStmt(tx, stmt)
|
||||
if err == nil {
|
||||
updateSequences(sequences, stmt)
|
||||
lastSuccessfulIdx = i
|
||||
continue
|
||||
}
|
||||
|
||||
shouldContinue := h.handleFailedStmt(tx, stmt, err)
|
||||
if !shouldContinue {
|
||||
break
|
||||
}
|
||||
|
||||
updateSequences(sequences, stmt)
|
||||
lastSuccessfulIdx = i
|
||||
continue
|
||||
}
|
||||
return lastSuccessfulIdx
|
||||
}
|
||||
|
||||
// executeStmt handles sql statements
|
||||
// an error is returned if the statement could not be inserted properly
|
||||
func (h *StatementHandler) executeStmt(tx *sql.Tx, stmt *handler.Statement) error {
|
||||
if stmt.IsNoop() {
|
||||
return nil
|
||||
}
|
||||
_, err := tx.Exec("SAVEPOINT push_stmt")
|
||||
if err != nil {
|
||||
return errors.ThrowInternal(err, "CRDB-i1wp6", "unable to create savepoint")
|
||||
}
|
||||
err = stmt.Execute(tx, h.ProjectionName)
|
||||
if err != nil {
|
||||
logging.WithError(err).Error()
|
||||
_, rollbackErr := tx.Exec("ROLLBACK TO SAVEPOINT push_stmt")
|
||||
if rollbackErr != nil {
|
||||
return errors.ThrowInternal(rollbackErr, "CRDB-zzp3P", "rollback to savepoint failed")
|
||||
}
|
||||
return errors.ThrowInternal(err, "CRDB-oRkaN", "unable execute stmt")
|
||||
}
|
||||
_, err = tx.Exec("RELEASE push_stmt")
|
||||
if err != nil {
|
||||
return errors.ThrowInternal(err, "CRDB-qWgwT", "unable to release savepoint")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func updateSequences(sequences currentSequences, stmt *handler.Statement) {
|
||||
for _, sequence := range sequences[stmt.AggregateType] {
|
||||
if sequence.instanceID == stmt.InstanceID {
|
||||
sequence.sequence = stmt.Sequence
|
||||
return
|
||||
}
|
||||
}
|
||||
sequences[stmt.AggregateType] = append(sequences[stmt.AggregateType], &instanceSequence{
|
||||
instanceID: stmt.InstanceID,
|
||||
sequence: stmt.Sequence,
|
||||
})
|
||||
}
|
||||
|
||||
func appendToInstanceIDs(instances []string, id string) []string {
|
||||
for _, instance := range instances {
|
||||
if instance == id {
|
||||
return instances
|
||||
}
|
||||
}
|
||||
return append(instances, id)
|
||||
}
|
File diff suppressed because it is too large
Load Diff
@@ -1,49 +0,0 @@
|
||||
package crdb
|
||||
|
||||
import "testing"
|
||||
|
||||
func Test_defaultValue(t *testing.T) {
|
||||
type args struct {
|
||||
value interface{}
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "string",
|
||||
args: args{
|
||||
value: "asdf",
|
||||
},
|
||||
want: "'asdf'",
|
||||
},
|
||||
{
|
||||
name: "primitive non string",
|
||||
args: args{
|
||||
value: 1,
|
||||
},
|
||||
want: "1",
|
||||
},
|
||||
{
|
||||
name: "stringer",
|
||||
args: args{
|
||||
value: testStringer(0),
|
||||
},
|
||||
want: "0",
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := defaultValue(tt.args.value); got != tt.want {
|
||||
t.Errorf("defaultValue() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
type testStringer int
|
||||
|
||||
func (t testStringer) String() string {
|
||||
return "0529958243"
|
||||
}
|
@@ -91,7 +91,7 @@ func (h *locker) Unlock(instanceIDs ...string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *locker) lockStatement(lockDuration time.Duration, instanceIDs database.StringArray) (string, []interface{}) {
|
||||
func (h *locker) lockStatement(lockDuration time.Duration, instanceIDs database.TextArray[string]) (string, []interface{}) {
|
||||
valueQueries := make([]string, len(instanceIDs))
|
||||
values := make([]interface{}, len(instanceIDs)+4)
|
||||
values[0] = h.workerName
|
||||
|
@@ -158,7 +158,7 @@ func TestStatementHandler_renewLock(t *testing.T) {
|
||||
},
|
||||
args: args{
|
||||
lockDuration: 1 * time.Second,
|
||||
instanceIDs: database.StringArray{"instanceID"},
|
||||
instanceIDs: database.TextArray[string]{"instanceID"},
|
||||
},
|
||||
},
|
||||
{
|
||||
@@ -173,7 +173,7 @@ func TestStatementHandler_renewLock(t *testing.T) {
|
||||
},
|
||||
args: args{
|
||||
lockDuration: 2 * time.Second,
|
||||
instanceIDs: database.StringArray{"instanceID"},
|
||||
instanceIDs: database.TextArray[string]{"instanceID"},
|
||||
},
|
||||
},
|
||||
{
|
||||
@@ -188,7 +188,7 @@ func TestStatementHandler_renewLock(t *testing.T) {
|
||||
},
|
||||
args: args{
|
||||
lockDuration: 3 * time.Second,
|
||||
instanceIDs: database.StringArray{"instanceID"},
|
||||
instanceIDs: database.TextArray[string]{"instanceID"},
|
||||
},
|
||||
},
|
||||
{
|
||||
|
@@ -1,16 +0,0 @@
|
||||
package crdb
|
||||
|
||||
import (
|
||||
"github.com/zitadel/zitadel/internal/eventstore"
|
||||
"github.com/zitadel/zitadel/internal/eventstore/handler"
|
||||
)
|
||||
|
||||
//reduce implements handler.Reduce function
|
||||
func (h *StatementHandler) reduce(event eventstore.Event) (*handler.Statement, error) {
|
||||
reduce, ok := h.reduces[event.Type()]
|
||||
if !ok {
|
||||
return NewNoOpStatement(event), nil
|
||||
}
|
||||
|
||||
return reduce(event)
|
||||
}
|
@@ -1,36 +0,0 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"github.com/zitadel/zitadel/internal/eventstore"
|
||||
)
|
||||
|
||||
type HandlerConfig struct {
|
||||
Eventstore *eventstore.Eventstore
|
||||
}
|
||||
type Handler struct {
|
||||
Eventstore *eventstore.Eventstore
|
||||
Sub *eventstore.Subscription
|
||||
EventQueue chan eventstore.Event
|
||||
}
|
||||
|
||||
func NewHandler(config HandlerConfig) Handler {
|
||||
return Handler{
|
||||
Eventstore: config.Eventstore,
|
||||
EventQueue: make(chan eventstore.Event, 100),
|
||||
}
|
||||
}
|
||||
|
||||
func (h *Handler) Subscribe(aggregates ...eventstore.AggregateType) {
|
||||
h.Sub = eventstore.SubscribeAggregates(h.EventQueue, aggregates...)
|
||||
}
|
||||
|
||||
func (h *Handler) SubscribeEvents(types map[eventstore.AggregateType][]eventstore.EventType) {
|
||||
h.Sub = eventstore.SubscribeEventTypes(h.EventQueue, types)
|
||||
}
|
||||
|
||||
func (h *Handler) Unsubscribe() {
|
||||
if h.Sub == nil {
|
||||
return
|
||||
}
|
||||
h.Sub.Unsubscribe()
|
||||
}
|
@@ -1,396 +0,0 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"runtime/debug"
|
||||
"time"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/zitadel/logging"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/api/authz"
|
||||
"github.com/zitadel/zitadel/internal/api/call"
|
||||
"github.com/zitadel/zitadel/internal/eventstore"
|
||||
"github.com/zitadel/zitadel/internal/repository/pseudo"
|
||||
)
|
||||
|
||||
const (
|
||||
schedulerSucceeded = eventstore.EventType("system.projections.scheduler.succeeded")
|
||||
aggregateType = eventstore.AggregateType("system")
|
||||
aggregateID = "SYSTEM"
|
||||
)
|
||||
|
||||
type ProjectionHandlerConfig struct {
|
||||
HandlerConfig
|
||||
ProjectionName string
|
||||
RequeueEvery time.Duration
|
||||
RetryFailedAfter time.Duration
|
||||
Retries uint
|
||||
ConcurrentInstances uint
|
||||
HandleActiveInstances time.Duration
|
||||
}
|
||||
|
||||
// Update updates the projection with the given statements
|
||||
type Update func(context.Context, []*Statement, Reduce) (index int, err error)
|
||||
|
||||
// Reduce reduces the given event to a statement
|
||||
// which is used to update the projection
|
||||
type Reduce func(eventstore.Event) (*Statement, error)
|
||||
|
||||
// SearchQuery generates the search query to lookup for events
|
||||
type SearchQuery func(ctx context.Context, instanceIDs []string) (query *eventstore.SearchQueryBuilder, queryLimit uint64, err error)
|
||||
|
||||
// Lock is used for mutex handling if needed on the projection
|
||||
type Lock func(context.Context, time.Duration, ...string) <-chan error
|
||||
|
||||
// Unlock releases the mutex of the projection
|
||||
type Unlock func(...string) error
|
||||
|
||||
// NowFunc makes time.Now() mockable
|
||||
type NowFunc func() time.Time
|
||||
|
||||
type ProjectionHandler struct {
|
||||
Handler
|
||||
ProjectionName string
|
||||
reduce Reduce
|
||||
update Update
|
||||
searchQuery SearchQuery
|
||||
triggerProjection *time.Timer
|
||||
lock Lock
|
||||
unlock Unlock
|
||||
requeueAfter time.Duration
|
||||
retryFailedAfter time.Duration
|
||||
retries int
|
||||
concurrentInstances int
|
||||
handleActiveInstances time.Duration
|
||||
nowFunc NowFunc
|
||||
reduceScheduledPseudoEvent bool
|
||||
}
|
||||
|
||||
func NewProjectionHandler(
|
||||
ctx context.Context,
|
||||
config ProjectionHandlerConfig,
|
||||
reduce Reduce,
|
||||
update Update,
|
||||
query SearchQuery,
|
||||
lock Lock,
|
||||
unlock Unlock,
|
||||
initialized <-chan bool,
|
||||
reduceScheduledPseudoEvent bool,
|
||||
) *ProjectionHandler {
|
||||
concurrentInstances := int(config.ConcurrentInstances)
|
||||
if concurrentInstances < 1 {
|
||||
concurrentInstances = 1
|
||||
}
|
||||
h := &ProjectionHandler{
|
||||
Handler: NewHandler(config.HandlerConfig),
|
||||
ProjectionName: config.ProjectionName,
|
||||
reduce: reduce,
|
||||
update: update,
|
||||
searchQuery: query,
|
||||
lock: lock,
|
||||
unlock: unlock,
|
||||
requeueAfter: config.RequeueEvery,
|
||||
triggerProjection: time.NewTimer(0), // first trigger is instant on startup
|
||||
retryFailedAfter: config.RetryFailedAfter,
|
||||
retries: int(config.Retries),
|
||||
concurrentInstances: concurrentInstances,
|
||||
handleActiveInstances: config.HandleActiveInstances,
|
||||
nowFunc: time.Now,
|
||||
reduceScheduledPseudoEvent: reduceScheduledPseudoEvent,
|
||||
}
|
||||
|
||||
go func() {
|
||||
<-initialized
|
||||
if !h.reduceScheduledPseudoEvent {
|
||||
go h.subscribe(ctx)
|
||||
}
|
||||
go h.schedule(ctx)
|
||||
}()
|
||||
|
||||
return h
|
||||
}
|
||||
|
||||
func triggerInstances(ctx context.Context, instances []string) []string {
|
||||
if len(instances) == 0 {
|
||||
instances = append(instances, authz.GetInstance(ctx).InstanceID())
|
||||
}
|
||||
return instances
|
||||
}
|
||||
|
||||
// Trigger handles all events for the provided instances (or current instance from context if non specified)
|
||||
// by calling FetchEvents and Process until the amount of events is smaller than the BulkLimit.
|
||||
// If a bulk action was executed, the call timestamp in context will be reset for subsequent queries.
|
||||
// The returned context is never nil. It is either the original context or an updated context.
|
||||
//
|
||||
// If Trigger encounters an error, it is only logged. If the error is important for the caller,
|
||||
// use TriggerErr instead.
|
||||
func (h *ProjectionHandler) Trigger(ctx context.Context, instances ...string) context.Context {
|
||||
instances = triggerInstances(ctx, instances)
|
||||
ctx, err := h.TriggerErr(ctx, instances...)
|
||||
logging.OnError(err).WithFields(logrus.Fields{
|
||||
"projection": h.ProjectionName,
|
||||
"instanceIDs": instances,
|
||||
}).Error("trigger failed")
|
||||
return ctx
|
||||
}
|
||||
|
||||
// TriggerErr handles all events for the provided instances (or current instance from context if non specified)
|
||||
// by calling FetchEvents and Process until the amount of events is smaller than the BulkLimit.
|
||||
// If a bulk action was executed, the call timestamp in context will be reset for subsequent queries.
|
||||
// The returned context is never nil. It is either the original context or an updated context.
|
||||
func (h *ProjectionHandler) TriggerErr(ctx context.Context, instances ...string) (outCtx context.Context, err error) {
|
||||
instances = triggerInstances(ctx, instances)
|
||||
defer func() {
|
||||
outCtx = call.ResetTimestamp(ctx)
|
||||
}()
|
||||
for {
|
||||
events, hasLimitExceeded, err := h.FetchEvents(ctx, instances...)
|
||||
if err != nil {
|
||||
return ctx, err
|
||||
}
|
||||
if len(events) == 0 {
|
||||
return ctx, nil
|
||||
}
|
||||
_, err = h.Process(ctx, events...)
|
||||
if err != nil {
|
||||
return ctx, err
|
||||
}
|
||||
if !hasLimitExceeded {
|
||||
return ctx, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Process handles multiple events by reducing them to statements and updating the projection
|
||||
func (h *ProjectionHandler) Process(ctx context.Context, events ...eventstore.Event) (index int, err error) {
|
||||
if len(events) == 0 {
|
||||
return 0, nil
|
||||
}
|
||||
index = -1
|
||||
statements := make([]*Statement, len(events))
|
||||
for i, event := range events {
|
||||
statements[i], err = h.reduce(event)
|
||||
if err != nil {
|
||||
return index, err
|
||||
}
|
||||
}
|
||||
for retry := 0; retry <= h.retries; retry++ {
|
||||
index, err = h.update(ctx, statements[index+1:], h.reduce)
|
||||
if err != nil && !errors.Is(err, ErrSomeStmtsFailed) {
|
||||
return index, err
|
||||
}
|
||||
if err == nil {
|
||||
return index, nil
|
||||
}
|
||||
time.Sleep(h.retryFailedAfter)
|
||||
}
|
||||
return index, err
|
||||
}
|
||||
|
||||
// FetchEvents checks the current sequences and filters for newer events
|
||||
func (h *ProjectionHandler) FetchEvents(ctx context.Context, instances ...string) ([]eventstore.Event, bool, error) {
|
||||
if h.reduceScheduledPseudoEvent {
|
||||
return h.fetchPseudoEvents(ctx, instances...)
|
||||
}
|
||||
return h.fetchDBEvents(ctx, instances...)
|
||||
}
|
||||
|
||||
func (h *ProjectionHandler) fetchDBEvents(ctx context.Context, instances ...string) ([]eventstore.Event, bool, error) {
|
||||
eventQuery, eventsLimit, err := h.searchQuery(ctx, instances)
|
||||
if err != nil {
|
||||
return nil, false, err
|
||||
}
|
||||
events, err := h.Eventstore.Filter(ctx, eventQuery)
|
||||
if err != nil {
|
||||
return nil, false, err
|
||||
}
|
||||
return events, int(eventsLimit) == len(events), err
|
||||
}
|
||||
|
||||
func (h *ProjectionHandler) fetchPseudoEvents(ctx context.Context, instances ...string) ([]eventstore.Event, bool, error) {
|
||||
return []eventstore.Event{pseudo.NewScheduledEvent(ctx, time.Now(), instances...)}, false, nil
|
||||
}
|
||||
|
||||
func (h *ProjectionHandler) subscribe(ctx context.Context) {
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
defer func() {
|
||||
err := recover()
|
||||
if err != nil {
|
||||
h.Handler.Unsubscribe()
|
||||
logging.WithFields("projection", h.ProjectionName).Errorf("subscription panicked: %v", err)
|
||||
}
|
||||
cancel()
|
||||
}()
|
||||
for firstEvent := range h.EventQueue {
|
||||
events := checkAdditionalEvents(h.EventQueue, firstEvent)
|
||||
|
||||
index, err := h.Process(ctx, events...)
|
||||
if err != nil || index < len(events)-1 {
|
||||
logging.WithFields("projection", h.ProjectionName).WithError(err).Warn("unable to process all events from subscription")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (h *ProjectionHandler) schedule(ctx context.Context) {
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
defer func() {
|
||||
err := recover()
|
||||
if err != nil {
|
||||
logging.WithFields("projection", h.ProjectionName, "cause", err, "stack", string(debug.Stack())).Error("schedule panicked")
|
||||
}
|
||||
cancel()
|
||||
}()
|
||||
// flag if projection has been successfully executed at least once since start
|
||||
var succeededOnce bool
|
||||
var err error
|
||||
// get every instance id except empty (system)
|
||||
query := eventstore.NewSearchQueryBuilder(eventstore.ColumnsInstanceIDs).AllowTimeTravel().AddQuery().ExcludedInstanceID("")
|
||||
for range h.triggerProjection.C {
|
||||
if !succeededOnce {
|
||||
// (re)check if it has succeeded in the meantime
|
||||
succeededOnce, err = h.hasSucceededOnce(ctx)
|
||||
if err != nil {
|
||||
logging.WithFields("projection", h.ProjectionName, "err", err).
|
||||
Error("schedule could not check if projection has already succeeded once")
|
||||
h.triggerProjection.Reset(h.requeueAfter)
|
||||
continue
|
||||
}
|
||||
}
|
||||
lockCtx := ctx
|
||||
var cancelLock context.CancelFunc
|
||||
// if it still has not succeeded, lock the projection for the system
|
||||
// so that only a single scheduler does a first schedule (of every instance)
|
||||
if !succeededOnce {
|
||||
lockCtx, cancelLock = context.WithCancel(ctx)
|
||||
errs := h.lock(lockCtx, h.requeueAfter, "system")
|
||||
if err, ok := <-errs; err != nil || !ok {
|
||||
cancelLock()
|
||||
logging.WithFields("projection", h.ProjectionName).OnError(err).Debug("initial lock failed for first schedule")
|
||||
h.triggerProjection.Reset(h.requeueAfter)
|
||||
continue
|
||||
}
|
||||
go h.cancelOnErr(lockCtx, errs, cancelLock)
|
||||
}
|
||||
if succeededOnce {
|
||||
// since we have at least one successful run, we can restrict it to events not older than
|
||||
// h.handleActiveInstances (just to be sure not to miss an event)
|
||||
// This ensures that only instances with recent events on the handler are projected
|
||||
query = query.CreationDateAfter(h.nowFunc().Add(-1 * h.handleActiveInstances))
|
||||
}
|
||||
ids, err := h.Eventstore.InstanceIDs(ctx, h.requeueAfter, !succeededOnce, query.Builder())
|
||||
if err != nil {
|
||||
logging.WithFields("projection", h.ProjectionName).WithError(err).Error("instance ids")
|
||||
h.triggerProjection.Reset(h.requeueAfter)
|
||||
continue
|
||||
}
|
||||
var failed bool
|
||||
for i := 0; i < len(ids); i = i + h.concurrentInstances {
|
||||
max := i + h.concurrentInstances
|
||||
if max > len(ids) {
|
||||
max = len(ids)
|
||||
}
|
||||
instances := ids[i:max]
|
||||
lockInstanceCtx, cancelInstanceLock := context.WithCancel(lockCtx)
|
||||
errs := h.lock(lockInstanceCtx, h.requeueAfter, instances...)
|
||||
//wait until projection is locked
|
||||
if err, ok := <-errs; err != nil || !ok {
|
||||
cancelInstanceLock()
|
||||
logging.WithFields("projection", h.ProjectionName).OnError(err).Debug("initial lock failed")
|
||||
failed = true
|
||||
continue
|
||||
}
|
||||
go h.cancelOnErr(lockInstanceCtx, errs, cancelInstanceLock)
|
||||
_, err = h.TriggerErr(lockInstanceCtx, instances...)
|
||||
if err != nil {
|
||||
logging.WithFields("projection", h.ProjectionName, "instanceIDs", instances).WithError(err).Error("trigger failed")
|
||||
failed = true
|
||||
}
|
||||
|
||||
cancelInstanceLock()
|
||||
unlockErr := h.unlock(instances...)
|
||||
logging.WithFields("projection", h.ProjectionName).OnError(unlockErr).Warn("unable to unlock")
|
||||
}
|
||||
// if the first schedule did not fail, store that in the eventstore, so we can check on later starts
|
||||
if !succeededOnce {
|
||||
if !failed {
|
||||
err = h.setSucceededOnce(ctx)
|
||||
logging.WithFields("projection", h.ProjectionName).OnError(err).Warn("unable to push first schedule succeeded")
|
||||
}
|
||||
cancelLock()
|
||||
unlockErr := h.unlock("system")
|
||||
logging.WithFields("projection", h.ProjectionName).OnError(unlockErr).Warn("unable to unlock first schedule")
|
||||
}
|
||||
// it succeeded at least once if it has succeeded before or if it has succeeded now - not failed ;-)
|
||||
succeededOnce = succeededOnce || !failed
|
||||
h.triggerProjection.Reset(h.requeueAfter)
|
||||
}
|
||||
}
|
||||
|
||||
func (h *ProjectionHandler) hasSucceededOnce(ctx context.Context) (bool, error) {
|
||||
events, err := h.Eventstore.Filter(ctx, eventstore.NewSearchQueryBuilder(eventstore.ColumnsEvent).
|
||||
AddQuery().
|
||||
AggregateTypes(aggregateType).
|
||||
AggregateIDs(aggregateID).
|
||||
EventTypes(schedulerSucceeded).
|
||||
EventData(map[string]interface{}{
|
||||
"name": h.ProjectionName,
|
||||
}).
|
||||
Builder(),
|
||||
)
|
||||
return len(events) > 0 && err == nil, err
|
||||
}
|
||||
|
||||
func (h *ProjectionHandler) setSucceededOnce(ctx context.Context) error {
|
||||
_, err := h.Eventstore.Push(ctx, &ProjectionSucceededEvent{
|
||||
BaseEvent: *eventstore.NewBaseEventForPush(ctx,
|
||||
eventstore.NewAggregate(ctx, aggregateID, aggregateType, "v1"),
|
||||
schedulerSucceeded,
|
||||
),
|
||||
Name: h.ProjectionName,
|
||||
})
|
||||
return err
|
||||
}
|
||||
|
||||
type ProjectionSucceededEvent struct {
|
||||
eventstore.BaseEvent `json:"-"`
|
||||
Name string `json:"name"`
|
||||
}
|
||||
|
||||
func (p *ProjectionSucceededEvent) Data() interface{} {
|
||||
return p
|
||||
}
|
||||
|
||||
func (p *ProjectionSucceededEvent) UniqueConstraints() []*eventstore.EventUniqueConstraint {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *ProjectionHandler) cancelOnErr(ctx context.Context, errs <-chan error, cancel func()) {
|
||||
for {
|
||||
select {
|
||||
case err := <-errs:
|
||||
if err != nil {
|
||||
logging.WithFields("projection", h.ProjectionName).WithError(err).Warn("bulk canceled")
|
||||
cancel()
|
||||
return
|
||||
}
|
||||
case <-ctx.Done():
|
||||
cancel()
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func checkAdditionalEvents(eventQueue chan eventstore.Event, event eventstore.Event) []eventstore.Event {
|
||||
events := make([]eventstore.Event, 1)
|
||||
events[0] = event
|
||||
for {
|
||||
select {
|
||||
case event := <-eventQueue:
|
||||
events = append(events, event)
|
||||
default:
|
||||
return events
|
||||
}
|
||||
}
|
||||
}
|
File diff suppressed because it is too large
Load Diff
@@ -2,7 +2,7 @@ package handler
|
||||
|
||||
import "context"
|
||||
|
||||
//Init initializes the projection with the given check
|
||||
// Init initializes the projection with the given check
|
||||
type Init func(context.Context, *Check) error
|
||||
|
||||
type Check struct {
|
||||
|
@@ -1,17 +0,0 @@
|
||||
package handler
|
||||
|
||||
import "github.com/zitadel/zitadel/internal/eventstore"
|
||||
|
||||
//EventReducer represents the required data
|
||||
//to work with events
|
||||
type EventReducer struct {
|
||||
Event eventstore.EventType
|
||||
Reduce Reduce
|
||||
}
|
||||
|
||||
//EventReducer represents the required data
|
||||
//to work with aggregates
|
||||
type AggregateReducer struct {
|
||||
Aggregate eventstore.AggregateType
|
||||
EventRedusers []EventReducer
|
||||
}
|
@@ -2,77 +2,8 @@ package handler
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"github.com/zitadel/logging"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/eventstore"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrNoProjection = errors.New("no projection")
|
||||
ErrNoValues = errors.New("no values")
|
||||
ErrNoCondition = errors.New("no condition")
|
||||
ErrSomeStmtsFailed = errors.New("some statements failed")
|
||||
)
|
||||
|
||||
type Statements []Statement
|
||||
|
||||
func (stmts Statements) Len() int { return len(stmts) }
|
||||
func (stmts Statements) Swap(i, j int) { stmts[i], stmts[j] = stmts[j], stmts[i] }
|
||||
func (stmts Statements) Less(i, j int) bool { return stmts[i].Sequence < stmts[j].Sequence }
|
||||
|
||||
type Statement struct {
|
||||
AggregateType eventstore.AggregateType
|
||||
Sequence uint64
|
||||
PreviousSequence uint64
|
||||
InstanceID string
|
||||
|
||||
Execute func(ex Executer, projectionName string) error
|
||||
}
|
||||
|
||||
func (s *Statement) IsNoop() bool {
|
||||
return s.Execute == nil
|
||||
}
|
||||
|
||||
type Executer interface {
|
||||
Exec(string, ...interface{}) (sql.Result, error)
|
||||
}
|
||||
|
||||
type Column struct {
|
||||
Name string
|
||||
Value interface{}
|
||||
ParameterOpt func(string) string
|
||||
}
|
||||
|
||||
func NewCol(name string, value interface{}) Column {
|
||||
return Column{
|
||||
Name: name,
|
||||
Value: value,
|
||||
}
|
||||
}
|
||||
|
||||
func NewJSONCol(name string, value interface{}) Column {
|
||||
marshalled, err := json.Marshal(value)
|
||||
if err != nil {
|
||||
logging.WithFields("column", name).WithError(err).Panic("unable to marshal column")
|
||||
}
|
||||
|
||||
return NewCol(name, marshalled)
|
||||
}
|
||||
|
||||
type Condition func(param string) (string, interface{})
|
||||
|
||||
type NamespacedCondition func(namespace string) Condition
|
||||
|
||||
func NewCond(name string, value interface{}) Condition {
|
||||
return func(param string) (string, interface{}) {
|
||||
return name + " = " + param, value
|
||||
}
|
||||
}
|
||||
|
||||
func NewNamespacedCondition(name string, value interface{}) NamespacedCondition {
|
||||
return func(namespace string) Condition {
|
||||
return NewCond(namespace+"."+name, value)
|
||||
}
|
||||
}
|
||||
|
52
internal/eventstore/handler/v2/event.go
Normal file
52
internal/eventstore/handler/v2/event.go
Normal file
@@ -0,0 +1,52 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/eventstore"
|
||||
)
|
||||
|
||||
const (
|
||||
schedulerSucceeded = eventstore.EventType("system.projections.scheduler.succeeded")
|
||||
aggregateType = eventstore.AggregateType("system")
|
||||
aggregateID = "SYSTEM"
|
||||
)
|
||||
|
||||
func (h *Handler) didProjectionInitialize(ctx context.Context) bool {
|
||||
events, err := h.es.Filter(ctx, eventstore.NewSearchQueryBuilder(eventstore.ColumnsEvent).
|
||||
InstanceID("").
|
||||
AddQuery().
|
||||
AggregateTypes(aggregateType).
|
||||
AggregateIDs(aggregateID).
|
||||
EventTypes(schedulerSucceeded).
|
||||
EventData(map[string]interface{}{
|
||||
"name": h.projection.Name(),
|
||||
}).
|
||||
Builder(),
|
||||
)
|
||||
return len(events) > 0 && err == nil
|
||||
}
|
||||
|
||||
func (h *Handler) setSucceededOnce(ctx context.Context) error {
|
||||
_, err := h.es.Push(ctx, &ProjectionSucceededEvent{
|
||||
BaseEvent: *eventstore.NewBaseEventForPush(ctx,
|
||||
eventstore.NewAggregate(ctx, aggregateID, aggregateType, "v1"),
|
||||
schedulerSucceeded,
|
||||
),
|
||||
Name: h.projection.Name(),
|
||||
})
|
||||
return err
|
||||
}
|
||||
|
||||
type ProjectionSucceededEvent struct {
|
||||
eventstore.BaseEvent `json:"-"`
|
||||
Name string `json:"name"`
|
||||
}
|
||||
|
||||
func (p *ProjectionSucceededEvent) Payload() interface{} {
|
||||
return p
|
||||
}
|
||||
|
||||
func (p *ProjectionSucceededEvent) UniqueConstraints() []*eventstore.UniqueConstraint {
|
||||
return nil
|
||||
}
|
95
internal/eventstore/handler/v2/failed_event.go
Normal file
95
internal/eventstore/handler/v2/failed_event.go
Normal file
@@ -0,0 +1,95 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
_ "embed"
|
||||
"time"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/errors"
|
||||
"github.com/zitadel/zitadel/internal/eventstore"
|
||||
)
|
||||
|
||||
var (
|
||||
//go:embed failed_event_set.sql
|
||||
setFailedEventStmt string
|
||||
//go:embed failed_event_get_count.sql
|
||||
failureCountStmt string
|
||||
)
|
||||
|
||||
type failure struct {
|
||||
sequence uint64
|
||||
instance string
|
||||
aggregateID string
|
||||
aggregateType eventstore.AggregateType
|
||||
eventDate time.Time
|
||||
err error
|
||||
}
|
||||
|
||||
func failureFromEvent(event eventstore.Event, err error) *failure {
|
||||
return &failure{
|
||||
sequence: event.Sequence(),
|
||||
instance: event.Aggregate().InstanceID,
|
||||
aggregateID: event.Aggregate().ID,
|
||||
aggregateType: event.Aggregate().Type,
|
||||
eventDate: event.CreatedAt(),
|
||||
err: err,
|
||||
}
|
||||
}
|
||||
|
||||
func failureFromStatement(statement *Statement, err error) *failure {
|
||||
return &failure{
|
||||
sequence: statement.Sequence,
|
||||
instance: statement.InstanceID,
|
||||
aggregateID: statement.AggregateID,
|
||||
aggregateType: statement.AggregateType,
|
||||
eventDate: statement.CreationDate,
|
||||
err: err,
|
||||
}
|
||||
}
|
||||
|
||||
func (h *Handler) handleFailedStmt(tx *sql.Tx, currentState *state, f *failure) (shouldContinue bool) {
|
||||
failureCount, err := h.failureCount(tx, f)
|
||||
if err != nil {
|
||||
h.logFailure(f).WithError(err).Warn("unable to get failure count")
|
||||
return false
|
||||
}
|
||||
failureCount += 1
|
||||
err = h.setFailureCount(tx, failureCount, f)
|
||||
h.logFailure(f).OnError(err).Warn("unable to update failure count")
|
||||
|
||||
return failureCount >= h.maxFailureCount
|
||||
}
|
||||
|
||||
func (h *Handler) failureCount(tx *sql.Tx, f *failure) (count uint8, err error) {
|
||||
row := tx.QueryRow(failureCountStmt,
|
||||
h.projection.Name(),
|
||||
f.instance,
|
||||
f.aggregateType,
|
||||
f.aggregateID,
|
||||
f.sequence,
|
||||
)
|
||||
if err = row.Err(); err != nil {
|
||||
return 0, errors.ThrowInternal(err, "CRDB-Unnex", "unable to update failure count")
|
||||
}
|
||||
if err = row.Scan(&count); err != nil {
|
||||
return 0, errors.ThrowInternal(err, "CRDB-RwSMV", "unable to scan count")
|
||||
}
|
||||
return count, nil
|
||||
}
|
||||
|
||||
func (h *Handler) setFailureCount(tx *sql.Tx, count uint8, f *failure) error {
|
||||
_, err := tx.Exec(setFailedEventStmt,
|
||||
h.projection.Name(),
|
||||
f.instance,
|
||||
f.aggregateType,
|
||||
f.aggregateID,
|
||||
f.eventDate,
|
||||
f.sequence,
|
||||
count,
|
||||
f.err.Error(),
|
||||
)
|
||||
if err != nil {
|
||||
return errors.ThrowInternal(err, "CRDB-4Ht4x", "set failure count failed")
|
||||
}
|
||||
return nil
|
||||
}
|
12
internal/eventstore/handler/v2/failed_event_get_count.sql
Normal file
12
internal/eventstore/handler/v2/failed_event_get_count.sql
Normal file
@@ -0,0 +1,12 @@
|
||||
WITH failures AS (
|
||||
SELECT
|
||||
failure_count
|
||||
FROM
|
||||
projections.failed_events2
|
||||
WHERE
|
||||
projection_name = $1
|
||||
AND instance_id = $2
|
||||
AND aggregate_type = $3
|
||||
AND aggregate_id = $4
|
||||
AND failed_sequence = $5
|
||||
) SELECT COALESCE((SELECT failure_count FROM failures), 0) AS failure_count
|
31
internal/eventstore/handler/v2/failed_event_set.sql
Normal file
31
internal/eventstore/handler/v2/failed_event_set.sql
Normal file
@@ -0,0 +1,31 @@
|
||||
INSERT INTO projections.failed_events2 (
|
||||
projection_name
|
||||
, instance_id
|
||||
, aggregate_type
|
||||
, aggregate_id
|
||||
, event_creation_date
|
||||
, failed_sequence
|
||||
, failure_count
|
||||
, error
|
||||
, last_failed
|
||||
) VALUES (
|
||||
$1
|
||||
, $2
|
||||
, $3
|
||||
, $4
|
||||
, $5
|
||||
, $6
|
||||
, $7
|
||||
, $8
|
||||
, now()
|
||||
) ON CONFLICT (
|
||||
projection_name
|
||||
, aggregate_type
|
||||
, aggregate_id
|
||||
, failed_sequence
|
||||
, instance_id
|
||||
) DO UPDATE SET
|
||||
failure_count = EXCLUDED.failure_count
|
||||
, error = EXCLUDED.error
|
||||
, last_failed = EXCLUDED.last_failed
|
||||
;
|
465
internal/eventstore/handler/v2/handler.go
Normal file
465
internal/eventstore/handler/v2/handler.go
Normal file
@@ -0,0 +1,465 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"math"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/jackc/pgconn"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/api/authz"
|
||||
"github.com/zitadel/zitadel/internal/api/call"
|
||||
"github.com/zitadel/zitadel/internal/database"
|
||||
"github.com/zitadel/zitadel/internal/eventstore"
|
||||
"github.com/zitadel/zitadel/internal/repository/pseudo"
|
||||
"github.com/zitadel/zitadel/internal/telemetry/tracing"
|
||||
)
|
||||
|
||||
type EventStore interface {
|
||||
InstanceIDs(ctx context.Context, maxAge time.Duration, forceLoad bool, query *eventstore.SearchQueryBuilder) ([]string, error)
|
||||
Filter(ctx context.Context, queryFactory *eventstore.SearchQueryBuilder) ([]eventstore.Event, error)
|
||||
Push(ctx context.Context, cmds ...eventstore.Command) ([]eventstore.Event, error)
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
Client *database.DB
|
||||
Eventstore EventStore
|
||||
|
||||
BulkLimit uint16
|
||||
RequeueEvery time.Duration
|
||||
RetryFailedAfter time.Duration
|
||||
HandleActiveInstances time.Duration
|
||||
TransactionDuration time.Duration
|
||||
MaxFailureCount uint8
|
||||
|
||||
TriggerWithoutEvents Reduce
|
||||
}
|
||||
|
||||
type Handler struct {
|
||||
client *database.DB
|
||||
projection Projection
|
||||
|
||||
es EventStore
|
||||
bulkLimit uint16
|
||||
eventTypes map[eventstore.AggregateType][]eventstore.EventType
|
||||
|
||||
maxFailureCount uint8
|
||||
retryFailedAfter time.Duration
|
||||
requeueEvery time.Duration
|
||||
handleActiveInstances time.Duration
|
||||
txDuration time.Duration
|
||||
now nowFunc
|
||||
|
||||
triggeredInstancesSync sync.Map
|
||||
|
||||
triggerWithoutEvents Reduce
|
||||
}
|
||||
|
||||
// nowFunc makes [time.Now] mockable
|
||||
type nowFunc func() time.Time
|
||||
|
||||
type Projection interface {
|
||||
Name() string
|
||||
Reducers() []AggregateReducer
|
||||
}
|
||||
|
||||
func NewHandler(
|
||||
ctx context.Context,
|
||||
config *Config,
|
||||
projection Projection,
|
||||
) *Handler {
|
||||
aggregates := make(map[eventstore.AggregateType][]eventstore.EventType, len(projection.Reducers()))
|
||||
for _, reducer := range projection.Reducers() {
|
||||
eventTypes := make([]eventstore.EventType, len(reducer.EventReducers))
|
||||
for i, eventReducer := range reducer.EventReducers {
|
||||
eventTypes[i] = eventReducer.Event
|
||||
}
|
||||
if _, ok := aggregates[reducer.Aggregate]; ok {
|
||||
aggregates[reducer.Aggregate] = append(aggregates[reducer.Aggregate], eventTypes...)
|
||||
continue
|
||||
}
|
||||
aggregates[reducer.Aggregate] = eventTypes
|
||||
}
|
||||
|
||||
handler := &Handler{
|
||||
projection: projection,
|
||||
client: config.Client,
|
||||
es: config.Eventstore,
|
||||
bulkLimit: config.BulkLimit,
|
||||
eventTypes: aggregates,
|
||||
requeueEvery: config.RequeueEvery,
|
||||
handleActiveInstances: config.HandleActiveInstances,
|
||||
now: time.Now,
|
||||
maxFailureCount: config.MaxFailureCount,
|
||||
retryFailedAfter: config.RetryFailedAfter,
|
||||
triggeredInstancesSync: sync.Map{},
|
||||
triggerWithoutEvents: config.TriggerWithoutEvents,
|
||||
txDuration: config.TransactionDuration,
|
||||
}
|
||||
|
||||
return handler
|
||||
}
|
||||
|
||||
func (h *Handler) Start(ctx context.Context) {
|
||||
go h.schedule(ctx)
|
||||
if h.triggerWithoutEvents != nil {
|
||||
return
|
||||
}
|
||||
go h.subscribe(ctx)
|
||||
}
|
||||
|
||||
func (h *Handler) schedule(ctx context.Context) {
|
||||
// if there was no run before trigger instantly
|
||||
t := time.NewTimer(0)
|
||||
didInitialize := h.didProjectionInitialize(ctx)
|
||||
if didInitialize {
|
||||
t.Reset(h.requeueEvery)
|
||||
}
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
t.Stop()
|
||||
return
|
||||
case <-t.C:
|
||||
instances, err := h.queryInstances(ctx, didInitialize)
|
||||
h.log().OnError(err).Debug("unable to query instances")
|
||||
|
||||
var instanceFailed bool
|
||||
scheduledCtx := call.WithTimestamp(ctx)
|
||||
for _, instance := range instances {
|
||||
instanceCtx := authz.WithInstanceID(scheduledCtx, instance)
|
||||
|
||||
// simple implementation of do while
|
||||
_, err = h.Trigger(instanceCtx)
|
||||
instanceFailed = instanceFailed || err != nil
|
||||
h.log().WithField("instance", instance).OnError(err).Info("scheduled trigger failed")
|
||||
// retry if trigger failed
|
||||
for ; err != nil; _, err = h.Trigger(instanceCtx) {
|
||||
time.Sleep(h.retryFailedAfter)
|
||||
instanceFailed = instanceFailed || err != nil
|
||||
h.log().WithField("instance", instance).OnError(err).Info("scheduled trigger failed")
|
||||
if err == nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !didInitialize && !instanceFailed {
|
||||
err = h.setSucceededOnce(ctx)
|
||||
h.log().OnError(err).Debug("unable to set succeeded once")
|
||||
didInitialize = err == nil
|
||||
}
|
||||
t.Reset(h.requeueEvery)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (h *Handler) subscribe(ctx context.Context) {
|
||||
queue := make(chan eventstore.Event, 100)
|
||||
subscription := eventstore.SubscribeEventTypes(queue, h.eventTypes)
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
subscription.Unsubscribe()
|
||||
h.log().Debug("shutdown")
|
||||
return
|
||||
case event := <-queue:
|
||||
events := checkAdditionalEvents(queue, event)
|
||||
solvedInstances := make([]string, 0, len(events))
|
||||
queueCtx := call.WithTimestamp(ctx)
|
||||
for _, e := range events {
|
||||
if instanceSolved(solvedInstances, e.Aggregate().InstanceID) {
|
||||
continue
|
||||
}
|
||||
queueCtx = authz.WithInstanceID(queueCtx, e.Aggregate().InstanceID)
|
||||
_, err := h.Trigger(queueCtx)
|
||||
h.log().OnError(err).Debug("trigger of queued event failed")
|
||||
if err == nil {
|
||||
solvedInstances = append(solvedInstances, e.Aggregate().InstanceID)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func instanceSolved(solvedInstances []string, instanceID string) bool {
|
||||
for _, solvedInstance := range solvedInstances {
|
||||
if solvedInstance == instanceID {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func checkAdditionalEvents(eventQueue chan eventstore.Event, event eventstore.Event) []eventstore.Event {
|
||||
events := make([]eventstore.Event, 1)
|
||||
events[0] = event
|
||||
for {
|
||||
wait := time.NewTimer(1 * time.Millisecond)
|
||||
select {
|
||||
case event := <-eventQueue:
|
||||
events = append(events, event)
|
||||
case <-wait.C:
|
||||
return events
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (h *Handler) queryInstances(ctx context.Context, didInitialize bool) ([]string, error) {
|
||||
query := eventstore.NewSearchQueryBuilder(eventstore.ColumnsInstanceIDs).
|
||||
AwaitOpenTransactions().
|
||||
AllowTimeTravel().
|
||||
ExcludedInstanceID("")
|
||||
if didInitialize {
|
||||
query = query.
|
||||
CreationDateAfter(h.now().Add(-1 * h.handleActiveInstances))
|
||||
}
|
||||
return h.es.InstanceIDs(ctx, h.requeueEvery, !didInitialize, query)
|
||||
}
|
||||
|
||||
type triggerConfig struct {
|
||||
awaitRunning bool
|
||||
}
|
||||
|
||||
type triggerOpt func(conf *triggerConfig)
|
||||
|
||||
func WithAwaitRunning() triggerOpt {
|
||||
return func(conf *triggerConfig) {
|
||||
conf.awaitRunning = true
|
||||
}
|
||||
}
|
||||
|
||||
func (h *Handler) Trigger(ctx context.Context, opts ...triggerOpt) (_ context.Context, err error) {
|
||||
if authz.GetInstance(ctx).InstanceID() != "" {
|
||||
var span *tracing.Span
|
||||
ctx, span = tracing.NewSpan(ctx)
|
||||
defer func() { span.EndWithError(err) }()
|
||||
}
|
||||
|
||||
config := new(triggerConfig)
|
||||
for _, opt := range opts {
|
||||
opt(config)
|
||||
}
|
||||
|
||||
cancel := h.lockInstance(ctx, config)
|
||||
if cancel == nil {
|
||||
return call.ResetTimestamp(ctx), nil
|
||||
}
|
||||
defer cancel()
|
||||
|
||||
for i := 0; ; i++ {
|
||||
additionalIteration, err := h.processEvents(ctx, config)
|
||||
h.log().OnError(err).Warn("process events failed")
|
||||
h.log().WithField("iteration", i).Debug("trigger iteration")
|
||||
if !additionalIteration || err != nil {
|
||||
return call.ResetTimestamp(ctx), err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// lockInstances tries to lock the instance.
|
||||
// If the instance is already locked from another process no cancel function is returned
|
||||
// the instance can be skipped then
|
||||
// If the instance is locked, an unlock deferable function is returned
|
||||
func (h *Handler) lockInstance(ctx context.Context, config *triggerConfig) func() {
|
||||
instanceID := authz.GetInstance(ctx).InstanceID()
|
||||
|
||||
// Check that the instance has a mutex to lock
|
||||
instanceMu, _ := h.triggeredInstancesSync.LoadOrStore(instanceID, new(sync.Mutex))
|
||||
unlock := func() {
|
||||
instanceMu.(*sync.Mutex).Unlock()
|
||||
}
|
||||
if !instanceMu.(*sync.Mutex).TryLock() {
|
||||
instanceMu.(*sync.Mutex).Lock()
|
||||
if config.awaitRunning {
|
||||
return unlock
|
||||
}
|
||||
defer unlock()
|
||||
return nil
|
||||
}
|
||||
return unlock
|
||||
}
|
||||
|
||||
func (h *Handler) processEvents(ctx context.Context, config *triggerConfig) (additionalIteration bool, err error) {
|
||||
defer func() {
|
||||
pgErr := new(pgconn.PgError)
|
||||
if errors.As(err, &pgErr) {
|
||||
// error returned if the row is currently locked by another connection
|
||||
if pgErr.Code == "55P03" {
|
||||
h.log().Debug("state already locked")
|
||||
err = nil
|
||||
additionalIteration = false
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
if h.txDuration > 0 {
|
||||
var cancel func()
|
||||
ctx, cancel = context.WithTimeout(ctx, h.txDuration)
|
||||
defer cancel()
|
||||
}
|
||||
|
||||
tx, err := h.client.Begin()
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
defer func() {
|
||||
if err != nil {
|
||||
rollbackErr := tx.Rollback()
|
||||
h.log().OnError(rollbackErr).Debug("unable to rollback tx")
|
||||
return
|
||||
}
|
||||
err = tx.Commit()
|
||||
}()
|
||||
|
||||
currentState, err := h.currentState(ctx, tx, config)
|
||||
if err != nil {
|
||||
if errors.Is(err, errJustUpdated) {
|
||||
return false, nil
|
||||
}
|
||||
return additionalIteration, err
|
||||
}
|
||||
|
||||
var statements []*Statement
|
||||
statements, additionalIteration, err = h.generateStatements(ctx, tx, currentState)
|
||||
if err != nil || len(statements) == 0 {
|
||||
return additionalIteration, err
|
||||
}
|
||||
|
||||
lastProcessedIndex, err := h.executeStatements(ctx, tx, currentState, statements)
|
||||
if lastProcessedIndex < 0 {
|
||||
return false, err
|
||||
}
|
||||
|
||||
currentState.position = statements[lastProcessedIndex].Position
|
||||
currentState.aggregateID = statements[lastProcessedIndex].AggregateID
|
||||
currentState.aggregateType = statements[lastProcessedIndex].AggregateType
|
||||
currentState.sequence = statements[lastProcessedIndex].Sequence
|
||||
currentState.eventTimestamp = statements[lastProcessedIndex].CreationDate
|
||||
err = h.setState(tx, currentState)
|
||||
|
||||
return additionalIteration, err
|
||||
}
|
||||
|
||||
func (h *Handler) generateStatements(ctx context.Context, tx *sql.Tx, currentState *state) (_ []*Statement, additionalIteration bool, err error) {
|
||||
if h.triggerWithoutEvents != nil {
|
||||
stmt, err := h.triggerWithoutEvents(pseudo.NewScheduledEvent(ctx, time.Now(), currentState.instanceID))
|
||||
if err != nil {
|
||||
return nil, false, err
|
||||
}
|
||||
return []*Statement{stmt}, false, nil
|
||||
}
|
||||
|
||||
events, err := h.es.Filter(ctx, h.eventQuery(currentState))
|
||||
if err != nil {
|
||||
h.log().WithError(err).Debug("filter eventstore failed")
|
||||
return nil, false, err
|
||||
}
|
||||
eventAmount := len(events)
|
||||
events = skipPreviouslyReduced(events, currentState)
|
||||
|
||||
if len(events) == 0 {
|
||||
h.updateLastUpdated(ctx, tx, currentState)
|
||||
return nil, false, nil
|
||||
}
|
||||
|
||||
statements, err := h.eventsToStatements(tx, events, currentState)
|
||||
if len(statements) == 0 {
|
||||
return nil, false, err
|
||||
}
|
||||
|
||||
additionalIteration = eventAmount == int(h.bulkLimit)
|
||||
if len(statements) < len(events) {
|
||||
// retry imediatly if statements failed
|
||||
additionalIteration = true
|
||||
}
|
||||
|
||||
return statements, additionalIteration, nil
|
||||
}
|
||||
|
||||
func skipPreviouslyReduced(events []eventstore.Event, currentState *state) []eventstore.Event {
|
||||
for i, event := range events {
|
||||
if event.Position() == currentState.position &&
|
||||
event.Aggregate().ID == currentState.aggregateID &&
|
||||
event.Aggregate().Type == currentState.aggregateType &&
|
||||
event.Sequence() == currentState.sequence {
|
||||
return events[i+1:]
|
||||
}
|
||||
}
|
||||
return events
|
||||
}
|
||||
|
||||
func (h *Handler) executeStatements(ctx context.Context, tx *sql.Tx, currentState *state, statements []*Statement) (lastProcessedIndex int, err error) {
|
||||
lastProcessedIndex = -1
|
||||
|
||||
for i, statement := range statements {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
break
|
||||
default:
|
||||
err := h.executeStatement(ctx, tx, currentState, statement)
|
||||
if err != nil {
|
||||
return lastProcessedIndex, err
|
||||
}
|
||||
lastProcessedIndex = i
|
||||
}
|
||||
}
|
||||
return lastProcessedIndex, nil
|
||||
}
|
||||
|
||||
func (h *Handler) executeStatement(ctx context.Context, tx *sql.Tx, currentState *state, statement *Statement) (err error) {
|
||||
if statement.Execute == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
_, err = tx.Exec("SAVEPOINT exec")
|
||||
if err != nil {
|
||||
h.log().WithError(err).Debug("create savepoint failed")
|
||||
return err
|
||||
}
|
||||
var shouldContinue bool
|
||||
defer func() {
|
||||
_, err = tx.Exec("RELEASE SAVEPOINT exec")
|
||||
}()
|
||||
|
||||
if err = statement.Execute(tx, h.projection.Name()); err != nil {
|
||||
h.log().WithError(err).Error("statement execution failed")
|
||||
|
||||
shouldContinue = h.handleFailedStmt(tx, currentState, failureFromStatement(statement, err))
|
||||
if shouldContinue {
|
||||
return nil
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *Handler) eventQuery(currentState *state) *eventstore.SearchQueryBuilder {
|
||||
builder := eventstore.NewSearchQueryBuilder(eventstore.ColumnsEvent).
|
||||
AwaitOpenTransactions().
|
||||
Limit(uint64(h.bulkLimit)).
|
||||
AllowTimeTravel().
|
||||
OrderAsc().
|
||||
InstanceID(currentState.instanceID)
|
||||
|
||||
if currentState.position > 0 {
|
||||
builder = builder.PositionAfter(math.Float64frombits(math.Float64bits(currentState.position) - 10))
|
||||
}
|
||||
|
||||
for aggregateType, eventTypes := range h.eventTypes {
|
||||
query := builder.
|
||||
AddQuery().
|
||||
AggregateTypes(aggregateType).
|
||||
EventTypes(eventTypes...)
|
||||
|
||||
builder = query.Builder()
|
||||
}
|
||||
|
||||
return builder
|
||||
}
|
@@ -1,4 +1,4 @@
|
||||
package crdb
|
||||
package handler
|
||||
|
||||
import (
|
||||
"context"
|
||||
@@ -9,19 +9,19 @@ import (
|
||||
"github.com/jackc/pgconn"
|
||||
"github.com/zitadel/logging"
|
||||
|
||||
caos_errs "github.com/zitadel/zitadel/internal/errors"
|
||||
errs "github.com/zitadel/zitadel/internal/errors"
|
||||
"github.com/zitadel/zitadel/internal/eventstore/handler"
|
||||
)
|
||||
|
||||
type Table struct {
|
||||
columns []*Column
|
||||
columns []*InitColumn
|
||||
primaryKey PrimaryKey
|
||||
indices []*Index
|
||||
constraints []*Constraint
|
||||
foreignKeys []*ForeignKey
|
||||
}
|
||||
|
||||
func NewTable(columns []*Column, key PrimaryKey, opts ...TableOption) *Table {
|
||||
func NewTable(columns []*InitColumn, key PrimaryKey, opts ...TableOption) *Table {
|
||||
t := &Table{
|
||||
columns: columns,
|
||||
primaryKey: key,
|
||||
@@ -37,7 +37,7 @@ type SuffixedTable struct {
|
||||
suffix string
|
||||
}
|
||||
|
||||
func NewSuffixedTable(columns []*Column, key PrimaryKey, suffix string, opts ...TableOption) *SuffixedTable {
|
||||
func NewSuffixedTable(columns []*InitColumn, key PrimaryKey, suffix string, opts ...TableOption) *SuffixedTable {
|
||||
return &SuffixedTable{
|
||||
Table: *NewTable(columns, key, opts...),
|
||||
suffix: suffix,
|
||||
@@ -64,7 +64,7 @@ func WithForeignKey(key *ForeignKey) TableOption {
|
||||
}
|
||||
}
|
||||
|
||||
type Column struct {
|
||||
type InitColumn struct {
|
||||
Name string
|
||||
Type ColumnType
|
||||
nullable bool
|
||||
@@ -72,10 +72,10 @@ type Column struct {
|
||||
deleteCascade string
|
||||
}
|
||||
|
||||
type ColumnOption func(*Column)
|
||||
type ColumnOption func(*InitColumn)
|
||||
|
||||
func NewColumn(name string, columnType ColumnType, opts ...ColumnOption) *Column {
|
||||
column := &Column{
|
||||
func NewColumn(name string, columnType ColumnType, opts ...ColumnOption) *InitColumn {
|
||||
column := &InitColumn{
|
||||
Name: name,
|
||||
Type: columnType,
|
||||
nullable: false,
|
||||
@@ -88,19 +88,19 @@ func NewColumn(name string, columnType ColumnType, opts ...ColumnOption) *Column
|
||||
}
|
||||
|
||||
func Nullable() ColumnOption {
|
||||
return func(c *Column) {
|
||||
return func(c *InitColumn) {
|
||||
c.nullable = true
|
||||
}
|
||||
}
|
||||
|
||||
func Default(value interface{}) ColumnOption {
|
||||
return func(c *Column) {
|
||||
return func(c *InitColumn) {
|
||||
c.defaultValue = value
|
||||
}
|
||||
}
|
||||
|
||||
func DeleteCascade(column string) ColumnOption {
|
||||
return func(c *Column) {
|
||||
return func(c *InitColumn) {
|
||||
c.deleteCascade = column
|
||||
}
|
||||
}
|
||||
@@ -128,9 +128,8 @@ const (
|
||||
|
||||
func NewIndex(name string, columns []string, opts ...indexOpts) *Index {
|
||||
i := &Index{
|
||||
Name: name,
|
||||
Columns: columns,
|
||||
bucketCount: 0,
|
||||
Name: name,
|
||||
Columns: columns,
|
||||
}
|
||||
for _, opt := range opts {
|
||||
opt(i)
|
||||
@@ -139,16 +138,16 @@ func NewIndex(name string, columns []string, opts ...indexOpts) *Index {
|
||||
}
|
||||
|
||||
type Index struct {
|
||||
Name string
|
||||
Columns []string
|
||||
bucketCount uint16
|
||||
Name string
|
||||
Columns []string
|
||||
includes []string
|
||||
}
|
||||
|
||||
type indexOpts func(*Index)
|
||||
|
||||
func Hash(bucketsCount uint16) indexOpts {
|
||||
func WithInclude(columns ...string) indexOpts {
|
||||
return func(i *Index) {
|
||||
i.bucketCount = bucketsCount
|
||||
i.includes = columns
|
||||
}
|
||||
}
|
||||
|
||||
@@ -186,25 +185,28 @@ type ForeignKey struct {
|
||||
RefColumns []string
|
||||
}
|
||||
|
||||
// Init implements handler.Init
|
||||
func (h *StatementHandler) Init(ctx context.Context) error {
|
||||
check := h.initCheck
|
||||
if check == nil || check.IsNoop() {
|
||||
type initializer interface {
|
||||
Init() *handler.Check
|
||||
}
|
||||
|
||||
func (h *Handler) Init(ctx context.Context) error {
|
||||
check, ok := h.projection.(initializer)
|
||||
if !ok || check.Init().IsNoop() {
|
||||
return nil
|
||||
}
|
||||
tx, err := h.client.BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
return caos_errs.ThrowInternal(err, "CRDB-SAdf2", "begin failed")
|
||||
return errs.ThrowInternal(err, "CRDB-SAdf2", "begin failed")
|
||||
}
|
||||
for i, execute := range check.Executes {
|
||||
logging.WithFields("projection", h.ProjectionName, "execute", i).Debug("executing check")
|
||||
next, err := execute(h.client, h.ProjectionName)
|
||||
for i, execute := range check.Init().Executes {
|
||||
logging.WithFields("projection", h.projection.Name(), "execute", i).Debug("executing check")
|
||||
next, err := execute(tx, h.projection.Name())
|
||||
if err != nil {
|
||||
tx.Rollback()
|
||||
logging.OnError(tx.Rollback()).Debug("unable to rollback")
|
||||
return err
|
||||
}
|
||||
if !next {
|
||||
logging.WithFields("projection", h.ProjectionName, "execute", i).Debug("projection set up")
|
||||
logging.WithFields("projection", h.projection.Name(), "execute", i).Debug("projection set up")
|
||||
break
|
||||
}
|
||||
}
|
||||
@@ -272,15 +274,15 @@ func execNextIfExists(config execConfig, q query, opts []execOption, executeNext
|
||||
}
|
||||
|
||||
func isErrAlreadyExists(err error) bool {
|
||||
caosErr := &caos_errs.CaosError{}
|
||||
caosErr := &errs.CaosError{}
|
||||
if !errors.As(err, &caosErr) {
|
||||
return false
|
||||
}
|
||||
sqlErr, ok := caosErr.GetParent().(*pgconn.PgError)
|
||||
if !ok {
|
||||
return false
|
||||
pgErr := new(pgconn.PgError)
|
||||
if errors.As(caosErr.Parent, &pgErr) {
|
||||
return pgErr.Code == "42P07"
|
||||
}
|
||||
return sqlErr.Code == "42P07"
|
||||
return false
|
||||
}
|
||||
|
||||
func createTableStatement(table *Table, tableName string, suffix string) string {
|
||||
@@ -330,11 +332,10 @@ func createIndexStatement(index *Index, tableName string) string {
|
||||
tableName,
|
||||
strings.Join(index.Columns, ","),
|
||||
)
|
||||
if index.bucketCount == 0 {
|
||||
return stmt + ";"
|
||||
if len(index.includes) > 0 {
|
||||
stmt += " INCLUDE (" + strings.Join(index.includes, ", ") + ")"
|
||||
}
|
||||
return fmt.Sprintf("SET experimental_enable_hash_sharded_indexes=on; %s USING HASH WITH BUCKET_COUNT = %d;",
|
||||
stmt, index.bucketCount)
|
||||
return stmt + ";"
|
||||
}
|
||||
|
||||
func foreignKeyName(name, tableName, suffix string) string {
|
||||
@@ -355,7 +356,7 @@ func tableNameWithoutSchema(name string) string {
|
||||
return name[strings.LastIndex(name, ".")+1:]
|
||||
}
|
||||
|
||||
func createColumnsStatement(cols []*Column, tableName string) string {
|
||||
func createColumnsStatement(cols []*InitColumn, tableName string) string {
|
||||
columns := make([]string, len(cols))
|
||||
for i, col := range cols {
|
||||
column := col.Name + " " + columnType(col.Type)
|
23
internal/eventstore/handler/v2/log.go
Normal file
23
internal/eventstore/handler/v2/log.go
Normal file
@@ -0,0 +1,23 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"github.com/zitadel/logging"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/eventstore"
|
||||
)
|
||||
|
||||
func (h *Handler) log() *logging.Entry {
|
||||
return logging.WithFields("projection", h.projection.Name())
|
||||
}
|
||||
|
||||
func (h *Handler) logFailure(fail *failure) *logging.Entry {
|
||||
return h.log().WithField("sequence", fail.sequence).
|
||||
WithField("instance", fail.instance).
|
||||
WithField("aggregate", fail.aggregateID)
|
||||
}
|
||||
|
||||
func (h *Handler) logEvent(event eventstore.Event) *logging.Entry {
|
||||
return h.log().WithField("sequence", event.Sequence()).
|
||||
WithField("instance", event.Aggregate().InstanceID).
|
||||
WithField("aggregate", event.Aggregate().Type)
|
||||
}
|
18
internal/eventstore/handler/v2/mock_test.go
Normal file
18
internal/eventstore/handler/v2/mock_test.go
Normal file
@@ -0,0 +1,18 @@
|
||||
package handler
|
||||
|
||||
var _ Projection = (*projection)(nil)
|
||||
|
||||
type projection struct {
|
||||
name string
|
||||
reducers []AggregateReducer
|
||||
}
|
||||
|
||||
// Name implements Projection
|
||||
func (p *projection) Name() string {
|
||||
return p.name
|
||||
}
|
||||
|
||||
// Reducers implements Projection
|
||||
func (p *projection) Reducers() []AggregateReducer {
|
||||
return p.reducers
|
||||
}
|
21
internal/eventstore/handler/v2/reduce.go
Normal file
21
internal/eventstore/handler/v2/reduce.go
Normal file
@@ -0,0 +1,21 @@
|
||||
package handler
|
||||
|
||||
import "github.com/zitadel/zitadel/internal/eventstore"
|
||||
|
||||
// EventReducer represents the required data
|
||||
// to work with events
|
||||
type EventReducer struct {
|
||||
Event eventstore.EventType
|
||||
Reduce Reduce
|
||||
}
|
||||
|
||||
// Reduce reduces the given event to a statement
|
||||
// which is used to update the projection
|
||||
type Reduce func(eventstore.Event) (*Statement, error)
|
||||
|
||||
// EventReducer represents the required data
|
||||
// to work with aggregates
|
||||
type AggregateReducer struct {
|
||||
Aggregate eventstore.AggregateType
|
||||
EventReducers []EventReducer
|
||||
}
|
119
internal/eventstore/handler/v2/state.go
Normal file
119
internal/eventstore/handler/v2/state.go
Normal file
@@ -0,0 +1,119 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
_ "embed"
|
||||
"errors"
|
||||
"time"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/api/authz"
|
||||
errs "github.com/zitadel/zitadel/internal/errors"
|
||||
"github.com/zitadel/zitadel/internal/eventstore"
|
||||
)
|
||||
|
||||
type state struct {
|
||||
instanceID string
|
||||
position float64
|
||||
eventTimestamp time.Time
|
||||
aggregateType eventstore.AggregateType
|
||||
aggregateID string
|
||||
sequence uint64
|
||||
}
|
||||
|
||||
var (
|
||||
//go:embed state_get.sql
|
||||
currentStateStmt string
|
||||
//go:embed state_get_await.sql
|
||||
currentStateAwaitStmt string
|
||||
//go:embed state_set.sql
|
||||
updateStateStmt string
|
||||
//go:embed state_lock.sql
|
||||
lockStateStmt string
|
||||
//go:embed state_set_last_run.sql
|
||||
updateStateLastRunStmt string
|
||||
|
||||
errJustUpdated = errors.New("projection was just updated")
|
||||
)
|
||||
|
||||
func (h *Handler) currentState(ctx context.Context, tx *sql.Tx, config *triggerConfig) (currentState *state, err error) {
|
||||
currentState = &state{
|
||||
instanceID: authz.GetInstance(ctx).InstanceID(),
|
||||
}
|
||||
|
||||
var (
|
||||
aggregateID = new(sql.NullString)
|
||||
aggregateType = new(sql.NullString)
|
||||
sequence = new(sql.NullInt64)
|
||||
timestamp = new(sql.NullTime)
|
||||
position = new(sql.NullFloat64)
|
||||
)
|
||||
|
||||
stateQuery := currentStateStmt
|
||||
if config.awaitRunning {
|
||||
stateQuery = currentStateAwaitStmt
|
||||
}
|
||||
|
||||
row := tx.QueryRow(stateQuery, currentState.instanceID, h.projection.Name())
|
||||
err = row.Scan(
|
||||
aggregateID,
|
||||
aggregateType,
|
||||
sequence,
|
||||
timestamp,
|
||||
position,
|
||||
)
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
err = h.lockState(tx, currentState.instanceID)
|
||||
}
|
||||
if err != nil {
|
||||
h.log().WithError(err).Debug("unable to query current state")
|
||||
return nil, err
|
||||
}
|
||||
|
||||
currentState.aggregateID = aggregateID.String
|
||||
currentState.aggregateType = eventstore.AggregateType(aggregateType.String)
|
||||
currentState.sequence = uint64(sequence.Int64)
|
||||
currentState.eventTimestamp = timestamp.Time
|
||||
currentState.position = position.Float64
|
||||
return currentState, nil
|
||||
}
|
||||
|
||||
func (h *Handler) setState(tx *sql.Tx, updatedState *state) error {
|
||||
res, err := tx.Exec(updateStateStmt,
|
||||
h.projection.Name(),
|
||||
updatedState.instanceID,
|
||||
updatedState.aggregateID,
|
||||
updatedState.aggregateType,
|
||||
updatedState.sequence,
|
||||
updatedState.eventTimestamp,
|
||||
updatedState.position,
|
||||
)
|
||||
if err != nil {
|
||||
h.log().WithError(err).Debug("unable to update state")
|
||||
return err
|
||||
}
|
||||
if affected, err := res.RowsAffected(); affected == 0 {
|
||||
h.log().OnError(err).Error("unable to check if states are updated")
|
||||
return errs.ThrowInternal(err, "V2-FGEKi", "unable to update state")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *Handler) updateLastUpdated(ctx context.Context, tx *sql.Tx, updatedState *state) {
|
||||
_, err := tx.ExecContext(ctx, updateStateLastRunStmt, h.projection.Name(), updatedState.instanceID)
|
||||
h.log().OnError(err).Debug("unable to update last updated")
|
||||
}
|
||||
|
||||
func (h *Handler) lockState(tx *sql.Tx, instanceID string) error {
|
||||
res, err := tx.Exec(lockStateStmt,
|
||||
h.projection.Name(),
|
||||
instanceID,
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if affected, err := res.RowsAffected(); affected == 0 || err != nil {
|
||||
return errs.ThrowInternal(err, "V2-lpiK0", "projection already locked")
|
||||
}
|
||||
return nil
|
||||
}
|
12
internal/eventstore/handler/v2/state_get.sql
Normal file
12
internal/eventstore/handler/v2/state_get.sql
Normal file
@@ -0,0 +1,12 @@
|
||||
SELECT
|
||||
aggregate_id
|
||||
, aggregate_type
|
||||
, "sequence"
|
||||
, event_date
|
||||
, "position"
|
||||
FROM
|
||||
projections.current_states
|
||||
WHERE
|
||||
instance_id = $1
|
||||
AND projection_name = $2
|
||||
FOR UPDATE NOWAIT;
|
12
internal/eventstore/handler/v2/state_get_await.sql
Normal file
12
internal/eventstore/handler/v2/state_get_await.sql
Normal file
@@ -0,0 +1,12 @@
|
||||
SELECT
|
||||
aggregate_id
|
||||
, aggregate_type
|
||||
, "sequence"
|
||||
, event_date
|
||||
, "position"
|
||||
FROM
|
||||
projections.current_states
|
||||
WHERE
|
||||
instance_id = $1
|
||||
AND projection_name = $2
|
||||
FOR UPDATE;
|
9
internal/eventstore/handler/v2/state_lock.sql
Normal file
9
internal/eventstore/handler/v2/state_lock.sql
Normal file
@@ -0,0 +1,9 @@
|
||||
INSERT INTO projections.current_states (
|
||||
projection_name
|
||||
, instance_id
|
||||
, last_updated
|
||||
) VALUES (
|
||||
$1
|
||||
, $2
|
||||
, now()
|
||||
) ON CONFLICT DO NOTHING;
|
29
internal/eventstore/handler/v2/state_set.sql
Normal file
29
internal/eventstore/handler/v2/state_set.sql
Normal file
@@ -0,0 +1,29 @@
|
||||
INSERT INTO projections.current_states (
|
||||
projection_name
|
||||
, instance_id
|
||||
, aggregate_id
|
||||
, aggregate_type
|
||||
, "sequence"
|
||||
, event_date
|
||||
, "position"
|
||||
, last_updated
|
||||
) VALUES (
|
||||
$1
|
||||
, $2
|
||||
, $3
|
||||
, $4
|
||||
, $5
|
||||
, $6
|
||||
, $7
|
||||
, now()
|
||||
) ON CONFLICT (
|
||||
projection_name
|
||||
, instance_id
|
||||
) DO UPDATE SET
|
||||
aggregate_id = $3
|
||||
, aggregate_type = $4
|
||||
, "sequence" = $5
|
||||
, event_date = $6
|
||||
, "position" = $7
|
||||
, last_updated = statement_timestamp()
|
||||
;
|
2
internal/eventstore/handler/v2/state_set_last_run.sql
Normal file
2
internal/eventstore/handler/v2/state_set_last_run.sql
Normal file
@@ -0,0 +1,2 @@
|
||||
UPDATE projections.current_states SET last_updated = now() WHERE projection_name = $1 AND instance_id = $2;
|
||||
|
447
internal/eventstore/handler/v2/state_test.go
Normal file
447
internal/eventstore/handler/v2/state_test.go
Normal file
@@ -0,0 +1,447 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"database/sql/driver"
|
||||
_ "embed"
|
||||
"errors"
|
||||
"reflect"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/jackc/pgconn"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/api/authz"
|
||||
"github.com/zitadel/zitadel/internal/database/mock"
|
||||
errs "github.com/zitadel/zitadel/internal/errors"
|
||||
)
|
||||
|
||||
func TestHandler_lockState(t *testing.T) {
|
||||
type fields struct {
|
||||
projection Projection
|
||||
mock *mock.SQLMock
|
||||
}
|
||||
type args struct {
|
||||
instanceID string
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
args args
|
||||
isErr func(t *testing.T, err error)
|
||||
}{
|
||||
{
|
||||
name: "tx closed",
|
||||
fields: fields{
|
||||
projection: &projection{
|
||||
name: "projection",
|
||||
},
|
||||
mock: mock.NewSQLMock(t,
|
||||
mock.ExpectBegin(nil),
|
||||
mock.ExcpectExec(
|
||||
lockStateStmt,
|
||||
mock.WithExecArgs(
|
||||
"projection",
|
||||
"instance",
|
||||
),
|
||||
mock.WithExecErr(sql.ErrTxDone),
|
||||
),
|
||||
),
|
||||
},
|
||||
args: args{
|
||||
instanceID: "instance",
|
||||
},
|
||||
isErr: func(t *testing.T, err error) {
|
||||
if !errors.Is(err, sql.ErrTxDone) {
|
||||
t.Errorf("unexpected error, want: %v got: %v", sql.ErrTxDone, err)
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "no rows affeced",
|
||||
fields: fields{
|
||||
projection: &projection{
|
||||
name: "projection",
|
||||
},
|
||||
mock: mock.NewSQLMock(t,
|
||||
mock.ExpectBegin(nil),
|
||||
mock.ExcpectExec(
|
||||
lockStateStmt,
|
||||
mock.WithExecArgs(
|
||||
"projection",
|
||||
"instance",
|
||||
),
|
||||
mock.WithExecNoRowsAffected(),
|
||||
),
|
||||
),
|
||||
},
|
||||
args: args{
|
||||
instanceID: "instance",
|
||||
},
|
||||
isErr: func(t *testing.T, err error) {
|
||||
if !errors.Is(err, errs.ThrowInternal(nil, "V2-lpiK0", "")) {
|
||||
t.Errorf("unexpected error: want internal (V2lpiK0), got: %v", err)
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "rows affected",
|
||||
fields: fields{
|
||||
projection: &projection{
|
||||
name: "projection",
|
||||
},
|
||||
mock: mock.NewSQLMock(t,
|
||||
mock.ExpectBegin(nil),
|
||||
mock.ExcpectExec(
|
||||
lockStateStmt,
|
||||
mock.WithExecArgs(
|
||||
"projection",
|
||||
"instance",
|
||||
),
|
||||
mock.WithExecRowsAffected(1),
|
||||
),
|
||||
),
|
||||
},
|
||||
args: args{
|
||||
instanceID: "instance",
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
if tt.isErr == nil {
|
||||
tt.isErr = func(t *testing.T, err error) {
|
||||
if err != nil {
|
||||
t.Error("expected no error got:", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
h := &Handler{
|
||||
projection: tt.fields.projection,
|
||||
}
|
||||
|
||||
tx, err := tt.fields.mock.DB.Begin()
|
||||
if err != nil {
|
||||
t.Fatalf("unable to begin transaction: %v", err)
|
||||
}
|
||||
|
||||
err = h.lockState(tx, tt.args.instanceID)
|
||||
tt.isErr(t, err)
|
||||
|
||||
tt.fields.mock.Assert(t)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandler_updateLastUpdated(t *testing.T) {
|
||||
type fields struct {
|
||||
projection Projection
|
||||
mock *mock.SQLMock
|
||||
}
|
||||
type args struct {
|
||||
updatedState *state
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
args args
|
||||
isErr func(t *testing.T, err error)
|
||||
}{
|
||||
{
|
||||
name: "update fails",
|
||||
fields: fields{
|
||||
projection: &projection{
|
||||
name: "instance",
|
||||
},
|
||||
mock: mock.NewSQLMock(t,
|
||||
mock.ExpectBegin(nil),
|
||||
mock.ExcpectExec(updateStateStmt,
|
||||
mock.WithExecErr(sql.ErrTxDone),
|
||||
),
|
||||
),
|
||||
},
|
||||
args: args{
|
||||
updatedState: &state{
|
||||
instanceID: "instance",
|
||||
eventTimestamp: time.Now(),
|
||||
position: 42,
|
||||
},
|
||||
},
|
||||
isErr: func(t *testing.T, err error) {
|
||||
if !errors.Is(err, sql.ErrTxDone) {
|
||||
t.Errorf("unexpected error, want: %v, got %v", sql.ErrTxDone, err)
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "no rows affected",
|
||||
fields: fields{
|
||||
projection: &projection{
|
||||
name: "instance",
|
||||
},
|
||||
mock: mock.NewSQLMock(t,
|
||||
mock.ExpectBegin(nil),
|
||||
mock.ExcpectExec(updateStateStmt,
|
||||
mock.WithExecNoRowsAffected(),
|
||||
),
|
||||
),
|
||||
},
|
||||
args: args{
|
||||
updatedState: &state{
|
||||
instanceID: "instance",
|
||||
eventTimestamp: time.Now(),
|
||||
position: 42,
|
||||
},
|
||||
},
|
||||
isErr: func(t *testing.T, err error) {
|
||||
if !errors.Is(err, errs.ThrowInternal(nil, "V2-FGEKi", "")) {
|
||||
t.Errorf("unexpected error, want: %v, got %v", sql.ErrTxDone, err)
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "success",
|
||||
fields: fields{
|
||||
projection: &projection{
|
||||
name: "projection",
|
||||
},
|
||||
mock: mock.NewSQLMock(t,
|
||||
mock.ExpectBegin(nil),
|
||||
mock.ExcpectExec(updateStateStmt,
|
||||
mock.WithExecArgs(
|
||||
"projection",
|
||||
"instance",
|
||||
"aggregate id",
|
||||
"aggregate type",
|
||||
uint64(42),
|
||||
mock.AnyType[time.Time]{},
|
||||
float64(42),
|
||||
),
|
||||
mock.WithExecRowsAffected(1),
|
||||
),
|
||||
),
|
||||
},
|
||||
args: args{
|
||||
updatedState: &state{
|
||||
instanceID: "instance",
|
||||
eventTimestamp: time.Now(),
|
||||
position: 42,
|
||||
aggregateType: "aggregate type",
|
||||
aggregateID: "aggregate id",
|
||||
sequence: 42,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
if tt.isErr == nil {
|
||||
tt.isErr = func(t *testing.T, err error) {
|
||||
if err != nil {
|
||||
t.Error("expected no error got:", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
tx, err := tt.fields.mock.DB.Begin()
|
||||
if err != nil {
|
||||
t.Fatalf("unable to begin transaction: %v", err)
|
||||
}
|
||||
|
||||
h := &Handler{
|
||||
projection: tt.fields.projection,
|
||||
}
|
||||
err = h.setState(tx, tt.args.updatedState)
|
||||
|
||||
tt.isErr(t, err)
|
||||
tt.fields.mock.Assert(t)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandler_currentState(t *testing.T) {
|
||||
testTime := time.Now()
|
||||
type fields struct {
|
||||
projection Projection
|
||||
mock *mock.SQLMock
|
||||
}
|
||||
type args struct {
|
||||
ctx context.Context
|
||||
}
|
||||
type want struct {
|
||||
currentState *state
|
||||
isErr func(t *testing.T, err error)
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
args args
|
||||
want want
|
||||
}{
|
||||
{
|
||||
name: "connection done",
|
||||
fields: fields{
|
||||
projection: &projection{
|
||||
name: "projection",
|
||||
},
|
||||
mock: mock.NewSQLMock(t,
|
||||
mock.ExpectBegin(nil),
|
||||
mock.ExpectQuery(currentStateStmt,
|
||||
mock.WithQueryArgs(
|
||||
"instance",
|
||||
"projection",
|
||||
),
|
||||
mock.WithQueryErr(sql.ErrConnDone),
|
||||
),
|
||||
),
|
||||
},
|
||||
args: args{
|
||||
ctx: authz.WithInstanceID(context.Background(), "instance"),
|
||||
},
|
||||
want: want{
|
||||
isErr: func(t *testing.T, err error) {
|
||||
if !errors.Is(err, sql.ErrConnDone) {
|
||||
t.Errorf("unexpected error, want: %v, got: %v", sql.ErrConnDone, err)
|
||||
}
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "no row but lock err",
|
||||
fields: fields{
|
||||
projection: &projection{
|
||||
name: "projection",
|
||||
},
|
||||
mock: mock.NewSQLMock(t,
|
||||
mock.ExpectBegin(nil),
|
||||
mock.ExpectQuery(currentStateStmt,
|
||||
mock.WithQueryArgs(
|
||||
"instance",
|
||||
"projection",
|
||||
),
|
||||
mock.WithQueryErr(sql.ErrNoRows),
|
||||
),
|
||||
mock.ExcpectExec(lockStateStmt,
|
||||
mock.WithExecArgs(
|
||||
"projection",
|
||||
"instance",
|
||||
),
|
||||
mock.WithExecErr(sql.ErrTxDone),
|
||||
),
|
||||
),
|
||||
},
|
||||
args: args{
|
||||
ctx: authz.WithInstanceID(context.Background(), "instance"),
|
||||
},
|
||||
want: want{
|
||||
isErr: func(t *testing.T, err error) {
|
||||
if !errors.Is(err, sql.ErrTxDone) {
|
||||
t.Errorf("unexpected error, want: %v, got: %v", sql.ErrTxDone, err)
|
||||
}
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "state locked",
|
||||
fields: fields{
|
||||
projection: &projection{
|
||||
name: "projection",
|
||||
},
|
||||
mock: mock.NewSQLMock(t,
|
||||
mock.ExpectBegin(nil),
|
||||
mock.ExpectQuery(currentStateStmt,
|
||||
mock.WithQueryArgs(
|
||||
"instance",
|
||||
"projection",
|
||||
),
|
||||
mock.WithQueryErr(&pgconn.PgError{Code: "55P03"}),
|
||||
),
|
||||
),
|
||||
},
|
||||
args: args{
|
||||
ctx: authz.WithInstanceID(context.Background(), "instance"),
|
||||
},
|
||||
want: want{
|
||||
isErr: func(t *testing.T, err error) {
|
||||
pgErr := new(pgconn.PgError)
|
||||
if !errors.As(err, &pgErr) {
|
||||
t.Errorf("error should be PgErr but was %T", err)
|
||||
return
|
||||
}
|
||||
if pgErr.Code != "55P03" {
|
||||
t.Errorf("expected code 55P03 got: %s", pgErr.Code)
|
||||
}
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "success",
|
||||
fields: fields{
|
||||
projection: &projection{
|
||||
name: "projection",
|
||||
},
|
||||
mock: mock.NewSQLMock(t,
|
||||
mock.ExpectBegin(nil),
|
||||
mock.ExpectQuery(currentStateStmt,
|
||||
mock.WithQueryArgs(
|
||||
"instance",
|
||||
"projection",
|
||||
),
|
||||
mock.WithQueryResult(
|
||||
[]string{"aggregate_id", "aggregate_type", "event_sequence", "event_date", "position"},
|
||||
[][]driver.Value{
|
||||
{
|
||||
"aggregate id",
|
||||
"aggregate type",
|
||||
int64(42),
|
||||
testTime,
|
||||
float64(42),
|
||||
},
|
||||
},
|
||||
),
|
||||
),
|
||||
),
|
||||
},
|
||||
args: args{
|
||||
ctx: authz.WithInstanceID(context.Background(), "instance"),
|
||||
},
|
||||
want: want{
|
||||
currentState: &state{
|
||||
instanceID: "instance",
|
||||
eventTimestamp: testTime,
|
||||
position: 42,
|
||||
aggregateType: "aggregate type",
|
||||
aggregateID: "aggregate id",
|
||||
sequence: 42,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
if tt.want.isErr == nil {
|
||||
tt.want.isErr = func(t *testing.T, err error) {
|
||||
if err != nil {
|
||||
t.Error("expected no error got:", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
h := &Handler{
|
||||
projection: tt.fields.projection,
|
||||
}
|
||||
|
||||
tx, err := tt.fields.mock.DB.Begin()
|
||||
if err != nil {
|
||||
t.Fatalf("unable to begin transaction: %v", err)
|
||||
}
|
||||
|
||||
gotCurrentState, err := h.currentState(tt.args.ctx, tx, new(triggerConfig))
|
||||
|
||||
tt.want.isErr(t, err)
|
||||
if !reflect.DeepEqual(gotCurrentState, tt.want.currentState) {
|
||||
t.Errorf("Handler.currentState() gotCurrentState = %v, want %v", gotCurrentState, tt.want.currentState)
|
||||
}
|
||||
tt.fields.mock.Assert(t)
|
||||
})
|
||||
}
|
||||
}
|
@@ -1,41 +1,89 @@
|
||||
package crdb
|
||||
package handler
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"errors"
|
||||
"encoding/json"
|
||||
errs "errors"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/zitadel/logging"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/database"
|
||||
zitadel_errors "github.com/zitadel/zitadel/internal/errors"
|
||||
"github.com/zitadel/zitadel/internal/errors"
|
||||
"github.com/zitadel/zitadel/internal/eventstore"
|
||||
"github.com/zitadel/zitadel/internal/eventstore/handler"
|
||||
)
|
||||
|
||||
type execOption func(*execConfig)
|
||||
type execConfig struct {
|
||||
tableName string
|
||||
|
||||
args []interface{}
|
||||
err error
|
||||
ignoreNotFound bool
|
||||
func (h *Handler) eventsToStatements(tx *sql.Tx, events []eventstore.Event, currentState *state) (statements []*Statement, err error) {
|
||||
statements = make([]*Statement, 0, len(events))
|
||||
for _, event := range events {
|
||||
statement, err := h.reduce(event)
|
||||
if err != nil {
|
||||
h.logEvent(event).WithError(err).Error("reduce failed")
|
||||
if shouldContinue := h.handleFailedStmt(tx, currentState, failureFromEvent(event, err)); shouldContinue {
|
||||
continue
|
||||
}
|
||||
return statements, err
|
||||
}
|
||||
statements = append(statements, statement)
|
||||
}
|
||||
return statements, nil
|
||||
}
|
||||
|
||||
func (h *Handler) reduce(event eventstore.Event) (*Statement, error) {
|
||||
for _, reducer := range h.projection.Reducers() {
|
||||
if reducer.Aggregate != event.Aggregate().Type {
|
||||
continue
|
||||
}
|
||||
for _, reduce := range reducer.EventReducers {
|
||||
if reduce.Event != event.Type() {
|
||||
continue
|
||||
}
|
||||
return reduce.Reduce(event)
|
||||
}
|
||||
}
|
||||
return NewNoOpStatement(event), nil
|
||||
}
|
||||
|
||||
type Statement struct {
|
||||
AggregateType eventstore.AggregateType
|
||||
AggregateID string
|
||||
Sequence uint64
|
||||
Position float64
|
||||
CreationDate time.Time
|
||||
InstanceID string
|
||||
|
||||
Execute Exec
|
||||
}
|
||||
|
||||
type Exec func(ex Executer, projectionName string) error
|
||||
|
||||
func WithTableSuffix(name string) func(*execConfig) {
|
||||
return func(o *execConfig) {
|
||||
o.tableName += "_" + name
|
||||
}
|
||||
}
|
||||
|
||||
func WithIgnoreNotFound() func(*execConfig) {
|
||||
return func(o *execConfig) {
|
||||
o.ignoreNotFound = true
|
||||
var (
|
||||
ErrNoProjection = errs.New("no projection")
|
||||
ErrNoValues = errs.New("no values")
|
||||
ErrNoCondition = errs.New("no condition")
|
||||
)
|
||||
|
||||
func NewStatement(event eventstore.Event, e Exec) *Statement {
|
||||
return &Statement{
|
||||
AggregateType: event.Aggregate().Type,
|
||||
Sequence: event.Sequence(),
|
||||
Position: event.Position(),
|
||||
AggregateID: event.Aggregate().ID,
|
||||
CreationDate: event.CreatedAt(),
|
||||
InstanceID: event.Aggregate().InstanceID,
|
||||
Execute: e,
|
||||
}
|
||||
}
|
||||
|
||||
func NewCreateStatement(event eventstore.Event, values []handler.Column, opts ...execOption) *handler.Statement {
|
||||
func NewCreateStatement(event eventstore.Event, values []Column, opts ...execOption) *Statement {
|
||||
cols, params, args := columnsToQuery(values)
|
||||
columnNames := strings.Join(cols, ", ")
|
||||
valuesPlaceholder := strings.Join(params, ", ")
|
||||
@@ -45,23 +93,17 @@ func NewCreateStatement(event eventstore.Event, values []handler.Column, opts ..
|
||||
}
|
||||
|
||||
if len(values) == 0 {
|
||||
config.err = handler.ErrNoValues
|
||||
config.err = ErrNoValues
|
||||
}
|
||||
|
||||
q := func(config execConfig) string {
|
||||
return "INSERT INTO " + config.tableName + " (" + columnNames + ") VALUES (" + valuesPlaceholder + ")"
|
||||
}
|
||||
|
||||
return &handler.Statement{
|
||||
AggregateType: event.Aggregate().Type,
|
||||
Sequence: event.Sequence(),
|
||||
PreviousSequence: event.PreviousAggregateTypeSequence(),
|
||||
InstanceID: event.Aggregate().InstanceID,
|
||||
Execute: exec(config, q, opts),
|
||||
}
|
||||
return NewStatement(event, exec(config, q, opts))
|
||||
}
|
||||
|
||||
func NewUpsertStatement(event eventstore.Event, conflictCols []handler.Column, values []handler.Column, opts ...execOption) *handler.Statement {
|
||||
func NewUpsertStatement(event eventstore.Event, conflictCols []Column, values []Column, opts ...execOption) *Statement {
|
||||
cols, params, args := columnsToQuery(values)
|
||||
|
||||
conflictTarget := make([]string, len(conflictCols))
|
||||
@@ -74,12 +116,12 @@ func NewUpsertStatement(event eventstore.Event, conflictCols []handler.Column, v
|
||||
}
|
||||
|
||||
if len(values) == 0 {
|
||||
config.err = handler.ErrNoValues
|
||||
config.err = ErrNoValues
|
||||
}
|
||||
|
||||
updateCols, updateVals := getUpdateCols(cols, conflictTarget)
|
||||
if len(updateCols) == 0 || len(updateVals) == 0 {
|
||||
config.err = handler.ErrNoValues
|
||||
config.err = ErrNoValues
|
||||
}
|
||||
|
||||
q := func(config execConfig) string {
|
||||
@@ -96,13 +138,7 @@ func NewUpsertStatement(event eventstore.Event, conflictCols []handler.Column, v
|
||||
" ON CONFLICT (" + strings.Join(conflictTarget, ", ") + ") DO " + updateStmt
|
||||
}
|
||||
|
||||
return &handler.Statement{
|
||||
AggregateType: event.Aggregate().Type,
|
||||
Sequence: event.Sequence(),
|
||||
PreviousSequence: event.PreviousAggregateTypeSequence(),
|
||||
InstanceID: event.Aggregate().InstanceID,
|
||||
Execute: exec(config, q, opts),
|
||||
}
|
||||
return NewStatement(event, exec(config, q, opts))
|
||||
}
|
||||
|
||||
func getUpdateCols(cols, conflictTarget []string) (updateCols, updateVals []string) {
|
||||
@@ -132,9 +168,9 @@ func getUpdateCols(cols, conflictTarget []string) (updateCols, updateVals []stri
|
||||
return updateCols, updateVals
|
||||
}
|
||||
|
||||
func NewUpdateStatement(event eventstore.Event, values []handler.Column, conditions []handler.Condition, opts ...execOption) *handler.Statement {
|
||||
func NewUpdateStatement(event eventstore.Event, values []Column, conditions []Condition, opts ...execOption) *Statement {
|
||||
cols, params, args := columnsToQuery(values)
|
||||
wheres, whereArgs := conditionsToWhere(conditions, len(args))
|
||||
wheres, whereArgs := conditionsToWhere(conditions, len(args)+1)
|
||||
args = append(args, whereArgs...)
|
||||
|
||||
config := execConfig{
|
||||
@@ -142,11 +178,11 @@ func NewUpdateStatement(event eventstore.Event, values []handler.Column, conditi
|
||||
}
|
||||
|
||||
if len(values) == 0 {
|
||||
config.err = handler.ErrNoValues
|
||||
config.err = ErrNoValues
|
||||
}
|
||||
|
||||
if len(conditions) == 0 {
|
||||
config.err = handler.ErrNoCondition
|
||||
config.err = ErrNoCondition
|
||||
}
|
||||
|
||||
q := func(config execConfig) string {
|
||||
@@ -159,17 +195,11 @@ func NewUpdateStatement(event eventstore.Event, values []handler.Column, conditi
|
||||
return "UPDATE " + config.tableName + " SET (" + strings.Join(cols, ", ") + ") = (" + strings.Join(params, ", ") + ") WHERE " + strings.Join(wheres, " AND ")
|
||||
}
|
||||
|
||||
return &handler.Statement{
|
||||
AggregateType: event.Aggregate().Type,
|
||||
Sequence: event.Sequence(),
|
||||
PreviousSequence: event.PreviousAggregateTypeSequence(),
|
||||
InstanceID: event.Aggregate().InstanceID,
|
||||
Execute: exec(config, q, opts),
|
||||
}
|
||||
return NewStatement(event, exec(config, q, opts))
|
||||
}
|
||||
|
||||
func NewDeleteStatement(event eventstore.Event, conditions []handler.Condition, opts ...execOption) *handler.Statement {
|
||||
wheres, args := conditionsToWhere(conditions, 0)
|
||||
func NewDeleteStatement(event eventstore.Event, conditions []Condition, opts ...execOption) *Statement {
|
||||
wheres, args := conditionsToWhere(conditions, 1)
|
||||
|
||||
wheresPlaceholders := strings.Join(wheres, " AND ")
|
||||
|
||||
@@ -178,32 +208,21 @@ func NewDeleteStatement(event eventstore.Event, conditions []handler.Condition,
|
||||
}
|
||||
|
||||
if len(conditions) == 0 {
|
||||
config.err = handler.ErrNoCondition
|
||||
config.err = ErrNoCondition
|
||||
}
|
||||
|
||||
q := func(config execConfig) string {
|
||||
return "DELETE FROM " + config.tableName + " WHERE " + wheresPlaceholders
|
||||
}
|
||||
|
||||
return &handler.Statement{
|
||||
AggregateType: event.Aggregate().Type,
|
||||
Sequence: event.Sequence(),
|
||||
PreviousSequence: event.PreviousAggregateTypeSequence(),
|
||||
InstanceID: event.Aggregate().InstanceID,
|
||||
Execute: exec(config, q, opts),
|
||||
}
|
||||
return NewStatement(event, exec(config, q, opts))
|
||||
}
|
||||
|
||||
func NewNoOpStatement(event eventstore.Event) *handler.Statement {
|
||||
return &handler.Statement{
|
||||
AggregateType: event.Aggregate().Type,
|
||||
Sequence: event.Sequence(),
|
||||
PreviousSequence: event.PreviousAggregateTypeSequence(),
|
||||
InstanceID: event.Aggregate().InstanceID,
|
||||
}
|
||||
func NewNoOpStatement(event eventstore.Event) *Statement {
|
||||
return NewStatement(event, nil)
|
||||
}
|
||||
|
||||
func NewMultiStatement(event eventstore.Event, opts ...func(eventstore.Event) Exec) *handler.Statement {
|
||||
func NewMultiStatement(event eventstore.Event, opts ...func(eventstore.Event) Exec) *Statement {
|
||||
if len(opts) == 0 {
|
||||
return NewNoOpStatement(event)
|
||||
}
|
||||
@@ -211,43 +230,47 @@ func NewMultiStatement(event eventstore.Event, opts ...func(eventstore.Event) Ex
|
||||
for i, opt := range opts {
|
||||
execs[i] = opt(event)
|
||||
}
|
||||
return &handler.Statement{
|
||||
AggregateType: event.Aggregate().Type,
|
||||
Sequence: event.Sequence(),
|
||||
PreviousSequence: event.PreviousAggregateTypeSequence(),
|
||||
InstanceID: event.Aggregate().InstanceID,
|
||||
Execute: multiExec(execs),
|
||||
return NewStatement(event, multiExec(execs))
|
||||
}
|
||||
|
||||
func AddNoOpStatement() func(eventstore.Event) Exec {
|
||||
return func(event eventstore.Event) Exec {
|
||||
return NewNoOpStatement(event).Execute
|
||||
}
|
||||
}
|
||||
|
||||
type Exec func(ex handler.Executer, projectionName string) error
|
||||
|
||||
func AddCreateStatement(columns []handler.Column, opts ...execOption) func(eventstore.Event) Exec {
|
||||
func AddCreateStatement(columns []Column, opts ...execOption) func(eventstore.Event) Exec {
|
||||
return func(event eventstore.Event) Exec {
|
||||
return NewCreateStatement(event, columns, opts...).Execute
|
||||
}
|
||||
}
|
||||
|
||||
func AddUpsertStatement(indexCols []handler.Column, values []handler.Column, opts ...execOption) func(eventstore.Event) Exec {
|
||||
func AddUpsertStatement(indexCols []Column, values []Column, opts ...execOption) func(eventstore.Event) Exec {
|
||||
return func(event eventstore.Event) Exec {
|
||||
return NewUpsertStatement(event, indexCols, values, opts...).Execute
|
||||
}
|
||||
}
|
||||
|
||||
func AddUpdateStatement(values []handler.Column, conditions []handler.Condition, opts ...execOption) func(eventstore.Event) Exec {
|
||||
func AddUpdateStatement(values []Column, conditions []Condition, opts ...execOption) func(eventstore.Event) Exec {
|
||||
return func(event eventstore.Event) Exec {
|
||||
return NewUpdateStatement(event, values, conditions, opts...).Execute
|
||||
}
|
||||
}
|
||||
|
||||
func AddDeleteStatement(conditions []handler.Condition, opts ...execOption) func(eventstore.Event) Exec {
|
||||
func AddDeleteStatement(conditions []Condition, opts ...execOption) func(eventstore.Event) Exec {
|
||||
return func(event eventstore.Event) Exec {
|
||||
return NewDeleteStatement(event, conditions, opts...).Execute
|
||||
}
|
||||
}
|
||||
|
||||
func NewArrayAppendCol(column string, value interface{}) handler.Column {
|
||||
return handler.Column{
|
||||
func AddCopyStatement(conflict, from, to []Column, conditions []NamespacedCondition, opts ...execOption) func(eventstore.Event) Exec {
|
||||
return func(event eventstore.Event) Exec {
|
||||
return NewCopyStatement(event, conflict, from, to, conditions, opts...).Execute
|
||||
}
|
||||
}
|
||||
|
||||
func NewArrayAppendCol(column string, value interface{}) Column {
|
||||
return Column{
|
||||
Name: column,
|
||||
Value: value,
|
||||
ParameterOpt: func(placeholder string) string {
|
||||
@@ -256,8 +279,8 @@ func NewArrayAppendCol(column string, value interface{}) handler.Column {
|
||||
}
|
||||
}
|
||||
|
||||
func NewArrayRemoveCol(column string, value interface{}) handler.Column {
|
||||
return handler.Column{
|
||||
func NewArrayRemoveCol(column string, value interface{}) Column {
|
||||
return Column{
|
||||
Name: column,
|
||||
Value: value,
|
||||
ParameterOpt: func(placeholder string) string {
|
||||
@@ -266,15 +289,15 @@ func NewArrayRemoveCol(column string, value interface{}) handler.Column {
|
||||
}
|
||||
}
|
||||
|
||||
func NewArrayIntersectCol(column string, value interface{}) handler.Column {
|
||||
func NewArrayIntersectCol(column string, value interface{}) Column {
|
||||
var arrayType string
|
||||
switch value.(type) {
|
||||
|
||||
case []string, database.StringArray:
|
||||
case []string, database.TextArray[string]:
|
||||
arrayType = "TEXT"
|
||||
//TODO: handle more types if necessary
|
||||
}
|
||||
return handler.Column{
|
||||
return Column{
|
||||
Name: column,
|
||||
Value: value,
|
||||
ParameterOpt: func(placeholder string) string {
|
||||
@@ -283,38 +306,10 @@ func NewArrayIntersectCol(column string, value interface{}) handler.Column {
|
||||
}
|
||||
}
|
||||
|
||||
func NewCopyCol(column, from string) handler.Column {
|
||||
return handler.Column{
|
||||
func NewCopyCol(column, from string) Column {
|
||||
return Column{
|
||||
Name: column,
|
||||
Value: handler.NewCol(from, nil),
|
||||
}
|
||||
}
|
||||
|
||||
func NewLessThanCond(column string, value interface{}) handler.Condition {
|
||||
return func(param string) (string, interface{}) {
|
||||
return column + " < " + param, value
|
||||
}
|
||||
}
|
||||
|
||||
func NewIsNullCond(column string) handler.Condition {
|
||||
return func(param string) (string, interface{}) {
|
||||
return column + " IS NULL", nil
|
||||
}
|
||||
}
|
||||
|
||||
// NewTextArrayContainsCond returns a handler.Condition that checks if the column that stores an array of text contains the given value
|
||||
func NewTextArrayContainsCond(column string, value string) handler.Condition {
|
||||
return func(param string) (string, interface{}) {
|
||||
return column + " @> " + param, database.StringArray{value}
|
||||
}
|
||||
}
|
||||
|
||||
// Not is a function and not a method, so that calling it is well readable
|
||||
// For example conditions := []handler.Condition{ Not(NewTextArrayContainsCond())}
|
||||
func Not(condition handler.Condition) handler.Condition {
|
||||
return func(param string) (string, interface{}) {
|
||||
cond, value := condition(param)
|
||||
return "NOT (" + cond + ")", value
|
||||
Value: NewCol(from, nil),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -323,7 +318,7 @@ func Not(condition handler.Condition) handler.Condition {
|
||||
// if the value of a col is empty the data will be copied from the selected row
|
||||
// if the value of a col is not empty the data will be set by the static value
|
||||
// conds represent the conditions for the selection subquery
|
||||
func NewCopyStatement(event eventstore.Event, conflictCols, from, to []handler.Column, nsCond []handler.NamespacedCondition, opts ...execOption) *handler.Statement {
|
||||
func NewCopyStatement(event eventstore.Event, conflictCols, from, to []Column, nsCond []NamespacedCondition, opts ...execOption) *Statement {
|
||||
columnNames := make([]string, len(to))
|
||||
selectColumns := make([]string, len(from))
|
||||
updateColumns := make([]string, len(columnNames))
|
||||
@@ -342,11 +337,11 @@ func NewCopyStatement(event eventstore.Event, conflictCols, from, to []handler.C
|
||||
}
|
||||
|
||||
}
|
||||
cond := make([]handler.Condition, len(nsCond))
|
||||
cond := make([]Condition, len(nsCond))
|
||||
for i := range nsCond {
|
||||
cond[i] = nsCond[i]("copy_table")
|
||||
}
|
||||
wheres, values := conditionsToWhere(cond, len(args))
|
||||
wheres, values := conditionsToWhere(cond, len(args)+1)
|
||||
args = append(args, values...)
|
||||
|
||||
conflictTargets := make([]string, len(conflictCols))
|
||||
@@ -359,11 +354,11 @@ func NewCopyStatement(event eventstore.Event, conflictCols, from, to []handler.C
|
||||
}
|
||||
|
||||
if len(from) == 0 || len(to) == 0 || len(from) != len(to) {
|
||||
config.err = handler.ErrNoValues
|
||||
config.err = ErrNoValues
|
||||
}
|
||||
|
||||
if len(cond) == 0 {
|
||||
config.err = handler.ErrNoCondition
|
||||
config.err = ErrNoCondition
|
||||
}
|
||||
|
||||
q := func(config execConfig) string {
|
||||
@@ -385,23 +380,17 @@ func NewCopyStatement(event eventstore.Event, conflictCols, from, to []handler.C
|
||||
")"
|
||||
}
|
||||
|
||||
return &handler.Statement{
|
||||
AggregateType: event.Aggregate().Type,
|
||||
Sequence: event.Sequence(),
|
||||
PreviousSequence: event.PreviousAggregateTypeSequence(),
|
||||
InstanceID: event.Aggregate().InstanceID,
|
||||
Execute: exec(config, q, opts),
|
||||
}
|
||||
return NewStatement(event, exec(config, q, opts))
|
||||
}
|
||||
|
||||
func columnsToQuery(cols []handler.Column) (names []string, parameters []string, values []interface{}) {
|
||||
func columnsToQuery(cols []Column) (names []string, parameters []string, values []interface{}) {
|
||||
names = make([]string, len(cols))
|
||||
values = make([]interface{}, len(cols))
|
||||
parameters = make([]string, len(cols))
|
||||
var parameterIndex int
|
||||
for i, col := range cols {
|
||||
names[i] = col.Name
|
||||
if c, ok := col.Value.(handler.Column); ok {
|
||||
if c, ok := col.Value.(Column); ok {
|
||||
parameters[i] = c.Name
|
||||
continue
|
||||
} else {
|
||||
@@ -416,25 +405,105 @@ func columnsToQuery(cols []handler.Column) (names []string, parameters []string,
|
||||
return names, parameters, values[:parameterIndex]
|
||||
}
|
||||
|
||||
func conditionsToWhere(conditions []handler.Condition, paramOffset int) (wheres []string, values []interface{}) {
|
||||
wheres = make([]string, len(conditions))
|
||||
values = make([]interface{}, 0, len(conditions))
|
||||
for i, conditionFunc := range conditions {
|
||||
condition, value := conditionFunc("$" + strconv.Itoa(i+1+paramOffset))
|
||||
wheres[i] = "(" + condition + ")"
|
||||
if value != nil {
|
||||
values = append(values, value)
|
||||
}
|
||||
func conditionsToWhere(conds []Condition, paramOffset int) (wheres []string, values []interface{}) {
|
||||
wheres = make([]string, len(conds))
|
||||
values = make([]any, 0, len(conds))
|
||||
|
||||
for i, cond := range conds {
|
||||
var args []any
|
||||
wheres[i], args = cond("$" + strconv.Itoa(paramOffset))
|
||||
paramOffset += len(args)
|
||||
values = append(values, args...)
|
||||
wheres[i] = "(" + wheres[i] + ")"
|
||||
}
|
||||
|
||||
return wheres, values
|
||||
}
|
||||
|
||||
type Column struct {
|
||||
Name string
|
||||
Value interface{}
|
||||
ParameterOpt func(string) string
|
||||
}
|
||||
|
||||
func NewCol(name string, value interface{}) Column {
|
||||
return Column{
|
||||
Name: name,
|
||||
Value: value,
|
||||
}
|
||||
}
|
||||
|
||||
func NewJSONCol(name string, value interface{}) Column {
|
||||
marshalled, err := json.Marshal(value)
|
||||
if err != nil {
|
||||
logging.WithFields("column", name).WithError(err).Panic("unable to marshal column")
|
||||
}
|
||||
|
||||
return NewCol(name, marshalled)
|
||||
}
|
||||
|
||||
type Condition func(param string) (string, []any)
|
||||
|
||||
type NamespacedCondition func(namespace string) Condition
|
||||
|
||||
func NewCond(name string, value interface{}) Condition {
|
||||
return func(param string) (string, []any) {
|
||||
return name + " = " + param, []any{value}
|
||||
}
|
||||
}
|
||||
|
||||
func NewNamespacedCondition(name string, value interface{}) NamespacedCondition {
|
||||
return func(namespace string) Condition {
|
||||
return NewCond(namespace+"."+name, value)
|
||||
}
|
||||
}
|
||||
|
||||
func NewLessThanCond(column string, value interface{}) Condition {
|
||||
return func(param string) (string, []any) {
|
||||
return column + " < " + param, []any{value}
|
||||
}
|
||||
}
|
||||
|
||||
func NewIsNullCond(column string) Condition {
|
||||
return func(string) (string, []any) {
|
||||
return column + " IS NULL", nil
|
||||
}
|
||||
}
|
||||
|
||||
// NewTextArrayContainsCond returns a Condition that checks if the column that stores an array of text contains the given value
|
||||
func NewTextArrayContainsCond(column string, value string) Condition {
|
||||
return func(param string) (string, []any) {
|
||||
return column + " @> " + param, []any{database.TextArray[string]{value}}
|
||||
}
|
||||
}
|
||||
|
||||
// Not is a function and not a method, so that calling it is well readable
|
||||
// For example conditions := []Condition{ Not(NewTextArrayContainsCond())}
|
||||
func Not(condition Condition) Condition {
|
||||
return func(param string) (string, []any) {
|
||||
cond, value := condition(param)
|
||||
return "NOT (" + cond + ")", value
|
||||
}
|
||||
}
|
||||
|
||||
type Executer interface {
|
||||
Exec(string, ...interface{}) (sql.Result, error)
|
||||
}
|
||||
|
||||
type execOption func(*execConfig)
|
||||
type execConfig struct {
|
||||
tableName string
|
||||
|
||||
args []interface{}
|
||||
err error
|
||||
}
|
||||
|
||||
type query func(config execConfig) string
|
||||
|
||||
func exec(config execConfig, q query, opts []execOption) Exec {
|
||||
return func(ex handler.Executer, projectionName string) error {
|
||||
return func(ex Executer, projectionName string) (err error) {
|
||||
if projectionName == "" {
|
||||
return handler.ErrNoProjection
|
||||
return ErrNoProjection
|
||||
}
|
||||
|
||||
if config.err != nil {
|
||||
@@ -446,12 +515,21 @@ func exec(config execConfig, q query, opts []execOption) Exec {
|
||||
opt(&config)
|
||||
}
|
||||
|
||||
if _, err := ex.Exec(q(config), config.args...); err != nil {
|
||||
if config.ignoreNotFound && errors.Is(err, sql.ErrNoRows) {
|
||||
logging.WithError(err).Debugf("ignored not found: %v", err)
|
||||
return nil
|
||||
_, err = ex.Exec("SAVEPOINT stmt_exec")
|
||||
if err != nil {
|
||||
return errors.ThrowInternal(err, "CRDB-YdOXD", "create savepoint failed")
|
||||
}
|
||||
defer func() {
|
||||
if err != nil {
|
||||
_, rollbackErr := ex.Exec("ROLLBACK TO SAVEPOINT stmt_exec")
|
||||
logging.OnError(rollbackErr).Debug("rollback failed")
|
||||
return
|
||||
}
|
||||
return zitadel_errors.ThrowInternal(err, "CRDB-pKtsr", "exec failed")
|
||||
_, err = ex.Exec("RELEASE SAVEPOINT stmt_exec")
|
||||
}()
|
||||
_, err = ex.Exec(q(config), config.args...)
|
||||
if err != nil {
|
||||
return errors.ThrowInternal(err, "CRDB-pKtsr", "exec failed")
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -459,8 +537,11 @@ func exec(config execConfig, q query, opts []execOption) Exec {
|
||||
}
|
||||
|
||||
func multiExec(execList []Exec) Exec {
|
||||
return func(ex handler.Executer, projectionName string) error {
|
||||
return func(ex Executer, projectionName string) error {
|
||||
for _, exec := range execList {
|
||||
if exec == nil {
|
||||
continue
|
||||
}
|
||||
if err := exec(ex, projectionName); err != nil {
|
||||
return err
|
||||
}
|
@@ -1,14 +1,14 @@
|
||||
package crdb
|
||||
package handler
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"errors"
|
||||
"reflect"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/database"
|
||||
"github.com/zitadel/zitadel/internal/eventstore"
|
||||
"github.com/zitadel/zitadel/internal/eventstore/handler"
|
||||
)
|
||||
|
||||
type wantExecuter struct {
|
||||
@@ -24,21 +24,50 @@ type params struct {
|
||||
args []interface{}
|
||||
}
|
||||
|
||||
var errTestErr = errors.New("some error")
|
||||
var errTest = errors.New("some error")
|
||||
|
||||
var _ eventstore.Event = &testEvent{}
|
||||
|
||||
type testEvent struct {
|
||||
eventstore.BaseEvent
|
||||
sequence uint64
|
||||
previousSequence uint64
|
||||
aggregateType eventstore.AggregateType
|
||||
instanceID string
|
||||
}
|
||||
|
||||
func (e *testEvent) Sequence() uint64 {
|
||||
return e.sequence
|
||||
}
|
||||
|
||||
func (e *testEvent) Aggregate() *eventstore.Aggregate {
|
||||
return &eventstore.Aggregate{
|
||||
Type: e.aggregateType,
|
||||
InstanceID: e.instanceID,
|
||||
}
|
||||
}
|
||||
|
||||
func (e *testEvent) PreviousAggregateTypeSequence() uint64 {
|
||||
return e.previousSequence
|
||||
}
|
||||
|
||||
func (ex *wantExecuter) check(t *testing.T) {
|
||||
t.Helper()
|
||||
if ex.wasExecuted && !ex.shouldExecute {
|
||||
switch {
|
||||
case ex.wasExecuted && !ex.shouldExecute:
|
||||
t.Error("executer should not be executed")
|
||||
} else if !ex.wasExecuted && ex.shouldExecute {
|
||||
case !ex.wasExecuted && ex.shouldExecute:
|
||||
t.Error("executer should be executed")
|
||||
} else if ex.wasExecuted != ex.shouldExecute {
|
||||
case ex.wasExecuted != ex.shouldExecute:
|
||||
t.Errorf("executed missmatched should be %t, but was %t", ex.shouldExecute, ex.wasExecuted)
|
||||
}
|
||||
}
|
||||
|
||||
func (ex *wantExecuter) Exec(query string, args ...interface{}) (sql.Result, error) {
|
||||
ex.t.Helper()
|
||||
if strings.Contains(query, "SAVEPOINT") {
|
||||
return nil, nil
|
||||
}
|
||||
ex.wasExecuted = true
|
||||
if ex.i >= len(ex.params) {
|
||||
ex.t.Errorf("did not expect more exec, but got:\n %q with %q", query, args)
|
||||
@@ -59,7 +88,7 @@ func TestNewCreateStatement(t *testing.T) {
|
||||
type args struct {
|
||||
table string
|
||||
event *testEvent
|
||||
values []handler.Column
|
||||
values []Column
|
||||
}
|
||||
type want struct {
|
||||
aggregateType eventstore.AggregateType
|
||||
@@ -83,7 +112,7 @@ func TestNewCreateStatement(t *testing.T) {
|
||||
sequence: 1,
|
||||
previousSequence: 0,
|
||||
},
|
||||
values: []handler.Column{
|
||||
values: []Column{
|
||||
{
|
||||
Name: "col1",
|
||||
Value: "val",
|
||||
@@ -99,7 +128,7 @@ func TestNewCreateStatement(t *testing.T) {
|
||||
shouldExecute: false,
|
||||
},
|
||||
isErr: func(err error) bool {
|
||||
return errors.Is(err, handler.ErrNoProjection)
|
||||
return errors.Is(err, ErrNoProjection)
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -112,7 +141,7 @@ func TestNewCreateStatement(t *testing.T) {
|
||||
sequence: 1,
|
||||
previousSequence: 0,
|
||||
},
|
||||
values: []handler.Column{},
|
||||
values: []Column{},
|
||||
},
|
||||
want: want{
|
||||
table: "my_table",
|
||||
@@ -123,7 +152,7 @@ func TestNewCreateStatement(t *testing.T) {
|
||||
shouldExecute: false,
|
||||
},
|
||||
isErr: func(err error) bool {
|
||||
return errors.Is(err, handler.ErrNoValues)
|
||||
return errors.Is(err, ErrNoValues)
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -136,7 +165,7 @@ func TestNewCreateStatement(t *testing.T) {
|
||||
sequence: 1,
|
||||
previousSequence: 0,
|
||||
},
|
||||
values: []handler.Column{
|
||||
values: []Column{
|
||||
{
|
||||
Name: "col1",
|
||||
Value: "val",
|
||||
@@ -181,8 +210,8 @@ func TestNewUpsertStatement(t *testing.T) {
|
||||
type args struct {
|
||||
table string
|
||||
event *testEvent
|
||||
conflictCols []handler.Column
|
||||
values []handler.Column
|
||||
conflictCols []Column
|
||||
values []Column
|
||||
}
|
||||
type want struct {
|
||||
aggregateType eventstore.AggregateType
|
||||
@@ -206,7 +235,7 @@ func TestNewUpsertStatement(t *testing.T) {
|
||||
sequence: 1,
|
||||
previousSequence: 0,
|
||||
},
|
||||
values: []handler.Column{
|
||||
values: []Column{
|
||||
{
|
||||
Name: "col1",
|
||||
Value: "val",
|
||||
@@ -222,7 +251,7 @@ func TestNewUpsertStatement(t *testing.T) {
|
||||
shouldExecute: false,
|
||||
},
|
||||
isErr: func(err error) bool {
|
||||
return errors.Is(err, handler.ErrNoProjection)
|
||||
return errors.Is(err, ErrNoProjection)
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -235,7 +264,7 @@ func TestNewUpsertStatement(t *testing.T) {
|
||||
sequence: 1,
|
||||
previousSequence: 0,
|
||||
},
|
||||
values: []handler.Column{},
|
||||
values: []Column{},
|
||||
},
|
||||
want: want{
|
||||
table: "my_table",
|
||||
@@ -246,7 +275,7 @@ func TestNewUpsertStatement(t *testing.T) {
|
||||
shouldExecute: false,
|
||||
},
|
||||
isErr: func(err error) bool {
|
||||
return errors.Is(err, handler.ErrNoValues)
|
||||
return errors.Is(err, ErrNoValues)
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -259,10 +288,10 @@ func TestNewUpsertStatement(t *testing.T) {
|
||||
sequence: 1,
|
||||
previousSequence: 0,
|
||||
},
|
||||
conflictCols: []handler.Column{
|
||||
handler.NewCol("col1", nil),
|
||||
conflictCols: []Column{
|
||||
NewCol("col1", nil),
|
||||
},
|
||||
values: []handler.Column{
|
||||
values: []Column{
|
||||
{
|
||||
Name: "col1",
|
||||
Value: "val",
|
||||
@@ -278,7 +307,7 @@ func TestNewUpsertStatement(t *testing.T) {
|
||||
shouldExecute: false,
|
||||
},
|
||||
isErr: func(err error) bool {
|
||||
return errors.Is(err, handler.ErrNoValues)
|
||||
return errors.Is(err, ErrNoValues)
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -291,10 +320,10 @@ func TestNewUpsertStatement(t *testing.T) {
|
||||
sequence: 1,
|
||||
previousSequence: 0,
|
||||
},
|
||||
conflictCols: []handler.Column{
|
||||
handler.NewCol("col1", nil),
|
||||
conflictCols: []Column{
|
||||
NewCol("col1", nil),
|
||||
},
|
||||
values: []handler.Column{
|
||||
values: []Column{
|
||||
{
|
||||
Name: "col1",
|
||||
Value: "val",
|
||||
@@ -337,10 +366,10 @@ func TestNewUpsertStatement(t *testing.T) {
|
||||
sequence: 1,
|
||||
previousSequence: 0,
|
||||
},
|
||||
conflictCols: []handler.Column{
|
||||
handler.NewCol("col1", nil),
|
||||
conflictCols: []Column{
|
||||
NewCol("col1", nil),
|
||||
},
|
||||
values: []handler.Column{
|
||||
values: []Column{
|
||||
{
|
||||
Name: "col1",
|
||||
Value: "val",
|
||||
@@ -389,8 +418,8 @@ func TestNewUpdateStatement(t *testing.T) {
|
||||
type args struct {
|
||||
table string
|
||||
event *testEvent
|
||||
values []handler.Column
|
||||
conditions []handler.Condition
|
||||
values []Column
|
||||
conditions []Condition
|
||||
}
|
||||
type want struct {
|
||||
table string
|
||||
@@ -414,14 +443,14 @@ func TestNewUpdateStatement(t *testing.T) {
|
||||
sequence: 1,
|
||||
previousSequence: 0,
|
||||
},
|
||||
values: []handler.Column{
|
||||
values: []Column{
|
||||
{
|
||||
Name: "col1",
|
||||
Value: "val",
|
||||
},
|
||||
},
|
||||
conditions: []handler.Condition{
|
||||
handler.NewCond("col2", 1),
|
||||
conditions: []Condition{
|
||||
NewCond("col2", 1),
|
||||
},
|
||||
},
|
||||
want: want{
|
||||
@@ -433,7 +462,7 @@ func TestNewUpdateStatement(t *testing.T) {
|
||||
shouldExecute: false,
|
||||
},
|
||||
isErr: func(err error) bool {
|
||||
return errors.Is(err, handler.ErrNoProjection)
|
||||
return errors.Is(err, ErrNoProjection)
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -446,9 +475,9 @@ func TestNewUpdateStatement(t *testing.T) {
|
||||
sequence: 1,
|
||||
previousSequence: 0,
|
||||
},
|
||||
values: []handler.Column{},
|
||||
conditions: []handler.Condition{
|
||||
handler.NewCond("col2", 1),
|
||||
values: []Column{},
|
||||
conditions: []Condition{
|
||||
NewCond("col2", 1),
|
||||
},
|
||||
},
|
||||
want: want{
|
||||
@@ -460,7 +489,7 @@ func TestNewUpdateStatement(t *testing.T) {
|
||||
shouldExecute: false,
|
||||
},
|
||||
isErr: func(err error) bool {
|
||||
return errors.Is(err, handler.ErrNoValues)
|
||||
return errors.Is(err, ErrNoValues)
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -473,13 +502,13 @@ func TestNewUpdateStatement(t *testing.T) {
|
||||
sequence: 1,
|
||||
previousSequence: 0,
|
||||
},
|
||||
values: []handler.Column{
|
||||
values: []Column{
|
||||
{
|
||||
Name: "col1",
|
||||
Value: "val",
|
||||
},
|
||||
},
|
||||
conditions: []handler.Condition{},
|
||||
conditions: []Condition{},
|
||||
},
|
||||
want: want{
|
||||
table: "my_table",
|
||||
@@ -490,7 +519,7 @@ func TestNewUpdateStatement(t *testing.T) {
|
||||
shouldExecute: false,
|
||||
},
|
||||
isErr: func(err error) bool {
|
||||
return errors.Is(err, handler.ErrNoCondition)
|
||||
return errors.Is(err, ErrNoCondition)
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -503,14 +532,14 @@ func TestNewUpdateStatement(t *testing.T) {
|
||||
sequence: 1,
|
||||
previousSequence: 0,
|
||||
},
|
||||
values: []handler.Column{
|
||||
values: []Column{
|
||||
{
|
||||
Name: "col1",
|
||||
Value: "val",
|
||||
},
|
||||
},
|
||||
conditions: []handler.Condition{
|
||||
handler.NewCond("col2", 1),
|
||||
conditions: []Condition{
|
||||
NewCond("col2", 1),
|
||||
},
|
||||
},
|
||||
want: want{
|
||||
@@ -541,7 +570,7 @@ func TestNewUpdateStatement(t *testing.T) {
|
||||
sequence: 1,
|
||||
previousSequence: 0,
|
||||
},
|
||||
values: []handler.Column{
|
||||
values: []Column{
|
||||
{
|
||||
Name: "col1",
|
||||
Value: "val",
|
||||
@@ -551,8 +580,8 @@ func TestNewUpdateStatement(t *testing.T) {
|
||||
Value: "val5",
|
||||
},
|
||||
},
|
||||
conditions: []handler.Condition{
|
||||
handler.NewCond("col2", 1),
|
||||
conditions: []Condition{
|
||||
NewCond("col2", 1),
|
||||
},
|
||||
},
|
||||
want: want{
|
||||
@@ -593,7 +622,7 @@ func TestNewDeleteStatement(t *testing.T) {
|
||||
type args struct {
|
||||
table string
|
||||
event *testEvent
|
||||
conditions []handler.Condition
|
||||
conditions []Condition
|
||||
}
|
||||
|
||||
type want struct {
|
||||
@@ -618,8 +647,8 @@ func TestNewDeleteStatement(t *testing.T) {
|
||||
sequence: 1,
|
||||
previousSequence: 0,
|
||||
},
|
||||
conditions: []handler.Condition{
|
||||
handler.NewCond("col2", 1),
|
||||
conditions: []Condition{
|
||||
NewCond("col2", 1),
|
||||
},
|
||||
},
|
||||
want: want{
|
||||
@@ -631,7 +660,7 @@ func TestNewDeleteStatement(t *testing.T) {
|
||||
shouldExecute: false,
|
||||
},
|
||||
isErr: func(err error) bool {
|
||||
return errors.Is(err, handler.ErrNoProjection)
|
||||
return errors.Is(err, ErrNoProjection)
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -644,7 +673,7 @@ func TestNewDeleteStatement(t *testing.T) {
|
||||
sequence: 1,
|
||||
previousSequence: 0,
|
||||
},
|
||||
conditions: []handler.Condition{},
|
||||
conditions: []Condition{},
|
||||
},
|
||||
want: want{
|
||||
table: "my_table",
|
||||
@@ -655,7 +684,7 @@ func TestNewDeleteStatement(t *testing.T) {
|
||||
shouldExecute: false,
|
||||
},
|
||||
isErr: func(err error) bool {
|
||||
return errors.Is(err, handler.ErrNoCondition)
|
||||
return errors.Is(err, ErrNoCondition)
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -668,8 +697,8 @@ func TestNewDeleteStatement(t *testing.T) {
|
||||
previousSequence: 0,
|
||||
aggregateType: "agg",
|
||||
},
|
||||
conditions: []handler.Condition{
|
||||
handler.NewCond("col1", 1),
|
||||
conditions: []Condition{
|
||||
NewCond("col1", 1),
|
||||
},
|
||||
},
|
||||
want: want{
|
||||
@@ -709,11 +738,16 @@ func TestNewDeleteStatement(t *testing.T) {
|
||||
func TestNewNoOpStatement(t *testing.T) {
|
||||
type args struct {
|
||||
event *testEvent
|
||||
table string
|
||||
}
|
||||
type want struct {
|
||||
executer *wantExecuter
|
||||
isErr func(error) bool
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want *handler.Statement
|
||||
want want
|
||||
}{
|
||||
{
|
||||
name: "generate correctly",
|
||||
@@ -725,20 +759,29 @@ func TestNewNoOpStatement(t *testing.T) {
|
||||
instanceID: "instanceID",
|
||||
},
|
||||
},
|
||||
want: &handler.Statement{
|
||||
AggregateType: "agg",
|
||||
Execute: nil,
|
||||
Sequence: 5,
|
||||
PreviousSequence: 3,
|
||||
InstanceID: "instanceID",
|
||||
want: want{
|
||||
executer: nil,
|
||||
isErr: func(err error) bool {
|
||||
return err == nil
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := NewNoOpStatement(tt.args.event); !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("NewNoOpStatement() = %v, want %v", got, tt.want)
|
||||
stmt := NewNoOpStatement(tt.args.event)
|
||||
if tt.want.executer != nil && stmt.Execute == nil {
|
||||
t.Error("expected executer, but was nil")
|
||||
}
|
||||
if stmt.Execute == nil {
|
||||
return
|
||||
}
|
||||
tt.want.executer.t = t
|
||||
err := stmt.Execute(tt.want.executer, tt.args.table)
|
||||
if !tt.want.isErr(err) {
|
||||
t.Errorf("unexpected error: %v", err)
|
||||
}
|
||||
tt.want.executer.check(t)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -772,10 +815,17 @@ func TestNewMultiStatement(t *testing.T) {
|
||||
sequence: 1,
|
||||
previousSequence: 0,
|
||||
},
|
||||
execs: nil,
|
||||
execs: []func(eventstore.Event) Exec{
|
||||
AddNoOpStatement(),
|
||||
},
|
||||
},
|
||||
want: want{
|
||||
executer: nil,
|
||||
executer: &wantExecuter{
|
||||
shouldExecute: false,
|
||||
},
|
||||
isErr: func(err error) bool {
|
||||
return err == nil
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
@@ -789,10 +839,10 @@ func TestNewMultiStatement(t *testing.T) {
|
||||
},
|
||||
execs: []func(eventstore.Event) Exec{
|
||||
AddDeleteStatement(
|
||||
[]handler.Condition{},
|
||||
[]Condition{},
|
||||
),
|
||||
AddCreateStatement(
|
||||
[]handler.Column{
|
||||
[]Column{
|
||||
{
|
||||
Name: "col1",
|
||||
Value: 1,
|
||||
@@ -809,7 +859,7 @@ func TestNewMultiStatement(t *testing.T) {
|
||||
shouldExecute: false,
|
||||
},
|
||||
isErr: func(err error) bool {
|
||||
return errors.Is(err, handler.ErrNoCondition)
|
||||
return errors.Is(err, ErrNoCondition)
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -824,22 +874,21 @@ func TestNewMultiStatement(t *testing.T) {
|
||||
},
|
||||
execs: []func(eventstore.Event) Exec{
|
||||
AddDeleteStatement(
|
||||
[]handler.Condition{
|
||||
handler.NewCond("col1", 1),
|
||||
},
|
||||
),
|
||||
[]Condition{
|
||||
NewCond("col1", 1),
|
||||
}),
|
||||
AddCreateStatement(
|
||||
[]handler.Column{
|
||||
[]Column{
|
||||
{
|
||||
Name: "col1",
|
||||
Value: 1,
|
||||
},
|
||||
}),
|
||||
AddUpsertStatement(
|
||||
[]handler.Column{
|
||||
handler.NewCol("col1", nil),
|
||||
[]Column{
|
||||
NewCol("col1", nil),
|
||||
},
|
||||
[]handler.Column{
|
||||
[]Column{
|
||||
{
|
||||
Name: "col1",
|
||||
Value: 1,
|
||||
@@ -850,16 +899,15 @@ func TestNewMultiStatement(t *testing.T) {
|
||||
},
|
||||
}),
|
||||
AddUpdateStatement(
|
||||
[]handler.Column{
|
||||
[]Column{
|
||||
{
|
||||
Name: "col1",
|
||||
Value: 1,
|
||||
},
|
||||
},
|
||||
[]handler.Condition{
|
||||
handler.NewCond("col1", 1),
|
||||
},
|
||||
),
|
||||
[]Condition{
|
||||
NewCond("col1", 1),
|
||||
}),
|
||||
},
|
||||
},
|
||||
want: want{
|
||||
@@ -918,10 +966,10 @@ func TestNewCopyStatement(t *testing.T) {
|
||||
type args struct {
|
||||
table string
|
||||
event *testEvent
|
||||
conflictingCols []handler.Column
|
||||
from []handler.Column
|
||||
to []handler.Column
|
||||
conds []handler.NamespacedCondition
|
||||
conflictingCols []Column
|
||||
from []Column
|
||||
to []Column
|
||||
conds []NamespacedCondition
|
||||
}
|
||||
type want struct {
|
||||
aggregateType eventstore.AggregateType
|
||||
@@ -945,8 +993,8 @@ func TestNewCopyStatement(t *testing.T) {
|
||||
sequence: 1,
|
||||
previousSequence: 0,
|
||||
},
|
||||
conds: []handler.NamespacedCondition{
|
||||
handler.NewNamespacedCondition("col2", 1),
|
||||
conds: []NamespacedCondition{
|
||||
NewNamespacedCondition("col2", 1),
|
||||
},
|
||||
},
|
||||
want: want{
|
||||
@@ -958,7 +1006,7 @@ func TestNewCopyStatement(t *testing.T) {
|
||||
shouldExecute: false,
|
||||
},
|
||||
isErr: func(err error) bool {
|
||||
return errors.Is(err, handler.ErrNoProjection)
|
||||
return errors.Is(err, ErrNoProjection)
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -971,13 +1019,13 @@ func TestNewCopyStatement(t *testing.T) {
|
||||
sequence: 1,
|
||||
previousSequence: 0,
|
||||
},
|
||||
conds: []handler.NamespacedCondition{},
|
||||
from: []handler.Column{
|
||||
conds: []NamespacedCondition{},
|
||||
from: []Column{
|
||||
{
|
||||
Name: "col",
|
||||
},
|
||||
},
|
||||
to: []handler.Column{
|
||||
to: []Column{
|
||||
{
|
||||
Name: "col",
|
||||
},
|
||||
@@ -992,7 +1040,7 @@ func TestNewCopyStatement(t *testing.T) {
|
||||
shouldExecute: false,
|
||||
},
|
||||
isErr: func(err error) bool {
|
||||
return errors.Is(err, handler.ErrNoCondition)
|
||||
return errors.Is(err, ErrNoCondition)
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -1005,13 +1053,13 @@ func TestNewCopyStatement(t *testing.T) {
|
||||
sequence: 1,
|
||||
previousSequence: 0,
|
||||
},
|
||||
conds: []handler.NamespacedCondition{},
|
||||
from: []handler.Column{
|
||||
conds: []NamespacedCondition{},
|
||||
from: []Column{
|
||||
{
|
||||
Name: "col",
|
||||
},
|
||||
},
|
||||
to: []handler.Column{
|
||||
to: []Column{
|
||||
{
|
||||
Name: "col",
|
||||
},
|
||||
@@ -1029,7 +1077,7 @@ func TestNewCopyStatement(t *testing.T) {
|
||||
shouldExecute: false,
|
||||
},
|
||||
isErr: func(err error) bool {
|
||||
return errors.Is(err, handler.ErrNoCondition)
|
||||
return errors.Is(err, ErrNoCondition)
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -1042,10 +1090,10 @@ func TestNewCopyStatement(t *testing.T) {
|
||||
sequence: 1,
|
||||
previousSequence: 0,
|
||||
},
|
||||
conds: []handler.NamespacedCondition{
|
||||
handler.NewNamespacedCondition("col2", nil),
|
||||
conds: []NamespacedCondition{
|
||||
NewNamespacedCondition("col2", nil),
|
||||
},
|
||||
from: []handler.Column{},
|
||||
from: []Column{},
|
||||
},
|
||||
want: want{
|
||||
table: "my_table",
|
||||
@@ -1056,7 +1104,7 @@ func TestNewCopyStatement(t *testing.T) {
|
||||
shouldExecute: false,
|
||||
},
|
||||
isErr: func(err error) bool {
|
||||
return errors.Is(err, handler.ErrNoValues)
|
||||
return errors.Is(err, ErrNoValues)
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -1069,7 +1117,7 @@ func TestNewCopyStatement(t *testing.T) {
|
||||
sequence: 1,
|
||||
previousSequence: 0,
|
||||
},
|
||||
from: []handler.Column{
|
||||
from: []Column{
|
||||
{
|
||||
Name: "state",
|
||||
Value: 1,
|
||||
@@ -1084,7 +1132,7 @@ func TestNewCopyStatement(t *testing.T) {
|
||||
Name: "col_b",
|
||||
},
|
||||
},
|
||||
to: []handler.Column{
|
||||
to: []Column{
|
||||
{
|
||||
Name: "state",
|
||||
},
|
||||
@@ -1098,9 +1146,9 @@ func TestNewCopyStatement(t *testing.T) {
|
||||
Name: "col_b",
|
||||
},
|
||||
},
|
||||
conds: []handler.NamespacedCondition{
|
||||
handler.NewNamespacedCondition("id", 2),
|
||||
handler.NewNamespacedCondition("state", 3),
|
||||
conds: []NamespacedCondition{
|
||||
NewNamespacedCondition("id", 2),
|
||||
NewNamespacedCondition("state", 3),
|
||||
},
|
||||
},
|
||||
want: want{
|
||||
@@ -1131,7 +1179,7 @@ func TestNewCopyStatement(t *testing.T) {
|
||||
sequence: 1,
|
||||
previousSequence: 0,
|
||||
},
|
||||
from: []handler.Column{
|
||||
from: []Column{
|
||||
{
|
||||
Value: 1,
|
||||
},
|
||||
@@ -1145,7 +1193,7 @@ func TestNewCopyStatement(t *testing.T) {
|
||||
Name: "col_b",
|
||||
},
|
||||
},
|
||||
to: []handler.Column{
|
||||
to: []Column{
|
||||
{
|
||||
Name: "state",
|
||||
},
|
||||
@@ -1159,9 +1207,9 @@ func TestNewCopyStatement(t *testing.T) {
|
||||
Name: "col_d",
|
||||
},
|
||||
},
|
||||
conds: []handler.NamespacedCondition{
|
||||
handler.NewNamespacedCondition("id", 2),
|
||||
handler.NewNamespacedCondition("state", 3),
|
||||
conds: []NamespacedCondition{
|
||||
NewNamespacedCondition("id", 2),
|
||||
NewNamespacedCondition("state", 3),
|
||||
},
|
||||
},
|
||||
want: want{
|
||||
@@ -1200,7 +1248,7 @@ func TestNewCopyStatement(t *testing.T) {
|
||||
|
||||
func TestStatement_Execute(t *testing.T) {
|
||||
type fields struct {
|
||||
execute func(ex handler.Executer, projectionName string) error
|
||||
execute func(ex Executer, projectionName string) error
|
||||
}
|
||||
type want struct {
|
||||
isErr func(error) bool
|
||||
@@ -1217,7 +1265,7 @@ func TestStatement_Execute(t *testing.T) {
|
||||
{
|
||||
name: "execute returns no error",
|
||||
fields: fields{
|
||||
execute: func(ex handler.Executer, projectionName string) error { return nil },
|
||||
execute: func(ex Executer, projectionName string) error { return nil },
|
||||
},
|
||||
args: args{
|
||||
projectionName: "my_projection",
|
||||
@@ -1234,18 +1282,18 @@ func TestStatement_Execute(t *testing.T) {
|
||||
projectionName: "my_projection",
|
||||
},
|
||||
fields: fields{
|
||||
execute: func(ex handler.Executer, projectionName string) error { return errTestErr },
|
||||
execute: func(ex Executer, projectionName string) error { return errTest },
|
||||
},
|
||||
want: want{
|
||||
isErr: func(err error) bool {
|
||||
return errors.Is(err, errTestErr)
|
||||
return errors.Is(err, errTest)
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
stmt := &handler.Statement{
|
||||
stmt := &Statement{
|
||||
Execute: tt.fields.execute,
|
||||
}
|
||||
if err := stmt.Execute(nil, tt.args.projectionName); !tt.want.isErr(err) {
|
||||
@@ -1257,7 +1305,7 @@ func TestStatement_Execute(t *testing.T) {
|
||||
|
||||
func Test_columnsToQuery(t *testing.T) {
|
||||
type args struct {
|
||||
cols []handler.Column
|
||||
cols []Column
|
||||
}
|
||||
type want struct {
|
||||
names []string
|
||||
@@ -1281,7 +1329,7 @@ func Test_columnsToQuery(t *testing.T) {
|
||||
{
|
||||
name: "one column",
|
||||
args: args{
|
||||
cols: []handler.Column{
|
||||
cols: []Column{
|
||||
{
|
||||
Name: "col1",
|
||||
Value: 1,
|
||||
@@ -1297,7 +1345,7 @@ func Test_columnsToQuery(t *testing.T) {
|
||||
{
|
||||
name: "multiple columns",
|
||||
args: args{
|
||||
cols: []handler.Column{
|
||||
cols: []Column{
|
||||
{
|
||||
Name: "col1",
|
||||
Value: 1,
|
||||
@@ -1317,14 +1365,14 @@ func Test_columnsToQuery(t *testing.T) {
|
||||
{
|
||||
name: "with copy column",
|
||||
args: args{
|
||||
cols: []handler.Column{
|
||||
cols: []Column{
|
||||
{
|
||||
Name: "col1",
|
||||
Value: 1,
|
||||
},
|
||||
{
|
||||
Name: "col2",
|
||||
Value: handler.Column{
|
||||
Value: Column{
|
||||
Name: "col1",
|
||||
},
|
||||
},
|
||||
@@ -1357,9 +1405,9 @@ func Test_columnsToQuery(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func Test_conditionsToWhere(t *testing.T) {
|
||||
func Test_columnsToWhere(t *testing.T) {
|
||||
type args struct {
|
||||
conds []handler.Condition
|
||||
conds []Condition
|
||||
paramOffset int
|
||||
}
|
||||
type want struct {
|
||||
@@ -1382,10 +1430,10 @@ func Test_conditionsToWhere(t *testing.T) {
|
||||
{
|
||||
name: "no offset",
|
||||
args: args{
|
||||
conds: []handler.Condition{
|
||||
handler.NewCond("col1", "val1"),
|
||||
conds: []Condition{
|
||||
NewCond("col1", "val1"),
|
||||
},
|
||||
paramOffset: 0,
|
||||
paramOffset: 1,
|
||||
},
|
||||
want: want{
|
||||
wheres: []string{"(col1 = $1)"},
|
||||
@@ -1395,11 +1443,11 @@ func Test_conditionsToWhere(t *testing.T) {
|
||||
{
|
||||
name: "multiple cols",
|
||||
args: args{
|
||||
conds: []handler.Condition{
|
||||
handler.NewCond("col1", "val1"),
|
||||
handler.NewCond("col2", "val2"),
|
||||
conds: []Condition{
|
||||
NewCond("col1", "val1"),
|
||||
NewCond("col2", "val2"),
|
||||
},
|
||||
paramOffset: 0,
|
||||
paramOffset: 1,
|
||||
},
|
||||
want: want{
|
||||
wheres: []string{"(col1 = $1)", "(col2 = $2)"},
|
||||
@@ -1409,10 +1457,10 @@ func Test_conditionsToWhere(t *testing.T) {
|
||||
{
|
||||
name: "2 offset",
|
||||
args: args{
|
||||
conds: []handler.Condition{
|
||||
handler.NewCond("col1", "val1"),
|
||||
conds: []Condition{
|
||||
NewCond("col1", "val1"),
|
||||
},
|
||||
paramOffset: 2,
|
||||
paramOffset: 3,
|
||||
},
|
||||
want: want{
|
||||
wheres: []string{"(col1 = $3)"},
|
||||
@@ -1422,9 +1470,10 @@ func Test_conditionsToWhere(t *testing.T) {
|
||||
{
|
||||
name: "less than",
|
||||
args: args{
|
||||
conds: []handler.Condition{
|
||||
conds: []Condition{
|
||||
NewLessThanCond("col1", "val1"),
|
||||
},
|
||||
paramOffset: 1,
|
||||
},
|
||||
want: want{
|
||||
wheres: []string{"(col1 < $1)"},
|
||||
@@ -1434,7 +1483,7 @@ func Test_conditionsToWhere(t *testing.T) {
|
||||
{
|
||||
name: "is null",
|
||||
args: args{
|
||||
conds: []handler.Condition{
|
||||
conds: []Condition{
|
||||
NewIsNullCond("col1"),
|
||||
},
|
||||
},
|
||||
@@ -1446,21 +1495,23 @@ func Test_conditionsToWhere(t *testing.T) {
|
||||
{
|
||||
name: "text array contains",
|
||||
args: args{
|
||||
conds: []handler.Condition{
|
||||
conds: []Condition{
|
||||
NewTextArrayContainsCond("col1", "val1"),
|
||||
},
|
||||
paramOffset: 1,
|
||||
},
|
||||
want: want{
|
||||
wheres: []string{"(col1 @> $1)"},
|
||||
values: []interface{}{database.StringArray{"val1"}},
|
||||
values: []interface{}{database.TextArray[string]{"val1"}},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "not",
|
||||
args: args{
|
||||
conds: []handler.Condition{
|
||||
Not(handler.NewCond("col1", "val1")),
|
||||
conds: []Condition{
|
||||
Not(NewCond("col1", "val1")),
|
||||
},
|
||||
paramOffset: 1,
|
||||
},
|
||||
want: want{
|
||||
wheres: []string{"(NOT (col1 = $1))"},
|
||||
@@ -1490,7 +1541,7 @@ func TestParameterOpts(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
constructor func(column string, value interface{}) handler.Column
|
||||
constructor func(column string, value interface{}) Column
|
||||
want string
|
||||
}{
|
||||
{
|
||||
@@ -1523,3 +1574,60 @@ func TestParameterOpts(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// func TestHandler_reduce(t *testing.T) {
|
||||
// type fields struct {
|
||||
// projection Projection
|
||||
// }
|
||||
// type args struct {
|
||||
// event eventstore.Event
|
||||
// }
|
||||
// tests := []struct {
|
||||
// name string
|
||||
// fields fields
|
||||
// args args
|
||||
// isErr func(t *testing.T, err error)
|
||||
// shouldBeCalled bool
|
||||
// }{
|
||||
// {
|
||||
// name: "",
|
||||
// fields: fields{
|
||||
// projection: &projection{
|
||||
// reducers: []AggregateReducer{
|
||||
// {
|
||||
// Aggregate: "aggregate",
|
||||
// EventRedusers: []EventReducer{
|
||||
// {
|
||||
// Event: "event",
|
||||
// Reduce: (&mockEventReducer{
|
||||
// statement: new(Statement),
|
||||
// }).reduce,
|
||||
// },
|
||||
// },
|
||||
// },
|
||||
// },
|
||||
// },
|
||||
// },
|
||||
// },
|
||||
// }
|
||||
// for _, tt := range tests {
|
||||
// if tt.isErr == nil {
|
||||
// tt.isErr = func(t *testing.T, err error) {
|
||||
// if err != nil {
|
||||
// t.Error("expected no error got:", err)
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
// t.Run(tt.name, func(t *testing.T) {
|
||||
// h := &Handler{
|
||||
// projection: tt.fields.projection,
|
||||
// }
|
||||
// got, err := h.reduce(tt.args.event)
|
||||
// tt.isErr(t, err)
|
||||
// if tt.shouldBeCalled != tt.
|
||||
// if !reflect.DeepEqual(got, tt.want) {
|
||||
// t.Errorf("Handler.reduce() = %v, want %v", got, tt.want)
|
||||
// }
|
||||
// })
|
||||
// }
|
||||
// }
|
Reference in New Issue
Block a user