fix: improve context handling in projections (#3638)

* fix: improve context handling in projections

* fix tests

* use as of system time for current sequence

* use as of system time for current sequence

Co-authored-by: Silvan <silvan.reusser@gmail.com>
This commit is contained in:
Livio Amstutz 2022-05-19 10:25:19 +02:00 committed by GitHub
parent ed0aa7088b
commit c71ccc8a80
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 47 additions and 21 deletions

View File

@ -1,6 +1,7 @@
package crdb package crdb
import ( import (
"context"
"database/sql" "database/sql"
"strconv" "strconv"
"strings" "strings"
@ -21,8 +22,8 @@ type instanceSequence struct {
sequence uint64 sequence uint64
} }
func (h *StatementHandler) currentSequences(query func(string, ...interface{}) (*sql.Rows, error)) (currentSequences, error) { func (h *StatementHandler) currentSequences(ctx context.Context, query func(context.Context, string, ...interface{}) (*sql.Rows, error)) (currentSequences, error) {
rows, err := query(h.currentSequenceStmt, h.ProjectionName) rows, err := query(ctx, h.currentSequenceStmt, h.ProjectionName)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -93,8 +93,8 @@ func NewStatementHandler(
return h return h
} }
func (h *StatementHandler) SearchQuery() (*eventstore.SearchQueryBuilder, uint64, error) { func (h *StatementHandler) SearchQuery(ctx context.Context) (*eventstore.SearchQueryBuilder, uint64, error) {
sequences, err := h.currentSequences(h.client.Query) sequences, err := h.currentSequences(ctx, h.client.QueryContext)
if err != nil { if err != nil {
return nil, 0, err return nil, 0, err
} }
@ -131,12 +131,15 @@ func appendToIgnoredInstances(instances []string, id string) []string {
//Update implements handler.Update //Update implements handler.Update
func (h *StatementHandler) Update(ctx context.Context, stmts []*handler.Statement, reduce handler.Reduce) (unexecutedStmts []*handler.Statement, err error) { func (h *StatementHandler) Update(ctx context.Context, stmts []*handler.Statement, reduce handler.Reduce) (unexecutedStmts []*handler.Statement, err error) {
if len(stmts) == 0 {
return nil, nil
}
tx, err := h.client.BeginTx(ctx, nil) tx, err := h.client.BeginTx(ctx, nil)
if err != nil { if err != nil {
return stmts, errors.ThrowInternal(err, "CRDB-e89Gq", "begin failed") return stmts, errors.ThrowInternal(err, "CRDB-e89Gq", "begin failed")
} }
sequences, err := h.currentSequences(tx.Query) sequences, err := h.currentSequences(ctx, tx.QueryContext)
if err != nil { if err != nil {
tx.Rollback() tx.Rollback()
return stmts, err return stmts, err
@ -154,7 +157,7 @@ func (h *StatementHandler) Update(ctx context.Context, stmts []*handler.Statemen
stmts = append(previousStmts, stmts...) stmts = append(previousStmts, stmts...)
} }
lastSuccessfulIdx := h.executeStmts(tx, stmts, sequences) lastSuccessfulIdx := h.executeStmts(tx, &stmts, sequences)
if lastSuccessfulIdx >= 0 { if lastSuccessfulIdx >= 0 {
err = h.updateCurrentSequences(tx, sequences) err = h.updateCurrentSequences(tx, sequences)
@ -168,7 +171,7 @@ func (h *StatementHandler) Update(ctx context.Context, stmts []*handler.Statemen
return stmts, err return stmts, err
} }
if lastSuccessfulIdx == -1 { if lastSuccessfulIdx == -1 && len(stmts) > 0 {
return stmts, handler.ErrSomeStmtsFailed return stmts, handler.ErrSomeStmtsFailed
} }
@ -225,19 +228,27 @@ func (h *StatementHandler) fetchPreviousStmts(ctx context.Context, stmtSeq uint6
func (h *StatementHandler) executeStmts( func (h *StatementHandler) executeStmts(
tx *sql.Tx, tx *sql.Tx,
stmts []*handler.Statement, stmts *[]*handler.Statement,
sequences currentSequences, sequences currentSequences,
) int { ) int {
lastSuccessfulIdx := -1 lastSuccessfulIdx := -1
for i, stmt := range stmts { stmts:
for i := 0; i < len(*stmts); i++ {
stmt := (*stmts)[i]
for _, sequence := range sequences[stmt.AggregateType] { for _, sequence := range sequences[stmt.AggregateType] {
if stmt.Sequence <= sequence.sequence && stmt.InstanceID == sequence.instanceID { if stmt.Sequence <= sequence.sequence && stmt.InstanceID == sequence.instanceID {
continue logging.WithFields("statement", stmt, "currentSequence", sequence).Debug("statement dropped")
if i < len(*stmts)-1 {
copy((*stmts)[i:], (*stmts)[i+1:])
}
*stmts = (*stmts)[:len(*stmts)-1]
i--
continue stmts
} }
if stmt.PreviousSequence > 0 && stmt.PreviousSequence != sequence.sequence && stmt.InstanceID == sequence.instanceID { if stmt.PreviousSequence > 0 && stmt.PreviousSequence != sequence.sequence && stmt.InstanceID == sequence.instanceID {
logging.WithFields("projection", h.ProjectionName, "aggregateType", stmt.AggregateType, "sequence", stmt.Sequence, "prevSeq", stmt.PreviousSequence, "currentSeq", sequences[stmt.AggregateType]).Warn("sequences do not match") logging.WithFields("projection", h.ProjectionName, "aggregateType", stmt.AggregateType, "sequence", stmt.Sequence, "prevSeq", stmt.PreviousSequence, "currentSeq", sequence.sequence).Warn("sequences do not match")
break break stmts
} }
} }
err := h.executeStmt(tx, stmt) err := h.executeStmt(tx, stmt)

View File

@ -138,7 +138,7 @@ func TestProjectionHandler_SearchQuery(t *testing.T) {
expectation(mock) expectation(mock)
} }
query, limit, err := h.SearchQuery() query, limit, err := h.SearchQuery(context.Background())
if !tt.want.isErr(err) { if !tt.want.isErr(err) {
t.Errorf("ProjectionHandler.prepareBulkStmts() error = %v", err) t.Errorf("ProjectionHandler.prepareBulkStmts() error = %v", err)
return return
@ -183,6 +183,13 @@ func TestStatementHandler_Update(t *testing.T) {
name: "begin fails", name: "begin fails",
args: args{ args: args{
ctx: context.Background(), ctx: context.Background(),
stmts: []*handler.Statement{
NewNoOpStatement(&testEvent{
aggregateType: "agg",
sequence: 6,
previousSequence: 0,
}),
},
}, },
want: want{ want: want{
expectations: []mockExpectation{ expectations: []mockExpectation{
@ -197,6 +204,13 @@ func TestStatementHandler_Update(t *testing.T) {
name: "current sequence fails", name: "current sequence fails",
args: args{ args: args{
ctx: context.Background(), ctx: context.Background(),
stmts: []*handler.Statement{
NewNoOpStatement(&testEvent{
aggregateType: "agg",
sequence: 6,
previousSequence: 0,
}),
},
}, },
want: want{ want: want{
expectations: []mockExpectation{ expectations: []mockExpectation{
@ -494,7 +508,7 @@ func TestStatementHandler_Update(t *testing.T) {
if !tt.want.isErr(err) { if !tt.want.isErr(err) {
t.Errorf("StatementHandler.Update() error = %v", err) t.Errorf("StatementHandler.Update() error = %v", err)
} }
if tt.want.stmtsLen != len(stmts) { if err == nil && tt.want.stmtsLen != len(stmts) {
t.Errorf("wrong stmts length: want: %d got %d", tt.want.stmtsLen, len(stmts)) t.Errorf("wrong stmts length: want: %d got %d", tt.want.stmtsLen, len(stmts))
} }
@ -1069,7 +1083,7 @@ func TestStatementHandler_executeStmts(t *testing.T) {
t.Fatalf("unexpected err in begin: %v", err) t.Fatalf("unexpected err in begin: %v", err)
} }
idx := h.executeStmts(tx, tt.args.stmts, tt.args.sequences) idx := h.executeStmts(tx, &tt.args.stmts, tt.args.sequences)
if idx != tt.want.idx { if idx != tt.want.idx {
t.Errorf("unexpected index want: %d got %d", tt.want.idx, idx) t.Errorf("unexpected index want: %d got %d", tt.want.idx, idx)
} }
@ -1420,7 +1434,7 @@ func TestStatementHandler_currentSequence(t *testing.T) {
t.Fatalf("unexpected err in begin: %v", err) t.Fatalf("unexpected err in begin: %v", err)
} }
seq, err := h.currentSequences(tx.Query) seq, err := h.currentSequences(context.Background(), tx.QueryContext)
if !tt.want.isErr(err) { if !tt.want.isErr(err) {
t.Errorf("unexpected error: %v", err) t.Errorf("unexpected error: %v", err)
} }

View File

@ -67,7 +67,7 @@ func (h *locker) handleLock(ctx context.Context, errs chan error, lockDuration t
func (h *locker) renewLock(ctx context.Context, lockDuration time.Duration, instanceID string) error { func (h *locker) renewLock(ctx context.Context, lockDuration time.Duration, instanceID string) error {
//the unit of crdb interval is seconds (https://www.cockroachlabs.com/docs/stable/interval.html). //the unit of crdb interval is seconds (https://www.cockroachlabs.com/docs/stable/interval.html).
res, err := h.client.Exec(h.lockStmt, h.workerName, lockDuration.Seconds(), h.projectionName, instanceID) res, err := h.client.ExecContext(ctx, h.lockStmt, h.workerName, lockDuration.Seconds(), h.projectionName, instanceID)
if err != nil { if err != nil {
return errors.ThrowInternal(err, "CRDB-uaDoR", "unable to execute lock") return errors.ThrowInternal(err, "CRDB-uaDoR", "unable to execute lock")
} }

View File

@ -35,7 +35,7 @@ type Lock func(context.Context, time.Duration, string) <-chan error
type Unlock func(string) error type Unlock func(string) error
//SearchQuery generates the search query to lookup for events //SearchQuery generates the search query to lookup for events
type SearchQuery func() (query *eventstore.SearchQueryBuilder, queryLimit uint64, err error) type SearchQuery func(ctx context.Context) (query *eventstore.SearchQueryBuilder, queryLimit uint64, err error)
type ProjectionHandler struct { type ProjectionHandler struct {
Handler Handler
@ -259,7 +259,7 @@ func (h *ProjectionHandler) fetchBulkStmts(
query SearchQuery, query SearchQuery,
reduce Reduce, reduce Reduce,
) (limitExeeded bool, err error) { ) (limitExeeded bool, err error) {
eventQuery, eventsLimit, err := query() eventQuery, eventsLimit, err := query(ctx)
if err != nil { if err != nil {
logging.WithFields("projection", h.ProjectionName).WithError(err).Warn("unable to create event query") logging.WithFields("projection", h.ProjectionName).WithError(err).Warn("unable to create event query")
return false, err return false, err

View File

@ -861,7 +861,7 @@ func testReduceErr(err error) Reduce {
} }
func testQuery(query *eventstore.SearchQueryBuilder, limit uint64, err error) SearchQuery { func testQuery(query *eventstore.SearchQueryBuilder, limit uint64, err error) SearchQuery {
return func() (*eventstore.SearchQueryBuilder, uint64, error) { return func(ctx context.Context) (*eventstore.SearchQueryBuilder, uint64, error) {
return query, limit, err return query, limit, err
} }
} }

View File

@ -195,7 +195,7 @@ func prepareLatestSequence() (sq.SelectBuilder, func(*sql.Row) (*LatestSequence,
return sq.Select( return sq.Select(
CurrentSequenceColCurrentSequence.identifier(), CurrentSequenceColCurrentSequence.identifier(),
CurrentSequenceColTimestamp.identifier()). CurrentSequenceColTimestamp.identifier()).
From(currentSequencesTable.identifier()).PlaceholderFormat(sq.Dollar), From(currentSequencesTable.identifier() + " AS OF SYSTEM TIME '-1ms'").PlaceholderFormat(sq.Dollar),
func(row *sql.Row) (*LatestSequence, error) { func(row *sql.Row) (*LatestSequence, error) {
seq := new(LatestSequence) seq := new(LatestSequence)
err := row.Scan( err := row.Scan(