mirror of
https://github.com/zitadel/zitadel.git
synced 2025-08-12 08:37:32 +00:00
feat: handle instanceID in projections (#3442)
* feat: handle instanceID in projections * rename functions * fix key lock * fix import
This commit is contained in:
@@ -10,11 +10,16 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
currentSequenceStmtFormat = `SELECT current_sequence, aggregate_type FROM %s WHERE projection_name = $1 FOR UPDATE`
|
||||
updateCurrentSequencesStmtFormat = `UPSERT INTO %s (projection_name, aggregate_type, current_sequence, timestamp) VALUES `
|
||||
currentSequenceStmtFormat = `SELECT current_sequence, aggregate_type, instance_id FROM %s WHERE projection_name = $1 FOR UPDATE`
|
||||
updateCurrentSequencesStmtFormat = `UPSERT INTO %s (projection_name, aggregate_type, current_sequence, instance_id, timestamp) VALUES `
|
||||
)
|
||||
|
||||
type currentSequences map[eventstore.AggregateType]uint64
|
||||
type currentSequences map[eventstore.AggregateType][]*instanceSequence
|
||||
|
||||
type instanceSequence struct {
|
||||
instanceID string
|
||||
sequence uint64
|
||||
}
|
||||
|
||||
func (h *StatementHandler) currentSequences(query func(string, ...interface{}) (*sql.Rows, error)) (currentSequences, error) {
|
||||
rows, err := query(h.currentSequenceStmt, h.ProjectionName)
|
||||
@@ -29,14 +34,18 @@ func (h *StatementHandler) currentSequences(query func(string, ...interface{}) (
|
||||
var (
|
||||
aggregateType eventstore.AggregateType
|
||||
sequence uint64
|
||||
instanceID string
|
||||
)
|
||||
|
||||
err = rows.Scan(&sequence, &aggregateType)
|
||||
err = rows.Scan(&sequence, &aggregateType, &instanceID)
|
||||
if err != nil {
|
||||
return nil, errors.ThrowInternal(err, "CRDB-dbatK", "scan failed")
|
||||
}
|
||||
|
||||
sequences[aggregateType] = sequence
|
||||
sequences[aggregateType] = append(sequences[aggregateType], &instanceSequence{
|
||||
sequence: sequence,
|
||||
instanceID: instanceID,
|
||||
})
|
||||
}
|
||||
|
||||
if err = rows.Close(); err != nil {
|
||||
@@ -54,10 +63,12 @@ func (h *StatementHandler) updateCurrentSequences(tx *sql.Tx, sequences currentS
|
||||
valueQueries := make([]string, 0, len(sequences))
|
||||
valueCounter := 0
|
||||
values := make([]interface{}, 0, len(sequences)*3)
|
||||
for aggregate, sequence := range sequences {
|
||||
valueQueries = append(valueQueries, "($"+strconv.Itoa(valueCounter+1)+", $"+strconv.Itoa(valueCounter+2)+", $"+strconv.Itoa(valueCounter+3)+", NOW())")
|
||||
valueCounter += 3
|
||||
values = append(values, h.ProjectionName, aggregate, sequence)
|
||||
for aggregate, instanceSequence := range sequences {
|
||||
for _, sequence := range instanceSequence {
|
||||
valueQueries = append(valueQueries, "($"+strconv.Itoa(valueCounter+1)+", $"+strconv.Itoa(valueCounter+2)+", $"+strconv.Itoa(valueCounter+3)+", $"+strconv.Itoa(valueCounter+4)+", NOW())")
|
||||
valueCounter += 4
|
||||
values = append(values, h.ProjectionName, aggregate, sequence.sequence, sequence.instanceID)
|
||||
}
|
||||
}
|
||||
|
||||
res, err := tx.Exec(h.updateSequencesBaseStmt+strings.Join(valueQueries, ", "), values...)
|
||||
|
@@ -3,7 +3,7 @@ package crdb
|
||||
import (
|
||||
"database/sql"
|
||||
"database/sql/driver"
|
||||
"log"
|
||||
"sort"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
@@ -123,22 +123,22 @@ func expectSavePointRelease() func(sqlmock.Sqlmock) {
|
||||
}
|
||||
}
|
||||
|
||||
func expectCurrentSequence(tableName, projection string, seq uint64, aggregateType string) func(sqlmock.Sqlmock) {
|
||||
func expectCurrentSequence(tableName, projection string, seq uint64, aggregateType, instanceID string) func(sqlmock.Sqlmock) {
|
||||
return func(m sqlmock.Sqlmock) {
|
||||
m.ExpectQuery(`SELECT current_sequence, aggregate_type FROM ` + tableName + ` WHERE projection_name = \$1 FOR UPDATE`).
|
||||
m.ExpectQuery(`SELECT current_sequence, aggregate_type, instance_id FROM ` + tableName + ` WHERE projection_name = \$1 FOR UPDATE`).
|
||||
WithArgs(
|
||||
projection,
|
||||
).
|
||||
WillReturnRows(
|
||||
sqlmock.NewRows([]string{"current_sequence", "aggregate_type"}).
|
||||
AddRow(seq, aggregateType),
|
||||
sqlmock.NewRows([]string{"current_sequence", "aggregate_type", "instance_id"}).
|
||||
AddRow(seq, aggregateType, instanceID),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
func expectCurrentSequenceErr(tableName, projection string, err error) func(sqlmock.Sqlmock) {
|
||||
return func(m sqlmock.Sqlmock) {
|
||||
m.ExpectQuery(`SELECT current_sequence, aggregate_type FROM ` + tableName + ` WHERE projection_name = \$1 FOR UPDATE`).
|
||||
m.ExpectQuery(`SELECT current_sequence, aggregate_type, instance_id FROM ` + tableName + ` WHERE projection_name = \$1 FOR UPDATE`).
|
||||
WithArgs(
|
||||
projection,
|
||||
).
|
||||
@@ -148,37 +148,38 @@ func expectCurrentSequenceErr(tableName, projection string, err error) func(sqlm
|
||||
|
||||
func expectCurrentSequenceNoRows(tableName, projection string) func(sqlmock.Sqlmock) {
|
||||
return func(m sqlmock.Sqlmock) {
|
||||
m.ExpectQuery(`SELECT current_sequence, aggregate_type FROM ` + tableName + ` WHERE projection_name = \$1 FOR UPDATE`).
|
||||
m.ExpectQuery(`SELECT current_sequence, aggregate_type, instance_id FROM ` + tableName + ` WHERE projection_name = \$1 FOR UPDATE`).
|
||||
WithArgs(
|
||||
projection,
|
||||
).
|
||||
WillReturnRows(
|
||||
sqlmock.NewRows([]string{"current_sequence", "aggregate_type"}),
|
||||
sqlmock.NewRows([]string{"current_sequence", "aggregate_type", "instance_id"}),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
func expectCurrentSequenceScanErr(tableName, projection string) func(sqlmock.Sqlmock) {
|
||||
return func(m sqlmock.Sqlmock) {
|
||||
m.ExpectQuery(`SELECT current_sequence, aggregate_type FROM ` + tableName + ` WHERE projection_name = \$1 FOR UPDATE`).
|
||||
m.ExpectQuery(`SELECT current_sequence, aggregate_type, instance_id FROM ` + tableName + ` WHERE projection_name = \$1 FOR UPDATE`).
|
||||
WithArgs(
|
||||
projection,
|
||||
).
|
||||
WillReturnRows(
|
||||
sqlmock.NewRows([]string{"current_sequence", "aggregate_type"}).
|
||||
sqlmock.NewRows([]string{"current_sequence", "aggregate_type", "instance_id"}).
|
||||
RowError(0, sql.ErrTxDone).
|
||||
AddRow(0, "agg"),
|
||||
AddRow(0, "agg", "instanceID"),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
func expectUpdateCurrentSequence(tableName, projection string, seq uint64, aggregateType string) func(sqlmock.Sqlmock) {
|
||||
func expectUpdateCurrentSequence(tableName, projection string, seq uint64, aggregateType, instanceID string) func(sqlmock.Sqlmock) {
|
||||
return func(m sqlmock.Sqlmock) {
|
||||
m.ExpectExec("UPSERT INTO "+tableName+` \(projection_name, aggregate_type, current_sequence, timestamp\) VALUES \(\$1, \$2, \$3, NOW\(\)\)`).
|
||||
m.ExpectExec("UPSERT INTO "+tableName+` \(projection_name, aggregate_type, current_sequence, instance_id, timestamp\) VALUES \(\$1, \$2, \$3, \$4, NOW\(\)\)`).
|
||||
WithArgs(
|
||||
projection,
|
||||
aggregateType,
|
||||
seq,
|
||||
instanceID,
|
||||
).
|
||||
WillReturnResult(
|
||||
sqlmock.NewResult(1, 1),
|
||||
@@ -187,16 +188,26 @@ func expectUpdateCurrentSequence(tableName, projection string, seq uint64, aggre
|
||||
}
|
||||
|
||||
func expectUpdateTwoCurrentSequence(tableName, projection string, sequences currentSequences) func(sqlmock.Sqlmock) {
|
||||
//sort them so the args will always have the same order
|
||||
keys := make([]string, 0, len(sequences))
|
||||
for k := range sequences {
|
||||
keys = append(keys, string(k))
|
||||
}
|
||||
sort.Strings(keys)
|
||||
args := make([]driver.Value, len(keys)*4)
|
||||
for i, k := range keys {
|
||||
aggregateType := eventstore.AggregateType(k)
|
||||
for _, sequence := range sequences[aggregateType] {
|
||||
args[i*4] = projection
|
||||
args[i*4+1] = aggregateType
|
||||
args[i*4+2] = sequence.sequence
|
||||
args[i*4+3] = sequence.instanceID
|
||||
}
|
||||
}
|
||||
return func(m sqlmock.Sqlmock) {
|
||||
matcher := ¤tSequenceMatcher{seq: sequences}
|
||||
m.ExpectExec("UPSERT INTO "+tableName+` \(projection_name, aggregate_type, current_sequence, timestamp\) VALUES \(\$1, \$2, \$3, NOW\(\)\), \(\$4, \$5, \$6, NOW\(\)\)`).
|
||||
m.ExpectExec("UPSERT INTO " + tableName + ` \(projection_name, aggregate_type, current_sequence, instance_id, timestamp\) VALUES \(\$1, \$2, \$3, \$4, NOW\(\)\), \(\$5, \$6, \$7, \$8, NOW\(\)\)`).
|
||||
WithArgs(
|
||||
projection,
|
||||
matcher,
|
||||
matcher,
|
||||
projection,
|
||||
matcher,
|
||||
matcher,
|
||||
args...,
|
||||
).
|
||||
WillReturnResult(
|
||||
sqlmock.NewResult(1, 1),
|
||||
@@ -204,51 +215,27 @@ func expectUpdateTwoCurrentSequence(tableName, projection string, sequences curr
|
||||
}
|
||||
}
|
||||
|
||||
type currentSequenceMatcher struct {
|
||||
seq currentSequences
|
||||
currentAggregate eventstore.AggregateType
|
||||
}
|
||||
|
||||
func (m *currentSequenceMatcher) Match(value driver.Value) bool {
|
||||
switch v := value.(type) {
|
||||
case string:
|
||||
if m.currentAggregate != "" {
|
||||
log.Printf("expected sequence of %s but got next aggregate type %s", m.currentAggregate, value)
|
||||
return false
|
||||
}
|
||||
_, ok := m.seq[eventstore.AggregateType(v)]
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
m.currentAggregate = eventstore.AggregateType(v)
|
||||
return true
|
||||
default:
|
||||
seq := m.seq[m.currentAggregate]
|
||||
m.currentAggregate = ""
|
||||
delete(m.seq, m.currentAggregate)
|
||||
return int64(seq) == value.(int64)
|
||||
}
|
||||
}
|
||||
|
||||
func expectUpdateCurrentSequenceErr(tableName, projection string, seq uint64, err error, aggregateType string) func(sqlmock.Sqlmock) {
|
||||
func expectUpdateCurrentSequenceErr(tableName, projection string, seq uint64, err error, aggregateType, instanceID string) func(sqlmock.Sqlmock) {
|
||||
return func(m sqlmock.Sqlmock) {
|
||||
m.ExpectExec("UPSERT INTO "+tableName+` \(projection_name, aggregate_type, current_sequence, timestamp\) VALUES \(\$1, \$2, \$3, NOW\(\)\)`).
|
||||
m.ExpectExec("UPSERT INTO "+tableName+` \(projection_name, aggregate_type, current_sequence, instance_id, timestamp\) VALUES \(\$1, \$2, \$3, \$4, NOW\(\)\)`).
|
||||
WithArgs(
|
||||
projection,
|
||||
aggregateType,
|
||||
seq,
|
||||
instanceID,
|
||||
).
|
||||
WillReturnError(err)
|
||||
}
|
||||
}
|
||||
|
||||
func expectUpdateCurrentSequenceNoRows(tableName, projection string, seq uint64, aggregateType string) func(sqlmock.Sqlmock) {
|
||||
func expectUpdateCurrentSequenceNoRows(tableName, projection string, seq uint64, aggregateType, instanceID string) func(sqlmock.Sqlmock) {
|
||||
return func(m sqlmock.Sqlmock) {
|
||||
m.ExpectExec("UPSERT INTO "+tableName+` \(projection_name, aggregate_type, current_sequence, timestamp\) VALUES \(\$1, \$2, \$3, NOW\(\)\)`).
|
||||
m.ExpectExec("UPSERT INTO "+tableName+` \(projection_name, aggregate_type, current_sequence, instance_id, timestamp\) VALUES \(\$1, \$2, \$3, \$4, NOW\(\)\)`).
|
||||
WithArgs(
|
||||
projection,
|
||||
aggregateType,
|
||||
seq,
|
||||
instanceID,
|
||||
).
|
||||
WillReturnResult(
|
||||
sqlmock.NewResult(0, 0),
|
||||
@@ -256,17 +243,18 @@ func expectUpdateCurrentSequenceNoRows(tableName, projection string, seq uint64,
|
||||
}
|
||||
}
|
||||
|
||||
func expectLock(lockTable, workerName string, d time.Duration) func(sqlmock.Sqlmock) {
|
||||
func expectLock(lockTable, workerName string, d time.Duration, instanceID string) func(sqlmock.Sqlmock) {
|
||||
return func(m sqlmock.Sqlmock) {
|
||||
m.ExpectExec(`INSERT INTO `+lockTable+
|
||||
` \(locker_id, locked_until, projection_name\) VALUES \(\$1, now\(\)\+\$2::INTERVAL, \$3\)`+
|
||||
` ON CONFLICT \(projection_name\)`+
|
||||
` \(locker_id, locked_until, projection_name, instance_id\) VALUES \(\$1, now\(\)\+\$2::INTERVAL, \$3\, \$4\)`+
|
||||
` ON CONFLICT \(projection_name, instance_id\)`+
|
||||
` DO UPDATE SET locker_id = \$1, locked_until = now\(\)\+\$2::INTERVAL`+
|
||||
` WHERE `+lockTable+`\.projection_name = \$3 AND \(`+lockTable+`\.locker_id = \$1 OR `+lockTable+`\.locked_until < now\(\)\)`).
|
||||
` WHERE `+lockTable+`\.projection_name = \$3 AND `+lockTable+`\.instance_id = \$4 AND \(`+lockTable+`\.locker_id = \$1 OR `+lockTable+`\.locked_until < now\(\)\)`).
|
||||
WithArgs(
|
||||
workerName,
|
||||
float64(d),
|
||||
projectionName,
|
||||
instanceID,
|
||||
).
|
||||
WillReturnResult(
|
||||
sqlmock.NewResult(1, 1),
|
||||
@@ -274,33 +262,35 @@ func expectLock(lockTable, workerName string, d time.Duration) func(sqlmock.Sqlm
|
||||
}
|
||||
}
|
||||
|
||||
func expectLockNoRows(lockTable, workerName string, d time.Duration) func(sqlmock.Sqlmock) {
|
||||
func expectLockNoRows(lockTable, workerName string, d time.Duration, instanceID string) func(sqlmock.Sqlmock) {
|
||||
return func(m sqlmock.Sqlmock) {
|
||||
m.ExpectExec(`INSERT INTO `+lockTable+
|
||||
` \(locker_id, locked_until, projection_name\) VALUES \(\$1, now\(\)\+\$2::INTERVAL, \$3\)`+
|
||||
` ON CONFLICT \(projection_name\)`+
|
||||
` \(locker_id, locked_until, projection_name, instance_id\) VALUES \(\$1, now\(\)\+\$2::INTERVAL, \$3\, \$4\)`+
|
||||
` ON CONFLICT \(projection_name, instance_id\)`+
|
||||
` DO UPDATE SET locker_id = \$1, locked_until = now\(\)\+\$2::INTERVAL`+
|
||||
` WHERE `+lockTable+`\.projection_name = \$3 AND \(`+lockTable+`\.locker_id = \$1 OR `+lockTable+`\.locked_until < now\(\)\)`).
|
||||
` WHERE `+lockTable+`\.projection_name = \$3 AND `+lockTable+`\.instance_id = \$4 AND \(`+lockTable+`\.locker_id = \$1 OR `+lockTable+`\.locked_until < now\(\)\)`).
|
||||
WithArgs(
|
||||
workerName,
|
||||
float64(d),
|
||||
projectionName,
|
||||
instanceID,
|
||||
).
|
||||
WillReturnResult(driver.ResultNoRows)
|
||||
}
|
||||
}
|
||||
|
||||
func expectLockErr(lockTable, workerName string, d time.Duration, err error) func(sqlmock.Sqlmock) {
|
||||
func expectLockErr(lockTable, workerName string, d time.Duration, instanceID string, err error) func(sqlmock.Sqlmock) {
|
||||
return func(m sqlmock.Sqlmock) {
|
||||
m.ExpectExec(`INSERT INTO `+lockTable+
|
||||
` \(locker_id, locked_until, projection_name\) VALUES \(\$1, now\(\)\+\$2::INTERVAL, \$3\)`+
|
||||
` ON CONFLICT \(projection_name\)`+
|
||||
` \(locker_id, locked_until, projection_name, instance_id\) VALUES \(\$1, now\(\)\+\$2::INTERVAL, \$3\, \$4\)`+
|
||||
` ON CONFLICT \(projection_name, instance_id\)`+
|
||||
` DO UPDATE SET locker_id = \$1, locked_until = now\(\)\+\$2::INTERVAL`+
|
||||
` WHERE `+lockTable+`\.projection_name = \$3 AND \(`+lockTable+`\.locker_id = \$1 OR `+lockTable+`\.locked_until < now\(\)\)`).
|
||||
` WHERE `+lockTable+`\.projection_name = \$3 AND `+lockTable+`\.instance_id = \$4 AND \(`+lockTable+`\.locker_id = \$1 OR `+lockTable+`\.locked_until < now\(\)\)`).
|
||||
WithArgs(
|
||||
workerName,
|
||||
float64(d),
|
||||
projectionName,
|
||||
instanceID,
|
||||
).
|
||||
WillReturnError(err)
|
||||
}
|
||||
|
@@ -101,15 +101,34 @@ func (h *StatementHandler) SearchQuery() (*eventstore.SearchQueryBuilder, uint64
|
||||
|
||||
queryBuilder := eventstore.NewSearchQueryBuilder(eventstore.ColumnsEvent).Limit(h.bulkLimit)
|
||||
for _, aggregateType := range h.aggregates {
|
||||
instances := make([]string, 0)
|
||||
for _, sequence := range sequences[aggregateType] {
|
||||
instances = appendToIgnoredInstances(instances, sequence.instanceID)
|
||||
queryBuilder.
|
||||
AddQuery().
|
||||
AggregateTypes(aggregateType).
|
||||
SequenceGreater(sequence.sequence).
|
||||
InstanceID(sequence.instanceID)
|
||||
}
|
||||
queryBuilder.
|
||||
AddQuery().
|
||||
AggregateTypes(aggregateType).
|
||||
SequenceGreater(sequences[aggregateType])
|
||||
SequenceGreater(0).
|
||||
ExcludedInstanceID(instances...)
|
||||
}
|
||||
|
||||
return queryBuilder, h.bulkLimit, nil
|
||||
}
|
||||
|
||||
func appendToIgnoredInstances(instances []string, id string) []string {
|
||||
for _, instance := range instances {
|
||||
if instance == id {
|
||||
return instances
|
||||
}
|
||||
}
|
||||
return append(instances, id)
|
||||
}
|
||||
|
||||
//Update implements handler.Update
|
||||
func (h *StatementHandler) Update(ctx context.Context, stmts []*handler.Statement, reduce handler.Reduce) (unexecutedStmts []*handler.Statement, err error) {
|
||||
tx, err := h.client.BeginTx(ctx, nil)
|
||||
@@ -127,7 +146,7 @@ func (h *StatementHandler) Update(ctx context.Context, stmts []*handler.Statemen
|
||||
// because there could be events between current sequence and a creation event
|
||||
// and we cannot check via stmt.PreviousSequence
|
||||
if stmts[0].PreviousSequence == 0 {
|
||||
previousStmts, err := h.fetchPreviousStmts(ctx, stmts[0].Sequence, sequences, reduce)
|
||||
previousStmts, err := h.fetchPreviousStmts(ctx, stmts[0].Sequence, stmts[0].InstanceID, sequences, reduce)
|
||||
if err != nil {
|
||||
tx.Rollback()
|
||||
return stmts, err
|
||||
@@ -164,27 +183,25 @@ func (h *StatementHandler) Update(ctx context.Context, stmts []*handler.Statemen
|
||||
return unexecutedStmts, nil
|
||||
}
|
||||
|
||||
func (h *StatementHandler) fetchPreviousStmts(
|
||||
ctx context.Context,
|
||||
stmtSeq uint64,
|
||||
sequences currentSequences,
|
||||
reduce handler.Reduce,
|
||||
) (previousStmts []*handler.Statement, err error) {
|
||||
func (h *StatementHandler) fetchPreviousStmts(ctx context.Context, stmtSeq uint64, instanceID string, sequences currentSequences, reduce handler.Reduce) (previousStmts []*handler.Statement, err error) {
|
||||
|
||||
query := eventstore.NewSearchQueryBuilder(eventstore.ColumnsEvent)
|
||||
queriesAdded := false
|
||||
for _, aggregateType := range h.aggregates {
|
||||
if stmtSeq <= sequences[aggregateType] {
|
||||
continue
|
||||
for _, sequence := range sequences[aggregateType] {
|
||||
if stmtSeq <= sequence.sequence && instanceID == sequence.instanceID {
|
||||
continue
|
||||
}
|
||||
|
||||
query.
|
||||
AddQuery().
|
||||
AggregateTypes(aggregateType).
|
||||
SequenceGreater(sequence.sequence).
|
||||
SequenceLess(stmtSeq).
|
||||
InstanceID(sequence.instanceID)
|
||||
|
||||
queriesAdded = true
|
||||
}
|
||||
|
||||
query.
|
||||
AddQuery().
|
||||
AggregateTypes(aggregateType).
|
||||
SequenceGreater(sequences[aggregateType]).
|
||||
SequenceLess(stmtSeq)
|
||||
|
||||
queriesAdded = true
|
||||
}
|
||||
|
||||
if !queriesAdded {
|
||||
@@ -214,16 +231,19 @@ func (h *StatementHandler) executeStmts(
|
||||
|
||||
lastSuccessfulIdx := -1
|
||||
for i, stmt := range stmts {
|
||||
if stmt.Sequence <= sequences[stmt.AggregateType] {
|
||||
continue
|
||||
}
|
||||
if stmt.PreviousSequence > 0 && stmt.PreviousSequence != sequences[stmt.AggregateType] {
|
||||
logging.WithFields("projection", h.ProjectionName, "aggregateType", stmt.AggregateType, "sequence", stmt.Sequence, "prevSeq", stmt.PreviousSequence, "currentSeq", sequences[stmt.AggregateType]).Warn("sequences do not match")
|
||||
break
|
||||
for _, sequence := range sequences[stmt.AggregateType] {
|
||||
if stmt.Sequence <= sequence.sequence && stmt.InstanceID == sequence.instanceID {
|
||||
continue
|
||||
}
|
||||
if stmt.PreviousSequence > 0 && stmt.PreviousSequence != sequence.sequence && stmt.InstanceID == sequence.instanceID {
|
||||
logging.WithFields("projection", h.ProjectionName, "aggregateType", stmt.AggregateType, "sequence", stmt.Sequence, "prevSeq", stmt.PreviousSequence, "currentSeq", sequences[stmt.AggregateType]).Warn("sequences do not match")
|
||||
break
|
||||
}
|
||||
}
|
||||
err := h.executeStmt(tx, stmt)
|
||||
if err == nil {
|
||||
sequences[stmt.AggregateType], lastSuccessfulIdx = stmt.Sequence, i
|
||||
updateSequences(sequences, stmt)
|
||||
lastSuccessfulIdx = i
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -232,7 +252,9 @@ func (h *StatementHandler) executeStmts(
|
||||
break
|
||||
}
|
||||
|
||||
sequences[stmt.AggregateType], lastSuccessfulIdx = stmt.Sequence, i
|
||||
updateSequences(sequences, stmt)
|
||||
lastSuccessfulIdx = i
|
||||
continue
|
||||
}
|
||||
return lastSuccessfulIdx
|
||||
}
|
||||
@@ -261,3 +283,16 @@ func (h *StatementHandler) executeStmt(tx *sql.Tx, stmt *handler.Statement) erro
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func updateSequences(sequences currentSequences, stmt *handler.Statement) {
|
||||
for _, sequence := range sequences[stmt.AggregateType] {
|
||||
if sequence.instanceID == stmt.InstanceID {
|
||||
sequence.sequence = stmt.Sequence
|
||||
return
|
||||
}
|
||||
}
|
||||
sequences[stmt.AggregateType] = append(sequences[stmt.AggregateType], &instanceSequence{
|
||||
instanceID: stmt.InstanceID,
|
||||
sequence: stmt.Sequence,
|
||||
})
|
||||
}
|
||||
|
@@ -9,6 +9,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/DATA-DOG/go-sqlmock"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/caos/zitadel/internal/eventstore"
|
||||
"github.com/caos/zitadel/internal/eventstore/handler"
|
||||
@@ -97,13 +98,18 @@ func TestProjectionHandler_SearchQuery(t *testing.T) {
|
||||
return err == nil
|
||||
},
|
||||
expectations: []mockExpectation{
|
||||
expectCurrentSequence("my_sequences", "my_projection", 5, "testAgg"),
|
||||
expectCurrentSequence("my_sequences", "my_projection", 5, "testAgg", "instanceID"),
|
||||
},
|
||||
SearchQueryBuilder: eventstore.
|
||||
NewSearchQueryBuilder(eventstore.ColumnsEvent).
|
||||
AddQuery().
|
||||
AggregateTypes("testAgg").
|
||||
SequenceGreater(5).
|
||||
InstanceID("instanceID").
|
||||
Or().
|
||||
AggregateTypes("testAgg").
|
||||
SequenceGreater(0).
|
||||
ExcludedInstanceID("instanceID").
|
||||
Builder().
|
||||
Limit(5),
|
||||
},
|
||||
@@ -225,7 +231,7 @@ func TestStatementHandler_Update(t *testing.T) {
|
||||
want: want{
|
||||
expectations: []mockExpectation{
|
||||
expectBegin(),
|
||||
expectCurrentSequence("my_sequences", "my_projection", 5, "testAgg"),
|
||||
expectCurrentSequence("my_sequences", "my_projection", 5, "testAgg", "instanceID"),
|
||||
expectRollback(),
|
||||
},
|
||||
isErr: func(err error) bool {
|
||||
@@ -262,7 +268,7 @@ func TestStatementHandler_Update(t *testing.T) {
|
||||
want: want{
|
||||
expectations: []mockExpectation{
|
||||
expectBegin(),
|
||||
expectCurrentSequence("my_sequences", "my_projection", 5, "testAgg"),
|
||||
expectCurrentSequence("my_sequences", "my_projection", 5, "testAgg", "instanceID"),
|
||||
expectCommit(),
|
||||
},
|
||||
isErr: func(err error) bool {
|
||||
@@ -287,6 +293,7 @@ func TestStatementHandler_Update(t *testing.T) {
|
||||
aggregateType: "agg",
|
||||
sequence: 7,
|
||||
previousSequence: 5,
|
||||
instanceID: "instanceID",
|
||||
},
|
||||
[]handler.Column{
|
||||
{
|
||||
@@ -299,11 +306,11 @@ func TestStatementHandler_Update(t *testing.T) {
|
||||
want: want{
|
||||
expectations: []mockExpectation{
|
||||
expectBegin(),
|
||||
expectCurrentSequence("my_sequences", "my_projection", 5, "agg"),
|
||||
expectCurrentSequence("my_sequences", "my_projection", 5, "agg", "instanceID"),
|
||||
expectSavePoint(),
|
||||
expectCreate("my_projection", []string{"col"}, []string{"$1"}),
|
||||
expectSavePointRelease(),
|
||||
expectUpdateCurrentSequenceNoRows("my_sequences", "my_projection", 7, "agg"),
|
||||
expectUpdateCurrentSequenceNoRows("my_sequences", "my_projection", 7, "agg", "instanceID"),
|
||||
expectRollback(),
|
||||
},
|
||||
isErr: func(err error) bool {
|
||||
@@ -328,6 +335,7 @@ func TestStatementHandler_Update(t *testing.T) {
|
||||
aggregateType: "agg",
|
||||
sequence: 7,
|
||||
previousSequence: 5,
|
||||
instanceID: "instanceID",
|
||||
},
|
||||
[]handler.Column{
|
||||
{
|
||||
@@ -340,11 +348,11 @@ func TestStatementHandler_Update(t *testing.T) {
|
||||
want: want{
|
||||
expectations: []mockExpectation{
|
||||
expectBegin(),
|
||||
expectCurrentSequence("my_sequences", "my_projection", 5, "agg"),
|
||||
expectCurrentSequence("my_sequences", "my_projection", 5, "agg", "instanceID"),
|
||||
expectSavePoint(),
|
||||
expectCreate("my_projection", []string{"col"}, []string{"$1"}),
|
||||
expectSavePointRelease(),
|
||||
expectUpdateCurrentSequence("my_sequences", "my_projection", 7, "agg"),
|
||||
expectUpdateCurrentSequence("my_sequences", "my_projection", 7, "agg", "instanceID"),
|
||||
expectCommitErr(sql.ErrConnDone),
|
||||
},
|
||||
isErr: func(err error) bool {
|
||||
@@ -368,14 +376,15 @@ func TestStatementHandler_Update(t *testing.T) {
|
||||
aggregateType: "testAgg",
|
||||
sequence: 7,
|
||||
previousSequence: 5,
|
||||
instanceID: "instanceID",
|
||||
}),
|
||||
},
|
||||
},
|
||||
want: want{
|
||||
expectations: []mockExpectation{
|
||||
expectBegin(),
|
||||
expectCurrentSequence("my_sequences", "my_projection", 5, "testAgg"),
|
||||
expectUpdateCurrentSequence("my_sequences", "my_projection", 7, "testAgg"),
|
||||
expectCurrentSequence("my_sequences", "my_projection", 5, "testAgg", "instanceID"),
|
||||
expectUpdateCurrentSequence("my_sequences", "my_projection", 7, "testAgg", "instanceID"),
|
||||
expectCommit(),
|
||||
},
|
||||
isErr: func(err error) bool {
|
||||
@@ -399,14 +408,15 @@ func TestStatementHandler_Update(t *testing.T) {
|
||||
aggregateType: "testAgg",
|
||||
sequence: 7,
|
||||
previousSequence: 0,
|
||||
instanceID: "instanceID",
|
||||
}),
|
||||
},
|
||||
},
|
||||
want: want{
|
||||
expectations: []mockExpectation{
|
||||
expectBegin(),
|
||||
expectCurrentSequence("my_sequences", "my_projection", 5, "testAgg"),
|
||||
expectUpdateCurrentSequence("my_sequences", "my_projection", 7, "testAgg"),
|
||||
expectCurrentSequence("my_sequences", "my_projection", 5, "testAgg", "instanceID"),
|
||||
expectUpdateCurrentSequence("my_sequences", "my_projection", 7, "testAgg", "instanceID"),
|
||||
expectCommit(),
|
||||
},
|
||||
isErr: func(err error) bool {
|
||||
@@ -423,6 +433,7 @@ func TestStatementHandler_Update(t *testing.T) {
|
||||
AggregateType: "testAgg",
|
||||
Sequence: 6,
|
||||
PreviousAggregateSequence: 5,
|
||||
InstanceID: "instanceID",
|
||||
},
|
||||
),
|
||||
),
|
||||
@@ -435,6 +446,7 @@ func TestStatementHandler_Update(t *testing.T) {
|
||||
aggregateType: "testAgg",
|
||||
sequence: 7,
|
||||
previousSequence: 0,
|
||||
instanceID: "instanceID",
|
||||
}),
|
||||
},
|
||||
reduce: testReduce(),
|
||||
@@ -442,8 +454,8 @@ func TestStatementHandler_Update(t *testing.T) {
|
||||
want: want{
|
||||
expectations: []mockExpectation{
|
||||
expectBegin(),
|
||||
expectCurrentSequence("my_sequences", "my_projection", 5, "testAgg"),
|
||||
expectUpdateCurrentSequence("my_sequences", "my_projection", 7, "testAgg"),
|
||||
expectCurrentSequence("my_sequences", "my_projection", 5, "testAgg", "instanceID"),
|
||||
expectUpdateCurrentSequence("my_sequences", "my_projection", 7, "testAgg", "instanceID"),
|
||||
expectCommit(),
|
||||
},
|
||||
isErr: func(err error) bool {
|
||||
@@ -537,7 +549,9 @@ func TestProjectionHandler_fetchPreviousStmts(t *testing.T) {
|
||||
ctx: context.Background(),
|
||||
reduce: testReduce(),
|
||||
sequences: currentSequences{
|
||||
"testAgg": 5,
|
||||
"testAgg": []*instanceSequence{
|
||||
{sequence: 5},
|
||||
},
|
||||
},
|
||||
stmtSeq: 6,
|
||||
},
|
||||
@@ -560,7 +574,9 @@ func TestProjectionHandler_fetchPreviousStmts(t *testing.T) {
|
||||
ctx: context.Background(),
|
||||
reduce: testReduce(),
|
||||
sequences: currentSequences{
|
||||
"testAgg": 5,
|
||||
"testAgg": []*instanceSequence{
|
||||
{sequence: 5},
|
||||
},
|
||||
},
|
||||
stmtSeq: 6,
|
||||
},
|
||||
@@ -582,7 +598,9 @@ func TestProjectionHandler_fetchPreviousStmts(t *testing.T) {
|
||||
ctx: context.Background(),
|
||||
reduce: testReduce(),
|
||||
sequences: currentSequences{
|
||||
"testAgg": 5,
|
||||
"testAgg": []*instanceSequence{
|
||||
{sequence: 5},
|
||||
},
|
||||
},
|
||||
stmtSeq: 10,
|
||||
},
|
||||
@@ -626,7 +644,9 @@ func TestProjectionHandler_fetchPreviousStmts(t *testing.T) {
|
||||
ctx: context.Background(),
|
||||
reduce: testReduceErr(errReduce),
|
||||
sequences: currentSequences{
|
||||
"testAgg": 5,
|
||||
"testAgg": []*instanceSequence{
|
||||
{sequence: 5},
|
||||
},
|
||||
},
|
||||
stmtSeq: 10,
|
||||
},
|
||||
@@ -667,7 +687,7 @@ func TestProjectionHandler_fetchPreviousStmts(t *testing.T) {
|
||||
}),
|
||||
aggregates: tt.fields.aggregates,
|
||||
}
|
||||
stmts, err := h.fetchPreviousStmts(tt.args.ctx, tt.args.stmtSeq, tt.args.sequences, tt.args.reduce)
|
||||
stmts, err := h.fetchPreviousStmts(tt.args.ctx, tt.args.stmtSeq, "", tt.args.sequences, tt.args.reduce)
|
||||
if !tt.want.isErr(err) {
|
||||
t.Errorf("ProjectionHandler.prepareBulkStmts() error = %v", err)
|
||||
return
|
||||
@@ -720,7 +740,9 @@ func TestStatementHandler_executeStmts(t *testing.T) {
|
||||
}),
|
||||
},
|
||||
sequences: currentSequences{
|
||||
"agg": 5,
|
||||
"agg": []*instanceSequence{
|
||||
{sequence: 5},
|
||||
},
|
||||
},
|
||||
},
|
||||
want: want{
|
||||
@@ -762,7 +784,9 @@ func TestStatementHandler_executeStmts(t *testing.T) {
|
||||
}),
|
||||
},
|
||||
sequences: currentSequences{
|
||||
"agg": 2,
|
||||
"agg": []*instanceSequence{
|
||||
{sequence: 2},
|
||||
},
|
||||
},
|
||||
},
|
||||
want: want{
|
||||
@@ -824,7 +848,9 @@ func TestStatementHandler_executeStmts(t *testing.T) {
|
||||
}),
|
||||
},
|
||||
sequences: currentSequences{
|
||||
"agg": 2,
|
||||
"agg": []*instanceSequence{
|
||||
{sequence: 2},
|
||||
},
|
||||
},
|
||||
},
|
||||
want: want{
|
||||
@@ -891,7 +917,9 @@ func TestStatementHandler_executeStmts(t *testing.T) {
|
||||
}),
|
||||
},
|
||||
sequences: currentSequences{
|
||||
"agg": 2,
|
||||
"agg": []*instanceSequence{
|
||||
{sequence: 2},
|
||||
},
|
||||
},
|
||||
},
|
||||
want: want{
|
||||
@@ -979,7 +1007,9 @@ func TestStatementHandler_executeStmts(t *testing.T) {
|
||||
),
|
||||
},
|
||||
sequences: currentSequences{
|
||||
"agg": 2,
|
||||
"agg": []*instanceSequence{
|
||||
{sequence: 2},
|
||||
},
|
||||
},
|
||||
},
|
||||
want: want{
|
||||
@@ -1309,9 +1339,7 @@ func TestStatementHandler_currentSequence(t *testing.T) {
|
||||
expectations: []mockExpectation{
|
||||
expectCurrentSequenceNoRows("my_table", "my_projection"),
|
||||
},
|
||||
sequences: currentSequences{
|
||||
"agg": 0,
|
||||
},
|
||||
sequences: currentSequences{},
|
||||
},
|
||||
},
|
||||
{
|
||||
@@ -1331,9 +1359,7 @@ func TestStatementHandler_currentSequence(t *testing.T) {
|
||||
expectations: []mockExpectation{
|
||||
expectCurrentSequenceScanErr("my_table", "my_projection"),
|
||||
},
|
||||
sequences: currentSequences{
|
||||
"agg": 0,
|
||||
},
|
||||
sequences: currentSequences{},
|
||||
},
|
||||
},
|
||||
{
|
||||
@@ -1351,10 +1377,15 @@ func TestStatementHandler_currentSequence(t *testing.T) {
|
||||
return errors.Is(err, nil)
|
||||
},
|
||||
expectations: []mockExpectation{
|
||||
expectCurrentSequence("my_table", "my_projection", 5, "agg"),
|
||||
expectCurrentSequence("my_table", "my_projection", 5, "agg", "instanceID"),
|
||||
},
|
||||
sequences: currentSequences{
|
||||
"agg": 5,
|
||||
"agg": []*instanceSequence{
|
||||
{
|
||||
sequence: 5,
|
||||
instanceID: "instanceID",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -1404,9 +1435,7 @@ func TestStatementHandler_currentSequence(t *testing.T) {
|
||||
}
|
||||
|
||||
for _, aggregateType := range tt.fields.aggregates {
|
||||
if seq[aggregateType] != tt.want.sequences[aggregateType] {
|
||||
t.Errorf("unexpected sequence in aggregate type %s: want %d got %d", aggregateType, tt.want.sequences[aggregateType], seq[aggregateType])
|
||||
}
|
||||
assert.Equal(t, tt.want.sequences[aggregateType], seq[aggregateType])
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -1440,7 +1469,12 @@ func TestStatementHandler_updateCurrentSequence(t *testing.T) {
|
||||
},
|
||||
args: args{
|
||||
sequences: currentSequences{
|
||||
"agg": 5,
|
||||
"agg": []*instanceSequence{
|
||||
{
|
||||
sequence: 5,
|
||||
instanceID: "instanceID",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
want: want{
|
||||
@@ -1448,7 +1482,7 @@ func TestStatementHandler_updateCurrentSequence(t *testing.T) {
|
||||
return errors.Is(err, sql.ErrConnDone)
|
||||
},
|
||||
expectations: []mockExpectation{
|
||||
expectUpdateCurrentSequenceErr("my_table", "my_projection", 5, sql.ErrConnDone, "agg"),
|
||||
expectUpdateCurrentSequenceErr("my_table", "my_projection", 5, sql.ErrConnDone, "agg", "instanceID"),
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -1461,7 +1495,12 @@ func TestStatementHandler_updateCurrentSequence(t *testing.T) {
|
||||
},
|
||||
args: args{
|
||||
sequences: currentSequences{
|
||||
"agg": 5,
|
||||
"agg": []*instanceSequence{
|
||||
{
|
||||
sequence: 5,
|
||||
instanceID: "instanceID",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
want: want{
|
||||
@@ -1469,7 +1508,7 @@ func TestStatementHandler_updateCurrentSequence(t *testing.T) {
|
||||
return errors.As(err, &errSeqNotUpdated)
|
||||
},
|
||||
expectations: []mockExpectation{
|
||||
expectUpdateCurrentSequenceNoRows("my_table", "my_projection", 5, "agg"),
|
||||
expectUpdateCurrentSequenceNoRows("my_table", "my_projection", 5, "agg", "instanceID"),
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -1482,7 +1521,12 @@ func TestStatementHandler_updateCurrentSequence(t *testing.T) {
|
||||
},
|
||||
args: args{
|
||||
sequences: currentSequences{
|
||||
"agg": 5,
|
||||
"agg": []*instanceSequence{
|
||||
{
|
||||
sequence: 5,
|
||||
instanceID: "instanceID",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
want: want{
|
||||
@@ -1490,7 +1534,7 @@ func TestStatementHandler_updateCurrentSequence(t *testing.T) {
|
||||
return err == nil
|
||||
},
|
||||
expectations: []mockExpectation{
|
||||
expectUpdateCurrentSequence("my_table", "my_projection", 5, "agg"),
|
||||
expectUpdateCurrentSequence("my_table", "my_projection", 5, "agg", "instanceID"),
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -1503,8 +1547,18 @@ func TestStatementHandler_updateCurrentSequence(t *testing.T) {
|
||||
},
|
||||
args: args{
|
||||
sequences: currentSequences{
|
||||
"agg": 5,
|
||||
"agg2": 6,
|
||||
"agg": []*instanceSequence{
|
||||
{
|
||||
sequence: 5,
|
||||
instanceID: "instanceID",
|
||||
},
|
||||
},
|
||||
"agg2": []*instanceSequence{
|
||||
{
|
||||
sequence: 6,
|
||||
instanceID: "instanceID",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
want: want{
|
||||
@@ -1513,9 +1567,19 @@ func TestStatementHandler_updateCurrentSequence(t *testing.T) {
|
||||
},
|
||||
expectations: []mockExpectation{
|
||||
expectUpdateTwoCurrentSequence("my_table", "my_projection", currentSequences{
|
||||
"agg": 5,
|
||||
"agg2": 6},
|
||||
),
|
||||
"agg": []*instanceSequence{
|
||||
{
|
||||
sequence: 5,
|
||||
instanceID: "instanceID",
|
||||
},
|
||||
},
|
||||
"agg2": []*instanceSequence{
|
||||
{
|
||||
sequence: 6,
|
||||
instanceID: "instanceID",
|
||||
},
|
||||
},
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
|
@@ -15,15 +15,15 @@ import (
|
||||
|
||||
const (
|
||||
lockStmtFormat = "INSERT INTO %[1]s" +
|
||||
" (locker_id, locked_until, projection_name) VALUES ($1, now()+$2::INTERVAL, $3)" +
|
||||
" ON CONFLICT (projection_name)" +
|
||||
" (locker_id, locked_until, projection_name, instance_id) VALUES ($1, now()+$2::INTERVAL, $3, $4)" +
|
||||
" ON CONFLICT (projection_name, instance_id)" +
|
||||
" DO UPDATE SET locker_id = $1, locked_until = now()+$2::INTERVAL" +
|
||||
" WHERE %[1]s.projection_name = $3 AND (%[1]s.locker_id = $1 OR %[1]s.locked_until < now())"
|
||||
" WHERE %[1]s.projection_name = $3 AND %[1]s.instance_id = $4 AND (%[1]s.locker_id = $1 OR %[1]s.locked_until < now())"
|
||||
)
|
||||
|
||||
type Locker interface {
|
||||
Lock(ctx context.Context, lockDuration time.Duration) <-chan error
|
||||
Unlock() error
|
||||
Lock(ctx context.Context, lockDuration time.Duration, instanceID string) <-chan error
|
||||
Unlock(instanceID string) error
|
||||
}
|
||||
|
||||
type locker struct {
|
||||
@@ -47,18 +47,18 @@ func NewLocker(client *sql.DB, lockTable, projectionName string) Locker {
|
||||
}
|
||||
}
|
||||
|
||||
func (h *locker) Lock(ctx context.Context, lockDuration time.Duration) <-chan error {
|
||||
func (h *locker) Lock(ctx context.Context, lockDuration time.Duration, instanceID string) <-chan error {
|
||||
errs := make(chan error)
|
||||
go h.handleLock(ctx, errs, lockDuration)
|
||||
go h.handleLock(ctx, errs, lockDuration, instanceID)
|
||||
return errs
|
||||
}
|
||||
|
||||
func (h *locker) handleLock(ctx context.Context, errs chan error, lockDuration time.Duration) {
|
||||
func (h *locker) handleLock(ctx context.Context, errs chan error, lockDuration time.Duration, instanceID string) {
|
||||
renewLock := time.NewTimer(0)
|
||||
for {
|
||||
select {
|
||||
case <-renewLock.C:
|
||||
errs <- h.renewLock(ctx, lockDuration)
|
||||
errs <- h.renewLock(ctx, lockDuration, instanceID)
|
||||
//refresh the lock 500ms before it times out. 500ms should be enough for one transaction
|
||||
renewLock.Reset(lockDuration - (500 * time.Millisecond))
|
||||
case <-ctx.Done():
|
||||
@@ -69,9 +69,9 @@ func (h *locker) handleLock(ctx context.Context, errs chan error, lockDuration t
|
||||
}
|
||||
}
|
||||
|
||||
func (h *locker) renewLock(ctx context.Context, lockDuration time.Duration) error {
|
||||
func (h *locker) renewLock(ctx context.Context, lockDuration time.Duration, instanceID string) error {
|
||||
//the unit of crdb interval is seconds (https://www.cockroachlabs.com/docs/stable/interval.html).
|
||||
res, err := h.client.Exec(h.lockStmt, h.workerName, lockDuration.Seconds(), h.projectionName)
|
||||
res, err := h.client.Exec(h.lockStmt, h.workerName, lockDuration.Seconds(), h.projectionName, instanceID)
|
||||
if err != nil {
|
||||
return errors.ThrowInternal(err, "CRDB-uaDoR", "unable to execute lock")
|
||||
}
|
||||
@@ -83,8 +83,8 @@ func (h *locker) renewLock(ctx context.Context, lockDuration time.Duration) erro
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *locker) Unlock() error {
|
||||
_, err := h.client.Exec(h.lockStmt, h.workerName, float64(0), h.projectionName)
|
||||
func (h *locker) Unlock(instanceID string) error {
|
||||
_, err := h.client.Exec(h.lockStmt, h.workerName, float64(0), h.projectionName, instanceID)
|
||||
if err != nil {
|
||||
return errors.ThrowUnknown(err, "CRDB-JjfwO", "unlock failed")
|
||||
}
|
||||
|
@@ -32,6 +32,7 @@ func TestStatementHandler_handleLock(t *testing.T) {
|
||||
lockDuration time.Duration
|
||||
ctx context.Context
|
||||
errMock *errsMock
|
||||
instanceID string
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
@@ -42,9 +43,9 @@ func TestStatementHandler_handleLock(t *testing.T) {
|
||||
name: "lock fails",
|
||||
want: want{
|
||||
expectations: []mockExpectation{
|
||||
expectLock(lockTable, workerName, 2),
|
||||
expectLock(lockTable, workerName, 2),
|
||||
expectLockErr(lockTable, workerName, 2, errLock),
|
||||
expectLock(lockTable, workerName, 2, "instanceID"),
|
||||
expectLock(lockTable, workerName, 2, "instanceID"),
|
||||
expectLockErr(lockTable, workerName, 2, "instanceID", errLock),
|
||||
},
|
||||
},
|
||||
args: args{
|
||||
@@ -55,14 +56,15 @@ func TestStatementHandler_handleLock(t *testing.T) {
|
||||
successfulIters: 2,
|
||||
shouldErr: true,
|
||||
},
|
||||
instanceID: "instanceID",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "success",
|
||||
want: want{
|
||||
expectations: []mockExpectation{
|
||||
expectLock(lockTable, workerName, 2),
|
||||
expectLock(lockTable, workerName, 2),
|
||||
expectLock(lockTable, workerName, 2, "instanceID"),
|
||||
expectLock(lockTable, workerName, 2, "instanceID"),
|
||||
},
|
||||
},
|
||||
args: args{
|
||||
@@ -72,6 +74,7 @@ func TestStatementHandler_handleLock(t *testing.T) {
|
||||
errs: make(chan error),
|
||||
successfulIters: 2,
|
||||
},
|
||||
instanceID: "instanceID",
|
||||
},
|
||||
},
|
||||
}
|
||||
@@ -96,7 +99,7 @@ func TestStatementHandler_handleLock(t *testing.T) {
|
||||
|
||||
go tt.args.errMock.handleErrs(t, cancel)
|
||||
|
||||
go h.handleLock(ctx, tt.args.errMock.errs, tt.args.lockDuration)
|
||||
go h.handleLock(ctx, tt.args.errMock.errs, tt.args.lockDuration, tt.args.instanceID)
|
||||
|
||||
<-ctx.Done()
|
||||
|
||||
@@ -115,6 +118,7 @@ func TestStatementHandler_renewLock(t *testing.T) {
|
||||
}
|
||||
type args struct {
|
||||
lockDuration time.Duration
|
||||
instanceID string
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
@@ -125,7 +129,7 @@ func TestStatementHandler_renewLock(t *testing.T) {
|
||||
name: "lock fails",
|
||||
want: want{
|
||||
expectations: []mockExpectation{
|
||||
expectLockErr(lockTable, workerName, 1, sql.ErrTxDone),
|
||||
expectLockErr(lockTable, workerName, 1, "instanceID", sql.ErrTxDone),
|
||||
},
|
||||
isErr: func(err error) bool {
|
||||
return errors.Is(err, sql.ErrTxDone)
|
||||
@@ -133,13 +137,14 @@ func TestStatementHandler_renewLock(t *testing.T) {
|
||||
},
|
||||
args: args{
|
||||
lockDuration: 1 * time.Second,
|
||||
instanceID: "instanceID",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "lock no rows",
|
||||
want: want{
|
||||
expectations: []mockExpectation{
|
||||
expectLockNoRows(lockTable, workerName, 2),
|
||||
expectLockNoRows(lockTable, workerName, 2, "instanceID"),
|
||||
},
|
||||
isErr: func(err error) bool {
|
||||
return errors.As(err, &renewNoRowsAffectedErr)
|
||||
@@ -147,13 +152,14 @@ func TestStatementHandler_renewLock(t *testing.T) {
|
||||
},
|
||||
args: args{
|
||||
lockDuration: 2 * time.Second,
|
||||
instanceID: "instanceID",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "success",
|
||||
want: want{
|
||||
expectations: []mockExpectation{
|
||||
expectLock(lockTable, workerName, 3),
|
||||
expectLock(lockTable, workerName, 3, "instanceID"),
|
||||
},
|
||||
isErr: func(err error) bool {
|
||||
return errors.Is(err, nil)
|
||||
@@ -161,6 +167,7 @@ func TestStatementHandler_renewLock(t *testing.T) {
|
||||
},
|
||||
args: args{
|
||||
lockDuration: 3 * time.Second,
|
||||
instanceID: "instanceID",
|
||||
},
|
||||
},
|
||||
}
|
||||
@@ -181,7 +188,7 @@ func TestStatementHandler_renewLock(t *testing.T) {
|
||||
expectation(mock)
|
||||
}
|
||||
|
||||
err = h.renewLock(context.Background(), tt.args.lockDuration)
|
||||
err = h.renewLock(context.Background(), tt.args.lockDuration, tt.args.instanceID)
|
||||
if !tt.want.isErr(err) {
|
||||
t.Errorf("unexpected error = %v", err)
|
||||
}
|
||||
@@ -199,15 +206,22 @@ func TestStatementHandler_Unlock(t *testing.T) {
|
||||
expectations []mockExpectation
|
||||
isErr func(err error) bool
|
||||
}
|
||||
type args struct {
|
||||
instanceID string
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want want
|
||||
}{
|
||||
{
|
||||
name: "unlock fails",
|
||||
args: args{
|
||||
instanceID: "instanceID",
|
||||
},
|
||||
want: want{
|
||||
expectations: []mockExpectation{
|
||||
expectLockErr(lockTable, workerName, 0, sql.ErrTxDone),
|
||||
expectLockErr(lockTable, workerName, 0, "instanceID", sql.ErrTxDone),
|
||||
},
|
||||
isErr: func(err error) bool {
|
||||
return errors.Is(err, sql.ErrTxDone)
|
||||
@@ -216,9 +230,12 @@ func TestStatementHandler_Unlock(t *testing.T) {
|
||||
},
|
||||
{
|
||||
name: "success",
|
||||
args: args{
|
||||
instanceID: "instanceID",
|
||||
},
|
||||
want: want{
|
||||
expectations: []mockExpectation{
|
||||
expectLock(lockTable, workerName, 0),
|
||||
expectLock(lockTable, workerName, 0, "instanceID"),
|
||||
},
|
||||
isErr: func(err error) bool {
|
||||
return errors.Is(err, nil)
|
||||
@@ -243,7 +260,7 @@ func TestStatementHandler_Unlock(t *testing.T) {
|
||||
expectation(mock)
|
||||
}
|
||||
|
||||
err = h.Unlock()
|
||||
err = h.Unlock(tt.args.instanceID)
|
||||
if !tt.want.isErr(err) {
|
||||
t.Errorf("unexpected error = %v", err)
|
||||
}
|
||||
|
@@ -12,6 +12,8 @@ import (
|
||||
"github.com/caos/zitadel/internal/eventstore"
|
||||
)
|
||||
|
||||
const systemID = "system"
|
||||
|
||||
type ProjectionHandlerConfig struct {
|
||||
HandlerConfig
|
||||
ProjectionName string
|
||||
@@ -27,10 +29,10 @@ type Update func(context.Context, []*Statement, Reduce) (unexecutedStmts []*Stat
|
||||
type Reduce func(eventstore.Event) (*Statement, error)
|
||||
|
||||
//Lock is used for mutex handling if needed on the projection
|
||||
type Lock func(context.Context, time.Duration) <-chan error
|
||||
type Lock func(context.Context, time.Duration, string) <-chan error
|
||||
|
||||
//Unlock releases the mutex of the projection
|
||||
type Unlock func() error
|
||||
type Unlock func(string) error
|
||||
|
||||
//SearchQuery generates the search query to lookup for events
|
||||
type SearchQuery func() (query *eventstore.SearchQueryBuilder, queryLimit uint64, err error)
|
||||
@@ -183,7 +185,7 @@ func (h *ProjectionHandler) bulk(
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
|
||||
errs := lock(ctx, h.requeueAfter)
|
||||
errs := lock(ctx, h.requeueAfter, systemID)
|
||||
//wait until projection is locked
|
||||
if err, ok := <-errs; err != nil || !ok {
|
||||
logging.WithFields("projection", h.ProjectionName).OnError(err).Warn("initial lock failed")
|
||||
@@ -194,7 +196,7 @@ func (h *ProjectionHandler) bulk(
|
||||
execErr := executeBulk(ctx)
|
||||
logging.WithFields("projection", h.ProjectionName).OnError(execErr).Warn("unable to execute")
|
||||
|
||||
unlockErr := unlock()
|
||||
unlockErr := unlock(systemID)
|
||||
logging.WithFields("projection", h.ProjectionName).OnError(unlockErr).Warn("unable to unlock")
|
||||
|
||||
if execErr != nil {
|
||||
|
@@ -912,7 +912,7 @@ type lockMock struct {
|
||||
}
|
||||
|
||||
func (m *lockMock) lock() Lock {
|
||||
return func(ctx context.Context, _ time.Duration) <-chan error {
|
||||
return func(ctx context.Context, _ time.Duration, _ string) <-chan error {
|
||||
m.callCount++
|
||||
errs := make(chan error)
|
||||
go func() {
|
||||
@@ -955,7 +955,7 @@ type unlockMock struct {
|
||||
}
|
||||
|
||||
func (m *unlockMock) unlock() Unlock {
|
||||
return func() error {
|
||||
return func(instanceID string) error {
|
||||
m.callCount++
|
||||
return m.err
|
||||
}
|
||||
|
Reference in New Issue
Block a user