mirror of
https://github.com/zitadel/zitadel.git
synced 2024-12-12 19:14:23 +00:00
99e1c654a3
* fix: tests * bastle wie en grosse * fix(database): scan as callback * fix tests * fix merge failures * remove as of system time * refactor: remove unused test * refacotr: remove unused lines
348 lines
10 KiB
Go
348 lines
10 KiB
Go
package crdb
|
|
|
|
import (
|
|
"context"
|
|
"database/sql"
|
|
"fmt"
|
|
|
|
"github.com/zitadel/logging"
|
|
|
|
"github.com/zitadel/zitadel/internal/database"
|
|
"github.com/zitadel/zitadel/internal/errors"
|
|
"github.com/zitadel/zitadel/internal/eventstore"
|
|
"github.com/zitadel/zitadel/internal/eventstore/handler"
|
|
"github.com/zitadel/zitadel/internal/repository/pseudo"
|
|
)
|
|
|
|
var (
|
|
errSeqNotUpdated = errors.ThrowInternal(nil, "CRDB-79GWt", "current sequence not updated")
|
|
)
|
|
|
|
type StatementHandlerConfig struct {
|
|
handler.ProjectionHandlerConfig
|
|
|
|
Client *database.DB
|
|
SequenceTable string
|
|
LockTable string
|
|
FailedEventsTable string
|
|
MaxFailureCount uint
|
|
BulkLimit uint64
|
|
|
|
Reducers []handler.AggregateReducer
|
|
InitCheck *handler.Check
|
|
}
|
|
|
|
type StatementHandler struct {
|
|
*handler.ProjectionHandler
|
|
Locker
|
|
|
|
client *database.DB
|
|
sequenceTable string
|
|
currentSequenceStmt string
|
|
currentSequenceWithoutLockStmt string
|
|
updateSequencesBaseStmt string
|
|
maxFailureCount uint
|
|
failureCountStmt string
|
|
setFailureCountStmt string
|
|
|
|
aggregates []eventstore.AggregateType
|
|
reduces map[eventstore.EventType]handler.Reduce
|
|
initCheck *handler.Check
|
|
initialized chan bool
|
|
|
|
bulkLimit uint64
|
|
|
|
reduceScheduledPseudoEvent bool
|
|
}
|
|
|
|
func NewStatementHandler(
|
|
ctx context.Context,
|
|
config StatementHandlerConfig,
|
|
) StatementHandler {
|
|
aggregateTypes := make([]eventstore.AggregateType, 0, len(config.Reducers))
|
|
reduces := make(map[eventstore.EventType]handler.Reduce, len(config.Reducers))
|
|
reduceScheduledPseudoEvent := false
|
|
for _, aggReducer := range config.Reducers {
|
|
aggregateTypes = append(aggregateTypes, aggReducer.Aggregate)
|
|
if aggReducer.Aggregate == pseudo.AggregateType {
|
|
reduceScheduledPseudoEvent = true
|
|
if len(config.Reducers) != 1 ||
|
|
len(aggReducer.EventRedusers) != 1 ||
|
|
aggReducer.EventRedusers[0].Event != pseudo.ScheduledEventType {
|
|
panic("if a pseudo.AggregateType is reduced, exactly one event reducer for pseudo.ScheduledEventType is supported and no other aggregate can be reduced")
|
|
}
|
|
}
|
|
for _, eventReducer := range aggReducer.EventRedusers {
|
|
reduces[eventReducer.Event] = eventReducer.Reduce
|
|
}
|
|
}
|
|
|
|
h := StatementHandler{
|
|
client: config.Client,
|
|
sequenceTable: config.SequenceTable,
|
|
maxFailureCount: config.MaxFailureCount,
|
|
currentSequenceStmt: fmt.Sprintf(currentSequenceStmtFormat, config.SequenceTable),
|
|
currentSequenceWithoutLockStmt: fmt.Sprintf(currentSequenceStmtWithoutLockFormat, config.SequenceTable),
|
|
updateSequencesBaseStmt: fmt.Sprintf(updateCurrentSequencesStmtFormat, config.SequenceTable),
|
|
failureCountStmt: fmt.Sprintf(failureCountStmtFormat, config.FailedEventsTable),
|
|
setFailureCountStmt: fmt.Sprintf(setFailureCountStmtFormat, config.FailedEventsTable),
|
|
aggregates: aggregateTypes,
|
|
reduces: reduces,
|
|
bulkLimit: config.BulkLimit,
|
|
Locker: NewLocker(config.Client.DB, config.LockTable, config.ProjectionName),
|
|
initCheck: config.InitCheck,
|
|
initialized: make(chan bool),
|
|
reduceScheduledPseudoEvent: reduceScheduledPseudoEvent,
|
|
}
|
|
|
|
h.ProjectionHandler = handler.NewProjectionHandler(ctx, config.ProjectionHandlerConfig, h.reduce, h.Update, h.searchQuery, h.Lock, h.Unlock, h.initialized, reduceScheduledPseudoEvent)
|
|
|
|
return h
|
|
}
|
|
|
|
func (h *StatementHandler) Start() {
|
|
h.initialized <- true
|
|
close(h.initialized)
|
|
if !h.reduceScheduledPseudoEvent {
|
|
h.Subscribe(h.aggregates...)
|
|
}
|
|
}
|
|
|
|
func (h *StatementHandler) searchQuery(ctx context.Context, instanceIDs []string) (*eventstore.SearchQueryBuilder, uint64, error) {
|
|
if h.reduceScheduledPseudoEvent {
|
|
return nil, 1, nil
|
|
}
|
|
return h.dbSearchQuery(ctx, instanceIDs)
|
|
}
|
|
|
|
func (h *StatementHandler) dbSearchQuery(ctx context.Context, instanceIDs []string) (*eventstore.SearchQueryBuilder, uint64, error) {
|
|
sequences, err := h.currentSequences(ctx, false, h.client.QueryContext, instanceIDs)
|
|
if err != nil {
|
|
return nil, 0, err
|
|
}
|
|
|
|
queryBuilder := eventstore.NewSearchQueryBuilder(eventstore.ColumnsEvent).Limit(h.bulkLimit).AllowTimeTravel()
|
|
|
|
for _, aggregateType := range h.aggregates {
|
|
for _, instanceID := range instanceIDs {
|
|
var seq uint64
|
|
for _, sequence := range sequences[aggregateType] {
|
|
if sequence.instanceID == instanceID {
|
|
seq = sequence.sequence
|
|
break
|
|
}
|
|
}
|
|
queryBuilder.
|
|
AddQuery().
|
|
AggregateTypes(aggregateType).
|
|
SequenceGreater(seq).
|
|
InstanceID(instanceID)
|
|
}
|
|
}
|
|
return queryBuilder, h.bulkLimit, nil
|
|
}
|
|
|
|
type transaction struct {
|
|
*sql.Tx
|
|
}
|
|
|
|
func (t *transaction) QueryContext(ctx context.Context, scan func(rows *sql.Rows) error, query string, args ...any) error {
|
|
rows, err := t.Tx.QueryContext(ctx, query, args...)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer func() {
|
|
closeErr := rows.Close()
|
|
logging.OnError(closeErr).Info("rows.Close failed")
|
|
}()
|
|
|
|
if err = scan(rows); err != nil {
|
|
return err
|
|
}
|
|
return rows.Err()
|
|
}
|
|
|
|
// Update implements handler.Update
|
|
func (h *StatementHandler) Update(ctx context.Context, stmts []*handler.Statement, reduce handler.Reduce) (index int, err error) {
|
|
if len(stmts) == 0 {
|
|
return -1, nil
|
|
}
|
|
instanceIDs := make([]string, 0, len(stmts))
|
|
for _, stmt := range stmts {
|
|
instanceIDs = appendToInstanceIDs(instanceIDs, stmt.InstanceID)
|
|
}
|
|
tx, err := h.client.BeginTx(ctx, nil)
|
|
if err != nil {
|
|
return -1, errors.ThrowInternal(err, "CRDB-e89Gq", "begin failed")
|
|
}
|
|
|
|
sequences, err := h.currentSequences(ctx, true, (&transaction{Tx: tx}).QueryContext, instanceIDs)
|
|
if err != nil {
|
|
tx.Rollback()
|
|
return -1, err
|
|
}
|
|
|
|
//checks for events between create statement and current sequence
|
|
// because there could be events between current sequence and a creation event
|
|
// and we cannot check via stmt.PreviousSequence
|
|
if stmts[0].PreviousSequence == 0 {
|
|
previousStmts, err := h.fetchPreviousStmts(ctx, tx, stmts[0].Sequence, stmts[0].InstanceID, sequences, reduce)
|
|
if err != nil {
|
|
tx.Rollback()
|
|
return -1, err
|
|
}
|
|
stmts = append(previousStmts, stmts...)
|
|
}
|
|
|
|
lastSuccessfulIdx := h.executeStmts(tx, &stmts, sequences)
|
|
|
|
if lastSuccessfulIdx >= 0 {
|
|
err = h.updateCurrentSequences(tx, sequences)
|
|
if err != nil {
|
|
tx.Rollback()
|
|
return -1, err
|
|
}
|
|
}
|
|
|
|
if err = tx.Commit(); err != nil {
|
|
return -1, err
|
|
}
|
|
|
|
if lastSuccessfulIdx < len(stmts)-1 {
|
|
return lastSuccessfulIdx, handler.ErrSomeStmtsFailed
|
|
}
|
|
|
|
return lastSuccessfulIdx, nil
|
|
}
|
|
|
|
func (h *StatementHandler) fetchPreviousStmts(ctx context.Context, tx *sql.Tx, stmtSeq uint64, instanceID string, sequences currentSequences, reduce handler.Reduce) (previousStmts []*handler.Statement, err error) {
|
|
query := eventstore.NewSearchQueryBuilder(eventstore.ColumnsEvent).SetTx(tx)
|
|
queriesAdded := false
|
|
for _, aggregateType := range h.aggregates {
|
|
for _, sequence := range sequences[aggregateType] {
|
|
if stmtSeq <= sequence.sequence && instanceID == sequence.instanceID {
|
|
continue
|
|
}
|
|
|
|
query.
|
|
AddQuery().
|
|
AggregateTypes(aggregateType).
|
|
SequenceGreater(sequence.sequence).
|
|
SequenceLess(stmtSeq).
|
|
InstanceID(sequence.instanceID)
|
|
|
|
queriesAdded = true
|
|
}
|
|
}
|
|
|
|
if !queriesAdded {
|
|
return nil, nil
|
|
}
|
|
|
|
events, err := h.Eventstore.Filter(ctx, query)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
for _, event := range events {
|
|
stmt, err := reduce(event)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
previousStmts = append(previousStmts, stmt)
|
|
}
|
|
return previousStmts, nil
|
|
}
|
|
|
|
func (h *StatementHandler) executeStmts(
|
|
tx *sql.Tx,
|
|
stmts *[]*handler.Statement,
|
|
sequences currentSequences,
|
|
) int {
|
|
|
|
lastSuccessfulIdx := -1
|
|
stmts:
|
|
for i := 0; i < len(*stmts); i++ {
|
|
stmt := (*stmts)[i]
|
|
for _, sequence := range sequences[stmt.AggregateType] {
|
|
if stmt.Sequence <= sequence.sequence && stmt.InstanceID == sequence.instanceID {
|
|
logging.WithFields("statement", stmt, "currentSequence", sequence).Debug("statement dropped")
|
|
if i < len(*stmts)-1 {
|
|
copy((*stmts)[i:], (*stmts)[i+1:])
|
|
}
|
|
*stmts = (*stmts)[:len(*stmts)-1]
|
|
i--
|
|
continue stmts
|
|
}
|
|
if stmt.PreviousSequence > 0 && stmt.PreviousSequence != sequence.sequence && stmt.InstanceID == sequence.instanceID {
|
|
logging.WithFields("projection", h.ProjectionName, "aggregateType", stmt.AggregateType, "sequence", stmt.Sequence, "prevSeq", stmt.PreviousSequence, "currentSeq", sequence.sequence).Warn("sequences do not match")
|
|
break stmts
|
|
}
|
|
}
|
|
err := h.executeStmt(tx, stmt)
|
|
if err == nil {
|
|
updateSequences(sequences, stmt)
|
|
lastSuccessfulIdx = i
|
|
continue
|
|
}
|
|
|
|
shouldContinue := h.handleFailedStmt(tx, stmt, err)
|
|
if !shouldContinue {
|
|
break
|
|
}
|
|
|
|
updateSequences(sequences, stmt)
|
|
lastSuccessfulIdx = i
|
|
continue
|
|
}
|
|
return lastSuccessfulIdx
|
|
}
|
|
|
|
// executeStmt handles sql statements
|
|
// an error is returned if the statement could not be inserted properly
|
|
func (h *StatementHandler) executeStmt(tx *sql.Tx, stmt *handler.Statement) error {
|
|
if stmt.IsNoop() {
|
|
return nil
|
|
}
|
|
_, err := tx.Exec("SAVEPOINT push_stmt")
|
|
if err != nil {
|
|
return errors.ThrowInternal(err, "CRDB-i1wp6", "unable to create savepoint")
|
|
}
|
|
err = stmt.Execute(tx, h.ProjectionName)
|
|
if err != nil {
|
|
logging.WithError(err).Error()
|
|
_, rollbackErr := tx.Exec("ROLLBACK TO SAVEPOINT push_stmt")
|
|
if rollbackErr != nil {
|
|
return errors.ThrowInternal(rollbackErr, "CRDB-zzp3P", "rollback to savepoint failed")
|
|
}
|
|
return errors.ThrowInternal(err, "CRDB-oRkaN", "unable execute stmt")
|
|
}
|
|
_, err = tx.Exec("RELEASE push_stmt")
|
|
if err != nil {
|
|
return errors.ThrowInternal(err, "CRDB-qWgwT", "unable to release savepoint")
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func updateSequences(sequences currentSequences, stmt *handler.Statement) {
|
|
for _, sequence := range sequences[stmt.AggregateType] {
|
|
if sequence.instanceID == stmt.InstanceID {
|
|
sequence.sequence = stmt.Sequence
|
|
return
|
|
}
|
|
}
|
|
sequences[stmt.AggregateType] = append(sequences[stmt.AggregateType], &instanceSequence{
|
|
instanceID: stmt.InstanceID,
|
|
sequence: stmt.Sequence,
|
|
})
|
|
}
|
|
|
|
func appendToInstanceIDs(instances []string, id string) []string {
|
|
for _, instance := range instances {
|
|
if instance == id {
|
|
return instances
|
|
}
|
|
}
|
|
return append(instances, id)
|
|
}
|