From c71ccc8a800819f7dac3a4cc8681f2990410435c Mon Sep 17 00:00:00 2001 From: Livio Amstutz Date: Thu, 19 May 2022 10:25:19 +0200 Subject: [PATCH] fix: improve context handling in projections (#3638) * fix: improve context handling in projections * fix tests * use as of system time for current sequence * use as of system time for current sequence Co-authored-by: Silvan --- .../handler/crdb/current_sequence.go | 5 +-- .../eventstore/handler/crdb/handler_stmt.go | 31 +++++++++++++------ .../handler/crdb/handler_stmt_test.go | 22 ++++++++++--- internal/eventstore/handler/crdb/lock.go | 2 +- .../eventstore/handler/handler_projection.go | 4 +-- .../handler/handler_projection_test.go | 2 +- internal/query/current_sequence.go | 2 +- 7 files changed, 47 insertions(+), 21 deletions(-) diff --git a/internal/eventstore/handler/crdb/current_sequence.go b/internal/eventstore/handler/crdb/current_sequence.go index 2b14ecccf5..5f381c922b 100644 --- a/internal/eventstore/handler/crdb/current_sequence.go +++ b/internal/eventstore/handler/crdb/current_sequence.go @@ -1,6 +1,7 @@ package crdb import ( + "context" "database/sql" "strconv" "strings" @@ -21,8 +22,8 @@ type instanceSequence struct { sequence uint64 } -func (h *StatementHandler) currentSequences(query func(string, ...interface{}) (*sql.Rows, error)) (currentSequences, error) { - rows, err := query(h.currentSequenceStmt, h.ProjectionName) +func (h *StatementHandler) currentSequences(ctx context.Context, query func(context.Context, string, ...interface{}) (*sql.Rows, error)) (currentSequences, error) { + rows, err := query(ctx, h.currentSequenceStmt, h.ProjectionName) if err != nil { return nil, err } diff --git a/internal/eventstore/handler/crdb/handler_stmt.go b/internal/eventstore/handler/crdb/handler_stmt.go index fd8cc5f17a..6a50ecd1c9 100644 --- a/internal/eventstore/handler/crdb/handler_stmt.go +++ b/internal/eventstore/handler/crdb/handler_stmt.go @@ -93,8 +93,8 @@ func NewStatementHandler( return h } -func (h *StatementHandler) SearchQuery() (*eventstore.SearchQueryBuilder, uint64, error) { - sequences, err := h.currentSequences(h.client.Query) +func (h *StatementHandler) SearchQuery(ctx context.Context) (*eventstore.SearchQueryBuilder, uint64, error) { + sequences, err := h.currentSequences(ctx, h.client.QueryContext) if err != nil { return nil, 0, err } @@ -131,12 +131,15 @@ func appendToIgnoredInstances(instances []string, id string) []string { //Update implements handler.Update func (h *StatementHandler) Update(ctx context.Context, stmts []*handler.Statement, reduce handler.Reduce) (unexecutedStmts []*handler.Statement, err error) { + if len(stmts) == 0 { + return nil, nil + } tx, err := h.client.BeginTx(ctx, nil) if err != nil { return stmts, errors.ThrowInternal(err, "CRDB-e89Gq", "begin failed") } - sequences, err := h.currentSequences(tx.Query) + sequences, err := h.currentSequences(ctx, tx.QueryContext) if err != nil { tx.Rollback() return stmts, err @@ -154,7 +157,7 @@ func (h *StatementHandler) Update(ctx context.Context, stmts []*handler.Statemen stmts = append(previousStmts, stmts...) } - lastSuccessfulIdx := h.executeStmts(tx, stmts, sequences) + lastSuccessfulIdx := h.executeStmts(tx, &stmts, sequences) if lastSuccessfulIdx >= 0 { err = h.updateCurrentSequences(tx, sequences) @@ -168,7 +171,7 @@ func (h *StatementHandler) Update(ctx context.Context, stmts []*handler.Statemen return stmts, err } - if lastSuccessfulIdx == -1 { + if lastSuccessfulIdx == -1 && len(stmts) > 0 { return stmts, handler.ErrSomeStmtsFailed } @@ -225,19 +228,27 @@ func (h *StatementHandler) fetchPreviousStmts(ctx context.Context, stmtSeq uint6 func (h *StatementHandler) executeStmts( tx *sql.Tx, - stmts []*handler.Statement, + stmts *[]*handler.Statement, sequences currentSequences, ) int { lastSuccessfulIdx := -1 - for i, stmt := range stmts { +stmts: + for i := 0; i < len(*stmts); i++ { + stmt := (*stmts)[i] for _, sequence := range sequences[stmt.AggregateType] { if stmt.Sequence <= sequence.sequence && stmt.InstanceID == sequence.instanceID { - continue + logging.WithFields("statement", stmt, "currentSequence", sequence).Debug("statement dropped") + if i < len(*stmts)-1 { + copy((*stmts)[i:], (*stmts)[i+1:]) + } + *stmts = (*stmts)[:len(*stmts)-1] + i-- + continue stmts } if stmt.PreviousSequence > 0 && stmt.PreviousSequence != sequence.sequence && stmt.InstanceID == sequence.instanceID { - logging.WithFields("projection", h.ProjectionName, "aggregateType", stmt.AggregateType, "sequence", stmt.Sequence, "prevSeq", stmt.PreviousSequence, "currentSeq", sequences[stmt.AggregateType]).Warn("sequences do not match") - break + logging.WithFields("projection", h.ProjectionName, "aggregateType", stmt.AggregateType, "sequence", stmt.Sequence, "prevSeq", stmt.PreviousSequence, "currentSeq", sequence.sequence).Warn("sequences do not match") + break stmts } } err := h.executeStmt(tx, stmt) diff --git a/internal/eventstore/handler/crdb/handler_stmt_test.go b/internal/eventstore/handler/crdb/handler_stmt_test.go index 009e443af9..f2bb26428d 100644 --- a/internal/eventstore/handler/crdb/handler_stmt_test.go +++ b/internal/eventstore/handler/crdb/handler_stmt_test.go @@ -138,7 +138,7 @@ func TestProjectionHandler_SearchQuery(t *testing.T) { expectation(mock) } - query, limit, err := h.SearchQuery() + query, limit, err := h.SearchQuery(context.Background()) if !tt.want.isErr(err) { t.Errorf("ProjectionHandler.prepareBulkStmts() error = %v", err) return @@ -183,6 +183,13 @@ func TestStatementHandler_Update(t *testing.T) { name: "begin fails", args: args{ ctx: context.Background(), + stmts: []*handler.Statement{ + NewNoOpStatement(&testEvent{ + aggregateType: "agg", + sequence: 6, + previousSequence: 0, + }), + }, }, want: want{ expectations: []mockExpectation{ @@ -197,6 +204,13 @@ func TestStatementHandler_Update(t *testing.T) { name: "current sequence fails", args: args{ ctx: context.Background(), + stmts: []*handler.Statement{ + NewNoOpStatement(&testEvent{ + aggregateType: "agg", + sequence: 6, + previousSequence: 0, + }), + }, }, want: want{ expectations: []mockExpectation{ @@ -494,7 +508,7 @@ func TestStatementHandler_Update(t *testing.T) { if !tt.want.isErr(err) { t.Errorf("StatementHandler.Update() error = %v", err) } - if tt.want.stmtsLen != len(stmts) { + if err == nil && tt.want.stmtsLen != len(stmts) { t.Errorf("wrong stmts length: want: %d got %d", tt.want.stmtsLen, len(stmts)) } @@ -1069,7 +1083,7 @@ func TestStatementHandler_executeStmts(t *testing.T) { t.Fatalf("unexpected err in begin: %v", err) } - idx := h.executeStmts(tx, tt.args.stmts, tt.args.sequences) + idx := h.executeStmts(tx, &tt.args.stmts, tt.args.sequences) if idx != tt.want.idx { t.Errorf("unexpected index want: %d got %d", tt.want.idx, idx) } @@ -1420,7 +1434,7 @@ func TestStatementHandler_currentSequence(t *testing.T) { t.Fatalf("unexpected err in begin: %v", err) } - seq, err := h.currentSequences(tx.Query) + seq, err := h.currentSequences(context.Background(), tx.QueryContext) if !tt.want.isErr(err) { t.Errorf("unexpected error: %v", err) } diff --git a/internal/eventstore/handler/crdb/lock.go b/internal/eventstore/handler/crdb/lock.go index 55af5168c1..ccd32f0494 100644 --- a/internal/eventstore/handler/crdb/lock.go +++ b/internal/eventstore/handler/crdb/lock.go @@ -67,7 +67,7 @@ func (h *locker) handleLock(ctx context.Context, errs chan error, lockDuration t func (h *locker) renewLock(ctx context.Context, lockDuration time.Duration, instanceID string) error { //the unit of crdb interval is seconds (https://www.cockroachlabs.com/docs/stable/interval.html). - res, err := h.client.Exec(h.lockStmt, h.workerName, lockDuration.Seconds(), h.projectionName, instanceID) + res, err := h.client.ExecContext(ctx, h.lockStmt, h.workerName, lockDuration.Seconds(), h.projectionName, instanceID) if err != nil { return errors.ThrowInternal(err, "CRDB-uaDoR", "unable to execute lock") } diff --git a/internal/eventstore/handler/handler_projection.go b/internal/eventstore/handler/handler_projection.go index 7f54464b81..15ce71ac4d 100644 --- a/internal/eventstore/handler/handler_projection.go +++ b/internal/eventstore/handler/handler_projection.go @@ -35,7 +35,7 @@ type Lock func(context.Context, time.Duration, string) <-chan error type Unlock func(string) error //SearchQuery generates the search query to lookup for events -type SearchQuery func() (query *eventstore.SearchQueryBuilder, queryLimit uint64, err error) +type SearchQuery func(ctx context.Context) (query *eventstore.SearchQueryBuilder, queryLimit uint64, err error) type ProjectionHandler struct { Handler @@ -259,7 +259,7 @@ func (h *ProjectionHandler) fetchBulkStmts( query SearchQuery, reduce Reduce, ) (limitExeeded bool, err error) { - eventQuery, eventsLimit, err := query() + eventQuery, eventsLimit, err := query(ctx) if err != nil { logging.WithFields("projection", h.ProjectionName).WithError(err).Warn("unable to create event query") return false, err diff --git a/internal/eventstore/handler/handler_projection_test.go b/internal/eventstore/handler/handler_projection_test.go index 86954b9ab1..16601233a4 100644 --- a/internal/eventstore/handler/handler_projection_test.go +++ b/internal/eventstore/handler/handler_projection_test.go @@ -861,7 +861,7 @@ func testReduceErr(err error) Reduce { } func testQuery(query *eventstore.SearchQueryBuilder, limit uint64, err error) SearchQuery { - return func() (*eventstore.SearchQueryBuilder, uint64, error) { + return func(ctx context.Context) (*eventstore.SearchQueryBuilder, uint64, error) { return query, limit, err } } diff --git a/internal/query/current_sequence.go b/internal/query/current_sequence.go index d0d034414a..2f71e6d45b 100644 --- a/internal/query/current_sequence.go +++ b/internal/query/current_sequence.go @@ -195,7 +195,7 @@ func prepareLatestSequence() (sq.SelectBuilder, func(*sql.Row) (*LatestSequence, return sq.Select( CurrentSequenceColCurrentSequence.identifier(), CurrentSequenceColTimestamp.identifier()). - From(currentSequencesTable.identifier()).PlaceholderFormat(sq.Dollar), + From(currentSequencesTable.identifier() + " AS OF SYSTEM TIME '-1ms'").PlaceholderFormat(sq.Dollar), func(row *sql.Row) (*LatestSequence, error) { seq := new(LatestSequence) err := row.Scan(