mirror of
https://github.com/zitadel/zitadel.git
synced 2025-01-13 16:43:40 +00:00
dd2f31683c
* feat: directly specify factors on addCustomLoginPolicy and return on LoginPolicy responses * fix proto * update login policy * feat: directly specify idp on addCustomLoginPolicy and return on LoginPolicy responses * fix: tests * fix(projection): trigger bulk * refactor: clean projection pkg * instance should bulk * fix(query): should trigger bulk on id calls * tests * build prerelease * fix: add shouldTriggerBulk * fix: test Co-authored-by: Livio Amstutz <livio.a@gmail.com> Co-authored-by: Max Peintner <max@caos.ch>
319 lines
8.9 KiB
Go
319 lines
8.9 KiB
Go
package crdb
|
|
|
|
import (
|
|
"context"
|
|
"database/sql"
|
|
"fmt"
|
|
|
|
"github.com/zitadel/logging"
|
|
|
|
"github.com/zitadel/zitadel/internal/errors"
|
|
"github.com/zitadel/zitadel/internal/eventstore"
|
|
"github.com/zitadel/zitadel/internal/eventstore/handler"
|
|
"github.com/zitadel/zitadel/internal/telemetry/tracing"
|
|
)
|
|
|
|
var (
|
|
errSeqNotUpdated = errors.ThrowInternal(nil, "CRDB-79GWt", "current sequence not updated")
|
|
)
|
|
|
|
type StatementHandlerConfig struct {
|
|
handler.ProjectionHandlerConfig
|
|
|
|
Client *sql.DB
|
|
SequenceTable string
|
|
LockTable string
|
|
FailedEventsTable string
|
|
MaxFailureCount uint
|
|
BulkLimit uint64
|
|
|
|
Reducers []handler.AggregateReducer
|
|
InitCheck *handler.Check
|
|
}
|
|
|
|
type StatementHandler struct {
|
|
*handler.ProjectionHandler
|
|
Locker
|
|
|
|
client *sql.DB
|
|
sequenceTable string
|
|
currentSequenceStmt string
|
|
updateSequencesBaseStmt string
|
|
maxFailureCount uint
|
|
failureCountStmt string
|
|
setFailureCountStmt string
|
|
|
|
aggregates []eventstore.AggregateType
|
|
reduces map[eventstore.EventType]handler.Reduce
|
|
|
|
bulkLimit uint64
|
|
}
|
|
|
|
func NewStatementHandler(
|
|
ctx context.Context,
|
|
config StatementHandlerConfig,
|
|
) StatementHandler {
|
|
aggregateTypes := make([]eventstore.AggregateType, 0, len(config.Reducers))
|
|
reduces := make(map[eventstore.EventType]handler.Reduce, len(config.Reducers))
|
|
for _, aggReducer := range config.Reducers {
|
|
aggregateTypes = append(aggregateTypes, aggReducer.Aggregate)
|
|
for _, eventReducer := range aggReducer.EventRedusers {
|
|
reduces[eventReducer.Event] = eventReducer.Reduce
|
|
}
|
|
}
|
|
|
|
h := StatementHandler{
|
|
client: config.Client,
|
|
sequenceTable: config.SequenceTable,
|
|
maxFailureCount: config.MaxFailureCount,
|
|
currentSequenceStmt: fmt.Sprintf(currentSequenceStmtFormat, config.SequenceTable),
|
|
updateSequencesBaseStmt: fmt.Sprintf(updateCurrentSequencesStmtFormat, config.SequenceTable),
|
|
failureCountStmt: fmt.Sprintf(failureCountStmtFormat, config.FailedEventsTable),
|
|
setFailureCountStmt: fmt.Sprintf(setFailureCountStmtFormat, config.FailedEventsTable),
|
|
aggregates: aggregateTypes,
|
|
reduces: reduces,
|
|
bulkLimit: config.BulkLimit,
|
|
Locker: NewLocker(config.Client, config.LockTable, config.ProjectionHandlerConfig.ProjectionName),
|
|
}
|
|
h.ProjectionHandler = handler.NewProjectionHandler(config.ProjectionHandlerConfig, h.reduce, h.Update, h.SearchQuery)
|
|
|
|
err := h.Init(ctx, config.InitCheck)
|
|
logging.OnError(err).Fatal("unable to initialize projections")
|
|
|
|
go h.Process(
|
|
ctx,
|
|
h.reduce,
|
|
h.Update,
|
|
h.Lock,
|
|
h.Unlock,
|
|
h.SearchQuery,
|
|
)
|
|
|
|
h.Subscribe(h.aggregates...)
|
|
|
|
return h
|
|
}
|
|
|
|
func (h *StatementHandler) TriggerBulk(ctx context.Context) {
|
|
ctx, span := tracing.NewSpan(ctx)
|
|
var err error
|
|
defer span.EndWithError(err)
|
|
|
|
err = h.ProjectionHandler.TriggerBulk(ctx, h.Lock, h.Unlock)
|
|
logging.OnError(err).WithField("projection", h.ProjectionName).Warn("unable to trigger bulk")
|
|
}
|
|
|
|
func (h *StatementHandler) SearchQuery(ctx context.Context) (*eventstore.SearchQueryBuilder, uint64, error) {
|
|
sequences, err := h.currentSequences(ctx, h.client.QueryContext)
|
|
if err != nil {
|
|
return nil, 0, err
|
|
}
|
|
|
|
queryBuilder := eventstore.NewSearchQueryBuilder(eventstore.ColumnsEvent).Limit(h.bulkLimit)
|
|
for _, aggregateType := range h.aggregates {
|
|
instances := make([]string, 0)
|
|
for _, sequence := range sequences[aggregateType] {
|
|
instances = appendToIgnoredInstances(instances, sequence.instanceID)
|
|
queryBuilder.
|
|
AddQuery().
|
|
AggregateTypes(aggregateType).
|
|
SequenceGreater(sequence.sequence).
|
|
InstanceID(sequence.instanceID)
|
|
}
|
|
queryBuilder.
|
|
AddQuery().
|
|
AggregateTypes(aggregateType).
|
|
SequenceGreater(0).
|
|
ExcludedInstanceID(instances...)
|
|
}
|
|
|
|
return queryBuilder, h.bulkLimit, nil
|
|
}
|
|
|
|
func appendToIgnoredInstances(instances []string, id string) []string {
|
|
for _, instance := range instances {
|
|
if instance == id {
|
|
return instances
|
|
}
|
|
}
|
|
return append(instances, id)
|
|
}
|
|
|
|
//Update implements handler.Update
|
|
func (h *StatementHandler) Update(ctx context.Context, stmts []*handler.Statement, reduce handler.Reduce) (unexecutedStmts []*handler.Statement, err error) {
|
|
if len(stmts) == 0 {
|
|
return nil, nil
|
|
}
|
|
tx, err := h.client.BeginTx(ctx, nil)
|
|
if err != nil {
|
|
return stmts, errors.ThrowInternal(err, "CRDB-e89Gq", "begin failed")
|
|
}
|
|
|
|
sequences, err := h.currentSequences(ctx, tx.QueryContext)
|
|
if err != nil {
|
|
tx.Rollback()
|
|
return stmts, err
|
|
}
|
|
|
|
//checks for events between create statement and current sequence
|
|
// because there could be events between current sequence and a creation event
|
|
// and we cannot check via stmt.PreviousSequence
|
|
if stmts[0].PreviousSequence == 0 {
|
|
previousStmts, err := h.fetchPreviousStmts(ctx, tx, stmts[0].Sequence, stmts[0].InstanceID, sequences, reduce)
|
|
if err != nil {
|
|
tx.Rollback()
|
|
return stmts, err
|
|
}
|
|
stmts = append(previousStmts, stmts...)
|
|
}
|
|
|
|
lastSuccessfulIdx := h.executeStmts(tx, &stmts, sequences)
|
|
|
|
if lastSuccessfulIdx >= 0 {
|
|
err = h.updateCurrentSequences(tx, sequences)
|
|
if err != nil {
|
|
tx.Rollback()
|
|
return stmts, err
|
|
}
|
|
}
|
|
|
|
if err = tx.Commit(); err != nil {
|
|
return stmts, err
|
|
}
|
|
|
|
if lastSuccessfulIdx == -1 && len(stmts) > 0 {
|
|
return stmts, handler.ErrSomeStmtsFailed
|
|
}
|
|
|
|
unexecutedStmts = make([]*handler.Statement, len(stmts)-(lastSuccessfulIdx+1))
|
|
copy(unexecutedStmts, stmts[lastSuccessfulIdx+1:])
|
|
stmts = nil
|
|
|
|
if len(unexecutedStmts) > 0 {
|
|
return unexecutedStmts, handler.ErrSomeStmtsFailed
|
|
}
|
|
|
|
return unexecutedStmts, nil
|
|
}
|
|
|
|
func (h *StatementHandler) fetchPreviousStmts(ctx context.Context, tx *sql.Tx, stmtSeq uint64, instanceID string, sequences currentSequences, reduce handler.Reduce) (previousStmts []*handler.Statement, err error) {
|
|
query := eventstore.NewSearchQueryBuilder(eventstore.ColumnsEvent).SetTx(tx)
|
|
queriesAdded := false
|
|
for _, aggregateType := range h.aggregates {
|
|
for _, sequence := range sequences[aggregateType] {
|
|
if stmtSeq <= sequence.sequence && instanceID == sequence.instanceID {
|
|
continue
|
|
}
|
|
|
|
query.
|
|
AddQuery().
|
|
AggregateTypes(aggregateType).
|
|
SequenceGreater(sequence.sequence).
|
|
SequenceLess(stmtSeq).
|
|
InstanceID(sequence.instanceID)
|
|
|
|
queriesAdded = true
|
|
}
|
|
}
|
|
|
|
if !queriesAdded {
|
|
return nil, nil
|
|
}
|
|
|
|
events, err := h.Eventstore.Filter(ctx, query)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
for _, event := range events {
|
|
stmt, err := reduce(event)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
previousStmts = append(previousStmts, stmt)
|
|
}
|
|
return previousStmts, nil
|
|
}
|
|
|
|
func (h *StatementHandler) executeStmts(
|
|
tx *sql.Tx,
|
|
stmts *[]*handler.Statement,
|
|
sequences currentSequences,
|
|
) int {
|
|
|
|
lastSuccessfulIdx := -1
|
|
stmts:
|
|
for i := 0; i < len(*stmts); i++ {
|
|
stmt := (*stmts)[i]
|
|
for _, sequence := range sequences[stmt.AggregateType] {
|
|
if stmt.Sequence <= sequence.sequence && stmt.InstanceID == sequence.instanceID {
|
|
logging.WithFields("statement", stmt, "currentSequence", sequence).Debug("statement dropped")
|
|
if i < len(*stmts)-1 {
|
|
copy((*stmts)[i:], (*stmts)[i+1:])
|
|
}
|
|
*stmts = (*stmts)[:len(*stmts)-1]
|
|
i--
|
|
continue stmts
|
|
}
|
|
if stmt.PreviousSequence > 0 && stmt.PreviousSequence != sequence.sequence && stmt.InstanceID == sequence.instanceID {
|
|
logging.WithFields("projection", h.ProjectionName, "aggregateType", stmt.AggregateType, "sequence", stmt.Sequence, "prevSeq", stmt.PreviousSequence, "currentSeq", sequence.sequence).Warn("sequences do not match")
|
|
break stmts
|
|
}
|
|
}
|
|
err := h.executeStmt(tx, stmt)
|
|
if err == nil {
|
|
updateSequences(sequences, stmt)
|
|
lastSuccessfulIdx = i
|
|
continue
|
|
}
|
|
|
|
shouldContinue := h.handleFailedStmt(tx, stmt, err)
|
|
if !shouldContinue {
|
|
break
|
|
}
|
|
|
|
updateSequences(sequences, stmt)
|
|
lastSuccessfulIdx = i
|
|
continue
|
|
}
|
|
return lastSuccessfulIdx
|
|
}
|
|
|
|
//executeStmt handles sql statements
|
|
//an error is returned if the statement could not be inserted properly
|
|
func (h *StatementHandler) executeStmt(tx *sql.Tx, stmt *handler.Statement) error {
|
|
if stmt.IsNoop() {
|
|
return nil
|
|
}
|
|
_, err := tx.Exec("SAVEPOINT push_stmt")
|
|
if err != nil {
|
|
return errors.ThrowInternal(err, "CRDB-i1wp6", "unable to create savepoint")
|
|
}
|
|
err = stmt.Execute(tx, h.ProjectionName)
|
|
if err != nil {
|
|
_, rollbackErr := tx.Exec("ROLLBACK TO SAVEPOINT push_stmt")
|
|
if rollbackErr != nil {
|
|
return errors.ThrowInternal(rollbackErr, "CRDB-zzp3P", "rollback to savepoint failed")
|
|
}
|
|
return errors.ThrowInternal(err, "CRDB-oRkaN", "unable execute stmt")
|
|
}
|
|
_, err = tx.Exec("RELEASE push_stmt")
|
|
if err != nil {
|
|
return errors.ThrowInternal(err, "CRDB-qWgwT", "unable to release savepoint")
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func updateSequences(sequences currentSequences, stmt *handler.Statement) {
|
|
for _, sequence := range sequences[stmt.AggregateType] {
|
|
if sequence.instanceID == stmt.InstanceID {
|
|
sequence.sequence = stmt.Sequence
|
|
return
|
|
}
|
|
}
|
|
sequences[stmt.AggregateType] = append(sequences[stmt.AggregateType], &instanceSequence{
|
|
instanceID: stmt.InstanceID,
|
|
sequence: stmt.Sequence,
|
|
})
|
|
}
|