feat(eventstore): increase parallel write capabilities (#5940)

This implementation increases parallel write capabilities of the eventstore.
Please have a look at the technical advisories: [05](https://zitadel.com/docs/support/advisory/a10005) and  [06](https://zitadel.com/docs/support/advisory/a10006).
The implementation of eventstore.push is rewritten and stored events are migrated to a new table `eventstore.events2`.
If you are using cockroach: make sure that the database user of ZITADEL has `VIEWACTIVITY` grant. This is used to query events.
This commit is contained in:
Silvan
2023-10-19 12:19:10 +02:00
committed by GitHub
parent 259faba3f0
commit b5564572bc
791 changed files with 30326 additions and 43202 deletions

View File

@@ -1,83 +0,0 @@
package crdb
import (
"context"
"database/sql"
"strconv"
"strings"
"github.com/zitadel/zitadel/internal/database"
"github.com/zitadel/zitadel/internal/errors"
"github.com/zitadel/zitadel/internal/eventstore"
)
const (
currentSequenceStmtFormat = `SELECT current_sequence, aggregate_type, instance_id FROM %s WHERE projection_name = $1 AND instance_id = ANY ($2) FOR UPDATE`
currentSequenceStmtWithoutLockFormat = `SELECT current_sequence, aggregate_type, instance_id FROM %s WHERE projection_name = $1 AND instance_id = ANY ($2)`
updateCurrentSequencesStmtFormat = `INSERT INTO %s (projection_name, aggregate_type, current_sequence, instance_id, timestamp) VALUES `
updateCurrentSequencesConflictStmt = ` ON CONFLICT (projection_name, aggregate_type, instance_id) DO UPDATE SET current_sequence = EXCLUDED.current_sequence, timestamp = EXCLUDED.timestamp`
)
type currentSequences map[eventstore.AggregateType][]*instanceSequence
type instanceSequence struct {
instanceID string
sequence uint64
}
func (h *StatementHandler) currentSequences(ctx context.Context, isTx bool, query func(context.Context, func(*sql.Rows) error, string, ...interface{}) error, instanceIDs database.StringArray) (currentSequences, error) {
stmt := h.currentSequenceStmt
if !isTx {
stmt = h.currentSequenceWithoutLockStmt
}
sequences := make(currentSequences, len(h.aggregates))
err := query(ctx,
func(rows *sql.Rows) error {
for rows.Next() {
var (
aggregateType eventstore.AggregateType
sequence uint64
instanceID string
)
err := rows.Scan(&sequence, &aggregateType, &instanceID)
if err != nil {
return errors.ThrowInternal(err, "CRDB-dbatK", "scan failed")
}
sequences[aggregateType] = append(sequences[aggregateType], &instanceSequence{
sequence: sequence,
instanceID: instanceID,
})
}
return nil
},
stmt, h.ProjectionName, instanceIDs)
if err != nil {
return nil, err
}
return sequences, nil
}
func (h *StatementHandler) updateCurrentSequences(tx *sql.Tx, sequences currentSequences) error {
valueQueries := make([]string, 0, len(sequences))
valueCounter := 0
values := make([]interface{}, 0, len(sequences)*3)
for aggregate, instanceSequence := range sequences {
for _, sequence := range instanceSequence {
valueQueries = append(valueQueries, "($"+strconv.Itoa(valueCounter+1)+", $"+strconv.Itoa(valueCounter+2)+", $"+strconv.Itoa(valueCounter+3)+", $"+strconv.Itoa(valueCounter+4)+", NOW())")
valueCounter += 4
values = append(values, h.ProjectionName, aggregate, sequence.sequence, sequence.instanceID)
}
}
res, err := tx.Exec(h.updateSequencesBaseStmt+strings.Join(valueQueries, ", ")+updateCurrentSequencesConflictStmt, values...)
if err != nil {
return errors.ThrowInternal(err, "CRDB-TrH2Z", "unable to exec update sequence")
}
if rows, _ := res.RowsAffected(); rows < 1 {
return errSeqNotUpdated
}
return nil
}

View File

@@ -1,301 +1,15 @@
package crdb
import (
"database/sql"
"database/sql/driver"
"strings"
"testing"
"time"
"github.com/DATA-DOG/go-sqlmock"
"github.com/zitadel/zitadel/internal/database"
"github.com/zitadel/zitadel/internal/eventstore"
)
type mockExpectation func(sqlmock.Sqlmock)
func expectFailureCount(tableName string, projectionName, instanceID string, failedSeq, failureCount uint64) func(sqlmock.Sqlmock) {
return func(m sqlmock.Sqlmock) {
m.ExpectQuery(`WITH failures AS \(SELECT failure_count FROM `+tableName+` WHERE projection_name = \$1 AND failed_sequence = \$2 AND instance_id = \$3\) SELECT COALESCE\(\(SELECT failure_count FROM failures\), 0\) AS failure_count`).
WithArgs(projectionName, failedSeq, instanceID).
WillReturnRows(
sqlmock.NewRows([]string{"failure_count"}).
AddRow(failureCount),
)
}
}
func expectUpdateFailureCount(tableName string, projectionName, instanceID string, seq, failureCount uint64) func(sqlmock.Sqlmock) {
return func(m sqlmock.Sqlmock) {
m.ExpectExec(`INSERT INTO `+tableName+` \(projection_name, failed_sequence, failure_count, error, instance_id, last_failed\) VALUES \(\$1, \$2, \$3, \$4\, \$5\, \$6\) ON CONFLICT \(projection_name, failed_sequence, instance_id\) DO UPDATE SET failure_count = EXCLUDED\.failure_count, error = EXCLUDED\.error, last_failed = EXCLUDED\.last_failed`).
WithArgs(projectionName, seq, failureCount, sqlmock.AnyArg(), instanceID, "NOW()").WillReturnResult(sqlmock.NewResult(1, 1))
}
}
func expectCreate(projectionName string, columnNames, placeholders []string) func(sqlmock.Sqlmock) {
return func(m sqlmock.Sqlmock) {
args := make([]driver.Value, len(columnNames))
for i := 0; i < len(columnNames); i++ {
args[i] = sqlmock.AnyArg()
placeholders[i] = `\` + placeholders[i]
}
m.ExpectExec("INSERT INTO " + projectionName + ` \(` + strings.Join(columnNames, ", ") + `\) VALUES \(` + strings.Join(placeholders, ", ") + `\)`).
WithArgs(args...).
WillReturnResult(sqlmock.NewResult(1, 1))
}
}
func expectCreateErr(projectionName string, columnNames, placeholders []string, err error) func(sqlmock.Sqlmock) {
return func(m sqlmock.Sqlmock) {
args := make([]driver.Value, len(columnNames))
for i := 0; i < len(columnNames); i++ {
args[i] = sqlmock.AnyArg()
placeholders[i] = `\` + placeholders[i]
}
m.ExpectExec("INSERT INTO " + projectionName + ` \(` + strings.Join(columnNames, ", ") + `\) VALUES \(` + strings.Join(placeholders, ", ") + `\)`).
WithArgs(args...).
WillReturnError(err)
}
}
func expectBegin() func(sqlmock.Sqlmock) {
return func(m sqlmock.Sqlmock) {
m.ExpectBegin()
}
}
func expectBeginErr(err error) func(sqlmock.Sqlmock) {
return func(m sqlmock.Sqlmock) {
m.ExpectBegin().WillReturnError(err)
}
}
func expectCommit() func(sqlmock.Sqlmock) {
return func(m sqlmock.Sqlmock) {
m.ExpectCommit()
}
}
func expectCommitErr(err error) func(sqlmock.Sqlmock) {
return func(m sqlmock.Sqlmock) {
m.ExpectCommit().WillReturnError(err)
}
}
func expectRollback() func(sqlmock.Sqlmock) {
return func(m sqlmock.Sqlmock) {
m.ExpectRollback()
}
}
func expectSavePoint() func(sqlmock.Sqlmock) {
return func(m sqlmock.Sqlmock) {
m.ExpectExec("SAVEPOINT push_stmt").
WillReturnResult(sqlmock.NewResult(1, 1))
}
}
func expectSavePointErr(err error) func(sqlmock.Sqlmock) {
return func(m sqlmock.Sqlmock) {
m.ExpectExec("SAVEPOINT push_stmt").
WillReturnError(err)
}
}
func expectSavePointRollback() func(sqlmock.Sqlmock) {
return func(m sqlmock.Sqlmock) {
m.ExpectExec("ROLLBACK TO SAVEPOINT push_stmt").
WillReturnResult(sqlmock.NewResult(1, 1))
}
}
func expectSavePointRollbackErr(err error) func(sqlmock.Sqlmock) {
return func(m sqlmock.Sqlmock) {
m.ExpectExec("ROLLBACK TO SAVEPOINT push_stmt").
WillReturnError(err)
}
}
func expectSavePointRelease() func(sqlmock.Sqlmock) {
return func(m sqlmock.Sqlmock) {
m.ExpectExec("RELEASE push_stmt").
WillReturnResult(sqlmock.NewResult(1, 1))
}
}
func expectCurrentSequence(isTx bool, tableName, projection string, seq uint64, aggregateType string, instanceIDs []string) func(sqlmock.Sqlmock) {
rows := sqlmock.NewRows([]string{"current_sequence", "aggregate_type", "instance_id"})
for _, instanceID := range instanceIDs {
rows.AddRow(seq, aggregateType, instanceID)
}
return func(m sqlmock.Sqlmock) {
stmt := `SELECT current_sequence, aggregate_type, instance_id FROM ` + tableName + ` WHERE projection_name = \$1 AND instance_id = ANY \(\$2\)`
if isTx {
stmt += " FOR UPDATE"
}
m.ExpectQuery(stmt).
WithArgs(
projection,
database.StringArray(instanceIDs),
).
WillReturnRows(
rows,
)
}
}
func expectCurrentSequenceErr(isTx bool, tableName, projection string, instanceIDs []string, err error) func(sqlmock.Sqlmock) {
return func(m sqlmock.Sqlmock) {
stmt := `SELECT current_sequence, aggregate_type, instance_id FROM ` + tableName + ` WHERE projection_name = \$1 AND instance_id = ANY \(\$2\)`
if isTx {
stmt += " FOR UPDATE"
}
m.ExpectQuery(stmt).
WithArgs(
projection,
database.StringArray(instanceIDs),
).
WillReturnError(err)
}
}
func expectCurrentSequenceNoRows(tableName, projection string, instanceIDs []string) func(sqlmock.Sqlmock) {
return func(m sqlmock.Sqlmock) {
m.ExpectQuery(`SELECT current_sequence, aggregate_type, instance_id FROM `+tableName+` WHERE projection_name = \$1 AND instance_id = ANY \(\$2\) FOR UPDATE`).
WithArgs(
projection,
database.StringArray(instanceIDs),
).
WillReturnRows(
sqlmock.NewRows([]string{"current_sequence", "aggregate_type", "instance_id"}),
)
}
}
func expectCurrentSequenceScanErr(tableName, projection string, instanceIDs []string) func(sqlmock.Sqlmock) {
return func(m sqlmock.Sqlmock) {
m.ExpectQuery(`SELECT current_sequence, aggregate_type, instance_id FROM `+tableName+` WHERE projection_name = \$1 AND instance_id = ANY \(\$2\) FOR UPDATE`).
WithArgs(
projection,
database.StringArray(instanceIDs),
).
WillReturnRows(
sqlmock.NewRows([]string{"current_sequence", "aggregate_type", "instance_id"}).
RowError(0, sql.ErrTxDone).
AddRow(0, "agg", "instanceID"),
)
}
}
func expectUpdateCurrentSequence(tableName, projection string, seq uint64, aggregateType, instanceID string) func(sqlmock.Sqlmock) {
return func(m sqlmock.Sqlmock) {
m.ExpectExec("INSERT INTO "+tableName+` \(projection_name, aggregate_type, current_sequence, instance_id, timestamp\) VALUES \(\$1, \$2, \$3, \$4, NOW\(\)\) ON CONFLICT \(projection_name, aggregate_type, instance_id\) DO UPDATE SET current_sequence = EXCLUDED.current_sequence, timestamp = EXCLUDED.timestamp`).
WithArgs(
projection,
aggregateType,
seq,
instanceID,
).
WillReturnResult(
sqlmock.NewResult(1, 1),
)
}
}
func expectUpdateThreeCurrentSequence(t *testing.T, tableName, projection string, sequences currentSequences) func(sqlmock.Sqlmock) {
args := make([][]interface{}, 0)
for aggregateType, instanceSequences := range sequences {
for _, sequence := range instanceSequences {
args = append(args, []interface{}{
projection,
aggregateType,
sequence.sequence,
sequence.instanceID,
})
}
}
matcher := &currentSequenceMatcher{t: t, seq: args}
matchers := make([]driver.Value, len(args)*4)
for i := 0; i < len(args)*4; i++ {
matchers[i] = matcher
}
return func(m sqlmock.Sqlmock) {
m.ExpectExec("INSERT INTO " + tableName + ` \(projection_name, aggregate_type, current_sequence, instance_id, timestamp\) VALUES \(\$1, \$2, \$3, \$4, NOW\(\)\), \(\$5, \$6, \$7, \$8, NOW\(\)\), \(\$9, \$10, \$11, \$12, NOW\(\)\) ON CONFLICT \(projection_name, aggregate_type, instance_id\) DO UPDATE SET current_sequence = EXCLUDED.current_sequence, timestamp = EXCLUDED.timestamp`).
WithArgs(
matchers...,
).
WillReturnResult(
sqlmock.NewResult(1, 1),
)
}
}
type currentSequenceMatcher struct {
seq [][]interface{}
i int
t *testing.T
}
func (m *currentSequenceMatcher) Match(value driver.Value) bool {
if m.i%4 == 0 {
m.i = 0
}
left := make([]interface{}, 0, len(m.seq))
for _, seq := range m.seq {
found := seq[m.i]
if found == nil {
continue
}
switch v := value.(type) {
case string:
if found == v || found == eventstore.AggregateType(v) {
seq[m.i] = nil
m.i++
return true
}
case int64:
if found == uint64(v) {
seq[m.i] = nil
m.i++
return true
}
}
left = append(left, found)
}
m.t.Errorf("expected: %v, possible left values: %v", value, left)
m.t.FailNow()
return false
}
func expectUpdateCurrentSequenceErr(tableName, projection string, seq uint64, err error, aggregateType, instanceID string) func(sqlmock.Sqlmock) {
return func(m sqlmock.Sqlmock) {
m.ExpectExec("INSERT INTO "+tableName+` \(projection_name, aggregate_type, current_sequence, instance_id, timestamp\) VALUES \(\$1, \$2, \$3, \$4, NOW\(\)\) ON CONFLICT \(projection_name, aggregate_type, instance_id\) DO UPDATE SET current_sequence = EXCLUDED.current_sequence, timestamp = EXCLUDED.timestamp`).
WithArgs(
projection,
aggregateType,
seq,
instanceID,
).
WillReturnError(err)
}
}
func expectUpdateCurrentSequenceNoRows(tableName, projection string, seq uint64, aggregateType, instanceID string) func(sqlmock.Sqlmock) {
return func(m sqlmock.Sqlmock) {
m.ExpectExec("INSERT INTO "+tableName+` \(projection_name, aggregate_type, current_sequence, instance_id, timestamp\) VALUES \(\$1, \$2, \$3, \$4, NOW\(\)\) ON CONFLICT \(projection_name, aggregate_type, instance_id\) DO UPDATE SET current_sequence = EXCLUDED.current_sequence, timestamp = EXCLUDED.timestamp`).
WithArgs(
projection,
aggregateType,
seq,
instanceID,
).
WillReturnResult(
sqlmock.NewResult(0, 0),
)
}
}
func expectLock(lockTable, workerName string, d time.Duration, instanceID string) func(sqlmock.Sqlmock) {
return func(m sqlmock.Sqlmock) {
m.ExpectExec(`INSERT INTO `+lockTable+
@@ -308,7 +22,7 @@ func expectLock(lockTable, workerName string, d time.Duration, instanceID string
d,
projectionName,
instanceID,
database.StringArray{instanceID},
database.TextArray[string]{instanceID},
).
WillReturnResult(
sqlmock.NewResult(1, 1),
@@ -329,7 +43,7 @@ func expectLockMultipleInstances(lockTable, workerName string, d time.Duration,
projectionName,
instanceID1,
instanceID2,
database.StringArray{instanceID1, instanceID2},
database.TextArray[string]{instanceID1, instanceID2},
).
WillReturnResult(
sqlmock.NewResult(1, 1),
@@ -349,7 +63,7 @@ func expectLockNoRows(lockTable, workerName string, d time.Duration, instanceID
d,
projectionName,
instanceID,
database.StringArray{instanceID},
database.TextArray[string]{instanceID},
).
WillReturnResult(driver.ResultNoRows)
}
@@ -367,7 +81,7 @@ func expectLockErr(lockTable, workerName string, d time.Duration, instanceID str
d,
projectionName,
instanceID,
database.StringArray{instanceID},
database.TextArray[string]{instanceID},
).
WillReturnError(err)
}

View File

@@ -1,51 +0,0 @@
package crdb
import (
"database/sql"
"github.com/zitadel/logging"
"github.com/zitadel/zitadel/internal/errors"
"github.com/zitadel/zitadel/internal/eventstore/handler"
)
const (
setFailureCountStmtFormat = "INSERT INTO %s" +
" (projection_name, failed_sequence, failure_count, error, instance_id, last_failed)" +
" VALUES ($1, $2, $3, $4, $5, $6) ON CONFLICT (projection_name, failed_sequence, instance_id)" +
" DO UPDATE SET failure_count = EXCLUDED.failure_count, error = EXCLUDED.error, last_failed = EXCLUDED.last_failed"
failureCountStmtFormat = "WITH failures AS (SELECT failure_count FROM %s WHERE projection_name = $1 AND failed_sequence = $2 AND instance_id = $3)" +
" SELECT COALESCE((SELECT failure_count FROM failures), 0) AS failure_count"
)
func (h *StatementHandler) handleFailedStmt(tx *sql.Tx, stmt *handler.Statement, execErr error) (shouldContinue bool) {
failureCount, err := h.failureCount(tx, stmt.Sequence, stmt.InstanceID)
if err != nil {
logging.WithFields("projection", h.ProjectionName, "sequence", stmt.Sequence).WithError(err).Warn("unable to get failure count")
return false
}
failureCount += 1
err = h.setFailureCount(tx, stmt.Sequence, failureCount, execErr, stmt.InstanceID)
logging.WithFields("projection", h.ProjectionName, "sequence", stmt.Sequence).OnError(err).Warn("unable to update failure count")
return failureCount >= h.maxFailureCount
}
func (h *StatementHandler) failureCount(tx *sql.Tx, seq uint64, instanceID string) (count uint, err error) {
row := tx.QueryRow(h.failureCountStmt, h.ProjectionName, seq, instanceID)
if err = row.Err(); err != nil {
return 0, errors.ThrowInternal(err, "CRDB-Unnex", "unable to update failure count")
}
if err = row.Scan(&count); err != nil {
return 0, errors.ThrowInternal(err, "CRDB-RwSMV", "unable to scan count")
}
return count, nil
}
func (h *StatementHandler) setFailureCount(tx *sql.Tx, seq uint64, count uint, err error, instanceID string) error {
_, dbErr := tx.Exec(h.setFailureCountStmt, h.ProjectionName, seq, count, err.Error(), instanceID, "NOW()")
if dbErr != nil {
return errors.ThrowInternal(dbErr, "CRDB-4Ht4x", "set failure count failed")
}
return nil
}

View File

@@ -1,347 +0,0 @@
package crdb
import (
"context"
"database/sql"
"fmt"
"github.com/zitadel/logging"
"github.com/zitadel/zitadel/internal/database"
"github.com/zitadel/zitadel/internal/errors"
"github.com/zitadel/zitadel/internal/eventstore"
"github.com/zitadel/zitadel/internal/eventstore/handler"
"github.com/zitadel/zitadel/internal/repository/pseudo"
)
var (
errSeqNotUpdated = errors.ThrowInternal(nil, "CRDB-79GWt", "current sequence not updated")
)
type StatementHandlerConfig struct {
handler.ProjectionHandlerConfig
Client *database.DB
SequenceTable string
LockTable string
FailedEventsTable string
MaxFailureCount uint
BulkLimit uint64
Reducers []handler.AggregateReducer
InitCheck *handler.Check
}
type StatementHandler struct {
*handler.ProjectionHandler
Locker
client *database.DB
sequenceTable string
currentSequenceStmt string
currentSequenceWithoutLockStmt string
updateSequencesBaseStmt string
maxFailureCount uint
failureCountStmt string
setFailureCountStmt string
aggregates []eventstore.AggregateType
reduces map[eventstore.EventType]handler.Reduce
initCheck *handler.Check
initialized chan bool
bulkLimit uint64
reduceScheduledPseudoEvent bool
}
func NewStatementHandler(
ctx context.Context,
config StatementHandlerConfig,
) StatementHandler {
aggregateTypes := make([]eventstore.AggregateType, 0, len(config.Reducers))
reduces := make(map[eventstore.EventType]handler.Reduce, len(config.Reducers))
reduceScheduledPseudoEvent := false
for _, aggReducer := range config.Reducers {
aggregateTypes = append(aggregateTypes, aggReducer.Aggregate)
if aggReducer.Aggregate == pseudo.AggregateType {
reduceScheduledPseudoEvent = true
if len(config.Reducers) != 1 ||
len(aggReducer.EventRedusers) != 1 ||
aggReducer.EventRedusers[0].Event != pseudo.ScheduledEventType {
panic("if a pseudo.AggregateType is reduced, exactly one event reducer for pseudo.ScheduledEventType is supported and no other aggregate can be reduced")
}
}
for _, eventReducer := range aggReducer.EventRedusers {
reduces[eventReducer.Event] = eventReducer.Reduce
}
}
h := StatementHandler{
client: config.Client,
sequenceTable: config.SequenceTable,
maxFailureCount: config.MaxFailureCount,
currentSequenceStmt: fmt.Sprintf(currentSequenceStmtFormat, config.SequenceTable),
currentSequenceWithoutLockStmt: fmt.Sprintf(currentSequenceStmtWithoutLockFormat, config.SequenceTable),
updateSequencesBaseStmt: fmt.Sprintf(updateCurrentSequencesStmtFormat, config.SequenceTable),
failureCountStmt: fmt.Sprintf(failureCountStmtFormat, config.FailedEventsTable),
setFailureCountStmt: fmt.Sprintf(setFailureCountStmtFormat, config.FailedEventsTable),
aggregates: aggregateTypes,
reduces: reduces,
bulkLimit: config.BulkLimit,
Locker: NewLocker(config.Client.DB, config.LockTable, config.ProjectionName),
initCheck: config.InitCheck,
initialized: make(chan bool),
reduceScheduledPseudoEvent: reduceScheduledPseudoEvent,
}
h.ProjectionHandler = handler.NewProjectionHandler(ctx, config.ProjectionHandlerConfig, h.reduce, h.Update, h.searchQuery, h.Lock, h.Unlock, h.initialized, reduceScheduledPseudoEvent)
return h
}
func (h *StatementHandler) Start() {
h.initialized <- true
close(h.initialized)
if !h.reduceScheduledPseudoEvent {
h.Subscribe(h.aggregates...)
}
}
func (h *StatementHandler) searchQuery(ctx context.Context, instanceIDs []string) (*eventstore.SearchQueryBuilder, uint64, error) {
if h.reduceScheduledPseudoEvent {
return nil, 1, nil
}
return h.dbSearchQuery(ctx, instanceIDs)
}
func (h *StatementHandler) dbSearchQuery(ctx context.Context, instanceIDs []string) (*eventstore.SearchQueryBuilder, uint64, error) {
sequences, err := h.currentSequences(ctx, false, h.client.QueryContext, instanceIDs)
if err != nil {
return nil, 0, err
}
queryBuilder := eventstore.NewSearchQueryBuilder(eventstore.ColumnsEvent).Limit(h.bulkLimit).AllowTimeTravel()
for _, aggregateType := range h.aggregates {
for _, instanceID := range instanceIDs {
var seq uint64
for _, sequence := range sequences[aggregateType] {
if sequence.instanceID == instanceID {
seq = sequence.sequence
break
}
}
queryBuilder.
AddQuery().
AggregateTypes(aggregateType).
SequenceGreater(seq).
InstanceID(instanceID)
}
}
return queryBuilder, h.bulkLimit, nil
}
type transaction struct {
*sql.Tx
}
func (t *transaction) QueryContext(ctx context.Context, scan func(rows *sql.Rows) error, query string, args ...any) error {
rows, err := t.Tx.QueryContext(ctx, query, args...)
if err != nil {
return err
}
defer func() {
closeErr := rows.Close()
logging.OnError(closeErr).Info("rows.Close failed")
}()
if err = scan(rows); err != nil {
return err
}
return rows.Err()
}
// Update implements handler.Update
func (h *StatementHandler) Update(ctx context.Context, stmts []*handler.Statement, reduce handler.Reduce) (index int, err error) {
if len(stmts) == 0 {
return -1, nil
}
instanceIDs := make([]string, 0, len(stmts))
for _, stmt := range stmts {
instanceIDs = appendToInstanceIDs(instanceIDs, stmt.InstanceID)
}
tx, err := h.client.BeginTx(ctx, nil)
if err != nil {
return -1, errors.ThrowInternal(err, "CRDB-e89Gq", "begin failed")
}
sequences, err := h.currentSequences(ctx, true, (&transaction{Tx: tx}).QueryContext, instanceIDs)
if err != nil {
tx.Rollback()
return -1, err
}
//checks for events between create statement and current sequence
// because there could be events between current sequence and a creation event
// and we cannot check via stmt.PreviousSequence
if stmts[0].PreviousSequence == 0 {
previousStmts, err := h.fetchPreviousStmts(ctx, tx, stmts[0].Sequence, stmts[0].InstanceID, sequences, reduce)
if err != nil {
tx.Rollback()
return -1, err
}
stmts = append(previousStmts, stmts...)
}
lastSuccessfulIdx := h.executeStmts(tx, &stmts, sequences)
if lastSuccessfulIdx >= 0 {
err = h.updateCurrentSequences(tx, sequences)
if err != nil {
tx.Rollback()
return -1, err
}
}
if err = tx.Commit(); err != nil {
return -1, err
}
if lastSuccessfulIdx < len(stmts)-1 {
return lastSuccessfulIdx, handler.ErrSomeStmtsFailed
}
return lastSuccessfulIdx, nil
}
func (h *StatementHandler) fetchPreviousStmts(ctx context.Context, tx *sql.Tx, stmtSeq uint64, instanceID string, sequences currentSequences, reduce handler.Reduce) (previousStmts []*handler.Statement, err error) {
query := eventstore.NewSearchQueryBuilder(eventstore.ColumnsEvent).SetTx(tx)
queriesAdded := false
for _, aggregateType := range h.aggregates {
for _, sequence := range sequences[aggregateType] {
if stmtSeq <= sequence.sequence && instanceID == sequence.instanceID {
continue
}
query.
AddQuery().
AggregateTypes(aggregateType).
SequenceGreater(sequence.sequence).
SequenceLess(stmtSeq).
InstanceID(sequence.instanceID)
queriesAdded = true
}
}
if !queriesAdded {
return nil, nil
}
events, err := h.Eventstore.Filter(ctx, query)
if err != nil {
return nil, err
}
for _, event := range events {
stmt, err := reduce(event)
if err != nil {
return nil, err
}
previousStmts = append(previousStmts, stmt)
}
return previousStmts, nil
}
func (h *StatementHandler) executeStmts(
tx *sql.Tx,
stmts *[]*handler.Statement,
sequences currentSequences,
) int {
lastSuccessfulIdx := -1
stmts:
for i := 0; i < len(*stmts); i++ {
stmt := (*stmts)[i]
for _, sequence := range sequences[stmt.AggregateType] {
if stmt.Sequence <= sequence.sequence && stmt.InstanceID == sequence.instanceID {
logging.WithFields("statement", stmt, "currentSequence", sequence).Debug("statement dropped")
if i < len(*stmts)-1 {
copy((*stmts)[i:], (*stmts)[i+1:])
}
*stmts = (*stmts)[:len(*stmts)-1]
i--
continue stmts
}
if stmt.PreviousSequence > 0 && stmt.PreviousSequence != sequence.sequence && stmt.InstanceID == sequence.instanceID {
logging.WithFields("projection", h.ProjectionName, "aggregateType", stmt.AggregateType, "sequence", stmt.Sequence, "prevSeq", stmt.PreviousSequence, "currentSeq", sequence.sequence).Warn("sequences do not match")
break stmts
}
}
err := h.executeStmt(tx, stmt)
if err == nil {
updateSequences(sequences, stmt)
lastSuccessfulIdx = i
continue
}
shouldContinue := h.handleFailedStmt(tx, stmt, err)
if !shouldContinue {
break
}
updateSequences(sequences, stmt)
lastSuccessfulIdx = i
continue
}
return lastSuccessfulIdx
}
// executeStmt handles sql statements
// an error is returned if the statement could not be inserted properly
func (h *StatementHandler) executeStmt(tx *sql.Tx, stmt *handler.Statement) error {
if stmt.IsNoop() {
return nil
}
_, err := tx.Exec("SAVEPOINT push_stmt")
if err != nil {
return errors.ThrowInternal(err, "CRDB-i1wp6", "unable to create savepoint")
}
err = stmt.Execute(tx, h.ProjectionName)
if err != nil {
logging.WithError(err).Error()
_, rollbackErr := tx.Exec("ROLLBACK TO SAVEPOINT push_stmt")
if rollbackErr != nil {
return errors.ThrowInternal(rollbackErr, "CRDB-zzp3P", "rollback to savepoint failed")
}
return errors.ThrowInternal(err, "CRDB-oRkaN", "unable execute stmt")
}
_, err = tx.Exec("RELEASE push_stmt")
if err != nil {
return errors.ThrowInternal(err, "CRDB-qWgwT", "unable to release savepoint")
}
return nil
}
func updateSequences(sequences currentSequences, stmt *handler.Statement) {
for _, sequence := range sequences[stmt.AggregateType] {
if sequence.instanceID == stmt.InstanceID {
sequence.sequence = stmt.Sequence
return
}
}
sequences[stmt.AggregateType] = append(sequences[stmt.AggregateType], &instanceSequence{
instanceID: stmt.InstanceID,
sequence: stmt.Sequence,
})
}
func appendToInstanceIDs(instances []string, id string) []string {
for _, instance := range instances {
if instance == id {
return instances
}
}
return append(instances, id)
}

File diff suppressed because it is too large Load Diff

View File

@@ -1,49 +0,0 @@
package crdb
import "testing"
func Test_defaultValue(t *testing.T) {
type args struct {
value interface{}
}
tests := []struct {
name string
args args
want string
}{
{
name: "string",
args: args{
value: "asdf",
},
want: "'asdf'",
},
{
name: "primitive non string",
args: args{
value: 1,
},
want: "1",
},
{
name: "stringer",
args: args{
value: testStringer(0),
},
want: "0",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := defaultValue(tt.args.value); got != tt.want {
t.Errorf("defaultValue() = %v, want %v", got, tt.want)
}
})
}
}
type testStringer int
func (t testStringer) String() string {
return "0529958243"
}

View File

@@ -91,7 +91,7 @@ func (h *locker) Unlock(instanceIDs ...string) error {
return nil
}
func (h *locker) lockStatement(lockDuration time.Duration, instanceIDs database.StringArray) (string, []interface{}) {
func (h *locker) lockStatement(lockDuration time.Duration, instanceIDs database.TextArray[string]) (string, []interface{}) {
valueQueries := make([]string, len(instanceIDs))
values := make([]interface{}, len(instanceIDs)+4)
values[0] = h.workerName

View File

@@ -158,7 +158,7 @@ func TestStatementHandler_renewLock(t *testing.T) {
},
args: args{
lockDuration: 1 * time.Second,
instanceIDs: database.StringArray{"instanceID"},
instanceIDs: database.TextArray[string]{"instanceID"},
},
},
{
@@ -173,7 +173,7 @@ func TestStatementHandler_renewLock(t *testing.T) {
},
args: args{
lockDuration: 2 * time.Second,
instanceIDs: database.StringArray{"instanceID"},
instanceIDs: database.TextArray[string]{"instanceID"},
},
},
{
@@ -188,7 +188,7 @@ func TestStatementHandler_renewLock(t *testing.T) {
},
args: args{
lockDuration: 3 * time.Second,
instanceIDs: database.StringArray{"instanceID"},
instanceIDs: database.TextArray[string]{"instanceID"},
},
},
{

View File

@@ -1,16 +0,0 @@
package crdb
import (
"github.com/zitadel/zitadel/internal/eventstore"
"github.com/zitadel/zitadel/internal/eventstore/handler"
)
//reduce implements handler.Reduce function
func (h *StatementHandler) reduce(event eventstore.Event) (*handler.Statement, error) {
reduce, ok := h.reduces[event.Type()]
if !ok {
return NewNoOpStatement(event), nil
}
return reduce(event)
}

View File

@@ -1,36 +0,0 @@
package handler
import (
"github.com/zitadel/zitadel/internal/eventstore"
)
type HandlerConfig struct {
Eventstore *eventstore.Eventstore
}
type Handler struct {
Eventstore *eventstore.Eventstore
Sub *eventstore.Subscription
EventQueue chan eventstore.Event
}
func NewHandler(config HandlerConfig) Handler {
return Handler{
Eventstore: config.Eventstore,
EventQueue: make(chan eventstore.Event, 100),
}
}
func (h *Handler) Subscribe(aggregates ...eventstore.AggregateType) {
h.Sub = eventstore.SubscribeAggregates(h.EventQueue, aggregates...)
}
func (h *Handler) SubscribeEvents(types map[eventstore.AggregateType][]eventstore.EventType) {
h.Sub = eventstore.SubscribeEventTypes(h.EventQueue, types)
}
func (h *Handler) Unsubscribe() {
if h.Sub == nil {
return
}
h.Sub.Unsubscribe()
}

View File

@@ -1,396 +0,0 @@
package handler
import (
"context"
"errors"
"runtime/debug"
"time"
"github.com/sirupsen/logrus"
"github.com/zitadel/logging"
"github.com/zitadel/zitadel/internal/api/authz"
"github.com/zitadel/zitadel/internal/api/call"
"github.com/zitadel/zitadel/internal/eventstore"
"github.com/zitadel/zitadel/internal/repository/pseudo"
)
const (
schedulerSucceeded = eventstore.EventType("system.projections.scheduler.succeeded")
aggregateType = eventstore.AggregateType("system")
aggregateID = "SYSTEM"
)
type ProjectionHandlerConfig struct {
HandlerConfig
ProjectionName string
RequeueEvery time.Duration
RetryFailedAfter time.Duration
Retries uint
ConcurrentInstances uint
HandleActiveInstances time.Duration
}
// Update updates the projection with the given statements
type Update func(context.Context, []*Statement, Reduce) (index int, err error)
// Reduce reduces the given event to a statement
// which is used to update the projection
type Reduce func(eventstore.Event) (*Statement, error)
// SearchQuery generates the search query to lookup for events
type SearchQuery func(ctx context.Context, instanceIDs []string) (query *eventstore.SearchQueryBuilder, queryLimit uint64, err error)
// Lock is used for mutex handling if needed on the projection
type Lock func(context.Context, time.Duration, ...string) <-chan error
// Unlock releases the mutex of the projection
type Unlock func(...string) error
// NowFunc makes time.Now() mockable
type NowFunc func() time.Time
type ProjectionHandler struct {
Handler
ProjectionName string
reduce Reduce
update Update
searchQuery SearchQuery
triggerProjection *time.Timer
lock Lock
unlock Unlock
requeueAfter time.Duration
retryFailedAfter time.Duration
retries int
concurrentInstances int
handleActiveInstances time.Duration
nowFunc NowFunc
reduceScheduledPseudoEvent bool
}
func NewProjectionHandler(
ctx context.Context,
config ProjectionHandlerConfig,
reduce Reduce,
update Update,
query SearchQuery,
lock Lock,
unlock Unlock,
initialized <-chan bool,
reduceScheduledPseudoEvent bool,
) *ProjectionHandler {
concurrentInstances := int(config.ConcurrentInstances)
if concurrentInstances < 1 {
concurrentInstances = 1
}
h := &ProjectionHandler{
Handler: NewHandler(config.HandlerConfig),
ProjectionName: config.ProjectionName,
reduce: reduce,
update: update,
searchQuery: query,
lock: lock,
unlock: unlock,
requeueAfter: config.RequeueEvery,
triggerProjection: time.NewTimer(0), // first trigger is instant on startup
retryFailedAfter: config.RetryFailedAfter,
retries: int(config.Retries),
concurrentInstances: concurrentInstances,
handleActiveInstances: config.HandleActiveInstances,
nowFunc: time.Now,
reduceScheduledPseudoEvent: reduceScheduledPseudoEvent,
}
go func() {
<-initialized
if !h.reduceScheduledPseudoEvent {
go h.subscribe(ctx)
}
go h.schedule(ctx)
}()
return h
}
func triggerInstances(ctx context.Context, instances []string) []string {
if len(instances) == 0 {
instances = append(instances, authz.GetInstance(ctx).InstanceID())
}
return instances
}
// Trigger handles all events for the provided instances (or current instance from context if non specified)
// by calling FetchEvents and Process until the amount of events is smaller than the BulkLimit.
// If a bulk action was executed, the call timestamp in context will be reset for subsequent queries.
// The returned context is never nil. It is either the original context or an updated context.
//
// If Trigger encounters an error, it is only logged. If the error is important for the caller,
// use TriggerErr instead.
func (h *ProjectionHandler) Trigger(ctx context.Context, instances ...string) context.Context {
instances = triggerInstances(ctx, instances)
ctx, err := h.TriggerErr(ctx, instances...)
logging.OnError(err).WithFields(logrus.Fields{
"projection": h.ProjectionName,
"instanceIDs": instances,
}).Error("trigger failed")
return ctx
}
// TriggerErr handles all events for the provided instances (or current instance from context if non specified)
// by calling FetchEvents and Process until the amount of events is smaller than the BulkLimit.
// If a bulk action was executed, the call timestamp in context will be reset for subsequent queries.
// The returned context is never nil. It is either the original context or an updated context.
func (h *ProjectionHandler) TriggerErr(ctx context.Context, instances ...string) (outCtx context.Context, err error) {
instances = triggerInstances(ctx, instances)
defer func() {
outCtx = call.ResetTimestamp(ctx)
}()
for {
events, hasLimitExceeded, err := h.FetchEvents(ctx, instances...)
if err != nil {
return ctx, err
}
if len(events) == 0 {
return ctx, nil
}
_, err = h.Process(ctx, events...)
if err != nil {
return ctx, err
}
if !hasLimitExceeded {
return ctx, nil
}
}
}
// Process handles multiple events by reducing them to statements and updating the projection
func (h *ProjectionHandler) Process(ctx context.Context, events ...eventstore.Event) (index int, err error) {
if len(events) == 0 {
return 0, nil
}
index = -1
statements := make([]*Statement, len(events))
for i, event := range events {
statements[i], err = h.reduce(event)
if err != nil {
return index, err
}
}
for retry := 0; retry <= h.retries; retry++ {
index, err = h.update(ctx, statements[index+1:], h.reduce)
if err != nil && !errors.Is(err, ErrSomeStmtsFailed) {
return index, err
}
if err == nil {
return index, nil
}
time.Sleep(h.retryFailedAfter)
}
return index, err
}
// FetchEvents checks the current sequences and filters for newer events
func (h *ProjectionHandler) FetchEvents(ctx context.Context, instances ...string) ([]eventstore.Event, bool, error) {
if h.reduceScheduledPseudoEvent {
return h.fetchPseudoEvents(ctx, instances...)
}
return h.fetchDBEvents(ctx, instances...)
}
func (h *ProjectionHandler) fetchDBEvents(ctx context.Context, instances ...string) ([]eventstore.Event, bool, error) {
eventQuery, eventsLimit, err := h.searchQuery(ctx, instances)
if err != nil {
return nil, false, err
}
events, err := h.Eventstore.Filter(ctx, eventQuery)
if err != nil {
return nil, false, err
}
return events, int(eventsLimit) == len(events), err
}
func (h *ProjectionHandler) fetchPseudoEvents(ctx context.Context, instances ...string) ([]eventstore.Event, bool, error) {
return []eventstore.Event{pseudo.NewScheduledEvent(ctx, time.Now(), instances...)}, false, nil
}
func (h *ProjectionHandler) subscribe(ctx context.Context) {
ctx, cancel := context.WithCancel(ctx)
defer func() {
err := recover()
if err != nil {
h.Handler.Unsubscribe()
logging.WithFields("projection", h.ProjectionName).Errorf("subscription panicked: %v", err)
}
cancel()
}()
for firstEvent := range h.EventQueue {
events := checkAdditionalEvents(h.EventQueue, firstEvent)
index, err := h.Process(ctx, events...)
if err != nil || index < len(events)-1 {
logging.WithFields("projection", h.ProjectionName).WithError(err).Warn("unable to process all events from subscription")
}
}
}
func (h *ProjectionHandler) schedule(ctx context.Context) {
ctx, cancel := context.WithCancel(ctx)
defer func() {
err := recover()
if err != nil {
logging.WithFields("projection", h.ProjectionName, "cause", err, "stack", string(debug.Stack())).Error("schedule panicked")
}
cancel()
}()
// flag if projection has been successfully executed at least once since start
var succeededOnce bool
var err error
// get every instance id except empty (system)
query := eventstore.NewSearchQueryBuilder(eventstore.ColumnsInstanceIDs).AllowTimeTravel().AddQuery().ExcludedInstanceID("")
for range h.triggerProjection.C {
if !succeededOnce {
// (re)check if it has succeeded in the meantime
succeededOnce, err = h.hasSucceededOnce(ctx)
if err != nil {
logging.WithFields("projection", h.ProjectionName, "err", err).
Error("schedule could not check if projection has already succeeded once")
h.triggerProjection.Reset(h.requeueAfter)
continue
}
}
lockCtx := ctx
var cancelLock context.CancelFunc
// if it still has not succeeded, lock the projection for the system
// so that only a single scheduler does a first schedule (of every instance)
if !succeededOnce {
lockCtx, cancelLock = context.WithCancel(ctx)
errs := h.lock(lockCtx, h.requeueAfter, "system")
if err, ok := <-errs; err != nil || !ok {
cancelLock()
logging.WithFields("projection", h.ProjectionName).OnError(err).Debug("initial lock failed for first schedule")
h.triggerProjection.Reset(h.requeueAfter)
continue
}
go h.cancelOnErr(lockCtx, errs, cancelLock)
}
if succeededOnce {
// since we have at least one successful run, we can restrict it to events not older than
// h.handleActiveInstances (just to be sure not to miss an event)
// This ensures that only instances with recent events on the handler are projected
query = query.CreationDateAfter(h.nowFunc().Add(-1 * h.handleActiveInstances))
}
ids, err := h.Eventstore.InstanceIDs(ctx, h.requeueAfter, !succeededOnce, query.Builder())
if err != nil {
logging.WithFields("projection", h.ProjectionName).WithError(err).Error("instance ids")
h.triggerProjection.Reset(h.requeueAfter)
continue
}
var failed bool
for i := 0; i < len(ids); i = i + h.concurrentInstances {
max := i + h.concurrentInstances
if max > len(ids) {
max = len(ids)
}
instances := ids[i:max]
lockInstanceCtx, cancelInstanceLock := context.WithCancel(lockCtx)
errs := h.lock(lockInstanceCtx, h.requeueAfter, instances...)
//wait until projection is locked
if err, ok := <-errs; err != nil || !ok {
cancelInstanceLock()
logging.WithFields("projection", h.ProjectionName).OnError(err).Debug("initial lock failed")
failed = true
continue
}
go h.cancelOnErr(lockInstanceCtx, errs, cancelInstanceLock)
_, err = h.TriggerErr(lockInstanceCtx, instances...)
if err != nil {
logging.WithFields("projection", h.ProjectionName, "instanceIDs", instances).WithError(err).Error("trigger failed")
failed = true
}
cancelInstanceLock()
unlockErr := h.unlock(instances...)
logging.WithFields("projection", h.ProjectionName).OnError(unlockErr).Warn("unable to unlock")
}
// if the first schedule did not fail, store that in the eventstore, so we can check on later starts
if !succeededOnce {
if !failed {
err = h.setSucceededOnce(ctx)
logging.WithFields("projection", h.ProjectionName).OnError(err).Warn("unable to push first schedule succeeded")
}
cancelLock()
unlockErr := h.unlock("system")
logging.WithFields("projection", h.ProjectionName).OnError(unlockErr).Warn("unable to unlock first schedule")
}
// it succeeded at least once if it has succeeded before or if it has succeeded now - not failed ;-)
succeededOnce = succeededOnce || !failed
h.triggerProjection.Reset(h.requeueAfter)
}
}
func (h *ProjectionHandler) hasSucceededOnce(ctx context.Context) (bool, error) {
events, err := h.Eventstore.Filter(ctx, eventstore.NewSearchQueryBuilder(eventstore.ColumnsEvent).
AddQuery().
AggregateTypes(aggregateType).
AggregateIDs(aggregateID).
EventTypes(schedulerSucceeded).
EventData(map[string]interface{}{
"name": h.ProjectionName,
}).
Builder(),
)
return len(events) > 0 && err == nil, err
}
func (h *ProjectionHandler) setSucceededOnce(ctx context.Context) error {
_, err := h.Eventstore.Push(ctx, &ProjectionSucceededEvent{
BaseEvent: *eventstore.NewBaseEventForPush(ctx,
eventstore.NewAggregate(ctx, aggregateID, aggregateType, "v1"),
schedulerSucceeded,
),
Name: h.ProjectionName,
})
return err
}
type ProjectionSucceededEvent struct {
eventstore.BaseEvent `json:"-"`
Name string `json:"name"`
}
func (p *ProjectionSucceededEvent) Data() interface{} {
return p
}
func (p *ProjectionSucceededEvent) UniqueConstraints() []*eventstore.EventUniqueConstraint {
return nil
}
func (h *ProjectionHandler) cancelOnErr(ctx context.Context, errs <-chan error, cancel func()) {
for {
select {
case err := <-errs:
if err != nil {
logging.WithFields("projection", h.ProjectionName).WithError(err).Warn("bulk canceled")
cancel()
return
}
case <-ctx.Done():
cancel()
return
}
}
}
func checkAdditionalEvents(eventQueue chan eventstore.Event, event eventstore.Event) []eventstore.Event {
events := make([]eventstore.Event, 1)
events[0] = event
for {
select {
case event := <-eventQueue:
events = append(events, event)
default:
return events
}
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -2,7 +2,7 @@ package handler
import "context"
//Init initializes the projection with the given check
// Init initializes the projection with the given check
type Init func(context.Context, *Check) error
type Check struct {

View File

@@ -1,17 +0,0 @@
package handler
import "github.com/zitadel/zitadel/internal/eventstore"
//EventReducer represents the required data
//to work with events
type EventReducer struct {
Event eventstore.EventType
Reduce Reduce
}
//EventReducer represents the required data
//to work with aggregates
type AggregateReducer struct {
Aggregate eventstore.AggregateType
EventRedusers []EventReducer
}

View File

@@ -2,77 +2,8 @@ package handler
import (
"database/sql"
"encoding/json"
"errors"
"github.com/zitadel/logging"
"github.com/zitadel/zitadel/internal/eventstore"
)
var (
ErrNoProjection = errors.New("no projection")
ErrNoValues = errors.New("no values")
ErrNoCondition = errors.New("no condition")
ErrSomeStmtsFailed = errors.New("some statements failed")
)
type Statements []Statement
func (stmts Statements) Len() int { return len(stmts) }
func (stmts Statements) Swap(i, j int) { stmts[i], stmts[j] = stmts[j], stmts[i] }
func (stmts Statements) Less(i, j int) bool { return stmts[i].Sequence < stmts[j].Sequence }
type Statement struct {
AggregateType eventstore.AggregateType
Sequence uint64
PreviousSequence uint64
InstanceID string
Execute func(ex Executer, projectionName string) error
}
func (s *Statement) IsNoop() bool {
return s.Execute == nil
}
type Executer interface {
Exec(string, ...interface{}) (sql.Result, error)
}
type Column struct {
Name string
Value interface{}
ParameterOpt func(string) string
}
func NewCol(name string, value interface{}) Column {
return Column{
Name: name,
Value: value,
}
}
func NewJSONCol(name string, value interface{}) Column {
marshalled, err := json.Marshal(value)
if err != nil {
logging.WithFields("column", name).WithError(err).Panic("unable to marshal column")
}
return NewCol(name, marshalled)
}
type Condition func(param string) (string, interface{})
type NamespacedCondition func(namespace string) Condition
func NewCond(name string, value interface{}) Condition {
return func(param string) (string, interface{}) {
return name + " = " + param, value
}
}
func NewNamespacedCondition(name string, value interface{}) NamespacedCondition {
return func(namespace string) Condition {
return NewCond(namespace+"."+name, value)
}
}

View File

@@ -0,0 +1,52 @@
package handler
import (
"context"
"github.com/zitadel/zitadel/internal/eventstore"
)
const (
schedulerSucceeded = eventstore.EventType("system.projections.scheduler.succeeded")
aggregateType = eventstore.AggregateType("system")
aggregateID = "SYSTEM"
)
func (h *Handler) didProjectionInitialize(ctx context.Context) bool {
events, err := h.es.Filter(ctx, eventstore.NewSearchQueryBuilder(eventstore.ColumnsEvent).
InstanceID("").
AddQuery().
AggregateTypes(aggregateType).
AggregateIDs(aggregateID).
EventTypes(schedulerSucceeded).
EventData(map[string]interface{}{
"name": h.projection.Name(),
}).
Builder(),
)
return len(events) > 0 && err == nil
}
func (h *Handler) setSucceededOnce(ctx context.Context) error {
_, err := h.es.Push(ctx, &ProjectionSucceededEvent{
BaseEvent: *eventstore.NewBaseEventForPush(ctx,
eventstore.NewAggregate(ctx, aggregateID, aggregateType, "v1"),
schedulerSucceeded,
),
Name: h.projection.Name(),
})
return err
}
type ProjectionSucceededEvent struct {
eventstore.BaseEvent `json:"-"`
Name string `json:"name"`
}
func (p *ProjectionSucceededEvent) Payload() interface{} {
return p
}
func (p *ProjectionSucceededEvent) UniqueConstraints() []*eventstore.UniqueConstraint {
return nil
}

View File

@@ -0,0 +1,95 @@
package handler
import (
"database/sql"
_ "embed"
"time"
"github.com/zitadel/zitadel/internal/errors"
"github.com/zitadel/zitadel/internal/eventstore"
)
var (
//go:embed failed_event_set.sql
setFailedEventStmt string
//go:embed failed_event_get_count.sql
failureCountStmt string
)
type failure struct {
sequence uint64
instance string
aggregateID string
aggregateType eventstore.AggregateType
eventDate time.Time
err error
}
func failureFromEvent(event eventstore.Event, err error) *failure {
return &failure{
sequence: event.Sequence(),
instance: event.Aggregate().InstanceID,
aggregateID: event.Aggregate().ID,
aggregateType: event.Aggregate().Type,
eventDate: event.CreatedAt(),
err: err,
}
}
func failureFromStatement(statement *Statement, err error) *failure {
return &failure{
sequence: statement.Sequence,
instance: statement.InstanceID,
aggregateID: statement.AggregateID,
aggregateType: statement.AggregateType,
eventDate: statement.CreationDate,
err: err,
}
}
func (h *Handler) handleFailedStmt(tx *sql.Tx, currentState *state, f *failure) (shouldContinue bool) {
failureCount, err := h.failureCount(tx, f)
if err != nil {
h.logFailure(f).WithError(err).Warn("unable to get failure count")
return false
}
failureCount += 1
err = h.setFailureCount(tx, failureCount, f)
h.logFailure(f).OnError(err).Warn("unable to update failure count")
return failureCount >= h.maxFailureCount
}
func (h *Handler) failureCount(tx *sql.Tx, f *failure) (count uint8, err error) {
row := tx.QueryRow(failureCountStmt,
h.projection.Name(),
f.instance,
f.aggregateType,
f.aggregateID,
f.sequence,
)
if err = row.Err(); err != nil {
return 0, errors.ThrowInternal(err, "CRDB-Unnex", "unable to update failure count")
}
if err = row.Scan(&count); err != nil {
return 0, errors.ThrowInternal(err, "CRDB-RwSMV", "unable to scan count")
}
return count, nil
}
func (h *Handler) setFailureCount(tx *sql.Tx, count uint8, f *failure) error {
_, err := tx.Exec(setFailedEventStmt,
h.projection.Name(),
f.instance,
f.aggregateType,
f.aggregateID,
f.eventDate,
f.sequence,
count,
f.err.Error(),
)
if err != nil {
return errors.ThrowInternal(err, "CRDB-4Ht4x", "set failure count failed")
}
return nil
}

View File

@@ -0,0 +1,12 @@
WITH failures AS (
SELECT
failure_count
FROM
projections.failed_events2
WHERE
projection_name = $1
AND instance_id = $2
AND aggregate_type = $3
AND aggregate_id = $4
AND failed_sequence = $5
) SELECT COALESCE((SELECT failure_count FROM failures), 0) AS failure_count

View File

@@ -0,0 +1,31 @@
INSERT INTO projections.failed_events2 (
projection_name
, instance_id
, aggregate_type
, aggregate_id
, event_creation_date
, failed_sequence
, failure_count
, error
, last_failed
) VALUES (
$1
, $2
, $3
, $4
, $5
, $6
, $7
, $8
, now()
) ON CONFLICT (
projection_name
, aggregate_type
, aggregate_id
, failed_sequence
, instance_id
) DO UPDATE SET
failure_count = EXCLUDED.failure_count
, error = EXCLUDED.error
, last_failed = EXCLUDED.last_failed
;

View File

@@ -0,0 +1,465 @@
package handler
import (
"context"
"database/sql"
"errors"
"math"
"sync"
"time"
"github.com/jackc/pgconn"
"github.com/zitadel/zitadel/internal/api/authz"
"github.com/zitadel/zitadel/internal/api/call"
"github.com/zitadel/zitadel/internal/database"
"github.com/zitadel/zitadel/internal/eventstore"
"github.com/zitadel/zitadel/internal/repository/pseudo"
"github.com/zitadel/zitadel/internal/telemetry/tracing"
)
type EventStore interface {
InstanceIDs(ctx context.Context, maxAge time.Duration, forceLoad bool, query *eventstore.SearchQueryBuilder) ([]string, error)
Filter(ctx context.Context, queryFactory *eventstore.SearchQueryBuilder) ([]eventstore.Event, error)
Push(ctx context.Context, cmds ...eventstore.Command) ([]eventstore.Event, error)
}
type Config struct {
Client *database.DB
Eventstore EventStore
BulkLimit uint16
RequeueEvery time.Duration
RetryFailedAfter time.Duration
HandleActiveInstances time.Duration
TransactionDuration time.Duration
MaxFailureCount uint8
TriggerWithoutEvents Reduce
}
type Handler struct {
client *database.DB
projection Projection
es EventStore
bulkLimit uint16
eventTypes map[eventstore.AggregateType][]eventstore.EventType
maxFailureCount uint8
retryFailedAfter time.Duration
requeueEvery time.Duration
handleActiveInstances time.Duration
txDuration time.Duration
now nowFunc
triggeredInstancesSync sync.Map
triggerWithoutEvents Reduce
}
// nowFunc makes [time.Now] mockable
type nowFunc func() time.Time
type Projection interface {
Name() string
Reducers() []AggregateReducer
}
func NewHandler(
ctx context.Context,
config *Config,
projection Projection,
) *Handler {
aggregates := make(map[eventstore.AggregateType][]eventstore.EventType, len(projection.Reducers()))
for _, reducer := range projection.Reducers() {
eventTypes := make([]eventstore.EventType, len(reducer.EventReducers))
for i, eventReducer := range reducer.EventReducers {
eventTypes[i] = eventReducer.Event
}
if _, ok := aggregates[reducer.Aggregate]; ok {
aggregates[reducer.Aggregate] = append(aggregates[reducer.Aggregate], eventTypes...)
continue
}
aggregates[reducer.Aggregate] = eventTypes
}
handler := &Handler{
projection: projection,
client: config.Client,
es: config.Eventstore,
bulkLimit: config.BulkLimit,
eventTypes: aggregates,
requeueEvery: config.RequeueEvery,
handleActiveInstances: config.HandleActiveInstances,
now: time.Now,
maxFailureCount: config.MaxFailureCount,
retryFailedAfter: config.RetryFailedAfter,
triggeredInstancesSync: sync.Map{},
triggerWithoutEvents: config.TriggerWithoutEvents,
txDuration: config.TransactionDuration,
}
return handler
}
func (h *Handler) Start(ctx context.Context) {
go h.schedule(ctx)
if h.triggerWithoutEvents != nil {
return
}
go h.subscribe(ctx)
}
func (h *Handler) schedule(ctx context.Context) {
// if there was no run before trigger instantly
t := time.NewTimer(0)
didInitialize := h.didProjectionInitialize(ctx)
if didInitialize {
t.Reset(h.requeueEvery)
}
for {
select {
case <-ctx.Done():
t.Stop()
return
case <-t.C:
instances, err := h.queryInstances(ctx, didInitialize)
h.log().OnError(err).Debug("unable to query instances")
var instanceFailed bool
scheduledCtx := call.WithTimestamp(ctx)
for _, instance := range instances {
instanceCtx := authz.WithInstanceID(scheduledCtx, instance)
// simple implementation of do while
_, err = h.Trigger(instanceCtx)
instanceFailed = instanceFailed || err != nil
h.log().WithField("instance", instance).OnError(err).Info("scheduled trigger failed")
// retry if trigger failed
for ; err != nil; _, err = h.Trigger(instanceCtx) {
time.Sleep(h.retryFailedAfter)
instanceFailed = instanceFailed || err != nil
h.log().WithField("instance", instance).OnError(err).Info("scheduled trigger failed")
if err == nil {
break
}
}
}
if !didInitialize && !instanceFailed {
err = h.setSucceededOnce(ctx)
h.log().OnError(err).Debug("unable to set succeeded once")
didInitialize = err == nil
}
t.Reset(h.requeueEvery)
}
}
}
func (h *Handler) subscribe(ctx context.Context) {
queue := make(chan eventstore.Event, 100)
subscription := eventstore.SubscribeEventTypes(queue, h.eventTypes)
for {
select {
case <-ctx.Done():
subscription.Unsubscribe()
h.log().Debug("shutdown")
return
case event := <-queue:
events := checkAdditionalEvents(queue, event)
solvedInstances := make([]string, 0, len(events))
queueCtx := call.WithTimestamp(ctx)
for _, e := range events {
if instanceSolved(solvedInstances, e.Aggregate().InstanceID) {
continue
}
queueCtx = authz.WithInstanceID(queueCtx, e.Aggregate().InstanceID)
_, err := h.Trigger(queueCtx)
h.log().OnError(err).Debug("trigger of queued event failed")
if err == nil {
solvedInstances = append(solvedInstances, e.Aggregate().InstanceID)
}
}
}
}
}
func instanceSolved(solvedInstances []string, instanceID string) bool {
for _, solvedInstance := range solvedInstances {
if solvedInstance == instanceID {
return true
}
}
return false
}
func checkAdditionalEvents(eventQueue chan eventstore.Event, event eventstore.Event) []eventstore.Event {
events := make([]eventstore.Event, 1)
events[0] = event
for {
wait := time.NewTimer(1 * time.Millisecond)
select {
case event := <-eventQueue:
events = append(events, event)
case <-wait.C:
return events
}
}
}
func (h *Handler) queryInstances(ctx context.Context, didInitialize bool) ([]string, error) {
query := eventstore.NewSearchQueryBuilder(eventstore.ColumnsInstanceIDs).
AwaitOpenTransactions().
AllowTimeTravel().
ExcludedInstanceID("")
if didInitialize {
query = query.
CreationDateAfter(h.now().Add(-1 * h.handleActiveInstances))
}
return h.es.InstanceIDs(ctx, h.requeueEvery, !didInitialize, query)
}
type triggerConfig struct {
awaitRunning bool
}
type triggerOpt func(conf *triggerConfig)
func WithAwaitRunning() triggerOpt {
return func(conf *triggerConfig) {
conf.awaitRunning = true
}
}
func (h *Handler) Trigger(ctx context.Context, opts ...triggerOpt) (_ context.Context, err error) {
if authz.GetInstance(ctx).InstanceID() != "" {
var span *tracing.Span
ctx, span = tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()
}
config := new(triggerConfig)
for _, opt := range opts {
opt(config)
}
cancel := h.lockInstance(ctx, config)
if cancel == nil {
return call.ResetTimestamp(ctx), nil
}
defer cancel()
for i := 0; ; i++ {
additionalIteration, err := h.processEvents(ctx, config)
h.log().OnError(err).Warn("process events failed")
h.log().WithField("iteration", i).Debug("trigger iteration")
if !additionalIteration || err != nil {
return call.ResetTimestamp(ctx), err
}
}
}
// lockInstances tries to lock the instance.
// If the instance is already locked from another process no cancel function is returned
// the instance can be skipped then
// If the instance is locked, an unlock deferable function is returned
func (h *Handler) lockInstance(ctx context.Context, config *triggerConfig) func() {
instanceID := authz.GetInstance(ctx).InstanceID()
// Check that the instance has a mutex to lock
instanceMu, _ := h.triggeredInstancesSync.LoadOrStore(instanceID, new(sync.Mutex))
unlock := func() {
instanceMu.(*sync.Mutex).Unlock()
}
if !instanceMu.(*sync.Mutex).TryLock() {
instanceMu.(*sync.Mutex).Lock()
if config.awaitRunning {
return unlock
}
defer unlock()
return nil
}
return unlock
}
func (h *Handler) processEvents(ctx context.Context, config *triggerConfig) (additionalIteration bool, err error) {
defer func() {
pgErr := new(pgconn.PgError)
if errors.As(err, &pgErr) {
// error returned if the row is currently locked by another connection
if pgErr.Code == "55P03" {
h.log().Debug("state already locked")
err = nil
additionalIteration = false
}
}
}()
if h.txDuration > 0 {
var cancel func()
ctx, cancel = context.WithTimeout(ctx, h.txDuration)
defer cancel()
}
tx, err := h.client.Begin()
if err != nil {
return false, err
}
defer func() {
if err != nil {
rollbackErr := tx.Rollback()
h.log().OnError(rollbackErr).Debug("unable to rollback tx")
return
}
err = tx.Commit()
}()
currentState, err := h.currentState(ctx, tx, config)
if err != nil {
if errors.Is(err, errJustUpdated) {
return false, nil
}
return additionalIteration, err
}
var statements []*Statement
statements, additionalIteration, err = h.generateStatements(ctx, tx, currentState)
if err != nil || len(statements) == 0 {
return additionalIteration, err
}
lastProcessedIndex, err := h.executeStatements(ctx, tx, currentState, statements)
if lastProcessedIndex < 0 {
return false, err
}
currentState.position = statements[lastProcessedIndex].Position
currentState.aggregateID = statements[lastProcessedIndex].AggregateID
currentState.aggregateType = statements[lastProcessedIndex].AggregateType
currentState.sequence = statements[lastProcessedIndex].Sequence
currentState.eventTimestamp = statements[lastProcessedIndex].CreationDate
err = h.setState(tx, currentState)
return additionalIteration, err
}
func (h *Handler) generateStatements(ctx context.Context, tx *sql.Tx, currentState *state) (_ []*Statement, additionalIteration bool, err error) {
if h.triggerWithoutEvents != nil {
stmt, err := h.triggerWithoutEvents(pseudo.NewScheduledEvent(ctx, time.Now(), currentState.instanceID))
if err != nil {
return nil, false, err
}
return []*Statement{stmt}, false, nil
}
events, err := h.es.Filter(ctx, h.eventQuery(currentState))
if err != nil {
h.log().WithError(err).Debug("filter eventstore failed")
return nil, false, err
}
eventAmount := len(events)
events = skipPreviouslyReduced(events, currentState)
if len(events) == 0 {
h.updateLastUpdated(ctx, tx, currentState)
return nil, false, nil
}
statements, err := h.eventsToStatements(tx, events, currentState)
if len(statements) == 0 {
return nil, false, err
}
additionalIteration = eventAmount == int(h.bulkLimit)
if len(statements) < len(events) {
// retry imediatly if statements failed
additionalIteration = true
}
return statements, additionalIteration, nil
}
func skipPreviouslyReduced(events []eventstore.Event, currentState *state) []eventstore.Event {
for i, event := range events {
if event.Position() == currentState.position &&
event.Aggregate().ID == currentState.aggregateID &&
event.Aggregate().Type == currentState.aggregateType &&
event.Sequence() == currentState.sequence {
return events[i+1:]
}
}
return events
}
func (h *Handler) executeStatements(ctx context.Context, tx *sql.Tx, currentState *state, statements []*Statement) (lastProcessedIndex int, err error) {
lastProcessedIndex = -1
for i, statement := range statements {
select {
case <-ctx.Done():
break
default:
err := h.executeStatement(ctx, tx, currentState, statement)
if err != nil {
return lastProcessedIndex, err
}
lastProcessedIndex = i
}
}
return lastProcessedIndex, nil
}
func (h *Handler) executeStatement(ctx context.Context, tx *sql.Tx, currentState *state, statement *Statement) (err error) {
if statement.Execute == nil {
return nil
}
_, err = tx.Exec("SAVEPOINT exec")
if err != nil {
h.log().WithError(err).Debug("create savepoint failed")
return err
}
var shouldContinue bool
defer func() {
_, err = tx.Exec("RELEASE SAVEPOINT exec")
}()
if err = statement.Execute(tx, h.projection.Name()); err != nil {
h.log().WithError(err).Error("statement execution failed")
shouldContinue = h.handleFailedStmt(tx, currentState, failureFromStatement(statement, err))
if shouldContinue {
return nil
}
return err
}
return nil
}
func (h *Handler) eventQuery(currentState *state) *eventstore.SearchQueryBuilder {
builder := eventstore.NewSearchQueryBuilder(eventstore.ColumnsEvent).
AwaitOpenTransactions().
Limit(uint64(h.bulkLimit)).
AllowTimeTravel().
OrderAsc().
InstanceID(currentState.instanceID)
if currentState.position > 0 {
builder = builder.PositionAfter(math.Float64frombits(math.Float64bits(currentState.position) - 10))
}
for aggregateType, eventTypes := range h.eventTypes {
query := builder.
AddQuery().
AggregateTypes(aggregateType).
EventTypes(eventTypes...)
builder = query.Builder()
}
return builder
}

View File

@@ -1,4 +1,4 @@
package crdb
package handler
import (
"context"
@@ -9,19 +9,19 @@ import (
"github.com/jackc/pgconn"
"github.com/zitadel/logging"
caos_errs "github.com/zitadel/zitadel/internal/errors"
errs "github.com/zitadel/zitadel/internal/errors"
"github.com/zitadel/zitadel/internal/eventstore/handler"
)
type Table struct {
columns []*Column
columns []*InitColumn
primaryKey PrimaryKey
indices []*Index
constraints []*Constraint
foreignKeys []*ForeignKey
}
func NewTable(columns []*Column, key PrimaryKey, opts ...TableOption) *Table {
func NewTable(columns []*InitColumn, key PrimaryKey, opts ...TableOption) *Table {
t := &Table{
columns: columns,
primaryKey: key,
@@ -37,7 +37,7 @@ type SuffixedTable struct {
suffix string
}
func NewSuffixedTable(columns []*Column, key PrimaryKey, suffix string, opts ...TableOption) *SuffixedTable {
func NewSuffixedTable(columns []*InitColumn, key PrimaryKey, suffix string, opts ...TableOption) *SuffixedTable {
return &SuffixedTable{
Table: *NewTable(columns, key, opts...),
suffix: suffix,
@@ -64,7 +64,7 @@ func WithForeignKey(key *ForeignKey) TableOption {
}
}
type Column struct {
type InitColumn struct {
Name string
Type ColumnType
nullable bool
@@ -72,10 +72,10 @@ type Column struct {
deleteCascade string
}
type ColumnOption func(*Column)
type ColumnOption func(*InitColumn)
func NewColumn(name string, columnType ColumnType, opts ...ColumnOption) *Column {
column := &Column{
func NewColumn(name string, columnType ColumnType, opts ...ColumnOption) *InitColumn {
column := &InitColumn{
Name: name,
Type: columnType,
nullable: false,
@@ -88,19 +88,19 @@ func NewColumn(name string, columnType ColumnType, opts ...ColumnOption) *Column
}
func Nullable() ColumnOption {
return func(c *Column) {
return func(c *InitColumn) {
c.nullable = true
}
}
func Default(value interface{}) ColumnOption {
return func(c *Column) {
return func(c *InitColumn) {
c.defaultValue = value
}
}
func DeleteCascade(column string) ColumnOption {
return func(c *Column) {
return func(c *InitColumn) {
c.deleteCascade = column
}
}
@@ -128,9 +128,8 @@ const (
func NewIndex(name string, columns []string, opts ...indexOpts) *Index {
i := &Index{
Name: name,
Columns: columns,
bucketCount: 0,
Name: name,
Columns: columns,
}
for _, opt := range opts {
opt(i)
@@ -139,16 +138,16 @@ func NewIndex(name string, columns []string, opts ...indexOpts) *Index {
}
type Index struct {
Name string
Columns []string
bucketCount uint16
Name string
Columns []string
includes []string
}
type indexOpts func(*Index)
func Hash(bucketsCount uint16) indexOpts {
func WithInclude(columns ...string) indexOpts {
return func(i *Index) {
i.bucketCount = bucketsCount
i.includes = columns
}
}
@@ -186,25 +185,28 @@ type ForeignKey struct {
RefColumns []string
}
// Init implements handler.Init
func (h *StatementHandler) Init(ctx context.Context) error {
check := h.initCheck
if check == nil || check.IsNoop() {
type initializer interface {
Init() *handler.Check
}
func (h *Handler) Init(ctx context.Context) error {
check, ok := h.projection.(initializer)
if !ok || check.Init().IsNoop() {
return nil
}
tx, err := h.client.BeginTx(ctx, nil)
if err != nil {
return caos_errs.ThrowInternal(err, "CRDB-SAdf2", "begin failed")
return errs.ThrowInternal(err, "CRDB-SAdf2", "begin failed")
}
for i, execute := range check.Executes {
logging.WithFields("projection", h.ProjectionName, "execute", i).Debug("executing check")
next, err := execute(h.client, h.ProjectionName)
for i, execute := range check.Init().Executes {
logging.WithFields("projection", h.projection.Name(), "execute", i).Debug("executing check")
next, err := execute(tx, h.projection.Name())
if err != nil {
tx.Rollback()
logging.OnError(tx.Rollback()).Debug("unable to rollback")
return err
}
if !next {
logging.WithFields("projection", h.ProjectionName, "execute", i).Debug("projection set up")
logging.WithFields("projection", h.projection.Name(), "execute", i).Debug("projection set up")
break
}
}
@@ -272,15 +274,15 @@ func execNextIfExists(config execConfig, q query, opts []execOption, executeNext
}
func isErrAlreadyExists(err error) bool {
caosErr := &caos_errs.CaosError{}
caosErr := &errs.CaosError{}
if !errors.As(err, &caosErr) {
return false
}
sqlErr, ok := caosErr.GetParent().(*pgconn.PgError)
if !ok {
return false
pgErr := new(pgconn.PgError)
if errors.As(caosErr.Parent, &pgErr) {
return pgErr.Code == "42P07"
}
return sqlErr.Code == "42P07"
return false
}
func createTableStatement(table *Table, tableName string, suffix string) string {
@@ -330,11 +332,10 @@ func createIndexStatement(index *Index, tableName string) string {
tableName,
strings.Join(index.Columns, ","),
)
if index.bucketCount == 0 {
return stmt + ";"
if len(index.includes) > 0 {
stmt += " INCLUDE (" + strings.Join(index.includes, ", ") + ")"
}
return fmt.Sprintf("SET experimental_enable_hash_sharded_indexes=on; %s USING HASH WITH BUCKET_COUNT = %d;",
stmt, index.bucketCount)
return stmt + ";"
}
func foreignKeyName(name, tableName, suffix string) string {
@@ -355,7 +356,7 @@ func tableNameWithoutSchema(name string) string {
return name[strings.LastIndex(name, ".")+1:]
}
func createColumnsStatement(cols []*Column, tableName string) string {
func createColumnsStatement(cols []*InitColumn, tableName string) string {
columns := make([]string, len(cols))
for i, col := range cols {
column := col.Name + " " + columnType(col.Type)

View File

@@ -0,0 +1,23 @@
package handler
import (
"github.com/zitadel/logging"
"github.com/zitadel/zitadel/internal/eventstore"
)
func (h *Handler) log() *logging.Entry {
return logging.WithFields("projection", h.projection.Name())
}
func (h *Handler) logFailure(fail *failure) *logging.Entry {
return h.log().WithField("sequence", fail.sequence).
WithField("instance", fail.instance).
WithField("aggregate", fail.aggregateID)
}
func (h *Handler) logEvent(event eventstore.Event) *logging.Entry {
return h.log().WithField("sequence", event.Sequence()).
WithField("instance", event.Aggregate().InstanceID).
WithField("aggregate", event.Aggregate().Type)
}

View File

@@ -0,0 +1,18 @@
package handler
var _ Projection = (*projection)(nil)
type projection struct {
name string
reducers []AggregateReducer
}
// Name implements Projection
func (p *projection) Name() string {
return p.name
}
// Reducers implements Projection
func (p *projection) Reducers() []AggregateReducer {
return p.reducers
}

View File

@@ -0,0 +1,21 @@
package handler
import "github.com/zitadel/zitadel/internal/eventstore"
// EventReducer represents the required data
// to work with events
type EventReducer struct {
Event eventstore.EventType
Reduce Reduce
}
// Reduce reduces the given event to a statement
// which is used to update the projection
type Reduce func(eventstore.Event) (*Statement, error)
// EventReducer represents the required data
// to work with aggregates
type AggregateReducer struct {
Aggregate eventstore.AggregateType
EventReducers []EventReducer
}

View File

@@ -0,0 +1,119 @@
package handler
import (
"context"
"database/sql"
_ "embed"
"errors"
"time"
"github.com/zitadel/zitadel/internal/api/authz"
errs "github.com/zitadel/zitadel/internal/errors"
"github.com/zitadel/zitadel/internal/eventstore"
)
type state struct {
instanceID string
position float64
eventTimestamp time.Time
aggregateType eventstore.AggregateType
aggregateID string
sequence uint64
}
var (
//go:embed state_get.sql
currentStateStmt string
//go:embed state_get_await.sql
currentStateAwaitStmt string
//go:embed state_set.sql
updateStateStmt string
//go:embed state_lock.sql
lockStateStmt string
//go:embed state_set_last_run.sql
updateStateLastRunStmt string
errJustUpdated = errors.New("projection was just updated")
)
func (h *Handler) currentState(ctx context.Context, tx *sql.Tx, config *triggerConfig) (currentState *state, err error) {
currentState = &state{
instanceID: authz.GetInstance(ctx).InstanceID(),
}
var (
aggregateID = new(sql.NullString)
aggregateType = new(sql.NullString)
sequence = new(sql.NullInt64)
timestamp = new(sql.NullTime)
position = new(sql.NullFloat64)
)
stateQuery := currentStateStmt
if config.awaitRunning {
stateQuery = currentStateAwaitStmt
}
row := tx.QueryRow(stateQuery, currentState.instanceID, h.projection.Name())
err = row.Scan(
aggregateID,
aggregateType,
sequence,
timestamp,
position,
)
if errors.Is(err, sql.ErrNoRows) {
err = h.lockState(tx, currentState.instanceID)
}
if err != nil {
h.log().WithError(err).Debug("unable to query current state")
return nil, err
}
currentState.aggregateID = aggregateID.String
currentState.aggregateType = eventstore.AggregateType(aggregateType.String)
currentState.sequence = uint64(sequence.Int64)
currentState.eventTimestamp = timestamp.Time
currentState.position = position.Float64
return currentState, nil
}
func (h *Handler) setState(tx *sql.Tx, updatedState *state) error {
res, err := tx.Exec(updateStateStmt,
h.projection.Name(),
updatedState.instanceID,
updatedState.aggregateID,
updatedState.aggregateType,
updatedState.sequence,
updatedState.eventTimestamp,
updatedState.position,
)
if err != nil {
h.log().WithError(err).Debug("unable to update state")
return err
}
if affected, err := res.RowsAffected(); affected == 0 {
h.log().OnError(err).Error("unable to check if states are updated")
return errs.ThrowInternal(err, "V2-FGEKi", "unable to update state")
}
return nil
}
func (h *Handler) updateLastUpdated(ctx context.Context, tx *sql.Tx, updatedState *state) {
_, err := tx.ExecContext(ctx, updateStateLastRunStmt, h.projection.Name(), updatedState.instanceID)
h.log().OnError(err).Debug("unable to update last updated")
}
func (h *Handler) lockState(tx *sql.Tx, instanceID string) error {
res, err := tx.Exec(lockStateStmt,
h.projection.Name(),
instanceID,
)
if err != nil {
return err
}
if affected, err := res.RowsAffected(); affected == 0 || err != nil {
return errs.ThrowInternal(err, "V2-lpiK0", "projection already locked")
}
return nil
}

View File

@@ -0,0 +1,12 @@
SELECT
aggregate_id
, aggregate_type
, "sequence"
, event_date
, "position"
FROM
projections.current_states
WHERE
instance_id = $1
AND projection_name = $2
FOR UPDATE NOWAIT;

View File

@@ -0,0 +1,12 @@
SELECT
aggregate_id
, aggregate_type
, "sequence"
, event_date
, "position"
FROM
projections.current_states
WHERE
instance_id = $1
AND projection_name = $2
FOR UPDATE;

View File

@@ -0,0 +1,9 @@
INSERT INTO projections.current_states (
projection_name
, instance_id
, last_updated
) VALUES (
$1
, $2
, now()
) ON CONFLICT DO NOTHING;

View File

@@ -0,0 +1,29 @@
INSERT INTO projections.current_states (
projection_name
, instance_id
, aggregate_id
, aggregate_type
, "sequence"
, event_date
, "position"
, last_updated
) VALUES (
$1
, $2
, $3
, $4
, $5
, $6
, $7
, now()
) ON CONFLICT (
projection_name
, instance_id
) DO UPDATE SET
aggregate_id = $3
, aggregate_type = $4
, "sequence" = $5
, event_date = $6
, "position" = $7
, last_updated = statement_timestamp()
;

View File

@@ -0,0 +1,2 @@
UPDATE projections.current_states SET last_updated = now() WHERE projection_name = $1 AND instance_id = $2;

View File

@@ -0,0 +1,447 @@
package handler
import (
"context"
"database/sql"
"database/sql/driver"
_ "embed"
"errors"
"reflect"
"testing"
"time"
"github.com/jackc/pgconn"
"github.com/zitadel/zitadel/internal/api/authz"
"github.com/zitadel/zitadel/internal/database/mock"
errs "github.com/zitadel/zitadel/internal/errors"
)
func TestHandler_lockState(t *testing.T) {
type fields struct {
projection Projection
mock *mock.SQLMock
}
type args struct {
instanceID string
}
tests := []struct {
name string
fields fields
args args
isErr func(t *testing.T, err error)
}{
{
name: "tx closed",
fields: fields{
projection: &projection{
name: "projection",
},
mock: mock.NewSQLMock(t,
mock.ExpectBegin(nil),
mock.ExcpectExec(
lockStateStmt,
mock.WithExecArgs(
"projection",
"instance",
),
mock.WithExecErr(sql.ErrTxDone),
),
),
},
args: args{
instanceID: "instance",
},
isErr: func(t *testing.T, err error) {
if !errors.Is(err, sql.ErrTxDone) {
t.Errorf("unexpected error, want: %v got: %v", sql.ErrTxDone, err)
}
},
},
{
name: "no rows affeced",
fields: fields{
projection: &projection{
name: "projection",
},
mock: mock.NewSQLMock(t,
mock.ExpectBegin(nil),
mock.ExcpectExec(
lockStateStmt,
mock.WithExecArgs(
"projection",
"instance",
),
mock.WithExecNoRowsAffected(),
),
),
},
args: args{
instanceID: "instance",
},
isErr: func(t *testing.T, err error) {
if !errors.Is(err, errs.ThrowInternal(nil, "V2-lpiK0", "")) {
t.Errorf("unexpected error: want internal (V2lpiK0), got: %v", err)
}
},
},
{
name: "rows affected",
fields: fields{
projection: &projection{
name: "projection",
},
mock: mock.NewSQLMock(t,
mock.ExpectBegin(nil),
mock.ExcpectExec(
lockStateStmt,
mock.WithExecArgs(
"projection",
"instance",
),
mock.WithExecRowsAffected(1),
),
),
},
args: args{
instanceID: "instance",
},
},
}
for _, tt := range tests {
if tt.isErr == nil {
tt.isErr = func(t *testing.T, err error) {
if err != nil {
t.Error("expected no error got:", err)
}
}
}
t.Run(tt.name, func(t *testing.T) {
h := &Handler{
projection: tt.fields.projection,
}
tx, err := tt.fields.mock.DB.Begin()
if err != nil {
t.Fatalf("unable to begin transaction: %v", err)
}
err = h.lockState(tx, tt.args.instanceID)
tt.isErr(t, err)
tt.fields.mock.Assert(t)
})
}
}
func TestHandler_updateLastUpdated(t *testing.T) {
type fields struct {
projection Projection
mock *mock.SQLMock
}
type args struct {
updatedState *state
}
tests := []struct {
name string
fields fields
args args
isErr func(t *testing.T, err error)
}{
{
name: "update fails",
fields: fields{
projection: &projection{
name: "instance",
},
mock: mock.NewSQLMock(t,
mock.ExpectBegin(nil),
mock.ExcpectExec(updateStateStmt,
mock.WithExecErr(sql.ErrTxDone),
),
),
},
args: args{
updatedState: &state{
instanceID: "instance",
eventTimestamp: time.Now(),
position: 42,
},
},
isErr: func(t *testing.T, err error) {
if !errors.Is(err, sql.ErrTxDone) {
t.Errorf("unexpected error, want: %v, got %v", sql.ErrTxDone, err)
}
},
},
{
name: "no rows affected",
fields: fields{
projection: &projection{
name: "instance",
},
mock: mock.NewSQLMock(t,
mock.ExpectBegin(nil),
mock.ExcpectExec(updateStateStmt,
mock.WithExecNoRowsAffected(),
),
),
},
args: args{
updatedState: &state{
instanceID: "instance",
eventTimestamp: time.Now(),
position: 42,
},
},
isErr: func(t *testing.T, err error) {
if !errors.Is(err, errs.ThrowInternal(nil, "V2-FGEKi", "")) {
t.Errorf("unexpected error, want: %v, got %v", sql.ErrTxDone, err)
}
},
},
{
name: "success",
fields: fields{
projection: &projection{
name: "projection",
},
mock: mock.NewSQLMock(t,
mock.ExpectBegin(nil),
mock.ExcpectExec(updateStateStmt,
mock.WithExecArgs(
"projection",
"instance",
"aggregate id",
"aggregate type",
uint64(42),
mock.AnyType[time.Time]{},
float64(42),
),
mock.WithExecRowsAffected(1),
),
),
},
args: args{
updatedState: &state{
instanceID: "instance",
eventTimestamp: time.Now(),
position: 42,
aggregateType: "aggregate type",
aggregateID: "aggregate id",
sequence: 42,
},
},
},
}
for _, tt := range tests {
if tt.isErr == nil {
tt.isErr = func(t *testing.T, err error) {
if err != nil {
t.Error("expected no error got:", err)
}
}
}
t.Run(tt.name, func(t *testing.T) {
tx, err := tt.fields.mock.DB.Begin()
if err != nil {
t.Fatalf("unable to begin transaction: %v", err)
}
h := &Handler{
projection: tt.fields.projection,
}
err = h.setState(tx, tt.args.updatedState)
tt.isErr(t, err)
tt.fields.mock.Assert(t)
})
}
}
func TestHandler_currentState(t *testing.T) {
testTime := time.Now()
type fields struct {
projection Projection
mock *mock.SQLMock
}
type args struct {
ctx context.Context
}
type want struct {
currentState *state
isErr func(t *testing.T, err error)
}
tests := []struct {
name string
fields fields
args args
want want
}{
{
name: "connection done",
fields: fields{
projection: &projection{
name: "projection",
},
mock: mock.NewSQLMock(t,
mock.ExpectBegin(nil),
mock.ExpectQuery(currentStateStmt,
mock.WithQueryArgs(
"instance",
"projection",
),
mock.WithQueryErr(sql.ErrConnDone),
),
),
},
args: args{
ctx: authz.WithInstanceID(context.Background(), "instance"),
},
want: want{
isErr: func(t *testing.T, err error) {
if !errors.Is(err, sql.ErrConnDone) {
t.Errorf("unexpected error, want: %v, got: %v", sql.ErrConnDone, err)
}
},
},
},
{
name: "no row but lock err",
fields: fields{
projection: &projection{
name: "projection",
},
mock: mock.NewSQLMock(t,
mock.ExpectBegin(nil),
mock.ExpectQuery(currentStateStmt,
mock.WithQueryArgs(
"instance",
"projection",
),
mock.WithQueryErr(sql.ErrNoRows),
),
mock.ExcpectExec(lockStateStmt,
mock.WithExecArgs(
"projection",
"instance",
),
mock.WithExecErr(sql.ErrTxDone),
),
),
},
args: args{
ctx: authz.WithInstanceID(context.Background(), "instance"),
},
want: want{
isErr: func(t *testing.T, err error) {
if !errors.Is(err, sql.ErrTxDone) {
t.Errorf("unexpected error, want: %v, got: %v", sql.ErrTxDone, err)
}
},
},
},
{
name: "state locked",
fields: fields{
projection: &projection{
name: "projection",
},
mock: mock.NewSQLMock(t,
mock.ExpectBegin(nil),
mock.ExpectQuery(currentStateStmt,
mock.WithQueryArgs(
"instance",
"projection",
),
mock.WithQueryErr(&pgconn.PgError{Code: "55P03"}),
),
),
},
args: args{
ctx: authz.WithInstanceID(context.Background(), "instance"),
},
want: want{
isErr: func(t *testing.T, err error) {
pgErr := new(pgconn.PgError)
if !errors.As(err, &pgErr) {
t.Errorf("error should be PgErr but was %T", err)
return
}
if pgErr.Code != "55P03" {
t.Errorf("expected code 55P03 got: %s", pgErr.Code)
}
},
},
},
{
name: "success",
fields: fields{
projection: &projection{
name: "projection",
},
mock: mock.NewSQLMock(t,
mock.ExpectBegin(nil),
mock.ExpectQuery(currentStateStmt,
mock.WithQueryArgs(
"instance",
"projection",
),
mock.WithQueryResult(
[]string{"aggregate_id", "aggregate_type", "event_sequence", "event_date", "position"},
[][]driver.Value{
{
"aggregate id",
"aggregate type",
int64(42),
testTime,
float64(42),
},
},
),
),
),
},
args: args{
ctx: authz.WithInstanceID(context.Background(), "instance"),
},
want: want{
currentState: &state{
instanceID: "instance",
eventTimestamp: testTime,
position: 42,
aggregateType: "aggregate type",
aggregateID: "aggregate id",
sequence: 42,
},
},
},
}
for _, tt := range tests {
if tt.want.isErr == nil {
tt.want.isErr = func(t *testing.T, err error) {
if err != nil {
t.Error("expected no error got:", err)
}
}
}
t.Run(tt.name, func(t *testing.T) {
h := &Handler{
projection: tt.fields.projection,
}
tx, err := tt.fields.mock.DB.Begin()
if err != nil {
t.Fatalf("unable to begin transaction: %v", err)
}
gotCurrentState, err := h.currentState(tt.args.ctx, tx, new(triggerConfig))
tt.want.isErr(t, err)
if !reflect.DeepEqual(gotCurrentState, tt.want.currentState) {
t.Errorf("Handler.currentState() gotCurrentState = %v, want %v", gotCurrentState, tt.want.currentState)
}
tt.fields.mock.Assert(t)
})
}
}

View File

@@ -1,41 +1,89 @@
package crdb
package handler
import (
"database/sql"
"errors"
"encoding/json"
errs "errors"
"strconv"
"strings"
"time"
"github.com/zitadel/logging"
"github.com/zitadel/zitadel/internal/database"
zitadel_errors "github.com/zitadel/zitadel/internal/errors"
"github.com/zitadel/zitadel/internal/errors"
"github.com/zitadel/zitadel/internal/eventstore"
"github.com/zitadel/zitadel/internal/eventstore/handler"
)
type execOption func(*execConfig)
type execConfig struct {
tableName string
args []interface{}
err error
ignoreNotFound bool
func (h *Handler) eventsToStatements(tx *sql.Tx, events []eventstore.Event, currentState *state) (statements []*Statement, err error) {
statements = make([]*Statement, 0, len(events))
for _, event := range events {
statement, err := h.reduce(event)
if err != nil {
h.logEvent(event).WithError(err).Error("reduce failed")
if shouldContinue := h.handleFailedStmt(tx, currentState, failureFromEvent(event, err)); shouldContinue {
continue
}
return statements, err
}
statements = append(statements, statement)
}
return statements, nil
}
func (h *Handler) reduce(event eventstore.Event) (*Statement, error) {
for _, reducer := range h.projection.Reducers() {
if reducer.Aggregate != event.Aggregate().Type {
continue
}
for _, reduce := range reducer.EventReducers {
if reduce.Event != event.Type() {
continue
}
return reduce.Reduce(event)
}
}
return NewNoOpStatement(event), nil
}
type Statement struct {
AggregateType eventstore.AggregateType
AggregateID string
Sequence uint64
Position float64
CreationDate time.Time
InstanceID string
Execute Exec
}
type Exec func(ex Executer, projectionName string) error
func WithTableSuffix(name string) func(*execConfig) {
return func(o *execConfig) {
o.tableName += "_" + name
}
}
func WithIgnoreNotFound() func(*execConfig) {
return func(o *execConfig) {
o.ignoreNotFound = true
var (
ErrNoProjection = errs.New("no projection")
ErrNoValues = errs.New("no values")
ErrNoCondition = errs.New("no condition")
)
func NewStatement(event eventstore.Event, e Exec) *Statement {
return &Statement{
AggregateType: event.Aggregate().Type,
Sequence: event.Sequence(),
Position: event.Position(),
AggregateID: event.Aggregate().ID,
CreationDate: event.CreatedAt(),
InstanceID: event.Aggregate().InstanceID,
Execute: e,
}
}
func NewCreateStatement(event eventstore.Event, values []handler.Column, opts ...execOption) *handler.Statement {
func NewCreateStatement(event eventstore.Event, values []Column, opts ...execOption) *Statement {
cols, params, args := columnsToQuery(values)
columnNames := strings.Join(cols, ", ")
valuesPlaceholder := strings.Join(params, ", ")
@@ -45,23 +93,17 @@ func NewCreateStatement(event eventstore.Event, values []handler.Column, opts ..
}
if len(values) == 0 {
config.err = handler.ErrNoValues
config.err = ErrNoValues
}
q := func(config execConfig) string {
return "INSERT INTO " + config.tableName + " (" + columnNames + ") VALUES (" + valuesPlaceholder + ")"
}
return &handler.Statement{
AggregateType: event.Aggregate().Type,
Sequence: event.Sequence(),
PreviousSequence: event.PreviousAggregateTypeSequence(),
InstanceID: event.Aggregate().InstanceID,
Execute: exec(config, q, opts),
}
return NewStatement(event, exec(config, q, opts))
}
func NewUpsertStatement(event eventstore.Event, conflictCols []handler.Column, values []handler.Column, opts ...execOption) *handler.Statement {
func NewUpsertStatement(event eventstore.Event, conflictCols []Column, values []Column, opts ...execOption) *Statement {
cols, params, args := columnsToQuery(values)
conflictTarget := make([]string, len(conflictCols))
@@ -74,12 +116,12 @@ func NewUpsertStatement(event eventstore.Event, conflictCols []handler.Column, v
}
if len(values) == 0 {
config.err = handler.ErrNoValues
config.err = ErrNoValues
}
updateCols, updateVals := getUpdateCols(cols, conflictTarget)
if len(updateCols) == 0 || len(updateVals) == 0 {
config.err = handler.ErrNoValues
config.err = ErrNoValues
}
q := func(config execConfig) string {
@@ -96,13 +138,7 @@ func NewUpsertStatement(event eventstore.Event, conflictCols []handler.Column, v
" ON CONFLICT (" + strings.Join(conflictTarget, ", ") + ") DO " + updateStmt
}
return &handler.Statement{
AggregateType: event.Aggregate().Type,
Sequence: event.Sequence(),
PreviousSequence: event.PreviousAggregateTypeSequence(),
InstanceID: event.Aggregate().InstanceID,
Execute: exec(config, q, opts),
}
return NewStatement(event, exec(config, q, opts))
}
func getUpdateCols(cols, conflictTarget []string) (updateCols, updateVals []string) {
@@ -132,9 +168,9 @@ func getUpdateCols(cols, conflictTarget []string) (updateCols, updateVals []stri
return updateCols, updateVals
}
func NewUpdateStatement(event eventstore.Event, values []handler.Column, conditions []handler.Condition, opts ...execOption) *handler.Statement {
func NewUpdateStatement(event eventstore.Event, values []Column, conditions []Condition, opts ...execOption) *Statement {
cols, params, args := columnsToQuery(values)
wheres, whereArgs := conditionsToWhere(conditions, len(args))
wheres, whereArgs := conditionsToWhere(conditions, len(args)+1)
args = append(args, whereArgs...)
config := execConfig{
@@ -142,11 +178,11 @@ func NewUpdateStatement(event eventstore.Event, values []handler.Column, conditi
}
if len(values) == 0 {
config.err = handler.ErrNoValues
config.err = ErrNoValues
}
if len(conditions) == 0 {
config.err = handler.ErrNoCondition
config.err = ErrNoCondition
}
q := func(config execConfig) string {
@@ -159,17 +195,11 @@ func NewUpdateStatement(event eventstore.Event, values []handler.Column, conditi
return "UPDATE " + config.tableName + " SET (" + strings.Join(cols, ", ") + ") = (" + strings.Join(params, ", ") + ") WHERE " + strings.Join(wheres, " AND ")
}
return &handler.Statement{
AggregateType: event.Aggregate().Type,
Sequence: event.Sequence(),
PreviousSequence: event.PreviousAggregateTypeSequence(),
InstanceID: event.Aggregate().InstanceID,
Execute: exec(config, q, opts),
}
return NewStatement(event, exec(config, q, opts))
}
func NewDeleteStatement(event eventstore.Event, conditions []handler.Condition, opts ...execOption) *handler.Statement {
wheres, args := conditionsToWhere(conditions, 0)
func NewDeleteStatement(event eventstore.Event, conditions []Condition, opts ...execOption) *Statement {
wheres, args := conditionsToWhere(conditions, 1)
wheresPlaceholders := strings.Join(wheres, " AND ")
@@ -178,32 +208,21 @@ func NewDeleteStatement(event eventstore.Event, conditions []handler.Condition,
}
if len(conditions) == 0 {
config.err = handler.ErrNoCondition
config.err = ErrNoCondition
}
q := func(config execConfig) string {
return "DELETE FROM " + config.tableName + " WHERE " + wheresPlaceholders
}
return &handler.Statement{
AggregateType: event.Aggregate().Type,
Sequence: event.Sequence(),
PreviousSequence: event.PreviousAggregateTypeSequence(),
InstanceID: event.Aggregate().InstanceID,
Execute: exec(config, q, opts),
}
return NewStatement(event, exec(config, q, opts))
}
func NewNoOpStatement(event eventstore.Event) *handler.Statement {
return &handler.Statement{
AggregateType: event.Aggregate().Type,
Sequence: event.Sequence(),
PreviousSequence: event.PreviousAggregateTypeSequence(),
InstanceID: event.Aggregate().InstanceID,
}
func NewNoOpStatement(event eventstore.Event) *Statement {
return NewStatement(event, nil)
}
func NewMultiStatement(event eventstore.Event, opts ...func(eventstore.Event) Exec) *handler.Statement {
func NewMultiStatement(event eventstore.Event, opts ...func(eventstore.Event) Exec) *Statement {
if len(opts) == 0 {
return NewNoOpStatement(event)
}
@@ -211,43 +230,47 @@ func NewMultiStatement(event eventstore.Event, opts ...func(eventstore.Event) Ex
for i, opt := range opts {
execs[i] = opt(event)
}
return &handler.Statement{
AggregateType: event.Aggregate().Type,
Sequence: event.Sequence(),
PreviousSequence: event.PreviousAggregateTypeSequence(),
InstanceID: event.Aggregate().InstanceID,
Execute: multiExec(execs),
return NewStatement(event, multiExec(execs))
}
func AddNoOpStatement() func(eventstore.Event) Exec {
return func(event eventstore.Event) Exec {
return NewNoOpStatement(event).Execute
}
}
type Exec func(ex handler.Executer, projectionName string) error
func AddCreateStatement(columns []handler.Column, opts ...execOption) func(eventstore.Event) Exec {
func AddCreateStatement(columns []Column, opts ...execOption) func(eventstore.Event) Exec {
return func(event eventstore.Event) Exec {
return NewCreateStatement(event, columns, opts...).Execute
}
}
func AddUpsertStatement(indexCols []handler.Column, values []handler.Column, opts ...execOption) func(eventstore.Event) Exec {
func AddUpsertStatement(indexCols []Column, values []Column, opts ...execOption) func(eventstore.Event) Exec {
return func(event eventstore.Event) Exec {
return NewUpsertStatement(event, indexCols, values, opts...).Execute
}
}
func AddUpdateStatement(values []handler.Column, conditions []handler.Condition, opts ...execOption) func(eventstore.Event) Exec {
func AddUpdateStatement(values []Column, conditions []Condition, opts ...execOption) func(eventstore.Event) Exec {
return func(event eventstore.Event) Exec {
return NewUpdateStatement(event, values, conditions, opts...).Execute
}
}
func AddDeleteStatement(conditions []handler.Condition, opts ...execOption) func(eventstore.Event) Exec {
func AddDeleteStatement(conditions []Condition, opts ...execOption) func(eventstore.Event) Exec {
return func(event eventstore.Event) Exec {
return NewDeleteStatement(event, conditions, opts...).Execute
}
}
func NewArrayAppendCol(column string, value interface{}) handler.Column {
return handler.Column{
func AddCopyStatement(conflict, from, to []Column, conditions []NamespacedCondition, opts ...execOption) func(eventstore.Event) Exec {
return func(event eventstore.Event) Exec {
return NewCopyStatement(event, conflict, from, to, conditions, opts...).Execute
}
}
func NewArrayAppendCol(column string, value interface{}) Column {
return Column{
Name: column,
Value: value,
ParameterOpt: func(placeholder string) string {
@@ -256,8 +279,8 @@ func NewArrayAppendCol(column string, value interface{}) handler.Column {
}
}
func NewArrayRemoveCol(column string, value interface{}) handler.Column {
return handler.Column{
func NewArrayRemoveCol(column string, value interface{}) Column {
return Column{
Name: column,
Value: value,
ParameterOpt: func(placeholder string) string {
@@ -266,15 +289,15 @@ func NewArrayRemoveCol(column string, value interface{}) handler.Column {
}
}
func NewArrayIntersectCol(column string, value interface{}) handler.Column {
func NewArrayIntersectCol(column string, value interface{}) Column {
var arrayType string
switch value.(type) {
case []string, database.StringArray:
case []string, database.TextArray[string]:
arrayType = "TEXT"
//TODO: handle more types if necessary
}
return handler.Column{
return Column{
Name: column,
Value: value,
ParameterOpt: func(placeholder string) string {
@@ -283,38 +306,10 @@ func NewArrayIntersectCol(column string, value interface{}) handler.Column {
}
}
func NewCopyCol(column, from string) handler.Column {
return handler.Column{
func NewCopyCol(column, from string) Column {
return Column{
Name: column,
Value: handler.NewCol(from, nil),
}
}
func NewLessThanCond(column string, value interface{}) handler.Condition {
return func(param string) (string, interface{}) {
return column + " < " + param, value
}
}
func NewIsNullCond(column string) handler.Condition {
return func(param string) (string, interface{}) {
return column + " IS NULL", nil
}
}
// NewTextArrayContainsCond returns a handler.Condition that checks if the column that stores an array of text contains the given value
func NewTextArrayContainsCond(column string, value string) handler.Condition {
return func(param string) (string, interface{}) {
return column + " @> " + param, database.StringArray{value}
}
}
// Not is a function and not a method, so that calling it is well readable
// For example conditions := []handler.Condition{ Not(NewTextArrayContainsCond())}
func Not(condition handler.Condition) handler.Condition {
return func(param string) (string, interface{}) {
cond, value := condition(param)
return "NOT (" + cond + ")", value
Value: NewCol(from, nil),
}
}
@@ -323,7 +318,7 @@ func Not(condition handler.Condition) handler.Condition {
// if the value of a col is empty the data will be copied from the selected row
// if the value of a col is not empty the data will be set by the static value
// conds represent the conditions for the selection subquery
func NewCopyStatement(event eventstore.Event, conflictCols, from, to []handler.Column, nsCond []handler.NamespacedCondition, opts ...execOption) *handler.Statement {
func NewCopyStatement(event eventstore.Event, conflictCols, from, to []Column, nsCond []NamespacedCondition, opts ...execOption) *Statement {
columnNames := make([]string, len(to))
selectColumns := make([]string, len(from))
updateColumns := make([]string, len(columnNames))
@@ -342,11 +337,11 @@ func NewCopyStatement(event eventstore.Event, conflictCols, from, to []handler.C
}
}
cond := make([]handler.Condition, len(nsCond))
cond := make([]Condition, len(nsCond))
for i := range nsCond {
cond[i] = nsCond[i]("copy_table")
}
wheres, values := conditionsToWhere(cond, len(args))
wheres, values := conditionsToWhere(cond, len(args)+1)
args = append(args, values...)
conflictTargets := make([]string, len(conflictCols))
@@ -359,11 +354,11 @@ func NewCopyStatement(event eventstore.Event, conflictCols, from, to []handler.C
}
if len(from) == 0 || len(to) == 0 || len(from) != len(to) {
config.err = handler.ErrNoValues
config.err = ErrNoValues
}
if len(cond) == 0 {
config.err = handler.ErrNoCondition
config.err = ErrNoCondition
}
q := func(config execConfig) string {
@@ -385,23 +380,17 @@ func NewCopyStatement(event eventstore.Event, conflictCols, from, to []handler.C
")"
}
return &handler.Statement{
AggregateType: event.Aggregate().Type,
Sequence: event.Sequence(),
PreviousSequence: event.PreviousAggregateTypeSequence(),
InstanceID: event.Aggregate().InstanceID,
Execute: exec(config, q, opts),
}
return NewStatement(event, exec(config, q, opts))
}
func columnsToQuery(cols []handler.Column) (names []string, parameters []string, values []interface{}) {
func columnsToQuery(cols []Column) (names []string, parameters []string, values []interface{}) {
names = make([]string, len(cols))
values = make([]interface{}, len(cols))
parameters = make([]string, len(cols))
var parameterIndex int
for i, col := range cols {
names[i] = col.Name
if c, ok := col.Value.(handler.Column); ok {
if c, ok := col.Value.(Column); ok {
parameters[i] = c.Name
continue
} else {
@@ -416,25 +405,105 @@ func columnsToQuery(cols []handler.Column) (names []string, parameters []string,
return names, parameters, values[:parameterIndex]
}
func conditionsToWhere(conditions []handler.Condition, paramOffset int) (wheres []string, values []interface{}) {
wheres = make([]string, len(conditions))
values = make([]interface{}, 0, len(conditions))
for i, conditionFunc := range conditions {
condition, value := conditionFunc("$" + strconv.Itoa(i+1+paramOffset))
wheres[i] = "(" + condition + ")"
if value != nil {
values = append(values, value)
}
func conditionsToWhere(conds []Condition, paramOffset int) (wheres []string, values []interface{}) {
wheres = make([]string, len(conds))
values = make([]any, 0, len(conds))
for i, cond := range conds {
var args []any
wheres[i], args = cond("$" + strconv.Itoa(paramOffset))
paramOffset += len(args)
values = append(values, args...)
wheres[i] = "(" + wheres[i] + ")"
}
return wheres, values
}
type Column struct {
Name string
Value interface{}
ParameterOpt func(string) string
}
func NewCol(name string, value interface{}) Column {
return Column{
Name: name,
Value: value,
}
}
func NewJSONCol(name string, value interface{}) Column {
marshalled, err := json.Marshal(value)
if err != nil {
logging.WithFields("column", name).WithError(err).Panic("unable to marshal column")
}
return NewCol(name, marshalled)
}
type Condition func(param string) (string, []any)
type NamespacedCondition func(namespace string) Condition
func NewCond(name string, value interface{}) Condition {
return func(param string) (string, []any) {
return name + " = " + param, []any{value}
}
}
func NewNamespacedCondition(name string, value interface{}) NamespacedCondition {
return func(namespace string) Condition {
return NewCond(namespace+"."+name, value)
}
}
func NewLessThanCond(column string, value interface{}) Condition {
return func(param string) (string, []any) {
return column + " < " + param, []any{value}
}
}
func NewIsNullCond(column string) Condition {
return func(string) (string, []any) {
return column + " IS NULL", nil
}
}
// NewTextArrayContainsCond returns a Condition that checks if the column that stores an array of text contains the given value
func NewTextArrayContainsCond(column string, value string) Condition {
return func(param string) (string, []any) {
return column + " @> " + param, []any{database.TextArray[string]{value}}
}
}
// Not is a function and not a method, so that calling it is well readable
// For example conditions := []Condition{ Not(NewTextArrayContainsCond())}
func Not(condition Condition) Condition {
return func(param string) (string, []any) {
cond, value := condition(param)
return "NOT (" + cond + ")", value
}
}
type Executer interface {
Exec(string, ...interface{}) (sql.Result, error)
}
type execOption func(*execConfig)
type execConfig struct {
tableName string
args []interface{}
err error
}
type query func(config execConfig) string
func exec(config execConfig, q query, opts []execOption) Exec {
return func(ex handler.Executer, projectionName string) error {
return func(ex Executer, projectionName string) (err error) {
if projectionName == "" {
return handler.ErrNoProjection
return ErrNoProjection
}
if config.err != nil {
@@ -446,12 +515,21 @@ func exec(config execConfig, q query, opts []execOption) Exec {
opt(&config)
}
if _, err := ex.Exec(q(config), config.args...); err != nil {
if config.ignoreNotFound && errors.Is(err, sql.ErrNoRows) {
logging.WithError(err).Debugf("ignored not found: %v", err)
return nil
_, err = ex.Exec("SAVEPOINT stmt_exec")
if err != nil {
return errors.ThrowInternal(err, "CRDB-YdOXD", "create savepoint failed")
}
defer func() {
if err != nil {
_, rollbackErr := ex.Exec("ROLLBACK TO SAVEPOINT stmt_exec")
logging.OnError(rollbackErr).Debug("rollback failed")
return
}
return zitadel_errors.ThrowInternal(err, "CRDB-pKtsr", "exec failed")
_, err = ex.Exec("RELEASE SAVEPOINT stmt_exec")
}()
_, err = ex.Exec(q(config), config.args...)
if err != nil {
return errors.ThrowInternal(err, "CRDB-pKtsr", "exec failed")
}
return nil
@@ -459,8 +537,11 @@ func exec(config execConfig, q query, opts []execOption) Exec {
}
func multiExec(execList []Exec) Exec {
return func(ex handler.Executer, projectionName string) error {
return func(ex Executer, projectionName string) error {
for _, exec := range execList {
if exec == nil {
continue
}
if err := exec(ex, projectionName); err != nil {
return err
}

View File

@@ -1,14 +1,14 @@
package crdb
package handler
import (
"database/sql"
"errors"
"reflect"
"strings"
"testing"
"github.com/zitadel/zitadel/internal/database"
"github.com/zitadel/zitadel/internal/eventstore"
"github.com/zitadel/zitadel/internal/eventstore/handler"
)
type wantExecuter struct {
@@ -24,21 +24,50 @@ type params struct {
args []interface{}
}
var errTestErr = errors.New("some error")
var errTest = errors.New("some error")
var _ eventstore.Event = &testEvent{}
type testEvent struct {
eventstore.BaseEvent
sequence uint64
previousSequence uint64
aggregateType eventstore.AggregateType
instanceID string
}
func (e *testEvent) Sequence() uint64 {
return e.sequence
}
func (e *testEvent) Aggregate() *eventstore.Aggregate {
return &eventstore.Aggregate{
Type: e.aggregateType,
InstanceID: e.instanceID,
}
}
func (e *testEvent) PreviousAggregateTypeSequence() uint64 {
return e.previousSequence
}
func (ex *wantExecuter) check(t *testing.T) {
t.Helper()
if ex.wasExecuted && !ex.shouldExecute {
switch {
case ex.wasExecuted && !ex.shouldExecute:
t.Error("executer should not be executed")
} else if !ex.wasExecuted && ex.shouldExecute {
case !ex.wasExecuted && ex.shouldExecute:
t.Error("executer should be executed")
} else if ex.wasExecuted != ex.shouldExecute {
case ex.wasExecuted != ex.shouldExecute:
t.Errorf("executed missmatched should be %t, but was %t", ex.shouldExecute, ex.wasExecuted)
}
}
func (ex *wantExecuter) Exec(query string, args ...interface{}) (sql.Result, error) {
ex.t.Helper()
if strings.Contains(query, "SAVEPOINT") {
return nil, nil
}
ex.wasExecuted = true
if ex.i >= len(ex.params) {
ex.t.Errorf("did not expect more exec, but got:\n %q with %q", query, args)
@@ -59,7 +88,7 @@ func TestNewCreateStatement(t *testing.T) {
type args struct {
table string
event *testEvent
values []handler.Column
values []Column
}
type want struct {
aggregateType eventstore.AggregateType
@@ -83,7 +112,7 @@ func TestNewCreateStatement(t *testing.T) {
sequence: 1,
previousSequence: 0,
},
values: []handler.Column{
values: []Column{
{
Name: "col1",
Value: "val",
@@ -99,7 +128,7 @@ func TestNewCreateStatement(t *testing.T) {
shouldExecute: false,
},
isErr: func(err error) bool {
return errors.Is(err, handler.ErrNoProjection)
return errors.Is(err, ErrNoProjection)
},
},
},
@@ -112,7 +141,7 @@ func TestNewCreateStatement(t *testing.T) {
sequence: 1,
previousSequence: 0,
},
values: []handler.Column{},
values: []Column{},
},
want: want{
table: "my_table",
@@ -123,7 +152,7 @@ func TestNewCreateStatement(t *testing.T) {
shouldExecute: false,
},
isErr: func(err error) bool {
return errors.Is(err, handler.ErrNoValues)
return errors.Is(err, ErrNoValues)
},
},
},
@@ -136,7 +165,7 @@ func TestNewCreateStatement(t *testing.T) {
sequence: 1,
previousSequence: 0,
},
values: []handler.Column{
values: []Column{
{
Name: "col1",
Value: "val",
@@ -181,8 +210,8 @@ func TestNewUpsertStatement(t *testing.T) {
type args struct {
table string
event *testEvent
conflictCols []handler.Column
values []handler.Column
conflictCols []Column
values []Column
}
type want struct {
aggregateType eventstore.AggregateType
@@ -206,7 +235,7 @@ func TestNewUpsertStatement(t *testing.T) {
sequence: 1,
previousSequence: 0,
},
values: []handler.Column{
values: []Column{
{
Name: "col1",
Value: "val",
@@ -222,7 +251,7 @@ func TestNewUpsertStatement(t *testing.T) {
shouldExecute: false,
},
isErr: func(err error) bool {
return errors.Is(err, handler.ErrNoProjection)
return errors.Is(err, ErrNoProjection)
},
},
},
@@ -235,7 +264,7 @@ func TestNewUpsertStatement(t *testing.T) {
sequence: 1,
previousSequence: 0,
},
values: []handler.Column{},
values: []Column{},
},
want: want{
table: "my_table",
@@ -246,7 +275,7 @@ func TestNewUpsertStatement(t *testing.T) {
shouldExecute: false,
},
isErr: func(err error) bool {
return errors.Is(err, handler.ErrNoValues)
return errors.Is(err, ErrNoValues)
},
},
},
@@ -259,10 +288,10 @@ func TestNewUpsertStatement(t *testing.T) {
sequence: 1,
previousSequence: 0,
},
conflictCols: []handler.Column{
handler.NewCol("col1", nil),
conflictCols: []Column{
NewCol("col1", nil),
},
values: []handler.Column{
values: []Column{
{
Name: "col1",
Value: "val",
@@ -278,7 +307,7 @@ func TestNewUpsertStatement(t *testing.T) {
shouldExecute: false,
},
isErr: func(err error) bool {
return errors.Is(err, handler.ErrNoValues)
return errors.Is(err, ErrNoValues)
},
},
},
@@ -291,10 +320,10 @@ func TestNewUpsertStatement(t *testing.T) {
sequence: 1,
previousSequence: 0,
},
conflictCols: []handler.Column{
handler.NewCol("col1", nil),
conflictCols: []Column{
NewCol("col1", nil),
},
values: []handler.Column{
values: []Column{
{
Name: "col1",
Value: "val",
@@ -337,10 +366,10 @@ func TestNewUpsertStatement(t *testing.T) {
sequence: 1,
previousSequence: 0,
},
conflictCols: []handler.Column{
handler.NewCol("col1", nil),
conflictCols: []Column{
NewCol("col1", nil),
},
values: []handler.Column{
values: []Column{
{
Name: "col1",
Value: "val",
@@ -389,8 +418,8 @@ func TestNewUpdateStatement(t *testing.T) {
type args struct {
table string
event *testEvent
values []handler.Column
conditions []handler.Condition
values []Column
conditions []Condition
}
type want struct {
table string
@@ -414,14 +443,14 @@ func TestNewUpdateStatement(t *testing.T) {
sequence: 1,
previousSequence: 0,
},
values: []handler.Column{
values: []Column{
{
Name: "col1",
Value: "val",
},
},
conditions: []handler.Condition{
handler.NewCond("col2", 1),
conditions: []Condition{
NewCond("col2", 1),
},
},
want: want{
@@ -433,7 +462,7 @@ func TestNewUpdateStatement(t *testing.T) {
shouldExecute: false,
},
isErr: func(err error) bool {
return errors.Is(err, handler.ErrNoProjection)
return errors.Is(err, ErrNoProjection)
},
},
},
@@ -446,9 +475,9 @@ func TestNewUpdateStatement(t *testing.T) {
sequence: 1,
previousSequence: 0,
},
values: []handler.Column{},
conditions: []handler.Condition{
handler.NewCond("col2", 1),
values: []Column{},
conditions: []Condition{
NewCond("col2", 1),
},
},
want: want{
@@ -460,7 +489,7 @@ func TestNewUpdateStatement(t *testing.T) {
shouldExecute: false,
},
isErr: func(err error) bool {
return errors.Is(err, handler.ErrNoValues)
return errors.Is(err, ErrNoValues)
},
},
},
@@ -473,13 +502,13 @@ func TestNewUpdateStatement(t *testing.T) {
sequence: 1,
previousSequence: 0,
},
values: []handler.Column{
values: []Column{
{
Name: "col1",
Value: "val",
},
},
conditions: []handler.Condition{},
conditions: []Condition{},
},
want: want{
table: "my_table",
@@ -490,7 +519,7 @@ func TestNewUpdateStatement(t *testing.T) {
shouldExecute: false,
},
isErr: func(err error) bool {
return errors.Is(err, handler.ErrNoCondition)
return errors.Is(err, ErrNoCondition)
},
},
},
@@ -503,14 +532,14 @@ func TestNewUpdateStatement(t *testing.T) {
sequence: 1,
previousSequence: 0,
},
values: []handler.Column{
values: []Column{
{
Name: "col1",
Value: "val",
},
},
conditions: []handler.Condition{
handler.NewCond("col2", 1),
conditions: []Condition{
NewCond("col2", 1),
},
},
want: want{
@@ -541,7 +570,7 @@ func TestNewUpdateStatement(t *testing.T) {
sequence: 1,
previousSequence: 0,
},
values: []handler.Column{
values: []Column{
{
Name: "col1",
Value: "val",
@@ -551,8 +580,8 @@ func TestNewUpdateStatement(t *testing.T) {
Value: "val5",
},
},
conditions: []handler.Condition{
handler.NewCond("col2", 1),
conditions: []Condition{
NewCond("col2", 1),
},
},
want: want{
@@ -593,7 +622,7 @@ func TestNewDeleteStatement(t *testing.T) {
type args struct {
table string
event *testEvent
conditions []handler.Condition
conditions []Condition
}
type want struct {
@@ -618,8 +647,8 @@ func TestNewDeleteStatement(t *testing.T) {
sequence: 1,
previousSequence: 0,
},
conditions: []handler.Condition{
handler.NewCond("col2", 1),
conditions: []Condition{
NewCond("col2", 1),
},
},
want: want{
@@ -631,7 +660,7 @@ func TestNewDeleteStatement(t *testing.T) {
shouldExecute: false,
},
isErr: func(err error) bool {
return errors.Is(err, handler.ErrNoProjection)
return errors.Is(err, ErrNoProjection)
},
},
},
@@ -644,7 +673,7 @@ func TestNewDeleteStatement(t *testing.T) {
sequence: 1,
previousSequence: 0,
},
conditions: []handler.Condition{},
conditions: []Condition{},
},
want: want{
table: "my_table",
@@ -655,7 +684,7 @@ func TestNewDeleteStatement(t *testing.T) {
shouldExecute: false,
},
isErr: func(err error) bool {
return errors.Is(err, handler.ErrNoCondition)
return errors.Is(err, ErrNoCondition)
},
},
},
@@ -668,8 +697,8 @@ func TestNewDeleteStatement(t *testing.T) {
previousSequence: 0,
aggregateType: "agg",
},
conditions: []handler.Condition{
handler.NewCond("col1", 1),
conditions: []Condition{
NewCond("col1", 1),
},
},
want: want{
@@ -709,11 +738,16 @@ func TestNewDeleteStatement(t *testing.T) {
func TestNewNoOpStatement(t *testing.T) {
type args struct {
event *testEvent
table string
}
type want struct {
executer *wantExecuter
isErr func(error) bool
}
tests := []struct {
name string
args args
want *handler.Statement
want want
}{
{
name: "generate correctly",
@@ -725,20 +759,29 @@ func TestNewNoOpStatement(t *testing.T) {
instanceID: "instanceID",
},
},
want: &handler.Statement{
AggregateType: "agg",
Execute: nil,
Sequence: 5,
PreviousSequence: 3,
InstanceID: "instanceID",
want: want{
executer: nil,
isErr: func(err error) bool {
return err == nil
},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := NewNoOpStatement(tt.args.event); !reflect.DeepEqual(got, tt.want) {
t.Errorf("NewNoOpStatement() = %v, want %v", got, tt.want)
stmt := NewNoOpStatement(tt.args.event)
if tt.want.executer != nil && stmt.Execute == nil {
t.Error("expected executer, but was nil")
}
if stmt.Execute == nil {
return
}
tt.want.executer.t = t
err := stmt.Execute(tt.want.executer, tt.args.table)
if !tt.want.isErr(err) {
t.Errorf("unexpected error: %v", err)
}
tt.want.executer.check(t)
})
}
}
@@ -772,10 +815,17 @@ func TestNewMultiStatement(t *testing.T) {
sequence: 1,
previousSequence: 0,
},
execs: nil,
execs: []func(eventstore.Event) Exec{
AddNoOpStatement(),
},
},
want: want{
executer: nil,
executer: &wantExecuter{
shouldExecute: false,
},
isErr: func(err error) bool {
return err == nil
},
},
},
{
@@ -789,10 +839,10 @@ func TestNewMultiStatement(t *testing.T) {
},
execs: []func(eventstore.Event) Exec{
AddDeleteStatement(
[]handler.Condition{},
[]Condition{},
),
AddCreateStatement(
[]handler.Column{
[]Column{
{
Name: "col1",
Value: 1,
@@ -809,7 +859,7 @@ func TestNewMultiStatement(t *testing.T) {
shouldExecute: false,
},
isErr: func(err error) bool {
return errors.Is(err, handler.ErrNoCondition)
return errors.Is(err, ErrNoCondition)
},
},
},
@@ -824,22 +874,21 @@ func TestNewMultiStatement(t *testing.T) {
},
execs: []func(eventstore.Event) Exec{
AddDeleteStatement(
[]handler.Condition{
handler.NewCond("col1", 1),
},
),
[]Condition{
NewCond("col1", 1),
}),
AddCreateStatement(
[]handler.Column{
[]Column{
{
Name: "col1",
Value: 1,
},
}),
AddUpsertStatement(
[]handler.Column{
handler.NewCol("col1", nil),
[]Column{
NewCol("col1", nil),
},
[]handler.Column{
[]Column{
{
Name: "col1",
Value: 1,
@@ -850,16 +899,15 @@ func TestNewMultiStatement(t *testing.T) {
},
}),
AddUpdateStatement(
[]handler.Column{
[]Column{
{
Name: "col1",
Value: 1,
},
},
[]handler.Condition{
handler.NewCond("col1", 1),
},
),
[]Condition{
NewCond("col1", 1),
}),
},
},
want: want{
@@ -918,10 +966,10 @@ func TestNewCopyStatement(t *testing.T) {
type args struct {
table string
event *testEvent
conflictingCols []handler.Column
from []handler.Column
to []handler.Column
conds []handler.NamespacedCondition
conflictingCols []Column
from []Column
to []Column
conds []NamespacedCondition
}
type want struct {
aggregateType eventstore.AggregateType
@@ -945,8 +993,8 @@ func TestNewCopyStatement(t *testing.T) {
sequence: 1,
previousSequence: 0,
},
conds: []handler.NamespacedCondition{
handler.NewNamespacedCondition("col2", 1),
conds: []NamespacedCondition{
NewNamespacedCondition("col2", 1),
},
},
want: want{
@@ -958,7 +1006,7 @@ func TestNewCopyStatement(t *testing.T) {
shouldExecute: false,
},
isErr: func(err error) bool {
return errors.Is(err, handler.ErrNoProjection)
return errors.Is(err, ErrNoProjection)
},
},
},
@@ -971,13 +1019,13 @@ func TestNewCopyStatement(t *testing.T) {
sequence: 1,
previousSequence: 0,
},
conds: []handler.NamespacedCondition{},
from: []handler.Column{
conds: []NamespacedCondition{},
from: []Column{
{
Name: "col",
},
},
to: []handler.Column{
to: []Column{
{
Name: "col",
},
@@ -992,7 +1040,7 @@ func TestNewCopyStatement(t *testing.T) {
shouldExecute: false,
},
isErr: func(err error) bool {
return errors.Is(err, handler.ErrNoCondition)
return errors.Is(err, ErrNoCondition)
},
},
},
@@ -1005,13 +1053,13 @@ func TestNewCopyStatement(t *testing.T) {
sequence: 1,
previousSequence: 0,
},
conds: []handler.NamespacedCondition{},
from: []handler.Column{
conds: []NamespacedCondition{},
from: []Column{
{
Name: "col",
},
},
to: []handler.Column{
to: []Column{
{
Name: "col",
},
@@ -1029,7 +1077,7 @@ func TestNewCopyStatement(t *testing.T) {
shouldExecute: false,
},
isErr: func(err error) bool {
return errors.Is(err, handler.ErrNoCondition)
return errors.Is(err, ErrNoCondition)
},
},
},
@@ -1042,10 +1090,10 @@ func TestNewCopyStatement(t *testing.T) {
sequence: 1,
previousSequence: 0,
},
conds: []handler.NamespacedCondition{
handler.NewNamespacedCondition("col2", nil),
conds: []NamespacedCondition{
NewNamespacedCondition("col2", nil),
},
from: []handler.Column{},
from: []Column{},
},
want: want{
table: "my_table",
@@ -1056,7 +1104,7 @@ func TestNewCopyStatement(t *testing.T) {
shouldExecute: false,
},
isErr: func(err error) bool {
return errors.Is(err, handler.ErrNoValues)
return errors.Is(err, ErrNoValues)
},
},
},
@@ -1069,7 +1117,7 @@ func TestNewCopyStatement(t *testing.T) {
sequence: 1,
previousSequence: 0,
},
from: []handler.Column{
from: []Column{
{
Name: "state",
Value: 1,
@@ -1084,7 +1132,7 @@ func TestNewCopyStatement(t *testing.T) {
Name: "col_b",
},
},
to: []handler.Column{
to: []Column{
{
Name: "state",
},
@@ -1098,9 +1146,9 @@ func TestNewCopyStatement(t *testing.T) {
Name: "col_b",
},
},
conds: []handler.NamespacedCondition{
handler.NewNamespacedCondition("id", 2),
handler.NewNamespacedCondition("state", 3),
conds: []NamespacedCondition{
NewNamespacedCondition("id", 2),
NewNamespacedCondition("state", 3),
},
},
want: want{
@@ -1131,7 +1179,7 @@ func TestNewCopyStatement(t *testing.T) {
sequence: 1,
previousSequence: 0,
},
from: []handler.Column{
from: []Column{
{
Value: 1,
},
@@ -1145,7 +1193,7 @@ func TestNewCopyStatement(t *testing.T) {
Name: "col_b",
},
},
to: []handler.Column{
to: []Column{
{
Name: "state",
},
@@ -1159,9 +1207,9 @@ func TestNewCopyStatement(t *testing.T) {
Name: "col_d",
},
},
conds: []handler.NamespacedCondition{
handler.NewNamespacedCondition("id", 2),
handler.NewNamespacedCondition("state", 3),
conds: []NamespacedCondition{
NewNamespacedCondition("id", 2),
NewNamespacedCondition("state", 3),
},
},
want: want{
@@ -1200,7 +1248,7 @@ func TestNewCopyStatement(t *testing.T) {
func TestStatement_Execute(t *testing.T) {
type fields struct {
execute func(ex handler.Executer, projectionName string) error
execute func(ex Executer, projectionName string) error
}
type want struct {
isErr func(error) bool
@@ -1217,7 +1265,7 @@ func TestStatement_Execute(t *testing.T) {
{
name: "execute returns no error",
fields: fields{
execute: func(ex handler.Executer, projectionName string) error { return nil },
execute: func(ex Executer, projectionName string) error { return nil },
},
args: args{
projectionName: "my_projection",
@@ -1234,18 +1282,18 @@ func TestStatement_Execute(t *testing.T) {
projectionName: "my_projection",
},
fields: fields{
execute: func(ex handler.Executer, projectionName string) error { return errTestErr },
execute: func(ex Executer, projectionName string) error { return errTest },
},
want: want{
isErr: func(err error) bool {
return errors.Is(err, errTestErr)
return errors.Is(err, errTest)
},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
stmt := &handler.Statement{
stmt := &Statement{
Execute: tt.fields.execute,
}
if err := stmt.Execute(nil, tt.args.projectionName); !tt.want.isErr(err) {
@@ -1257,7 +1305,7 @@ func TestStatement_Execute(t *testing.T) {
func Test_columnsToQuery(t *testing.T) {
type args struct {
cols []handler.Column
cols []Column
}
type want struct {
names []string
@@ -1281,7 +1329,7 @@ func Test_columnsToQuery(t *testing.T) {
{
name: "one column",
args: args{
cols: []handler.Column{
cols: []Column{
{
Name: "col1",
Value: 1,
@@ -1297,7 +1345,7 @@ func Test_columnsToQuery(t *testing.T) {
{
name: "multiple columns",
args: args{
cols: []handler.Column{
cols: []Column{
{
Name: "col1",
Value: 1,
@@ -1317,14 +1365,14 @@ func Test_columnsToQuery(t *testing.T) {
{
name: "with copy column",
args: args{
cols: []handler.Column{
cols: []Column{
{
Name: "col1",
Value: 1,
},
{
Name: "col2",
Value: handler.Column{
Value: Column{
Name: "col1",
},
},
@@ -1357,9 +1405,9 @@ func Test_columnsToQuery(t *testing.T) {
}
}
func Test_conditionsToWhere(t *testing.T) {
func Test_columnsToWhere(t *testing.T) {
type args struct {
conds []handler.Condition
conds []Condition
paramOffset int
}
type want struct {
@@ -1382,10 +1430,10 @@ func Test_conditionsToWhere(t *testing.T) {
{
name: "no offset",
args: args{
conds: []handler.Condition{
handler.NewCond("col1", "val1"),
conds: []Condition{
NewCond("col1", "val1"),
},
paramOffset: 0,
paramOffset: 1,
},
want: want{
wheres: []string{"(col1 = $1)"},
@@ -1395,11 +1443,11 @@ func Test_conditionsToWhere(t *testing.T) {
{
name: "multiple cols",
args: args{
conds: []handler.Condition{
handler.NewCond("col1", "val1"),
handler.NewCond("col2", "val2"),
conds: []Condition{
NewCond("col1", "val1"),
NewCond("col2", "val2"),
},
paramOffset: 0,
paramOffset: 1,
},
want: want{
wheres: []string{"(col1 = $1)", "(col2 = $2)"},
@@ -1409,10 +1457,10 @@ func Test_conditionsToWhere(t *testing.T) {
{
name: "2 offset",
args: args{
conds: []handler.Condition{
handler.NewCond("col1", "val1"),
conds: []Condition{
NewCond("col1", "val1"),
},
paramOffset: 2,
paramOffset: 3,
},
want: want{
wheres: []string{"(col1 = $3)"},
@@ -1422,9 +1470,10 @@ func Test_conditionsToWhere(t *testing.T) {
{
name: "less than",
args: args{
conds: []handler.Condition{
conds: []Condition{
NewLessThanCond("col1", "val1"),
},
paramOffset: 1,
},
want: want{
wheres: []string{"(col1 < $1)"},
@@ -1434,7 +1483,7 @@ func Test_conditionsToWhere(t *testing.T) {
{
name: "is null",
args: args{
conds: []handler.Condition{
conds: []Condition{
NewIsNullCond("col1"),
},
},
@@ -1446,21 +1495,23 @@ func Test_conditionsToWhere(t *testing.T) {
{
name: "text array contains",
args: args{
conds: []handler.Condition{
conds: []Condition{
NewTextArrayContainsCond("col1", "val1"),
},
paramOffset: 1,
},
want: want{
wheres: []string{"(col1 @> $1)"},
values: []interface{}{database.StringArray{"val1"}},
values: []interface{}{database.TextArray[string]{"val1"}},
},
},
{
name: "not",
args: args{
conds: []handler.Condition{
Not(handler.NewCond("col1", "val1")),
conds: []Condition{
Not(NewCond("col1", "val1")),
},
paramOffset: 1,
},
want: want{
wheres: []string{"(NOT (col1 = $1))"},
@@ -1490,7 +1541,7 @@ func TestParameterOpts(t *testing.T) {
tests := []struct {
name string
args args
constructor func(column string, value interface{}) handler.Column
constructor func(column string, value interface{}) Column
want string
}{
{
@@ -1523,3 +1574,60 @@ func TestParameterOpts(t *testing.T) {
})
}
}
// func TestHandler_reduce(t *testing.T) {
// type fields struct {
// projection Projection
// }
// type args struct {
// event eventstore.Event
// }
// tests := []struct {
// name string
// fields fields
// args args
// isErr func(t *testing.T, err error)
// shouldBeCalled bool
// }{
// {
// name: "",
// fields: fields{
// projection: &projection{
// reducers: []AggregateReducer{
// {
// Aggregate: "aggregate",
// EventRedusers: []EventReducer{
// {
// Event: "event",
// Reduce: (&mockEventReducer{
// statement: new(Statement),
// }).reduce,
// },
// },
// },
// },
// },
// },
// },
// }
// for _, tt := range tests {
// if tt.isErr == nil {
// tt.isErr = func(t *testing.T, err error) {
// if err != nil {
// t.Error("expected no error got:", err)
// }
// }
// }
// t.Run(tt.name, func(t *testing.T) {
// h := &Handler{
// projection: tt.fields.projection,
// }
// got, err := h.reduce(tt.args.event)
// tt.isErr(t, err)
// if tt.shouldBeCalled != tt.
// if !reflect.DeepEqual(got, tt.want) {
// t.Errorf("Handler.reduce() = %v, want %v", got, tt.want)
// }
// })
// }
// }