Silvan 99e1c654a3
feat(storage): read only transactions for queries (#6415)
* fix: tests

* bastle wie en grosse

* fix(database): scan as callback

* fix tests

* fix merge failures

* remove as of system time

* refactor: remove unused test

* refacotr: remove unused lines
2023-08-22 10:49:22 +00:00

348 lines
10 KiB
Go

package crdb
import (
"context"
"database/sql"
"fmt"
"github.com/zitadel/logging"
"github.com/zitadel/zitadel/internal/database"
"github.com/zitadel/zitadel/internal/errors"
"github.com/zitadel/zitadel/internal/eventstore"
"github.com/zitadel/zitadel/internal/eventstore/handler"
"github.com/zitadel/zitadel/internal/repository/pseudo"
)
var (
errSeqNotUpdated = errors.ThrowInternal(nil, "CRDB-79GWt", "current sequence not updated")
)
type StatementHandlerConfig struct {
handler.ProjectionHandlerConfig
Client *database.DB
SequenceTable string
LockTable string
FailedEventsTable string
MaxFailureCount uint
BulkLimit uint64
Reducers []handler.AggregateReducer
InitCheck *handler.Check
}
type StatementHandler struct {
*handler.ProjectionHandler
Locker
client *database.DB
sequenceTable string
currentSequenceStmt string
currentSequenceWithoutLockStmt string
updateSequencesBaseStmt string
maxFailureCount uint
failureCountStmt string
setFailureCountStmt string
aggregates []eventstore.AggregateType
reduces map[eventstore.EventType]handler.Reduce
initCheck *handler.Check
initialized chan bool
bulkLimit uint64
reduceScheduledPseudoEvent bool
}
func NewStatementHandler(
ctx context.Context,
config StatementHandlerConfig,
) StatementHandler {
aggregateTypes := make([]eventstore.AggregateType, 0, len(config.Reducers))
reduces := make(map[eventstore.EventType]handler.Reduce, len(config.Reducers))
reduceScheduledPseudoEvent := false
for _, aggReducer := range config.Reducers {
aggregateTypes = append(aggregateTypes, aggReducer.Aggregate)
if aggReducer.Aggregate == pseudo.AggregateType {
reduceScheduledPseudoEvent = true
if len(config.Reducers) != 1 ||
len(aggReducer.EventRedusers) != 1 ||
aggReducer.EventRedusers[0].Event != pseudo.ScheduledEventType {
panic("if a pseudo.AggregateType is reduced, exactly one event reducer for pseudo.ScheduledEventType is supported and no other aggregate can be reduced")
}
}
for _, eventReducer := range aggReducer.EventRedusers {
reduces[eventReducer.Event] = eventReducer.Reduce
}
}
h := StatementHandler{
client: config.Client,
sequenceTable: config.SequenceTable,
maxFailureCount: config.MaxFailureCount,
currentSequenceStmt: fmt.Sprintf(currentSequenceStmtFormat, config.SequenceTable),
currentSequenceWithoutLockStmt: fmt.Sprintf(currentSequenceStmtWithoutLockFormat, config.SequenceTable),
updateSequencesBaseStmt: fmt.Sprintf(updateCurrentSequencesStmtFormat, config.SequenceTable),
failureCountStmt: fmt.Sprintf(failureCountStmtFormat, config.FailedEventsTable),
setFailureCountStmt: fmt.Sprintf(setFailureCountStmtFormat, config.FailedEventsTable),
aggregates: aggregateTypes,
reduces: reduces,
bulkLimit: config.BulkLimit,
Locker: NewLocker(config.Client.DB, config.LockTable, config.ProjectionName),
initCheck: config.InitCheck,
initialized: make(chan bool),
reduceScheduledPseudoEvent: reduceScheduledPseudoEvent,
}
h.ProjectionHandler = handler.NewProjectionHandler(ctx, config.ProjectionHandlerConfig, h.reduce, h.Update, h.searchQuery, h.Lock, h.Unlock, h.initialized, reduceScheduledPseudoEvent)
return h
}
func (h *StatementHandler) Start() {
h.initialized <- true
close(h.initialized)
if !h.reduceScheduledPseudoEvent {
h.Subscribe(h.aggregates...)
}
}
func (h *StatementHandler) searchQuery(ctx context.Context, instanceIDs []string) (*eventstore.SearchQueryBuilder, uint64, error) {
if h.reduceScheduledPseudoEvent {
return nil, 1, nil
}
return h.dbSearchQuery(ctx, instanceIDs)
}
func (h *StatementHandler) dbSearchQuery(ctx context.Context, instanceIDs []string) (*eventstore.SearchQueryBuilder, uint64, error) {
sequences, err := h.currentSequences(ctx, false, h.client.QueryContext, instanceIDs)
if err != nil {
return nil, 0, err
}
queryBuilder := eventstore.NewSearchQueryBuilder(eventstore.ColumnsEvent).Limit(h.bulkLimit).AllowTimeTravel()
for _, aggregateType := range h.aggregates {
for _, instanceID := range instanceIDs {
var seq uint64
for _, sequence := range sequences[aggregateType] {
if sequence.instanceID == instanceID {
seq = sequence.sequence
break
}
}
queryBuilder.
AddQuery().
AggregateTypes(aggregateType).
SequenceGreater(seq).
InstanceID(instanceID)
}
}
return queryBuilder, h.bulkLimit, nil
}
type transaction struct {
*sql.Tx
}
func (t *transaction) QueryContext(ctx context.Context, scan func(rows *sql.Rows) error, query string, args ...any) error {
rows, err := t.Tx.QueryContext(ctx, query, args...)
if err != nil {
return err
}
defer func() {
closeErr := rows.Close()
logging.OnError(closeErr).Info("rows.Close failed")
}()
if err = scan(rows); err != nil {
return err
}
return rows.Err()
}
// Update implements handler.Update
func (h *StatementHandler) Update(ctx context.Context, stmts []*handler.Statement, reduce handler.Reduce) (index int, err error) {
if len(stmts) == 0 {
return -1, nil
}
instanceIDs := make([]string, 0, len(stmts))
for _, stmt := range stmts {
instanceIDs = appendToInstanceIDs(instanceIDs, stmt.InstanceID)
}
tx, err := h.client.BeginTx(ctx, nil)
if err != nil {
return -1, errors.ThrowInternal(err, "CRDB-e89Gq", "begin failed")
}
sequences, err := h.currentSequences(ctx, true, (&transaction{Tx: tx}).QueryContext, instanceIDs)
if err != nil {
tx.Rollback()
return -1, err
}
//checks for events between create statement and current sequence
// because there could be events between current sequence and a creation event
// and we cannot check via stmt.PreviousSequence
if stmts[0].PreviousSequence == 0 {
previousStmts, err := h.fetchPreviousStmts(ctx, tx, stmts[0].Sequence, stmts[0].InstanceID, sequences, reduce)
if err != nil {
tx.Rollback()
return -1, err
}
stmts = append(previousStmts, stmts...)
}
lastSuccessfulIdx := h.executeStmts(tx, &stmts, sequences)
if lastSuccessfulIdx >= 0 {
err = h.updateCurrentSequences(tx, sequences)
if err != nil {
tx.Rollback()
return -1, err
}
}
if err = tx.Commit(); err != nil {
return -1, err
}
if lastSuccessfulIdx < len(stmts)-1 {
return lastSuccessfulIdx, handler.ErrSomeStmtsFailed
}
return lastSuccessfulIdx, nil
}
func (h *StatementHandler) fetchPreviousStmts(ctx context.Context, tx *sql.Tx, stmtSeq uint64, instanceID string, sequences currentSequences, reduce handler.Reduce) (previousStmts []*handler.Statement, err error) {
query := eventstore.NewSearchQueryBuilder(eventstore.ColumnsEvent).SetTx(tx)
queriesAdded := false
for _, aggregateType := range h.aggregates {
for _, sequence := range sequences[aggregateType] {
if stmtSeq <= sequence.sequence && instanceID == sequence.instanceID {
continue
}
query.
AddQuery().
AggregateTypes(aggregateType).
SequenceGreater(sequence.sequence).
SequenceLess(stmtSeq).
InstanceID(sequence.instanceID)
queriesAdded = true
}
}
if !queriesAdded {
return nil, nil
}
events, err := h.Eventstore.Filter(ctx, query)
if err != nil {
return nil, err
}
for _, event := range events {
stmt, err := reduce(event)
if err != nil {
return nil, err
}
previousStmts = append(previousStmts, stmt)
}
return previousStmts, nil
}
func (h *StatementHandler) executeStmts(
tx *sql.Tx,
stmts *[]*handler.Statement,
sequences currentSequences,
) int {
lastSuccessfulIdx := -1
stmts:
for i := 0; i < len(*stmts); i++ {
stmt := (*stmts)[i]
for _, sequence := range sequences[stmt.AggregateType] {
if stmt.Sequence <= sequence.sequence && stmt.InstanceID == sequence.instanceID {
logging.WithFields("statement", stmt, "currentSequence", sequence).Debug("statement dropped")
if i < len(*stmts)-1 {
copy((*stmts)[i:], (*stmts)[i+1:])
}
*stmts = (*stmts)[:len(*stmts)-1]
i--
continue stmts
}
if stmt.PreviousSequence > 0 && stmt.PreviousSequence != sequence.sequence && stmt.InstanceID == sequence.instanceID {
logging.WithFields("projection", h.ProjectionName, "aggregateType", stmt.AggregateType, "sequence", stmt.Sequence, "prevSeq", stmt.PreviousSequence, "currentSeq", sequence.sequence).Warn("sequences do not match")
break stmts
}
}
err := h.executeStmt(tx, stmt)
if err == nil {
updateSequences(sequences, stmt)
lastSuccessfulIdx = i
continue
}
shouldContinue := h.handleFailedStmt(tx, stmt, err)
if !shouldContinue {
break
}
updateSequences(sequences, stmt)
lastSuccessfulIdx = i
continue
}
return lastSuccessfulIdx
}
// executeStmt handles sql statements
// an error is returned if the statement could not be inserted properly
func (h *StatementHandler) executeStmt(tx *sql.Tx, stmt *handler.Statement) error {
if stmt.IsNoop() {
return nil
}
_, err := tx.Exec("SAVEPOINT push_stmt")
if err != nil {
return errors.ThrowInternal(err, "CRDB-i1wp6", "unable to create savepoint")
}
err = stmt.Execute(tx, h.ProjectionName)
if err != nil {
logging.WithError(err).Error()
_, rollbackErr := tx.Exec("ROLLBACK TO SAVEPOINT push_stmt")
if rollbackErr != nil {
return errors.ThrowInternal(rollbackErr, "CRDB-zzp3P", "rollback to savepoint failed")
}
return errors.ThrowInternal(err, "CRDB-oRkaN", "unable execute stmt")
}
_, err = tx.Exec("RELEASE push_stmt")
if err != nil {
return errors.ThrowInternal(err, "CRDB-qWgwT", "unable to release savepoint")
}
return nil
}
func updateSequences(sequences currentSequences, stmt *handler.Statement) {
for _, sequence := range sequences[stmt.AggregateType] {
if sequence.instanceID == stmt.InstanceID {
sequence.sequence = stmt.Sequence
return
}
}
sequences[stmt.AggregateType] = append(sequences[stmt.AggregateType], &instanceSequence{
instanceID: stmt.InstanceID,
sequence: stmt.Sequence,
})
}
func appendToInstanceIDs(instances []string, id string) []string {
for _, instance := range instances {
if instance == id {
return instances
}
}
return append(instances, id)
}