feat: handle instanceID in projections (#3442)

* feat: handle instanceID in projections

* rename functions

* fix key lock

* fix import
This commit is contained in:
Livio Amstutz
2022-04-19 08:26:12 +02:00
committed by GitHub
parent c25d853820
commit 1305c14e49
120 changed files with 2078 additions and 1209 deletions

View File

@@ -10,11 +10,16 @@ import (
)
const (
currentSequenceStmtFormat = `SELECT current_sequence, aggregate_type FROM %s WHERE projection_name = $1 FOR UPDATE`
updateCurrentSequencesStmtFormat = `UPSERT INTO %s (projection_name, aggregate_type, current_sequence, timestamp) VALUES `
currentSequenceStmtFormat = `SELECT current_sequence, aggregate_type, instance_id FROM %s WHERE projection_name = $1 FOR UPDATE`
updateCurrentSequencesStmtFormat = `UPSERT INTO %s (projection_name, aggregate_type, current_sequence, instance_id, timestamp) VALUES `
)
type currentSequences map[eventstore.AggregateType]uint64
type currentSequences map[eventstore.AggregateType][]*instanceSequence
type instanceSequence struct {
instanceID string
sequence uint64
}
func (h *StatementHandler) currentSequences(query func(string, ...interface{}) (*sql.Rows, error)) (currentSequences, error) {
rows, err := query(h.currentSequenceStmt, h.ProjectionName)
@@ -29,14 +34,18 @@ func (h *StatementHandler) currentSequences(query func(string, ...interface{}) (
var (
aggregateType eventstore.AggregateType
sequence uint64
instanceID string
)
err = rows.Scan(&sequence, &aggregateType)
err = rows.Scan(&sequence, &aggregateType, &instanceID)
if err != nil {
return nil, errors.ThrowInternal(err, "CRDB-dbatK", "scan failed")
}
sequences[aggregateType] = sequence
sequences[aggregateType] = append(sequences[aggregateType], &instanceSequence{
sequence: sequence,
instanceID: instanceID,
})
}
if err = rows.Close(); err != nil {
@@ -54,10 +63,12 @@ func (h *StatementHandler) updateCurrentSequences(tx *sql.Tx, sequences currentS
valueQueries := make([]string, 0, len(sequences))
valueCounter := 0
values := make([]interface{}, 0, len(sequences)*3)
for aggregate, sequence := range sequences {
valueQueries = append(valueQueries, "($"+strconv.Itoa(valueCounter+1)+", $"+strconv.Itoa(valueCounter+2)+", $"+strconv.Itoa(valueCounter+3)+", NOW())")
valueCounter += 3
values = append(values, h.ProjectionName, aggregate, sequence)
for aggregate, instanceSequence := range sequences {
for _, sequence := range instanceSequence {
valueQueries = append(valueQueries, "($"+strconv.Itoa(valueCounter+1)+", $"+strconv.Itoa(valueCounter+2)+", $"+strconv.Itoa(valueCounter+3)+", $"+strconv.Itoa(valueCounter+4)+", NOW())")
valueCounter += 4
values = append(values, h.ProjectionName, aggregate, sequence.sequence, sequence.instanceID)
}
}
res, err := tx.Exec(h.updateSequencesBaseStmt+strings.Join(valueQueries, ", "), values...)

View File

@@ -3,7 +3,7 @@ package crdb
import (
"database/sql"
"database/sql/driver"
"log"
"sort"
"strings"
"time"
@@ -123,22 +123,22 @@ func expectSavePointRelease() func(sqlmock.Sqlmock) {
}
}
func expectCurrentSequence(tableName, projection string, seq uint64, aggregateType string) func(sqlmock.Sqlmock) {
func expectCurrentSequence(tableName, projection string, seq uint64, aggregateType, instanceID string) func(sqlmock.Sqlmock) {
return func(m sqlmock.Sqlmock) {
m.ExpectQuery(`SELECT current_sequence, aggregate_type FROM ` + tableName + ` WHERE projection_name = \$1 FOR UPDATE`).
m.ExpectQuery(`SELECT current_sequence, aggregate_type, instance_id FROM ` + tableName + ` WHERE projection_name = \$1 FOR UPDATE`).
WithArgs(
projection,
).
WillReturnRows(
sqlmock.NewRows([]string{"current_sequence", "aggregate_type"}).
AddRow(seq, aggregateType),
sqlmock.NewRows([]string{"current_sequence", "aggregate_type", "instance_id"}).
AddRow(seq, aggregateType, instanceID),
)
}
}
func expectCurrentSequenceErr(tableName, projection string, err error) func(sqlmock.Sqlmock) {
return func(m sqlmock.Sqlmock) {
m.ExpectQuery(`SELECT current_sequence, aggregate_type FROM ` + tableName + ` WHERE projection_name = \$1 FOR UPDATE`).
m.ExpectQuery(`SELECT current_sequence, aggregate_type, instance_id FROM ` + tableName + ` WHERE projection_name = \$1 FOR UPDATE`).
WithArgs(
projection,
).
@@ -148,37 +148,38 @@ func expectCurrentSequenceErr(tableName, projection string, err error) func(sqlm
func expectCurrentSequenceNoRows(tableName, projection string) func(sqlmock.Sqlmock) {
return func(m sqlmock.Sqlmock) {
m.ExpectQuery(`SELECT current_sequence, aggregate_type FROM ` + tableName + ` WHERE projection_name = \$1 FOR UPDATE`).
m.ExpectQuery(`SELECT current_sequence, aggregate_type, instance_id FROM ` + tableName + ` WHERE projection_name = \$1 FOR UPDATE`).
WithArgs(
projection,
).
WillReturnRows(
sqlmock.NewRows([]string{"current_sequence", "aggregate_type"}),
sqlmock.NewRows([]string{"current_sequence", "aggregate_type", "instance_id"}),
)
}
}
func expectCurrentSequenceScanErr(tableName, projection string) func(sqlmock.Sqlmock) {
return func(m sqlmock.Sqlmock) {
m.ExpectQuery(`SELECT current_sequence, aggregate_type FROM ` + tableName + ` WHERE projection_name = \$1 FOR UPDATE`).
m.ExpectQuery(`SELECT current_sequence, aggregate_type, instance_id FROM ` + tableName + ` WHERE projection_name = \$1 FOR UPDATE`).
WithArgs(
projection,
).
WillReturnRows(
sqlmock.NewRows([]string{"current_sequence", "aggregate_type"}).
sqlmock.NewRows([]string{"current_sequence", "aggregate_type", "instance_id"}).
RowError(0, sql.ErrTxDone).
AddRow(0, "agg"),
AddRow(0, "agg", "instanceID"),
)
}
}
func expectUpdateCurrentSequence(tableName, projection string, seq uint64, aggregateType string) func(sqlmock.Sqlmock) {
func expectUpdateCurrentSequence(tableName, projection string, seq uint64, aggregateType, instanceID string) func(sqlmock.Sqlmock) {
return func(m sqlmock.Sqlmock) {
m.ExpectExec("UPSERT INTO "+tableName+` \(projection_name, aggregate_type, current_sequence, timestamp\) VALUES \(\$1, \$2, \$3, NOW\(\)\)`).
m.ExpectExec("UPSERT INTO "+tableName+` \(projection_name, aggregate_type, current_sequence, instance_id, timestamp\) VALUES \(\$1, \$2, \$3, \$4, NOW\(\)\)`).
WithArgs(
projection,
aggregateType,
seq,
instanceID,
).
WillReturnResult(
sqlmock.NewResult(1, 1),
@@ -187,16 +188,26 @@ func expectUpdateCurrentSequence(tableName, projection string, seq uint64, aggre
}
func expectUpdateTwoCurrentSequence(tableName, projection string, sequences currentSequences) func(sqlmock.Sqlmock) {
//sort them so the args will always have the same order
keys := make([]string, 0, len(sequences))
for k := range sequences {
keys = append(keys, string(k))
}
sort.Strings(keys)
args := make([]driver.Value, len(keys)*4)
for i, k := range keys {
aggregateType := eventstore.AggregateType(k)
for _, sequence := range sequences[aggregateType] {
args[i*4] = projection
args[i*4+1] = aggregateType
args[i*4+2] = sequence.sequence
args[i*4+3] = sequence.instanceID
}
}
return func(m sqlmock.Sqlmock) {
matcher := &currentSequenceMatcher{seq: sequences}
m.ExpectExec("UPSERT INTO "+tableName+` \(projection_name, aggregate_type, current_sequence, timestamp\) VALUES \(\$1, \$2, \$3, NOW\(\)\), \(\$4, \$5, \$6, NOW\(\)\)`).
m.ExpectExec("UPSERT INTO " + tableName + ` \(projection_name, aggregate_type, current_sequence, instance_id, timestamp\) VALUES \(\$1, \$2, \$3, \$4, NOW\(\)\), \(\$5, \$6, \$7, \$8, NOW\(\)\)`).
WithArgs(
projection,
matcher,
matcher,
projection,
matcher,
matcher,
args...,
).
WillReturnResult(
sqlmock.NewResult(1, 1),
@@ -204,51 +215,27 @@ func expectUpdateTwoCurrentSequence(tableName, projection string, sequences curr
}
}
type currentSequenceMatcher struct {
seq currentSequences
currentAggregate eventstore.AggregateType
}
func (m *currentSequenceMatcher) Match(value driver.Value) bool {
switch v := value.(type) {
case string:
if m.currentAggregate != "" {
log.Printf("expected sequence of %s but got next aggregate type %s", m.currentAggregate, value)
return false
}
_, ok := m.seq[eventstore.AggregateType(v)]
if !ok {
return false
}
m.currentAggregate = eventstore.AggregateType(v)
return true
default:
seq := m.seq[m.currentAggregate]
m.currentAggregate = ""
delete(m.seq, m.currentAggregate)
return int64(seq) == value.(int64)
}
}
func expectUpdateCurrentSequenceErr(tableName, projection string, seq uint64, err error, aggregateType string) func(sqlmock.Sqlmock) {
func expectUpdateCurrentSequenceErr(tableName, projection string, seq uint64, err error, aggregateType, instanceID string) func(sqlmock.Sqlmock) {
return func(m sqlmock.Sqlmock) {
m.ExpectExec("UPSERT INTO "+tableName+` \(projection_name, aggregate_type, current_sequence, timestamp\) VALUES \(\$1, \$2, \$3, NOW\(\)\)`).
m.ExpectExec("UPSERT INTO "+tableName+` \(projection_name, aggregate_type, current_sequence, instance_id, timestamp\) VALUES \(\$1, \$2, \$3, \$4, NOW\(\)\)`).
WithArgs(
projection,
aggregateType,
seq,
instanceID,
).
WillReturnError(err)
}
}
func expectUpdateCurrentSequenceNoRows(tableName, projection string, seq uint64, aggregateType string) func(sqlmock.Sqlmock) {
func expectUpdateCurrentSequenceNoRows(tableName, projection string, seq uint64, aggregateType, instanceID string) func(sqlmock.Sqlmock) {
return func(m sqlmock.Sqlmock) {
m.ExpectExec("UPSERT INTO "+tableName+` \(projection_name, aggregate_type, current_sequence, timestamp\) VALUES \(\$1, \$2, \$3, NOW\(\)\)`).
m.ExpectExec("UPSERT INTO "+tableName+` \(projection_name, aggregate_type, current_sequence, instance_id, timestamp\) VALUES \(\$1, \$2, \$3, \$4, NOW\(\)\)`).
WithArgs(
projection,
aggregateType,
seq,
instanceID,
).
WillReturnResult(
sqlmock.NewResult(0, 0),
@@ -256,17 +243,18 @@ func expectUpdateCurrentSequenceNoRows(tableName, projection string, seq uint64,
}
}
func expectLock(lockTable, workerName string, d time.Duration) func(sqlmock.Sqlmock) {
func expectLock(lockTable, workerName string, d time.Duration, instanceID string) func(sqlmock.Sqlmock) {
return func(m sqlmock.Sqlmock) {
m.ExpectExec(`INSERT INTO `+lockTable+
` \(locker_id, locked_until, projection_name\) VALUES \(\$1, now\(\)\+\$2::INTERVAL, \$3\)`+
` ON CONFLICT \(projection_name\)`+
` \(locker_id, locked_until, projection_name, instance_id\) VALUES \(\$1, now\(\)\+\$2::INTERVAL, \$3\, \$4\)`+
` ON CONFLICT \(projection_name, instance_id\)`+
` DO UPDATE SET locker_id = \$1, locked_until = now\(\)\+\$2::INTERVAL`+
` WHERE `+lockTable+`\.projection_name = \$3 AND \(`+lockTable+`\.locker_id = \$1 OR `+lockTable+`\.locked_until < now\(\)\)`).
` WHERE `+lockTable+`\.projection_name = \$3 AND `+lockTable+`\.instance_id = \$4 AND \(`+lockTable+`\.locker_id = \$1 OR `+lockTable+`\.locked_until < now\(\)\)`).
WithArgs(
workerName,
float64(d),
projectionName,
instanceID,
).
WillReturnResult(
sqlmock.NewResult(1, 1),
@@ -274,33 +262,35 @@ func expectLock(lockTable, workerName string, d time.Duration) func(sqlmock.Sqlm
}
}
func expectLockNoRows(lockTable, workerName string, d time.Duration) func(sqlmock.Sqlmock) {
func expectLockNoRows(lockTable, workerName string, d time.Duration, instanceID string) func(sqlmock.Sqlmock) {
return func(m sqlmock.Sqlmock) {
m.ExpectExec(`INSERT INTO `+lockTable+
` \(locker_id, locked_until, projection_name\) VALUES \(\$1, now\(\)\+\$2::INTERVAL, \$3\)`+
` ON CONFLICT \(projection_name\)`+
` \(locker_id, locked_until, projection_name, instance_id\) VALUES \(\$1, now\(\)\+\$2::INTERVAL, \$3\, \$4\)`+
` ON CONFLICT \(projection_name, instance_id\)`+
` DO UPDATE SET locker_id = \$1, locked_until = now\(\)\+\$2::INTERVAL`+
` WHERE `+lockTable+`\.projection_name = \$3 AND \(`+lockTable+`\.locker_id = \$1 OR `+lockTable+`\.locked_until < now\(\)\)`).
` WHERE `+lockTable+`\.projection_name = \$3 AND `+lockTable+`\.instance_id = \$4 AND \(`+lockTable+`\.locker_id = \$1 OR `+lockTable+`\.locked_until < now\(\)\)`).
WithArgs(
workerName,
float64(d),
projectionName,
instanceID,
).
WillReturnResult(driver.ResultNoRows)
}
}
func expectLockErr(lockTable, workerName string, d time.Duration, err error) func(sqlmock.Sqlmock) {
func expectLockErr(lockTable, workerName string, d time.Duration, instanceID string, err error) func(sqlmock.Sqlmock) {
return func(m sqlmock.Sqlmock) {
m.ExpectExec(`INSERT INTO `+lockTable+
` \(locker_id, locked_until, projection_name\) VALUES \(\$1, now\(\)\+\$2::INTERVAL, \$3\)`+
` ON CONFLICT \(projection_name\)`+
` \(locker_id, locked_until, projection_name, instance_id\) VALUES \(\$1, now\(\)\+\$2::INTERVAL, \$3\, \$4\)`+
` ON CONFLICT \(projection_name, instance_id\)`+
` DO UPDATE SET locker_id = \$1, locked_until = now\(\)\+\$2::INTERVAL`+
` WHERE `+lockTable+`\.projection_name = \$3 AND \(`+lockTable+`\.locker_id = \$1 OR `+lockTable+`\.locked_until < now\(\)\)`).
` WHERE `+lockTable+`\.projection_name = \$3 AND `+lockTable+`\.instance_id = \$4 AND \(`+lockTable+`\.locker_id = \$1 OR `+lockTable+`\.locked_until < now\(\)\)`).
WithArgs(
workerName,
float64(d),
projectionName,
instanceID,
).
WillReturnError(err)
}

View File

@@ -101,15 +101,34 @@ func (h *StatementHandler) SearchQuery() (*eventstore.SearchQueryBuilder, uint64
queryBuilder := eventstore.NewSearchQueryBuilder(eventstore.ColumnsEvent).Limit(h.bulkLimit)
for _, aggregateType := range h.aggregates {
instances := make([]string, 0)
for _, sequence := range sequences[aggregateType] {
instances = appendToIgnoredInstances(instances, sequence.instanceID)
queryBuilder.
AddQuery().
AggregateTypes(aggregateType).
SequenceGreater(sequence.sequence).
InstanceID(sequence.instanceID)
}
queryBuilder.
AddQuery().
AggregateTypes(aggregateType).
SequenceGreater(sequences[aggregateType])
SequenceGreater(0).
ExcludedInstanceID(instances...)
}
return queryBuilder, h.bulkLimit, nil
}
func appendToIgnoredInstances(instances []string, id string) []string {
for _, instance := range instances {
if instance == id {
return instances
}
}
return append(instances, id)
}
//Update implements handler.Update
func (h *StatementHandler) Update(ctx context.Context, stmts []*handler.Statement, reduce handler.Reduce) (unexecutedStmts []*handler.Statement, err error) {
tx, err := h.client.BeginTx(ctx, nil)
@@ -127,7 +146,7 @@ func (h *StatementHandler) Update(ctx context.Context, stmts []*handler.Statemen
// because there could be events between current sequence and a creation event
// and we cannot check via stmt.PreviousSequence
if stmts[0].PreviousSequence == 0 {
previousStmts, err := h.fetchPreviousStmts(ctx, stmts[0].Sequence, sequences, reduce)
previousStmts, err := h.fetchPreviousStmts(ctx, stmts[0].Sequence, stmts[0].InstanceID, sequences, reduce)
if err != nil {
tx.Rollback()
return stmts, err
@@ -164,27 +183,25 @@ func (h *StatementHandler) Update(ctx context.Context, stmts []*handler.Statemen
return unexecutedStmts, nil
}
func (h *StatementHandler) fetchPreviousStmts(
ctx context.Context,
stmtSeq uint64,
sequences currentSequences,
reduce handler.Reduce,
) (previousStmts []*handler.Statement, err error) {
func (h *StatementHandler) fetchPreviousStmts(ctx context.Context, stmtSeq uint64, instanceID string, sequences currentSequences, reduce handler.Reduce) (previousStmts []*handler.Statement, err error) {
query := eventstore.NewSearchQueryBuilder(eventstore.ColumnsEvent)
queriesAdded := false
for _, aggregateType := range h.aggregates {
if stmtSeq <= sequences[aggregateType] {
continue
for _, sequence := range sequences[aggregateType] {
if stmtSeq <= sequence.sequence && instanceID == sequence.instanceID {
continue
}
query.
AddQuery().
AggregateTypes(aggregateType).
SequenceGreater(sequence.sequence).
SequenceLess(stmtSeq).
InstanceID(sequence.instanceID)
queriesAdded = true
}
query.
AddQuery().
AggregateTypes(aggregateType).
SequenceGreater(sequences[aggregateType]).
SequenceLess(stmtSeq)
queriesAdded = true
}
if !queriesAdded {
@@ -214,16 +231,19 @@ func (h *StatementHandler) executeStmts(
lastSuccessfulIdx := -1
for i, stmt := range stmts {
if stmt.Sequence <= sequences[stmt.AggregateType] {
continue
}
if stmt.PreviousSequence > 0 && stmt.PreviousSequence != sequences[stmt.AggregateType] {
logging.WithFields("projection", h.ProjectionName, "aggregateType", stmt.AggregateType, "sequence", stmt.Sequence, "prevSeq", stmt.PreviousSequence, "currentSeq", sequences[stmt.AggregateType]).Warn("sequences do not match")
break
for _, sequence := range sequences[stmt.AggregateType] {
if stmt.Sequence <= sequence.sequence && stmt.InstanceID == sequence.instanceID {
continue
}
if stmt.PreviousSequence > 0 && stmt.PreviousSequence != sequence.sequence && stmt.InstanceID == sequence.instanceID {
logging.WithFields("projection", h.ProjectionName, "aggregateType", stmt.AggregateType, "sequence", stmt.Sequence, "prevSeq", stmt.PreviousSequence, "currentSeq", sequences[stmt.AggregateType]).Warn("sequences do not match")
break
}
}
err := h.executeStmt(tx, stmt)
if err == nil {
sequences[stmt.AggregateType], lastSuccessfulIdx = stmt.Sequence, i
updateSequences(sequences, stmt)
lastSuccessfulIdx = i
continue
}
@@ -232,7 +252,9 @@ func (h *StatementHandler) executeStmts(
break
}
sequences[stmt.AggregateType], lastSuccessfulIdx = stmt.Sequence, i
updateSequences(sequences, stmt)
lastSuccessfulIdx = i
continue
}
return lastSuccessfulIdx
}
@@ -261,3 +283,16 @@ func (h *StatementHandler) executeStmt(tx *sql.Tx, stmt *handler.Statement) erro
}
return nil
}
func updateSequences(sequences currentSequences, stmt *handler.Statement) {
for _, sequence := range sequences[stmt.AggregateType] {
if sequence.instanceID == stmt.InstanceID {
sequence.sequence = stmt.Sequence
return
}
}
sequences[stmt.AggregateType] = append(sequences[stmt.AggregateType], &instanceSequence{
instanceID: stmt.InstanceID,
sequence: stmt.Sequence,
})
}

View File

@@ -9,6 +9,7 @@ import (
"time"
"github.com/DATA-DOG/go-sqlmock"
"github.com/stretchr/testify/assert"
"github.com/caos/zitadel/internal/eventstore"
"github.com/caos/zitadel/internal/eventstore/handler"
@@ -97,13 +98,18 @@ func TestProjectionHandler_SearchQuery(t *testing.T) {
return err == nil
},
expectations: []mockExpectation{
expectCurrentSequence("my_sequences", "my_projection", 5, "testAgg"),
expectCurrentSequence("my_sequences", "my_projection", 5, "testAgg", "instanceID"),
},
SearchQueryBuilder: eventstore.
NewSearchQueryBuilder(eventstore.ColumnsEvent).
AddQuery().
AggregateTypes("testAgg").
SequenceGreater(5).
InstanceID("instanceID").
Or().
AggregateTypes("testAgg").
SequenceGreater(0).
ExcludedInstanceID("instanceID").
Builder().
Limit(5),
},
@@ -225,7 +231,7 @@ func TestStatementHandler_Update(t *testing.T) {
want: want{
expectations: []mockExpectation{
expectBegin(),
expectCurrentSequence("my_sequences", "my_projection", 5, "testAgg"),
expectCurrentSequence("my_sequences", "my_projection", 5, "testAgg", "instanceID"),
expectRollback(),
},
isErr: func(err error) bool {
@@ -262,7 +268,7 @@ func TestStatementHandler_Update(t *testing.T) {
want: want{
expectations: []mockExpectation{
expectBegin(),
expectCurrentSequence("my_sequences", "my_projection", 5, "testAgg"),
expectCurrentSequence("my_sequences", "my_projection", 5, "testAgg", "instanceID"),
expectCommit(),
},
isErr: func(err error) bool {
@@ -287,6 +293,7 @@ func TestStatementHandler_Update(t *testing.T) {
aggregateType: "agg",
sequence: 7,
previousSequence: 5,
instanceID: "instanceID",
},
[]handler.Column{
{
@@ -299,11 +306,11 @@ func TestStatementHandler_Update(t *testing.T) {
want: want{
expectations: []mockExpectation{
expectBegin(),
expectCurrentSequence("my_sequences", "my_projection", 5, "agg"),
expectCurrentSequence("my_sequences", "my_projection", 5, "agg", "instanceID"),
expectSavePoint(),
expectCreate("my_projection", []string{"col"}, []string{"$1"}),
expectSavePointRelease(),
expectUpdateCurrentSequenceNoRows("my_sequences", "my_projection", 7, "agg"),
expectUpdateCurrentSequenceNoRows("my_sequences", "my_projection", 7, "agg", "instanceID"),
expectRollback(),
},
isErr: func(err error) bool {
@@ -328,6 +335,7 @@ func TestStatementHandler_Update(t *testing.T) {
aggregateType: "agg",
sequence: 7,
previousSequence: 5,
instanceID: "instanceID",
},
[]handler.Column{
{
@@ -340,11 +348,11 @@ func TestStatementHandler_Update(t *testing.T) {
want: want{
expectations: []mockExpectation{
expectBegin(),
expectCurrentSequence("my_sequences", "my_projection", 5, "agg"),
expectCurrentSequence("my_sequences", "my_projection", 5, "agg", "instanceID"),
expectSavePoint(),
expectCreate("my_projection", []string{"col"}, []string{"$1"}),
expectSavePointRelease(),
expectUpdateCurrentSequence("my_sequences", "my_projection", 7, "agg"),
expectUpdateCurrentSequence("my_sequences", "my_projection", 7, "agg", "instanceID"),
expectCommitErr(sql.ErrConnDone),
},
isErr: func(err error) bool {
@@ -368,14 +376,15 @@ func TestStatementHandler_Update(t *testing.T) {
aggregateType: "testAgg",
sequence: 7,
previousSequence: 5,
instanceID: "instanceID",
}),
},
},
want: want{
expectations: []mockExpectation{
expectBegin(),
expectCurrentSequence("my_sequences", "my_projection", 5, "testAgg"),
expectUpdateCurrentSequence("my_sequences", "my_projection", 7, "testAgg"),
expectCurrentSequence("my_sequences", "my_projection", 5, "testAgg", "instanceID"),
expectUpdateCurrentSequence("my_sequences", "my_projection", 7, "testAgg", "instanceID"),
expectCommit(),
},
isErr: func(err error) bool {
@@ -399,14 +408,15 @@ func TestStatementHandler_Update(t *testing.T) {
aggregateType: "testAgg",
sequence: 7,
previousSequence: 0,
instanceID: "instanceID",
}),
},
},
want: want{
expectations: []mockExpectation{
expectBegin(),
expectCurrentSequence("my_sequences", "my_projection", 5, "testAgg"),
expectUpdateCurrentSequence("my_sequences", "my_projection", 7, "testAgg"),
expectCurrentSequence("my_sequences", "my_projection", 5, "testAgg", "instanceID"),
expectUpdateCurrentSequence("my_sequences", "my_projection", 7, "testAgg", "instanceID"),
expectCommit(),
},
isErr: func(err error) bool {
@@ -423,6 +433,7 @@ func TestStatementHandler_Update(t *testing.T) {
AggregateType: "testAgg",
Sequence: 6,
PreviousAggregateSequence: 5,
InstanceID: "instanceID",
},
),
),
@@ -435,6 +446,7 @@ func TestStatementHandler_Update(t *testing.T) {
aggregateType: "testAgg",
sequence: 7,
previousSequence: 0,
instanceID: "instanceID",
}),
},
reduce: testReduce(),
@@ -442,8 +454,8 @@ func TestStatementHandler_Update(t *testing.T) {
want: want{
expectations: []mockExpectation{
expectBegin(),
expectCurrentSequence("my_sequences", "my_projection", 5, "testAgg"),
expectUpdateCurrentSequence("my_sequences", "my_projection", 7, "testAgg"),
expectCurrentSequence("my_sequences", "my_projection", 5, "testAgg", "instanceID"),
expectUpdateCurrentSequence("my_sequences", "my_projection", 7, "testAgg", "instanceID"),
expectCommit(),
},
isErr: func(err error) bool {
@@ -537,7 +549,9 @@ func TestProjectionHandler_fetchPreviousStmts(t *testing.T) {
ctx: context.Background(),
reduce: testReduce(),
sequences: currentSequences{
"testAgg": 5,
"testAgg": []*instanceSequence{
{sequence: 5},
},
},
stmtSeq: 6,
},
@@ -560,7 +574,9 @@ func TestProjectionHandler_fetchPreviousStmts(t *testing.T) {
ctx: context.Background(),
reduce: testReduce(),
sequences: currentSequences{
"testAgg": 5,
"testAgg": []*instanceSequence{
{sequence: 5},
},
},
stmtSeq: 6,
},
@@ -582,7 +598,9 @@ func TestProjectionHandler_fetchPreviousStmts(t *testing.T) {
ctx: context.Background(),
reduce: testReduce(),
sequences: currentSequences{
"testAgg": 5,
"testAgg": []*instanceSequence{
{sequence: 5},
},
},
stmtSeq: 10,
},
@@ -626,7 +644,9 @@ func TestProjectionHandler_fetchPreviousStmts(t *testing.T) {
ctx: context.Background(),
reduce: testReduceErr(errReduce),
sequences: currentSequences{
"testAgg": 5,
"testAgg": []*instanceSequence{
{sequence: 5},
},
},
stmtSeq: 10,
},
@@ -667,7 +687,7 @@ func TestProjectionHandler_fetchPreviousStmts(t *testing.T) {
}),
aggregates: tt.fields.aggregates,
}
stmts, err := h.fetchPreviousStmts(tt.args.ctx, tt.args.stmtSeq, tt.args.sequences, tt.args.reduce)
stmts, err := h.fetchPreviousStmts(tt.args.ctx, tt.args.stmtSeq, "", tt.args.sequences, tt.args.reduce)
if !tt.want.isErr(err) {
t.Errorf("ProjectionHandler.prepareBulkStmts() error = %v", err)
return
@@ -720,7 +740,9 @@ func TestStatementHandler_executeStmts(t *testing.T) {
}),
},
sequences: currentSequences{
"agg": 5,
"agg": []*instanceSequence{
{sequence: 5},
},
},
},
want: want{
@@ -762,7 +784,9 @@ func TestStatementHandler_executeStmts(t *testing.T) {
}),
},
sequences: currentSequences{
"agg": 2,
"agg": []*instanceSequence{
{sequence: 2},
},
},
},
want: want{
@@ -824,7 +848,9 @@ func TestStatementHandler_executeStmts(t *testing.T) {
}),
},
sequences: currentSequences{
"agg": 2,
"agg": []*instanceSequence{
{sequence: 2},
},
},
},
want: want{
@@ -891,7 +917,9 @@ func TestStatementHandler_executeStmts(t *testing.T) {
}),
},
sequences: currentSequences{
"agg": 2,
"agg": []*instanceSequence{
{sequence: 2},
},
},
},
want: want{
@@ -979,7 +1007,9 @@ func TestStatementHandler_executeStmts(t *testing.T) {
),
},
sequences: currentSequences{
"agg": 2,
"agg": []*instanceSequence{
{sequence: 2},
},
},
},
want: want{
@@ -1309,9 +1339,7 @@ func TestStatementHandler_currentSequence(t *testing.T) {
expectations: []mockExpectation{
expectCurrentSequenceNoRows("my_table", "my_projection"),
},
sequences: currentSequences{
"agg": 0,
},
sequences: currentSequences{},
},
},
{
@@ -1331,9 +1359,7 @@ func TestStatementHandler_currentSequence(t *testing.T) {
expectations: []mockExpectation{
expectCurrentSequenceScanErr("my_table", "my_projection"),
},
sequences: currentSequences{
"agg": 0,
},
sequences: currentSequences{},
},
},
{
@@ -1351,10 +1377,15 @@ func TestStatementHandler_currentSequence(t *testing.T) {
return errors.Is(err, nil)
},
expectations: []mockExpectation{
expectCurrentSequence("my_table", "my_projection", 5, "agg"),
expectCurrentSequence("my_table", "my_projection", 5, "agg", "instanceID"),
},
sequences: currentSequences{
"agg": 5,
"agg": []*instanceSequence{
{
sequence: 5,
instanceID: "instanceID",
},
},
},
},
},
@@ -1404,9 +1435,7 @@ func TestStatementHandler_currentSequence(t *testing.T) {
}
for _, aggregateType := range tt.fields.aggregates {
if seq[aggregateType] != tt.want.sequences[aggregateType] {
t.Errorf("unexpected sequence in aggregate type %s: want %d got %d", aggregateType, tt.want.sequences[aggregateType], seq[aggregateType])
}
assert.Equal(t, tt.want.sequences[aggregateType], seq[aggregateType])
}
})
}
@@ -1440,7 +1469,12 @@ func TestStatementHandler_updateCurrentSequence(t *testing.T) {
},
args: args{
sequences: currentSequences{
"agg": 5,
"agg": []*instanceSequence{
{
sequence: 5,
instanceID: "instanceID",
},
},
},
},
want: want{
@@ -1448,7 +1482,7 @@ func TestStatementHandler_updateCurrentSequence(t *testing.T) {
return errors.Is(err, sql.ErrConnDone)
},
expectations: []mockExpectation{
expectUpdateCurrentSequenceErr("my_table", "my_projection", 5, sql.ErrConnDone, "agg"),
expectUpdateCurrentSequenceErr("my_table", "my_projection", 5, sql.ErrConnDone, "agg", "instanceID"),
},
},
},
@@ -1461,7 +1495,12 @@ func TestStatementHandler_updateCurrentSequence(t *testing.T) {
},
args: args{
sequences: currentSequences{
"agg": 5,
"agg": []*instanceSequence{
{
sequence: 5,
instanceID: "instanceID",
},
},
},
},
want: want{
@@ -1469,7 +1508,7 @@ func TestStatementHandler_updateCurrentSequence(t *testing.T) {
return errors.As(err, &errSeqNotUpdated)
},
expectations: []mockExpectation{
expectUpdateCurrentSequenceNoRows("my_table", "my_projection", 5, "agg"),
expectUpdateCurrentSequenceNoRows("my_table", "my_projection", 5, "agg", "instanceID"),
},
},
},
@@ -1482,7 +1521,12 @@ func TestStatementHandler_updateCurrentSequence(t *testing.T) {
},
args: args{
sequences: currentSequences{
"agg": 5,
"agg": []*instanceSequence{
{
sequence: 5,
instanceID: "instanceID",
},
},
},
},
want: want{
@@ -1490,7 +1534,7 @@ func TestStatementHandler_updateCurrentSequence(t *testing.T) {
return err == nil
},
expectations: []mockExpectation{
expectUpdateCurrentSequence("my_table", "my_projection", 5, "agg"),
expectUpdateCurrentSequence("my_table", "my_projection", 5, "agg", "instanceID"),
},
},
},
@@ -1503,8 +1547,18 @@ func TestStatementHandler_updateCurrentSequence(t *testing.T) {
},
args: args{
sequences: currentSequences{
"agg": 5,
"agg2": 6,
"agg": []*instanceSequence{
{
sequence: 5,
instanceID: "instanceID",
},
},
"agg2": []*instanceSequence{
{
sequence: 6,
instanceID: "instanceID",
},
},
},
},
want: want{
@@ -1513,9 +1567,19 @@ func TestStatementHandler_updateCurrentSequence(t *testing.T) {
},
expectations: []mockExpectation{
expectUpdateTwoCurrentSequence("my_table", "my_projection", currentSequences{
"agg": 5,
"agg2": 6},
),
"agg": []*instanceSequence{
{
sequence: 5,
instanceID: "instanceID",
},
},
"agg2": []*instanceSequence{
{
sequence: 6,
instanceID: "instanceID",
},
},
}),
},
},
},

View File

@@ -15,15 +15,15 @@ import (
const (
lockStmtFormat = "INSERT INTO %[1]s" +
" (locker_id, locked_until, projection_name) VALUES ($1, now()+$2::INTERVAL, $3)" +
" ON CONFLICT (projection_name)" +
" (locker_id, locked_until, projection_name, instance_id) VALUES ($1, now()+$2::INTERVAL, $3, $4)" +
" ON CONFLICT (projection_name, instance_id)" +
" DO UPDATE SET locker_id = $1, locked_until = now()+$2::INTERVAL" +
" WHERE %[1]s.projection_name = $3 AND (%[1]s.locker_id = $1 OR %[1]s.locked_until < now())"
" WHERE %[1]s.projection_name = $3 AND %[1]s.instance_id = $4 AND (%[1]s.locker_id = $1 OR %[1]s.locked_until < now())"
)
type Locker interface {
Lock(ctx context.Context, lockDuration time.Duration) <-chan error
Unlock() error
Lock(ctx context.Context, lockDuration time.Duration, instanceID string) <-chan error
Unlock(instanceID string) error
}
type locker struct {
@@ -47,18 +47,18 @@ func NewLocker(client *sql.DB, lockTable, projectionName string) Locker {
}
}
func (h *locker) Lock(ctx context.Context, lockDuration time.Duration) <-chan error {
func (h *locker) Lock(ctx context.Context, lockDuration time.Duration, instanceID string) <-chan error {
errs := make(chan error)
go h.handleLock(ctx, errs, lockDuration)
go h.handleLock(ctx, errs, lockDuration, instanceID)
return errs
}
func (h *locker) handleLock(ctx context.Context, errs chan error, lockDuration time.Duration) {
func (h *locker) handleLock(ctx context.Context, errs chan error, lockDuration time.Duration, instanceID string) {
renewLock := time.NewTimer(0)
for {
select {
case <-renewLock.C:
errs <- h.renewLock(ctx, lockDuration)
errs <- h.renewLock(ctx, lockDuration, instanceID)
//refresh the lock 500ms before it times out. 500ms should be enough for one transaction
renewLock.Reset(lockDuration - (500 * time.Millisecond))
case <-ctx.Done():
@@ -69,9 +69,9 @@ func (h *locker) handleLock(ctx context.Context, errs chan error, lockDuration t
}
}
func (h *locker) renewLock(ctx context.Context, lockDuration time.Duration) error {
func (h *locker) renewLock(ctx context.Context, lockDuration time.Duration, instanceID string) error {
//the unit of crdb interval is seconds (https://www.cockroachlabs.com/docs/stable/interval.html).
res, err := h.client.Exec(h.lockStmt, h.workerName, lockDuration.Seconds(), h.projectionName)
res, err := h.client.Exec(h.lockStmt, h.workerName, lockDuration.Seconds(), h.projectionName, instanceID)
if err != nil {
return errors.ThrowInternal(err, "CRDB-uaDoR", "unable to execute lock")
}
@@ -83,8 +83,8 @@ func (h *locker) renewLock(ctx context.Context, lockDuration time.Duration) erro
return nil
}
func (h *locker) Unlock() error {
_, err := h.client.Exec(h.lockStmt, h.workerName, float64(0), h.projectionName)
func (h *locker) Unlock(instanceID string) error {
_, err := h.client.Exec(h.lockStmt, h.workerName, float64(0), h.projectionName, instanceID)
if err != nil {
return errors.ThrowUnknown(err, "CRDB-JjfwO", "unlock failed")
}

View File

@@ -32,6 +32,7 @@ func TestStatementHandler_handleLock(t *testing.T) {
lockDuration time.Duration
ctx context.Context
errMock *errsMock
instanceID string
}
tests := []struct {
name string
@@ -42,9 +43,9 @@ func TestStatementHandler_handleLock(t *testing.T) {
name: "lock fails",
want: want{
expectations: []mockExpectation{
expectLock(lockTable, workerName, 2),
expectLock(lockTable, workerName, 2),
expectLockErr(lockTable, workerName, 2, errLock),
expectLock(lockTable, workerName, 2, "instanceID"),
expectLock(lockTable, workerName, 2, "instanceID"),
expectLockErr(lockTable, workerName, 2, "instanceID", errLock),
},
},
args: args{
@@ -55,14 +56,15 @@ func TestStatementHandler_handleLock(t *testing.T) {
successfulIters: 2,
shouldErr: true,
},
instanceID: "instanceID",
},
},
{
name: "success",
want: want{
expectations: []mockExpectation{
expectLock(lockTable, workerName, 2),
expectLock(lockTable, workerName, 2),
expectLock(lockTable, workerName, 2, "instanceID"),
expectLock(lockTable, workerName, 2, "instanceID"),
},
},
args: args{
@@ -72,6 +74,7 @@ func TestStatementHandler_handleLock(t *testing.T) {
errs: make(chan error),
successfulIters: 2,
},
instanceID: "instanceID",
},
},
}
@@ -96,7 +99,7 @@ func TestStatementHandler_handleLock(t *testing.T) {
go tt.args.errMock.handleErrs(t, cancel)
go h.handleLock(ctx, tt.args.errMock.errs, tt.args.lockDuration)
go h.handleLock(ctx, tt.args.errMock.errs, tt.args.lockDuration, tt.args.instanceID)
<-ctx.Done()
@@ -115,6 +118,7 @@ func TestStatementHandler_renewLock(t *testing.T) {
}
type args struct {
lockDuration time.Duration
instanceID string
}
tests := []struct {
name string
@@ -125,7 +129,7 @@ func TestStatementHandler_renewLock(t *testing.T) {
name: "lock fails",
want: want{
expectations: []mockExpectation{
expectLockErr(lockTable, workerName, 1, sql.ErrTxDone),
expectLockErr(lockTable, workerName, 1, "instanceID", sql.ErrTxDone),
},
isErr: func(err error) bool {
return errors.Is(err, sql.ErrTxDone)
@@ -133,13 +137,14 @@ func TestStatementHandler_renewLock(t *testing.T) {
},
args: args{
lockDuration: 1 * time.Second,
instanceID: "instanceID",
},
},
{
name: "lock no rows",
want: want{
expectations: []mockExpectation{
expectLockNoRows(lockTable, workerName, 2),
expectLockNoRows(lockTable, workerName, 2, "instanceID"),
},
isErr: func(err error) bool {
return errors.As(err, &renewNoRowsAffectedErr)
@@ -147,13 +152,14 @@ func TestStatementHandler_renewLock(t *testing.T) {
},
args: args{
lockDuration: 2 * time.Second,
instanceID: "instanceID",
},
},
{
name: "success",
want: want{
expectations: []mockExpectation{
expectLock(lockTable, workerName, 3),
expectLock(lockTable, workerName, 3, "instanceID"),
},
isErr: func(err error) bool {
return errors.Is(err, nil)
@@ -161,6 +167,7 @@ func TestStatementHandler_renewLock(t *testing.T) {
},
args: args{
lockDuration: 3 * time.Second,
instanceID: "instanceID",
},
},
}
@@ -181,7 +188,7 @@ func TestStatementHandler_renewLock(t *testing.T) {
expectation(mock)
}
err = h.renewLock(context.Background(), tt.args.lockDuration)
err = h.renewLock(context.Background(), tt.args.lockDuration, tt.args.instanceID)
if !tt.want.isErr(err) {
t.Errorf("unexpected error = %v", err)
}
@@ -199,15 +206,22 @@ func TestStatementHandler_Unlock(t *testing.T) {
expectations []mockExpectation
isErr func(err error) bool
}
type args struct {
instanceID string
}
tests := []struct {
name string
args args
want want
}{
{
name: "unlock fails",
args: args{
instanceID: "instanceID",
},
want: want{
expectations: []mockExpectation{
expectLockErr(lockTable, workerName, 0, sql.ErrTxDone),
expectLockErr(lockTable, workerName, 0, "instanceID", sql.ErrTxDone),
},
isErr: func(err error) bool {
return errors.Is(err, sql.ErrTxDone)
@@ -216,9 +230,12 @@ func TestStatementHandler_Unlock(t *testing.T) {
},
{
name: "success",
args: args{
instanceID: "instanceID",
},
want: want{
expectations: []mockExpectation{
expectLock(lockTable, workerName, 0),
expectLock(lockTable, workerName, 0, "instanceID"),
},
isErr: func(err error) bool {
return errors.Is(err, nil)
@@ -243,7 +260,7 @@ func TestStatementHandler_Unlock(t *testing.T) {
expectation(mock)
}
err = h.Unlock()
err = h.Unlock(tt.args.instanceID)
if !tt.want.isErr(err) {
t.Errorf("unexpected error = %v", err)
}

View File

@@ -12,6 +12,8 @@ import (
"github.com/caos/zitadel/internal/eventstore"
)
const systemID = "system"
type ProjectionHandlerConfig struct {
HandlerConfig
ProjectionName string
@@ -27,10 +29,10 @@ type Update func(context.Context, []*Statement, Reduce) (unexecutedStmts []*Stat
type Reduce func(eventstore.Event) (*Statement, error)
//Lock is used for mutex handling if needed on the projection
type Lock func(context.Context, time.Duration) <-chan error
type Lock func(context.Context, time.Duration, string) <-chan error
//Unlock releases the mutex of the projection
type Unlock func() error
type Unlock func(string) error
//SearchQuery generates the search query to lookup for events
type SearchQuery func() (query *eventstore.SearchQueryBuilder, queryLimit uint64, err error)
@@ -183,7 +185,7 @@ func (h *ProjectionHandler) bulk(
ctx, cancel := context.WithCancel(ctx)
defer cancel()
errs := lock(ctx, h.requeueAfter)
errs := lock(ctx, h.requeueAfter, systemID)
//wait until projection is locked
if err, ok := <-errs; err != nil || !ok {
logging.WithFields("projection", h.ProjectionName).OnError(err).Warn("initial lock failed")
@@ -194,7 +196,7 @@ func (h *ProjectionHandler) bulk(
execErr := executeBulk(ctx)
logging.WithFields("projection", h.ProjectionName).OnError(execErr).Warn("unable to execute")
unlockErr := unlock()
unlockErr := unlock(systemID)
logging.WithFields("projection", h.ProjectionName).OnError(unlockErr).Warn("unable to unlock")
if execErr != nil {

View File

@@ -912,7 +912,7 @@ type lockMock struct {
}
func (m *lockMock) lock() Lock {
return func(ctx context.Context, _ time.Duration) <-chan error {
return func(ctx context.Context, _ time.Duration, _ string) <-chan error {
m.callCount++
errs := make(chan error)
go func() {
@@ -955,7 +955,7 @@ type unlockMock struct {
}
func (m *unlockMock) unlock() Unlock {
return func() error {
return func(instanceID string) error {
m.callCount++
return m.err
}