mirror of
https://github.com/zitadel/zitadel.git
synced 2025-08-11 19:17:32 +00:00
fix: scheduling (#3978)
* fix: improve scheduling * build pre-release * fix: locker * fix: user handler and print stack in case of panic in reducer * chore: remove sentry * fix: improve handler projection and implement tests * more tests * fix: race condition in tests * Update internal/eventstore/repository/sql/query.go Co-authored-by: Silvan <silvan.reusser@gmail.com> * fix: implemented suggested changes * fix: lock statement Co-authored-by: Silvan <silvan.reusser@gmail.com>
This commit is contained in:
@@ -186,6 +186,15 @@ func (es *Eventstore) LatestSequence(ctx context.Context, queryFactory *SearchQu
|
||||
return es.repo.LatestSequence(ctx, query)
|
||||
}
|
||||
|
||||
//InstanceIDs returns the instance ids found by the search query
|
||||
func (es *Eventstore) InstanceIDs(ctx context.Context, queryFactory *SearchQueryBuilder) ([]string, error) {
|
||||
query, err := queryFactory.build(authz.GetInstance(ctx).InstanceID())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return es.repo.InstanceIDs(ctx, query)
|
||||
}
|
||||
|
||||
type QueryReducer interface {
|
||||
reducer
|
||||
//Query returns the SearchQueryFactory for the events needed in reducer
|
||||
|
@@ -688,10 +688,11 @@ func TestEventstore_aggregatesToEvents(t *testing.T) {
|
||||
}
|
||||
|
||||
type testRepo struct {
|
||||
events []*repository.Event
|
||||
sequence uint64
|
||||
err error
|
||||
t *testing.T
|
||||
events []*repository.Event
|
||||
sequence uint64
|
||||
instances []string
|
||||
err error
|
||||
t *testing.T
|
||||
}
|
||||
|
||||
func (repo *testRepo) Health(ctx context.Context) error {
|
||||
@@ -735,6 +736,13 @@ func (repo *testRepo) LatestSequence(ctx context.Context, queryFactory *reposito
|
||||
return repo.sequence, nil
|
||||
}
|
||||
|
||||
func (repo *testRepo) InstanceIDs(ctx context.Context, queryFactory *repository.SearchQuery) ([]string, error) {
|
||||
if repo.err != nil {
|
||||
return nil, repo.err
|
||||
}
|
||||
return repo.instances, nil
|
||||
}
|
||||
|
||||
func TestEventstore_Push(t *testing.T) {
|
||||
type args struct {
|
||||
events []Command
|
||||
|
@@ -6,12 +6,14 @@ import (
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/lib/pq"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/errors"
|
||||
"github.com/zitadel/zitadel/internal/eventstore"
|
||||
)
|
||||
|
||||
const (
|
||||
currentSequenceStmtFormat = `SELECT current_sequence, aggregate_type, instance_id FROM %s WHERE projection_name = $1 FOR UPDATE`
|
||||
currentSequenceStmtFormat = `SELECT current_sequence, aggregate_type, instance_id FROM %s WHERE projection_name = $1 AND instance_id = ANY ($2) FOR UPDATE`
|
||||
updateCurrentSequencesStmtFormat = `UPSERT INTO %s (projection_name, aggregate_type, current_sequence, instance_id, timestamp) VALUES `
|
||||
)
|
||||
|
||||
@@ -22,8 +24,8 @@ type instanceSequence struct {
|
||||
sequence uint64
|
||||
}
|
||||
|
||||
func (h *StatementHandler) currentSequences(ctx context.Context, query func(context.Context, string, ...interface{}) (*sql.Rows, error)) (currentSequences, error) {
|
||||
rows, err := query(ctx, h.currentSequenceStmt, h.ProjectionName)
|
||||
func (h *StatementHandler) currentSequences(ctx context.Context, query func(context.Context, string, ...interface{}) (*sql.Rows, error), instanceIDs []string) (currentSequences, error) {
|
||||
rows, err := query(ctx, h.currentSequenceStmt, h.ProjectionName, pq.StringArray(instanceIDs))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@@ -8,6 +8,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/DATA-DOG/go-sqlmock"
|
||||
"github.com/lib/pq"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/eventstore"
|
||||
)
|
||||
@@ -123,34 +124,40 @@ func expectSavePointRelease() func(sqlmock.Sqlmock) {
|
||||
}
|
||||
}
|
||||
|
||||
func expectCurrentSequence(tableName, projection string, seq uint64, aggregateType, instanceID string) func(sqlmock.Sqlmock) {
|
||||
func expectCurrentSequence(tableName, projection string, seq uint64, aggregateType string, instanceIDs []string) func(sqlmock.Sqlmock) {
|
||||
rows := sqlmock.NewRows([]string{"current_sequence", "aggregate_type", "instance_id"})
|
||||
for _, instanceID := range instanceIDs {
|
||||
rows.AddRow(seq, aggregateType, instanceID)
|
||||
}
|
||||
return func(m sqlmock.Sqlmock) {
|
||||
m.ExpectQuery(`SELECT current_sequence, aggregate_type, instance_id FROM ` + tableName + ` WHERE projection_name = \$1 FOR UPDATE`).
|
||||
m.ExpectQuery(`SELECT current_sequence, aggregate_type, instance_id FROM `+tableName+` WHERE projection_name = \$1 AND instance_id = ANY \(\$2\) FOR UPDATE`).
|
||||
WithArgs(
|
||||
projection,
|
||||
pq.StringArray(instanceIDs),
|
||||
).
|
||||
WillReturnRows(
|
||||
sqlmock.NewRows([]string{"current_sequence", "aggregate_type", "instance_id"}).
|
||||
AddRow(seq, aggregateType, instanceID),
|
||||
rows,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
func expectCurrentSequenceErr(tableName, projection string, err error) func(sqlmock.Sqlmock) {
|
||||
func expectCurrentSequenceErr(tableName, projection string, instanceIDs []string, err error) func(sqlmock.Sqlmock) {
|
||||
return func(m sqlmock.Sqlmock) {
|
||||
m.ExpectQuery(`SELECT current_sequence, aggregate_type, instance_id FROM ` + tableName + ` WHERE projection_name = \$1 FOR UPDATE`).
|
||||
m.ExpectQuery(`SELECT current_sequence, aggregate_type, instance_id FROM `+tableName+` WHERE projection_name = \$1 AND instance_id = ANY \(\$2\) FOR UPDATE`).
|
||||
WithArgs(
|
||||
projection,
|
||||
pq.StringArray(instanceIDs),
|
||||
).
|
||||
WillReturnError(err)
|
||||
}
|
||||
}
|
||||
|
||||
func expectCurrentSequenceNoRows(tableName, projection string) func(sqlmock.Sqlmock) {
|
||||
func expectCurrentSequenceNoRows(tableName, projection string, instanceIDs []string) func(sqlmock.Sqlmock) {
|
||||
return func(m sqlmock.Sqlmock) {
|
||||
m.ExpectQuery(`SELECT current_sequence, aggregate_type, instance_id FROM ` + tableName + ` WHERE projection_name = \$1 FOR UPDATE`).
|
||||
m.ExpectQuery(`SELECT current_sequence, aggregate_type, instance_id FROM `+tableName+` WHERE projection_name = \$1 AND instance_id = ANY \(\$2\) FOR UPDATE`).
|
||||
WithArgs(
|
||||
projection,
|
||||
pq.StringArray(instanceIDs),
|
||||
).
|
||||
WillReturnRows(
|
||||
sqlmock.NewRows([]string{"current_sequence", "aggregate_type", "instance_id"}),
|
||||
@@ -158,11 +165,12 @@ func expectCurrentSequenceNoRows(tableName, projection string) func(sqlmock.Sqlm
|
||||
}
|
||||
}
|
||||
|
||||
func expectCurrentSequenceScanErr(tableName, projection string) func(sqlmock.Sqlmock) {
|
||||
func expectCurrentSequenceScanErr(tableName, projection string, instanceIDs []string) func(sqlmock.Sqlmock) {
|
||||
return func(m sqlmock.Sqlmock) {
|
||||
m.ExpectQuery(`SELECT current_sequence, aggregate_type, instance_id FROM ` + tableName + ` WHERE projection_name = \$1 FOR UPDATE`).
|
||||
m.ExpectQuery(`SELECT current_sequence, aggregate_type, instance_id FROM `+tableName+` WHERE projection_name = \$1 AND instance_id = ANY \(\$2\) FOR UPDATE`).
|
||||
WithArgs(
|
||||
projection,
|
||||
pq.StringArray(instanceIDs),
|
||||
).
|
||||
WillReturnRows(
|
||||
sqlmock.NewRows([]string{"current_sequence", "aggregate_type", "instance_id"}).
|
||||
@@ -286,12 +294,34 @@ func expectLock(lockTable, workerName string, d time.Duration, instanceID string
|
||||
` \(locker_id, locked_until, projection_name, instance_id\) VALUES \(\$1, now\(\)\+\$2::INTERVAL, \$3\, \$4\)`+
|
||||
` ON CONFLICT \(projection_name, instance_id\)`+
|
||||
` DO UPDATE SET locker_id = \$1, locked_until = now\(\)\+\$2::INTERVAL`+
|
||||
` WHERE `+lockTable+`\.projection_name = \$3 AND `+lockTable+`\.instance_id = \$4 AND \(`+lockTable+`\.locker_id = \$1 OR `+lockTable+`\.locked_until < now\(\)\)`).
|
||||
` WHERE `+lockTable+`\.projection_name = \$3 AND `+lockTable+`\.instance_id = ANY \(\$5\) AND \(`+lockTable+`\.locker_id = \$1 OR `+lockTable+`\.locked_until < now\(\)\)`).
|
||||
WithArgs(
|
||||
workerName,
|
||||
float64(d),
|
||||
projectionName,
|
||||
instanceID,
|
||||
pq.StringArray{instanceID},
|
||||
).
|
||||
WillReturnResult(
|
||||
sqlmock.NewResult(1, 1),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
func expectLockMultipleInstances(lockTable, workerName string, d time.Duration, instanceID1, instanceID2 string) func(sqlmock.Sqlmock) {
|
||||
return func(m sqlmock.Sqlmock) {
|
||||
m.ExpectExec(`INSERT INTO `+lockTable+
|
||||
` \(locker_id, locked_until, projection_name, instance_id\) VALUES \(\$1, now\(\)\+\$2::INTERVAL, \$3\, \$4\), \(\$1, now\(\)\+\$2::INTERVAL, \$3\, \$5\)`+
|
||||
` ON CONFLICT \(projection_name, instance_id\)`+
|
||||
` DO UPDATE SET locker_id = \$1, locked_until = now\(\)\+\$2::INTERVAL`+
|
||||
` WHERE `+lockTable+`\.projection_name = \$3 AND `+lockTable+`\.instance_id = ANY \(\$6\) AND \(`+lockTable+`\.locker_id = \$1 OR `+lockTable+`\.locked_until < now\(\)\)`).
|
||||
WithArgs(
|
||||
workerName,
|
||||
float64(d),
|
||||
projectionName,
|
||||
instanceID1,
|
||||
instanceID2,
|
||||
pq.StringArray{instanceID1, instanceID2},
|
||||
).
|
||||
WillReturnResult(
|
||||
sqlmock.NewResult(1, 1),
|
||||
@@ -305,12 +335,13 @@ func expectLockNoRows(lockTable, workerName string, d time.Duration, instanceID
|
||||
` \(locker_id, locked_until, projection_name, instance_id\) VALUES \(\$1, now\(\)\+\$2::INTERVAL, \$3\, \$4\)`+
|
||||
` ON CONFLICT \(projection_name, instance_id\)`+
|
||||
` DO UPDATE SET locker_id = \$1, locked_until = now\(\)\+\$2::INTERVAL`+
|
||||
` WHERE `+lockTable+`\.projection_name = \$3 AND `+lockTable+`\.instance_id = \$4 AND \(`+lockTable+`\.locker_id = \$1 OR `+lockTable+`\.locked_until < now\(\)\)`).
|
||||
` WHERE `+lockTable+`\.projection_name = \$3 AND `+lockTable+`\.instance_id = ANY \(\$5\) AND \(`+lockTable+`\.locker_id = \$1 OR `+lockTable+`\.locked_until < now\(\)\)`).
|
||||
WithArgs(
|
||||
workerName,
|
||||
float64(d),
|
||||
projectionName,
|
||||
instanceID,
|
||||
pq.StringArray{instanceID},
|
||||
).
|
||||
WillReturnResult(driver.ResultNoRows)
|
||||
}
|
||||
@@ -322,12 +353,13 @@ func expectLockErr(lockTable, workerName string, d time.Duration, instanceID str
|
||||
` \(locker_id, locked_until, projection_name, instance_id\) VALUES \(\$1, now\(\)\+\$2::INTERVAL, \$3\, \$4\)`+
|
||||
` ON CONFLICT \(projection_name, instance_id\)`+
|
||||
` DO UPDATE SET locker_id = \$1, locked_until = now\(\)\+\$2::INTERVAL`+
|
||||
` WHERE `+lockTable+`\.projection_name = \$3 AND `+lockTable+`\.instance_id = \$4 AND \(`+lockTable+`\.locker_id = \$1 OR `+lockTable+`\.locked_until < now\(\)\)`).
|
||||
` WHERE `+lockTable+`\.projection_name = \$3 AND `+lockTable+`\.instance_id = ANY \(\$5\) AND \(`+lockTable+`\.locker_id = \$1 OR `+lockTable+`\.locked_until < now\(\)\)`).
|
||||
WithArgs(
|
||||
workerName,
|
||||
float64(d),
|
||||
projectionName,
|
||||
instanceID,
|
||||
pq.StringArray{instanceID},
|
||||
).
|
||||
WillReturnError(err)
|
||||
}
|
||||
|
@@ -10,7 +10,6 @@ import (
|
||||
"github.com/zitadel/zitadel/internal/errors"
|
||||
"github.com/zitadel/zitadel/internal/eventstore"
|
||||
"github.com/zitadel/zitadel/internal/eventstore/handler"
|
||||
"github.com/zitadel/zitadel/internal/telemetry/tracing"
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -75,84 +74,62 @@ func NewStatementHandler(
|
||||
bulkLimit: config.BulkLimit,
|
||||
Locker: NewLocker(config.Client, config.LockTable, config.ProjectionHandlerConfig.ProjectionName),
|
||||
}
|
||||
h.ProjectionHandler = handler.NewProjectionHandler(config.ProjectionHandlerConfig, h.reduce, h.Update, h.SearchQuery)
|
||||
h.ProjectionHandler = handler.NewProjectionHandler(ctx, config.ProjectionHandlerConfig, h.reduce, h.Update, h.SearchQuery, h.Lock, h.Unlock)
|
||||
|
||||
err := h.Init(ctx, config.InitCheck)
|
||||
logging.OnError(err).Fatal("unable to initialize projections")
|
||||
|
||||
go h.Process(
|
||||
ctx,
|
||||
h.reduce,
|
||||
h.Update,
|
||||
h.Lock,
|
||||
h.Unlock,
|
||||
h.SearchQuery,
|
||||
)
|
||||
|
||||
h.Subscribe(h.aggregates...)
|
||||
|
||||
return h
|
||||
}
|
||||
|
||||
func (h *StatementHandler) TriggerBulk(ctx context.Context) {
|
||||
ctx, span := tracing.NewSpan(ctx)
|
||||
var err error
|
||||
defer span.EndWithError(err)
|
||||
|
||||
err = h.ProjectionHandler.TriggerBulk(ctx, h.Lock, h.Unlock)
|
||||
logging.OnError(err).WithField("projection", h.ProjectionName).Warn("unable to trigger bulk")
|
||||
}
|
||||
|
||||
func (h *StatementHandler) SearchQuery(ctx context.Context) (*eventstore.SearchQueryBuilder, uint64, error) {
|
||||
sequences, err := h.currentSequences(ctx, h.client.QueryContext)
|
||||
func (h *StatementHandler) SearchQuery(ctx context.Context, instanceIDs []string) (*eventstore.SearchQueryBuilder, uint64, error) {
|
||||
sequences, err := h.currentSequences(ctx, h.client.QueryContext, instanceIDs)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
queryBuilder := eventstore.NewSearchQueryBuilder(eventstore.ColumnsEvent).Limit(h.bulkLimit)
|
||||
|
||||
for _, aggregateType := range h.aggregates {
|
||||
instances := make([]string, 0)
|
||||
for _, sequence := range sequences[aggregateType] {
|
||||
instances = appendToIgnoredInstances(instances, sequence.instanceID)
|
||||
for _, instanceID := range instanceIDs {
|
||||
var seq uint64
|
||||
for _, sequence := range sequences[aggregateType] {
|
||||
if sequence.instanceID == instanceID {
|
||||
seq = sequence.sequence
|
||||
break
|
||||
}
|
||||
}
|
||||
queryBuilder.
|
||||
AddQuery().
|
||||
AggregateTypes(aggregateType).
|
||||
SequenceGreater(sequence.sequence).
|
||||
InstanceID(sequence.instanceID)
|
||||
SequenceGreater(seq).
|
||||
InstanceID(instanceID)
|
||||
}
|
||||
queryBuilder.
|
||||
AddQuery().
|
||||
AggregateTypes(aggregateType).
|
||||
SequenceGreater(0).
|
||||
ExcludedInstanceID(instances...)
|
||||
}
|
||||
|
||||
return queryBuilder, h.bulkLimit, nil
|
||||
}
|
||||
|
||||
func appendToIgnoredInstances(instances []string, id string) []string {
|
||||
for _, instance := range instances {
|
||||
if instance == id {
|
||||
return instances
|
||||
}
|
||||
}
|
||||
return append(instances, id)
|
||||
}
|
||||
|
||||
//Update implements handler.Update
|
||||
func (h *StatementHandler) Update(ctx context.Context, stmts []*handler.Statement, reduce handler.Reduce) (unexecutedStmts []*handler.Statement, err error) {
|
||||
func (h *StatementHandler) Update(ctx context.Context, stmts []*handler.Statement, reduce handler.Reduce) (index int, err error) {
|
||||
if len(stmts) == 0 {
|
||||
return nil, nil
|
||||
return -1, nil
|
||||
}
|
||||
instanceIDs := make([]string, 0, len(stmts))
|
||||
for _, stmt := range stmts {
|
||||
instanceIDs = appendToInstanceIDs(instanceIDs, stmt.InstanceID)
|
||||
}
|
||||
tx, err := h.client.BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
return stmts, errors.ThrowInternal(err, "CRDB-e89Gq", "begin failed")
|
||||
return -1, errors.ThrowInternal(err, "CRDB-e89Gq", "begin failed")
|
||||
}
|
||||
|
||||
sequences, err := h.currentSequences(ctx, tx.QueryContext)
|
||||
sequences, err := h.currentSequences(ctx, tx.QueryContext, instanceIDs)
|
||||
if err != nil {
|
||||
tx.Rollback()
|
||||
return stmts, err
|
||||
return -1, err
|
||||
}
|
||||
|
||||
//checks for events between create statement and current sequence
|
||||
@@ -162,7 +139,7 @@ func (h *StatementHandler) Update(ctx context.Context, stmts []*handler.Statemen
|
||||
previousStmts, err := h.fetchPreviousStmts(ctx, tx, stmts[0].Sequence, stmts[0].InstanceID, sequences, reduce)
|
||||
if err != nil {
|
||||
tx.Rollback()
|
||||
return stmts, err
|
||||
return -1, err
|
||||
}
|
||||
stmts = append(previousStmts, stmts...)
|
||||
}
|
||||
@@ -173,27 +150,19 @@ func (h *StatementHandler) Update(ctx context.Context, stmts []*handler.Statemen
|
||||
err = h.updateCurrentSequences(tx, sequences)
|
||||
if err != nil {
|
||||
tx.Rollback()
|
||||
return stmts, err
|
||||
return -1, err
|
||||
}
|
||||
}
|
||||
|
||||
if err = tx.Commit(); err != nil {
|
||||
return stmts, err
|
||||
return -1, err
|
||||
}
|
||||
|
||||
if lastSuccessfulIdx == -1 && len(stmts) > 0 {
|
||||
return stmts, handler.ErrSomeStmtsFailed
|
||||
if lastSuccessfulIdx < len(stmts)-1 {
|
||||
return lastSuccessfulIdx, handler.ErrSomeStmtsFailed
|
||||
}
|
||||
|
||||
unexecutedStmts = make([]*handler.Statement, len(stmts)-(lastSuccessfulIdx+1))
|
||||
copy(unexecutedStmts, stmts[lastSuccessfulIdx+1:])
|
||||
stmts = nil
|
||||
|
||||
if len(unexecutedStmts) > 0 {
|
||||
return unexecutedStmts, handler.ErrSomeStmtsFailed
|
||||
}
|
||||
|
||||
return unexecutedStmts, nil
|
||||
return lastSuccessfulIdx, nil
|
||||
}
|
||||
|
||||
func (h *StatementHandler) fetchPreviousStmts(ctx context.Context, tx *sql.Tx, stmtSeq uint64, instanceID string, sequences currentSequences, reduce handler.Reduce) (previousStmts []*handler.Statement, err error) {
|
||||
@@ -316,3 +285,12 @@ func updateSequences(sequences currentSequences, stmt *handler.Statement) {
|
||||
sequence: stmt.Sequence,
|
||||
})
|
||||
}
|
||||
|
||||
func appendToInstanceIDs(instances []string, id string) []string {
|
||||
for _, instance := range instances {
|
||||
if instance == id {
|
||||
return instances
|
||||
}
|
||||
}
|
||||
return append(instances, id)
|
||||
}
|
||||
|
@@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"testing"
|
||||
"time"
|
||||
@@ -61,9 +62,13 @@ func TestProjectionHandler_SearchQuery(t *testing.T) {
|
||||
aggregates []eventstore.AggregateType
|
||||
bulkLimit uint64
|
||||
}
|
||||
type args struct {
|
||||
instanceIDs []string
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
args args
|
||||
want want
|
||||
}{
|
||||
{
|
||||
@@ -74,13 +79,16 @@ func TestProjectionHandler_SearchQuery(t *testing.T) {
|
||||
aggregates: []eventstore.AggregateType{"testAgg"},
|
||||
bulkLimit: 5,
|
||||
},
|
||||
args: args{
|
||||
instanceIDs: []string{"instanceID1"},
|
||||
},
|
||||
want: want{
|
||||
limit: 0,
|
||||
isErr: func(err error) bool {
|
||||
return errors.Is(err, sql.ErrTxDone)
|
||||
},
|
||||
expectations: []mockExpectation{
|
||||
expectCurrentSequenceErr("my_sequences", "my_projection", sql.ErrTxDone),
|
||||
expectCurrentSequenceErr("my_sequences", "my_projection", []string{"instanceID1"}, sql.ErrTxDone),
|
||||
},
|
||||
SearchQueryBuilder: nil,
|
||||
},
|
||||
@@ -93,24 +101,56 @@ func TestProjectionHandler_SearchQuery(t *testing.T) {
|
||||
aggregates: []eventstore.AggregateType{"testAgg"},
|
||||
bulkLimit: 5,
|
||||
},
|
||||
args: args{
|
||||
instanceIDs: []string{"instanceID1"},
|
||||
},
|
||||
want: want{
|
||||
limit: 5,
|
||||
isErr: func(err error) bool {
|
||||
return err == nil
|
||||
},
|
||||
expectations: []mockExpectation{
|
||||
expectCurrentSequence("my_sequences", "my_projection", 5, "testAgg", "instanceID"),
|
||||
expectCurrentSequence("my_sequences", "my_projection", 5, "testAgg", []string{"instanceID1"}),
|
||||
},
|
||||
SearchQueryBuilder: eventstore.
|
||||
NewSearchQueryBuilder(eventstore.ColumnsEvent).
|
||||
AddQuery().
|
||||
AggregateTypes("testAgg").
|
||||
SequenceGreater(5).
|
||||
InstanceID("instanceID").
|
||||
InstanceID("instanceID1").
|
||||
Builder().
|
||||
Limit(5),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "multiple instances",
|
||||
fields: fields{
|
||||
sequenceTable: "my_sequences",
|
||||
projectionName: "my_projection",
|
||||
aggregates: []eventstore.AggregateType{"testAgg"},
|
||||
bulkLimit: 5,
|
||||
},
|
||||
args: args{
|
||||
instanceIDs: []string{"instanceID1", "instanceID2"},
|
||||
},
|
||||
want: want{
|
||||
limit: 5,
|
||||
isErr: func(err error) bool {
|
||||
return err == nil
|
||||
},
|
||||
expectations: []mockExpectation{
|
||||
expectCurrentSequence("my_sequences", "my_projection", 5, "testAgg", []string{"instanceID1", "instanceID2"}),
|
||||
},
|
||||
SearchQueryBuilder: eventstore.
|
||||
NewSearchQueryBuilder(eventstore.ColumnsEvent).
|
||||
AddQuery().
|
||||
AggregateTypes("testAgg").
|
||||
SequenceGreater(5).
|
||||
InstanceID("instanceID1").
|
||||
Or().
|
||||
AggregateTypes("testAgg").
|
||||
SequenceGreater(0).
|
||||
ExcludedInstanceID("instanceID").
|
||||
SequenceGreater(5).
|
||||
InstanceID("instanceID2").
|
||||
Builder().
|
||||
Limit(5),
|
||||
},
|
||||
@@ -140,7 +180,7 @@ func TestProjectionHandler_SearchQuery(t *testing.T) {
|
||||
expectation(mock)
|
||||
}
|
||||
|
||||
query, limit, err := h.SearchQuery(context.Background())
|
||||
query, limit, err := h.SearchQuery(context.Background(), tt.args.instanceIDs)
|
||||
if !tt.want.isErr(err) {
|
||||
t.Errorf("ProjectionHandler.prepareBulkStmts() error = %v", err)
|
||||
return
|
||||
@@ -211,13 +251,14 @@ func TestStatementHandler_Update(t *testing.T) {
|
||||
aggregateType: "agg",
|
||||
sequence: 6,
|
||||
previousSequence: 0,
|
||||
instanceID: "instanceID",
|
||||
}),
|
||||
},
|
||||
},
|
||||
want: want{
|
||||
expectations: []mockExpectation{
|
||||
expectBegin(),
|
||||
expectCurrentSequenceErr("my_sequences", "my_projection", sql.ErrTxDone),
|
||||
expectCurrentSequenceErr("my_sequences", "my_projection", []string{"instanceID"}, sql.ErrTxDone),
|
||||
expectRollback(),
|
||||
},
|
||||
isErr: func(err error) bool {
|
||||
@@ -241,13 +282,14 @@ func TestStatementHandler_Update(t *testing.T) {
|
||||
aggregateType: "agg",
|
||||
sequence: 6,
|
||||
previousSequence: 0,
|
||||
instanceID: "instanceID",
|
||||
}),
|
||||
},
|
||||
},
|
||||
want: want{
|
||||
expectations: []mockExpectation{
|
||||
expectBegin(),
|
||||
expectCurrentSequence("my_sequences", "my_projection", 5, "testAgg", "instanceID"),
|
||||
expectCurrentSequence("my_sequences", "my_projection", 5, "testAgg", []string{"instanceID"}),
|
||||
expectRollback(),
|
||||
},
|
||||
isErr: func(err error) bool {
|
||||
@@ -272,6 +314,7 @@ func TestStatementHandler_Update(t *testing.T) {
|
||||
aggregateType: "testAgg",
|
||||
sequence: 7,
|
||||
previousSequence: 6,
|
||||
instanceID: "instanceID",
|
||||
},
|
||||
[]handler.Column{
|
||||
{
|
||||
@@ -284,7 +327,7 @@ func TestStatementHandler_Update(t *testing.T) {
|
||||
want: want{
|
||||
expectations: []mockExpectation{
|
||||
expectBegin(),
|
||||
expectCurrentSequence("my_sequences", "my_projection", 5, "testAgg", "instanceID"),
|
||||
expectCurrentSequence("my_sequences", "my_projection", 5, "testAgg", []string{"instanceID"}),
|
||||
expectCommit(),
|
||||
},
|
||||
isErr: func(err error) bool {
|
||||
@@ -322,7 +365,7 @@ func TestStatementHandler_Update(t *testing.T) {
|
||||
want: want{
|
||||
expectations: []mockExpectation{
|
||||
expectBegin(),
|
||||
expectCurrentSequence("my_sequences", "my_projection", 5, "agg", "instanceID"),
|
||||
expectCurrentSequence("my_sequences", "my_projection", 5, "agg", []string{"instanceID"}),
|
||||
expectSavePoint(),
|
||||
expectCreate("my_projection", []string{"col"}, []string{"$1"}),
|
||||
expectSavePointRelease(),
|
||||
@@ -364,7 +407,7 @@ func TestStatementHandler_Update(t *testing.T) {
|
||||
want: want{
|
||||
expectations: []mockExpectation{
|
||||
expectBegin(),
|
||||
expectCurrentSequence("my_sequences", "my_projection", 5, "agg", "instanceID"),
|
||||
expectCurrentSequence("my_sequences", "my_projection", 5, "agg", []string{"instanceID"}),
|
||||
expectSavePoint(),
|
||||
expectCreate("my_projection", []string{"col"}, []string{"$1"}),
|
||||
expectSavePointRelease(),
|
||||
@@ -399,7 +442,7 @@ func TestStatementHandler_Update(t *testing.T) {
|
||||
want: want{
|
||||
expectations: []mockExpectation{
|
||||
expectBegin(),
|
||||
expectCurrentSequence("my_sequences", "my_projection", 5, "testAgg", "instanceID"),
|
||||
expectCurrentSequence("my_sequences", "my_projection", 5, "testAgg", []string{"instanceID"}),
|
||||
expectUpdateCurrentSequence("my_sequences", "my_projection", 7, "testAgg", "instanceID"),
|
||||
expectCommit(),
|
||||
},
|
||||
@@ -431,7 +474,7 @@ func TestStatementHandler_Update(t *testing.T) {
|
||||
want: want{
|
||||
expectations: []mockExpectation{
|
||||
expectBegin(),
|
||||
expectCurrentSequence("my_sequences", "my_projection", 5, "testAgg", "instanceID"),
|
||||
expectCurrentSequence("my_sequences", "my_projection", 5, "testAgg", []string{"instanceID"}),
|
||||
expectUpdateCurrentSequence("my_sequences", "my_projection", 7, "testAgg", "instanceID"),
|
||||
expectCommit(),
|
||||
},
|
||||
@@ -470,13 +513,14 @@ func TestStatementHandler_Update(t *testing.T) {
|
||||
want: want{
|
||||
expectations: []mockExpectation{
|
||||
expectBegin(),
|
||||
expectCurrentSequence("my_sequences", "my_projection", 5, "testAgg", "instanceID"),
|
||||
expectCurrentSequence("my_sequences", "my_projection", 5, "testAgg", []string{"instanceID"}),
|
||||
expectUpdateCurrentSequence("my_sequences", "my_projection", 7, "testAgg", "instanceID"),
|
||||
expectCommit(),
|
||||
},
|
||||
isErr: func(err error) bool {
|
||||
return errors.Is(err, nil)
|
||||
},
|
||||
stmtsLen: 1,
|
||||
},
|
||||
},
|
||||
}
|
||||
@@ -488,17 +532,18 @@ func TestStatementHandler_Update(t *testing.T) {
|
||||
}
|
||||
defer client.Close()
|
||||
|
||||
h := NewStatementHandler(context.Background(), StatementHandlerConfig{
|
||||
ProjectionHandlerConfig: handler.ProjectionHandlerConfig{
|
||||
ProjectionName: "my_projection",
|
||||
HandlerConfig: handler.HandlerConfig{
|
||||
h := &StatementHandler{
|
||||
ProjectionHandler: &handler.ProjectionHandler{
|
||||
Handler: handler.Handler{
|
||||
Eventstore: tt.fields.eventstore,
|
||||
},
|
||||
RequeueEvery: 0,
|
||||
ProjectionName: "my_projection",
|
||||
},
|
||||
SequenceTable: "my_sequences",
|
||||
Client: client,
|
||||
})
|
||||
sequenceTable: "my_sequences",
|
||||
currentSequenceStmt: fmt.Sprintf(currentSequenceStmtFormat, "my_sequences"),
|
||||
updateSequencesBaseStmt: fmt.Sprintf(updateCurrentSequencesStmtFormat, "my_sequences"),
|
||||
client: client,
|
||||
}
|
||||
|
||||
h.aggregates = tt.fields.aggregates
|
||||
|
||||
@@ -506,12 +551,12 @@ func TestStatementHandler_Update(t *testing.T) {
|
||||
expectation(mock)
|
||||
}
|
||||
|
||||
stmts, err := h.Update(tt.args.ctx, tt.args.stmts, tt.args.reduce)
|
||||
index, err := h.Update(tt.args.ctx, tt.args.stmts, tt.args.reduce)
|
||||
if !tt.want.isErr(err) {
|
||||
t.Errorf("StatementHandler.Update() error = %v", err)
|
||||
}
|
||||
if err == nil && tt.want.stmtsLen != len(stmts) {
|
||||
t.Errorf("wrong stmts length: want: %d got %d", tt.want.stmtsLen, len(stmts))
|
||||
if err == nil && tt.want.stmtsLen != index {
|
||||
t.Errorf("wrong stmts length: want: %d got %d", tt.want.stmtsLen, index)
|
||||
}
|
||||
|
||||
mock.MatchExpectationsInOrder(true)
|
||||
@@ -696,17 +741,12 @@ func TestProjectionHandler_fetchPreviousStmts(t *testing.T) {
|
||||
h := &StatementHandler{
|
||||
aggregates: tt.fields.aggregates,
|
||||
}
|
||||
h.ProjectionHandler = handler.NewProjectionHandler(handler.ProjectionHandlerConfig{
|
||||
HandlerConfig: handler.HandlerConfig{
|
||||
h.ProjectionHandler = &handler.ProjectionHandler{
|
||||
Handler: handler.Handler{
|
||||
Eventstore: tt.fields.eventstore,
|
||||
},
|
||||
ProjectionName: "my_projection",
|
||||
RequeueEvery: 0,
|
||||
},
|
||||
h.reduce,
|
||||
h.Update,
|
||||
h.SearchQuery,
|
||||
)
|
||||
}
|
||||
stmts, err := h.fetchPreviousStmts(tt.args.ctx, nil, tt.args.stmtSeq, "", tt.args.sequences, tt.args.reduce)
|
||||
if !tt.want.isErr(err) {
|
||||
t.Errorf("ProjectionHandler.prepareBulkStmts() error = %v", err)
|
||||
@@ -1311,7 +1351,8 @@ func TestStatementHandler_currentSequence(t *testing.T) {
|
||||
aggregates []eventstore.AggregateType
|
||||
}
|
||||
type args struct {
|
||||
stmt handler.Statement
|
||||
stmt handler.Statement
|
||||
instanceIDs []string
|
||||
}
|
||||
type want struct {
|
||||
expectations []mockExpectation
|
||||
@@ -1338,7 +1379,7 @@ func TestStatementHandler_currentSequence(t *testing.T) {
|
||||
return errors.Is(err, sql.ErrConnDone)
|
||||
},
|
||||
expectations: []mockExpectation{
|
||||
expectCurrentSequenceErr("my_table", "my_projection", sql.ErrConnDone),
|
||||
expectCurrentSequenceErr("my_table", "my_projection", nil, sql.ErrConnDone),
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -1350,14 +1391,15 @@ func TestStatementHandler_currentSequence(t *testing.T) {
|
||||
aggregates: []eventstore.AggregateType{"agg"},
|
||||
},
|
||||
args: args{
|
||||
stmt: handler.Statement{},
|
||||
stmt: handler.Statement{},
|
||||
instanceIDs: []string{"instanceID"},
|
||||
},
|
||||
want: want{
|
||||
isErr: func(err error) bool {
|
||||
return errors.Is(err, nil)
|
||||
},
|
||||
expectations: []mockExpectation{
|
||||
expectCurrentSequenceNoRows("my_table", "my_projection"),
|
||||
expectCurrentSequenceNoRows("my_table", "my_projection", []string{"instanceID"}),
|
||||
},
|
||||
sequences: currentSequences{},
|
||||
},
|
||||
@@ -1370,14 +1412,15 @@ func TestStatementHandler_currentSequence(t *testing.T) {
|
||||
aggregates: []eventstore.AggregateType{"agg"},
|
||||
},
|
||||
args: args{
|
||||
stmt: handler.Statement{},
|
||||
stmt: handler.Statement{},
|
||||
instanceIDs: []string{"instanceID"},
|
||||
},
|
||||
want: want{
|
||||
isErr: func(err error) bool {
|
||||
return errors.Is(err, sql.ErrTxDone)
|
||||
},
|
||||
expectations: []mockExpectation{
|
||||
expectCurrentSequenceScanErr("my_table", "my_projection"),
|
||||
expectCurrentSequenceScanErr("my_table", "my_projection", []string{"instanceID"}),
|
||||
},
|
||||
sequences: currentSequences{},
|
||||
},
|
||||
@@ -1390,14 +1433,15 @@ func TestStatementHandler_currentSequence(t *testing.T) {
|
||||
aggregates: []eventstore.AggregateType{"agg"},
|
||||
},
|
||||
args: args{
|
||||
stmt: handler.Statement{},
|
||||
stmt: handler.Statement{},
|
||||
instanceIDs: []string{"instanceID"},
|
||||
},
|
||||
want: want{
|
||||
isErr: func(err error) bool {
|
||||
return errors.Is(err, nil)
|
||||
},
|
||||
expectations: []mockExpectation{
|
||||
expectCurrentSequence("my_table", "my_projection", 5, "agg", "instanceID"),
|
||||
expectCurrentSequence("my_table", "my_projection", 5, "agg", []string{"instanceID"}),
|
||||
},
|
||||
sequences: currentSequences{
|
||||
"agg": []*instanceSequence{
|
||||
@@ -1409,15 +1453,48 @@ func TestStatementHandler_currentSequence(t *testing.T) {
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "multiple found",
|
||||
fields: fields{
|
||||
sequenceTable: "my_table",
|
||||
projectionName: "my_projection",
|
||||
aggregates: []eventstore.AggregateType{"agg"},
|
||||
},
|
||||
args: args{
|
||||
stmt: handler.Statement{},
|
||||
instanceIDs: []string{"instanceID1", "instanceID2"},
|
||||
},
|
||||
want: want{
|
||||
isErr: func(err error) bool {
|
||||
return errors.Is(err, nil)
|
||||
},
|
||||
expectations: []mockExpectation{
|
||||
expectCurrentSequence("my_table", "my_projection", 5, "agg", []string{"instanceID1", "instanceID2"}),
|
||||
},
|
||||
sequences: currentSequences{
|
||||
"agg": []*instanceSequence{
|
||||
{
|
||||
sequence: 5,
|
||||
instanceID: "instanceID1",
|
||||
},
|
||||
{
|
||||
sequence: 5,
|
||||
instanceID: "instanceID2",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
h := NewStatementHandler(context.Background(), StatementHandlerConfig{
|
||||
ProjectionHandlerConfig: handler.ProjectionHandlerConfig{
|
||||
h := &StatementHandler{
|
||||
ProjectionHandler: &handler.ProjectionHandler{
|
||||
ProjectionName: tt.fields.projectionName,
|
||||
},
|
||||
SequenceTable: tt.fields.sequenceTable,
|
||||
})
|
||||
sequenceTable: tt.fields.sequenceTable,
|
||||
currentSequenceStmt: fmt.Sprintf(currentSequenceStmtFormat, tt.fields.sequenceTable),
|
||||
}
|
||||
|
||||
h.aggregates = tt.fields.aggregates
|
||||
|
||||
@@ -1440,7 +1517,7 @@ func TestStatementHandler_currentSequence(t *testing.T) {
|
||||
t.Fatalf("unexpected err in begin: %v", err)
|
||||
}
|
||||
|
||||
seq, err := h.currentSequences(context.Background(), tx.QueryContext)
|
||||
seq, err := h.currentSequences(context.Background(), tx.QueryContext, tt.args.instanceIDs)
|
||||
if !tt.want.isErr(err) {
|
||||
t.Errorf("unexpected error: %v", err)
|
||||
}
|
||||
@@ -1615,12 +1692,13 @@ func TestStatementHandler_updateCurrentSequence(t *testing.T) {
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
|
||||
h := NewStatementHandler(context.Background(), StatementHandlerConfig{
|
||||
ProjectionHandlerConfig: handler.ProjectionHandlerConfig{
|
||||
h := &StatementHandler{
|
||||
ProjectionHandler: &handler.ProjectionHandler{
|
||||
ProjectionName: tt.fields.projectionName,
|
||||
},
|
||||
SequenceTable: tt.fields.sequenceTable,
|
||||
})
|
||||
sequenceTable: tt.fields.sequenceTable,
|
||||
updateSequencesBaseStmt: fmt.Sprintf(updateCurrentSequencesStmtFormat, tt.fields.sequenceTable),
|
||||
}
|
||||
|
||||
h.aggregates = tt.fields.aggregates
|
||||
|
||||
|
@@ -4,8 +4,11 @@ import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/lib/pq"
|
||||
"github.com/zitadel/logging"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/errors"
|
||||
@@ -14,20 +17,20 @@ import (
|
||||
|
||||
const (
|
||||
lockStmtFormat = "INSERT INTO %[1]s" +
|
||||
" (locker_id, locked_until, projection_name, instance_id) VALUES ($1, now()+$2::INTERVAL, $3, $4)" +
|
||||
" (locker_id, locked_until, projection_name, instance_id) VALUES %[2]s" +
|
||||
" ON CONFLICT (projection_name, instance_id)" +
|
||||
" DO UPDATE SET locker_id = $1, locked_until = now()+$2::INTERVAL" +
|
||||
" WHERE %[1]s.projection_name = $3 AND %[1]s.instance_id = $4 AND (%[1]s.locker_id = $1 OR %[1]s.locked_until < now())"
|
||||
" WHERE %[1]s.projection_name = $3 AND %[1]s.instance_id = ANY ($%[3]d) AND (%[1]s.locker_id = $1 OR %[1]s.locked_until < now())"
|
||||
)
|
||||
|
||||
type Locker interface {
|
||||
Lock(ctx context.Context, lockDuration time.Duration, instanceID string) <-chan error
|
||||
Unlock(instanceID string) error
|
||||
Lock(ctx context.Context, lockDuration time.Duration, instanceIDs ...string) <-chan error
|
||||
Unlock(instanceIDs ...string) error
|
||||
}
|
||||
|
||||
type locker struct {
|
||||
client *sql.DB
|
||||
lockStmt string
|
||||
lockStmt func(values string, instances int) string
|
||||
workerName string
|
||||
projectionName string
|
||||
}
|
||||
@@ -36,25 +39,27 @@ func NewLocker(client *sql.DB, lockTable, projectionName string) Locker {
|
||||
workerName, err := id.SonyFlakeGenerator().Next()
|
||||
logging.OnError(err).Panic("unable to generate lockID")
|
||||
return &locker{
|
||||
client: client,
|
||||
lockStmt: fmt.Sprintf(lockStmtFormat, lockTable),
|
||||
client: client,
|
||||
lockStmt: func(values string, instances int) string {
|
||||
return fmt.Sprintf(lockStmtFormat, lockTable, values, instances)
|
||||
},
|
||||
workerName: workerName,
|
||||
projectionName: projectionName,
|
||||
}
|
||||
}
|
||||
|
||||
func (h *locker) Lock(ctx context.Context, lockDuration time.Duration, instanceID string) <-chan error {
|
||||
func (h *locker) Lock(ctx context.Context, lockDuration time.Duration, instanceIDs ...string) <-chan error {
|
||||
errs := make(chan error)
|
||||
go h.handleLock(ctx, errs, lockDuration, instanceID)
|
||||
go h.handleLock(ctx, errs, lockDuration, instanceIDs...)
|
||||
return errs
|
||||
}
|
||||
|
||||
func (h *locker) handleLock(ctx context.Context, errs chan error, lockDuration time.Duration, instanceID string) {
|
||||
func (h *locker) handleLock(ctx context.Context, errs chan error, lockDuration time.Duration, instanceIDs ...string) {
|
||||
renewLock := time.NewTimer(0)
|
||||
for {
|
||||
select {
|
||||
case <-renewLock.C:
|
||||
errs <- h.renewLock(ctx, lockDuration, instanceID)
|
||||
errs <- h.renewLock(ctx, lockDuration, instanceIDs...)
|
||||
//refresh the lock 500ms before it times out. 500ms should be enough for one transaction
|
||||
renewLock.Reset(lockDuration - (500 * time.Millisecond))
|
||||
case <-ctx.Done():
|
||||
@@ -65,24 +70,38 @@ func (h *locker) handleLock(ctx context.Context, errs chan error, lockDuration t
|
||||
}
|
||||
}
|
||||
|
||||
func (h *locker) renewLock(ctx context.Context, lockDuration time.Duration, instanceID string) error {
|
||||
//the unit of crdb interval is seconds (https://www.cockroachlabs.com/docs/stable/interval.html).
|
||||
res, err := h.client.ExecContext(ctx, h.lockStmt, h.workerName, lockDuration.Seconds(), h.projectionName, instanceID)
|
||||
func (h *locker) renewLock(ctx context.Context, lockDuration time.Duration, instanceIDs ...string) error {
|
||||
lockStmt, values := h.lockStatement(lockDuration, instanceIDs)
|
||||
res, err := h.client.ExecContext(ctx, lockStmt, values...)
|
||||
if err != nil {
|
||||
return errors.ThrowInternal(err, "CRDB-uaDoR", "unable to execute lock")
|
||||
}
|
||||
|
||||
if rows, _ := res.RowsAffected(); rows == 0 {
|
||||
return errors.ThrowAlreadyExists(nil, "CRDB-mmi4J", "projection already locked")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *locker) Unlock(instanceID string) error {
|
||||
_, err := h.client.Exec(h.lockStmt, h.workerName, float64(0), h.projectionName, instanceID)
|
||||
func (h *locker) Unlock(instanceIDs ...string) error {
|
||||
lockStmt, values := h.lockStatement(0, instanceIDs)
|
||||
_, err := h.client.Exec(lockStmt, values...)
|
||||
if err != nil {
|
||||
return errors.ThrowUnknown(err, "CRDB-JjfwO", "unlock failed")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *locker) lockStatement(lockDuration time.Duration, instanceIDs []string) (string, []interface{}) {
|
||||
valueQueries := make([]string, len(instanceIDs))
|
||||
values := make([]interface{}, len(instanceIDs)+4)
|
||||
values[0] = h.workerName
|
||||
//the unit of crdb interval is seconds (https://www.cockroachlabs.com/docs/stable/interval.html).
|
||||
values[1] = lockDuration.Seconds()
|
||||
values[2] = h.projectionName
|
||||
for i, instanceID := range instanceIDs {
|
||||
valueQueries[i] = "($1, now()+$2::INTERVAL, $3, $" + strconv.Itoa(i+4) + ")"
|
||||
values[i+3] = instanceID
|
||||
}
|
||||
values[len(values)-1] = pq.StringArray(instanceIDs)
|
||||
return h.lockStmt(strings.Join(valueQueries, ", "), len(values)), values
|
||||
}
|
||||
|
@@ -32,7 +32,7 @@ func TestStatementHandler_handleLock(t *testing.T) {
|
||||
lockDuration time.Duration
|
||||
ctx context.Context
|
||||
errMock *errsMock
|
||||
instanceID string
|
||||
instanceIDs []string
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
@@ -56,7 +56,7 @@ func TestStatementHandler_handleLock(t *testing.T) {
|
||||
successfulIters: 2,
|
||||
shouldErr: true,
|
||||
},
|
||||
instanceID: "instanceID",
|
||||
instanceIDs: []string{"instanceID"},
|
||||
},
|
||||
},
|
||||
{
|
||||
@@ -74,7 +74,25 @@ func TestStatementHandler_handleLock(t *testing.T) {
|
||||
errs: make(chan error),
|
||||
successfulIters: 2,
|
||||
},
|
||||
instanceID: "instanceID",
|
||||
instanceIDs: []string{"instanceID"},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "success with multiple",
|
||||
want: want{
|
||||
expectations: []mockExpectation{
|
||||
expectLockMultipleInstances(lockTable, workerName, 2, "instanceID1", "instanceID2"),
|
||||
expectLockMultipleInstances(lockTable, workerName, 2, "instanceID1", "instanceID2"),
|
||||
},
|
||||
},
|
||||
args: args{
|
||||
lockDuration: 2 * time.Second,
|
||||
ctx: context.Background(),
|
||||
errMock: &errsMock{
|
||||
errs: make(chan error),
|
||||
successfulIters: 2,
|
||||
},
|
||||
instanceIDs: []string{"instanceID1", "instanceID2"},
|
||||
},
|
||||
},
|
||||
}
|
||||
@@ -88,7 +106,9 @@ func TestStatementHandler_handleLock(t *testing.T) {
|
||||
projectionName: projectionName,
|
||||
client: client,
|
||||
workerName: workerName,
|
||||
lockStmt: fmt.Sprintf(lockStmtFormat, lockTable),
|
||||
lockStmt: func(values string, instances int) string {
|
||||
return fmt.Sprintf(lockStmtFormat, lockTable, values, instances)
|
||||
},
|
||||
}
|
||||
|
||||
for _, expectation := range tt.want.expectations {
|
||||
@@ -99,7 +119,7 @@ func TestStatementHandler_handleLock(t *testing.T) {
|
||||
|
||||
go tt.args.errMock.handleErrs(t, cancel)
|
||||
|
||||
go h.handleLock(ctx, tt.args.errMock.errs, tt.args.lockDuration, tt.args.instanceID)
|
||||
go h.handleLock(ctx, tt.args.errMock.errs, tt.args.lockDuration, tt.args.instanceIDs...)
|
||||
|
||||
<-ctx.Done()
|
||||
|
||||
@@ -118,7 +138,7 @@ func TestStatementHandler_renewLock(t *testing.T) {
|
||||
}
|
||||
type args struct {
|
||||
lockDuration time.Duration
|
||||
instanceID string
|
||||
instanceIDs []string
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
@@ -137,7 +157,7 @@ func TestStatementHandler_renewLock(t *testing.T) {
|
||||
},
|
||||
args: args{
|
||||
lockDuration: 1 * time.Second,
|
||||
instanceID: "instanceID",
|
||||
instanceIDs: []string{"instanceID"},
|
||||
},
|
||||
},
|
||||
{
|
||||
@@ -152,7 +172,7 @@ func TestStatementHandler_renewLock(t *testing.T) {
|
||||
},
|
||||
args: args{
|
||||
lockDuration: 2 * time.Second,
|
||||
instanceID: "instanceID",
|
||||
instanceIDs: []string{"instanceID"},
|
||||
},
|
||||
},
|
||||
{
|
||||
@@ -167,7 +187,22 @@ func TestStatementHandler_renewLock(t *testing.T) {
|
||||
},
|
||||
args: args{
|
||||
lockDuration: 3 * time.Second,
|
||||
instanceID: "instanceID",
|
||||
instanceIDs: []string{"instanceID"},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "success with multiple",
|
||||
want: want{
|
||||
expectations: []mockExpectation{
|
||||
expectLockMultipleInstances(lockTable, workerName, 3, "instanceID1", "instanceID2"),
|
||||
},
|
||||
isErr: func(err error) bool {
|
||||
return errors.Is(err, nil)
|
||||
},
|
||||
},
|
||||
args: args{
|
||||
lockDuration: 3 * time.Second,
|
||||
instanceIDs: []string{"instanceID1", "instanceID2"},
|
||||
},
|
||||
},
|
||||
}
|
||||
@@ -181,14 +216,16 @@ func TestStatementHandler_renewLock(t *testing.T) {
|
||||
projectionName: projectionName,
|
||||
client: client,
|
||||
workerName: workerName,
|
||||
lockStmt: fmt.Sprintf(lockStmtFormat, lockTable),
|
||||
lockStmt: func(values string, instances int) string {
|
||||
return fmt.Sprintf(lockStmtFormat, lockTable, values, instances)
|
||||
},
|
||||
}
|
||||
|
||||
for _, expectation := range tt.want.expectations {
|
||||
expectation(mock)
|
||||
}
|
||||
|
||||
err = h.renewLock(context.Background(), tt.args.lockDuration, tt.args.instanceID)
|
||||
err = h.renewLock(context.Background(), tt.args.lockDuration, tt.args.instanceIDs...)
|
||||
if !tt.want.isErr(err) {
|
||||
t.Errorf("unexpected error = %v", err)
|
||||
}
|
||||
@@ -253,7 +290,9 @@ func TestStatementHandler_Unlock(t *testing.T) {
|
||||
projectionName: projectionName,
|
||||
client: client,
|
||||
workerName: workerName,
|
||||
lockStmt: fmt.Sprintf(lockStmtFormat, lockTable),
|
||||
lockStmt: func(values string, instances int) string {
|
||||
return fmt.Sprintf(lockStmtFormat, lockTable, values, instances)
|
||||
},
|
||||
}
|
||||
|
||||
for _, expectation := range tt.want.expectations {
|
||||
|
@@ -27,3 +27,10 @@ func (h *Handler) Subscribe(aggregates ...eventstore.AggregateType) {
|
||||
func (h *Handler) SubscribeEvents(types map[eventstore.AggregateType][]eventstore.EventType) {
|
||||
h.Sub = eventstore.SubscribeEventTypes(h.EventQueue, types)
|
||||
}
|
||||
|
||||
func (h *Handler) Unsubscribe() {
|
||||
if h.Sub == nil {
|
||||
return
|
||||
}
|
||||
h.Sub.Unsubscribe()
|
||||
}
|
||||
|
@@ -2,13 +2,13 @@ package handler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"runtime/debug"
|
||||
"sort"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/zitadel/logging"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/api/authz"
|
||||
"github.com/zitadel/zitadel/internal/eventstore"
|
||||
)
|
||||
|
||||
@@ -16,241 +16,207 @@ const systemID = "system"
|
||||
|
||||
type ProjectionHandlerConfig struct {
|
||||
HandlerConfig
|
||||
ProjectionName string
|
||||
RequeueEvery time.Duration
|
||||
RetryFailedAfter time.Duration
|
||||
ProjectionName string
|
||||
RequeueEvery time.Duration
|
||||
RetryFailedAfter time.Duration
|
||||
Retries uint
|
||||
ConcurrentInstances uint
|
||||
}
|
||||
|
||||
//Update updates the projection with the given statements
|
||||
type Update func(context.Context, []*Statement, Reduce) (unexecutedStmts []*Statement, err error)
|
||||
type Update func(context.Context, []*Statement, Reduce) (index int, err error)
|
||||
|
||||
//Reduce reduces the given event to a statement
|
||||
//which is used to update the projection
|
||||
type Reduce func(eventstore.Event) (*Statement, error)
|
||||
|
||||
//SearchQuery generates the search query to lookup for events
|
||||
type SearchQuery func(ctx context.Context, instanceIDs []string) (query *eventstore.SearchQueryBuilder, queryLimit uint64, err error)
|
||||
|
||||
//Lock is used for mutex handling if needed on the projection
|
||||
type Lock func(context.Context, time.Duration, string) <-chan error
|
||||
type Lock func(context.Context, time.Duration, ...string) <-chan error
|
||||
|
||||
//Unlock releases the mutex of the projection
|
||||
type Unlock func(string) error
|
||||
|
||||
//SearchQuery generates the search query to lookup for events
|
||||
type SearchQuery func(ctx context.Context) (query *eventstore.SearchQueryBuilder, queryLimit uint64, err error)
|
||||
type Unlock func(...string) error
|
||||
|
||||
type ProjectionHandler struct {
|
||||
Handler
|
||||
|
||||
requeueAfter time.Duration
|
||||
shouldBulk *time.Timer
|
||||
bulkMu sync.Mutex
|
||||
bulkLocked bool
|
||||
execBulk executeBulk
|
||||
|
||||
retryFailedAfter time.Duration
|
||||
shouldPush *time.Timer
|
||||
pushSet bool
|
||||
|
||||
ProjectionName string
|
||||
|
||||
lockMu sync.Mutex
|
||||
stmts []*Statement
|
||||
ProjectionName string
|
||||
reduce Reduce
|
||||
update Update
|
||||
searchQuery SearchQuery
|
||||
triggerProjection *time.Timer
|
||||
lock Lock
|
||||
unlock Unlock
|
||||
requeueAfter time.Duration
|
||||
retryFailedAfter time.Duration
|
||||
retries int
|
||||
concurrentInstances int
|
||||
}
|
||||
|
||||
func NewProjectionHandler(
|
||||
ctx context.Context,
|
||||
config ProjectionHandlerConfig,
|
||||
reduce Reduce,
|
||||
update Update,
|
||||
query SearchQuery,
|
||||
lock Lock,
|
||||
unlock Unlock,
|
||||
) *ProjectionHandler {
|
||||
concurrentInstances := int(config.ConcurrentInstances)
|
||||
if concurrentInstances < 1 {
|
||||
concurrentInstances = 1
|
||||
}
|
||||
h := &ProjectionHandler{
|
||||
Handler: NewHandler(config.HandlerConfig),
|
||||
ProjectionName: config.ProjectionName,
|
||||
requeueAfter: config.RequeueEvery,
|
||||
// first bulk is instant on startup
|
||||
shouldBulk: time.NewTimer(0),
|
||||
shouldPush: time.NewTimer(0),
|
||||
retryFailedAfter: config.RetryFailedAfter,
|
||||
Handler: NewHandler(config.HandlerConfig),
|
||||
ProjectionName: config.ProjectionName,
|
||||
reduce: reduce,
|
||||
update: update,
|
||||
searchQuery: query,
|
||||
lock: lock,
|
||||
unlock: unlock,
|
||||
requeueAfter: config.RequeueEvery,
|
||||
triggerProjection: time.NewTimer(0), // first trigger is instant on startup
|
||||
retryFailedAfter: config.RetryFailedAfter,
|
||||
retries: int(config.Retries),
|
||||
concurrentInstances: concurrentInstances,
|
||||
}
|
||||
|
||||
h.execBulk = h.prepareExecuteBulk(query, reduce, update)
|
||||
go h.subscribe(ctx)
|
||||
|
||||
//unitialized timer
|
||||
//https://github.com/golang/go/issues/12721
|
||||
<-h.shouldPush.C
|
||||
go h.schedule(ctx)
|
||||
|
||||
if config.RequeueEvery <= 0 {
|
||||
if !h.shouldBulk.Stop() {
|
||||
<-h.shouldBulk.C
|
||||
}
|
||||
logging.WithFields("projection", h.ProjectionName).Info("starting handler without requeue")
|
||||
return h
|
||||
} else if config.RequeueEvery < 500*time.Millisecond {
|
||||
logging.WithFields("projection", h.ProjectionName).Fatal("requeue every must be greater 500ms or <= 0")
|
||||
}
|
||||
logging.WithFields("projection", h.ProjectionName).Info("starting handler")
|
||||
return h
|
||||
}
|
||||
|
||||
func (h *ProjectionHandler) ResetShouldBulk() {
|
||||
if h.requeueAfter > 0 {
|
||||
h.shouldBulk.Reset(h.requeueAfter)
|
||||
//Trigger handles all events for the provided instances (or current instance from context if non specified)
|
||||
//by calling FetchEvents and Process until the amount of events is smaller than the BulkLimit
|
||||
func (h *ProjectionHandler) Trigger(ctx context.Context, instances ...string) error {
|
||||
ids := []string{authz.GetInstance(ctx).InstanceID()}
|
||||
if len(instances) > 0 {
|
||||
ids = instances
|
||||
}
|
||||
}
|
||||
|
||||
func (h *ProjectionHandler) triggerShouldPush(after time.Duration) {
|
||||
if !h.pushSet {
|
||||
h.pushSet = true
|
||||
h.shouldPush.Reset(after)
|
||||
}
|
||||
}
|
||||
|
||||
//Process waits for several conditions:
|
||||
// if context is canceled the function gracefully shuts down
|
||||
// if an event occures it reduces the event
|
||||
// if the internal timer expires the handler will check
|
||||
// for unprocessed events on eventstore
|
||||
func (h *ProjectionHandler) Process(
|
||||
ctx context.Context,
|
||||
reduce Reduce,
|
||||
update Update,
|
||||
lock Lock,
|
||||
unlock Unlock,
|
||||
query SearchQuery,
|
||||
) {
|
||||
//handle panic
|
||||
defer func() {
|
||||
cause := recover()
|
||||
logging.WithFields("projection", h.ProjectionName, "cause", cause, "stack", string(debug.Stack())).Error("projection handler paniced")
|
||||
}()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
if h.pushSet {
|
||||
h.push(context.Background(), update, reduce)
|
||||
events, hasLimitExceeded, err := h.FetchEvents(ctx, ids...)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if len(events) == 0 {
|
||||
return nil
|
||||
}
|
||||
_, err = h.Process(ctx, events...)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !hasLimitExceeded {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
//Process handles multiple events by reducing them to statements and updating the projection
|
||||
func (h *ProjectionHandler) Process(ctx context.Context, events ...eventstore.Event) (index int, err error) {
|
||||
if len(events) == 0 {
|
||||
return 0, nil
|
||||
}
|
||||
index = -1
|
||||
statements := make([]*Statement, len(events))
|
||||
for i, event := range events {
|
||||
statements[i], err = h.reduce(event)
|
||||
if err != nil {
|
||||
return index, err
|
||||
}
|
||||
}
|
||||
for retry := 0; retry <= h.retries; retry++ {
|
||||
index, err = h.update(ctx, statements[index+1:], h.reduce)
|
||||
if err != nil && !errors.Is(err, ErrSomeStmtsFailed) {
|
||||
return index, err
|
||||
}
|
||||
if err == nil {
|
||||
return index, nil
|
||||
}
|
||||
time.Sleep(h.retryFailedAfter)
|
||||
}
|
||||
return index, err
|
||||
}
|
||||
|
||||
//FetchEvents checks the current sequences and filters for newer events
|
||||
func (h *ProjectionHandler) FetchEvents(ctx context.Context, instances ...string) ([]eventstore.Event, bool, error) {
|
||||
eventQuery, eventsLimit, err := h.searchQuery(ctx, instances)
|
||||
if err != nil {
|
||||
return nil, false, err
|
||||
}
|
||||
events, err := h.Eventstore.Filter(ctx, eventQuery)
|
||||
if err != nil {
|
||||
return nil, false, err
|
||||
}
|
||||
return events, int(eventsLimit) == len(events), err
|
||||
}
|
||||
|
||||
func (h *ProjectionHandler) subscribe(ctx context.Context) {
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
defer func() {
|
||||
err := recover()
|
||||
if err != nil {
|
||||
h.Handler.Unsubscribe()
|
||||
logging.WithFields("projection", h.ProjectionName).Errorf("subscription panicked: %v", err)
|
||||
}
|
||||
cancel()
|
||||
}()
|
||||
for firstEvent := range h.EventQueue {
|
||||
events := checkAdditionalEvents(h.EventQueue, firstEvent)
|
||||
|
||||
index, err := h.Process(ctx, events...)
|
||||
if err != nil || index < len(events)-1 {
|
||||
logging.WithFields("projection", h.ProjectionName).WithError(err).Error("unable to process all events from subscription")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (h *ProjectionHandler) schedule(ctx context.Context) {
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
defer func() {
|
||||
err := recover()
|
||||
if err != nil {
|
||||
logging.WithFields("projection", h.ProjectionName, "cause", err, "stack", string(debug.Stack())).Error("schedule panicked")
|
||||
}
|
||||
cancel()
|
||||
}()
|
||||
for range h.triggerProjection.C {
|
||||
ids, err := h.Eventstore.InstanceIDs(ctx, eventstore.NewSearchQueryBuilder(eventstore.ColumnsInstanceIDs).AddQuery().ExcludedInstanceID("").Builder())
|
||||
if err != nil {
|
||||
logging.WithFields("projection", h.ProjectionName).WithError(err).Error("instance ids")
|
||||
h.triggerProjection.Reset(h.requeueAfter)
|
||||
continue
|
||||
}
|
||||
for i := 0; i < len(ids); i = i + h.concurrentInstances {
|
||||
max := i + h.concurrentInstances
|
||||
if max > len(ids) {
|
||||
max = len(ids)
|
||||
}
|
||||
h.shutdown()
|
||||
return
|
||||
case event := <-h.EventQueue:
|
||||
if err := h.processEvent(ctx, event, reduce); err != nil {
|
||||
logging.WithFields("projection", h.ProjectionName).WithError(err).Warn("process failed")
|
||||
instances := ids[i:max]
|
||||
lockCtx, cancelLock := context.WithCancel(ctx)
|
||||
errs := h.lock(lockCtx, h.requeueAfter, instances...)
|
||||
//wait until projection is locked
|
||||
if err, ok := <-errs; err != nil || !ok {
|
||||
cancelLock()
|
||||
logging.WithFields("projection", h.ProjectionName).OnError(err).Warn("initial lock failed")
|
||||
continue
|
||||
}
|
||||
h.triggerShouldPush(0)
|
||||
case <-h.shouldBulk.C:
|
||||
h.bulkMu.Lock()
|
||||
h.bulkLocked = true
|
||||
h.bulk(ctx, lock, unlock)
|
||||
h.ResetShouldBulk()
|
||||
h.bulkLocked = false
|
||||
h.bulkMu.Unlock()
|
||||
default:
|
||||
//lower prio select with push
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
if h.pushSet {
|
||||
h.push(context.Background(), update, reduce)
|
||||
}
|
||||
h.shutdown()
|
||||
return
|
||||
case event := <-h.EventQueue:
|
||||
if err := h.processEvent(ctx, event, reduce); err != nil {
|
||||
logging.WithFields("projection", h.ProjectionName).WithError(err).Warn("process failed")
|
||||
continue
|
||||
}
|
||||
h.triggerShouldPush(0)
|
||||
case <-h.shouldBulk.C:
|
||||
h.bulkMu.Lock()
|
||||
h.bulkLocked = true
|
||||
h.bulk(ctx, lock, unlock)
|
||||
h.ResetShouldBulk()
|
||||
h.bulkLocked = false
|
||||
h.bulkMu.Unlock()
|
||||
case <-h.shouldPush.C:
|
||||
h.push(ctx, update, reduce)
|
||||
h.ResetShouldBulk()
|
||||
go h.cancelOnErr(lockCtx, errs, cancelLock)
|
||||
err = h.Trigger(lockCtx, instances...)
|
||||
if err != nil {
|
||||
logging.WithFields("projection", h.ProjectionName, "instanceIDs", instances).WithError(err).Error("trigger failed")
|
||||
}
|
||||
|
||||
cancelLock()
|
||||
unlockErr := h.unlock(instances...)
|
||||
logging.WithFields("projection", h.ProjectionName).OnError(unlockErr).Warn("unable to unlock")
|
||||
}
|
||||
h.triggerProjection.Reset(h.requeueAfter)
|
||||
}
|
||||
}
|
||||
|
||||
func (h *ProjectionHandler) processEvent(
|
||||
ctx context.Context,
|
||||
event eventstore.Event,
|
||||
reduce Reduce,
|
||||
) error {
|
||||
stmt, err := reduce(event)
|
||||
if err != nil {
|
||||
logging.New().WithError(err).Warn("unable to process event")
|
||||
return err
|
||||
}
|
||||
|
||||
h.lockMu.Lock()
|
||||
defer h.lockMu.Unlock()
|
||||
|
||||
h.stmts = append(h.stmts, stmt)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *ProjectionHandler) TriggerBulk(
|
||||
ctx context.Context,
|
||||
lock Lock,
|
||||
unlock Unlock,
|
||||
) error {
|
||||
if !h.shouldBulk.Stop() {
|
||||
//make sure to flush shouldBulk chan
|
||||
select {
|
||||
case <-h.shouldBulk.C:
|
||||
default:
|
||||
}
|
||||
}
|
||||
defer h.ResetShouldBulk()
|
||||
|
||||
h.bulkMu.Lock()
|
||||
if h.bulkLocked {
|
||||
logging.WithFields("projection", h.ProjectionName).Debugf("waiting for existing bulk to finish")
|
||||
h.bulkMu.Unlock()
|
||||
return nil
|
||||
}
|
||||
h.bulkLocked = true
|
||||
defer func() {
|
||||
h.bulkLocked = false
|
||||
h.bulkMu.Unlock()
|
||||
}()
|
||||
|
||||
return h.bulk(ctx, lock, unlock)
|
||||
}
|
||||
|
||||
func (h *ProjectionHandler) bulk(
|
||||
ctx context.Context,
|
||||
lock Lock,
|
||||
unlock Unlock,
|
||||
) error {
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
|
||||
errs := lock(ctx, h.requeueAfter, systemID)
|
||||
//wait until projection is locked
|
||||
if err, ok := <-errs; err != nil || !ok {
|
||||
logging.WithFields("projection", h.ProjectionName).OnError(err).Warn("initial lock failed")
|
||||
return err
|
||||
}
|
||||
go h.cancelOnErr(ctx, errs, cancel)
|
||||
|
||||
execErr := h.execBulk(ctx)
|
||||
logging.WithFields("projection", h.ProjectionName).OnError(execErr).Warn("unable to execute")
|
||||
|
||||
unlockErr := unlock(systemID)
|
||||
logging.WithFields("projection", h.ProjectionName).OnError(unlockErr).Warn("unable to unlock")
|
||||
|
||||
if execErr != nil {
|
||||
return execErr
|
||||
}
|
||||
|
||||
return unlockErr
|
||||
}
|
||||
|
||||
func (h *ProjectionHandler) cancelOnErr(ctx context.Context, errs <-chan error, cancel func()) {
|
||||
for {
|
||||
select {
|
||||
@@ -268,98 +234,15 @@ func (h *ProjectionHandler) cancelOnErr(ctx context.Context, errs <-chan error,
|
||||
}
|
||||
}
|
||||
|
||||
type executeBulk func(ctx context.Context) error
|
||||
|
||||
func (h *ProjectionHandler) prepareExecuteBulk(
|
||||
query SearchQuery,
|
||||
reduce Reduce,
|
||||
update Update,
|
||||
) executeBulk {
|
||||
return func(ctx context.Context) error {
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil
|
||||
default:
|
||||
hasLimitExeeded, err := h.fetchBulkStmts(ctx, query, reduce)
|
||||
if err != nil || len(h.stmts) == 0 {
|
||||
logging.WithFields("projection", h.ProjectionName).OnError(err).Warn("unable to fetch stmts")
|
||||
return err
|
||||
}
|
||||
|
||||
if err = h.push(ctx, update, reduce); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if !hasLimitExeeded {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
func checkAdditionalEvents(eventQueue chan eventstore.Event, event eventstore.Event) []eventstore.Event {
|
||||
events := make([]eventstore.Event, 1)
|
||||
events[0] = event
|
||||
for {
|
||||
select {
|
||||
case event := <-eventQueue:
|
||||
events = append(events, event)
|
||||
default:
|
||||
return events
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (h *ProjectionHandler) fetchBulkStmts(
|
||||
ctx context.Context,
|
||||
query SearchQuery,
|
||||
reduce Reduce,
|
||||
) (limitExeeded bool, err error) {
|
||||
eventQuery, eventsLimit, err := query(ctx)
|
||||
if err != nil {
|
||||
logging.WithFields("projection", h.ProjectionName).WithError(err).Warn("unable to create event query")
|
||||
return false, err
|
||||
}
|
||||
|
||||
events, err := h.Eventstore.Filter(ctx, eventQuery)
|
||||
if err != nil {
|
||||
logging.WithFields("projection", h.ProjectionName).WithError(err).Info("Unable to bulk fetch events")
|
||||
return false, err
|
||||
}
|
||||
|
||||
for _, event := range events {
|
||||
if err = h.processEvent(ctx, event, reduce); err != nil {
|
||||
logging.WithFields("projection", h.ProjectionName, "sequence", event.Sequence(), "instanceID", event.Aggregate().InstanceID).WithError(err).Warn("unable to process event in bulk")
|
||||
return false, err
|
||||
}
|
||||
}
|
||||
|
||||
return len(events) == int(eventsLimit), nil
|
||||
}
|
||||
|
||||
func (h *ProjectionHandler) push(
|
||||
ctx context.Context,
|
||||
update Update,
|
||||
reduce Reduce,
|
||||
) (err error) {
|
||||
h.lockMu.Lock()
|
||||
defer h.lockMu.Unlock()
|
||||
|
||||
sort.Slice(h.stmts, func(i, j int) bool {
|
||||
return h.stmts[i].Sequence < h.stmts[j].Sequence
|
||||
})
|
||||
|
||||
h.stmts, err = update(ctx, h.stmts, reduce)
|
||||
h.pushSet = len(h.stmts) > 0
|
||||
|
||||
if h.pushSet {
|
||||
h.triggerShouldPush(h.retryFailedAfter)
|
||||
return nil
|
||||
}
|
||||
|
||||
h.shouldPush.Stop()
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (h *ProjectionHandler) shutdown() {
|
||||
h.lockMu.Lock()
|
||||
defer h.lockMu.Unlock()
|
||||
h.Sub.Unsubscribe()
|
||||
if !h.shouldBulk.Stop() {
|
||||
<-h.shouldBulk.C
|
||||
}
|
||||
if !h.shouldPush.Stop() {
|
||||
<-h.shouldPush.C
|
||||
}
|
||||
logging.New().Info("stop processing")
|
||||
}
|
||||
|
File diff suppressed because it is too large
Load Diff
@@ -8,8 +8,8 @@ import (
|
||||
context "context"
|
||||
reflect "reflect"
|
||||
|
||||
repository "github.com/zitadel/zitadel/internal/eventstore/repository"
|
||||
gomock "github.com/golang/mock/gomock"
|
||||
repository "github.com/zitadel/zitadel/internal/eventstore/repository"
|
||||
)
|
||||
|
||||
// MockRepository is a mock of Repository interface.
|
||||
@@ -78,6 +78,21 @@ func (mr *MockRepositoryMockRecorder) Health(arg0 interface{}) *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Health", reflect.TypeOf((*MockRepository)(nil).Health), arg0)
|
||||
}
|
||||
|
||||
// InstanceIDs mocks base method.
|
||||
func (m *MockRepository) InstanceIDs(arg0 context.Context, arg1 *repository.SearchQuery) ([]string, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "InstanceIDs", arg0, arg1)
|
||||
ret0, _ := ret[0].([]string)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// InstanceIDs indicates an expected call of InstanceIDs.
|
||||
func (mr *MockRepositoryMockRecorder) InstanceIDs(arg0, arg1 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "InstanceIDs", reflect.TypeOf((*MockRepository)(nil).InstanceIDs), arg0, arg1)
|
||||
}
|
||||
|
||||
// LatestSequence mocks base method.
|
||||
func (m *MockRepository) LatestSequence(arg0 context.Context, arg1 *repository.SearchQuery) (uint64, error) {
|
||||
m.ctrl.T.Helper()
|
||||
|
@@ -29,6 +29,16 @@ func (m *MockRepository) ExpectFilterEventsError(err error) *MockRepository {
|
||||
return m
|
||||
}
|
||||
|
||||
func (m *MockRepository) ExpectInstanceIDs(instanceIDs ...string) *MockRepository {
|
||||
m.EXPECT().InstanceIDs(gomock.Any(), gomock.Any()).Return(instanceIDs, nil)
|
||||
return m
|
||||
}
|
||||
|
||||
func (m *MockRepository) ExpectInstanceIDsError(err error) *MockRepository {
|
||||
m.EXPECT().InstanceIDs(gomock.Any(), gomock.Any()).Return(nil, err)
|
||||
return m
|
||||
}
|
||||
|
||||
func (m *MockRepository) ExpectPush(expectedEvents []*repository.Event, expectedUniqueConstraints ...*repository.UniqueConstraint) *MockRepository {
|
||||
m.EXPECT().Push(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(
|
||||
func(ctx context.Context, events []*repository.Event, uniqueConstraints ...*repository.UniqueConstraint) error {
|
||||
|
@@ -8,14 +8,16 @@ import (
|
||||
type Repository interface {
|
||||
//Health checks if the connection to the storage is available
|
||||
Health(ctx context.Context) error
|
||||
// PushEvents adds all events of the given aggregates to the eventstreams of the aggregates.
|
||||
// Push adds all events of the given aggregates to the event streams of the aggregates.
|
||||
// if unique constraints are pushed, they will be added to the unique table for checking unique constraint violations
|
||||
// This call is transaction save. The transaction will be rolled back if one event fails
|
||||
Push(ctx context.Context, events []*Event, uniqueConstraints ...*UniqueConstraint) error
|
||||
// Filter returns all events matching the given search query
|
||||
Filter(ctx context.Context, searchQuery *SearchQuery) (events []*Event, err error)
|
||||
//LatestSequence returns the latests sequence found by the the search query
|
||||
//LatestSequence returns the latest sequence found by the search query
|
||||
LatestSequence(ctx context.Context, queryFactory *SearchQuery) (uint64, error)
|
||||
//InstanceIDs returns the instance ids found by the search query
|
||||
InstanceIDs(ctx context.Context, queryFactory *SearchQuery) ([]string, error)
|
||||
//CreateInstance creates a new sequence for the given instance
|
||||
CreateInstance(ctx context.Context, instanceID string) error
|
||||
}
|
||||
|
@@ -23,6 +23,8 @@ const (
|
||||
ColumnsEvent = iota + 1
|
||||
//ColumnsMaxSequence represents the latest sequence of the filtered events
|
||||
ColumnsMaxSequence
|
||||
// ColumnsInstanceIDs represents the instance ids of the filtered events
|
||||
ColumnsInstanceIDs
|
||||
|
||||
columnsCount
|
||||
)
|
||||
|
@@ -218,7 +218,7 @@ func (db *CRDB) Filter(ctx context.Context, searchQuery *repository.SearchQuery)
|
||||
return events, nil
|
||||
}
|
||||
|
||||
//LatestSequence returns the latests sequence found by the the search query
|
||||
//LatestSequence returns the latest sequence found by the search query
|
||||
func (db *CRDB) LatestSequence(ctx context.Context, searchQuery *repository.SearchQuery) (uint64, error) {
|
||||
var seq Sequence
|
||||
err := query(ctx, db, searchQuery, &seq)
|
||||
@@ -228,6 +228,16 @@ func (db *CRDB) LatestSequence(ctx context.Context, searchQuery *repository.Sear
|
||||
return uint64(seq), nil
|
||||
}
|
||||
|
||||
//InstanceIDs returns the instance ids found by the search query
|
||||
func (db *CRDB) InstanceIDs(ctx context.Context, searchQuery *repository.SearchQuery) ([]string, error) {
|
||||
var ids []string
|
||||
err := query(ctx, db, searchQuery, &ids)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return ids, nil
|
||||
}
|
||||
|
||||
func (db *CRDB) db() *sql.DB {
|
||||
return db.client
|
||||
}
|
||||
@@ -262,6 +272,10 @@ func (db *CRDB) maxSequenceQuery() string {
|
||||
return "SELECT MAX(event_sequence) FROM eventstore.events"
|
||||
}
|
||||
|
||||
func (db *CRDB) instanceIDsQuery() string {
|
||||
return "SELECT DISTINCT instance_id FROM eventstore.events"
|
||||
}
|
||||
|
||||
func (db *CRDB) columnName(col repository.Field) string {
|
||||
switch col {
|
||||
case repository.FieldAggregateID:
|
||||
|
@@ -22,6 +22,7 @@ type querier interface {
|
||||
placeholder(query string) string
|
||||
eventQuery() string
|
||||
maxSequenceQuery() string
|
||||
instanceIDsQuery() string
|
||||
db() *sql.DB
|
||||
orderByEventSequence(desc bool) string
|
||||
}
|
||||
@@ -36,7 +37,7 @@ func query(ctx context.Context, criteria querier, searchQuery *repository.Search
|
||||
}
|
||||
query += where
|
||||
|
||||
if searchQuery.Columns != repository.ColumnsMaxSequence {
|
||||
if searchQuery.Columns == repository.ColumnsEvent {
|
||||
query += criteria.orderByEventSequence(searchQuery.Desc)
|
||||
}
|
||||
|
||||
@@ -76,6 +77,8 @@ func prepareColumns(criteria querier, columns repository.Columns) (string, func(
|
||||
switch columns {
|
||||
case repository.ColumnsMaxSequence:
|
||||
return criteria.maxSequenceQuery(), maxSequenceScanner
|
||||
case repository.ColumnsInstanceIDs:
|
||||
return criteria.instanceIDsQuery(), instanceIDsScanner
|
||||
case repository.ColumnsEvent:
|
||||
return criteria.eventQuery(), eventsScanner
|
||||
default:
|
||||
@@ -95,6 +98,22 @@ func maxSequenceScanner(row scan, dest interface{}) (err error) {
|
||||
return z_errors.ThrowInternal(err, "SQL-bN5xg", "something went wrong")
|
||||
}
|
||||
|
||||
func instanceIDsScanner(scanner scan, dest interface{}) (err error) {
|
||||
ids, ok := dest.(*[]string)
|
||||
if !ok {
|
||||
return z_errors.ThrowInvalidArgument(nil, "SQL-Begh2", "type must be an array of string")
|
||||
}
|
||||
var id string
|
||||
err = scanner(&id)
|
||||
if err != nil {
|
||||
logging.WithError(err).Warn("unable to scan row")
|
||||
return z_errors.ThrowInternal(err, "SQL-DEFGe", "unable to scan row")
|
||||
}
|
||||
*ids = append(*ids, id)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func eventsScanner(scanner scan, dest interface{}) (err error) {
|
||||
events, ok := dest.(*[]*repository.Event)
|
||||
if !ok {
|
||||
@@ -157,7 +176,7 @@ func prepareCondition(criteria querier, filters [][]*repository.Filter) (clause
|
||||
var err error
|
||||
value, err = json.Marshal(value)
|
||||
if err != nil {
|
||||
logging.New().WithError(err).Warn("unable to marshal search value")
|
||||
logging.WithError(err).Warn("unable to marshal search value")
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
@@ -39,6 +39,8 @@ const (
|
||||
ColumnsEvent Columns = repository.ColumnsEvent
|
||||
// ColumnsMaxSequence represents the latest sequence of the filtered events
|
||||
ColumnsMaxSequence Columns = repository.ColumnsMaxSequence
|
||||
// ColumnsInstanceIDs represents the instance ids of the filtered events
|
||||
ColumnsInstanceIDs Columns = repository.ColumnsInstanceIDs
|
||||
)
|
||||
|
||||
// AggregateType is the object name
|
||||
@@ -278,6 +280,9 @@ func (query *SearchQuery) eventTypeFilter() *repository.Filter {
|
||||
}
|
||||
|
||||
func (query *SearchQuery) aggregateTypeFilter() *repository.Filter {
|
||||
if len(query.aggregateTypes) < 1 {
|
||||
return nil
|
||||
}
|
||||
if len(query.aggregateTypes) == 1 {
|
||||
return repository.NewFilter(repository.FieldAggregateType, repository.AggregateType(query.aggregateTypes[0]), repository.OperationEquals)
|
||||
}
|
||||
|
@@ -13,6 +13,7 @@ type Eventstore interface {
|
||||
Health(ctx context.Context) error
|
||||
FilterEvents(ctx context.Context, searchQuery *models.SearchQuery) (events []*models.Event, err error)
|
||||
Subscribe(aggregates ...models.AggregateType) *Subscription
|
||||
InstanceIDs(ctx context.Context, searchQuery *models.SearchQuery) ([]string, error)
|
||||
}
|
||||
|
||||
var _ Eventstore = (*eventstore)(nil)
|
||||
@@ -37,3 +38,10 @@ func (es *eventstore) FilterEvents(ctx context.Context, searchQuery *models.Sear
|
||||
func (es *eventstore) Health(ctx context.Context) error {
|
||||
return es.repo.Health(ctx)
|
||||
}
|
||||
|
||||
func (es *eventstore) InstanceIDs(ctx context.Context, searchQuery *models.SearchQuery) ([]string, error) {
|
||||
if err := searchQuery.Validate(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return es.repo.InstanceIDs(ctx, models.FactoryFromSearchQuery(searchQuery))
|
||||
}
|
||||
|
@@ -11,6 +11,8 @@ type Repository interface {
|
||||
|
||||
// Filter returns all events matching the given search query
|
||||
Filter(ctx context.Context, searchQuery *models.SearchQueryFactory) (events []*models.Event, err error)
|
||||
//LatestSequence returns the latests sequence found by the the search query
|
||||
//LatestSequence returns the latest sequence found by the search query
|
||||
LatestSequence(ctx context.Context, queryFactory *models.SearchQueryFactory) (uint64, error)
|
||||
//InstanceIDs returns the instance ids found by the search query
|
||||
InstanceIDs(ctx context.Context, queryFactory *models.SearchQueryFactory) ([]string, error)
|
||||
}
|
||||
|
@@ -5,6 +5,7 @@ import (
|
||||
"database/sql"
|
||||
|
||||
"github.com/zitadel/logging"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/errors"
|
||||
es_models "github.com/zitadel/zitadel/internal/eventstore/v1/models"
|
||||
"github.com/zitadel/zitadel/internal/telemetry/tracing"
|
||||
@@ -60,3 +61,31 @@ func (db *SQL) LatestSequence(ctx context.Context, queryFactory *es_models.Searc
|
||||
}
|
||||
return uint64(*sequence), nil
|
||||
}
|
||||
|
||||
func (db *SQL) InstanceIDs(ctx context.Context, queryFactory *es_models.SearchQueryFactory) ([]string, error) {
|
||||
query, _, values, rowScanner := buildQuery(queryFactory)
|
||||
if query == "" {
|
||||
return nil, errors.ThrowInvalidArgument(nil, "SQL-Sfwg2", "invalid query factory")
|
||||
}
|
||||
|
||||
rows, err := db.client.Query(query, values...)
|
||||
if err != nil {
|
||||
logging.New().WithError(err).Info("query failed")
|
||||
return nil, errors.ThrowInternal(err, "SQL-Sfg3r", "unable to filter instance ids")
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
ids := make([]string, 0)
|
||||
|
||||
for rows.Next() {
|
||||
var id string
|
||||
err := rowScanner(rows.Scan, &id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ids = append(ids, id)
|
||||
}
|
||||
|
||||
return ids, nil
|
||||
}
|
||||
|
@@ -44,7 +44,7 @@ func buildQuery(queryFactory *es_models.SearchQueryFactory) (query string, limit
|
||||
}
|
||||
query += where
|
||||
|
||||
if searchQuery.Columns != es_models.Columns_Max_Sequence {
|
||||
if searchQuery.Columns == es_models.Columns_Event {
|
||||
query += " ORDER BY event_sequence"
|
||||
if searchQuery.Desc {
|
||||
query += " DESC"
|
||||
@@ -104,6 +104,19 @@ func prepareColumns(columns es_models.Columns) (string, func(s scan, dest interf
|
||||
}
|
||||
return z_errors.ThrowInternal(err, "SQL-bN5xg", "something went wrong")
|
||||
}
|
||||
case es_models.Columns_InstanceIDs:
|
||||
return "SELECT DISTINCT instance_id FROM eventstore.events", func(row scan, dest interface{}) (err error) {
|
||||
instanceID, ok := dest.(*string)
|
||||
if !ok {
|
||||
return z_errors.ThrowInvalidArgument(nil, "SQL-Fef5h", "type must be *string]")
|
||||
}
|
||||
err = row(instanceID)
|
||||
if err != nil {
|
||||
logging.New().WithError(err).Warn("unable to scan row")
|
||||
return z_errors.ThrowInternal(err, "SQL-SFef3", "unable to scan row")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
case es_models.Columns_Event:
|
||||
return selectStmt, func(row scan, dest interface{}) (err error) {
|
||||
event, ok := dest.(*es_models.Event)
|
||||
|
@@ -41,6 +41,7 @@ type Columns int32
|
||||
const (
|
||||
Columns_Event = iota
|
||||
Columns_Max_Sequence
|
||||
Columns_InstanceIDs
|
||||
//insert new columns-types before this columnsCount because count is needed for validation
|
||||
columnsCount
|
||||
)
|
||||
@@ -48,7 +49,7 @@ const (
|
||||
//FactoryFromSearchQuery is deprecated because it's for migration purposes. use NewSearchQueryFactory
|
||||
func FactoryFromSearchQuery(q *SearchQuery) *SearchQueryFactory {
|
||||
factory := &SearchQueryFactory{
|
||||
columns: Columns_Event,
|
||||
columns: q.Columns,
|
||||
desc: q.Desc,
|
||||
limit: q.Limit,
|
||||
queries: make([]*query, len(q.Queries)),
|
||||
@@ -232,6 +233,9 @@ func (q *query) eventTypeFilter() *Filter {
|
||||
}
|
||||
|
||||
func (q *query) aggregateTypeFilter() *Filter {
|
||||
if len(q.aggregateTypes) < 1 {
|
||||
return nil
|
||||
}
|
||||
if len(q.aggregateTypes) == 1 {
|
||||
return NewFilter(Field_AggregateType, q.aggregateTypes[0], Operation_Equals)
|
||||
}
|
||||
|
@@ -8,6 +8,7 @@ import (
|
||||
|
||||
//SearchQuery is deprecated. Use SearchQueryFactory
|
||||
type SearchQuery struct {
|
||||
Columns Columns
|
||||
Limit uint64
|
||||
Desc bool
|
||||
Filters []*Filter
|
||||
@@ -27,6 +28,11 @@ func NewSearchQuery() *SearchQuery {
|
||||
}
|
||||
}
|
||||
|
||||
func (q *SearchQuery) SetColumn(columns Columns) *SearchQuery {
|
||||
q.Columns = columns
|
||||
return q
|
||||
}
|
||||
|
||||
func (q *SearchQuery) AddQuery() *Query {
|
||||
query := &Query{
|
||||
searchQuery: q,
|
||||
|
@@ -2,9 +2,9 @@ package query
|
||||
|
||||
import (
|
||||
"context"
|
||||
"runtime/debug"
|
||||
"time"
|
||||
|
||||
"github.com/getsentry/sentry-go"
|
||||
"github.com/zitadel/logging"
|
||||
|
||||
v1 "github.com/zitadel/zitadel/internal/eventstore/v1"
|
||||
@@ -17,7 +17,7 @@ const (
|
||||
|
||||
type Handler interface {
|
||||
ViewModel() string
|
||||
EventQuery() (*models.SearchQuery, error)
|
||||
EventQuery(instanceIDs ...string) (*models.SearchQuery, error)
|
||||
Reduce(*models.Event) error
|
||||
OnError(event *models.Event, err error) error
|
||||
OnSuccess() error
|
||||
@@ -37,14 +37,13 @@ func ReduceEvent(handler Handler, event *models.Event) {
|
||||
err := recover()
|
||||
|
||||
if err != nil {
|
||||
sentry.CurrentHub().Recover(err)
|
||||
handler.Subscription().Unsubscribe()
|
||||
logging.WithFields("HANDL-SAFe1").Errorf("reduce panicked: %v", err)
|
||||
logging.WithFields("cause", err, "stack", string(debug.Stack())).Error("reduce panicked")
|
||||
}
|
||||
}()
|
||||
currentSequence, err := handler.CurrentSequence(event.InstanceID)
|
||||
if err != nil {
|
||||
logging.New().WithError(err).Warn("unable to get current sequence")
|
||||
logging.WithError(err).Warn("unable to get current sequence")
|
||||
return
|
||||
}
|
||||
|
||||
@@ -58,14 +57,14 @@ func ReduceEvent(handler Handler, event *models.Event) {
|
||||
|
||||
unprocessedEvents, err := handler.Eventstore().FilterEvents(context.Background(), searchQuery)
|
||||
if err != nil {
|
||||
logging.WithFields("HANDL-L6YH1", "sequence", event.Sequence).Warn("filter failed")
|
||||
logging.WithFields("sequence", event.Sequence).Warn("filter failed")
|
||||
return
|
||||
}
|
||||
|
||||
for _, unprocessedEvent := range unprocessedEvents {
|
||||
currentSequence, err := handler.CurrentSequence(unprocessedEvent.InstanceID)
|
||||
if err != nil {
|
||||
logging.Log("HANDL-BmpkC").WithError(err).Warn("unable to get current sequence")
|
||||
logging.WithError(err).Warn("unable to get current sequence")
|
||||
return
|
||||
}
|
||||
if unprocessedEvent.Sequence < currentSequence {
|
||||
@@ -78,12 +77,12 @@ func ReduceEvent(handler Handler, event *models.Event) {
|
||||
}
|
||||
|
||||
err = handler.Reduce(unprocessedEvent)
|
||||
logging.WithFields("HANDL-V42TI", "sequence", unprocessedEvent.Sequence).OnError(err).Warn("reduce failed")
|
||||
logging.WithFields("sequence", unprocessedEvent.Sequence).OnError(err).Warn("reduce failed")
|
||||
}
|
||||
if len(unprocessedEvents) == eventLimit {
|
||||
logging.WithFields("QUERY-BSqe9", "sequence", event.Sequence).Warn("didnt process event")
|
||||
logging.WithFields("sequence", event.Sequence).Warn("didnt process event")
|
||||
return
|
||||
}
|
||||
err = handler.Reduce(event)
|
||||
logging.WithFields("HANDL-wQDL2", "sequence", event.Sequence).OnError(err).Warn("reduce failed")
|
||||
logging.WithFields("sequence", event.Sequence).OnError(err).Warn("reduce failed")
|
||||
}
|
||||
|
@@ -11,10 +11,11 @@ import (
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
Eventstore v1.Eventstore
|
||||
Locker Locker
|
||||
ViewHandlers []query.Handler
|
||||
ConcurrentWorkers int
|
||||
Eventstore v1.Eventstore
|
||||
Locker Locker
|
||||
ViewHandlers []query.Handler
|
||||
ConcurrentWorkers int
|
||||
ConcurrentInstances int
|
||||
}
|
||||
|
||||
func (c *Config) New() *Spooler {
|
||||
@@ -27,11 +28,12 @@ func (c *Config) New() *Spooler {
|
||||
})
|
||||
|
||||
return &Spooler{
|
||||
handlers: c.ViewHandlers,
|
||||
lockID: lockID,
|
||||
eventstore: c.Eventstore,
|
||||
locker: c.Locker,
|
||||
queue: make(chan *spooledHandler, len(c.ViewHandlers)),
|
||||
workers: c.ConcurrentWorkers,
|
||||
handlers: c.ViewHandlers,
|
||||
lockID: lockID,
|
||||
eventstore: c.Eventstore,
|
||||
locker: c.Locker,
|
||||
queue: make(chan *spooledHandler, len(c.ViewHandlers)),
|
||||
workers: c.ConcurrentWorkers,
|
||||
concurrentInstances: c.ConcurrentInstances,
|
||||
}
|
||||
}
|
||||
|
@@ -2,11 +2,11 @@ package spooler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"runtime/debug"
|
||||
"strconv"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/getsentry/sentry-go"
|
||||
"github.com/zitadel/logging"
|
||||
|
||||
v1 "github.com/zitadel/zitadel/internal/eventstore/v1"
|
||||
@@ -19,12 +19,13 @@ import (
|
||||
const systemID = "system"
|
||||
|
||||
type Spooler struct {
|
||||
handlers []query.Handler
|
||||
locker Locker
|
||||
lockID string
|
||||
eventstore v1.Eventstore
|
||||
workers int
|
||||
queue chan *spooledHandler
|
||||
handlers []query.Handler
|
||||
locker Locker
|
||||
lockID string
|
||||
eventstore v1.Eventstore
|
||||
workers int
|
||||
queue chan *spooledHandler
|
||||
concurrentInstances int
|
||||
}
|
||||
|
||||
type Locker interface {
|
||||
@@ -33,9 +34,10 @@ type Locker interface {
|
||||
|
||||
type spooledHandler struct {
|
||||
query.Handler
|
||||
locker Locker
|
||||
queuedAt time.Time
|
||||
eventstore v1.Eventstore
|
||||
locker Locker
|
||||
queuedAt time.Time
|
||||
eventstore v1.Eventstore
|
||||
concurrentInstances int
|
||||
}
|
||||
|
||||
func (s *Spooler) Start() {
|
||||
@@ -55,7 +57,7 @@ func (s *Spooler) Start() {
|
||||
}
|
||||
go func() {
|
||||
for _, handler := range s.handlers {
|
||||
s.queue <- &spooledHandler{Handler: handler, locker: s.locker, queuedAt: time.Now(), eventstore: s.eventstore}
|
||||
s.queue <- &spooledHandler{Handler: handler, locker: s.locker, queuedAt: time.Now(), eventstore: s.eventstore, concurrentInstances: s.concurrentInstances}
|
||||
}
|
||||
}()
|
||||
}
|
||||
@@ -73,7 +75,7 @@ func (s *spooledHandler) load(workerID string) {
|
||||
err := recover()
|
||||
|
||||
if err != nil {
|
||||
sentry.CurrentHub().Recover(err)
|
||||
logging.WithFields("cause", err, "stack", string(debug.Stack())).Error("reduce panicked")
|
||||
}
|
||||
}()
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
@@ -82,29 +84,50 @@ func (s *spooledHandler) load(workerID string) {
|
||||
|
||||
if <-hasLocked {
|
||||
for {
|
||||
events, err := s.query(ctx)
|
||||
ids, err := s.eventstore.InstanceIDs(ctx, models.NewSearchQuery().SetColumn(models.Columns_InstanceIDs).AddQuery().ExcludedInstanceIDsFilter("").SearchQuery())
|
||||
if err != nil {
|
||||
errs <- err
|
||||
break
|
||||
}
|
||||
err = s.process(ctx, events, workerID)
|
||||
if err != nil {
|
||||
errs <- err
|
||||
break
|
||||
}
|
||||
if uint64(len(events)) < s.QueryLimit() {
|
||||
// no more events to process
|
||||
// stop chan
|
||||
if ctx.Err() == nil {
|
||||
errs <- nil
|
||||
for i := 0; i < len(ids); i = i + s.concurrentInstances {
|
||||
max := i + s.concurrentInstances
|
||||
if max > len(ids) {
|
||||
max = len(ids)
|
||||
}
|
||||
err = s.processInstances(ctx, workerID, ids[i:max]...)
|
||||
if err != nil {
|
||||
errs <- err
|
||||
}
|
||||
break
|
||||
}
|
||||
if ctx.Err() == nil {
|
||||
errs <- nil
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
<-ctx.Done()
|
||||
}
|
||||
|
||||
func (s *spooledHandler) processInstances(ctx context.Context, workerID string, ids ...string) error {
|
||||
for {
|
||||
events, err := s.query(ctx, ids...)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if len(events) == 0 {
|
||||
return nil
|
||||
}
|
||||
err = s.process(ctx, events, workerID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if uint64(len(events)) < s.QueryLimit() {
|
||||
// no more events to process
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *spooledHandler) awaitError(cancel func(), errs chan error, workerID string) {
|
||||
select {
|
||||
case err := <-errs:
|
||||
@@ -135,8 +158,8 @@ func (s *spooledHandler) process(ctx context.Context, events []*models.Event, wo
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *spooledHandler) query(ctx context.Context) ([]*models.Event, error) {
|
||||
query, err := s.EventQuery()
|
||||
func (s *spooledHandler) query(ctx context.Context, instanceIDs ...string) ([]*models.Event, error) {
|
||||
query, err := s.EventQuery(instanceIDs...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@@ -47,7 +47,7 @@ func (h *testHandler) Subscription() *v1.Subscription {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *testHandler) EventQuery() (*models.SearchQuery, error) {
|
||||
func (h *testHandler) EventQuery(instanceIDs ...string) (*models.SearchQuery, error) {
|
||||
if h.queryError != nil {
|
||||
return nil, h.queryError
|
||||
}
|
||||
@@ -111,6 +111,9 @@ func (es *eventstoreStub) PushAggregates(ctx context.Context, in ...*models.Aggr
|
||||
func (es *eventstoreStub) LatestSequence(ctx context.Context, in *models.SearchQueryFactory) (uint64, error) {
|
||||
return 0, nil
|
||||
}
|
||||
func (es *eventstoreStub) InstanceIDs(ctx context.Context, in *models.SearchQuery) ([]string, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (es *eventstoreStub) V2() *eventstore.Eventstore {
|
||||
return nil
|
||||
}
|
||||
|
Reference in New Issue
Block a user