mirror of
https://github.com/zitadel/zitadel.git
synced 2024-12-13 11:34:26 +00:00
9ab566fdeb
* fix: add keys to projections * change to multiple tables * query keys * query keys * fix race condition * fix timer reset * begin tests * tests * remove migration * only send to keyChannel if not nil
260 lines
6.9 KiB
Go
260 lines
6.9 KiB
Go
package crdb
|
|
|
|
import (
|
|
"context"
|
|
"database/sql"
|
|
"fmt"
|
|
|
|
"github.com/caos/logging"
|
|
|
|
"github.com/caos/zitadel/internal/errors"
|
|
"github.com/caos/zitadel/internal/eventstore"
|
|
"github.com/caos/zitadel/internal/eventstore/handler"
|
|
)
|
|
|
|
var (
|
|
errSeqNotUpdated = errors.ThrowInternal(nil, "CRDB-79GWt", "current sequence not updated")
|
|
)
|
|
|
|
type StatementHandlerConfig struct {
|
|
handler.ProjectionHandlerConfig
|
|
|
|
Client *sql.DB
|
|
SequenceTable string
|
|
LockTable string
|
|
FailedEventsTable string
|
|
MaxFailureCount uint
|
|
BulkLimit uint64
|
|
|
|
Reducers []handler.AggregateReducer
|
|
}
|
|
|
|
type StatementHandler struct {
|
|
*handler.ProjectionHandler
|
|
Locker
|
|
|
|
client *sql.DB
|
|
sequenceTable string
|
|
currentSequenceStmt string
|
|
updateSequencesBaseStmt string
|
|
maxFailureCount uint
|
|
failureCountStmt string
|
|
setFailureCountStmt string
|
|
|
|
aggregates []eventstore.AggregateType
|
|
reduces map[eventstore.EventType]handler.Reduce
|
|
|
|
bulkLimit uint64
|
|
}
|
|
|
|
func NewStatementHandler(
|
|
ctx context.Context,
|
|
config StatementHandlerConfig,
|
|
) StatementHandler {
|
|
aggregateTypes := make([]eventstore.AggregateType, 0, len(config.Reducers))
|
|
reduces := make(map[eventstore.EventType]handler.Reduce, len(config.Reducers))
|
|
for _, aggReducer := range config.Reducers {
|
|
aggregateTypes = append(aggregateTypes, aggReducer.Aggregate)
|
|
for _, eventReducer := range aggReducer.EventRedusers {
|
|
reduces[eventReducer.Event] = eventReducer.Reduce
|
|
}
|
|
}
|
|
|
|
h := StatementHandler{
|
|
ProjectionHandler: handler.NewProjectionHandler(config.ProjectionHandlerConfig),
|
|
client: config.Client,
|
|
sequenceTable: config.SequenceTable,
|
|
maxFailureCount: config.MaxFailureCount,
|
|
currentSequenceStmt: fmt.Sprintf(currentSequenceStmtFormat, config.SequenceTable),
|
|
updateSequencesBaseStmt: fmt.Sprintf(updateCurrentSequencesStmtFormat, config.SequenceTable),
|
|
failureCountStmt: fmt.Sprintf(failureCountStmtFormat, config.FailedEventsTable),
|
|
setFailureCountStmt: fmt.Sprintf(setFailureCountStmtFormat, config.FailedEventsTable),
|
|
aggregates: aggregateTypes,
|
|
reduces: reduces,
|
|
bulkLimit: config.BulkLimit,
|
|
Locker: NewLocker(config.Client, config.LockTable, config.ProjectionHandlerConfig.ProjectionName),
|
|
}
|
|
|
|
go h.ProjectionHandler.Process(
|
|
ctx,
|
|
h.reduce,
|
|
h.Update,
|
|
h.Lock,
|
|
h.Unlock,
|
|
h.SearchQuery,
|
|
)
|
|
|
|
h.ProjectionHandler.Handler.Subscribe(h.aggregates...)
|
|
|
|
return h
|
|
}
|
|
|
|
func (h *StatementHandler) SearchQuery() (*eventstore.SearchQueryBuilder, uint64, error) {
|
|
sequences, err := h.currentSequences(h.client.Query)
|
|
if err != nil {
|
|
return nil, 0, err
|
|
}
|
|
|
|
queryBuilder := eventstore.NewSearchQueryBuilder(eventstore.ColumnsEvent).Limit(h.bulkLimit)
|
|
for _, aggregateType := range h.aggregates {
|
|
queryBuilder.
|
|
AddQuery().
|
|
AggregateTypes(aggregateType).
|
|
SequenceGreater(sequences[aggregateType])
|
|
}
|
|
|
|
return queryBuilder, h.bulkLimit, nil
|
|
}
|
|
|
|
//Update implements handler.Update
|
|
func (h *StatementHandler) Update(ctx context.Context, stmts []*handler.Statement, reduce handler.Reduce) (unexecutedStmts []*handler.Statement, err error) {
|
|
tx, err := h.client.BeginTx(ctx, nil)
|
|
if err != nil {
|
|
return stmts, errors.ThrowInternal(err, "CRDB-e89Gq", "begin failed")
|
|
}
|
|
|
|
sequences, err := h.currentSequences(tx.Query)
|
|
if err != nil {
|
|
tx.Rollback()
|
|
return stmts, err
|
|
}
|
|
|
|
//checks for events between create statement and current sequence
|
|
// because there could be events between current sequence and a creation event
|
|
// and we cannot check via stmt.PreviousSequence
|
|
if stmts[0].PreviousSequence == 0 {
|
|
previousStmts, err := h.fetchPreviousStmts(ctx, stmts[0].Sequence, sequences, reduce)
|
|
if err != nil {
|
|
tx.Rollback()
|
|
return stmts, err
|
|
}
|
|
stmts = append(previousStmts, stmts...)
|
|
}
|
|
|
|
lastSuccessfulIdx := h.executeStmts(tx, stmts, sequences)
|
|
|
|
if lastSuccessfulIdx >= 0 {
|
|
err = h.updateCurrentSequences(tx, sequences)
|
|
if err != nil {
|
|
tx.Rollback()
|
|
return stmts, err
|
|
}
|
|
}
|
|
|
|
if err = tx.Commit(); err != nil {
|
|
return stmts, err
|
|
}
|
|
|
|
if lastSuccessfulIdx == -1 {
|
|
return stmts, handler.ErrSomeStmtsFailed
|
|
}
|
|
|
|
unexecutedStmts = make([]*handler.Statement, len(stmts)-(lastSuccessfulIdx+1))
|
|
copy(unexecutedStmts, stmts[lastSuccessfulIdx+1:])
|
|
stmts = nil
|
|
|
|
if len(unexecutedStmts) > 0 {
|
|
return unexecutedStmts, handler.ErrSomeStmtsFailed
|
|
}
|
|
|
|
return unexecutedStmts, nil
|
|
}
|
|
|
|
func (h *StatementHandler) fetchPreviousStmts(
|
|
ctx context.Context,
|
|
stmtSeq uint64,
|
|
sequences currentSequences,
|
|
reduce handler.Reduce,
|
|
) (previousStmts []*handler.Statement, err error) {
|
|
|
|
query := eventstore.NewSearchQueryBuilder(eventstore.ColumnsEvent)
|
|
queriesAdded := false
|
|
for _, aggregateType := range h.aggregates {
|
|
if stmtSeq <= sequences[aggregateType] {
|
|
continue
|
|
}
|
|
|
|
query.
|
|
AddQuery().
|
|
AggregateTypes(aggregateType).
|
|
SequenceGreater(sequences[aggregateType]).
|
|
SequenceLess(stmtSeq)
|
|
|
|
queriesAdded = true
|
|
}
|
|
|
|
if !queriesAdded {
|
|
return nil, nil
|
|
}
|
|
|
|
events, err := h.Eventstore.Filter(ctx, query)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
for _, event := range events {
|
|
stmt, err := reduce(event)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
previousStmts = append(previousStmts, stmt)
|
|
}
|
|
return previousStmts, nil
|
|
}
|
|
|
|
func (h *StatementHandler) executeStmts(
|
|
tx *sql.Tx,
|
|
stmts []*handler.Statement,
|
|
sequences currentSequences,
|
|
) int {
|
|
|
|
lastSuccessfulIdx := -1
|
|
for i, stmt := range stmts {
|
|
if stmt.Sequence <= sequences[stmt.AggregateType] {
|
|
continue
|
|
}
|
|
if stmt.PreviousSequence > 0 && stmt.PreviousSequence != sequences[stmt.AggregateType] {
|
|
logging.LogWithFields("CRDB-jJBJn", "projection", h.ProjectionName, "aggregateType", stmt.AggregateType, "seq", stmt.Sequence, "prevSeq", stmt.PreviousSequence, "currentSeq", sequences[stmt.AggregateType]).Warn("sequences do not match")
|
|
break
|
|
}
|
|
err := h.executeStmt(tx, stmt)
|
|
if err == nil {
|
|
sequences[stmt.AggregateType], lastSuccessfulIdx = stmt.Sequence, i
|
|
continue
|
|
}
|
|
|
|
shouldContinue := h.handleFailedStmt(tx, stmt, err)
|
|
if !shouldContinue {
|
|
break
|
|
}
|
|
|
|
sequences[stmt.AggregateType], lastSuccessfulIdx = stmt.Sequence, i
|
|
}
|
|
return lastSuccessfulIdx
|
|
}
|
|
|
|
//executeStmt handles sql statements
|
|
//an error is returned if the statement could not be inserted properly
|
|
func (h *StatementHandler) executeStmt(tx *sql.Tx, stmt *handler.Statement) error {
|
|
if stmt.IsNoop() {
|
|
return nil
|
|
}
|
|
_, err := tx.Exec("SAVEPOINT push_stmt")
|
|
if err != nil {
|
|
return errors.ThrowInternal(err, "CRDB-i1wp6", "unable to create savepoint")
|
|
}
|
|
err = stmt.Execute(tx, h.ProjectionName)
|
|
if err != nil {
|
|
_, rollbackErr := tx.Exec("ROLLBACK TO SAVEPOINT push_stmt")
|
|
if rollbackErr != nil {
|
|
return errors.ThrowInternal(rollbackErr, "CRDB-zzp3P", "rollback to savepoint failed")
|
|
}
|
|
return errors.ThrowInternal(err, "CRDB-oRkaN", "unable execute stmt")
|
|
}
|
|
_, err = tx.Exec("RELEASE push_stmt")
|
|
if err != nil {
|
|
return errors.ThrowInternal(err, "CRDB-qWgwT", "unable to release savepoint")
|
|
}
|
|
return nil
|
|
}
|