feat: handle instanceID in projections (#3442)

* feat: handle instanceID in projections

* rename functions

* fix key lock

* fix import
This commit is contained in:
Livio Amstutz
2022-04-19 08:26:12 +02:00
committed by GitHub
parent c25d853820
commit 1305c14e49
120 changed files with 2078 additions and 1209 deletions

View File

@@ -12,7 +12,6 @@ import (
type Eventstore interface {
Health(ctx context.Context) error
FilterEvents(ctx context.Context, searchQuery *models.SearchQuery) (events []*models.Event, err error)
LatestSequence(ctx context.Context, searchQuery *models.SearchQueryFactory) (uint64, error)
Subscribe(aggregates ...models.AggregateType) *Subscription
}
@@ -35,13 +34,6 @@ func (es *eventstore) FilterEvents(ctx context.Context, searchQuery *models.Sear
return es.repo.Filter(ctx, models.FactoryFromSearchQuery(searchQuery))
}
func (es *eventstore) LatestSequence(ctx context.Context, queryFactory *models.SearchQueryFactory) (uint64, error) {
sequenceFactory := *queryFactory
sequenceFactory = *(&sequenceFactory).Columns(models.Columns_Max_Sequence)
sequenceFactory = *(&sequenceFactory).SequenceGreater(0)
return es.repo.LatestSequence(ctx, &sequenceFactory)
}
func (es *eventstore) Health(ctx context.Context) error {
return es.repo.Health(ctx)
}

View File

@@ -12,16 +12,16 @@ import (
)
const (
selectEscaped = `SELECT creation_date, event_type, event_sequence, previous_aggregate_sequence, event_data, editor_service, editor_user, resource_owner, instance_id, aggregate_type, aggregate_id, aggregate_version FROM eventstore\.events WHERE aggregate_type = \$1`
selectEscaped = `SELECT creation_date, event_type, event_sequence, previous_aggregate_sequence, event_data, editor_service, editor_user, resource_owner, instance_id, aggregate_type, aggregate_id, aggregate_version FROM eventstore\.events WHERE \( aggregate_type = \$1`
)
var (
eventColumns = []string{"creation_date", "event_type", "event_sequence", "previous_aggregate_sequence", "event_data", "editor_service", "editor_user", "resource_owner", "instance_id", "aggregate_type", "aggregate_id", "aggregate_version"}
expectedFilterEventsLimitFormat = regexp.MustCompile(selectEscaped + ` ORDER BY event_sequence LIMIT \$2`).String()
expectedFilterEventsDescFormat = regexp.MustCompile(selectEscaped + ` ORDER BY event_sequence DESC`).String()
expectedFilterEventsAggregateIDLimit = regexp.MustCompile(selectEscaped + ` AND aggregate_id = \$2 ORDER BY event_sequence LIMIT \$3`).String()
expectedFilterEventsAggregateIDTypeLimit = regexp.MustCompile(selectEscaped + ` AND aggregate_id = \$2 ORDER BY event_sequence LIMIT \$3`).String()
expectedGetAllEvents = regexp.MustCompile(selectEscaped + ` ORDER BY event_sequence`).String()
expectedFilterEventsLimitFormat = regexp.MustCompile(selectEscaped + ` \) ORDER BY event_sequence LIMIT \$2`).String()
expectedFilterEventsDescFormat = regexp.MustCompile(selectEscaped + ` \) ORDER BY event_sequence DESC`).String()
expectedFilterEventsAggregateIDLimit = regexp.MustCompile(selectEscaped + ` AND aggregate_id = \$2 \) ORDER BY event_sequence LIMIT \$3`).String()
expectedFilterEventsAggregateIDTypeLimit = regexp.MustCompile(selectEscaped + ` AND aggregate_id = \$2 \) ORDER BY event_sequence LIMIT \$3`).String()
expectedGetAllEvents = regexp.MustCompile(selectEscaped + ` \) ORDER BY event_sequence`).String()
expectedInsertStatement = regexp.MustCompile(`INSERT INTO eventstore\.events ` +
`\(event_type, aggregate_type, aggregate_id, aggregate_version, creation_date, event_data, editor_user, editor_service, resource_owner, instance_id, previous_aggregate_sequence, previous_aggregate_type_sequence\) ` +
@@ -172,14 +172,14 @@ func (db *dbMock) expectFilterEventsError(returnedErr error) *dbMock {
}
func (db *dbMock) expectLatestSequenceFilter(aggregateType string, sequence Sequence) *dbMock {
db.mock.ExpectQuery(`SELECT MAX\(event_sequence\) FROM eventstore\.events WHERE aggregate_type = \$1`).
db.mock.ExpectQuery(`SELECT MAX\(event_sequence\) FROM eventstore\.events WHERE \( aggregate_type = \$1 \)`).
WithArgs(aggregateType).
WillReturnRows(sqlmock.NewRows([]string{"max_sequence"}).AddRow(sequence))
return db
}
func (db *dbMock) expectLatestSequenceFilterError(aggregateType string, err error) *dbMock {
db.mock.ExpectQuery(`SELECT MAX\(event_sequence\) FROM eventstore\.events WHERE aggregate_type = \$1`).
db.mock.ExpectQuery(`SELECT MAX\(event_sequence\) FROM eventstore\.events WHERE \( aggregate_type = \$1 \)`).
WithArgs(aggregateType).WillReturnError(err)
return db
}

View File

@@ -41,7 +41,7 @@ func TestSQL_Filter(t *testing.T) {
},
args: args{
events: &mockEvents{t: t},
searchQuery: es_models.NewSearchQueryFactory("user").Limit(34),
searchQuery: es_models.NewSearchQueryFactory().Limit(34).AddQuery().AggregateTypes("user").Factory(),
},
res: res{
eventsLen: 3,
@@ -55,7 +55,7 @@ func TestSQL_Filter(t *testing.T) {
},
args: args{
events: &mockEvents{t: t},
searchQuery: es_models.NewSearchQueryFactory("user").OrderDesc(),
searchQuery: es_models.NewSearchQueryFactory().OrderDesc().AddQuery().AggregateTypes("user").Factory(),
},
res: res{
eventsLen: 34,
@@ -69,7 +69,7 @@ func TestSQL_Filter(t *testing.T) {
},
args: args{
events: &mockEvents{t: t},
searchQuery: es_models.NewSearchQueryFactory("nonAggregate"),
searchQuery: es_models.NewSearchQueryFactory().AddQuery().AggregateTypes("nonAggregate").Factory(),
},
res: res{
wantErr: true,
@@ -83,7 +83,7 @@ func TestSQL_Filter(t *testing.T) {
},
args: args{
events: &mockEvents{t: t},
searchQuery: es_models.NewSearchQueryFactory("user"),
searchQuery: es_models.NewSearchQueryFactory().AddQuery().AggregateTypes("user").Factory(),
},
res: res{
wantErr: true,
@@ -97,7 +97,7 @@ func TestSQL_Filter(t *testing.T) {
},
args: args{
events: &mockEvents{t: t},
searchQuery: es_models.NewSearchQueryFactory("user").Limit(5).AggregateIDs("hop"),
searchQuery: es_models.NewSearchQueryFactory().Limit(5).AddQuery().AggregateTypes("user").AggregateIDs("hop").Factory(),
},
res: res{
wantErr: false,
@@ -111,7 +111,7 @@ func TestSQL_Filter(t *testing.T) {
},
args: args{
events: &mockEvents{t: t},
searchQuery: es_models.NewSearchQueryFactory("user").Limit(5).AggregateIDs("hop"),
searchQuery: es_models.NewSearchQueryFactory().Limit(5).AddQuery().AggregateTypes("user").AggregateIDs("hop").Factory(),
},
res: res{
wantErr: false,
@@ -176,7 +176,7 @@ func TestSQL_LatestSequence(t *testing.T) {
{
name: "no events for aggregate",
args: args{
searchQuery: es_models.NewSearchQueryFactory("idiot").Columns(es_models.Columns_Max_Sequence),
searchQuery: es_models.NewSearchQueryFactory().Columns(es_models.Columns_Max_Sequence).AddQuery().AggregateTypes("idiot").Factory(),
},
fields: fields{
client: mockDB(t).expectLatestSequenceFilterError("idiot", sql.ErrNoRows),
@@ -189,7 +189,7 @@ func TestSQL_LatestSequence(t *testing.T) {
{
name: "sql query error",
args: args{
searchQuery: es_models.NewSearchQueryFactory("idiot").Columns(es_models.Columns_Max_Sequence),
searchQuery: es_models.NewSearchQueryFactory().Columns(es_models.Columns_Max_Sequence).AddQuery().AggregateTypes("idiot").Factory(),
},
fields: fields{
client: mockDB(t).expectLatestSequenceFilterError("idiot", sql.ErrConnDone),
@@ -203,7 +203,7 @@ func TestSQL_LatestSequence(t *testing.T) {
{
name: "events for aggregate found",
args: args{
searchQuery: es_models.NewSearchQueryFactory("user").Columns(es_models.Columns_Max_Sequence),
searchQuery: es_models.NewSearchQueryFactory().Columns(es_models.Columns_Max_Sequence).AddQuery().AggregateTypes("user").Factory(),
},
fields: fields{
client: mockDB(t).expectLatestSequenceFilter("user", math.MaxUint64),

View File

@@ -61,27 +61,31 @@ func buildQuery(queryFactory *es_models.SearchQueryFactory) (query string, limit
return query, searchQuery.Limit, values, rowScanner
}
func prepareCondition(filters []*es_models.Filter) (clause string, values []interface{}) {
values = make([]interface{}, len(filters))
func prepareCondition(filters [][]*es_models.Filter) (clause string, values []interface{}) {
values = make([]interface{}, 0, len(filters))
clauses := make([]string, len(filters))
if len(filters) == 0 {
return clause, values
}
for i, filter := range filters {
value := filter.GetValue()
switch value.(type) {
case []bool, []float64, []int64, []string, []es_models.AggregateType, []es_models.EventType, *[]bool, *[]float64, *[]int64, *[]string, *[]es_models.AggregateType, *[]es_models.EventType:
value = pq.Array(value)
}
subClauses := make([]string, 0, len(filter))
for _, f := range filter {
value := f.GetValue()
switch value.(type) {
case []bool, []float64, []int64, []string, []es_models.AggregateType, []es_models.EventType, *[]bool, *[]float64, *[]int64, *[]string, *[]es_models.AggregateType, *[]es_models.EventType:
value = pq.Array(value)
}
clauses[i] = getCondition(filter)
if clauses[i] == "" {
return "", nil
subClauses = append(subClauses, getCondition(f))
if subClauses[len(subClauses)-1] == "" {
return "", nil
}
values = append(values, value)
}
values[i] = value
clauses[i] = "( " + strings.Join(subClauses, " AND ") + " )"
}
return " WHERE " + strings.Join(clauses, " AND "), values
return " WHERE " + strings.Join(clauses, " OR "), values
}
type scan func(dest ...interface{}) error
@@ -162,8 +166,11 @@ func getCondition(filter *es_models.Filter) (condition string) {
}
func getConditionFormat(operation es_models.Operation) string {
if operation == es_models.Operation_In {
switch operation {
case es_models.Operation_In:
return "%s %s ANY(?)"
case es_models.Operation_NotIn:
return "%s %s ALL(?)"
}
return "%s %s ?"
}
@@ -200,6 +207,8 @@ func getOperation(operation es_models.Operation) string {
return ">"
case es_models.Operation_Less:
return "<"
case es_models.Operation_NotIn:
return "<>"
}
return ""
}

View File

@@ -309,7 +309,7 @@ func prepareTestScan(err error, res []interface{}) scan {
func Test_prepareCondition(t *testing.T) {
type args struct {
filters []*es_models.Filter
filters [][]*es_models.Filter
}
type res struct {
clause string
@@ -333,7 +333,7 @@ func Test_prepareCondition(t *testing.T) {
{
name: "empty filters",
args: args{
filters: []*es_models.Filter{},
filters: [][]*es_models.Filter{},
},
res: res{
clause: "",
@@ -343,8 +343,10 @@ func Test_prepareCondition(t *testing.T) {
{
name: "invalid condition",
args: args{
filters: []*es_models.Filter{
es_models.NewFilter(es_models.Field_AggregateID, "wrong", es_models.Operation(-1)),
filters: [][]*es_models.Filter{
{
es_models.NewFilter(es_models.Field_AggregateID, "wrong", es_models.Operation(-1)),
},
},
},
res: res{
@@ -355,26 +357,30 @@ func Test_prepareCondition(t *testing.T) {
{
name: "array as condition value",
args: args{
filters: []*es_models.Filter{
es_models.NewFilter(es_models.Field_AggregateType, []es_models.AggregateType{"user", "org"}, es_models.Operation_In),
filters: [][]*es_models.Filter{
{
es_models.NewFilter(es_models.Field_AggregateType, []es_models.AggregateType{"user", "org"}, es_models.Operation_In),
},
},
},
res: res{
clause: " WHERE aggregate_type = ANY(?)",
clause: " WHERE ( aggregate_type = ANY(?) )",
values: []interface{}{pq.Array([]es_models.AggregateType{"user", "org"})},
},
},
{
name: "multiple filters",
args: args{
filters: []*es_models.Filter{
es_models.NewFilter(es_models.Field_AggregateType, []es_models.AggregateType{"user", "org"}, es_models.Operation_In),
es_models.NewFilter(es_models.Field_AggregateID, "1234", es_models.Operation_Equals),
es_models.NewFilter(es_models.Field_EventType, []es_models.EventType{"user.created", "org.created"}, es_models.Operation_In),
filters: [][]*es_models.Filter{
{
es_models.NewFilter(es_models.Field_AggregateType, []es_models.AggregateType{"user", "org"}, es_models.Operation_In),
es_models.NewFilter(es_models.Field_AggregateID, "1234", es_models.Operation_Equals),
es_models.NewFilter(es_models.Field_EventType, []es_models.EventType{"user.created", "org.created"}, es_models.Operation_In),
},
},
},
res: res{
clause: " WHERE aggregate_type = ANY(?) AND aggregate_id = ? AND event_type = ANY(?)",
clause: " WHERE ( aggregate_type = ANY(?) AND aggregate_id = ? AND event_type = ANY(?) )",
values: []interface{}{pq.Array([]es_models.AggregateType{"user", "org"}), "1234", pq.Array([]es_models.EventType{"user.created", "org.created"})},
},
},
@@ -428,10 +434,10 @@ func Test_buildQuery(t *testing.T) {
{
name: "with order by desc",
args: args{
queryFactory: es_models.NewSearchQueryFactory("user").OrderDesc(),
queryFactory: es_models.NewSearchQueryFactory().OrderDesc().AddQuery().AggregateTypes("user").Factory(),
},
res: res{
query: "SELECT creation_date, event_type, event_sequence, previous_aggregate_sequence, event_data, editor_service, editor_user, resource_owner, instance_id, aggregate_type, aggregate_id, aggregate_version FROM eventstore.events WHERE aggregate_type = $1 ORDER BY event_sequence DESC",
query: "SELECT creation_date, event_type, event_sequence, previous_aggregate_sequence, event_data, editor_service, editor_user, resource_owner, instance_id, aggregate_type, aggregate_id, aggregate_version FROM eventstore.events WHERE ( aggregate_type = $1 ) ORDER BY event_sequence DESC",
rowScanner: true,
values: []interface{}{es_models.AggregateType("user")},
},
@@ -439,10 +445,10 @@ func Test_buildQuery(t *testing.T) {
{
name: "with limit",
args: args{
queryFactory: es_models.NewSearchQueryFactory("user").Limit(5),
queryFactory: es_models.NewSearchQueryFactory().Limit(5).AddQuery().AggregateTypes("user").Factory(),
},
res: res{
query: "SELECT creation_date, event_type, event_sequence, previous_aggregate_sequence, event_data, editor_service, editor_user, resource_owner, instance_id, aggregate_type, aggregate_id, aggregate_version FROM eventstore.events WHERE aggregate_type = $1 ORDER BY event_sequence LIMIT $2",
query: "SELECT creation_date, event_type, event_sequence, previous_aggregate_sequence, event_data, editor_service, editor_user, resource_owner, instance_id, aggregate_type, aggregate_id, aggregate_version FROM eventstore.events WHERE ( aggregate_type = $1 ) ORDER BY event_sequence LIMIT $2",
rowScanner: true,
values: []interface{}{es_models.AggregateType("user"), uint64(5)},
limit: 5,
@@ -451,10 +457,10 @@ func Test_buildQuery(t *testing.T) {
{
name: "with limit and order by desc",
args: args{
queryFactory: es_models.NewSearchQueryFactory("user").Limit(5).OrderDesc(),
queryFactory: es_models.NewSearchQueryFactory().Limit(5).OrderDesc().AddQuery().AggregateTypes("user").Factory(),
},
res: res{
query: "SELECT creation_date, event_type, event_sequence, previous_aggregate_sequence, event_data, editor_service, editor_user, resource_owner, instance_id, aggregate_type, aggregate_id, aggregate_version FROM eventstore.events WHERE aggregate_type = $1 ORDER BY event_sequence DESC LIMIT $2",
query: "SELECT creation_date, event_type, event_sequence, previous_aggregate_sequence, event_data, editor_service, editor_user, resource_owner, instance_id, aggregate_type, aggregate_id, aggregate_version FROM eventstore.events WHERE ( aggregate_type = $1 ) ORDER BY event_sequence DESC LIMIT $2",
rowScanner: true,
values: []interface{}{es_models.AggregateType("user"), uint64(5)},
limit: 5,

View File

@@ -7,16 +7,17 @@ import (
"time"
"github.com/caos/logging"
caos_errs "github.com/caos/zitadel/internal/errors"
"github.com/cockroachdb/cockroach-go/v2/crdb"
caos_errs "github.com/caos/zitadel/internal/errors"
)
const (
insertStmtFormat = "INSERT INTO %s" +
" (locker_id, locked_until, view_name) VALUES ($1, now()+$2::INTERVAL, $3)" +
" ON CONFLICT (view_name)" +
" DO UPDATE SET locker_id = $4, locked_until = now()+$5::INTERVAL" +
" WHERE locks.view_name = $6 AND (locks.locker_id = $7 OR locks.locked_until < now())"
" (locker_id, locked_until, view_name, instance_id) VALUES ($1, now()+$2::INTERVAL, $3, $4)" +
" ON CONFLICT (view_name, instance_id)" +
" DO UPDATE SET locker_id = $1, locked_until = now()+$2::INTERVAL" +
" WHERE locks.view_name = $3 AND locks.instance_id = $4 AND (locks.locker_id = $1 OR locks.locked_until < now())"
millisecondsAsSeconds = int64(time.Second / time.Millisecond)
)
@@ -26,13 +27,11 @@ type lock struct {
ViewName string `gorm:"column:view_name;primary_key"`
}
func Renew(dbClient *sql.DB, lockTable, lockerID, viewModel string, waitTime time.Duration) error {
func Renew(dbClient *sql.DB, lockTable, lockerID, viewModel, instanceID string, waitTime time.Duration) error {
return crdb.ExecuteTx(context.Background(), dbClient, nil, func(tx *sql.Tx) error {
insert := fmt.Sprintf(insertStmtFormat, lockTable)
result, err := tx.Exec(insert,
lockerID, waitTime.Milliseconds()/millisecondsAsSeconds, viewModel,
lockerID, waitTime.Milliseconds()/millisecondsAsSeconds,
viewModel, lockerID)
lockerID, waitTime.Milliseconds()/millisecondsAsSeconds, viewModel, instanceID)
if err != nil {
tx.Rollback()

View File

@@ -55,10 +55,10 @@ func (db *dbMock) expectReleaseSavepoint() *dbMock {
return db
}
func (db *dbMock) expectRenew(lockerID, view string, affectedRows int64) *dbMock {
func (db *dbMock) expectRenew(lockerID, view, instanceID string, affectedRows int64) *dbMock {
query := db.mock.
ExpectExec(`INSERT INTO table\.locks \(locker_id, locked_until, view_name\) VALUES \(\$1, now\(\)\+\$2::INTERVAL, \$3\) ON CONFLICT \(view_name\) DO UPDATE SET locker_id = \$4, locked_until = now\(\)\+\$5::INTERVAL WHERE locks\.view_name = \$6 AND \(locks\.locker_id = \$7 OR locks\.locked_until < now\(\)\)`).
WithArgs(lockerID, sqlmock.AnyArg(), view, lockerID, sqlmock.AnyArg(), view, lockerID).
ExpectExec(`INSERT INTO table\.locks \(locker_id, locked_until, view_name, instance_id\) VALUES \(\$1, now\(\)\+\$2::INTERVAL, \$3\, \$4\) ON CONFLICT \(view_name, instance_id\) DO UPDATE SET locker_id = \$1, locked_until = now\(\)\+\$2::INTERVAL WHERE locks\.view_name = \$3 AND locks\.instance_id = \$4 AND \(locks\.locker_id = \$1 OR locks\.locked_until < now\(\)\)`).
WithArgs(lockerID, sqlmock.AnyArg(), view, instanceID).
WillReturnResult(sqlmock.NewResult(1, 1))
if affectedRows == 0 {
@@ -75,10 +75,11 @@ func Test_locker_Renew(t *testing.T) {
db *dbMock
}
type args struct {
tableName string
lockerID string
viewModel string
waitTime time.Duration
tableName string
lockerID string
viewModel string
instanceID string
waitTime time.Duration
}
tests := []struct {
name string
@@ -92,11 +93,11 @@ func Test_locker_Renew(t *testing.T) {
db: mockDB(t).
expectBegin().
expectSavepoint().
expectRenew("locker", "view", 1).
expectRenew("locker", "view", "instanceID", 1).
expectReleaseSavepoint().
expectCommit(),
},
args: args{tableName: "table.locks", lockerID: "locker", viewModel: "view", waitTime: 1 * time.Second},
args: args{tableName: "table.locks", lockerID: "locker", viewModel: "view", instanceID: "instanceID", waitTime: 1 * time.Second},
wantErr: false,
},
{
@@ -105,16 +106,16 @@ func Test_locker_Renew(t *testing.T) {
db: mockDB(t).
expectBegin().
expectSavepoint().
expectRenew("locker", "view", 0).
expectRenew("locker", "view", "instanceID", 0).
expectRollback(),
},
args: args{tableName: "table.locks", lockerID: "locker", viewModel: "view", waitTime: 1 * time.Second},
args: args{tableName: "table.locks", lockerID: "locker", viewModel: "view", instanceID: "instanceID", waitTime: 1 * time.Second},
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if err := Renew(tt.fields.db.db, tt.args.tableName, tt.args.lockerID, tt.args.viewModel, tt.args.waitTime); (err != nil) != tt.wantErr {
if err := Renew(tt.fields.db.db, tt.args.tableName, tt.args.lockerID, tt.args.viewModel, tt.args.instanceID, tt.args.waitTime); (err != nil) != tt.wantErr {
t.Errorf("locker.Renew() error = %v, wantErr %v", err, tt.wantErr)
}
if err := tt.fields.db.mock.ExpectationsWereMet(); err != nil {

View File

@@ -1,56 +1,42 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: github.com/caos/zitadel/internal/eventstore (interfaces: Eventstore)
// Source: github.com/caos/zitadel/internal/eventstore/v1 (interfaces: Eventstore)
// Package mock is a generated GoMock package.
package mock
import (
context "context"
"github.com/caos/zitadel/internal/eventstore"
"github.com/caos/zitadel/internal/eventstore/v1"
reflect "reflect"
v1 "github.com/caos/zitadel/internal/eventstore/v1"
models "github.com/caos/zitadel/internal/eventstore/v1/models"
gomock "github.com/golang/mock/gomock"
reflect "reflect"
)
// MockEventstore is a mock of Eventstore interface
// MockEventstore is a mock of Eventstore interface.
type MockEventstore struct {
ctrl *gomock.Controller
recorder *MockEventstoreMockRecorder
}
// MockEventstoreMockRecorder is the mock recorder for MockEventstore
// MockEventstoreMockRecorder is the mock recorder for MockEventstore.
type MockEventstoreMockRecorder struct {
mock *MockEventstore
}
// NewMockEventstore creates a new mock instance
// NewMockEventstore creates a new mock instance.
func NewMockEventstore(ctrl *gomock.Controller) *MockEventstore {
mock := &MockEventstore{ctrl: ctrl}
mock.recorder = &MockEventstoreMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use
// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockEventstore) EXPECT() *MockEventstoreMockRecorder {
return m.recorder
}
// AggregateCreator mocks base method
func (m *MockEventstore) AggregateCreator() *models.AggregateCreator {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "AggregateCreator")
ret0, _ := ret[0].(*models.AggregateCreator)
return ret0
}
// AggregateCreator indicates an expected call of AggregateCreator
func (mr *MockEventstoreMockRecorder) AggregateCreator() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AggregateCreator", reflect.TypeOf((*MockEventstore)(nil).AggregateCreator))
}
// FilterEvents mocks base method
// FilterEvents mocks base method.
func (m *MockEventstore) FilterEvents(arg0 context.Context, arg1 *models.SearchQuery) ([]*models.Event, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "FilterEvents", arg0, arg1)
@@ -59,13 +45,13 @@ func (m *MockEventstore) FilterEvents(arg0 context.Context, arg1 *models.SearchQ
return ret0, ret1
}
// FilterEvents indicates an expected call of FilterEvents
// FilterEvents indicates an expected call of FilterEvents.
func (mr *MockEventstoreMockRecorder) FilterEvents(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FilterEvents", reflect.TypeOf((*MockEventstore)(nil).FilterEvents), arg0, arg1)
}
// Health mocks base method
// Health mocks base method.
func (m *MockEventstore) Health(arg0 context.Context) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Health", arg0)
@@ -73,47 +59,13 @@ func (m *MockEventstore) Health(arg0 context.Context) error {
return ret0
}
// Health indicates an expected call of Health
// Health indicates an expected call of Health.
func (mr *MockEventstoreMockRecorder) Health(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Health", reflect.TypeOf((*MockEventstore)(nil).Health), arg0)
}
// LatestSequence mocks base method
func (m *MockEventstore) LatestSequence(arg0 context.Context, arg1 *models.SearchQueryFactory) (uint64, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "LatestSequence", arg0, arg1)
ret0, _ := ret[0].(uint64)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// LatestSequence indicates an expected call of LatestSequence
func (mr *MockEventstoreMockRecorder) LatestSequence(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LatestSequence", reflect.TypeOf((*MockEventstore)(nil).LatestSequence), arg0, arg1)
}
// PushAggregates mocks base method
func (m *MockEventstore) PushAggregates(arg0 context.Context, arg1 ...*models.Aggregate) error {
m.ctrl.T.Helper()
varargs := []interface{}{arg0}
for _, a := range arg1 {
varargs = append(varargs, a)
}
ret := m.ctrl.Call(m, "PushAggregates", varargs...)
ret0, _ := ret[0].(error)
return ret0
}
// PushAggregates indicates an expected call of PushAggregates
func (mr *MockEventstoreMockRecorder) PushAggregates(arg0 interface{}, arg1 ...interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
varargs := append([]interface{}{arg0}, arg1...)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PushAggregates", reflect.TypeOf((*MockEventstore)(nil).PushAggregates), varargs...)
}
// Subscribe mocks base method
// Subscribe mocks base method.
func (m *MockEventstore) Subscribe(arg0 ...models.AggregateType) *v1.Subscription {
m.ctrl.T.Helper()
varargs := []interface{}{}
@@ -125,22 +77,8 @@ func (m *MockEventstore) Subscribe(arg0 ...models.AggregateType) *v1.Subscriptio
return ret0
}
// Subscribe indicates an expected call of Subscribe
// Subscribe indicates an expected call of Subscribe.
func (mr *MockEventstoreMockRecorder) Subscribe(arg0 ...interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Subscribe", reflect.TypeOf((*MockEventstore)(nil).Subscribe), arg0...)
}
// V2 mocks base method
func (m *MockEventstore) V2() *eventstore.Eventstore {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "V2")
ret0, _ := ret[0].(*eventstore.Eventstore)
return ret0
}
// V2 indicates an expected call of V2
func (mr *MockEventstoreMockRecorder) V2() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "V2", reflect.TypeOf((*MockEventstore)(nil).V2))
}

View File

@@ -190,7 +190,7 @@ func TestAggregate_Validate(t *testing.T) {
resourceOwner: "org",
PreviousSequence: 5,
Precondition: &precondition{
Query: NewSearchQuery().AggregateIDFilter("hodor"),
Query: NewSearchQuery().AddQuery().AggregateIDFilter("hodor").SearchQuery(),
},
Events: []*Event{
{
@@ -240,7 +240,7 @@ func TestAggregate_Validate(t *testing.T) {
PreviousSequence: 5,
Precondition: &precondition{
Validation: func(...*Event) error { return nil },
Query: NewSearchQuery().AggregateIDFilter("hodor"),
Query: NewSearchQuery().AddQuery().AggregateIDFilter("hodor").SearchQuery(),
},
Events: []*Event{
{

View File

@@ -7,4 +7,5 @@ const (
Operation_Greater
Operation_Less
Operation_In
Operation_NotIn
)

View File

@@ -9,24 +9,31 @@ import (
)
type SearchQueryFactory struct {
columns Columns
limit uint64
desc bool
aggregateTypes []AggregateType
aggregateIDs []string
sequenceFrom uint64
sequenceTo uint64
eventTypes []EventType
resourceOwner string
instanceID string
creationDate time.Time
columns Columns
limit uint64
desc bool
queries []*query
}
type query struct {
desc bool
aggregateTypes []AggregateType
aggregateIDs []string
sequenceFrom uint64
sequenceTo uint64
eventTypes []EventType
resourceOwner string
instanceID string
ignoredInstanceIDs []string
creationDate time.Time
factory *SearchQueryFactory
}
type searchQuery struct {
Columns Columns
Limit uint64
Desc bool
Filters []*Filter
Filters [][]*Filter
}
type Columns int32
@@ -39,49 +46,55 @@ const (
)
//FactoryFromSearchQuery is deprecated because it's for migration purposes. use NewSearchQueryFactory
func FactoryFromSearchQuery(query *SearchQuery) *SearchQueryFactory {
func FactoryFromSearchQuery(q *SearchQuery) *SearchQueryFactory {
factory := &SearchQueryFactory{
columns: Columns_Event,
desc: query.Desc,
limit: query.Limit,
desc: q.Desc,
limit: q.Limit,
queries: make([]*query, len(q.Queries)),
}
for _, filter := range query.Filters {
switch filter.field {
case Field_AggregateType:
factory = factory.aggregateTypesMig(filter.value.([]AggregateType)...)
case Field_AggregateID:
if aggregateID, ok := filter.value.(string); ok {
factory = factory.AggregateIDs(aggregateID)
} else if aggregateIDs, ok := filter.value.([]string); ok {
factory = factory.AggregateIDs(aggregateIDs...)
for i, qq := range q.Queries {
factory.queries[i] = &query{factory: factory}
for _, filter := range qq.Filters {
switch filter.field {
case Field_AggregateType:
factory.queries[i] = factory.queries[i].aggregateTypesMig(filter.value.([]AggregateType)...)
case Field_AggregateID:
if aggregateID, ok := filter.value.(string); ok {
factory.queries[i] = factory.queries[i].AggregateIDs(aggregateID)
} else if aggregateIDs, ok := filter.value.([]string); ok {
factory.queries[i] = factory.queries[i].AggregateIDs(aggregateIDs...)
}
case Field_LatestSequence:
if filter.operation == Operation_Greater {
factory.queries[i] = factory.queries[i].SequenceGreater(filter.value.(uint64))
} else {
factory.queries[i] = factory.queries[i].SequenceLess(filter.value.(uint64))
}
case Field_ResourceOwner:
factory.queries[i] = factory.queries[i].ResourceOwner(filter.value.(string))
case Field_InstanceID:
if filter.operation == Operation_Equals {
factory.queries[i] = factory.queries[i].InstanceID(filter.value.(string))
} else if filter.operation == Operation_NotIn {
factory.queries[i] = factory.queries[i].IgnoredInstanceIDs(filter.value.([]string)...)
}
case Field_EventType:
factory.queries[i] = factory.queries[i].EventTypes(filter.value.([]EventType)...)
case Field_EditorService, Field_EditorUser:
logging.WithFields("value", filter.value).Panic("field not converted to factory")
case Field_CreationDate:
factory.queries[i] = factory.queries[i].CreationDateNewer(filter.value.(time.Time))
}
case Field_LatestSequence:
if filter.operation == Operation_Greater {
factory = factory.SequenceGreater(filter.value.(uint64))
} else {
factory = factory.SequenceLess(filter.value.(uint64))
}
case Field_ResourceOwner:
factory = factory.ResourceOwner(filter.value.(string))
case Field_InstanceID:
factory = factory.InstanceID(filter.value.(string))
case Field_EventType:
factory = factory.EventTypes(filter.value.([]EventType)...)
case Field_EditorService, Field_EditorUser:
logging.Log("MODEL-Mr0VN").WithField("value", filter.value).Panic("field not converted to factory")
case Field_CreationDate:
factory = factory.CreationDateNewer(filter.value.(time.Time))
}
}
return factory
}
func NewSearchQueryFactory(aggregateTypes ...AggregateType) *SearchQueryFactory {
return &SearchQueryFactory{
aggregateTypes: aggregateTypes,
}
func NewSearchQueryFactory() *SearchQueryFactory {
return &SearchQueryFactory{}
}
func (factory *SearchQueryFactory) Columns(columns Columns) *SearchQueryFactory {
@@ -94,46 +107,6 @@ func (factory *SearchQueryFactory) Limit(limit uint64) *SearchQueryFactory {
return factory
}
func (factory *SearchQueryFactory) SequenceGreater(sequence uint64) *SearchQueryFactory {
factory.sequenceFrom = sequence
return factory
}
func (factory *SearchQueryFactory) SequenceLess(sequence uint64) *SearchQueryFactory {
factory.sequenceTo = sequence
return factory
}
func (factory *SearchQueryFactory) AggregateIDs(ids ...string) *SearchQueryFactory {
factory.aggregateIDs = ids
return factory
}
func (factory *SearchQueryFactory) aggregateTypesMig(types ...AggregateType) *SearchQueryFactory {
factory.aggregateTypes = types
return factory
}
func (factory *SearchQueryFactory) EventTypes(types ...EventType) *SearchQueryFactory {
factory.eventTypes = types
return factory
}
func (factory *SearchQueryFactory) ResourceOwner(resourceOwner string) *SearchQueryFactory {
factory.resourceOwner = resourceOwner
return factory
}
func (factory *SearchQueryFactory) InstanceID(instanceID string) *SearchQueryFactory {
factory.instanceID = instanceID
return factory
}
func (factory *SearchQueryFactory) CreationDateNewer(time time.Time) *SearchQueryFactory {
factory.creationDate = time
return factory
}
func (factory *SearchQueryFactory) OrderDesc() *SearchQueryFactory {
factory.desc = true
return factory
@@ -144,27 +117,89 @@ func (factory *SearchQueryFactory) OrderAsc() *SearchQueryFactory {
return factory
}
func (factory *SearchQueryFactory) AddQuery() *query {
q := &query{factory: factory}
factory.queries = append(factory.queries, q)
return q
}
func (q *query) Factory() *SearchQueryFactory {
return q.factory
}
func (q *query) SequenceGreater(sequence uint64) *query {
q.sequenceFrom = sequence
return q
}
func (q *query) SequenceLess(sequence uint64) *query {
q.sequenceTo = sequence
return q
}
func (q *query) AggregateTypes(types ...AggregateType) *query {
q.aggregateTypes = types
return q
}
func (q *query) AggregateIDs(ids ...string) *query {
q.aggregateIDs = ids
return q
}
func (q *query) aggregateTypesMig(types ...AggregateType) *query {
q.aggregateTypes = types
return q
}
func (q *query) EventTypes(types ...EventType) *query {
q.eventTypes = types
return q
}
func (q *query) ResourceOwner(resourceOwner string) *query {
q.resourceOwner = resourceOwner
return q
}
func (q *query) InstanceID(instanceID string) *query {
q.instanceID = instanceID
return q
}
func (q *query) IgnoredInstanceIDs(instanceIDs ...string) *query {
q.ignoredInstanceIDs = instanceIDs
return q
}
func (q *query) CreationDateNewer(time time.Time) *query {
q.creationDate = time
return q
}
func (factory *SearchQueryFactory) Build() (*searchQuery, error) {
if factory == nil ||
len(factory.aggregateTypes) < 1 ||
len(factory.queries) < 1 ||
(factory.columns < 0 || factory.columns >= columnsCount) {
return nil, errors.ThrowPreconditionFailed(nil, "MODEL-tGAD3", "factory invalid")
}
filters := []*Filter{
factory.aggregateTypeFilter(),
}
filters := make([][]*Filter, len(factory.queries))
for _, f := range []func() *Filter{
factory.aggregateIDFilter,
factory.sequenceFromFilter,
factory.sequenceToFilter,
factory.eventTypeFilter,
factory.resourceOwnerFilter,
factory.instanceIDFilter,
factory.creationDateNewerFilter,
} {
if filter := f(); filter != nil {
filters = append(filters, filter)
for i, query := range factory.queries {
for _, f := range []func() *Filter{
query.aggregateTypeFilter,
query.aggregateIDFilter,
query.sequenceFromFilter,
query.sequenceToFilter,
query.eventTypeFilter,
query.resourceOwnerFilter,
query.instanceIDFilter,
query.ignoredInstanceIDsFilter,
query.creationDateNewerFilter,
} {
if filter := f(); filter != nil {
filters[i] = append(filters[i], filter)
}
}
}
@@ -176,72 +211,79 @@ func (factory *SearchQueryFactory) Build() (*searchQuery, error) {
}, nil
}
func (factory *SearchQueryFactory) aggregateIDFilter() *Filter {
if len(factory.aggregateIDs) < 1 {
func (q *query) aggregateIDFilter() *Filter {
if len(q.aggregateIDs) < 1 {
return nil
}
if len(factory.aggregateIDs) == 1 {
return NewFilter(Field_AggregateID, factory.aggregateIDs[0], Operation_Equals)
if len(q.aggregateIDs) == 1 {
return NewFilter(Field_AggregateID, q.aggregateIDs[0], Operation_Equals)
}
return NewFilter(Field_AggregateID, factory.aggregateIDs, Operation_In)
return NewFilter(Field_AggregateID, q.aggregateIDs, Operation_In)
}
func (factory *SearchQueryFactory) eventTypeFilter() *Filter {
if len(factory.eventTypes) < 1 {
func (q *query) eventTypeFilter() *Filter {
if len(q.eventTypes) < 1 {
return nil
}
if len(factory.eventTypes) == 1 {
return NewFilter(Field_EventType, factory.eventTypes[0], Operation_Equals)
if len(q.eventTypes) == 1 {
return NewFilter(Field_EventType, q.eventTypes[0], Operation_Equals)
}
return NewFilter(Field_EventType, factory.eventTypes, Operation_In)
return NewFilter(Field_EventType, q.eventTypes, Operation_In)
}
func (factory *SearchQueryFactory) aggregateTypeFilter() *Filter {
if len(factory.aggregateTypes) == 1 {
return NewFilter(Field_AggregateType, factory.aggregateTypes[0], Operation_Equals)
func (q *query) aggregateTypeFilter() *Filter {
if len(q.aggregateTypes) == 1 {
return NewFilter(Field_AggregateType, q.aggregateTypes[0], Operation_Equals)
}
return NewFilter(Field_AggregateType, factory.aggregateTypes, Operation_In)
return NewFilter(Field_AggregateType, q.aggregateTypes, Operation_In)
}
func (factory *SearchQueryFactory) sequenceFromFilter() *Filter {
if factory.sequenceFrom == 0 {
func (q *query) sequenceFromFilter() *Filter {
if q.sequenceFrom == 0 {
return nil
}
sortOrder := Operation_Greater
if factory.desc {
if q.factory.desc {
sortOrder = Operation_Less
}
return NewFilter(Field_LatestSequence, factory.sequenceFrom, sortOrder)
return NewFilter(Field_LatestSequence, q.sequenceFrom, sortOrder)
}
func (factory *SearchQueryFactory) sequenceToFilter() *Filter {
if factory.sequenceTo == 0 {
func (q *query) sequenceToFilter() *Filter {
if q.sequenceTo == 0 {
return nil
}
sortOrder := Operation_Less
if factory.desc {
if q.factory.desc {
sortOrder = Operation_Greater
}
return NewFilter(Field_LatestSequence, factory.sequenceTo, sortOrder)
return NewFilter(Field_LatestSequence, q.sequenceTo, sortOrder)
}
func (factory *SearchQueryFactory) resourceOwnerFilter() *Filter {
if factory.resourceOwner == "" {
func (q *query) resourceOwnerFilter() *Filter {
if q.resourceOwner == "" {
return nil
}
return NewFilter(Field_ResourceOwner, factory.resourceOwner, Operation_Equals)
return NewFilter(Field_ResourceOwner, q.resourceOwner, Operation_Equals)
}
func (factory *SearchQueryFactory) instanceIDFilter() *Filter {
if factory.instanceID == "" {
func (q *query) instanceIDFilter() *Filter {
if q.instanceID == "" {
return nil
}
return NewFilter(Field_InstanceID, factory.instanceID, Operation_Equals)
return NewFilter(Field_InstanceID, q.instanceID, Operation_Equals)
}
func (factory *SearchQueryFactory) creationDateNewerFilter() *Filter {
if factory.creationDate.IsZero() {
func (q *query) ignoredInstanceIDsFilter() *Filter {
if len(q.ignoredInstanceIDs) == 0 {
return nil
}
return NewFilter(Field_CreationDate, factory.creationDate, Operation_Greater)
return NewFilter(Field_InstanceID, q.ignoredInstanceIDs, Operation_NotIn)
}
func (q *query) creationDateNewerFilter() *Filter {
if q.creationDate.IsZero() {
return nil
}
return NewFilter(Field_CreationDate, q.creationDate, Operation_Greater)
}

View File

@@ -11,15 +11,46 @@ type SearchQuery struct {
Limit uint64
Desc bool
Filters []*Filter
Queries []*Query
}
type Query struct {
searchQuery *SearchQuery
Filters []*Filter
}
//NewSearchQuery is deprecated. Use SearchQueryFactory
func NewSearchQuery() *SearchQuery {
return &SearchQuery{
Filters: make([]*Filter, 0, 4),
Queries: make([]*Query, 0),
}
}
func (q *SearchQuery) AddQuery() *Query {
query := &Query{
searchQuery: q,
}
q.Queries = append(q.Queries, query)
return query
}
//SearchQuery returns the SearchQuery of the sub query
func (q *Query) SearchQuery() *SearchQuery {
return q.searchQuery
}
func (q *Query) setFilter(filter *Filter) *Query {
for i, f := range q.Filters {
if f.field == filter.field && f.field != Field_LatestSequence {
q.Filters[i] = filter
return q
}
}
q.Filters = append(q.Filters, filter)
return q
}
func (q *SearchQuery) SetLimit(limit uint64) *SearchQuery {
q.Limit = limit
return q
@@ -35,23 +66,23 @@ func (q *SearchQuery) OrderAsc() *SearchQuery {
return q
}
func (q *SearchQuery) AggregateIDFilter(id string) *SearchQuery {
func (q *Query) AggregateIDFilter(id string) *Query {
return q.setFilter(NewFilter(Field_AggregateID, id, Operation_Equals))
}
func (q *SearchQuery) AggregateIDsFilter(ids ...string) *SearchQuery {
func (q *Query) AggregateIDsFilter(ids ...string) *Query {
return q.setFilter(NewFilter(Field_AggregateID, ids, Operation_In))
}
func (q *SearchQuery) AggregateTypeFilter(types ...AggregateType) *SearchQuery {
func (q *Query) AggregateTypeFilter(types ...AggregateType) *Query {
return q.setFilter(NewFilter(Field_AggregateType, types, Operation_In))
}
func (q *SearchQuery) EventTypesFilter(types ...EventType) *SearchQuery {
func (q *Query) EventTypesFilter(types ...EventType) *Query {
return q.setFilter(NewFilter(Field_EventType, types, Operation_In))
}
func (q *SearchQuery) LatestSequenceFilter(sequence uint64) *SearchQuery {
func (q *Query) LatestSequenceFilter(sequence uint64) *Query {
if sequence == 0 {
return q
}
@@ -59,21 +90,25 @@ func (q *SearchQuery) LatestSequenceFilter(sequence uint64) *SearchQuery {
return q.setFilter(NewFilter(Field_LatestSequence, sequence, sortOrder))
}
func (q *SearchQuery) SequenceBetween(from, to uint64) *SearchQuery {
func (q *Query) SequenceBetween(from, to uint64) *Query {
q.setFilter(NewFilter(Field_LatestSequence, from, Operation_Greater))
q.setFilter(NewFilter(Field_LatestSequence, to, Operation_Less))
return q
}
func (q *SearchQuery) ResourceOwnerFilter(resourceOwner string) *SearchQuery {
func (q *Query) ResourceOwnerFilter(resourceOwner string) *Query {
return q.setFilter(NewFilter(Field_ResourceOwner, resourceOwner, Operation_Equals))
}
func (q *SearchQuery) InstanceIDFilter(instanceID string) *SearchQuery {
func (q *Query) InstanceIDFilter(instanceID string) *Query {
return q.setFilter(NewFilter(Field_InstanceID, instanceID, Operation_Equals))
}
func (q *SearchQuery) CreationDateNewerFilter(time time.Time) *SearchQuery {
func (q *Query) ExcludedInstanceIDsFilter(instanceIDs ...string) *Query {
return q.setFilter(NewFilter(Field_InstanceID, instanceIDs, Operation_NotIn))
}
func (q *Query) CreationDateNewerFilter(time time.Time) *Query {
return q.setFilter(NewFilter(Field_CreationDate, time, Operation_Greater))
}
@@ -92,12 +127,14 @@ func (q *SearchQuery) Validate() error {
if q == nil {
return errors.ThrowPreconditionFailed(nil, "MODEL-J5xQi", "search query is nil")
}
if len(q.Filters) == 0 {
if len(q.Queries) == 0 {
return errors.ThrowPreconditionFailed(nil, "MODEL-pF3DR", "no filters set")
}
for _, filter := range q.Filters {
if err := filter.Validate(); err != nil {
return err
for _, query := range q.Queries {
for _, filter := range query.Filters {
if err := filter.Validate(); err != nil {
return err
}
}
}

View File

@@ -21,31 +21,48 @@ func testSetLimit(limit uint64) func(factory *SearchQueryFactory) *SearchQueryFa
}
}
func testSetSequence(sequence uint64) func(factory *SearchQueryFactory) *SearchQueryFactory {
return func(factory *SearchQueryFactory) *SearchQueryFactory {
factory = factory.SequenceGreater(sequence)
return factory
func testAddQuery(queryFuncs ...func(*query) *query) func(*SearchQueryFactory) *SearchQueryFactory {
return func(builder *SearchQueryFactory) *SearchQueryFactory {
query := builder.AddQuery()
for _, queryFunc := range queryFuncs {
queryFunc(query)
}
return query.Factory()
}
}
func testSetAggregateIDs(aggregateIDs ...string) func(factory *SearchQueryFactory) *SearchQueryFactory {
return func(factory *SearchQueryFactory) *SearchQueryFactory {
factory = factory.AggregateIDs(aggregateIDs...)
return factory
func testSetSequence(sequence uint64) func(*query) *query {
return func(q *query) *query {
q.SequenceGreater(sequence)
return q
}
}
func testSetEventTypes(eventTypes ...EventType) func(factory *SearchQueryFactory) *SearchQueryFactory {
return func(factory *SearchQueryFactory) *SearchQueryFactory {
factory = factory.EventTypes(eventTypes...)
return factory
func testSetAggregateIDs(aggregateIDs ...string) func(*query) *query {
return func(q *query) *query {
q.AggregateIDs(aggregateIDs...)
return q
}
}
func testSetResourceOwner(resourceOwner string) func(factory *SearchQueryFactory) *SearchQueryFactory {
return func(factory *SearchQueryFactory) *SearchQueryFactory {
factory = factory.ResourceOwner(resourceOwner)
return factory
func testSetAggregateTypes(aggregateTypes ...AggregateType) func(*query) *query {
return func(q *query) *query {
q.AggregateTypes(aggregateTypes...)
return q
}
}
func testSetEventTypes(eventTypes ...EventType) func(*query) *query {
return func(q *query) *query {
q.EventTypes(eventTypes...)
return q
}
}
func testSetResourceOwner(resourceOwner string) func(*query) *query {
return func(q *query) *query {
q.ResourceOwner(resourceOwner)
return q
}
}
@@ -60,10 +77,50 @@ func testSetSortOrder(asc bool) func(factory *SearchQueryFactory) *SearchQueryFa
}
}
func assertFactory(t *testing.T, want, got *SearchQueryFactory) {
t.Helper()
if got.columns != want.columns {
t.Errorf("wrong column: got: %v want: %v", got.columns, want.columns)
}
if got.desc != want.desc {
t.Errorf("wrong desc: got: %v want: %v", got.desc, want.desc)
}
if got.limit != want.limit {
t.Errorf("wrong limit: got: %v want: %v", got.limit, want.limit)
}
if len(got.queries) != len(want.queries) {
t.Errorf("wrong length of queries: got: %v want: %v", len(got.queries), len(want.queries))
}
for i, query := range got.queries {
assertQuery(t, i, want.queries[i], query)
}
}
func assertQuery(t *testing.T, i int, want, got *query) {
t.Helper()
if !reflect.DeepEqual(got.aggregateIDs, want.aggregateIDs) {
t.Errorf("wrong aggregateIDs in query %d : got: %v want: %v", i, got.aggregateIDs, want.aggregateIDs)
}
if !reflect.DeepEqual(got.aggregateTypes, want.aggregateTypes) {
t.Errorf("wrong aggregateTypes in query %d : got: %v want: %v", i, got.aggregateTypes, want.aggregateTypes)
}
if got.sequenceFrom != want.sequenceFrom {
t.Errorf("wrong sequenceFrom in query %d : got: %v want: %v", i, got.sequenceFrom, want.sequenceFrom)
}
if got.sequenceTo != want.sequenceTo {
t.Errorf("wrong sequenceTo in query %d : got: %v want: %v", i, got.sequenceTo, want.sequenceTo)
}
if !reflect.DeepEqual(got.eventTypes, want.eventTypes) {
t.Errorf("wrong eventTypes in query %d : got: %v want: %v", i, got.eventTypes, want.eventTypes)
}
}
func TestSearchQueryFactorySetters(t *testing.T) {
type args struct {
aggregateTypes []AggregateType
setters []func(*SearchQueryFactory) *SearchQueryFactory
setters []func(*SearchQueryFactory) *SearchQueryFactory
}
tests := []struct {
name string
@@ -73,11 +130,9 @@ func TestSearchQueryFactorySetters(t *testing.T) {
{
name: "New factory",
args: args{
aggregateTypes: []AggregateType{"user", "org"},
},
res: &SearchQueryFactory{
aggregateTypes: []AggregateType{"user", "org"},
setters: []func(*SearchQueryFactory) *SearchQueryFactory{},
},
res: &SearchQueryFactory{},
},
{
name: "set columns",
@@ -100,69 +155,98 @@ func TestSearchQueryFactorySetters(t *testing.T) {
{
name: "set sequence",
args: args{
setters: []func(*SearchQueryFactory) *SearchQueryFactory{testSetSequence(90)},
setters: []func(*SearchQueryFactory) *SearchQueryFactory{testAddQuery(testSetSequence(90))},
},
res: &SearchQueryFactory{
sequenceFrom: 90,
queries: []*query{
{
sequenceFrom: 90,
},
},
},
},
{
name: "set aggregateTypes",
args: args{
setters: []func(*SearchQueryFactory) *SearchQueryFactory{testAddQuery(testSetAggregateTypes("user", "org"))},
},
res: &SearchQueryFactory{
queries: []*query{
{
aggregateTypes: []AggregateType{"user", "org"},
},
},
},
},
{
name: "set aggregateIDs",
args: args{
setters: []func(*SearchQueryFactory) *SearchQueryFactory{testSetAggregateIDs("1235", "09824")},
setters: []func(*SearchQueryFactory) *SearchQueryFactory{testAddQuery(testSetAggregateIDs("1235", "09824"))},
},
res: &SearchQueryFactory{
aggregateIDs: []string{"1235", "09824"},
queries: []*query{
{
aggregateIDs: []string{"1235", "09824"},
},
},
},
},
{
name: "set eventTypes",
args: args{
setters: []func(*SearchQueryFactory) *SearchQueryFactory{testSetEventTypes("user.created", "user.updated")},
setters: []func(*SearchQueryFactory) *SearchQueryFactory{testAddQuery(testSetEventTypes("user.created", "user.updated"))},
},
res: &SearchQueryFactory{
eventTypes: []EventType{"user.created", "user.updated"},
queries: []*query{
{
eventTypes: []EventType{"user.created", "user.updated"},
},
},
},
},
{
name: "set resource owner",
args: args{
setters: []func(*SearchQueryFactory) *SearchQueryFactory{testSetResourceOwner("hodor")},
setters: []func(*SearchQueryFactory) *SearchQueryFactory{testAddQuery(testSetResourceOwner("hodor"))},
},
res: &SearchQueryFactory{
resourceOwner: "hodor",
queries: []*query{
{
resourceOwner: "hodor",
},
},
},
},
{
name: "default search query",
args: args{
aggregateTypes: []AggregateType{"user"},
setters: []func(*SearchQueryFactory) *SearchQueryFactory{testSetAggregateIDs("1235", "024"), testSetSortOrder(false)},
setters: []func(*SearchQueryFactory) *SearchQueryFactory{testAddQuery(testSetAggregateTypes("user"), testSetAggregateIDs("1235", "024")), testSetSortOrder(false)},
},
res: &SearchQueryFactory{
aggregateTypes: []AggregateType{"user"},
aggregateIDs: []string{"1235", "024"},
desc: true,
desc: true,
queries: []*query{
{
aggregateTypes: []AggregateType{"user"},
aggregateIDs: []string{"1235", "024"},
},
},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
factory := NewSearchQueryFactory(tt.args.aggregateTypes...)
factory := NewSearchQueryFactory()
for _, setter := range tt.args.setters {
factory = setter(factory)
}
if !reflect.DeepEqual(factory, tt.res) {
t.Errorf("NewSearchQueryFactory() = %v, want %v", factory, tt.res)
}
assertFactory(t, tt.res, factory)
})
}
}
func TestSearchQueryFactoryBuild(t *testing.T) {
type args struct {
aggregateTypes []AggregateType
setters []func(*SearchQueryFactory) *SearchQueryFactory
setters []func(*SearchQueryFactory) *SearchQueryFactory
}
type res struct {
isErr func(err error) bool
@@ -176,8 +260,7 @@ func TestSearchQueryFactoryBuild(t *testing.T) {
{
name: "no aggregate types",
args: args{
aggregateTypes: []AggregateType{},
setters: []func(*SearchQueryFactory) *SearchQueryFactory{},
setters: []func(*SearchQueryFactory) *SearchQueryFactory{},
},
res: res{
isErr: errors.IsPreconditionFailed,
@@ -187,9 +270,9 @@ func TestSearchQueryFactoryBuild(t *testing.T) {
{
name: "invalid column (too low)",
args: args{
aggregateTypes: []AggregateType{"user"},
setters: []func(*SearchQueryFactory) *SearchQueryFactory{
testSetColumns(Columns(-1)),
testAddQuery(testSetAggregateTypes("user")),
},
},
res: res{
@@ -199,9 +282,9 @@ func TestSearchQueryFactoryBuild(t *testing.T) {
{
name: "invalid column (too high)",
args: args{
aggregateTypes: []AggregateType{"user"},
setters: []func(*SearchQueryFactory) *SearchQueryFactory{
testSetColumns(columnsCount),
testAddQuery(testSetAggregateTypes("user")),
},
},
res: res{
@@ -211,8 +294,9 @@ func TestSearchQueryFactoryBuild(t *testing.T) {
{
name: "filter aggregate type",
args: args{
aggregateTypes: []AggregateType{"user"},
setters: []func(*SearchQueryFactory) *SearchQueryFactory{},
setters: []func(*SearchQueryFactory) *SearchQueryFactory{
testAddQuery(testSetAggregateTypes("user")),
},
},
res: res{
isErr: nil,
@@ -220,8 +304,10 @@ func TestSearchQueryFactoryBuild(t *testing.T) {
Columns: 0,
Desc: false,
Limit: 0,
Filters: []*Filter{
NewFilter(Field_AggregateType, AggregateType("user"), Operation_Equals),
Filters: [][]*Filter{
{
NewFilter(Field_AggregateType, AggregateType("user"), Operation_Equals),
},
},
},
},
@@ -229,8 +315,9 @@ func TestSearchQueryFactoryBuild(t *testing.T) {
{
name: "filter aggregate types",
args: args{
aggregateTypes: []AggregateType{"user", "org"},
setters: []func(*SearchQueryFactory) *SearchQueryFactory{},
setters: []func(*SearchQueryFactory) *SearchQueryFactory{
testAddQuery(testSetAggregateTypes("user", "org")),
},
},
res: res{
isErr: nil,
@@ -238,8 +325,10 @@ func TestSearchQueryFactoryBuild(t *testing.T) {
Columns: 0,
Desc: false,
Limit: 0,
Filters: []*Filter{
NewFilter(Field_AggregateType, []AggregateType{"user", "org"}, Operation_In),
Filters: [][]*Filter{
{
NewFilter(Field_AggregateType, []AggregateType{"user", "org"}, Operation_In),
},
},
},
},
@@ -247,11 +336,13 @@ func TestSearchQueryFactoryBuild(t *testing.T) {
{
name: "filter aggregate type, limit, desc",
args: args{
aggregateTypes: []AggregateType{"user"},
setters: []func(*SearchQueryFactory) *SearchQueryFactory{
testSetLimit(5),
testSetSortOrder(false),
testSetSequence(100),
testAddQuery(
testSetAggregateTypes("user"),
testSetSequence(100),
),
},
},
res: res{
@@ -260,9 +351,11 @@ func TestSearchQueryFactoryBuild(t *testing.T) {
Columns: 0,
Desc: true,
Limit: 5,
Filters: []*Filter{
NewFilter(Field_AggregateType, AggregateType("user"), Operation_Equals),
NewFilter(Field_LatestSequence, uint64(100), Operation_Less),
Filters: [][]*Filter{
{
NewFilter(Field_AggregateType, AggregateType("user"), Operation_Equals),
NewFilter(Field_LatestSequence, uint64(100), Operation_Less),
},
},
},
},
@@ -270,11 +363,13 @@ func TestSearchQueryFactoryBuild(t *testing.T) {
{
name: "filter aggregate type, limit, asc",
args: args{
aggregateTypes: []AggregateType{"user"},
setters: []func(*SearchQueryFactory) *SearchQueryFactory{
testSetLimit(5),
testSetSortOrder(true),
testSetSequence(100),
testAddQuery(
testSetSequence(100),
testSetAggregateTypes("user"),
),
},
},
res: res{
@@ -283,9 +378,11 @@ func TestSearchQueryFactoryBuild(t *testing.T) {
Columns: 0,
Desc: false,
Limit: 5,
Filters: []*Filter{
NewFilter(Field_AggregateType, AggregateType("user"), Operation_Equals),
NewFilter(Field_LatestSequence, uint64(100), Operation_Greater),
Filters: [][]*Filter{
{
NewFilter(Field_AggregateType, AggregateType("user"), Operation_Equals),
NewFilter(Field_LatestSequence, uint64(100), Operation_Greater),
},
},
},
},
@@ -293,12 +390,14 @@ func TestSearchQueryFactoryBuild(t *testing.T) {
{
name: "filter aggregate type, limit, desc, max event sequence cols",
args: args{
aggregateTypes: []AggregateType{"user"},
setters: []func(*SearchQueryFactory) *SearchQueryFactory{
testSetLimit(5),
testSetSortOrder(false),
testSetSequence(100),
testSetColumns(Columns_Max_Sequence),
testAddQuery(
testSetSequence(100),
testSetAggregateTypes("user"),
),
},
},
res: res{
@@ -307,9 +406,11 @@ func TestSearchQueryFactoryBuild(t *testing.T) {
Columns: Columns_Max_Sequence,
Desc: true,
Limit: 5,
Filters: []*Filter{
NewFilter(Field_AggregateType, AggregateType("user"), Operation_Equals),
NewFilter(Field_LatestSequence, uint64(100), Operation_Less),
Filters: [][]*Filter{
{
NewFilter(Field_AggregateType, AggregateType("user"), Operation_Equals),
NewFilter(Field_LatestSequence, uint64(100), Operation_Less),
},
},
},
},
@@ -317,9 +418,11 @@ func TestSearchQueryFactoryBuild(t *testing.T) {
{
name: "filter aggregate type and aggregate id",
args: args{
aggregateTypes: []AggregateType{"user"},
setters: []func(*SearchQueryFactory) *SearchQueryFactory{
testSetAggregateIDs("1234"),
testAddQuery(
testSetAggregateIDs("1234"),
testSetAggregateTypes("user"),
),
},
},
res: res{
@@ -328,9 +431,11 @@ func TestSearchQueryFactoryBuild(t *testing.T) {
Columns: 0,
Desc: false,
Limit: 0,
Filters: []*Filter{
NewFilter(Field_AggregateType, AggregateType("user"), Operation_Equals),
NewFilter(Field_AggregateID, "1234", Operation_Equals),
Filters: [][]*Filter{
{
NewFilter(Field_AggregateType, AggregateType("user"), Operation_Equals),
NewFilter(Field_AggregateID, "1234", Operation_Equals),
},
},
},
},
@@ -338,9 +443,11 @@ func TestSearchQueryFactoryBuild(t *testing.T) {
{
name: "filter aggregate type and aggregate ids",
args: args{
aggregateTypes: []AggregateType{"user"},
setters: []func(*SearchQueryFactory) *SearchQueryFactory{
testSetAggregateIDs("1234", "0815"),
testAddQuery(
testSetAggregateIDs("1234", "0815"),
testSetAggregateTypes("user"),
),
},
},
res: res{
@@ -349,9 +456,11 @@ func TestSearchQueryFactoryBuild(t *testing.T) {
Columns: 0,
Desc: false,
Limit: 0,
Filters: []*Filter{
NewFilter(Field_AggregateType, AggregateType("user"), Operation_Equals),
NewFilter(Field_AggregateID, []string{"1234", "0815"}, Operation_In),
Filters: [][]*Filter{
{
NewFilter(Field_AggregateType, AggregateType("user"), Operation_Equals),
NewFilter(Field_AggregateID, []string{"1234", "0815"}, Operation_In),
},
},
},
},
@@ -359,9 +468,11 @@ func TestSearchQueryFactoryBuild(t *testing.T) {
{
name: "filter aggregate type and sequence greater",
args: args{
aggregateTypes: []AggregateType{"user"},
setters: []func(*SearchQueryFactory) *SearchQueryFactory{
testSetSequence(8),
testAddQuery(
testSetSequence(8),
testSetAggregateTypes("user"),
),
},
},
res: res{
@@ -370,9 +481,11 @@ func TestSearchQueryFactoryBuild(t *testing.T) {
Columns: 0,
Desc: false,
Limit: 0,
Filters: []*Filter{
NewFilter(Field_AggregateType, AggregateType("user"), Operation_Equals),
NewFilter(Field_LatestSequence, uint64(8), Operation_Greater),
Filters: [][]*Filter{
{
NewFilter(Field_AggregateType, AggregateType("user"), Operation_Equals),
NewFilter(Field_LatestSequence, uint64(8), Operation_Greater),
},
},
},
},
@@ -380,9 +493,11 @@ func TestSearchQueryFactoryBuild(t *testing.T) {
{
name: "filter aggregate type and event type",
args: args{
aggregateTypes: []AggregateType{"user"},
setters: []func(*SearchQueryFactory) *SearchQueryFactory{
testSetEventTypes("user.created"),
testAddQuery(
testSetAggregateTypes("user"),
testSetEventTypes("user.created"),
),
},
},
res: res{
@@ -391,9 +506,11 @@ func TestSearchQueryFactoryBuild(t *testing.T) {
Columns: 0,
Desc: false,
Limit: 0,
Filters: []*Filter{
NewFilter(Field_AggregateType, AggregateType("user"), Operation_Equals),
NewFilter(Field_EventType, EventType("user.created"), Operation_Equals),
Filters: [][]*Filter{
{
NewFilter(Field_AggregateType, AggregateType("user"), Operation_Equals),
NewFilter(Field_EventType, EventType("user.created"), Operation_Equals),
},
},
},
},
@@ -401,9 +518,11 @@ func TestSearchQueryFactoryBuild(t *testing.T) {
{
name: "filter aggregate type and event types",
args: args{
aggregateTypes: []AggregateType{"user"},
setters: []func(*SearchQueryFactory) *SearchQueryFactory{
testSetEventTypes("user.created", "user.changed"),
testAddQuery(
testSetAggregateTypes("user"),
testSetEventTypes("user.created", "user.changed"),
),
},
},
res: res{
@@ -412,9 +531,11 @@ func TestSearchQueryFactoryBuild(t *testing.T) {
Columns: 0,
Desc: false,
Limit: 0,
Filters: []*Filter{
NewFilter(Field_AggregateType, AggregateType("user"), Operation_Equals),
NewFilter(Field_EventType, []EventType{"user.created", "user.changed"}, Operation_In),
Filters: [][]*Filter{
{
NewFilter(Field_AggregateType, AggregateType("user"), Operation_Equals),
NewFilter(Field_EventType, []EventType{"user.created", "user.changed"}, Operation_In),
},
},
},
},
@@ -422,9 +543,11 @@ func TestSearchQueryFactoryBuild(t *testing.T) {
{
name: "filter aggregate type resource owner",
args: args{
aggregateTypes: []AggregateType{"user"},
setters: []func(*SearchQueryFactory) *SearchQueryFactory{
testSetResourceOwner("hodor"),
testAddQuery(
testSetAggregateTypes("user"),
testSetResourceOwner("hodor"),
),
},
},
res: res{
@@ -433,9 +556,11 @@ func TestSearchQueryFactoryBuild(t *testing.T) {
Columns: 0,
Desc: false,
Limit: 0,
Filters: []*Filter{
NewFilter(Field_AggregateType, AggregateType("user"), Operation_Equals),
NewFilter(Field_ResourceOwner, "hodor", Operation_Equals),
Filters: [][]*Filter{
{
NewFilter(Field_AggregateType, AggregateType("user"), Operation_Equals),
NewFilter(Field_ResourceOwner, "hodor", Operation_Equals),
},
},
},
},
@@ -443,7 +568,7 @@ func TestSearchQueryFactoryBuild(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
factory := NewSearchQueryFactory(tt.args.aggregateTypes...)
factory := NewSearchQueryFactory()
for _, f := range tt.args.setters {
factory = f(factory)
}

View File

@@ -26,7 +26,7 @@ type Handler interface {
QueryLimit() uint64
AggregateTypes() []models.AggregateType
CurrentSequence() (uint64, error)
CurrentSequence(instanceID string) (uint64, error)
Eventstore() v1.Eventstore
Subscription() *v1.Subscription
@@ -41,15 +41,18 @@ func ReduceEvent(handler Handler, event *models.Event) {
handler.Subscription().Unsubscribe()
}
}()
currentSequence, err := handler.CurrentSequence()
currentSequence, err := handler.CurrentSequence(event.InstanceID)
if err != nil {
logging.New().WithError(err).Warn("unable to get current sequence")
return
}
searchQuery := models.NewSearchQuery().
AddQuery().
AggregateTypeFilter(handler.AggregateTypes()...).
SequenceBetween(currentSequence, event.Sequence).
InstanceIDFilter(event.InstanceID).
SearchQuery().
SetLimit(eventLimit)
unprocessedEvents, err := handler.Eventstore().FilterEvents(context.Background(), searchQuery)
@@ -59,7 +62,7 @@ func ReduceEvent(handler Handler, event *models.Event) {
}
for _, unprocessedEvent := range unprocessedEvents {
currentSequence, err := handler.CurrentSequence()
currentSequence, err := handler.CurrentSequence(unprocessedEvent.InstanceID)
if err != nil {
logging.Log("HANDL-BmpkC").WithError(err).Warn("unable to get current sequence")
return

View File

@@ -5,44 +5,45 @@
package mock
import (
gomock "github.com/golang/mock/gomock"
reflect "reflect"
time "time"
gomock "github.com/golang/mock/gomock"
)
// MockLocker is a mock of Locker interface
// MockLocker is a mock of Locker interface.
type MockLocker struct {
ctrl *gomock.Controller
recorder *MockLockerMockRecorder
}
// MockLockerMockRecorder is the mock recorder for MockLocker
// MockLockerMockRecorder is the mock recorder for MockLocker.
type MockLockerMockRecorder struct {
mock *MockLocker
}
// NewMockLocker creates a new mock instance
// NewMockLocker creates a new mock instance.
func NewMockLocker(ctrl *gomock.Controller) *MockLocker {
mock := &MockLocker{ctrl: ctrl}
mock.recorder = &MockLockerMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use
// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockLocker) EXPECT() *MockLockerMockRecorder {
return m.recorder
}
// Renew mocks base method
func (m *MockLocker) Renew(lockerID, viewModel string, waitTime time.Duration) error {
// Renew mocks base method.
func (m *MockLocker) Renew(lockerID, viewModel, instanceID string, waitTime time.Duration) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Renew", lockerID, viewModel, waitTime)
ret := m.ctrl.Call(m, "Renew", lockerID, viewModel, instanceID, waitTime)
ret0, _ := ret[0].(error)
return ret0
}
// Renew indicates an expected call of Renew
func (mr *MockLockerMockRecorder) Renew(lockerID, viewModel, waitTime interface{}) *gomock.Call {
// Renew indicates an expected call of Renew.
func (mr *MockLockerMockRecorder) Renew(lockerID, viewModel, instanceID, waitTime interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Renew", reflect.TypeOf((*MockLocker)(nil).Renew), lockerID, viewModel, waitTime)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Renew", reflect.TypeOf((*MockLocker)(nil).Renew), lockerID, viewModel, instanceID, waitTime)
}

View File

@@ -16,6 +16,8 @@ import (
"github.com/caos/zitadel/internal/view/repository"
)
const systemID = "system"
type Spooler struct {
handlers []query.Handler
locker Locker
@@ -26,7 +28,7 @@ type Spooler struct {
}
type Locker interface {
Renew(lockerID, viewModel string, waitTime time.Duration) error
Renew(lockerID, viewModel, instanceID string, waitTime time.Duration) error
}
type spooledHandler struct {
@@ -138,19 +140,6 @@ func (s *spooledHandler) query(ctx context.Context) ([]*models.Event, error) {
if err != nil {
return nil, err
}
factory := models.FactoryFromSearchQuery(query)
sequence, err := s.eventstore.LatestSequence(ctx, factory)
logging.OnError(err).WithField("traceID", tracing.TraceIDFromCtx(ctx)).Debug("unable to query latest sequence")
var processedSequence uint64
for _, filter := range query.Filters {
if filter.GetField() == models.Field_LatestSequence {
processedSequence = filter.GetValue().(uint64)
}
}
if sequence != 0 && processedSequence == sequence {
return nil, nil
}
query.Limit = s.QueryLimit()
return s.eventstore.FilterEvents(ctx, query)
}
@@ -169,7 +158,7 @@ func (s *spooledHandler) lock(ctx context.Context, errs chan<- error, workerID s
case <-ctx.Done():
return
case <-renewTimer:
err := s.locker.Renew(workerID, s.ViewModel(), s.LockDuration())
err := s.locker.Renew(workerID, s.ViewModel(), systemID, s.LockDuration())
firstLock.Do(func() {
locked <- err == nil
})
@@ -190,16 +179,17 @@ func (s *spooledHandler) lock(ctx context.Context, errs chan<- error, workerID s
}
func HandleError(event *models.Event, failedErr error,
latestFailedEvent func(sequence uint64) (*repository.FailedEvent, error),
latestFailedEvent func(sequence uint64, instanceID string) (*repository.FailedEvent, error),
processFailedEvent func(*repository.FailedEvent) error,
processSequence func(*models.Event) error,
errorCountUntilSkip uint64) error {
failedEvent, err := latestFailedEvent(event.Sequence)
failedEvent, err := latestFailedEvent(event.Sequence, event.InstanceID)
if err != nil {
return err
}
failedEvent.FailureCount++
failedEvent.ErrMsg = failedErr.Error()
failedEvent.InstanceID = event.InstanceID
err = processFailedEvent(failedEvent)
if err != nil {
return err

View File

@@ -3,17 +3,18 @@ package spooler
import (
"context"
"fmt"
"github.com/caos/zitadel/internal/eventstore"
"github.com/caos/zitadel/internal/eventstore/v1"
"testing"
"time"
"github.com/golang/mock/gomock"
"github.com/caos/zitadel/internal/errors"
"github.com/caos/zitadel/internal/eventstore"
v1 "github.com/caos/zitadel/internal/eventstore/v1"
"github.com/caos/zitadel/internal/eventstore/v1/models"
"github.com/caos/zitadel/internal/eventstore/v1/query"
"github.com/caos/zitadel/internal/eventstore/v1/spooler/mock"
"github.com/caos/zitadel/internal/view/repository"
"github.com/golang/mock/gomock"
)
type testHandler struct {
@@ -30,7 +31,7 @@ func (h *testHandler) AggregateTypes() []models.AggregateType {
return nil
}
func (h *testHandler) CurrentSequence() (uint64, error) {
func (h *testHandler) CurrentSequence(instanceID string) (uint64, error) {
return 0, nil
}
@@ -376,8 +377,8 @@ func newTestLocker(t *testing.T, lockerID, viewName string) *testLocker {
func (l *testLocker) expectRenew(t *testing.T, err error, waitTime time.Duration) *testLocker {
t.Helper()
l.mock.EXPECT().Renew(gomock.Any(), l.viewName, gomock.Any()).DoAndReturn(
func(_, _ string, gotten time.Duration) error {
l.mock.EXPECT().Renew(gomock.Any(), l.viewName, gomock.Any(), gomock.Any()).DoAndReturn(
func(_, _, _ string, gotten time.Duration) error {
t.Helper()
if waitTime-gotten != 0 {
t.Errorf("expected waittime %v got %v", waitTime, gotten)
@@ -396,7 +397,7 @@ func TestHandleError(t *testing.T) {
type args struct {
event *models.Event
failedErr error
latestFailedEvent func(sequence uint64) (*repository.FailedEvent, error)
latestFailedEvent func(sequence uint64, instanceID string) (*repository.FailedEvent, error)
errorCountUntilSkip uint64
}
type res struct {
@@ -413,12 +414,13 @@ func TestHandleError(t *testing.T) {
args: args{
event: &models.Event{Sequence: 30000000},
failedErr: errors.ThrowInternal(nil, "SPOOL-Wk53B", "this was wrong"),
latestFailedEvent: func(s uint64) (*repository.FailedEvent, error) {
latestFailedEvent: func(s uint64, instanceID string) (*repository.FailedEvent, error) {
return &repository.FailedEvent{
ErrMsg: "blub",
FailedSequence: s - 1,
FailureCount: 6,
ViewName: "super.table",
InstanceID: instanceID,
}, nil
},
errorCountUntilSkip: 5,
@@ -432,12 +434,13 @@ func TestHandleError(t *testing.T) {
args: args{
event: &models.Event{Sequence: 30000000},
failedErr: errors.ThrowInternal(nil, "SPOOL-Wk53B", "this was wrong"),
latestFailedEvent: func(s uint64) (*repository.FailedEvent, error) {
latestFailedEvent: func(s uint64, instanceID string) (*repository.FailedEvent, error) {
return &repository.FailedEvent{
ErrMsg: "blub",
FailedSequence: s - 1,
FailureCount: 5,
ViewName: "super.table",
InstanceID: instanceID,
}, nil
},
errorCountUntilSkip: 6,
@@ -451,12 +454,13 @@ func TestHandleError(t *testing.T) {
args: args{
event: &models.Event{Sequence: 30000000},
failedErr: errors.ThrowInternal(nil, "SPOOL-Wk53B", "this was wrong"),
latestFailedEvent: func(s uint64) (*repository.FailedEvent, error) {
latestFailedEvent: func(s uint64, instanceID string) (*repository.FailedEvent, error) {
return &repository.FailedEvent{
ErrMsg: "blub",
FailedSequence: s - 1,
FailureCount: 3,
ViewName: "super.table",
InstanceID: instanceID,
}, nil
},
errorCountUntilSkip: 5,