mirror of
https://github.com/zitadel/zitadel.git
synced 2025-08-11 19:07:30 +00:00
feat: handle instanceID in projections (#3442)
* feat: handle instanceID in projections * rename functions * fix key lock * fix import
This commit is contained in:
@@ -10,11 +10,16 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
currentSequenceStmtFormat = `SELECT current_sequence, aggregate_type FROM %s WHERE projection_name = $1 FOR UPDATE`
|
||||
updateCurrentSequencesStmtFormat = `UPSERT INTO %s (projection_name, aggregate_type, current_sequence, timestamp) VALUES `
|
||||
currentSequenceStmtFormat = `SELECT current_sequence, aggregate_type, instance_id FROM %s WHERE projection_name = $1 FOR UPDATE`
|
||||
updateCurrentSequencesStmtFormat = `UPSERT INTO %s (projection_name, aggregate_type, current_sequence, instance_id, timestamp) VALUES `
|
||||
)
|
||||
|
||||
type currentSequences map[eventstore.AggregateType]uint64
|
||||
type currentSequences map[eventstore.AggregateType][]*instanceSequence
|
||||
|
||||
type instanceSequence struct {
|
||||
instanceID string
|
||||
sequence uint64
|
||||
}
|
||||
|
||||
func (h *StatementHandler) currentSequences(query func(string, ...interface{}) (*sql.Rows, error)) (currentSequences, error) {
|
||||
rows, err := query(h.currentSequenceStmt, h.ProjectionName)
|
||||
@@ -29,14 +34,18 @@ func (h *StatementHandler) currentSequences(query func(string, ...interface{}) (
|
||||
var (
|
||||
aggregateType eventstore.AggregateType
|
||||
sequence uint64
|
||||
instanceID string
|
||||
)
|
||||
|
||||
err = rows.Scan(&sequence, &aggregateType)
|
||||
err = rows.Scan(&sequence, &aggregateType, &instanceID)
|
||||
if err != nil {
|
||||
return nil, errors.ThrowInternal(err, "CRDB-dbatK", "scan failed")
|
||||
}
|
||||
|
||||
sequences[aggregateType] = sequence
|
||||
sequences[aggregateType] = append(sequences[aggregateType], &instanceSequence{
|
||||
sequence: sequence,
|
||||
instanceID: instanceID,
|
||||
})
|
||||
}
|
||||
|
||||
if err = rows.Close(); err != nil {
|
||||
@@ -54,10 +63,12 @@ func (h *StatementHandler) updateCurrentSequences(tx *sql.Tx, sequences currentS
|
||||
valueQueries := make([]string, 0, len(sequences))
|
||||
valueCounter := 0
|
||||
values := make([]interface{}, 0, len(sequences)*3)
|
||||
for aggregate, sequence := range sequences {
|
||||
valueQueries = append(valueQueries, "($"+strconv.Itoa(valueCounter+1)+", $"+strconv.Itoa(valueCounter+2)+", $"+strconv.Itoa(valueCounter+3)+", NOW())")
|
||||
valueCounter += 3
|
||||
values = append(values, h.ProjectionName, aggregate, sequence)
|
||||
for aggregate, instanceSequence := range sequences {
|
||||
for _, sequence := range instanceSequence {
|
||||
valueQueries = append(valueQueries, "($"+strconv.Itoa(valueCounter+1)+", $"+strconv.Itoa(valueCounter+2)+", $"+strconv.Itoa(valueCounter+3)+", $"+strconv.Itoa(valueCounter+4)+", NOW())")
|
||||
valueCounter += 4
|
||||
values = append(values, h.ProjectionName, aggregate, sequence.sequence, sequence.instanceID)
|
||||
}
|
||||
}
|
||||
|
||||
res, err := tx.Exec(h.updateSequencesBaseStmt+strings.Join(valueQueries, ", "), values...)
|
||||
|
@@ -3,7 +3,7 @@ package crdb
|
||||
import (
|
||||
"database/sql"
|
||||
"database/sql/driver"
|
||||
"log"
|
||||
"sort"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
@@ -123,22 +123,22 @@ func expectSavePointRelease() func(sqlmock.Sqlmock) {
|
||||
}
|
||||
}
|
||||
|
||||
func expectCurrentSequence(tableName, projection string, seq uint64, aggregateType string) func(sqlmock.Sqlmock) {
|
||||
func expectCurrentSequence(tableName, projection string, seq uint64, aggregateType, instanceID string) func(sqlmock.Sqlmock) {
|
||||
return func(m sqlmock.Sqlmock) {
|
||||
m.ExpectQuery(`SELECT current_sequence, aggregate_type FROM ` + tableName + ` WHERE projection_name = \$1 FOR UPDATE`).
|
||||
m.ExpectQuery(`SELECT current_sequence, aggregate_type, instance_id FROM ` + tableName + ` WHERE projection_name = \$1 FOR UPDATE`).
|
||||
WithArgs(
|
||||
projection,
|
||||
).
|
||||
WillReturnRows(
|
||||
sqlmock.NewRows([]string{"current_sequence", "aggregate_type"}).
|
||||
AddRow(seq, aggregateType),
|
||||
sqlmock.NewRows([]string{"current_sequence", "aggregate_type", "instance_id"}).
|
||||
AddRow(seq, aggregateType, instanceID),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
func expectCurrentSequenceErr(tableName, projection string, err error) func(sqlmock.Sqlmock) {
|
||||
return func(m sqlmock.Sqlmock) {
|
||||
m.ExpectQuery(`SELECT current_sequence, aggregate_type FROM ` + tableName + ` WHERE projection_name = \$1 FOR UPDATE`).
|
||||
m.ExpectQuery(`SELECT current_sequence, aggregate_type, instance_id FROM ` + tableName + ` WHERE projection_name = \$1 FOR UPDATE`).
|
||||
WithArgs(
|
||||
projection,
|
||||
).
|
||||
@@ -148,37 +148,38 @@ func expectCurrentSequenceErr(tableName, projection string, err error) func(sqlm
|
||||
|
||||
func expectCurrentSequenceNoRows(tableName, projection string) func(sqlmock.Sqlmock) {
|
||||
return func(m sqlmock.Sqlmock) {
|
||||
m.ExpectQuery(`SELECT current_sequence, aggregate_type FROM ` + tableName + ` WHERE projection_name = \$1 FOR UPDATE`).
|
||||
m.ExpectQuery(`SELECT current_sequence, aggregate_type, instance_id FROM ` + tableName + ` WHERE projection_name = \$1 FOR UPDATE`).
|
||||
WithArgs(
|
||||
projection,
|
||||
).
|
||||
WillReturnRows(
|
||||
sqlmock.NewRows([]string{"current_sequence", "aggregate_type"}),
|
||||
sqlmock.NewRows([]string{"current_sequence", "aggregate_type", "instance_id"}),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
func expectCurrentSequenceScanErr(tableName, projection string) func(sqlmock.Sqlmock) {
|
||||
return func(m sqlmock.Sqlmock) {
|
||||
m.ExpectQuery(`SELECT current_sequence, aggregate_type FROM ` + tableName + ` WHERE projection_name = \$1 FOR UPDATE`).
|
||||
m.ExpectQuery(`SELECT current_sequence, aggregate_type, instance_id FROM ` + tableName + ` WHERE projection_name = \$1 FOR UPDATE`).
|
||||
WithArgs(
|
||||
projection,
|
||||
).
|
||||
WillReturnRows(
|
||||
sqlmock.NewRows([]string{"current_sequence", "aggregate_type"}).
|
||||
sqlmock.NewRows([]string{"current_sequence", "aggregate_type", "instance_id"}).
|
||||
RowError(0, sql.ErrTxDone).
|
||||
AddRow(0, "agg"),
|
||||
AddRow(0, "agg", "instanceID"),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
func expectUpdateCurrentSequence(tableName, projection string, seq uint64, aggregateType string) func(sqlmock.Sqlmock) {
|
||||
func expectUpdateCurrentSequence(tableName, projection string, seq uint64, aggregateType, instanceID string) func(sqlmock.Sqlmock) {
|
||||
return func(m sqlmock.Sqlmock) {
|
||||
m.ExpectExec("UPSERT INTO "+tableName+` \(projection_name, aggregate_type, current_sequence, timestamp\) VALUES \(\$1, \$2, \$3, NOW\(\)\)`).
|
||||
m.ExpectExec("UPSERT INTO "+tableName+` \(projection_name, aggregate_type, current_sequence, instance_id, timestamp\) VALUES \(\$1, \$2, \$3, \$4, NOW\(\)\)`).
|
||||
WithArgs(
|
||||
projection,
|
||||
aggregateType,
|
||||
seq,
|
||||
instanceID,
|
||||
).
|
||||
WillReturnResult(
|
||||
sqlmock.NewResult(1, 1),
|
||||
@@ -187,16 +188,26 @@ func expectUpdateCurrentSequence(tableName, projection string, seq uint64, aggre
|
||||
}
|
||||
|
||||
func expectUpdateTwoCurrentSequence(tableName, projection string, sequences currentSequences) func(sqlmock.Sqlmock) {
|
||||
//sort them so the args will always have the same order
|
||||
keys := make([]string, 0, len(sequences))
|
||||
for k := range sequences {
|
||||
keys = append(keys, string(k))
|
||||
}
|
||||
sort.Strings(keys)
|
||||
args := make([]driver.Value, len(keys)*4)
|
||||
for i, k := range keys {
|
||||
aggregateType := eventstore.AggregateType(k)
|
||||
for _, sequence := range sequences[aggregateType] {
|
||||
args[i*4] = projection
|
||||
args[i*4+1] = aggregateType
|
||||
args[i*4+2] = sequence.sequence
|
||||
args[i*4+3] = sequence.instanceID
|
||||
}
|
||||
}
|
||||
return func(m sqlmock.Sqlmock) {
|
||||
matcher := ¤tSequenceMatcher{seq: sequences}
|
||||
m.ExpectExec("UPSERT INTO "+tableName+` \(projection_name, aggregate_type, current_sequence, timestamp\) VALUES \(\$1, \$2, \$3, NOW\(\)\), \(\$4, \$5, \$6, NOW\(\)\)`).
|
||||
m.ExpectExec("UPSERT INTO " + tableName + ` \(projection_name, aggregate_type, current_sequence, instance_id, timestamp\) VALUES \(\$1, \$2, \$3, \$4, NOW\(\)\), \(\$5, \$6, \$7, \$8, NOW\(\)\)`).
|
||||
WithArgs(
|
||||
projection,
|
||||
matcher,
|
||||
matcher,
|
||||
projection,
|
||||
matcher,
|
||||
matcher,
|
||||
args...,
|
||||
).
|
||||
WillReturnResult(
|
||||
sqlmock.NewResult(1, 1),
|
||||
@@ -204,51 +215,27 @@ func expectUpdateTwoCurrentSequence(tableName, projection string, sequences curr
|
||||
}
|
||||
}
|
||||
|
||||
type currentSequenceMatcher struct {
|
||||
seq currentSequences
|
||||
currentAggregate eventstore.AggregateType
|
||||
}
|
||||
|
||||
func (m *currentSequenceMatcher) Match(value driver.Value) bool {
|
||||
switch v := value.(type) {
|
||||
case string:
|
||||
if m.currentAggregate != "" {
|
||||
log.Printf("expected sequence of %s but got next aggregate type %s", m.currentAggregate, value)
|
||||
return false
|
||||
}
|
||||
_, ok := m.seq[eventstore.AggregateType(v)]
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
m.currentAggregate = eventstore.AggregateType(v)
|
||||
return true
|
||||
default:
|
||||
seq := m.seq[m.currentAggregate]
|
||||
m.currentAggregate = ""
|
||||
delete(m.seq, m.currentAggregate)
|
||||
return int64(seq) == value.(int64)
|
||||
}
|
||||
}
|
||||
|
||||
func expectUpdateCurrentSequenceErr(tableName, projection string, seq uint64, err error, aggregateType string) func(sqlmock.Sqlmock) {
|
||||
func expectUpdateCurrentSequenceErr(tableName, projection string, seq uint64, err error, aggregateType, instanceID string) func(sqlmock.Sqlmock) {
|
||||
return func(m sqlmock.Sqlmock) {
|
||||
m.ExpectExec("UPSERT INTO "+tableName+` \(projection_name, aggregate_type, current_sequence, timestamp\) VALUES \(\$1, \$2, \$3, NOW\(\)\)`).
|
||||
m.ExpectExec("UPSERT INTO "+tableName+` \(projection_name, aggregate_type, current_sequence, instance_id, timestamp\) VALUES \(\$1, \$2, \$3, \$4, NOW\(\)\)`).
|
||||
WithArgs(
|
||||
projection,
|
||||
aggregateType,
|
||||
seq,
|
||||
instanceID,
|
||||
).
|
||||
WillReturnError(err)
|
||||
}
|
||||
}
|
||||
|
||||
func expectUpdateCurrentSequenceNoRows(tableName, projection string, seq uint64, aggregateType string) func(sqlmock.Sqlmock) {
|
||||
func expectUpdateCurrentSequenceNoRows(tableName, projection string, seq uint64, aggregateType, instanceID string) func(sqlmock.Sqlmock) {
|
||||
return func(m sqlmock.Sqlmock) {
|
||||
m.ExpectExec("UPSERT INTO "+tableName+` \(projection_name, aggregate_type, current_sequence, timestamp\) VALUES \(\$1, \$2, \$3, NOW\(\)\)`).
|
||||
m.ExpectExec("UPSERT INTO "+tableName+` \(projection_name, aggregate_type, current_sequence, instance_id, timestamp\) VALUES \(\$1, \$2, \$3, \$4, NOW\(\)\)`).
|
||||
WithArgs(
|
||||
projection,
|
||||
aggregateType,
|
||||
seq,
|
||||
instanceID,
|
||||
).
|
||||
WillReturnResult(
|
||||
sqlmock.NewResult(0, 0),
|
||||
@@ -256,17 +243,18 @@ func expectUpdateCurrentSequenceNoRows(tableName, projection string, seq uint64,
|
||||
}
|
||||
}
|
||||
|
||||
func expectLock(lockTable, workerName string, d time.Duration) func(sqlmock.Sqlmock) {
|
||||
func expectLock(lockTable, workerName string, d time.Duration, instanceID string) func(sqlmock.Sqlmock) {
|
||||
return func(m sqlmock.Sqlmock) {
|
||||
m.ExpectExec(`INSERT INTO `+lockTable+
|
||||
` \(locker_id, locked_until, projection_name\) VALUES \(\$1, now\(\)\+\$2::INTERVAL, \$3\)`+
|
||||
` ON CONFLICT \(projection_name\)`+
|
||||
` \(locker_id, locked_until, projection_name, instance_id\) VALUES \(\$1, now\(\)\+\$2::INTERVAL, \$3\, \$4\)`+
|
||||
` ON CONFLICT \(projection_name, instance_id\)`+
|
||||
` DO UPDATE SET locker_id = \$1, locked_until = now\(\)\+\$2::INTERVAL`+
|
||||
` WHERE `+lockTable+`\.projection_name = \$3 AND \(`+lockTable+`\.locker_id = \$1 OR `+lockTable+`\.locked_until < now\(\)\)`).
|
||||
` WHERE `+lockTable+`\.projection_name = \$3 AND `+lockTable+`\.instance_id = \$4 AND \(`+lockTable+`\.locker_id = \$1 OR `+lockTable+`\.locked_until < now\(\)\)`).
|
||||
WithArgs(
|
||||
workerName,
|
||||
float64(d),
|
||||
projectionName,
|
||||
instanceID,
|
||||
).
|
||||
WillReturnResult(
|
||||
sqlmock.NewResult(1, 1),
|
||||
@@ -274,33 +262,35 @@ func expectLock(lockTable, workerName string, d time.Duration) func(sqlmock.Sqlm
|
||||
}
|
||||
}
|
||||
|
||||
func expectLockNoRows(lockTable, workerName string, d time.Duration) func(sqlmock.Sqlmock) {
|
||||
func expectLockNoRows(lockTable, workerName string, d time.Duration, instanceID string) func(sqlmock.Sqlmock) {
|
||||
return func(m sqlmock.Sqlmock) {
|
||||
m.ExpectExec(`INSERT INTO `+lockTable+
|
||||
` \(locker_id, locked_until, projection_name\) VALUES \(\$1, now\(\)\+\$2::INTERVAL, \$3\)`+
|
||||
` ON CONFLICT \(projection_name\)`+
|
||||
` \(locker_id, locked_until, projection_name, instance_id\) VALUES \(\$1, now\(\)\+\$2::INTERVAL, \$3\, \$4\)`+
|
||||
` ON CONFLICT \(projection_name, instance_id\)`+
|
||||
` DO UPDATE SET locker_id = \$1, locked_until = now\(\)\+\$2::INTERVAL`+
|
||||
` WHERE `+lockTable+`\.projection_name = \$3 AND \(`+lockTable+`\.locker_id = \$1 OR `+lockTable+`\.locked_until < now\(\)\)`).
|
||||
` WHERE `+lockTable+`\.projection_name = \$3 AND `+lockTable+`\.instance_id = \$4 AND \(`+lockTable+`\.locker_id = \$1 OR `+lockTable+`\.locked_until < now\(\)\)`).
|
||||
WithArgs(
|
||||
workerName,
|
||||
float64(d),
|
||||
projectionName,
|
||||
instanceID,
|
||||
).
|
||||
WillReturnResult(driver.ResultNoRows)
|
||||
}
|
||||
}
|
||||
|
||||
func expectLockErr(lockTable, workerName string, d time.Duration, err error) func(sqlmock.Sqlmock) {
|
||||
func expectLockErr(lockTable, workerName string, d time.Duration, instanceID string, err error) func(sqlmock.Sqlmock) {
|
||||
return func(m sqlmock.Sqlmock) {
|
||||
m.ExpectExec(`INSERT INTO `+lockTable+
|
||||
` \(locker_id, locked_until, projection_name\) VALUES \(\$1, now\(\)\+\$2::INTERVAL, \$3\)`+
|
||||
` ON CONFLICT \(projection_name\)`+
|
||||
` \(locker_id, locked_until, projection_name, instance_id\) VALUES \(\$1, now\(\)\+\$2::INTERVAL, \$3\, \$4\)`+
|
||||
` ON CONFLICT \(projection_name, instance_id\)`+
|
||||
` DO UPDATE SET locker_id = \$1, locked_until = now\(\)\+\$2::INTERVAL`+
|
||||
` WHERE `+lockTable+`\.projection_name = \$3 AND \(`+lockTable+`\.locker_id = \$1 OR `+lockTable+`\.locked_until < now\(\)\)`).
|
||||
` WHERE `+lockTable+`\.projection_name = \$3 AND `+lockTable+`\.instance_id = \$4 AND \(`+lockTable+`\.locker_id = \$1 OR `+lockTable+`\.locked_until < now\(\)\)`).
|
||||
WithArgs(
|
||||
workerName,
|
||||
float64(d),
|
||||
projectionName,
|
||||
instanceID,
|
||||
).
|
||||
WillReturnError(err)
|
||||
}
|
||||
|
@@ -101,15 +101,34 @@ func (h *StatementHandler) SearchQuery() (*eventstore.SearchQueryBuilder, uint64
|
||||
|
||||
queryBuilder := eventstore.NewSearchQueryBuilder(eventstore.ColumnsEvent).Limit(h.bulkLimit)
|
||||
for _, aggregateType := range h.aggregates {
|
||||
instances := make([]string, 0)
|
||||
for _, sequence := range sequences[aggregateType] {
|
||||
instances = appendToIgnoredInstances(instances, sequence.instanceID)
|
||||
queryBuilder.
|
||||
AddQuery().
|
||||
AggregateTypes(aggregateType).
|
||||
SequenceGreater(sequence.sequence).
|
||||
InstanceID(sequence.instanceID)
|
||||
}
|
||||
queryBuilder.
|
||||
AddQuery().
|
||||
AggregateTypes(aggregateType).
|
||||
SequenceGreater(sequences[aggregateType])
|
||||
SequenceGreater(0).
|
||||
ExcludedInstanceID(instances...)
|
||||
}
|
||||
|
||||
return queryBuilder, h.bulkLimit, nil
|
||||
}
|
||||
|
||||
func appendToIgnoredInstances(instances []string, id string) []string {
|
||||
for _, instance := range instances {
|
||||
if instance == id {
|
||||
return instances
|
||||
}
|
||||
}
|
||||
return append(instances, id)
|
||||
}
|
||||
|
||||
//Update implements handler.Update
|
||||
func (h *StatementHandler) Update(ctx context.Context, stmts []*handler.Statement, reduce handler.Reduce) (unexecutedStmts []*handler.Statement, err error) {
|
||||
tx, err := h.client.BeginTx(ctx, nil)
|
||||
@@ -127,7 +146,7 @@ func (h *StatementHandler) Update(ctx context.Context, stmts []*handler.Statemen
|
||||
// because there could be events between current sequence and a creation event
|
||||
// and we cannot check via stmt.PreviousSequence
|
||||
if stmts[0].PreviousSequence == 0 {
|
||||
previousStmts, err := h.fetchPreviousStmts(ctx, stmts[0].Sequence, sequences, reduce)
|
||||
previousStmts, err := h.fetchPreviousStmts(ctx, stmts[0].Sequence, stmts[0].InstanceID, sequences, reduce)
|
||||
if err != nil {
|
||||
tx.Rollback()
|
||||
return stmts, err
|
||||
@@ -164,27 +183,25 @@ func (h *StatementHandler) Update(ctx context.Context, stmts []*handler.Statemen
|
||||
return unexecutedStmts, nil
|
||||
}
|
||||
|
||||
func (h *StatementHandler) fetchPreviousStmts(
|
||||
ctx context.Context,
|
||||
stmtSeq uint64,
|
||||
sequences currentSequences,
|
||||
reduce handler.Reduce,
|
||||
) (previousStmts []*handler.Statement, err error) {
|
||||
func (h *StatementHandler) fetchPreviousStmts(ctx context.Context, stmtSeq uint64, instanceID string, sequences currentSequences, reduce handler.Reduce) (previousStmts []*handler.Statement, err error) {
|
||||
|
||||
query := eventstore.NewSearchQueryBuilder(eventstore.ColumnsEvent)
|
||||
queriesAdded := false
|
||||
for _, aggregateType := range h.aggregates {
|
||||
if stmtSeq <= sequences[aggregateType] {
|
||||
continue
|
||||
for _, sequence := range sequences[aggregateType] {
|
||||
if stmtSeq <= sequence.sequence && instanceID == sequence.instanceID {
|
||||
continue
|
||||
}
|
||||
|
||||
query.
|
||||
AddQuery().
|
||||
AggregateTypes(aggregateType).
|
||||
SequenceGreater(sequence.sequence).
|
||||
SequenceLess(stmtSeq).
|
||||
InstanceID(sequence.instanceID)
|
||||
|
||||
queriesAdded = true
|
||||
}
|
||||
|
||||
query.
|
||||
AddQuery().
|
||||
AggregateTypes(aggregateType).
|
||||
SequenceGreater(sequences[aggregateType]).
|
||||
SequenceLess(stmtSeq)
|
||||
|
||||
queriesAdded = true
|
||||
}
|
||||
|
||||
if !queriesAdded {
|
||||
@@ -214,16 +231,19 @@ func (h *StatementHandler) executeStmts(
|
||||
|
||||
lastSuccessfulIdx := -1
|
||||
for i, stmt := range stmts {
|
||||
if stmt.Sequence <= sequences[stmt.AggregateType] {
|
||||
continue
|
||||
}
|
||||
if stmt.PreviousSequence > 0 && stmt.PreviousSequence != sequences[stmt.AggregateType] {
|
||||
logging.WithFields("projection", h.ProjectionName, "aggregateType", stmt.AggregateType, "sequence", stmt.Sequence, "prevSeq", stmt.PreviousSequence, "currentSeq", sequences[stmt.AggregateType]).Warn("sequences do not match")
|
||||
break
|
||||
for _, sequence := range sequences[stmt.AggregateType] {
|
||||
if stmt.Sequence <= sequence.sequence && stmt.InstanceID == sequence.instanceID {
|
||||
continue
|
||||
}
|
||||
if stmt.PreviousSequence > 0 && stmt.PreviousSequence != sequence.sequence && stmt.InstanceID == sequence.instanceID {
|
||||
logging.WithFields("projection", h.ProjectionName, "aggregateType", stmt.AggregateType, "sequence", stmt.Sequence, "prevSeq", stmt.PreviousSequence, "currentSeq", sequences[stmt.AggregateType]).Warn("sequences do not match")
|
||||
break
|
||||
}
|
||||
}
|
||||
err := h.executeStmt(tx, stmt)
|
||||
if err == nil {
|
||||
sequences[stmt.AggregateType], lastSuccessfulIdx = stmt.Sequence, i
|
||||
updateSequences(sequences, stmt)
|
||||
lastSuccessfulIdx = i
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -232,7 +252,9 @@ func (h *StatementHandler) executeStmts(
|
||||
break
|
||||
}
|
||||
|
||||
sequences[stmt.AggregateType], lastSuccessfulIdx = stmt.Sequence, i
|
||||
updateSequences(sequences, stmt)
|
||||
lastSuccessfulIdx = i
|
||||
continue
|
||||
}
|
||||
return lastSuccessfulIdx
|
||||
}
|
||||
@@ -261,3 +283,16 @@ func (h *StatementHandler) executeStmt(tx *sql.Tx, stmt *handler.Statement) erro
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func updateSequences(sequences currentSequences, stmt *handler.Statement) {
|
||||
for _, sequence := range sequences[stmt.AggregateType] {
|
||||
if sequence.instanceID == stmt.InstanceID {
|
||||
sequence.sequence = stmt.Sequence
|
||||
return
|
||||
}
|
||||
}
|
||||
sequences[stmt.AggregateType] = append(sequences[stmt.AggregateType], &instanceSequence{
|
||||
instanceID: stmt.InstanceID,
|
||||
sequence: stmt.Sequence,
|
||||
})
|
||||
}
|
||||
|
@@ -9,6 +9,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/DATA-DOG/go-sqlmock"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/caos/zitadel/internal/eventstore"
|
||||
"github.com/caos/zitadel/internal/eventstore/handler"
|
||||
@@ -97,13 +98,18 @@ func TestProjectionHandler_SearchQuery(t *testing.T) {
|
||||
return err == nil
|
||||
},
|
||||
expectations: []mockExpectation{
|
||||
expectCurrentSequence("my_sequences", "my_projection", 5, "testAgg"),
|
||||
expectCurrentSequence("my_sequences", "my_projection", 5, "testAgg", "instanceID"),
|
||||
},
|
||||
SearchQueryBuilder: eventstore.
|
||||
NewSearchQueryBuilder(eventstore.ColumnsEvent).
|
||||
AddQuery().
|
||||
AggregateTypes("testAgg").
|
||||
SequenceGreater(5).
|
||||
InstanceID("instanceID").
|
||||
Or().
|
||||
AggregateTypes("testAgg").
|
||||
SequenceGreater(0).
|
||||
ExcludedInstanceID("instanceID").
|
||||
Builder().
|
||||
Limit(5),
|
||||
},
|
||||
@@ -225,7 +231,7 @@ func TestStatementHandler_Update(t *testing.T) {
|
||||
want: want{
|
||||
expectations: []mockExpectation{
|
||||
expectBegin(),
|
||||
expectCurrentSequence("my_sequences", "my_projection", 5, "testAgg"),
|
||||
expectCurrentSequence("my_sequences", "my_projection", 5, "testAgg", "instanceID"),
|
||||
expectRollback(),
|
||||
},
|
||||
isErr: func(err error) bool {
|
||||
@@ -262,7 +268,7 @@ func TestStatementHandler_Update(t *testing.T) {
|
||||
want: want{
|
||||
expectations: []mockExpectation{
|
||||
expectBegin(),
|
||||
expectCurrentSequence("my_sequences", "my_projection", 5, "testAgg"),
|
||||
expectCurrentSequence("my_sequences", "my_projection", 5, "testAgg", "instanceID"),
|
||||
expectCommit(),
|
||||
},
|
||||
isErr: func(err error) bool {
|
||||
@@ -287,6 +293,7 @@ func TestStatementHandler_Update(t *testing.T) {
|
||||
aggregateType: "agg",
|
||||
sequence: 7,
|
||||
previousSequence: 5,
|
||||
instanceID: "instanceID",
|
||||
},
|
||||
[]handler.Column{
|
||||
{
|
||||
@@ -299,11 +306,11 @@ func TestStatementHandler_Update(t *testing.T) {
|
||||
want: want{
|
||||
expectations: []mockExpectation{
|
||||
expectBegin(),
|
||||
expectCurrentSequence("my_sequences", "my_projection", 5, "agg"),
|
||||
expectCurrentSequence("my_sequences", "my_projection", 5, "agg", "instanceID"),
|
||||
expectSavePoint(),
|
||||
expectCreate("my_projection", []string{"col"}, []string{"$1"}),
|
||||
expectSavePointRelease(),
|
||||
expectUpdateCurrentSequenceNoRows("my_sequences", "my_projection", 7, "agg"),
|
||||
expectUpdateCurrentSequenceNoRows("my_sequences", "my_projection", 7, "agg", "instanceID"),
|
||||
expectRollback(),
|
||||
},
|
||||
isErr: func(err error) bool {
|
||||
@@ -328,6 +335,7 @@ func TestStatementHandler_Update(t *testing.T) {
|
||||
aggregateType: "agg",
|
||||
sequence: 7,
|
||||
previousSequence: 5,
|
||||
instanceID: "instanceID",
|
||||
},
|
||||
[]handler.Column{
|
||||
{
|
||||
@@ -340,11 +348,11 @@ func TestStatementHandler_Update(t *testing.T) {
|
||||
want: want{
|
||||
expectations: []mockExpectation{
|
||||
expectBegin(),
|
||||
expectCurrentSequence("my_sequences", "my_projection", 5, "agg"),
|
||||
expectCurrentSequence("my_sequences", "my_projection", 5, "agg", "instanceID"),
|
||||
expectSavePoint(),
|
||||
expectCreate("my_projection", []string{"col"}, []string{"$1"}),
|
||||
expectSavePointRelease(),
|
||||
expectUpdateCurrentSequence("my_sequences", "my_projection", 7, "agg"),
|
||||
expectUpdateCurrentSequence("my_sequences", "my_projection", 7, "agg", "instanceID"),
|
||||
expectCommitErr(sql.ErrConnDone),
|
||||
},
|
||||
isErr: func(err error) bool {
|
||||
@@ -368,14 +376,15 @@ func TestStatementHandler_Update(t *testing.T) {
|
||||
aggregateType: "testAgg",
|
||||
sequence: 7,
|
||||
previousSequence: 5,
|
||||
instanceID: "instanceID",
|
||||
}),
|
||||
},
|
||||
},
|
||||
want: want{
|
||||
expectations: []mockExpectation{
|
||||
expectBegin(),
|
||||
expectCurrentSequence("my_sequences", "my_projection", 5, "testAgg"),
|
||||
expectUpdateCurrentSequence("my_sequences", "my_projection", 7, "testAgg"),
|
||||
expectCurrentSequence("my_sequences", "my_projection", 5, "testAgg", "instanceID"),
|
||||
expectUpdateCurrentSequence("my_sequences", "my_projection", 7, "testAgg", "instanceID"),
|
||||
expectCommit(),
|
||||
},
|
||||
isErr: func(err error) bool {
|
||||
@@ -399,14 +408,15 @@ func TestStatementHandler_Update(t *testing.T) {
|
||||
aggregateType: "testAgg",
|
||||
sequence: 7,
|
||||
previousSequence: 0,
|
||||
instanceID: "instanceID",
|
||||
}),
|
||||
},
|
||||
},
|
||||
want: want{
|
||||
expectations: []mockExpectation{
|
||||
expectBegin(),
|
||||
expectCurrentSequence("my_sequences", "my_projection", 5, "testAgg"),
|
||||
expectUpdateCurrentSequence("my_sequences", "my_projection", 7, "testAgg"),
|
||||
expectCurrentSequence("my_sequences", "my_projection", 5, "testAgg", "instanceID"),
|
||||
expectUpdateCurrentSequence("my_sequences", "my_projection", 7, "testAgg", "instanceID"),
|
||||
expectCommit(),
|
||||
},
|
||||
isErr: func(err error) bool {
|
||||
@@ -423,6 +433,7 @@ func TestStatementHandler_Update(t *testing.T) {
|
||||
AggregateType: "testAgg",
|
||||
Sequence: 6,
|
||||
PreviousAggregateSequence: 5,
|
||||
InstanceID: "instanceID",
|
||||
},
|
||||
),
|
||||
),
|
||||
@@ -435,6 +446,7 @@ func TestStatementHandler_Update(t *testing.T) {
|
||||
aggregateType: "testAgg",
|
||||
sequence: 7,
|
||||
previousSequence: 0,
|
||||
instanceID: "instanceID",
|
||||
}),
|
||||
},
|
||||
reduce: testReduce(),
|
||||
@@ -442,8 +454,8 @@ func TestStatementHandler_Update(t *testing.T) {
|
||||
want: want{
|
||||
expectations: []mockExpectation{
|
||||
expectBegin(),
|
||||
expectCurrentSequence("my_sequences", "my_projection", 5, "testAgg"),
|
||||
expectUpdateCurrentSequence("my_sequences", "my_projection", 7, "testAgg"),
|
||||
expectCurrentSequence("my_sequences", "my_projection", 5, "testAgg", "instanceID"),
|
||||
expectUpdateCurrentSequence("my_sequences", "my_projection", 7, "testAgg", "instanceID"),
|
||||
expectCommit(),
|
||||
},
|
||||
isErr: func(err error) bool {
|
||||
@@ -537,7 +549,9 @@ func TestProjectionHandler_fetchPreviousStmts(t *testing.T) {
|
||||
ctx: context.Background(),
|
||||
reduce: testReduce(),
|
||||
sequences: currentSequences{
|
||||
"testAgg": 5,
|
||||
"testAgg": []*instanceSequence{
|
||||
{sequence: 5},
|
||||
},
|
||||
},
|
||||
stmtSeq: 6,
|
||||
},
|
||||
@@ -560,7 +574,9 @@ func TestProjectionHandler_fetchPreviousStmts(t *testing.T) {
|
||||
ctx: context.Background(),
|
||||
reduce: testReduce(),
|
||||
sequences: currentSequences{
|
||||
"testAgg": 5,
|
||||
"testAgg": []*instanceSequence{
|
||||
{sequence: 5},
|
||||
},
|
||||
},
|
||||
stmtSeq: 6,
|
||||
},
|
||||
@@ -582,7 +598,9 @@ func TestProjectionHandler_fetchPreviousStmts(t *testing.T) {
|
||||
ctx: context.Background(),
|
||||
reduce: testReduce(),
|
||||
sequences: currentSequences{
|
||||
"testAgg": 5,
|
||||
"testAgg": []*instanceSequence{
|
||||
{sequence: 5},
|
||||
},
|
||||
},
|
||||
stmtSeq: 10,
|
||||
},
|
||||
@@ -626,7 +644,9 @@ func TestProjectionHandler_fetchPreviousStmts(t *testing.T) {
|
||||
ctx: context.Background(),
|
||||
reduce: testReduceErr(errReduce),
|
||||
sequences: currentSequences{
|
||||
"testAgg": 5,
|
||||
"testAgg": []*instanceSequence{
|
||||
{sequence: 5},
|
||||
},
|
||||
},
|
||||
stmtSeq: 10,
|
||||
},
|
||||
@@ -667,7 +687,7 @@ func TestProjectionHandler_fetchPreviousStmts(t *testing.T) {
|
||||
}),
|
||||
aggregates: tt.fields.aggregates,
|
||||
}
|
||||
stmts, err := h.fetchPreviousStmts(tt.args.ctx, tt.args.stmtSeq, tt.args.sequences, tt.args.reduce)
|
||||
stmts, err := h.fetchPreviousStmts(tt.args.ctx, tt.args.stmtSeq, "", tt.args.sequences, tt.args.reduce)
|
||||
if !tt.want.isErr(err) {
|
||||
t.Errorf("ProjectionHandler.prepareBulkStmts() error = %v", err)
|
||||
return
|
||||
@@ -720,7 +740,9 @@ func TestStatementHandler_executeStmts(t *testing.T) {
|
||||
}),
|
||||
},
|
||||
sequences: currentSequences{
|
||||
"agg": 5,
|
||||
"agg": []*instanceSequence{
|
||||
{sequence: 5},
|
||||
},
|
||||
},
|
||||
},
|
||||
want: want{
|
||||
@@ -762,7 +784,9 @@ func TestStatementHandler_executeStmts(t *testing.T) {
|
||||
}),
|
||||
},
|
||||
sequences: currentSequences{
|
||||
"agg": 2,
|
||||
"agg": []*instanceSequence{
|
||||
{sequence: 2},
|
||||
},
|
||||
},
|
||||
},
|
||||
want: want{
|
||||
@@ -824,7 +848,9 @@ func TestStatementHandler_executeStmts(t *testing.T) {
|
||||
}),
|
||||
},
|
||||
sequences: currentSequences{
|
||||
"agg": 2,
|
||||
"agg": []*instanceSequence{
|
||||
{sequence: 2},
|
||||
},
|
||||
},
|
||||
},
|
||||
want: want{
|
||||
@@ -891,7 +917,9 @@ func TestStatementHandler_executeStmts(t *testing.T) {
|
||||
}),
|
||||
},
|
||||
sequences: currentSequences{
|
||||
"agg": 2,
|
||||
"agg": []*instanceSequence{
|
||||
{sequence: 2},
|
||||
},
|
||||
},
|
||||
},
|
||||
want: want{
|
||||
@@ -979,7 +1007,9 @@ func TestStatementHandler_executeStmts(t *testing.T) {
|
||||
),
|
||||
},
|
||||
sequences: currentSequences{
|
||||
"agg": 2,
|
||||
"agg": []*instanceSequence{
|
||||
{sequence: 2},
|
||||
},
|
||||
},
|
||||
},
|
||||
want: want{
|
||||
@@ -1309,9 +1339,7 @@ func TestStatementHandler_currentSequence(t *testing.T) {
|
||||
expectations: []mockExpectation{
|
||||
expectCurrentSequenceNoRows("my_table", "my_projection"),
|
||||
},
|
||||
sequences: currentSequences{
|
||||
"agg": 0,
|
||||
},
|
||||
sequences: currentSequences{},
|
||||
},
|
||||
},
|
||||
{
|
||||
@@ -1331,9 +1359,7 @@ func TestStatementHandler_currentSequence(t *testing.T) {
|
||||
expectations: []mockExpectation{
|
||||
expectCurrentSequenceScanErr("my_table", "my_projection"),
|
||||
},
|
||||
sequences: currentSequences{
|
||||
"agg": 0,
|
||||
},
|
||||
sequences: currentSequences{},
|
||||
},
|
||||
},
|
||||
{
|
||||
@@ -1351,10 +1377,15 @@ func TestStatementHandler_currentSequence(t *testing.T) {
|
||||
return errors.Is(err, nil)
|
||||
},
|
||||
expectations: []mockExpectation{
|
||||
expectCurrentSequence("my_table", "my_projection", 5, "agg"),
|
||||
expectCurrentSequence("my_table", "my_projection", 5, "agg", "instanceID"),
|
||||
},
|
||||
sequences: currentSequences{
|
||||
"agg": 5,
|
||||
"agg": []*instanceSequence{
|
||||
{
|
||||
sequence: 5,
|
||||
instanceID: "instanceID",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -1404,9 +1435,7 @@ func TestStatementHandler_currentSequence(t *testing.T) {
|
||||
}
|
||||
|
||||
for _, aggregateType := range tt.fields.aggregates {
|
||||
if seq[aggregateType] != tt.want.sequences[aggregateType] {
|
||||
t.Errorf("unexpected sequence in aggregate type %s: want %d got %d", aggregateType, tt.want.sequences[aggregateType], seq[aggregateType])
|
||||
}
|
||||
assert.Equal(t, tt.want.sequences[aggregateType], seq[aggregateType])
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -1440,7 +1469,12 @@ func TestStatementHandler_updateCurrentSequence(t *testing.T) {
|
||||
},
|
||||
args: args{
|
||||
sequences: currentSequences{
|
||||
"agg": 5,
|
||||
"agg": []*instanceSequence{
|
||||
{
|
||||
sequence: 5,
|
||||
instanceID: "instanceID",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
want: want{
|
||||
@@ -1448,7 +1482,7 @@ func TestStatementHandler_updateCurrentSequence(t *testing.T) {
|
||||
return errors.Is(err, sql.ErrConnDone)
|
||||
},
|
||||
expectations: []mockExpectation{
|
||||
expectUpdateCurrentSequenceErr("my_table", "my_projection", 5, sql.ErrConnDone, "agg"),
|
||||
expectUpdateCurrentSequenceErr("my_table", "my_projection", 5, sql.ErrConnDone, "agg", "instanceID"),
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -1461,7 +1495,12 @@ func TestStatementHandler_updateCurrentSequence(t *testing.T) {
|
||||
},
|
||||
args: args{
|
||||
sequences: currentSequences{
|
||||
"agg": 5,
|
||||
"agg": []*instanceSequence{
|
||||
{
|
||||
sequence: 5,
|
||||
instanceID: "instanceID",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
want: want{
|
||||
@@ -1469,7 +1508,7 @@ func TestStatementHandler_updateCurrentSequence(t *testing.T) {
|
||||
return errors.As(err, &errSeqNotUpdated)
|
||||
},
|
||||
expectations: []mockExpectation{
|
||||
expectUpdateCurrentSequenceNoRows("my_table", "my_projection", 5, "agg"),
|
||||
expectUpdateCurrentSequenceNoRows("my_table", "my_projection", 5, "agg", "instanceID"),
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -1482,7 +1521,12 @@ func TestStatementHandler_updateCurrentSequence(t *testing.T) {
|
||||
},
|
||||
args: args{
|
||||
sequences: currentSequences{
|
||||
"agg": 5,
|
||||
"agg": []*instanceSequence{
|
||||
{
|
||||
sequence: 5,
|
||||
instanceID: "instanceID",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
want: want{
|
||||
@@ -1490,7 +1534,7 @@ func TestStatementHandler_updateCurrentSequence(t *testing.T) {
|
||||
return err == nil
|
||||
},
|
||||
expectations: []mockExpectation{
|
||||
expectUpdateCurrentSequence("my_table", "my_projection", 5, "agg"),
|
||||
expectUpdateCurrentSequence("my_table", "my_projection", 5, "agg", "instanceID"),
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -1503,8 +1547,18 @@ func TestStatementHandler_updateCurrentSequence(t *testing.T) {
|
||||
},
|
||||
args: args{
|
||||
sequences: currentSequences{
|
||||
"agg": 5,
|
||||
"agg2": 6,
|
||||
"agg": []*instanceSequence{
|
||||
{
|
||||
sequence: 5,
|
||||
instanceID: "instanceID",
|
||||
},
|
||||
},
|
||||
"agg2": []*instanceSequence{
|
||||
{
|
||||
sequence: 6,
|
||||
instanceID: "instanceID",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
want: want{
|
||||
@@ -1513,9 +1567,19 @@ func TestStatementHandler_updateCurrentSequence(t *testing.T) {
|
||||
},
|
||||
expectations: []mockExpectation{
|
||||
expectUpdateTwoCurrentSequence("my_table", "my_projection", currentSequences{
|
||||
"agg": 5,
|
||||
"agg2": 6},
|
||||
),
|
||||
"agg": []*instanceSequence{
|
||||
{
|
||||
sequence: 5,
|
||||
instanceID: "instanceID",
|
||||
},
|
||||
},
|
||||
"agg2": []*instanceSequence{
|
||||
{
|
||||
sequence: 6,
|
||||
instanceID: "instanceID",
|
||||
},
|
||||
},
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
|
@@ -15,15 +15,15 @@ import (
|
||||
|
||||
const (
|
||||
lockStmtFormat = "INSERT INTO %[1]s" +
|
||||
" (locker_id, locked_until, projection_name) VALUES ($1, now()+$2::INTERVAL, $3)" +
|
||||
" ON CONFLICT (projection_name)" +
|
||||
" (locker_id, locked_until, projection_name, instance_id) VALUES ($1, now()+$2::INTERVAL, $3, $4)" +
|
||||
" ON CONFLICT (projection_name, instance_id)" +
|
||||
" DO UPDATE SET locker_id = $1, locked_until = now()+$2::INTERVAL" +
|
||||
" WHERE %[1]s.projection_name = $3 AND (%[1]s.locker_id = $1 OR %[1]s.locked_until < now())"
|
||||
" WHERE %[1]s.projection_name = $3 AND %[1]s.instance_id = $4 AND (%[1]s.locker_id = $1 OR %[1]s.locked_until < now())"
|
||||
)
|
||||
|
||||
type Locker interface {
|
||||
Lock(ctx context.Context, lockDuration time.Duration) <-chan error
|
||||
Unlock() error
|
||||
Lock(ctx context.Context, lockDuration time.Duration, instanceID string) <-chan error
|
||||
Unlock(instanceID string) error
|
||||
}
|
||||
|
||||
type locker struct {
|
||||
@@ -47,18 +47,18 @@ func NewLocker(client *sql.DB, lockTable, projectionName string) Locker {
|
||||
}
|
||||
}
|
||||
|
||||
func (h *locker) Lock(ctx context.Context, lockDuration time.Duration) <-chan error {
|
||||
func (h *locker) Lock(ctx context.Context, lockDuration time.Duration, instanceID string) <-chan error {
|
||||
errs := make(chan error)
|
||||
go h.handleLock(ctx, errs, lockDuration)
|
||||
go h.handleLock(ctx, errs, lockDuration, instanceID)
|
||||
return errs
|
||||
}
|
||||
|
||||
func (h *locker) handleLock(ctx context.Context, errs chan error, lockDuration time.Duration) {
|
||||
func (h *locker) handleLock(ctx context.Context, errs chan error, lockDuration time.Duration, instanceID string) {
|
||||
renewLock := time.NewTimer(0)
|
||||
for {
|
||||
select {
|
||||
case <-renewLock.C:
|
||||
errs <- h.renewLock(ctx, lockDuration)
|
||||
errs <- h.renewLock(ctx, lockDuration, instanceID)
|
||||
//refresh the lock 500ms before it times out. 500ms should be enough for one transaction
|
||||
renewLock.Reset(lockDuration - (500 * time.Millisecond))
|
||||
case <-ctx.Done():
|
||||
@@ -69,9 +69,9 @@ func (h *locker) handleLock(ctx context.Context, errs chan error, lockDuration t
|
||||
}
|
||||
}
|
||||
|
||||
func (h *locker) renewLock(ctx context.Context, lockDuration time.Duration) error {
|
||||
func (h *locker) renewLock(ctx context.Context, lockDuration time.Duration, instanceID string) error {
|
||||
//the unit of crdb interval is seconds (https://www.cockroachlabs.com/docs/stable/interval.html).
|
||||
res, err := h.client.Exec(h.lockStmt, h.workerName, lockDuration.Seconds(), h.projectionName)
|
||||
res, err := h.client.Exec(h.lockStmt, h.workerName, lockDuration.Seconds(), h.projectionName, instanceID)
|
||||
if err != nil {
|
||||
return errors.ThrowInternal(err, "CRDB-uaDoR", "unable to execute lock")
|
||||
}
|
||||
@@ -83,8 +83,8 @@ func (h *locker) renewLock(ctx context.Context, lockDuration time.Duration) erro
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *locker) Unlock() error {
|
||||
_, err := h.client.Exec(h.lockStmt, h.workerName, float64(0), h.projectionName)
|
||||
func (h *locker) Unlock(instanceID string) error {
|
||||
_, err := h.client.Exec(h.lockStmt, h.workerName, float64(0), h.projectionName, instanceID)
|
||||
if err != nil {
|
||||
return errors.ThrowUnknown(err, "CRDB-JjfwO", "unlock failed")
|
||||
}
|
||||
|
@@ -32,6 +32,7 @@ func TestStatementHandler_handleLock(t *testing.T) {
|
||||
lockDuration time.Duration
|
||||
ctx context.Context
|
||||
errMock *errsMock
|
||||
instanceID string
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
@@ -42,9 +43,9 @@ func TestStatementHandler_handleLock(t *testing.T) {
|
||||
name: "lock fails",
|
||||
want: want{
|
||||
expectations: []mockExpectation{
|
||||
expectLock(lockTable, workerName, 2),
|
||||
expectLock(lockTable, workerName, 2),
|
||||
expectLockErr(lockTable, workerName, 2, errLock),
|
||||
expectLock(lockTable, workerName, 2, "instanceID"),
|
||||
expectLock(lockTable, workerName, 2, "instanceID"),
|
||||
expectLockErr(lockTable, workerName, 2, "instanceID", errLock),
|
||||
},
|
||||
},
|
||||
args: args{
|
||||
@@ -55,14 +56,15 @@ func TestStatementHandler_handleLock(t *testing.T) {
|
||||
successfulIters: 2,
|
||||
shouldErr: true,
|
||||
},
|
||||
instanceID: "instanceID",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "success",
|
||||
want: want{
|
||||
expectations: []mockExpectation{
|
||||
expectLock(lockTable, workerName, 2),
|
||||
expectLock(lockTable, workerName, 2),
|
||||
expectLock(lockTable, workerName, 2, "instanceID"),
|
||||
expectLock(lockTable, workerName, 2, "instanceID"),
|
||||
},
|
||||
},
|
||||
args: args{
|
||||
@@ -72,6 +74,7 @@ func TestStatementHandler_handleLock(t *testing.T) {
|
||||
errs: make(chan error),
|
||||
successfulIters: 2,
|
||||
},
|
||||
instanceID: "instanceID",
|
||||
},
|
||||
},
|
||||
}
|
||||
@@ -96,7 +99,7 @@ func TestStatementHandler_handleLock(t *testing.T) {
|
||||
|
||||
go tt.args.errMock.handleErrs(t, cancel)
|
||||
|
||||
go h.handleLock(ctx, tt.args.errMock.errs, tt.args.lockDuration)
|
||||
go h.handleLock(ctx, tt.args.errMock.errs, tt.args.lockDuration, tt.args.instanceID)
|
||||
|
||||
<-ctx.Done()
|
||||
|
||||
@@ -115,6 +118,7 @@ func TestStatementHandler_renewLock(t *testing.T) {
|
||||
}
|
||||
type args struct {
|
||||
lockDuration time.Duration
|
||||
instanceID string
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
@@ -125,7 +129,7 @@ func TestStatementHandler_renewLock(t *testing.T) {
|
||||
name: "lock fails",
|
||||
want: want{
|
||||
expectations: []mockExpectation{
|
||||
expectLockErr(lockTable, workerName, 1, sql.ErrTxDone),
|
||||
expectLockErr(lockTable, workerName, 1, "instanceID", sql.ErrTxDone),
|
||||
},
|
||||
isErr: func(err error) bool {
|
||||
return errors.Is(err, sql.ErrTxDone)
|
||||
@@ -133,13 +137,14 @@ func TestStatementHandler_renewLock(t *testing.T) {
|
||||
},
|
||||
args: args{
|
||||
lockDuration: 1 * time.Second,
|
||||
instanceID: "instanceID",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "lock no rows",
|
||||
want: want{
|
||||
expectations: []mockExpectation{
|
||||
expectLockNoRows(lockTable, workerName, 2),
|
||||
expectLockNoRows(lockTable, workerName, 2, "instanceID"),
|
||||
},
|
||||
isErr: func(err error) bool {
|
||||
return errors.As(err, &renewNoRowsAffectedErr)
|
||||
@@ -147,13 +152,14 @@ func TestStatementHandler_renewLock(t *testing.T) {
|
||||
},
|
||||
args: args{
|
||||
lockDuration: 2 * time.Second,
|
||||
instanceID: "instanceID",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "success",
|
||||
want: want{
|
||||
expectations: []mockExpectation{
|
||||
expectLock(lockTable, workerName, 3),
|
||||
expectLock(lockTable, workerName, 3, "instanceID"),
|
||||
},
|
||||
isErr: func(err error) bool {
|
||||
return errors.Is(err, nil)
|
||||
@@ -161,6 +167,7 @@ func TestStatementHandler_renewLock(t *testing.T) {
|
||||
},
|
||||
args: args{
|
||||
lockDuration: 3 * time.Second,
|
||||
instanceID: "instanceID",
|
||||
},
|
||||
},
|
||||
}
|
||||
@@ -181,7 +188,7 @@ func TestStatementHandler_renewLock(t *testing.T) {
|
||||
expectation(mock)
|
||||
}
|
||||
|
||||
err = h.renewLock(context.Background(), tt.args.lockDuration)
|
||||
err = h.renewLock(context.Background(), tt.args.lockDuration, tt.args.instanceID)
|
||||
if !tt.want.isErr(err) {
|
||||
t.Errorf("unexpected error = %v", err)
|
||||
}
|
||||
@@ -199,15 +206,22 @@ func TestStatementHandler_Unlock(t *testing.T) {
|
||||
expectations []mockExpectation
|
||||
isErr func(err error) bool
|
||||
}
|
||||
type args struct {
|
||||
instanceID string
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want want
|
||||
}{
|
||||
{
|
||||
name: "unlock fails",
|
||||
args: args{
|
||||
instanceID: "instanceID",
|
||||
},
|
||||
want: want{
|
||||
expectations: []mockExpectation{
|
||||
expectLockErr(lockTable, workerName, 0, sql.ErrTxDone),
|
||||
expectLockErr(lockTable, workerName, 0, "instanceID", sql.ErrTxDone),
|
||||
},
|
||||
isErr: func(err error) bool {
|
||||
return errors.Is(err, sql.ErrTxDone)
|
||||
@@ -216,9 +230,12 @@ func TestStatementHandler_Unlock(t *testing.T) {
|
||||
},
|
||||
{
|
||||
name: "success",
|
||||
args: args{
|
||||
instanceID: "instanceID",
|
||||
},
|
||||
want: want{
|
||||
expectations: []mockExpectation{
|
||||
expectLock(lockTable, workerName, 0),
|
||||
expectLock(lockTable, workerName, 0, "instanceID"),
|
||||
},
|
||||
isErr: func(err error) bool {
|
||||
return errors.Is(err, nil)
|
||||
@@ -243,7 +260,7 @@ func TestStatementHandler_Unlock(t *testing.T) {
|
||||
expectation(mock)
|
||||
}
|
||||
|
||||
err = h.Unlock()
|
||||
err = h.Unlock(tt.args.instanceID)
|
||||
if !tt.want.isErr(err) {
|
||||
t.Errorf("unexpected error = %v", err)
|
||||
}
|
||||
|
@@ -12,6 +12,8 @@ import (
|
||||
"github.com/caos/zitadel/internal/eventstore"
|
||||
)
|
||||
|
||||
const systemID = "system"
|
||||
|
||||
type ProjectionHandlerConfig struct {
|
||||
HandlerConfig
|
||||
ProjectionName string
|
||||
@@ -27,10 +29,10 @@ type Update func(context.Context, []*Statement, Reduce) (unexecutedStmts []*Stat
|
||||
type Reduce func(eventstore.Event) (*Statement, error)
|
||||
|
||||
//Lock is used for mutex handling if needed on the projection
|
||||
type Lock func(context.Context, time.Duration) <-chan error
|
||||
type Lock func(context.Context, time.Duration, string) <-chan error
|
||||
|
||||
//Unlock releases the mutex of the projection
|
||||
type Unlock func() error
|
||||
type Unlock func(string) error
|
||||
|
||||
//SearchQuery generates the search query to lookup for events
|
||||
type SearchQuery func() (query *eventstore.SearchQueryBuilder, queryLimit uint64, err error)
|
||||
@@ -183,7 +185,7 @@ func (h *ProjectionHandler) bulk(
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
|
||||
errs := lock(ctx, h.requeueAfter)
|
||||
errs := lock(ctx, h.requeueAfter, systemID)
|
||||
//wait until projection is locked
|
||||
if err, ok := <-errs; err != nil || !ok {
|
||||
logging.WithFields("projection", h.ProjectionName).OnError(err).Warn("initial lock failed")
|
||||
@@ -194,7 +196,7 @@ func (h *ProjectionHandler) bulk(
|
||||
execErr := executeBulk(ctx)
|
||||
logging.WithFields("projection", h.ProjectionName).OnError(execErr).Warn("unable to execute")
|
||||
|
||||
unlockErr := unlock()
|
||||
unlockErr := unlock(systemID)
|
||||
logging.WithFields("projection", h.ProjectionName).OnError(unlockErr).Warn("unable to unlock")
|
||||
|
||||
if execErr != nil {
|
||||
|
@@ -912,7 +912,7 @@ type lockMock struct {
|
||||
}
|
||||
|
||||
func (m *lockMock) lock() Lock {
|
||||
return func(ctx context.Context, _ time.Duration) <-chan error {
|
||||
return func(ctx context.Context, _ time.Duration, _ string) <-chan error {
|
||||
m.callCount++
|
||||
errs := make(chan error)
|
||||
go func() {
|
||||
@@ -955,7 +955,7 @@ type unlockMock struct {
|
||||
}
|
||||
|
||||
func (m *unlockMock) unlock() Unlock {
|
||||
return func() error {
|
||||
return func(instanceID string) error {
|
||||
m.callCount++
|
||||
return m.err
|
||||
}
|
||||
|
@@ -50,6 +50,8 @@ const (
|
||||
OperationIn
|
||||
//OperationJSONContains checks if a stored value matches the given json
|
||||
OperationJSONContains
|
||||
//OperationNotIn checks if a stored value does not match one of the passed value list
|
||||
OperationNotIn
|
||||
|
||||
operationCount
|
||||
)
|
||||
|
@@ -288,8 +288,11 @@ func (db *CRDB) columnName(col repository.Field) string {
|
||||
}
|
||||
|
||||
func (db *CRDB) conditionFormat(operation repository.Operation) string {
|
||||
if operation == repository.OperationIn {
|
||||
switch operation {
|
||||
case repository.OperationIn:
|
||||
return "%s %s ANY(?)"
|
||||
case repository.OperationNotIn:
|
||||
return "%s %s ALL(?)"
|
||||
}
|
||||
return "%s %s ?"
|
||||
}
|
||||
@@ -304,6 +307,8 @@ func (db *CRDB) operation(operation repository.Operation) string {
|
||||
return "<"
|
||||
case repository.OperationJSONContains:
|
||||
return "@>"
|
||||
case repository.OperationNotIn:
|
||||
return "<>"
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
@@ -20,6 +20,8 @@ type SearchQuery struct {
|
||||
builder *SearchQueryBuilder
|
||||
aggregateTypes []AggregateType
|
||||
aggregateIDs []string
|
||||
instanceID string
|
||||
excludedInstanceIDs []string
|
||||
eventSequenceGreater uint64
|
||||
eventSequenceLess uint64
|
||||
eventTypes []EventType
|
||||
@@ -91,9 +93,9 @@ func (builder *SearchQueryBuilder) ResourceOwner(resourceOwner string) *SearchQu
|
||||
}
|
||||
|
||||
//InstanceID defines the instanceID (system) of the events
|
||||
func (factory *SearchQueryBuilder) InstanceID(instanceID string) *SearchQueryBuilder {
|
||||
factory.instanceID = instanceID
|
||||
return factory
|
||||
func (builder *SearchQueryBuilder) InstanceID(instanceID string) *SearchQueryBuilder {
|
||||
builder.instanceID = instanceID
|
||||
return builder
|
||||
}
|
||||
|
||||
//OrderDesc changes the sorting order of the returned events to descending
|
||||
@@ -149,6 +151,18 @@ func (query *SearchQuery) AggregateIDs(ids ...string) *SearchQuery {
|
||||
return query
|
||||
}
|
||||
|
||||
//InstanceID filters for events with the given instanceID
|
||||
func (query *SearchQuery) InstanceID(instanceID string) *SearchQuery {
|
||||
query.instanceID = instanceID
|
||||
return query
|
||||
}
|
||||
|
||||
//ExcludedInstanceID filters for events not having the given instanceIDs
|
||||
func (query *SearchQuery) ExcludedInstanceID(instanceIDs ...string) *SearchQuery {
|
||||
query.excludedInstanceIDs = instanceIDs
|
||||
return query
|
||||
}
|
||||
|
||||
//EventTypes filters for events with the given event types
|
||||
func (query *SearchQuery) EventTypes(types ...EventType) *SearchQuery {
|
||||
query.eventTypes = types
|
||||
@@ -180,6 +194,9 @@ func (query *SearchQuery) matches(event Event) bool {
|
||||
if ok := isAggregateIDs(event.Aggregate(), query.aggregateIDs...); len(query.aggregateIDs) > 0 && !ok {
|
||||
return false
|
||||
}
|
||||
if event.Aggregate().InstanceID != "" && query.instanceID != "" && event.Aggregate().InstanceID != query.instanceID {
|
||||
return false
|
||||
}
|
||||
if ok := isEventTypes(event, query.eventTypes...); len(query.eventTypes) > 0 && !ok {
|
||||
return false
|
||||
}
|
||||
@@ -203,6 +220,8 @@ func (builder *SearchQueryBuilder) build(instanceID string) (*repository.SearchQ
|
||||
query.eventDataFilter,
|
||||
query.eventSequenceGreaterFilter,
|
||||
query.eventSequenceLessFilter,
|
||||
query.instanceIDFilter,
|
||||
query.excludedInstanceIDFilter,
|
||||
query.builder.resourceOwnerFilter,
|
||||
query.builder.instanceIDFilter,
|
||||
} {
|
||||
@@ -281,6 +300,20 @@ func (query *SearchQuery) eventSequenceLessFilter() *repository.Filter {
|
||||
return repository.NewFilter(repository.FieldSequence, query.eventSequenceLess, sortOrder)
|
||||
}
|
||||
|
||||
func (query *SearchQuery) instanceIDFilter() *repository.Filter {
|
||||
if query.instanceID == "" {
|
||||
return nil
|
||||
}
|
||||
return repository.NewFilter(repository.FieldInstanceID, query.instanceID, repository.OperationEquals)
|
||||
}
|
||||
|
||||
func (query *SearchQuery) excludedInstanceIDFilter() *repository.Filter {
|
||||
if len(query.excludedInstanceIDs) == 0 {
|
||||
return nil
|
||||
}
|
||||
return repository.NewFilter(repository.FieldInstanceID, query.excludedInstanceIDs, repository.OperationNotIn)
|
||||
}
|
||||
|
||||
func (builder *SearchQueryBuilder) resourceOwnerFilter() *repository.Filter {
|
||||
if builder.resourceOwner == "" {
|
||||
return nil
|
||||
|
@@ -12,7 +12,6 @@ import (
|
||||
type Eventstore interface {
|
||||
Health(ctx context.Context) error
|
||||
FilterEvents(ctx context.Context, searchQuery *models.SearchQuery) (events []*models.Event, err error)
|
||||
LatestSequence(ctx context.Context, searchQuery *models.SearchQueryFactory) (uint64, error)
|
||||
Subscribe(aggregates ...models.AggregateType) *Subscription
|
||||
}
|
||||
|
||||
@@ -35,13 +34,6 @@ func (es *eventstore) FilterEvents(ctx context.Context, searchQuery *models.Sear
|
||||
return es.repo.Filter(ctx, models.FactoryFromSearchQuery(searchQuery))
|
||||
}
|
||||
|
||||
func (es *eventstore) LatestSequence(ctx context.Context, queryFactory *models.SearchQueryFactory) (uint64, error) {
|
||||
sequenceFactory := *queryFactory
|
||||
sequenceFactory = *(&sequenceFactory).Columns(models.Columns_Max_Sequence)
|
||||
sequenceFactory = *(&sequenceFactory).SequenceGreater(0)
|
||||
return es.repo.LatestSequence(ctx, &sequenceFactory)
|
||||
}
|
||||
|
||||
func (es *eventstore) Health(ctx context.Context) error {
|
||||
return es.repo.Health(ctx)
|
||||
}
|
||||
|
@@ -12,16 +12,16 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
selectEscaped = `SELECT creation_date, event_type, event_sequence, previous_aggregate_sequence, event_data, editor_service, editor_user, resource_owner, instance_id, aggregate_type, aggregate_id, aggregate_version FROM eventstore\.events WHERE aggregate_type = \$1`
|
||||
selectEscaped = `SELECT creation_date, event_type, event_sequence, previous_aggregate_sequence, event_data, editor_service, editor_user, resource_owner, instance_id, aggregate_type, aggregate_id, aggregate_version FROM eventstore\.events WHERE \( aggregate_type = \$1`
|
||||
)
|
||||
|
||||
var (
|
||||
eventColumns = []string{"creation_date", "event_type", "event_sequence", "previous_aggregate_sequence", "event_data", "editor_service", "editor_user", "resource_owner", "instance_id", "aggregate_type", "aggregate_id", "aggregate_version"}
|
||||
expectedFilterEventsLimitFormat = regexp.MustCompile(selectEscaped + ` ORDER BY event_sequence LIMIT \$2`).String()
|
||||
expectedFilterEventsDescFormat = regexp.MustCompile(selectEscaped + ` ORDER BY event_sequence DESC`).String()
|
||||
expectedFilterEventsAggregateIDLimit = regexp.MustCompile(selectEscaped + ` AND aggregate_id = \$2 ORDER BY event_sequence LIMIT \$3`).String()
|
||||
expectedFilterEventsAggregateIDTypeLimit = regexp.MustCompile(selectEscaped + ` AND aggregate_id = \$2 ORDER BY event_sequence LIMIT \$3`).String()
|
||||
expectedGetAllEvents = regexp.MustCompile(selectEscaped + ` ORDER BY event_sequence`).String()
|
||||
expectedFilterEventsLimitFormat = regexp.MustCompile(selectEscaped + ` \) ORDER BY event_sequence LIMIT \$2`).String()
|
||||
expectedFilterEventsDescFormat = regexp.MustCompile(selectEscaped + ` \) ORDER BY event_sequence DESC`).String()
|
||||
expectedFilterEventsAggregateIDLimit = regexp.MustCompile(selectEscaped + ` AND aggregate_id = \$2 \) ORDER BY event_sequence LIMIT \$3`).String()
|
||||
expectedFilterEventsAggregateIDTypeLimit = regexp.MustCompile(selectEscaped + ` AND aggregate_id = \$2 \) ORDER BY event_sequence LIMIT \$3`).String()
|
||||
expectedGetAllEvents = regexp.MustCompile(selectEscaped + ` \) ORDER BY event_sequence`).String()
|
||||
|
||||
expectedInsertStatement = regexp.MustCompile(`INSERT INTO eventstore\.events ` +
|
||||
`\(event_type, aggregate_type, aggregate_id, aggregate_version, creation_date, event_data, editor_user, editor_service, resource_owner, instance_id, previous_aggregate_sequence, previous_aggregate_type_sequence\) ` +
|
||||
@@ -172,14 +172,14 @@ func (db *dbMock) expectFilterEventsError(returnedErr error) *dbMock {
|
||||
}
|
||||
|
||||
func (db *dbMock) expectLatestSequenceFilter(aggregateType string, sequence Sequence) *dbMock {
|
||||
db.mock.ExpectQuery(`SELECT MAX\(event_sequence\) FROM eventstore\.events WHERE aggregate_type = \$1`).
|
||||
db.mock.ExpectQuery(`SELECT MAX\(event_sequence\) FROM eventstore\.events WHERE \( aggregate_type = \$1 \)`).
|
||||
WithArgs(aggregateType).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"max_sequence"}).AddRow(sequence))
|
||||
return db
|
||||
}
|
||||
|
||||
func (db *dbMock) expectLatestSequenceFilterError(aggregateType string, err error) *dbMock {
|
||||
db.mock.ExpectQuery(`SELECT MAX\(event_sequence\) FROM eventstore\.events WHERE aggregate_type = \$1`).
|
||||
db.mock.ExpectQuery(`SELECT MAX\(event_sequence\) FROM eventstore\.events WHERE \( aggregate_type = \$1 \)`).
|
||||
WithArgs(aggregateType).WillReturnError(err)
|
||||
return db
|
||||
}
|
||||
|
@@ -41,7 +41,7 @@ func TestSQL_Filter(t *testing.T) {
|
||||
},
|
||||
args: args{
|
||||
events: &mockEvents{t: t},
|
||||
searchQuery: es_models.NewSearchQueryFactory("user").Limit(34),
|
||||
searchQuery: es_models.NewSearchQueryFactory().Limit(34).AddQuery().AggregateTypes("user").Factory(),
|
||||
},
|
||||
res: res{
|
||||
eventsLen: 3,
|
||||
@@ -55,7 +55,7 @@ func TestSQL_Filter(t *testing.T) {
|
||||
},
|
||||
args: args{
|
||||
events: &mockEvents{t: t},
|
||||
searchQuery: es_models.NewSearchQueryFactory("user").OrderDesc(),
|
||||
searchQuery: es_models.NewSearchQueryFactory().OrderDesc().AddQuery().AggregateTypes("user").Factory(),
|
||||
},
|
||||
res: res{
|
||||
eventsLen: 34,
|
||||
@@ -69,7 +69,7 @@ func TestSQL_Filter(t *testing.T) {
|
||||
},
|
||||
args: args{
|
||||
events: &mockEvents{t: t},
|
||||
searchQuery: es_models.NewSearchQueryFactory("nonAggregate"),
|
||||
searchQuery: es_models.NewSearchQueryFactory().AddQuery().AggregateTypes("nonAggregate").Factory(),
|
||||
},
|
||||
res: res{
|
||||
wantErr: true,
|
||||
@@ -83,7 +83,7 @@ func TestSQL_Filter(t *testing.T) {
|
||||
},
|
||||
args: args{
|
||||
events: &mockEvents{t: t},
|
||||
searchQuery: es_models.NewSearchQueryFactory("user"),
|
||||
searchQuery: es_models.NewSearchQueryFactory().AddQuery().AggregateTypes("user").Factory(),
|
||||
},
|
||||
res: res{
|
||||
wantErr: true,
|
||||
@@ -97,7 +97,7 @@ func TestSQL_Filter(t *testing.T) {
|
||||
},
|
||||
args: args{
|
||||
events: &mockEvents{t: t},
|
||||
searchQuery: es_models.NewSearchQueryFactory("user").Limit(5).AggregateIDs("hop"),
|
||||
searchQuery: es_models.NewSearchQueryFactory().Limit(5).AddQuery().AggregateTypes("user").AggregateIDs("hop").Factory(),
|
||||
},
|
||||
res: res{
|
||||
wantErr: false,
|
||||
@@ -111,7 +111,7 @@ func TestSQL_Filter(t *testing.T) {
|
||||
},
|
||||
args: args{
|
||||
events: &mockEvents{t: t},
|
||||
searchQuery: es_models.NewSearchQueryFactory("user").Limit(5).AggregateIDs("hop"),
|
||||
searchQuery: es_models.NewSearchQueryFactory().Limit(5).AddQuery().AggregateTypes("user").AggregateIDs("hop").Factory(),
|
||||
},
|
||||
res: res{
|
||||
wantErr: false,
|
||||
@@ -176,7 +176,7 @@ func TestSQL_LatestSequence(t *testing.T) {
|
||||
{
|
||||
name: "no events for aggregate",
|
||||
args: args{
|
||||
searchQuery: es_models.NewSearchQueryFactory("idiot").Columns(es_models.Columns_Max_Sequence),
|
||||
searchQuery: es_models.NewSearchQueryFactory().Columns(es_models.Columns_Max_Sequence).AddQuery().AggregateTypes("idiot").Factory(),
|
||||
},
|
||||
fields: fields{
|
||||
client: mockDB(t).expectLatestSequenceFilterError("idiot", sql.ErrNoRows),
|
||||
@@ -189,7 +189,7 @@ func TestSQL_LatestSequence(t *testing.T) {
|
||||
{
|
||||
name: "sql query error",
|
||||
args: args{
|
||||
searchQuery: es_models.NewSearchQueryFactory("idiot").Columns(es_models.Columns_Max_Sequence),
|
||||
searchQuery: es_models.NewSearchQueryFactory().Columns(es_models.Columns_Max_Sequence).AddQuery().AggregateTypes("idiot").Factory(),
|
||||
},
|
||||
fields: fields{
|
||||
client: mockDB(t).expectLatestSequenceFilterError("idiot", sql.ErrConnDone),
|
||||
@@ -203,7 +203,7 @@ func TestSQL_LatestSequence(t *testing.T) {
|
||||
{
|
||||
name: "events for aggregate found",
|
||||
args: args{
|
||||
searchQuery: es_models.NewSearchQueryFactory("user").Columns(es_models.Columns_Max_Sequence),
|
||||
searchQuery: es_models.NewSearchQueryFactory().Columns(es_models.Columns_Max_Sequence).AddQuery().AggregateTypes("user").Factory(),
|
||||
},
|
||||
fields: fields{
|
||||
client: mockDB(t).expectLatestSequenceFilter("user", math.MaxUint64),
|
||||
|
@@ -61,27 +61,31 @@ func buildQuery(queryFactory *es_models.SearchQueryFactory) (query string, limit
|
||||
return query, searchQuery.Limit, values, rowScanner
|
||||
}
|
||||
|
||||
func prepareCondition(filters []*es_models.Filter) (clause string, values []interface{}) {
|
||||
values = make([]interface{}, len(filters))
|
||||
func prepareCondition(filters [][]*es_models.Filter) (clause string, values []interface{}) {
|
||||
values = make([]interface{}, 0, len(filters))
|
||||
clauses := make([]string, len(filters))
|
||||
|
||||
if len(filters) == 0 {
|
||||
return clause, values
|
||||
}
|
||||
for i, filter := range filters {
|
||||
value := filter.GetValue()
|
||||
switch value.(type) {
|
||||
case []bool, []float64, []int64, []string, []es_models.AggregateType, []es_models.EventType, *[]bool, *[]float64, *[]int64, *[]string, *[]es_models.AggregateType, *[]es_models.EventType:
|
||||
value = pq.Array(value)
|
||||
}
|
||||
subClauses := make([]string, 0, len(filter))
|
||||
for _, f := range filter {
|
||||
value := f.GetValue()
|
||||
switch value.(type) {
|
||||
case []bool, []float64, []int64, []string, []es_models.AggregateType, []es_models.EventType, *[]bool, *[]float64, *[]int64, *[]string, *[]es_models.AggregateType, *[]es_models.EventType:
|
||||
value = pq.Array(value)
|
||||
}
|
||||
|
||||
clauses[i] = getCondition(filter)
|
||||
if clauses[i] == "" {
|
||||
return "", nil
|
||||
subClauses = append(subClauses, getCondition(f))
|
||||
if subClauses[len(subClauses)-1] == "" {
|
||||
return "", nil
|
||||
}
|
||||
values = append(values, value)
|
||||
}
|
||||
values[i] = value
|
||||
clauses[i] = "( " + strings.Join(subClauses, " AND ") + " )"
|
||||
}
|
||||
return " WHERE " + strings.Join(clauses, " AND "), values
|
||||
return " WHERE " + strings.Join(clauses, " OR "), values
|
||||
}
|
||||
|
||||
type scan func(dest ...interface{}) error
|
||||
@@ -162,8 +166,11 @@ func getCondition(filter *es_models.Filter) (condition string) {
|
||||
}
|
||||
|
||||
func getConditionFormat(operation es_models.Operation) string {
|
||||
if operation == es_models.Operation_In {
|
||||
switch operation {
|
||||
case es_models.Operation_In:
|
||||
return "%s %s ANY(?)"
|
||||
case es_models.Operation_NotIn:
|
||||
return "%s %s ALL(?)"
|
||||
}
|
||||
return "%s %s ?"
|
||||
}
|
||||
@@ -200,6 +207,8 @@ func getOperation(operation es_models.Operation) string {
|
||||
return ">"
|
||||
case es_models.Operation_Less:
|
||||
return "<"
|
||||
case es_models.Operation_NotIn:
|
||||
return "<>"
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
@@ -309,7 +309,7 @@ func prepareTestScan(err error, res []interface{}) scan {
|
||||
|
||||
func Test_prepareCondition(t *testing.T) {
|
||||
type args struct {
|
||||
filters []*es_models.Filter
|
||||
filters [][]*es_models.Filter
|
||||
}
|
||||
type res struct {
|
||||
clause string
|
||||
@@ -333,7 +333,7 @@ func Test_prepareCondition(t *testing.T) {
|
||||
{
|
||||
name: "empty filters",
|
||||
args: args{
|
||||
filters: []*es_models.Filter{},
|
||||
filters: [][]*es_models.Filter{},
|
||||
},
|
||||
res: res{
|
||||
clause: "",
|
||||
@@ -343,8 +343,10 @@ func Test_prepareCondition(t *testing.T) {
|
||||
{
|
||||
name: "invalid condition",
|
||||
args: args{
|
||||
filters: []*es_models.Filter{
|
||||
es_models.NewFilter(es_models.Field_AggregateID, "wrong", es_models.Operation(-1)),
|
||||
filters: [][]*es_models.Filter{
|
||||
{
|
||||
es_models.NewFilter(es_models.Field_AggregateID, "wrong", es_models.Operation(-1)),
|
||||
},
|
||||
},
|
||||
},
|
||||
res: res{
|
||||
@@ -355,26 +357,30 @@ func Test_prepareCondition(t *testing.T) {
|
||||
{
|
||||
name: "array as condition value",
|
||||
args: args{
|
||||
filters: []*es_models.Filter{
|
||||
es_models.NewFilter(es_models.Field_AggregateType, []es_models.AggregateType{"user", "org"}, es_models.Operation_In),
|
||||
filters: [][]*es_models.Filter{
|
||||
{
|
||||
es_models.NewFilter(es_models.Field_AggregateType, []es_models.AggregateType{"user", "org"}, es_models.Operation_In),
|
||||
},
|
||||
},
|
||||
},
|
||||
res: res{
|
||||
clause: " WHERE aggregate_type = ANY(?)",
|
||||
clause: " WHERE ( aggregate_type = ANY(?) )",
|
||||
values: []interface{}{pq.Array([]es_models.AggregateType{"user", "org"})},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "multiple filters",
|
||||
args: args{
|
||||
filters: []*es_models.Filter{
|
||||
es_models.NewFilter(es_models.Field_AggregateType, []es_models.AggregateType{"user", "org"}, es_models.Operation_In),
|
||||
es_models.NewFilter(es_models.Field_AggregateID, "1234", es_models.Operation_Equals),
|
||||
es_models.NewFilter(es_models.Field_EventType, []es_models.EventType{"user.created", "org.created"}, es_models.Operation_In),
|
||||
filters: [][]*es_models.Filter{
|
||||
{
|
||||
es_models.NewFilter(es_models.Field_AggregateType, []es_models.AggregateType{"user", "org"}, es_models.Operation_In),
|
||||
es_models.NewFilter(es_models.Field_AggregateID, "1234", es_models.Operation_Equals),
|
||||
es_models.NewFilter(es_models.Field_EventType, []es_models.EventType{"user.created", "org.created"}, es_models.Operation_In),
|
||||
},
|
||||
},
|
||||
},
|
||||
res: res{
|
||||
clause: " WHERE aggregate_type = ANY(?) AND aggregate_id = ? AND event_type = ANY(?)",
|
||||
clause: " WHERE ( aggregate_type = ANY(?) AND aggregate_id = ? AND event_type = ANY(?) )",
|
||||
values: []interface{}{pq.Array([]es_models.AggregateType{"user", "org"}), "1234", pq.Array([]es_models.EventType{"user.created", "org.created"})},
|
||||
},
|
||||
},
|
||||
@@ -428,10 +434,10 @@ func Test_buildQuery(t *testing.T) {
|
||||
{
|
||||
name: "with order by desc",
|
||||
args: args{
|
||||
queryFactory: es_models.NewSearchQueryFactory("user").OrderDesc(),
|
||||
queryFactory: es_models.NewSearchQueryFactory().OrderDesc().AddQuery().AggregateTypes("user").Factory(),
|
||||
},
|
||||
res: res{
|
||||
query: "SELECT creation_date, event_type, event_sequence, previous_aggregate_sequence, event_data, editor_service, editor_user, resource_owner, instance_id, aggregate_type, aggregate_id, aggregate_version FROM eventstore.events WHERE aggregate_type = $1 ORDER BY event_sequence DESC",
|
||||
query: "SELECT creation_date, event_type, event_sequence, previous_aggregate_sequence, event_data, editor_service, editor_user, resource_owner, instance_id, aggregate_type, aggregate_id, aggregate_version FROM eventstore.events WHERE ( aggregate_type = $1 ) ORDER BY event_sequence DESC",
|
||||
rowScanner: true,
|
||||
values: []interface{}{es_models.AggregateType("user")},
|
||||
},
|
||||
@@ -439,10 +445,10 @@ func Test_buildQuery(t *testing.T) {
|
||||
{
|
||||
name: "with limit",
|
||||
args: args{
|
||||
queryFactory: es_models.NewSearchQueryFactory("user").Limit(5),
|
||||
queryFactory: es_models.NewSearchQueryFactory().Limit(5).AddQuery().AggregateTypes("user").Factory(),
|
||||
},
|
||||
res: res{
|
||||
query: "SELECT creation_date, event_type, event_sequence, previous_aggregate_sequence, event_data, editor_service, editor_user, resource_owner, instance_id, aggregate_type, aggregate_id, aggregate_version FROM eventstore.events WHERE aggregate_type = $1 ORDER BY event_sequence LIMIT $2",
|
||||
query: "SELECT creation_date, event_type, event_sequence, previous_aggregate_sequence, event_data, editor_service, editor_user, resource_owner, instance_id, aggregate_type, aggregate_id, aggregate_version FROM eventstore.events WHERE ( aggregate_type = $1 ) ORDER BY event_sequence LIMIT $2",
|
||||
rowScanner: true,
|
||||
values: []interface{}{es_models.AggregateType("user"), uint64(5)},
|
||||
limit: 5,
|
||||
@@ -451,10 +457,10 @@ func Test_buildQuery(t *testing.T) {
|
||||
{
|
||||
name: "with limit and order by desc",
|
||||
args: args{
|
||||
queryFactory: es_models.NewSearchQueryFactory("user").Limit(5).OrderDesc(),
|
||||
queryFactory: es_models.NewSearchQueryFactory().Limit(5).OrderDesc().AddQuery().AggregateTypes("user").Factory(),
|
||||
},
|
||||
res: res{
|
||||
query: "SELECT creation_date, event_type, event_sequence, previous_aggregate_sequence, event_data, editor_service, editor_user, resource_owner, instance_id, aggregate_type, aggregate_id, aggregate_version FROM eventstore.events WHERE aggregate_type = $1 ORDER BY event_sequence DESC LIMIT $2",
|
||||
query: "SELECT creation_date, event_type, event_sequence, previous_aggregate_sequence, event_data, editor_service, editor_user, resource_owner, instance_id, aggregate_type, aggregate_id, aggregate_version FROM eventstore.events WHERE ( aggregate_type = $1 ) ORDER BY event_sequence DESC LIMIT $2",
|
||||
rowScanner: true,
|
||||
values: []interface{}{es_models.AggregateType("user"), uint64(5)},
|
||||
limit: 5,
|
||||
|
@@ -7,16 +7,17 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/caos/logging"
|
||||
caos_errs "github.com/caos/zitadel/internal/errors"
|
||||
"github.com/cockroachdb/cockroach-go/v2/crdb"
|
||||
|
||||
caos_errs "github.com/caos/zitadel/internal/errors"
|
||||
)
|
||||
|
||||
const (
|
||||
insertStmtFormat = "INSERT INTO %s" +
|
||||
" (locker_id, locked_until, view_name) VALUES ($1, now()+$2::INTERVAL, $3)" +
|
||||
" ON CONFLICT (view_name)" +
|
||||
" DO UPDATE SET locker_id = $4, locked_until = now()+$5::INTERVAL" +
|
||||
" WHERE locks.view_name = $6 AND (locks.locker_id = $7 OR locks.locked_until < now())"
|
||||
" (locker_id, locked_until, view_name, instance_id) VALUES ($1, now()+$2::INTERVAL, $3, $4)" +
|
||||
" ON CONFLICT (view_name, instance_id)" +
|
||||
" DO UPDATE SET locker_id = $1, locked_until = now()+$2::INTERVAL" +
|
||||
" WHERE locks.view_name = $3 AND locks.instance_id = $4 AND (locks.locker_id = $1 OR locks.locked_until < now())"
|
||||
millisecondsAsSeconds = int64(time.Second / time.Millisecond)
|
||||
)
|
||||
|
||||
@@ -26,13 +27,11 @@ type lock struct {
|
||||
ViewName string `gorm:"column:view_name;primary_key"`
|
||||
}
|
||||
|
||||
func Renew(dbClient *sql.DB, lockTable, lockerID, viewModel string, waitTime time.Duration) error {
|
||||
func Renew(dbClient *sql.DB, lockTable, lockerID, viewModel, instanceID string, waitTime time.Duration) error {
|
||||
return crdb.ExecuteTx(context.Background(), dbClient, nil, func(tx *sql.Tx) error {
|
||||
insert := fmt.Sprintf(insertStmtFormat, lockTable)
|
||||
result, err := tx.Exec(insert,
|
||||
lockerID, waitTime.Milliseconds()/millisecondsAsSeconds, viewModel,
|
||||
lockerID, waitTime.Milliseconds()/millisecondsAsSeconds,
|
||||
viewModel, lockerID)
|
||||
lockerID, waitTime.Milliseconds()/millisecondsAsSeconds, viewModel, instanceID)
|
||||
|
||||
if err != nil {
|
||||
tx.Rollback()
|
||||
|
@@ -55,10 +55,10 @@ func (db *dbMock) expectReleaseSavepoint() *dbMock {
|
||||
return db
|
||||
}
|
||||
|
||||
func (db *dbMock) expectRenew(lockerID, view string, affectedRows int64) *dbMock {
|
||||
func (db *dbMock) expectRenew(lockerID, view, instanceID string, affectedRows int64) *dbMock {
|
||||
query := db.mock.
|
||||
ExpectExec(`INSERT INTO table\.locks \(locker_id, locked_until, view_name\) VALUES \(\$1, now\(\)\+\$2::INTERVAL, \$3\) ON CONFLICT \(view_name\) DO UPDATE SET locker_id = \$4, locked_until = now\(\)\+\$5::INTERVAL WHERE locks\.view_name = \$6 AND \(locks\.locker_id = \$7 OR locks\.locked_until < now\(\)\)`).
|
||||
WithArgs(lockerID, sqlmock.AnyArg(), view, lockerID, sqlmock.AnyArg(), view, lockerID).
|
||||
ExpectExec(`INSERT INTO table\.locks \(locker_id, locked_until, view_name, instance_id\) VALUES \(\$1, now\(\)\+\$2::INTERVAL, \$3\, \$4\) ON CONFLICT \(view_name, instance_id\) DO UPDATE SET locker_id = \$1, locked_until = now\(\)\+\$2::INTERVAL WHERE locks\.view_name = \$3 AND locks\.instance_id = \$4 AND \(locks\.locker_id = \$1 OR locks\.locked_until < now\(\)\)`).
|
||||
WithArgs(lockerID, sqlmock.AnyArg(), view, instanceID).
|
||||
WillReturnResult(sqlmock.NewResult(1, 1))
|
||||
|
||||
if affectedRows == 0 {
|
||||
@@ -75,10 +75,11 @@ func Test_locker_Renew(t *testing.T) {
|
||||
db *dbMock
|
||||
}
|
||||
type args struct {
|
||||
tableName string
|
||||
lockerID string
|
||||
viewModel string
|
||||
waitTime time.Duration
|
||||
tableName string
|
||||
lockerID string
|
||||
viewModel string
|
||||
instanceID string
|
||||
waitTime time.Duration
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
@@ -92,11 +93,11 @@ func Test_locker_Renew(t *testing.T) {
|
||||
db: mockDB(t).
|
||||
expectBegin().
|
||||
expectSavepoint().
|
||||
expectRenew("locker", "view", 1).
|
||||
expectRenew("locker", "view", "instanceID", 1).
|
||||
expectReleaseSavepoint().
|
||||
expectCommit(),
|
||||
},
|
||||
args: args{tableName: "table.locks", lockerID: "locker", viewModel: "view", waitTime: 1 * time.Second},
|
||||
args: args{tableName: "table.locks", lockerID: "locker", viewModel: "view", instanceID: "instanceID", waitTime: 1 * time.Second},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
@@ -105,16 +106,16 @@ func Test_locker_Renew(t *testing.T) {
|
||||
db: mockDB(t).
|
||||
expectBegin().
|
||||
expectSavepoint().
|
||||
expectRenew("locker", "view", 0).
|
||||
expectRenew("locker", "view", "instanceID", 0).
|
||||
expectRollback(),
|
||||
},
|
||||
args: args{tableName: "table.locks", lockerID: "locker", viewModel: "view", waitTime: 1 * time.Second},
|
||||
args: args{tableName: "table.locks", lockerID: "locker", viewModel: "view", instanceID: "instanceID", waitTime: 1 * time.Second},
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if err := Renew(tt.fields.db.db, tt.args.tableName, tt.args.lockerID, tt.args.viewModel, tt.args.waitTime); (err != nil) != tt.wantErr {
|
||||
if err := Renew(tt.fields.db.db, tt.args.tableName, tt.args.lockerID, tt.args.viewModel, tt.args.instanceID, tt.args.waitTime); (err != nil) != tt.wantErr {
|
||||
t.Errorf("locker.Renew() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
if err := tt.fields.db.mock.ExpectationsWereMet(); err != nil {
|
||||
|
@@ -1,56 +1,42 @@
|
||||
// Code generated by MockGen. DO NOT EDIT.
|
||||
// Source: github.com/caos/zitadel/internal/eventstore (interfaces: Eventstore)
|
||||
// Source: github.com/caos/zitadel/internal/eventstore/v1 (interfaces: Eventstore)
|
||||
|
||||
// Package mock is a generated GoMock package.
|
||||
package mock
|
||||
|
||||
import (
|
||||
context "context"
|
||||
"github.com/caos/zitadel/internal/eventstore"
|
||||
"github.com/caos/zitadel/internal/eventstore/v1"
|
||||
reflect "reflect"
|
||||
|
||||
v1 "github.com/caos/zitadel/internal/eventstore/v1"
|
||||
models "github.com/caos/zitadel/internal/eventstore/v1/models"
|
||||
gomock "github.com/golang/mock/gomock"
|
||||
reflect "reflect"
|
||||
)
|
||||
|
||||
// MockEventstore is a mock of Eventstore interface
|
||||
// MockEventstore is a mock of Eventstore interface.
|
||||
type MockEventstore struct {
|
||||
ctrl *gomock.Controller
|
||||
recorder *MockEventstoreMockRecorder
|
||||
}
|
||||
|
||||
// MockEventstoreMockRecorder is the mock recorder for MockEventstore
|
||||
// MockEventstoreMockRecorder is the mock recorder for MockEventstore.
|
||||
type MockEventstoreMockRecorder struct {
|
||||
mock *MockEventstore
|
||||
}
|
||||
|
||||
// NewMockEventstore creates a new mock instance
|
||||
// NewMockEventstore creates a new mock instance.
|
||||
func NewMockEventstore(ctrl *gomock.Controller) *MockEventstore {
|
||||
mock := &MockEventstore{ctrl: ctrl}
|
||||
mock.recorder = &MockEventstoreMockRecorder{mock}
|
||||
return mock
|
||||
}
|
||||
|
||||
// EXPECT returns an object that allows the caller to indicate expected use
|
||||
// EXPECT returns an object that allows the caller to indicate expected use.
|
||||
func (m *MockEventstore) EXPECT() *MockEventstoreMockRecorder {
|
||||
return m.recorder
|
||||
}
|
||||
|
||||
// AggregateCreator mocks base method
|
||||
func (m *MockEventstore) AggregateCreator() *models.AggregateCreator {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "AggregateCreator")
|
||||
ret0, _ := ret[0].(*models.AggregateCreator)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// AggregateCreator indicates an expected call of AggregateCreator
|
||||
func (mr *MockEventstoreMockRecorder) AggregateCreator() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AggregateCreator", reflect.TypeOf((*MockEventstore)(nil).AggregateCreator))
|
||||
}
|
||||
|
||||
// FilterEvents mocks base method
|
||||
// FilterEvents mocks base method.
|
||||
func (m *MockEventstore) FilterEvents(arg0 context.Context, arg1 *models.SearchQuery) ([]*models.Event, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "FilterEvents", arg0, arg1)
|
||||
@@ -59,13 +45,13 @@ func (m *MockEventstore) FilterEvents(arg0 context.Context, arg1 *models.SearchQ
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// FilterEvents indicates an expected call of FilterEvents
|
||||
// FilterEvents indicates an expected call of FilterEvents.
|
||||
func (mr *MockEventstoreMockRecorder) FilterEvents(arg0, arg1 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FilterEvents", reflect.TypeOf((*MockEventstore)(nil).FilterEvents), arg0, arg1)
|
||||
}
|
||||
|
||||
// Health mocks base method
|
||||
// Health mocks base method.
|
||||
func (m *MockEventstore) Health(arg0 context.Context) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Health", arg0)
|
||||
@@ -73,47 +59,13 @@ func (m *MockEventstore) Health(arg0 context.Context) error {
|
||||
return ret0
|
||||
}
|
||||
|
||||
// Health indicates an expected call of Health
|
||||
// Health indicates an expected call of Health.
|
||||
func (mr *MockEventstoreMockRecorder) Health(arg0 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Health", reflect.TypeOf((*MockEventstore)(nil).Health), arg0)
|
||||
}
|
||||
|
||||
// LatestSequence mocks base method
|
||||
func (m *MockEventstore) LatestSequence(arg0 context.Context, arg1 *models.SearchQueryFactory) (uint64, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "LatestSequence", arg0, arg1)
|
||||
ret0, _ := ret[0].(uint64)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// LatestSequence indicates an expected call of LatestSequence
|
||||
func (mr *MockEventstoreMockRecorder) LatestSequence(arg0, arg1 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LatestSequence", reflect.TypeOf((*MockEventstore)(nil).LatestSequence), arg0, arg1)
|
||||
}
|
||||
|
||||
// PushAggregates mocks base method
|
||||
func (m *MockEventstore) PushAggregates(arg0 context.Context, arg1 ...*models.Aggregate) error {
|
||||
m.ctrl.T.Helper()
|
||||
varargs := []interface{}{arg0}
|
||||
for _, a := range arg1 {
|
||||
varargs = append(varargs, a)
|
||||
}
|
||||
ret := m.ctrl.Call(m, "PushAggregates", varargs...)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// PushAggregates indicates an expected call of PushAggregates
|
||||
func (mr *MockEventstoreMockRecorder) PushAggregates(arg0 interface{}, arg1 ...interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
varargs := append([]interface{}{arg0}, arg1...)
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PushAggregates", reflect.TypeOf((*MockEventstore)(nil).PushAggregates), varargs...)
|
||||
}
|
||||
|
||||
// Subscribe mocks base method
|
||||
// Subscribe mocks base method.
|
||||
func (m *MockEventstore) Subscribe(arg0 ...models.AggregateType) *v1.Subscription {
|
||||
m.ctrl.T.Helper()
|
||||
varargs := []interface{}{}
|
||||
@@ -125,22 +77,8 @@ func (m *MockEventstore) Subscribe(arg0 ...models.AggregateType) *v1.Subscriptio
|
||||
return ret0
|
||||
}
|
||||
|
||||
// Subscribe indicates an expected call of Subscribe
|
||||
// Subscribe indicates an expected call of Subscribe.
|
||||
func (mr *MockEventstoreMockRecorder) Subscribe(arg0 ...interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Subscribe", reflect.TypeOf((*MockEventstore)(nil).Subscribe), arg0...)
|
||||
}
|
||||
|
||||
// V2 mocks base method
|
||||
func (m *MockEventstore) V2() *eventstore.Eventstore {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "V2")
|
||||
ret0, _ := ret[0].(*eventstore.Eventstore)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// V2 indicates an expected call of V2
|
||||
func (mr *MockEventstoreMockRecorder) V2() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "V2", reflect.TypeOf((*MockEventstore)(nil).V2))
|
||||
}
|
||||
|
@@ -190,7 +190,7 @@ func TestAggregate_Validate(t *testing.T) {
|
||||
resourceOwner: "org",
|
||||
PreviousSequence: 5,
|
||||
Precondition: &precondition{
|
||||
Query: NewSearchQuery().AggregateIDFilter("hodor"),
|
||||
Query: NewSearchQuery().AddQuery().AggregateIDFilter("hodor").SearchQuery(),
|
||||
},
|
||||
Events: []*Event{
|
||||
{
|
||||
@@ -240,7 +240,7 @@ func TestAggregate_Validate(t *testing.T) {
|
||||
PreviousSequence: 5,
|
||||
Precondition: &precondition{
|
||||
Validation: func(...*Event) error { return nil },
|
||||
Query: NewSearchQuery().AggregateIDFilter("hodor"),
|
||||
Query: NewSearchQuery().AddQuery().AggregateIDFilter("hodor").SearchQuery(),
|
||||
},
|
||||
Events: []*Event{
|
||||
{
|
||||
|
@@ -7,4 +7,5 @@ const (
|
||||
Operation_Greater
|
||||
Operation_Less
|
||||
Operation_In
|
||||
Operation_NotIn
|
||||
)
|
||||
|
@@ -9,24 +9,31 @@ import (
|
||||
)
|
||||
|
||||
type SearchQueryFactory struct {
|
||||
columns Columns
|
||||
limit uint64
|
||||
desc bool
|
||||
aggregateTypes []AggregateType
|
||||
aggregateIDs []string
|
||||
sequenceFrom uint64
|
||||
sequenceTo uint64
|
||||
eventTypes []EventType
|
||||
resourceOwner string
|
||||
instanceID string
|
||||
creationDate time.Time
|
||||
columns Columns
|
||||
limit uint64
|
||||
desc bool
|
||||
queries []*query
|
||||
}
|
||||
|
||||
type query struct {
|
||||
desc bool
|
||||
aggregateTypes []AggregateType
|
||||
aggregateIDs []string
|
||||
sequenceFrom uint64
|
||||
sequenceTo uint64
|
||||
eventTypes []EventType
|
||||
resourceOwner string
|
||||
instanceID string
|
||||
ignoredInstanceIDs []string
|
||||
creationDate time.Time
|
||||
factory *SearchQueryFactory
|
||||
}
|
||||
|
||||
type searchQuery struct {
|
||||
Columns Columns
|
||||
Limit uint64
|
||||
Desc bool
|
||||
Filters []*Filter
|
||||
Filters [][]*Filter
|
||||
}
|
||||
|
||||
type Columns int32
|
||||
@@ -39,49 +46,55 @@ const (
|
||||
)
|
||||
|
||||
//FactoryFromSearchQuery is deprecated because it's for migration purposes. use NewSearchQueryFactory
|
||||
func FactoryFromSearchQuery(query *SearchQuery) *SearchQueryFactory {
|
||||
func FactoryFromSearchQuery(q *SearchQuery) *SearchQueryFactory {
|
||||
factory := &SearchQueryFactory{
|
||||
columns: Columns_Event,
|
||||
desc: query.Desc,
|
||||
limit: query.Limit,
|
||||
desc: q.Desc,
|
||||
limit: q.Limit,
|
||||
queries: make([]*query, len(q.Queries)),
|
||||
}
|
||||
|
||||
for _, filter := range query.Filters {
|
||||
switch filter.field {
|
||||
case Field_AggregateType:
|
||||
factory = factory.aggregateTypesMig(filter.value.([]AggregateType)...)
|
||||
case Field_AggregateID:
|
||||
if aggregateID, ok := filter.value.(string); ok {
|
||||
factory = factory.AggregateIDs(aggregateID)
|
||||
} else if aggregateIDs, ok := filter.value.([]string); ok {
|
||||
factory = factory.AggregateIDs(aggregateIDs...)
|
||||
for i, qq := range q.Queries {
|
||||
factory.queries[i] = &query{factory: factory}
|
||||
for _, filter := range qq.Filters {
|
||||
switch filter.field {
|
||||
case Field_AggregateType:
|
||||
factory.queries[i] = factory.queries[i].aggregateTypesMig(filter.value.([]AggregateType)...)
|
||||
case Field_AggregateID:
|
||||
if aggregateID, ok := filter.value.(string); ok {
|
||||
factory.queries[i] = factory.queries[i].AggregateIDs(aggregateID)
|
||||
} else if aggregateIDs, ok := filter.value.([]string); ok {
|
||||
factory.queries[i] = factory.queries[i].AggregateIDs(aggregateIDs...)
|
||||
}
|
||||
case Field_LatestSequence:
|
||||
if filter.operation == Operation_Greater {
|
||||
factory.queries[i] = factory.queries[i].SequenceGreater(filter.value.(uint64))
|
||||
} else {
|
||||
factory.queries[i] = factory.queries[i].SequenceLess(filter.value.(uint64))
|
||||
}
|
||||
case Field_ResourceOwner:
|
||||
factory.queries[i] = factory.queries[i].ResourceOwner(filter.value.(string))
|
||||
case Field_InstanceID:
|
||||
if filter.operation == Operation_Equals {
|
||||
factory.queries[i] = factory.queries[i].InstanceID(filter.value.(string))
|
||||
} else if filter.operation == Operation_NotIn {
|
||||
factory.queries[i] = factory.queries[i].IgnoredInstanceIDs(filter.value.([]string)...)
|
||||
}
|
||||
case Field_EventType:
|
||||
factory.queries[i] = factory.queries[i].EventTypes(filter.value.([]EventType)...)
|
||||
case Field_EditorService, Field_EditorUser:
|
||||
logging.WithFields("value", filter.value).Panic("field not converted to factory")
|
||||
case Field_CreationDate:
|
||||
factory.queries[i] = factory.queries[i].CreationDateNewer(filter.value.(time.Time))
|
||||
}
|
||||
case Field_LatestSequence:
|
||||
if filter.operation == Operation_Greater {
|
||||
factory = factory.SequenceGreater(filter.value.(uint64))
|
||||
} else {
|
||||
factory = factory.SequenceLess(filter.value.(uint64))
|
||||
}
|
||||
case Field_ResourceOwner:
|
||||
factory = factory.ResourceOwner(filter.value.(string))
|
||||
case Field_InstanceID:
|
||||
factory = factory.InstanceID(filter.value.(string))
|
||||
case Field_EventType:
|
||||
factory = factory.EventTypes(filter.value.([]EventType)...)
|
||||
case Field_EditorService, Field_EditorUser:
|
||||
logging.Log("MODEL-Mr0VN").WithField("value", filter.value).Panic("field not converted to factory")
|
||||
case Field_CreationDate:
|
||||
factory = factory.CreationDateNewer(filter.value.(time.Time))
|
||||
}
|
||||
}
|
||||
|
||||
return factory
|
||||
}
|
||||
|
||||
func NewSearchQueryFactory(aggregateTypes ...AggregateType) *SearchQueryFactory {
|
||||
return &SearchQueryFactory{
|
||||
aggregateTypes: aggregateTypes,
|
||||
}
|
||||
func NewSearchQueryFactory() *SearchQueryFactory {
|
||||
return &SearchQueryFactory{}
|
||||
}
|
||||
|
||||
func (factory *SearchQueryFactory) Columns(columns Columns) *SearchQueryFactory {
|
||||
@@ -94,46 +107,6 @@ func (factory *SearchQueryFactory) Limit(limit uint64) *SearchQueryFactory {
|
||||
return factory
|
||||
}
|
||||
|
||||
func (factory *SearchQueryFactory) SequenceGreater(sequence uint64) *SearchQueryFactory {
|
||||
factory.sequenceFrom = sequence
|
||||
return factory
|
||||
}
|
||||
|
||||
func (factory *SearchQueryFactory) SequenceLess(sequence uint64) *SearchQueryFactory {
|
||||
factory.sequenceTo = sequence
|
||||
return factory
|
||||
}
|
||||
|
||||
func (factory *SearchQueryFactory) AggregateIDs(ids ...string) *SearchQueryFactory {
|
||||
factory.aggregateIDs = ids
|
||||
return factory
|
||||
}
|
||||
|
||||
func (factory *SearchQueryFactory) aggregateTypesMig(types ...AggregateType) *SearchQueryFactory {
|
||||
factory.aggregateTypes = types
|
||||
return factory
|
||||
}
|
||||
|
||||
func (factory *SearchQueryFactory) EventTypes(types ...EventType) *SearchQueryFactory {
|
||||
factory.eventTypes = types
|
||||
return factory
|
||||
}
|
||||
|
||||
func (factory *SearchQueryFactory) ResourceOwner(resourceOwner string) *SearchQueryFactory {
|
||||
factory.resourceOwner = resourceOwner
|
||||
return factory
|
||||
}
|
||||
|
||||
func (factory *SearchQueryFactory) InstanceID(instanceID string) *SearchQueryFactory {
|
||||
factory.instanceID = instanceID
|
||||
return factory
|
||||
}
|
||||
|
||||
func (factory *SearchQueryFactory) CreationDateNewer(time time.Time) *SearchQueryFactory {
|
||||
factory.creationDate = time
|
||||
return factory
|
||||
}
|
||||
|
||||
func (factory *SearchQueryFactory) OrderDesc() *SearchQueryFactory {
|
||||
factory.desc = true
|
||||
return factory
|
||||
@@ -144,27 +117,89 @@ func (factory *SearchQueryFactory) OrderAsc() *SearchQueryFactory {
|
||||
return factory
|
||||
}
|
||||
|
||||
func (factory *SearchQueryFactory) AddQuery() *query {
|
||||
q := &query{factory: factory}
|
||||
factory.queries = append(factory.queries, q)
|
||||
return q
|
||||
}
|
||||
|
||||
func (q *query) Factory() *SearchQueryFactory {
|
||||
return q.factory
|
||||
}
|
||||
|
||||
func (q *query) SequenceGreater(sequence uint64) *query {
|
||||
q.sequenceFrom = sequence
|
||||
return q
|
||||
}
|
||||
|
||||
func (q *query) SequenceLess(sequence uint64) *query {
|
||||
q.sequenceTo = sequence
|
||||
return q
|
||||
}
|
||||
|
||||
func (q *query) AggregateTypes(types ...AggregateType) *query {
|
||||
q.aggregateTypes = types
|
||||
return q
|
||||
}
|
||||
|
||||
func (q *query) AggregateIDs(ids ...string) *query {
|
||||
q.aggregateIDs = ids
|
||||
return q
|
||||
}
|
||||
|
||||
func (q *query) aggregateTypesMig(types ...AggregateType) *query {
|
||||
q.aggregateTypes = types
|
||||
return q
|
||||
}
|
||||
|
||||
func (q *query) EventTypes(types ...EventType) *query {
|
||||
q.eventTypes = types
|
||||
return q
|
||||
}
|
||||
|
||||
func (q *query) ResourceOwner(resourceOwner string) *query {
|
||||
q.resourceOwner = resourceOwner
|
||||
return q
|
||||
}
|
||||
|
||||
func (q *query) InstanceID(instanceID string) *query {
|
||||
q.instanceID = instanceID
|
||||
return q
|
||||
}
|
||||
|
||||
func (q *query) IgnoredInstanceIDs(instanceIDs ...string) *query {
|
||||
q.ignoredInstanceIDs = instanceIDs
|
||||
return q
|
||||
}
|
||||
|
||||
func (q *query) CreationDateNewer(time time.Time) *query {
|
||||
q.creationDate = time
|
||||
return q
|
||||
}
|
||||
|
||||
func (factory *SearchQueryFactory) Build() (*searchQuery, error) {
|
||||
if factory == nil ||
|
||||
len(factory.aggregateTypes) < 1 ||
|
||||
len(factory.queries) < 1 ||
|
||||
(factory.columns < 0 || factory.columns >= columnsCount) {
|
||||
return nil, errors.ThrowPreconditionFailed(nil, "MODEL-tGAD3", "factory invalid")
|
||||
}
|
||||
filters := []*Filter{
|
||||
factory.aggregateTypeFilter(),
|
||||
}
|
||||
filters := make([][]*Filter, len(factory.queries))
|
||||
|
||||
for _, f := range []func() *Filter{
|
||||
factory.aggregateIDFilter,
|
||||
factory.sequenceFromFilter,
|
||||
factory.sequenceToFilter,
|
||||
factory.eventTypeFilter,
|
||||
factory.resourceOwnerFilter,
|
||||
factory.instanceIDFilter,
|
||||
factory.creationDateNewerFilter,
|
||||
} {
|
||||
if filter := f(); filter != nil {
|
||||
filters = append(filters, filter)
|
||||
for i, query := range factory.queries {
|
||||
for _, f := range []func() *Filter{
|
||||
query.aggregateTypeFilter,
|
||||
query.aggregateIDFilter,
|
||||
query.sequenceFromFilter,
|
||||
query.sequenceToFilter,
|
||||
query.eventTypeFilter,
|
||||
query.resourceOwnerFilter,
|
||||
query.instanceIDFilter,
|
||||
query.ignoredInstanceIDsFilter,
|
||||
query.creationDateNewerFilter,
|
||||
} {
|
||||
if filter := f(); filter != nil {
|
||||
filters[i] = append(filters[i], filter)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -176,72 +211,79 @@ func (factory *SearchQueryFactory) Build() (*searchQuery, error) {
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (factory *SearchQueryFactory) aggregateIDFilter() *Filter {
|
||||
if len(factory.aggregateIDs) < 1 {
|
||||
func (q *query) aggregateIDFilter() *Filter {
|
||||
if len(q.aggregateIDs) < 1 {
|
||||
return nil
|
||||
}
|
||||
if len(factory.aggregateIDs) == 1 {
|
||||
return NewFilter(Field_AggregateID, factory.aggregateIDs[0], Operation_Equals)
|
||||
if len(q.aggregateIDs) == 1 {
|
||||
return NewFilter(Field_AggregateID, q.aggregateIDs[0], Operation_Equals)
|
||||
}
|
||||
return NewFilter(Field_AggregateID, factory.aggregateIDs, Operation_In)
|
||||
return NewFilter(Field_AggregateID, q.aggregateIDs, Operation_In)
|
||||
}
|
||||
|
||||
func (factory *SearchQueryFactory) eventTypeFilter() *Filter {
|
||||
if len(factory.eventTypes) < 1 {
|
||||
func (q *query) eventTypeFilter() *Filter {
|
||||
if len(q.eventTypes) < 1 {
|
||||
return nil
|
||||
}
|
||||
if len(factory.eventTypes) == 1 {
|
||||
return NewFilter(Field_EventType, factory.eventTypes[0], Operation_Equals)
|
||||
if len(q.eventTypes) == 1 {
|
||||
return NewFilter(Field_EventType, q.eventTypes[0], Operation_Equals)
|
||||
}
|
||||
return NewFilter(Field_EventType, factory.eventTypes, Operation_In)
|
||||
return NewFilter(Field_EventType, q.eventTypes, Operation_In)
|
||||
}
|
||||
|
||||
func (factory *SearchQueryFactory) aggregateTypeFilter() *Filter {
|
||||
if len(factory.aggregateTypes) == 1 {
|
||||
return NewFilter(Field_AggregateType, factory.aggregateTypes[0], Operation_Equals)
|
||||
func (q *query) aggregateTypeFilter() *Filter {
|
||||
if len(q.aggregateTypes) == 1 {
|
||||
return NewFilter(Field_AggregateType, q.aggregateTypes[0], Operation_Equals)
|
||||
}
|
||||
return NewFilter(Field_AggregateType, factory.aggregateTypes, Operation_In)
|
||||
return NewFilter(Field_AggregateType, q.aggregateTypes, Operation_In)
|
||||
}
|
||||
|
||||
func (factory *SearchQueryFactory) sequenceFromFilter() *Filter {
|
||||
if factory.sequenceFrom == 0 {
|
||||
func (q *query) sequenceFromFilter() *Filter {
|
||||
if q.sequenceFrom == 0 {
|
||||
return nil
|
||||
}
|
||||
sortOrder := Operation_Greater
|
||||
if factory.desc {
|
||||
if q.factory.desc {
|
||||
sortOrder = Operation_Less
|
||||
}
|
||||
return NewFilter(Field_LatestSequence, factory.sequenceFrom, sortOrder)
|
||||
return NewFilter(Field_LatestSequence, q.sequenceFrom, sortOrder)
|
||||
}
|
||||
|
||||
func (factory *SearchQueryFactory) sequenceToFilter() *Filter {
|
||||
if factory.sequenceTo == 0 {
|
||||
func (q *query) sequenceToFilter() *Filter {
|
||||
if q.sequenceTo == 0 {
|
||||
return nil
|
||||
}
|
||||
sortOrder := Operation_Less
|
||||
if factory.desc {
|
||||
if q.factory.desc {
|
||||
sortOrder = Operation_Greater
|
||||
}
|
||||
return NewFilter(Field_LatestSequence, factory.sequenceTo, sortOrder)
|
||||
return NewFilter(Field_LatestSequence, q.sequenceTo, sortOrder)
|
||||
}
|
||||
|
||||
func (factory *SearchQueryFactory) resourceOwnerFilter() *Filter {
|
||||
if factory.resourceOwner == "" {
|
||||
func (q *query) resourceOwnerFilter() *Filter {
|
||||
if q.resourceOwner == "" {
|
||||
return nil
|
||||
}
|
||||
return NewFilter(Field_ResourceOwner, factory.resourceOwner, Operation_Equals)
|
||||
return NewFilter(Field_ResourceOwner, q.resourceOwner, Operation_Equals)
|
||||
}
|
||||
|
||||
func (factory *SearchQueryFactory) instanceIDFilter() *Filter {
|
||||
if factory.instanceID == "" {
|
||||
func (q *query) instanceIDFilter() *Filter {
|
||||
if q.instanceID == "" {
|
||||
return nil
|
||||
}
|
||||
return NewFilter(Field_InstanceID, factory.instanceID, Operation_Equals)
|
||||
return NewFilter(Field_InstanceID, q.instanceID, Operation_Equals)
|
||||
}
|
||||
|
||||
func (factory *SearchQueryFactory) creationDateNewerFilter() *Filter {
|
||||
if factory.creationDate.IsZero() {
|
||||
func (q *query) ignoredInstanceIDsFilter() *Filter {
|
||||
if len(q.ignoredInstanceIDs) == 0 {
|
||||
return nil
|
||||
}
|
||||
return NewFilter(Field_CreationDate, factory.creationDate, Operation_Greater)
|
||||
return NewFilter(Field_InstanceID, q.ignoredInstanceIDs, Operation_NotIn)
|
||||
}
|
||||
|
||||
func (q *query) creationDateNewerFilter() *Filter {
|
||||
if q.creationDate.IsZero() {
|
||||
return nil
|
||||
}
|
||||
return NewFilter(Field_CreationDate, q.creationDate, Operation_Greater)
|
||||
}
|
||||
|
@@ -11,15 +11,46 @@ type SearchQuery struct {
|
||||
Limit uint64
|
||||
Desc bool
|
||||
Filters []*Filter
|
||||
Queries []*Query
|
||||
}
|
||||
|
||||
type Query struct {
|
||||
searchQuery *SearchQuery
|
||||
Filters []*Filter
|
||||
}
|
||||
|
||||
//NewSearchQuery is deprecated. Use SearchQueryFactory
|
||||
func NewSearchQuery() *SearchQuery {
|
||||
return &SearchQuery{
|
||||
Filters: make([]*Filter, 0, 4),
|
||||
Queries: make([]*Query, 0),
|
||||
}
|
||||
}
|
||||
|
||||
func (q *SearchQuery) AddQuery() *Query {
|
||||
query := &Query{
|
||||
searchQuery: q,
|
||||
}
|
||||
q.Queries = append(q.Queries, query)
|
||||
|
||||
return query
|
||||
}
|
||||
|
||||
//SearchQuery returns the SearchQuery of the sub query
|
||||
func (q *Query) SearchQuery() *SearchQuery {
|
||||
return q.searchQuery
|
||||
}
|
||||
func (q *Query) setFilter(filter *Filter) *Query {
|
||||
for i, f := range q.Filters {
|
||||
if f.field == filter.field && f.field != Field_LatestSequence {
|
||||
q.Filters[i] = filter
|
||||
return q
|
||||
}
|
||||
}
|
||||
q.Filters = append(q.Filters, filter)
|
||||
return q
|
||||
}
|
||||
|
||||
func (q *SearchQuery) SetLimit(limit uint64) *SearchQuery {
|
||||
q.Limit = limit
|
||||
return q
|
||||
@@ -35,23 +66,23 @@ func (q *SearchQuery) OrderAsc() *SearchQuery {
|
||||
return q
|
||||
}
|
||||
|
||||
func (q *SearchQuery) AggregateIDFilter(id string) *SearchQuery {
|
||||
func (q *Query) AggregateIDFilter(id string) *Query {
|
||||
return q.setFilter(NewFilter(Field_AggregateID, id, Operation_Equals))
|
||||
}
|
||||
|
||||
func (q *SearchQuery) AggregateIDsFilter(ids ...string) *SearchQuery {
|
||||
func (q *Query) AggregateIDsFilter(ids ...string) *Query {
|
||||
return q.setFilter(NewFilter(Field_AggregateID, ids, Operation_In))
|
||||
}
|
||||
|
||||
func (q *SearchQuery) AggregateTypeFilter(types ...AggregateType) *SearchQuery {
|
||||
func (q *Query) AggregateTypeFilter(types ...AggregateType) *Query {
|
||||
return q.setFilter(NewFilter(Field_AggregateType, types, Operation_In))
|
||||
}
|
||||
|
||||
func (q *SearchQuery) EventTypesFilter(types ...EventType) *SearchQuery {
|
||||
func (q *Query) EventTypesFilter(types ...EventType) *Query {
|
||||
return q.setFilter(NewFilter(Field_EventType, types, Operation_In))
|
||||
}
|
||||
|
||||
func (q *SearchQuery) LatestSequenceFilter(sequence uint64) *SearchQuery {
|
||||
func (q *Query) LatestSequenceFilter(sequence uint64) *Query {
|
||||
if sequence == 0 {
|
||||
return q
|
||||
}
|
||||
@@ -59,21 +90,25 @@ func (q *SearchQuery) LatestSequenceFilter(sequence uint64) *SearchQuery {
|
||||
return q.setFilter(NewFilter(Field_LatestSequence, sequence, sortOrder))
|
||||
}
|
||||
|
||||
func (q *SearchQuery) SequenceBetween(from, to uint64) *SearchQuery {
|
||||
func (q *Query) SequenceBetween(from, to uint64) *Query {
|
||||
q.setFilter(NewFilter(Field_LatestSequence, from, Operation_Greater))
|
||||
q.setFilter(NewFilter(Field_LatestSequence, to, Operation_Less))
|
||||
return q
|
||||
}
|
||||
|
||||
func (q *SearchQuery) ResourceOwnerFilter(resourceOwner string) *SearchQuery {
|
||||
func (q *Query) ResourceOwnerFilter(resourceOwner string) *Query {
|
||||
return q.setFilter(NewFilter(Field_ResourceOwner, resourceOwner, Operation_Equals))
|
||||
}
|
||||
|
||||
func (q *SearchQuery) InstanceIDFilter(instanceID string) *SearchQuery {
|
||||
func (q *Query) InstanceIDFilter(instanceID string) *Query {
|
||||
return q.setFilter(NewFilter(Field_InstanceID, instanceID, Operation_Equals))
|
||||
}
|
||||
|
||||
func (q *SearchQuery) CreationDateNewerFilter(time time.Time) *SearchQuery {
|
||||
func (q *Query) ExcludedInstanceIDsFilter(instanceIDs ...string) *Query {
|
||||
return q.setFilter(NewFilter(Field_InstanceID, instanceIDs, Operation_NotIn))
|
||||
}
|
||||
|
||||
func (q *Query) CreationDateNewerFilter(time time.Time) *Query {
|
||||
return q.setFilter(NewFilter(Field_CreationDate, time, Operation_Greater))
|
||||
}
|
||||
|
||||
@@ -92,12 +127,14 @@ func (q *SearchQuery) Validate() error {
|
||||
if q == nil {
|
||||
return errors.ThrowPreconditionFailed(nil, "MODEL-J5xQi", "search query is nil")
|
||||
}
|
||||
if len(q.Filters) == 0 {
|
||||
if len(q.Queries) == 0 {
|
||||
return errors.ThrowPreconditionFailed(nil, "MODEL-pF3DR", "no filters set")
|
||||
}
|
||||
for _, filter := range q.Filters {
|
||||
if err := filter.Validate(); err != nil {
|
||||
return err
|
||||
for _, query := range q.Queries {
|
||||
for _, filter := range query.Filters {
|
||||
if err := filter.Validate(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@@ -21,31 +21,48 @@ func testSetLimit(limit uint64) func(factory *SearchQueryFactory) *SearchQueryFa
|
||||
}
|
||||
}
|
||||
|
||||
func testSetSequence(sequence uint64) func(factory *SearchQueryFactory) *SearchQueryFactory {
|
||||
return func(factory *SearchQueryFactory) *SearchQueryFactory {
|
||||
factory = factory.SequenceGreater(sequence)
|
||||
return factory
|
||||
func testAddQuery(queryFuncs ...func(*query) *query) func(*SearchQueryFactory) *SearchQueryFactory {
|
||||
return func(builder *SearchQueryFactory) *SearchQueryFactory {
|
||||
query := builder.AddQuery()
|
||||
for _, queryFunc := range queryFuncs {
|
||||
queryFunc(query)
|
||||
}
|
||||
return query.Factory()
|
||||
}
|
||||
}
|
||||
|
||||
func testSetAggregateIDs(aggregateIDs ...string) func(factory *SearchQueryFactory) *SearchQueryFactory {
|
||||
return func(factory *SearchQueryFactory) *SearchQueryFactory {
|
||||
factory = factory.AggregateIDs(aggregateIDs...)
|
||||
return factory
|
||||
func testSetSequence(sequence uint64) func(*query) *query {
|
||||
return func(q *query) *query {
|
||||
q.SequenceGreater(sequence)
|
||||
return q
|
||||
}
|
||||
}
|
||||
|
||||
func testSetEventTypes(eventTypes ...EventType) func(factory *SearchQueryFactory) *SearchQueryFactory {
|
||||
return func(factory *SearchQueryFactory) *SearchQueryFactory {
|
||||
factory = factory.EventTypes(eventTypes...)
|
||||
return factory
|
||||
func testSetAggregateIDs(aggregateIDs ...string) func(*query) *query {
|
||||
return func(q *query) *query {
|
||||
q.AggregateIDs(aggregateIDs...)
|
||||
return q
|
||||
}
|
||||
}
|
||||
|
||||
func testSetResourceOwner(resourceOwner string) func(factory *SearchQueryFactory) *SearchQueryFactory {
|
||||
return func(factory *SearchQueryFactory) *SearchQueryFactory {
|
||||
factory = factory.ResourceOwner(resourceOwner)
|
||||
return factory
|
||||
func testSetAggregateTypes(aggregateTypes ...AggregateType) func(*query) *query {
|
||||
return func(q *query) *query {
|
||||
q.AggregateTypes(aggregateTypes...)
|
||||
return q
|
||||
}
|
||||
}
|
||||
|
||||
func testSetEventTypes(eventTypes ...EventType) func(*query) *query {
|
||||
return func(q *query) *query {
|
||||
q.EventTypes(eventTypes...)
|
||||
return q
|
||||
}
|
||||
}
|
||||
|
||||
func testSetResourceOwner(resourceOwner string) func(*query) *query {
|
||||
return func(q *query) *query {
|
||||
q.ResourceOwner(resourceOwner)
|
||||
return q
|
||||
}
|
||||
}
|
||||
|
||||
@@ -60,10 +77,50 @@ func testSetSortOrder(asc bool) func(factory *SearchQueryFactory) *SearchQueryFa
|
||||
}
|
||||
}
|
||||
|
||||
func assertFactory(t *testing.T, want, got *SearchQueryFactory) {
|
||||
t.Helper()
|
||||
|
||||
if got.columns != want.columns {
|
||||
t.Errorf("wrong column: got: %v want: %v", got.columns, want.columns)
|
||||
}
|
||||
if got.desc != want.desc {
|
||||
t.Errorf("wrong desc: got: %v want: %v", got.desc, want.desc)
|
||||
}
|
||||
if got.limit != want.limit {
|
||||
t.Errorf("wrong limit: got: %v want: %v", got.limit, want.limit)
|
||||
}
|
||||
if len(got.queries) != len(want.queries) {
|
||||
t.Errorf("wrong length of queries: got: %v want: %v", len(got.queries), len(want.queries))
|
||||
}
|
||||
|
||||
for i, query := range got.queries {
|
||||
assertQuery(t, i, want.queries[i], query)
|
||||
}
|
||||
}
|
||||
|
||||
func assertQuery(t *testing.T, i int, want, got *query) {
|
||||
t.Helper()
|
||||
|
||||
if !reflect.DeepEqual(got.aggregateIDs, want.aggregateIDs) {
|
||||
t.Errorf("wrong aggregateIDs in query %d : got: %v want: %v", i, got.aggregateIDs, want.aggregateIDs)
|
||||
}
|
||||
if !reflect.DeepEqual(got.aggregateTypes, want.aggregateTypes) {
|
||||
t.Errorf("wrong aggregateTypes in query %d : got: %v want: %v", i, got.aggregateTypes, want.aggregateTypes)
|
||||
}
|
||||
if got.sequenceFrom != want.sequenceFrom {
|
||||
t.Errorf("wrong sequenceFrom in query %d : got: %v want: %v", i, got.sequenceFrom, want.sequenceFrom)
|
||||
}
|
||||
if got.sequenceTo != want.sequenceTo {
|
||||
t.Errorf("wrong sequenceTo in query %d : got: %v want: %v", i, got.sequenceTo, want.sequenceTo)
|
||||
}
|
||||
if !reflect.DeepEqual(got.eventTypes, want.eventTypes) {
|
||||
t.Errorf("wrong eventTypes in query %d : got: %v want: %v", i, got.eventTypes, want.eventTypes)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSearchQueryFactorySetters(t *testing.T) {
|
||||
type args struct {
|
||||
aggregateTypes []AggregateType
|
||||
setters []func(*SearchQueryFactory) *SearchQueryFactory
|
||||
setters []func(*SearchQueryFactory) *SearchQueryFactory
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
@@ -73,11 +130,9 @@ func TestSearchQueryFactorySetters(t *testing.T) {
|
||||
{
|
||||
name: "New factory",
|
||||
args: args{
|
||||
aggregateTypes: []AggregateType{"user", "org"},
|
||||
},
|
||||
res: &SearchQueryFactory{
|
||||
aggregateTypes: []AggregateType{"user", "org"},
|
||||
setters: []func(*SearchQueryFactory) *SearchQueryFactory{},
|
||||
},
|
||||
res: &SearchQueryFactory{},
|
||||
},
|
||||
{
|
||||
name: "set columns",
|
||||
@@ -100,69 +155,98 @@ func TestSearchQueryFactorySetters(t *testing.T) {
|
||||
{
|
||||
name: "set sequence",
|
||||
args: args{
|
||||
setters: []func(*SearchQueryFactory) *SearchQueryFactory{testSetSequence(90)},
|
||||
setters: []func(*SearchQueryFactory) *SearchQueryFactory{testAddQuery(testSetSequence(90))},
|
||||
},
|
||||
res: &SearchQueryFactory{
|
||||
sequenceFrom: 90,
|
||||
queries: []*query{
|
||||
{
|
||||
sequenceFrom: 90,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "set aggregateTypes",
|
||||
args: args{
|
||||
setters: []func(*SearchQueryFactory) *SearchQueryFactory{testAddQuery(testSetAggregateTypes("user", "org"))},
|
||||
},
|
||||
res: &SearchQueryFactory{
|
||||
queries: []*query{
|
||||
{
|
||||
aggregateTypes: []AggregateType{"user", "org"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "set aggregateIDs",
|
||||
args: args{
|
||||
setters: []func(*SearchQueryFactory) *SearchQueryFactory{testSetAggregateIDs("1235", "09824")},
|
||||
setters: []func(*SearchQueryFactory) *SearchQueryFactory{testAddQuery(testSetAggregateIDs("1235", "09824"))},
|
||||
},
|
||||
res: &SearchQueryFactory{
|
||||
aggregateIDs: []string{"1235", "09824"},
|
||||
queries: []*query{
|
||||
{
|
||||
aggregateIDs: []string{"1235", "09824"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "set eventTypes",
|
||||
args: args{
|
||||
setters: []func(*SearchQueryFactory) *SearchQueryFactory{testSetEventTypes("user.created", "user.updated")},
|
||||
setters: []func(*SearchQueryFactory) *SearchQueryFactory{testAddQuery(testSetEventTypes("user.created", "user.updated"))},
|
||||
},
|
||||
res: &SearchQueryFactory{
|
||||
eventTypes: []EventType{"user.created", "user.updated"},
|
||||
queries: []*query{
|
||||
{
|
||||
eventTypes: []EventType{"user.created", "user.updated"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "set resource owner",
|
||||
args: args{
|
||||
setters: []func(*SearchQueryFactory) *SearchQueryFactory{testSetResourceOwner("hodor")},
|
||||
setters: []func(*SearchQueryFactory) *SearchQueryFactory{testAddQuery(testSetResourceOwner("hodor"))},
|
||||
},
|
||||
res: &SearchQueryFactory{
|
||||
resourceOwner: "hodor",
|
||||
queries: []*query{
|
||||
{
|
||||
resourceOwner: "hodor",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "default search query",
|
||||
args: args{
|
||||
aggregateTypes: []AggregateType{"user"},
|
||||
setters: []func(*SearchQueryFactory) *SearchQueryFactory{testSetAggregateIDs("1235", "024"), testSetSortOrder(false)},
|
||||
setters: []func(*SearchQueryFactory) *SearchQueryFactory{testAddQuery(testSetAggregateTypes("user"), testSetAggregateIDs("1235", "024")), testSetSortOrder(false)},
|
||||
},
|
||||
res: &SearchQueryFactory{
|
||||
aggregateTypes: []AggregateType{"user"},
|
||||
aggregateIDs: []string{"1235", "024"},
|
||||
desc: true,
|
||||
desc: true,
|
||||
queries: []*query{
|
||||
{
|
||||
aggregateTypes: []AggregateType{"user"},
|
||||
aggregateIDs: []string{"1235", "024"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
factory := NewSearchQueryFactory(tt.args.aggregateTypes...)
|
||||
factory := NewSearchQueryFactory()
|
||||
for _, setter := range tt.args.setters {
|
||||
factory = setter(factory)
|
||||
}
|
||||
if !reflect.DeepEqual(factory, tt.res) {
|
||||
t.Errorf("NewSearchQueryFactory() = %v, want %v", factory, tt.res)
|
||||
}
|
||||
assertFactory(t, tt.res, factory)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSearchQueryFactoryBuild(t *testing.T) {
|
||||
type args struct {
|
||||
aggregateTypes []AggregateType
|
||||
setters []func(*SearchQueryFactory) *SearchQueryFactory
|
||||
setters []func(*SearchQueryFactory) *SearchQueryFactory
|
||||
}
|
||||
type res struct {
|
||||
isErr func(err error) bool
|
||||
@@ -176,8 +260,7 @@ func TestSearchQueryFactoryBuild(t *testing.T) {
|
||||
{
|
||||
name: "no aggregate types",
|
||||
args: args{
|
||||
aggregateTypes: []AggregateType{},
|
||||
setters: []func(*SearchQueryFactory) *SearchQueryFactory{},
|
||||
setters: []func(*SearchQueryFactory) *SearchQueryFactory{},
|
||||
},
|
||||
res: res{
|
||||
isErr: errors.IsPreconditionFailed,
|
||||
@@ -187,9 +270,9 @@ func TestSearchQueryFactoryBuild(t *testing.T) {
|
||||
{
|
||||
name: "invalid column (too low)",
|
||||
args: args{
|
||||
aggregateTypes: []AggregateType{"user"},
|
||||
setters: []func(*SearchQueryFactory) *SearchQueryFactory{
|
||||
testSetColumns(Columns(-1)),
|
||||
testAddQuery(testSetAggregateTypes("user")),
|
||||
},
|
||||
},
|
||||
res: res{
|
||||
@@ -199,9 +282,9 @@ func TestSearchQueryFactoryBuild(t *testing.T) {
|
||||
{
|
||||
name: "invalid column (too high)",
|
||||
args: args{
|
||||
aggregateTypes: []AggregateType{"user"},
|
||||
setters: []func(*SearchQueryFactory) *SearchQueryFactory{
|
||||
testSetColumns(columnsCount),
|
||||
testAddQuery(testSetAggregateTypes("user")),
|
||||
},
|
||||
},
|
||||
res: res{
|
||||
@@ -211,8 +294,9 @@ func TestSearchQueryFactoryBuild(t *testing.T) {
|
||||
{
|
||||
name: "filter aggregate type",
|
||||
args: args{
|
||||
aggregateTypes: []AggregateType{"user"},
|
||||
setters: []func(*SearchQueryFactory) *SearchQueryFactory{},
|
||||
setters: []func(*SearchQueryFactory) *SearchQueryFactory{
|
||||
testAddQuery(testSetAggregateTypes("user")),
|
||||
},
|
||||
},
|
||||
res: res{
|
||||
isErr: nil,
|
||||
@@ -220,8 +304,10 @@ func TestSearchQueryFactoryBuild(t *testing.T) {
|
||||
Columns: 0,
|
||||
Desc: false,
|
||||
Limit: 0,
|
||||
Filters: []*Filter{
|
||||
NewFilter(Field_AggregateType, AggregateType("user"), Operation_Equals),
|
||||
Filters: [][]*Filter{
|
||||
{
|
||||
NewFilter(Field_AggregateType, AggregateType("user"), Operation_Equals),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -229,8 +315,9 @@ func TestSearchQueryFactoryBuild(t *testing.T) {
|
||||
{
|
||||
name: "filter aggregate types",
|
||||
args: args{
|
||||
aggregateTypes: []AggregateType{"user", "org"},
|
||||
setters: []func(*SearchQueryFactory) *SearchQueryFactory{},
|
||||
setters: []func(*SearchQueryFactory) *SearchQueryFactory{
|
||||
testAddQuery(testSetAggregateTypes("user", "org")),
|
||||
},
|
||||
},
|
||||
res: res{
|
||||
isErr: nil,
|
||||
@@ -238,8 +325,10 @@ func TestSearchQueryFactoryBuild(t *testing.T) {
|
||||
Columns: 0,
|
||||
Desc: false,
|
||||
Limit: 0,
|
||||
Filters: []*Filter{
|
||||
NewFilter(Field_AggregateType, []AggregateType{"user", "org"}, Operation_In),
|
||||
Filters: [][]*Filter{
|
||||
{
|
||||
NewFilter(Field_AggregateType, []AggregateType{"user", "org"}, Operation_In),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -247,11 +336,13 @@ func TestSearchQueryFactoryBuild(t *testing.T) {
|
||||
{
|
||||
name: "filter aggregate type, limit, desc",
|
||||
args: args{
|
||||
aggregateTypes: []AggregateType{"user"},
|
||||
setters: []func(*SearchQueryFactory) *SearchQueryFactory{
|
||||
testSetLimit(5),
|
||||
testSetSortOrder(false),
|
||||
testSetSequence(100),
|
||||
testAddQuery(
|
||||
testSetAggregateTypes("user"),
|
||||
testSetSequence(100),
|
||||
),
|
||||
},
|
||||
},
|
||||
res: res{
|
||||
@@ -260,9 +351,11 @@ func TestSearchQueryFactoryBuild(t *testing.T) {
|
||||
Columns: 0,
|
||||
Desc: true,
|
||||
Limit: 5,
|
||||
Filters: []*Filter{
|
||||
NewFilter(Field_AggregateType, AggregateType("user"), Operation_Equals),
|
||||
NewFilter(Field_LatestSequence, uint64(100), Operation_Less),
|
||||
Filters: [][]*Filter{
|
||||
{
|
||||
NewFilter(Field_AggregateType, AggregateType("user"), Operation_Equals),
|
||||
NewFilter(Field_LatestSequence, uint64(100), Operation_Less),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -270,11 +363,13 @@ func TestSearchQueryFactoryBuild(t *testing.T) {
|
||||
{
|
||||
name: "filter aggregate type, limit, asc",
|
||||
args: args{
|
||||
aggregateTypes: []AggregateType{"user"},
|
||||
setters: []func(*SearchQueryFactory) *SearchQueryFactory{
|
||||
testSetLimit(5),
|
||||
testSetSortOrder(true),
|
||||
testSetSequence(100),
|
||||
testAddQuery(
|
||||
testSetSequence(100),
|
||||
testSetAggregateTypes("user"),
|
||||
),
|
||||
},
|
||||
},
|
||||
res: res{
|
||||
@@ -283,9 +378,11 @@ func TestSearchQueryFactoryBuild(t *testing.T) {
|
||||
Columns: 0,
|
||||
Desc: false,
|
||||
Limit: 5,
|
||||
Filters: []*Filter{
|
||||
NewFilter(Field_AggregateType, AggregateType("user"), Operation_Equals),
|
||||
NewFilter(Field_LatestSequence, uint64(100), Operation_Greater),
|
||||
Filters: [][]*Filter{
|
||||
{
|
||||
NewFilter(Field_AggregateType, AggregateType("user"), Operation_Equals),
|
||||
NewFilter(Field_LatestSequence, uint64(100), Operation_Greater),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -293,12 +390,14 @@ func TestSearchQueryFactoryBuild(t *testing.T) {
|
||||
{
|
||||
name: "filter aggregate type, limit, desc, max event sequence cols",
|
||||
args: args{
|
||||
aggregateTypes: []AggregateType{"user"},
|
||||
setters: []func(*SearchQueryFactory) *SearchQueryFactory{
|
||||
testSetLimit(5),
|
||||
testSetSortOrder(false),
|
||||
testSetSequence(100),
|
||||
testSetColumns(Columns_Max_Sequence),
|
||||
testAddQuery(
|
||||
testSetSequence(100),
|
||||
testSetAggregateTypes("user"),
|
||||
),
|
||||
},
|
||||
},
|
||||
res: res{
|
||||
@@ -307,9 +406,11 @@ func TestSearchQueryFactoryBuild(t *testing.T) {
|
||||
Columns: Columns_Max_Sequence,
|
||||
Desc: true,
|
||||
Limit: 5,
|
||||
Filters: []*Filter{
|
||||
NewFilter(Field_AggregateType, AggregateType("user"), Operation_Equals),
|
||||
NewFilter(Field_LatestSequence, uint64(100), Operation_Less),
|
||||
Filters: [][]*Filter{
|
||||
{
|
||||
NewFilter(Field_AggregateType, AggregateType("user"), Operation_Equals),
|
||||
NewFilter(Field_LatestSequence, uint64(100), Operation_Less),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -317,9 +418,11 @@ func TestSearchQueryFactoryBuild(t *testing.T) {
|
||||
{
|
||||
name: "filter aggregate type and aggregate id",
|
||||
args: args{
|
||||
aggregateTypes: []AggregateType{"user"},
|
||||
setters: []func(*SearchQueryFactory) *SearchQueryFactory{
|
||||
testSetAggregateIDs("1234"),
|
||||
testAddQuery(
|
||||
testSetAggregateIDs("1234"),
|
||||
testSetAggregateTypes("user"),
|
||||
),
|
||||
},
|
||||
},
|
||||
res: res{
|
||||
@@ -328,9 +431,11 @@ func TestSearchQueryFactoryBuild(t *testing.T) {
|
||||
Columns: 0,
|
||||
Desc: false,
|
||||
Limit: 0,
|
||||
Filters: []*Filter{
|
||||
NewFilter(Field_AggregateType, AggregateType("user"), Operation_Equals),
|
||||
NewFilter(Field_AggregateID, "1234", Operation_Equals),
|
||||
Filters: [][]*Filter{
|
||||
{
|
||||
NewFilter(Field_AggregateType, AggregateType("user"), Operation_Equals),
|
||||
NewFilter(Field_AggregateID, "1234", Operation_Equals),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -338,9 +443,11 @@ func TestSearchQueryFactoryBuild(t *testing.T) {
|
||||
{
|
||||
name: "filter aggregate type and aggregate ids",
|
||||
args: args{
|
||||
aggregateTypes: []AggregateType{"user"},
|
||||
setters: []func(*SearchQueryFactory) *SearchQueryFactory{
|
||||
testSetAggregateIDs("1234", "0815"),
|
||||
testAddQuery(
|
||||
testSetAggregateIDs("1234", "0815"),
|
||||
testSetAggregateTypes("user"),
|
||||
),
|
||||
},
|
||||
},
|
||||
res: res{
|
||||
@@ -349,9 +456,11 @@ func TestSearchQueryFactoryBuild(t *testing.T) {
|
||||
Columns: 0,
|
||||
Desc: false,
|
||||
Limit: 0,
|
||||
Filters: []*Filter{
|
||||
NewFilter(Field_AggregateType, AggregateType("user"), Operation_Equals),
|
||||
NewFilter(Field_AggregateID, []string{"1234", "0815"}, Operation_In),
|
||||
Filters: [][]*Filter{
|
||||
{
|
||||
NewFilter(Field_AggregateType, AggregateType("user"), Operation_Equals),
|
||||
NewFilter(Field_AggregateID, []string{"1234", "0815"}, Operation_In),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -359,9 +468,11 @@ func TestSearchQueryFactoryBuild(t *testing.T) {
|
||||
{
|
||||
name: "filter aggregate type and sequence greater",
|
||||
args: args{
|
||||
aggregateTypes: []AggregateType{"user"},
|
||||
setters: []func(*SearchQueryFactory) *SearchQueryFactory{
|
||||
testSetSequence(8),
|
||||
testAddQuery(
|
||||
testSetSequence(8),
|
||||
testSetAggregateTypes("user"),
|
||||
),
|
||||
},
|
||||
},
|
||||
res: res{
|
||||
@@ -370,9 +481,11 @@ func TestSearchQueryFactoryBuild(t *testing.T) {
|
||||
Columns: 0,
|
||||
Desc: false,
|
||||
Limit: 0,
|
||||
Filters: []*Filter{
|
||||
NewFilter(Field_AggregateType, AggregateType("user"), Operation_Equals),
|
||||
NewFilter(Field_LatestSequence, uint64(8), Operation_Greater),
|
||||
Filters: [][]*Filter{
|
||||
{
|
||||
NewFilter(Field_AggregateType, AggregateType("user"), Operation_Equals),
|
||||
NewFilter(Field_LatestSequence, uint64(8), Operation_Greater),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -380,9 +493,11 @@ func TestSearchQueryFactoryBuild(t *testing.T) {
|
||||
{
|
||||
name: "filter aggregate type and event type",
|
||||
args: args{
|
||||
aggregateTypes: []AggregateType{"user"},
|
||||
setters: []func(*SearchQueryFactory) *SearchQueryFactory{
|
||||
testSetEventTypes("user.created"),
|
||||
testAddQuery(
|
||||
testSetAggregateTypes("user"),
|
||||
testSetEventTypes("user.created"),
|
||||
),
|
||||
},
|
||||
},
|
||||
res: res{
|
||||
@@ -391,9 +506,11 @@ func TestSearchQueryFactoryBuild(t *testing.T) {
|
||||
Columns: 0,
|
||||
Desc: false,
|
||||
Limit: 0,
|
||||
Filters: []*Filter{
|
||||
NewFilter(Field_AggregateType, AggregateType("user"), Operation_Equals),
|
||||
NewFilter(Field_EventType, EventType("user.created"), Operation_Equals),
|
||||
Filters: [][]*Filter{
|
||||
{
|
||||
NewFilter(Field_AggregateType, AggregateType("user"), Operation_Equals),
|
||||
NewFilter(Field_EventType, EventType("user.created"), Operation_Equals),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -401,9 +518,11 @@ func TestSearchQueryFactoryBuild(t *testing.T) {
|
||||
{
|
||||
name: "filter aggregate type and event types",
|
||||
args: args{
|
||||
aggregateTypes: []AggregateType{"user"},
|
||||
setters: []func(*SearchQueryFactory) *SearchQueryFactory{
|
||||
testSetEventTypes("user.created", "user.changed"),
|
||||
testAddQuery(
|
||||
testSetAggregateTypes("user"),
|
||||
testSetEventTypes("user.created", "user.changed"),
|
||||
),
|
||||
},
|
||||
},
|
||||
res: res{
|
||||
@@ -412,9 +531,11 @@ func TestSearchQueryFactoryBuild(t *testing.T) {
|
||||
Columns: 0,
|
||||
Desc: false,
|
||||
Limit: 0,
|
||||
Filters: []*Filter{
|
||||
NewFilter(Field_AggregateType, AggregateType("user"), Operation_Equals),
|
||||
NewFilter(Field_EventType, []EventType{"user.created", "user.changed"}, Operation_In),
|
||||
Filters: [][]*Filter{
|
||||
{
|
||||
NewFilter(Field_AggregateType, AggregateType("user"), Operation_Equals),
|
||||
NewFilter(Field_EventType, []EventType{"user.created", "user.changed"}, Operation_In),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -422,9 +543,11 @@ func TestSearchQueryFactoryBuild(t *testing.T) {
|
||||
{
|
||||
name: "filter aggregate type resource owner",
|
||||
args: args{
|
||||
aggregateTypes: []AggregateType{"user"},
|
||||
setters: []func(*SearchQueryFactory) *SearchQueryFactory{
|
||||
testSetResourceOwner("hodor"),
|
||||
testAddQuery(
|
||||
testSetAggregateTypes("user"),
|
||||
testSetResourceOwner("hodor"),
|
||||
),
|
||||
},
|
||||
},
|
||||
res: res{
|
||||
@@ -433,9 +556,11 @@ func TestSearchQueryFactoryBuild(t *testing.T) {
|
||||
Columns: 0,
|
||||
Desc: false,
|
||||
Limit: 0,
|
||||
Filters: []*Filter{
|
||||
NewFilter(Field_AggregateType, AggregateType("user"), Operation_Equals),
|
||||
NewFilter(Field_ResourceOwner, "hodor", Operation_Equals),
|
||||
Filters: [][]*Filter{
|
||||
{
|
||||
NewFilter(Field_AggregateType, AggregateType("user"), Operation_Equals),
|
||||
NewFilter(Field_ResourceOwner, "hodor", Operation_Equals),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -443,7 +568,7 @@ func TestSearchQueryFactoryBuild(t *testing.T) {
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
factory := NewSearchQueryFactory(tt.args.aggregateTypes...)
|
||||
factory := NewSearchQueryFactory()
|
||||
for _, f := range tt.args.setters {
|
||||
factory = f(factory)
|
||||
}
|
||||
|
@@ -26,7 +26,7 @@ type Handler interface {
|
||||
QueryLimit() uint64
|
||||
|
||||
AggregateTypes() []models.AggregateType
|
||||
CurrentSequence() (uint64, error)
|
||||
CurrentSequence(instanceID string) (uint64, error)
|
||||
Eventstore() v1.Eventstore
|
||||
|
||||
Subscription() *v1.Subscription
|
||||
@@ -41,15 +41,18 @@ func ReduceEvent(handler Handler, event *models.Event) {
|
||||
handler.Subscription().Unsubscribe()
|
||||
}
|
||||
}()
|
||||
currentSequence, err := handler.CurrentSequence()
|
||||
currentSequence, err := handler.CurrentSequence(event.InstanceID)
|
||||
if err != nil {
|
||||
logging.New().WithError(err).Warn("unable to get current sequence")
|
||||
return
|
||||
}
|
||||
|
||||
searchQuery := models.NewSearchQuery().
|
||||
AddQuery().
|
||||
AggregateTypeFilter(handler.AggregateTypes()...).
|
||||
SequenceBetween(currentSequence, event.Sequence).
|
||||
InstanceIDFilter(event.InstanceID).
|
||||
SearchQuery().
|
||||
SetLimit(eventLimit)
|
||||
|
||||
unprocessedEvents, err := handler.Eventstore().FilterEvents(context.Background(), searchQuery)
|
||||
@@ -59,7 +62,7 @@ func ReduceEvent(handler Handler, event *models.Event) {
|
||||
}
|
||||
|
||||
for _, unprocessedEvent := range unprocessedEvents {
|
||||
currentSequence, err := handler.CurrentSequence()
|
||||
currentSequence, err := handler.CurrentSequence(unprocessedEvent.InstanceID)
|
||||
if err != nil {
|
||||
logging.Log("HANDL-BmpkC").WithError(err).Warn("unable to get current sequence")
|
||||
return
|
||||
|
@@ -5,44 +5,45 @@
|
||||
package mock
|
||||
|
||||
import (
|
||||
gomock "github.com/golang/mock/gomock"
|
||||
reflect "reflect"
|
||||
time "time"
|
||||
|
||||
gomock "github.com/golang/mock/gomock"
|
||||
)
|
||||
|
||||
// MockLocker is a mock of Locker interface
|
||||
// MockLocker is a mock of Locker interface.
|
||||
type MockLocker struct {
|
||||
ctrl *gomock.Controller
|
||||
recorder *MockLockerMockRecorder
|
||||
}
|
||||
|
||||
// MockLockerMockRecorder is the mock recorder for MockLocker
|
||||
// MockLockerMockRecorder is the mock recorder for MockLocker.
|
||||
type MockLockerMockRecorder struct {
|
||||
mock *MockLocker
|
||||
}
|
||||
|
||||
// NewMockLocker creates a new mock instance
|
||||
// NewMockLocker creates a new mock instance.
|
||||
func NewMockLocker(ctrl *gomock.Controller) *MockLocker {
|
||||
mock := &MockLocker{ctrl: ctrl}
|
||||
mock.recorder = &MockLockerMockRecorder{mock}
|
||||
return mock
|
||||
}
|
||||
|
||||
// EXPECT returns an object that allows the caller to indicate expected use
|
||||
// EXPECT returns an object that allows the caller to indicate expected use.
|
||||
func (m *MockLocker) EXPECT() *MockLockerMockRecorder {
|
||||
return m.recorder
|
||||
}
|
||||
|
||||
// Renew mocks base method
|
||||
func (m *MockLocker) Renew(lockerID, viewModel string, waitTime time.Duration) error {
|
||||
// Renew mocks base method.
|
||||
func (m *MockLocker) Renew(lockerID, viewModel, instanceID string, waitTime time.Duration) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Renew", lockerID, viewModel, waitTime)
|
||||
ret := m.ctrl.Call(m, "Renew", lockerID, viewModel, instanceID, waitTime)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// Renew indicates an expected call of Renew
|
||||
func (mr *MockLockerMockRecorder) Renew(lockerID, viewModel, waitTime interface{}) *gomock.Call {
|
||||
// Renew indicates an expected call of Renew.
|
||||
func (mr *MockLockerMockRecorder) Renew(lockerID, viewModel, instanceID, waitTime interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Renew", reflect.TypeOf((*MockLocker)(nil).Renew), lockerID, viewModel, waitTime)
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Renew", reflect.TypeOf((*MockLocker)(nil).Renew), lockerID, viewModel, instanceID, waitTime)
|
||||
}
|
||||
|
@@ -16,6 +16,8 @@ import (
|
||||
"github.com/caos/zitadel/internal/view/repository"
|
||||
)
|
||||
|
||||
const systemID = "system"
|
||||
|
||||
type Spooler struct {
|
||||
handlers []query.Handler
|
||||
locker Locker
|
||||
@@ -26,7 +28,7 @@ type Spooler struct {
|
||||
}
|
||||
|
||||
type Locker interface {
|
||||
Renew(lockerID, viewModel string, waitTime time.Duration) error
|
||||
Renew(lockerID, viewModel, instanceID string, waitTime time.Duration) error
|
||||
}
|
||||
|
||||
type spooledHandler struct {
|
||||
@@ -138,19 +140,6 @@ func (s *spooledHandler) query(ctx context.Context) ([]*models.Event, error) {
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
factory := models.FactoryFromSearchQuery(query)
|
||||
sequence, err := s.eventstore.LatestSequence(ctx, factory)
|
||||
logging.OnError(err).WithField("traceID", tracing.TraceIDFromCtx(ctx)).Debug("unable to query latest sequence")
|
||||
var processedSequence uint64
|
||||
for _, filter := range query.Filters {
|
||||
if filter.GetField() == models.Field_LatestSequence {
|
||||
processedSequence = filter.GetValue().(uint64)
|
||||
}
|
||||
}
|
||||
if sequence != 0 && processedSequence == sequence {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
query.Limit = s.QueryLimit()
|
||||
return s.eventstore.FilterEvents(ctx, query)
|
||||
}
|
||||
@@ -169,7 +158,7 @@ func (s *spooledHandler) lock(ctx context.Context, errs chan<- error, workerID s
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-renewTimer:
|
||||
err := s.locker.Renew(workerID, s.ViewModel(), s.LockDuration())
|
||||
err := s.locker.Renew(workerID, s.ViewModel(), systemID, s.LockDuration())
|
||||
firstLock.Do(func() {
|
||||
locked <- err == nil
|
||||
})
|
||||
@@ -190,16 +179,17 @@ func (s *spooledHandler) lock(ctx context.Context, errs chan<- error, workerID s
|
||||
}
|
||||
|
||||
func HandleError(event *models.Event, failedErr error,
|
||||
latestFailedEvent func(sequence uint64) (*repository.FailedEvent, error),
|
||||
latestFailedEvent func(sequence uint64, instanceID string) (*repository.FailedEvent, error),
|
||||
processFailedEvent func(*repository.FailedEvent) error,
|
||||
processSequence func(*models.Event) error,
|
||||
errorCountUntilSkip uint64) error {
|
||||
failedEvent, err := latestFailedEvent(event.Sequence)
|
||||
failedEvent, err := latestFailedEvent(event.Sequence, event.InstanceID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
failedEvent.FailureCount++
|
||||
failedEvent.ErrMsg = failedErr.Error()
|
||||
failedEvent.InstanceID = event.InstanceID
|
||||
err = processFailedEvent(failedEvent)
|
||||
if err != nil {
|
||||
return err
|
||||
|
@@ -3,17 +3,18 @@ package spooler
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"github.com/caos/zitadel/internal/eventstore"
|
||||
"github.com/caos/zitadel/internal/eventstore/v1"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/golang/mock/gomock"
|
||||
|
||||
"github.com/caos/zitadel/internal/errors"
|
||||
"github.com/caos/zitadel/internal/eventstore"
|
||||
v1 "github.com/caos/zitadel/internal/eventstore/v1"
|
||||
"github.com/caos/zitadel/internal/eventstore/v1/models"
|
||||
"github.com/caos/zitadel/internal/eventstore/v1/query"
|
||||
"github.com/caos/zitadel/internal/eventstore/v1/spooler/mock"
|
||||
"github.com/caos/zitadel/internal/view/repository"
|
||||
"github.com/golang/mock/gomock"
|
||||
)
|
||||
|
||||
type testHandler struct {
|
||||
@@ -30,7 +31,7 @@ func (h *testHandler) AggregateTypes() []models.AggregateType {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *testHandler) CurrentSequence() (uint64, error) {
|
||||
func (h *testHandler) CurrentSequence(instanceID string) (uint64, error) {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
@@ -376,8 +377,8 @@ func newTestLocker(t *testing.T, lockerID, viewName string) *testLocker {
|
||||
|
||||
func (l *testLocker) expectRenew(t *testing.T, err error, waitTime time.Duration) *testLocker {
|
||||
t.Helper()
|
||||
l.mock.EXPECT().Renew(gomock.Any(), l.viewName, gomock.Any()).DoAndReturn(
|
||||
func(_, _ string, gotten time.Duration) error {
|
||||
l.mock.EXPECT().Renew(gomock.Any(), l.viewName, gomock.Any(), gomock.Any()).DoAndReturn(
|
||||
func(_, _, _ string, gotten time.Duration) error {
|
||||
t.Helper()
|
||||
if waitTime-gotten != 0 {
|
||||
t.Errorf("expected waittime %v got %v", waitTime, gotten)
|
||||
@@ -396,7 +397,7 @@ func TestHandleError(t *testing.T) {
|
||||
type args struct {
|
||||
event *models.Event
|
||||
failedErr error
|
||||
latestFailedEvent func(sequence uint64) (*repository.FailedEvent, error)
|
||||
latestFailedEvent func(sequence uint64, instanceID string) (*repository.FailedEvent, error)
|
||||
errorCountUntilSkip uint64
|
||||
}
|
||||
type res struct {
|
||||
@@ -413,12 +414,13 @@ func TestHandleError(t *testing.T) {
|
||||
args: args{
|
||||
event: &models.Event{Sequence: 30000000},
|
||||
failedErr: errors.ThrowInternal(nil, "SPOOL-Wk53B", "this was wrong"),
|
||||
latestFailedEvent: func(s uint64) (*repository.FailedEvent, error) {
|
||||
latestFailedEvent: func(s uint64, instanceID string) (*repository.FailedEvent, error) {
|
||||
return &repository.FailedEvent{
|
||||
ErrMsg: "blub",
|
||||
FailedSequence: s - 1,
|
||||
FailureCount: 6,
|
||||
ViewName: "super.table",
|
||||
InstanceID: instanceID,
|
||||
}, nil
|
||||
},
|
||||
errorCountUntilSkip: 5,
|
||||
@@ -432,12 +434,13 @@ func TestHandleError(t *testing.T) {
|
||||
args: args{
|
||||
event: &models.Event{Sequence: 30000000},
|
||||
failedErr: errors.ThrowInternal(nil, "SPOOL-Wk53B", "this was wrong"),
|
||||
latestFailedEvent: func(s uint64) (*repository.FailedEvent, error) {
|
||||
latestFailedEvent: func(s uint64, instanceID string) (*repository.FailedEvent, error) {
|
||||
return &repository.FailedEvent{
|
||||
ErrMsg: "blub",
|
||||
FailedSequence: s - 1,
|
||||
FailureCount: 5,
|
||||
ViewName: "super.table",
|
||||
InstanceID: instanceID,
|
||||
}, nil
|
||||
},
|
||||
errorCountUntilSkip: 6,
|
||||
@@ -451,12 +454,13 @@ func TestHandleError(t *testing.T) {
|
||||
args: args{
|
||||
event: &models.Event{Sequence: 30000000},
|
||||
failedErr: errors.ThrowInternal(nil, "SPOOL-Wk53B", "this was wrong"),
|
||||
latestFailedEvent: func(s uint64) (*repository.FailedEvent, error) {
|
||||
latestFailedEvent: func(s uint64, instanceID string) (*repository.FailedEvent, error) {
|
||||
return &repository.FailedEvent{
|
||||
ErrMsg: "blub",
|
||||
FailedSequence: s - 1,
|
||||
FailureCount: 3,
|
||||
ViewName: "super.table",
|
||||
InstanceID: instanceID,
|
||||
}, nil
|
||||
},
|
||||
errorCountUntilSkip: 5,
|
||||
|
Reference in New Issue
Block a user