mirror of
https://github.com/zitadel/zitadel.git
synced 2025-01-06 13:57:41 +00:00
fix: improve context handling in projections (#3638)
* fix: improve context handling in projections * fix tests * use as of system time for current sequence * use as of system time for current sequence Co-authored-by: Silvan <silvan.reusser@gmail.com>
This commit is contained in:
parent
ed0aa7088b
commit
c71ccc8a80
@ -1,6 +1,7 @@
|
||||
package crdb
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"strconv"
|
||||
"strings"
|
||||
@ -21,8 +22,8 @@ type instanceSequence struct {
|
||||
sequence uint64
|
||||
}
|
||||
|
||||
func (h *StatementHandler) currentSequences(query func(string, ...interface{}) (*sql.Rows, error)) (currentSequences, error) {
|
||||
rows, err := query(h.currentSequenceStmt, h.ProjectionName)
|
||||
func (h *StatementHandler) currentSequences(ctx context.Context, query func(context.Context, string, ...interface{}) (*sql.Rows, error)) (currentSequences, error) {
|
||||
rows, err := query(ctx, h.currentSequenceStmt, h.ProjectionName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -93,8 +93,8 @@ func NewStatementHandler(
|
||||
return h
|
||||
}
|
||||
|
||||
func (h *StatementHandler) SearchQuery() (*eventstore.SearchQueryBuilder, uint64, error) {
|
||||
sequences, err := h.currentSequences(h.client.Query)
|
||||
func (h *StatementHandler) SearchQuery(ctx context.Context) (*eventstore.SearchQueryBuilder, uint64, error) {
|
||||
sequences, err := h.currentSequences(ctx, h.client.QueryContext)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
@ -131,12 +131,15 @@ func appendToIgnoredInstances(instances []string, id string) []string {
|
||||
|
||||
//Update implements handler.Update
|
||||
func (h *StatementHandler) Update(ctx context.Context, stmts []*handler.Statement, reduce handler.Reduce) (unexecutedStmts []*handler.Statement, err error) {
|
||||
if len(stmts) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
tx, err := h.client.BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
return stmts, errors.ThrowInternal(err, "CRDB-e89Gq", "begin failed")
|
||||
}
|
||||
|
||||
sequences, err := h.currentSequences(tx.Query)
|
||||
sequences, err := h.currentSequences(ctx, tx.QueryContext)
|
||||
if err != nil {
|
||||
tx.Rollback()
|
||||
return stmts, err
|
||||
@ -154,7 +157,7 @@ func (h *StatementHandler) Update(ctx context.Context, stmts []*handler.Statemen
|
||||
stmts = append(previousStmts, stmts...)
|
||||
}
|
||||
|
||||
lastSuccessfulIdx := h.executeStmts(tx, stmts, sequences)
|
||||
lastSuccessfulIdx := h.executeStmts(tx, &stmts, sequences)
|
||||
|
||||
if lastSuccessfulIdx >= 0 {
|
||||
err = h.updateCurrentSequences(tx, sequences)
|
||||
@ -168,7 +171,7 @@ func (h *StatementHandler) Update(ctx context.Context, stmts []*handler.Statemen
|
||||
return stmts, err
|
||||
}
|
||||
|
||||
if lastSuccessfulIdx == -1 {
|
||||
if lastSuccessfulIdx == -1 && len(stmts) > 0 {
|
||||
return stmts, handler.ErrSomeStmtsFailed
|
||||
}
|
||||
|
||||
@ -225,19 +228,27 @@ func (h *StatementHandler) fetchPreviousStmts(ctx context.Context, stmtSeq uint6
|
||||
|
||||
func (h *StatementHandler) executeStmts(
|
||||
tx *sql.Tx,
|
||||
stmts []*handler.Statement,
|
||||
stmts *[]*handler.Statement,
|
||||
sequences currentSequences,
|
||||
) int {
|
||||
|
||||
lastSuccessfulIdx := -1
|
||||
for i, stmt := range stmts {
|
||||
stmts:
|
||||
for i := 0; i < len(*stmts); i++ {
|
||||
stmt := (*stmts)[i]
|
||||
for _, sequence := range sequences[stmt.AggregateType] {
|
||||
if stmt.Sequence <= sequence.sequence && stmt.InstanceID == sequence.instanceID {
|
||||
continue
|
||||
logging.WithFields("statement", stmt, "currentSequence", sequence).Debug("statement dropped")
|
||||
if i < len(*stmts)-1 {
|
||||
copy((*stmts)[i:], (*stmts)[i+1:])
|
||||
}
|
||||
*stmts = (*stmts)[:len(*stmts)-1]
|
||||
i--
|
||||
continue stmts
|
||||
}
|
||||
if stmt.PreviousSequence > 0 && stmt.PreviousSequence != sequence.sequence && stmt.InstanceID == sequence.instanceID {
|
||||
logging.WithFields("projection", h.ProjectionName, "aggregateType", stmt.AggregateType, "sequence", stmt.Sequence, "prevSeq", stmt.PreviousSequence, "currentSeq", sequences[stmt.AggregateType]).Warn("sequences do not match")
|
||||
break
|
||||
logging.WithFields("projection", h.ProjectionName, "aggregateType", stmt.AggregateType, "sequence", stmt.Sequence, "prevSeq", stmt.PreviousSequence, "currentSeq", sequence.sequence).Warn("sequences do not match")
|
||||
break stmts
|
||||
}
|
||||
}
|
||||
err := h.executeStmt(tx, stmt)
|
||||
|
@ -138,7 +138,7 @@ func TestProjectionHandler_SearchQuery(t *testing.T) {
|
||||
expectation(mock)
|
||||
}
|
||||
|
||||
query, limit, err := h.SearchQuery()
|
||||
query, limit, err := h.SearchQuery(context.Background())
|
||||
if !tt.want.isErr(err) {
|
||||
t.Errorf("ProjectionHandler.prepareBulkStmts() error = %v", err)
|
||||
return
|
||||
@ -183,6 +183,13 @@ func TestStatementHandler_Update(t *testing.T) {
|
||||
name: "begin fails",
|
||||
args: args{
|
||||
ctx: context.Background(),
|
||||
stmts: []*handler.Statement{
|
||||
NewNoOpStatement(&testEvent{
|
||||
aggregateType: "agg",
|
||||
sequence: 6,
|
||||
previousSequence: 0,
|
||||
}),
|
||||
},
|
||||
},
|
||||
want: want{
|
||||
expectations: []mockExpectation{
|
||||
@ -197,6 +204,13 @@ func TestStatementHandler_Update(t *testing.T) {
|
||||
name: "current sequence fails",
|
||||
args: args{
|
||||
ctx: context.Background(),
|
||||
stmts: []*handler.Statement{
|
||||
NewNoOpStatement(&testEvent{
|
||||
aggregateType: "agg",
|
||||
sequence: 6,
|
||||
previousSequence: 0,
|
||||
}),
|
||||
},
|
||||
},
|
||||
want: want{
|
||||
expectations: []mockExpectation{
|
||||
@ -494,7 +508,7 @@ func TestStatementHandler_Update(t *testing.T) {
|
||||
if !tt.want.isErr(err) {
|
||||
t.Errorf("StatementHandler.Update() error = %v", err)
|
||||
}
|
||||
if tt.want.stmtsLen != len(stmts) {
|
||||
if err == nil && tt.want.stmtsLen != len(stmts) {
|
||||
t.Errorf("wrong stmts length: want: %d got %d", tt.want.stmtsLen, len(stmts))
|
||||
}
|
||||
|
||||
@ -1069,7 +1083,7 @@ func TestStatementHandler_executeStmts(t *testing.T) {
|
||||
t.Fatalf("unexpected err in begin: %v", err)
|
||||
}
|
||||
|
||||
idx := h.executeStmts(tx, tt.args.stmts, tt.args.sequences)
|
||||
idx := h.executeStmts(tx, &tt.args.stmts, tt.args.sequences)
|
||||
if idx != tt.want.idx {
|
||||
t.Errorf("unexpected index want: %d got %d", tt.want.idx, idx)
|
||||
}
|
||||
@ -1420,7 +1434,7 @@ func TestStatementHandler_currentSequence(t *testing.T) {
|
||||
t.Fatalf("unexpected err in begin: %v", err)
|
||||
}
|
||||
|
||||
seq, err := h.currentSequences(tx.Query)
|
||||
seq, err := h.currentSequences(context.Background(), tx.QueryContext)
|
||||
if !tt.want.isErr(err) {
|
||||
t.Errorf("unexpected error: %v", err)
|
||||
}
|
||||
|
@ -67,7 +67,7 @@ func (h *locker) handleLock(ctx context.Context, errs chan error, lockDuration t
|
||||
|
||||
func (h *locker) renewLock(ctx context.Context, lockDuration time.Duration, instanceID string) error {
|
||||
//the unit of crdb interval is seconds (https://www.cockroachlabs.com/docs/stable/interval.html).
|
||||
res, err := h.client.Exec(h.lockStmt, h.workerName, lockDuration.Seconds(), h.projectionName, instanceID)
|
||||
res, err := h.client.ExecContext(ctx, h.lockStmt, h.workerName, lockDuration.Seconds(), h.projectionName, instanceID)
|
||||
if err != nil {
|
||||
return errors.ThrowInternal(err, "CRDB-uaDoR", "unable to execute lock")
|
||||
}
|
||||
|
@ -35,7 +35,7 @@ type Lock func(context.Context, time.Duration, string) <-chan error
|
||||
type Unlock func(string) error
|
||||
|
||||
//SearchQuery generates the search query to lookup for events
|
||||
type SearchQuery func() (query *eventstore.SearchQueryBuilder, queryLimit uint64, err error)
|
||||
type SearchQuery func(ctx context.Context) (query *eventstore.SearchQueryBuilder, queryLimit uint64, err error)
|
||||
|
||||
type ProjectionHandler struct {
|
||||
Handler
|
||||
@ -259,7 +259,7 @@ func (h *ProjectionHandler) fetchBulkStmts(
|
||||
query SearchQuery,
|
||||
reduce Reduce,
|
||||
) (limitExeeded bool, err error) {
|
||||
eventQuery, eventsLimit, err := query()
|
||||
eventQuery, eventsLimit, err := query(ctx)
|
||||
if err != nil {
|
||||
logging.WithFields("projection", h.ProjectionName).WithError(err).Warn("unable to create event query")
|
||||
return false, err
|
||||
|
@ -861,7 +861,7 @@ func testReduceErr(err error) Reduce {
|
||||
}
|
||||
|
||||
func testQuery(query *eventstore.SearchQueryBuilder, limit uint64, err error) SearchQuery {
|
||||
return func() (*eventstore.SearchQueryBuilder, uint64, error) {
|
||||
return func(ctx context.Context) (*eventstore.SearchQueryBuilder, uint64, error) {
|
||||
return query, limit, err
|
||||
}
|
||||
}
|
||||
|
@ -195,7 +195,7 @@ func prepareLatestSequence() (sq.SelectBuilder, func(*sql.Row) (*LatestSequence,
|
||||
return sq.Select(
|
||||
CurrentSequenceColCurrentSequence.identifier(),
|
||||
CurrentSequenceColTimestamp.identifier()).
|
||||
From(currentSequencesTable.identifier()).PlaceholderFormat(sq.Dollar),
|
||||
From(currentSequencesTable.identifier() + " AS OF SYSTEM TIME '-1ms'").PlaceholderFormat(sq.Dollar),
|
||||
func(row *sql.Row) (*LatestSequence, error) {
|
||||
seq := new(LatestSequence)
|
||||
err := row.Scan(
|
||||
|
Loading…
x
Reference in New Issue
Block a user