fix: scheduling (#3978)

* fix: improve scheduling

* build pre-release

* fix: locker

* fix: user handler and print stack in case of panic in reducer

* chore: remove sentry

* fix: improve handler projection and implement tests

* more tests

* fix: race condition in tests

* Update internal/eventstore/repository/sql/query.go

Co-authored-by: Silvan <silvan.reusser@gmail.com>

* fix: implemented suggested changes

* fix: lock statement

Co-authored-by: Silvan <silvan.reusser@gmail.com>
This commit is contained in:
Livio Spring
2022-07-22 12:08:39 +02:00
committed by GitHub
parent 0cc548e3f8
commit aed7010508
83 changed files with 1494 additions and 1544 deletions

View File

@@ -6,12 +6,14 @@ import (
"strconv"
"strings"
"github.com/lib/pq"
"github.com/zitadel/zitadel/internal/errors"
"github.com/zitadel/zitadel/internal/eventstore"
)
const (
currentSequenceStmtFormat = `SELECT current_sequence, aggregate_type, instance_id FROM %s WHERE projection_name = $1 FOR UPDATE`
currentSequenceStmtFormat = `SELECT current_sequence, aggregate_type, instance_id FROM %s WHERE projection_name = $1 AND instance_id = ANY ($2) FOR UPDATE`
updateCurrentSequencesStmtFormat = `UPSERT INTO %s (projection_name, aggregate_type, current_sequence, instance_id, timestamp) VALUES `
)
@@ -22,8 +24,8 @@ type instanceSequence struct {
sequence uint64
}
func (h *StatementHandler) currentSequences(ctx context.Context, query func(context.Context, string, ...interface{}) (*sql.Rows, error)) (currentSequences, error) {
rows, err := query(ctx, h.currentSequenceStmt, h.ProjectionName)
func (h *StatementHandler) currentSequences(ctx context.Context, query func(context.Context, string, ...interface{}) (*sql.Rows, error), instanceIDs []string) (currentSequences, error) {
rows, err := query(ctx, h.currentSequenceStmt, h.ProjectionName, pq.StringArray(instanceIDs))
if err != nil {
return nil, err
}

View File

@@ -8,6 +8,7 @@ import (
"time"
"github.com/DATA-DOG/go-sqlmock"
"github.com/lib/pq"
"github.com/zitadel/zitadel/internal/eventstore"
)
@@ -123,34 +124,40 @@ func expectSavePointRelease() func(sqlmock.Sqlmock) {
}
}
func expectCurrentSequence(tableName, projection string, seq uint64, aggregateType, instanceID string) func(sqlmock.Sqlmock) {
func expectCurrentSequence(tableName, projection string, seq uint64, aggregateType string, instanceIDs []string) func(sqlmock.Sqlmock) {
rows := sqlmock.NewRows([]string{"current_sequence", "aggregate_type", "instance_id"})
for _, instanceID := range instanceIDs {
rows.AddRow(seq, aggregateType, instanceID)
}
return func(m sqlmock.Sqlmock) {
m.ExpectQuery(`SELECT current_sequence, aggregate_type, instance_id FROM ` + tableName + ` WHERE projection_name = \$1 FOR UPDATE`).
m.ExpectQuery(`SELECT current_sequence, aggregate_type, instance_id FROM `+tableName+` WHERE projection_name = \$1 AND instance_id = ANY \(\$2\) FOR UPDATE`).
WithArgs(
projection,
pq.StringArray(instanceIDs),
).
WillReturnRows(
sqlmock.NewRows([]string{"current_sequence", "aggregate_type", "instance_id"}).
AddRow(seq, aggregateType, instanceID),
rows,
)
}
}
func expectCurrentSequenceErr(tableName, projection string, err error) func(sqlmock.Sqlmock) {
func expectCurrentSequenceErr(tableName, projection string, instanceIDs []string, err error) func(sqlmock.Sqlmock) {
return func(m sqlmock.Sqlmock) {
m.ExpectQuery(`SELECT current_sequence, aggregate_type, instance_id FROM ` + tableName + ` WHERE projection_name = \$1 FOR UPDATE`).
m.ExpectQuery(`SELECT current_sequence, aggregate_type, instance_id FROM `+tableName+` WHERE projection_name = \$1 AND instance_id = ANY \(\$2\) FOR UPDATE`).
WithArgs(
projection,
pq.StringArray(instanceIDs),
).
WillReturnError(err)
}
}
func expectCurrentSequenceNoRows(tableName, projection string) func(sqlmock.Sqlmock) {
func expectCurrentSequenceNoRows(tableName, projection string, instanceIDs []string) func(sqlmock.Sqlmock) {
return func(m sqlmock.Sqlmock) {
m.ExpectQuery(`SELECT current_sequence, aggregate_type, instance_id FROM ` + tableName + ` WHERE projection_name = \$1 FOR UPDATE`).
m.ExpectQuery(`SELECT current_sequence, aggregate_type, instance_id FROM `+tableName+` WHERE projection_name = \$1 AND instance_id = ANY \(\$2\) FOR UPDATE`).
WithArgs(
projection,
pq.StringArray(instanceIDs),
).
WillReturnRows(
sqlmock.NewRows([]string{"current_sequence", "aggregate_type", "instance_id"}),
@@ -158,11 +165,12 @@ func expectCurrentSequenceNoRows(tableName, projection string) func(sqlmock.Sqlm
}
}
func expectCurrentSequenceScanErr(tableName, projection string) func(sqlmock.Sqlmock) {
func expectCurrentSequenceScanErr(tableName, projection string, instanceIDs []string) func(sqlmock.Sqlmock) {
return func(m sqlmock.Sqlmock) {
m.ExpectQuery(`SELECT current_sequence, aggregate_type, instance_id FROM ` + tableName + ` WHERE projection_name = \$1 FOR UPDATE`).
m.ExpectQuery(`SELECT current_sequence, aggregate_type, instance_id FROM `+tableName+` WHERE projection_name = \$1 AND instance_id = ANY \(\$2\) FOR UPDATE`).
WithArgs(
projection,
pq.StringArray(instanceIDs),
).
WillReturnRows(
sqlmock.NewRows([]string{"current_sequence", "aggregate_type", "instance_id"}).
@@ -286,12 +294,34 @@ func expectLock(lockTable, workerName string, d time.Duration, instanceID string
` \(locker_id, locked_until, projection_name, instance_id\) VALUES \(\$1, now\(\)\+\$2::INTERVAL, \$3\, \$4\)`+
` ON CONFLICT \(projection_name, instance_id\)`+
` DO UPDATE SET locker_id = \$1, locked_until = now\(\)\+\$2::INTERVAL`+
` WHERE `+lockTable+`\.projection_name = \$3 AND `+lockTable+`\.instance_id = \$4 AND \(`+lockTable+`\.locker_id = \$1 OR `+lockTable+`\.locked_until < now\(\)\)`).
` WHERE `+lockTable+`\.projection_name = \$3 AND `+lockTable+`\.instance_id = ANY \(\$5\) AND \(`+lockTable+`\.locker_id = \$1 OR `+lockTable+`\.locked_until < now\(\)\)`).
WithArgs(
workerName,
float64(d),
projectionName,
instanceID,
pq.StringArray{instanceID},
).
WillReturnResult(
sqlmock.NewResult(1, 1),
)
}
}
func expectLockMultipleInstances(lockTable, workerName string, d time.Duration, instanceID1, instanceID2 string) func(sqlmock.Sqlmock) {
return func(m sqlmock.Sqlmock) {
m.ExpectExec(`INSERT INTO `+lockTable+
` \(locker_id, locked_until, projection_name, instance_id\) VALUES \(\$1, now\(\)\+\$2::INTERVAL, \$3\, \$4\), \(\$1, now\(\)\+\$2::INTERVAL, \$3\, \$5\)`+
` ON CONFLICT \(projection_name, instance_id\)`+
` DO UPDATE SET locker_id = \$1, locked_until = now\(\)\+\$2::INTERVAL`+
` WHERE `+lockTable+`\.projection_name = \$3 AND `+lockTable+`\.instance_id = ANY \(\$6\) AND \(`+lockTable+`\.locker_id = \$1 OR `+lockTable+`\.locked_until < now\(\)\)`).
WithArgs(
workerName,
float64(d),
projectionName,
instanceID1,
instanceID2,
pq.StringArray{instanceID1, instanceID2},
).
WillReturnResult(
sqlmock.NewResult(1, 1),
@@ -305,12 +335,13 @@ func expectLockNoRows(lockTable, workerName string, d time.Duration, instanceID
` \(locker_id, locked_until, projection_name, instance_id\) VALUES \(\$1, now\(\)\+\$2::INTERVAL, \$3\, \$4\)`+
` ON CONFLICT \(projection_name, instance_id\)`+
` DO UPDATE SET locker_id = \$1, locked_until = now\(\)\+\$2::INTERVAL`+
` WHERE `+lockTable+`\.projection_name = \$3 AND `+lockTable+`\.instance_id = \$4 AND \(`+lockTable+`\.locker_id = \$1 OR `+lockTable+`\.locked_until < now\(\)\)`).
` WHERE `+lockTable+`\.projection_name = \$3 AND `+lockTable+`\.instance_id = ANY \(\$5\) AND \(`+lockTable+`\.locker_id = \$1 OR `+lockTable+`\.locked_until < now\(\)\)`).
WithArgs(
workerName,
float64(d),
projectionName,
instanceID,
pq.StringArray{instanceID},
).
WillReturnResult(driver.ResultNoRows)
}
@@ -322,12 +353,13 @@ func expectLockErr(lockTable, workerName string, d time.Duration, instanceID str
` \(locker_id, locked_until, projection_name, instance_id\) VALUES \(\$1, now\(\)\+\$2::INTERVAL, \$3\, \$4\)`+
` ON CONFLICT \(projection_name, instance_id\)`+
` DO UPDATE SET locker_id = \$1, locked_until = now\(\)\+\$2::INTERVAL`+
` WHERE `+lockTable+`\.projection_name = \$3 AND `+lockTable+`\.instance_id = \$4 AND \(`+lockTable+`\.locker_id = \$1 OR `+lockTable+`\.locked_until < now\(\)\)`).
` WHERE `+lockTable+`\.projection_name = \$3 AND `+lockTable+`\.instance_id = ANY \(\$5\) AND \(`+lockTable+`\.locker_id = \$1 OR `+lockTable+`\.locked_until < now\(\)\)`).
WithArgs(
workerName,
float64(d),
projectionName,
instanceID,
pq.StringArray{instanceID},
).
WillReturnError(err)
}

View File

@@ -10,7 +10,6 @@ import (
"github.com/zitadel/zitadel/internal/errors"
"github.com/zitadel/zitadel/internal/eventstore"
"github.com/zitadel/zitadel/internal/eventstore/handler"
"github.com/zitadel/zitadel/internal/telemetry/tracing"
)
var (
@@ -75,84 +74,62 @@ func NewStatementHandler(
bulkLimit: config.BulkLimit,
Locker: NewLocker(config.Client, config.LockTable, config.ProjectionHandlerConfig.ProjectionName),
}
h.ProjectionHandler = handler.NewProjectionHandler(config.ProjectionHandlerConfig, h.reduce, h.Update, h.SearchQuery)
h.ProjectionHandler = handler.NewProjectionHandler(ctx, config.ProjectionHandlerConfig, h.reduce, h.Update, h.SearchQuery, h.Lock, h.Unlock)
err := h.Init(ctx, config.InitCheck)
logging.OnError(err).Fatal("unable to initialize projections")
go h.Process(
ctx,
h.reduce,
h.Update,
h.Lock,
h.Unlock,
h.SearchQuery,
)
h.Subscribe(h.aggregates...)
return h
}
func (h *StatementHandler) TriggerBulk(ctx context.Context) {
ctx, span := tracing.NewSpan(ctx)
var err error
defer span.EndWithError(err)
err = h.ProjectionHandler.TriggerBulk(ctx, h.Lock, h.Unlock)
logging.OnError(err).WithField("projection", h.ProjectionName).Warn("unable to trigger bulk")
}
func (h *StatementHandler) SearchQuery(ctx context.Context) (*eventstore.SearchQueryBuilder, uint64, error) {
sequences, err := h.currentSequences(ctx, h.client.QueryContext)
func (h *StatementHandler) SearchQuery(ctx context.Context, instanceIDs []string) (*eventstore.SearchQueryBuilder, uint64, error) {
sequences, err := h.currentSequences(ctx, h.client.QueryContext, instanceIDs)
if err != nil {
return nil, 0, err
}
queryBuilder := eventstore.NewSearchQueryBuilder(eventstore.ColumnsEvent).Limit(h.bulkLimit)
for _, aggregateType := range h.aggregates {
instances := make([]string, 0)
for _, sequence := range sequences[aggregateType] {
instances = appendToIgnoredInstances(instances, sequence.instanceID)
for _, instanceID := range instanceIDs {
var seq uint64
for _, sequence := range sequences[aggregateType] {
if sequence.instanceID == instanceID {
seq = sequence.sequence
break
}
}
queryBuilder.
AddQuery().
AggregateTypes(aggregateType).
SequenceGreater(sequence.sequence).
InstanceID(sequence.instanceID)
SequenceGreater(seq).
InstanceID(instanceID)
}
queryBuilder.
AddQuery().
AggregateTypes(aggregateType).
SequenceGreater(0).
ExcludedInstanceID(instances...)
}
return queryBuilder, h.bulkLimit, nil
}
func appendToIgnoredInstances(instances []string, id string) []string {
for _, instance := range instances {
if instance == id {
return instances
}
}
return append(instances, id)
}
//Update implements handler.Update
func (h *StatementHandler) Update(ctx context.Context, stmts []*handler.Statement, reduce handler.Reduce) (unexecutedStmts []*handler.Statement, err error) {
func (h *StatementHandler) Update(ctx context.Context, stmts []*handler.Statement, reduce handler.Reduce) (index int, err error) {
if len(stmts) == 0 {
return nil, nil
return -1, nil
}
instanceIDs := make([]string, 0, len(stmts))
for _, stmt := range stmts {
instanceIDs = appendToInstanceIDs(instanceIDs, stmt.InstanceID)
}
tx, err := h.client.BeginTx(ctx, nil)
if err != nil {
return stmts, errors.ThrowInternal(err, "CRDB-e89Gq", "begin failed")
return -1, errors.ThrowInternal(err, "CRDB-e89Gq", "begin failed")
}
sequences, err := h.currentSequences(ctx, tx.QueryContext)
sequences, err := h.currentSequences(ctx, tx.QueryContext, instanceIDs)
if err != nil {
tx.Rollback()
return stmts, err
return -1, err
}
//checks for events between create statement and current sequence
@@ -162,7 +139,7 @@ func (h *StatementHandler) Update(ctx context.Context, stmts []*handler.Statemen
previousStmts, err := h.fetchPreviousStmts(ctx, tx, stmts[0].Sequence, stmts[0].InstanceID, sequences, reduce)
if err != nil {
tx.Rollback()
return stmts, err
return -1, err
}
stmts = append(previousStmts, stmts...)
}
@@ -173,27 +150,19 @@ func (h *StatementHandler) Update(ctx context.Context, stmts []*handler.Statemen
err = h.updateCurrentSequences(tx, sequences)
if err != nil {
tx.Rollback()
return stmts, err
return -1, err
}
}
if err = tx.Commit(); err != nil {
return stmts, err
return -1, err
}
if lastSuccessfulIdx == -1 && len(stmts) > 0 {
return stmts, handler.ErrSomeStmtsFailed
if lastSuccessfulIdx < len(stmts)-1 {
return lastSuccessfulIdx, handler.ErrSomeStmtsFailed
}
unexecutedStmts = make([]*handler.Statement, len(stmts)-(lastSuccessfulIdx+1))
copy(unexecutedStmts, stmts[lastSuccessfulIdx+1:])
stmts = nil
if len(unexecutedStmts) > 0 {
return unexecutedStmts, handler.ErrSomeStmtsFailed
}
return unexecutedStmts, nil
return lastSuccessfulIdx, nil
}
func (h *StatementHandler) fetchPreviousStmts(ctx context.Context, tx *sql.Tx, stmtSeq uint64, instanceID string, sequences currentSequences, reduce handler.Reduce) (previousStmts []*handler.Statement, err error) {
@@ -316,3 +285,12 @@ func updateSequences(sequences currentSequences, stmt *handler.Statement) {
sequence: stmt.Sequence,
})
}
func appendToInstanceIDs(instances []string, id string) []string {
for _, instance := range instances {
if instance == id {
return instances
}
}
return append(instances, id)
}

View File

@@ -4,6 +4,7 @@ import (
"context"
"database/sql"
"errors"
"fmt"
"reflect"
"testing"
"time"
@@ -61,9 +62,13 @@ func TestProjectionHandler_SearchQuery(t *testing.T) {
aggregates []eventstore.AggregateType
bulkLimit uint64
}
type args struct {
instanceIDs []string
}
tests := []struct {
name string
fields fields
args args
want want
}{
{
@@ -74,13 +79,16 @@ func TestProjectionHandler_SearchQuery(t *testing.T) {
aggregates: []eventstore.AggregateType{"testAgg"},
bulkLimit: 5,
},
args: args{
instanceIDs: []string{"instanceID1"},
},
want: want{
limit: 0,
isErr: func(err error) bool {
return errors.Is(err, sql.ErrTxDone)
},
expectations: []mockExpectation{
expectCurrentSequenceErr("my_sequences", "my_projection", sql.ErrTxDone),
expectCurrentSequenceErr("my_sequences", "my_projection", []string{"instanceID1"}, sql.ErrTxDone),
},
SearchQueryBuilder: nil,
},
@@ -93,24 +101,56 @@ func TestProjectionHandler_SearchQuery(t *testing.T) {
aggregates: []eventstore.AggregateType{"testAgg"},
bulkLimit: 5,
},
args: args{
instanceIDs: []string{"instanceID1"},
},
want: want{
limit: 5,
isErr: func(err error) bool {
return err == nil
},
expectations: []mockExpectation{
expectCurrentSequence("my_sequences", "my_projection", 5, "testAgg", "instanceID"),
expectCurrentSequence("my_sequences", "my_projection", 5, "testAgg", []string{"instanceID1"}),
},
SearchQueryBuilder: eventstore.
NewSearchQueryBuilder(eventstore.ColumnsEvent).
AddQuery().
AggregateTypes("testAgg").
SequenceGreater(5).
InstanceID("instanceID").
InstanceID("instanceID1").
Builder().
Limit(5),
},
},
{
name: "multiple instances",
fields: fields{
sequenceTable: "my_sequences",
projectionName: "my_projection",
aggregates: []eventstore.AggregateType{"testAgg"},
bulkLimit: 5,
},
args: args{
instanceIDs: []string{"instanceID1", "instanceID2"},
},
want: want{
limit: 5,
isErr: func(err error) bool {
return err == nil
},
expectations: []mockExpectation{
expectCurrentSequence("my_sequences", "my_projection", 5, "testAgg", []string{"instanceID1", "instanceID2"}),
},
SearchQueryBuilder: eventstore.
NewSearchQueryBuilder(eventstore.ColumnsEvent).
AddQuery().
AggregateTypes("testAgg").
SequenceGreater(5).
InstanceID("instanceID1").
Or().
AggregateTypes("testAgg").
SequenceGreater(0).
ExcludedInstanceID("instanceID").
SequenceGreater(5).
InstanceID("instanceID2").
Builder().
Limit(5),
},
@@ -140,7 +180,7 @@ func TestProjectionHandler_SearchQuery(t *testing.T) {
expectation(mock)
}
query, limit, err := h.SearchQuery(context.Background())
query, limit, err := h.SearchQuery(context.Background(), tt.args.instanceIDs)
if !tt.want.isErr(err) {
t.Errorf("ProjectionHandler.prepareBulkStmts() error = %v", err)
return
@@ -211,13 +251,14 @@ func TestStatementHandler_Update(t *testing.T) {
aggregateType: "agg",
sequence: 6,
previousSequence: 0,
instanceID: "instanceID",
}),
},
},
want: want{
expectations: []mockExpectation{
expectBegin(),
expectCurrentSequenceErr("my_sequences", "my_projection", sql.ErrTxDone),
expectCurrentSequenceErr("my_sequences", "my_projection", []string{"instanceID"}, sql.ErrTxDone),
expectRollback(),
},
isErr: func(err error) bool {
@@ -241,13 +282,14 @@ func TestStatementHandler_Update(t *testing.T) {
aggregateType: "agg",
sequence: 6,
previousSequence: 0,
instanceID: "instanceID",
}),
},
},
want: want{
expectations: []mockExpectation{
expectBegin(),
expectCurrentSequence("my_sequences", "my_projection", 5, "testAgg", "instanceID"),
expectCurrentSequence("my_sequences", "my_projection", 5, "testAgg", []string{"instanceID"}),
expectRollback(),
},
isErr: func(err error) bool {
@@ -272,6 +314,7 @@ func TestStatementHandler_Update(t *testing.T) {
aggregateType: "testAgg",
sequence: 7,
previousSequence: 6,
instanceID: "instanceID",
},
[]handler.Column{
{
@@ -284,7 +327,7 @@ func TestStatementHandler_Update(t *testing.T) {
want: want{
expectations: []mockExpectation{
expectBegin(),
expectCurrentSequence("my_sequences", "my_projection", 5, "testAgg", "instanceID"),
expectCurrentSequence("my_sequences", "my_projection", 5, "testAgg", []string{"instanceID"}),
expectCommit(),
},
isErr: func(err error) bool {
@@ -322,7 +365,7 @@ func TestStatementHandler_Update(t *testing.T) {
want: want{
expectations: []mockExpectation{
expectBegin(),
expectCurrentSequence("my_sequences", "my_projection", 5, "agg", "instanceID"),
expectCurrentSequence("my_sequences", "my_projection", 5, "agg", []string{"instanceID"}),
expectSavePoint(),
expectCreate("my_projection", []string{"col"}, []string{"$1"}),
expectSavePointRelease(),
@@ -364,7 +407,7 @@ func TestStatementHandler_Update(t *testing.T) {
want: want{
expectations: []mockExpectation{
expectBegin(),
expectCurrentSequence("my_sequences", "my_projection", 5, "agg", "instanceID"),
expectCurrentSequence("my_sequences", "my_projection", 5, "agg", []string{"instanceID"}),
expectSavePoint(),
expectCreate("my_projection", []string{"col"}, []string{"$1"}),
expectSavePointRelease(),
@@ -399,7 +442,7 @@ func TestStatementHandler_Update(t *testing.T) {
want: want{
expectations: []mockExpectation{
expectBegin(),
expectCurrentSequence("my_sequences", "my_projection", 5, "testAgg", "instanceID"),
expectCurrentSequence("my_sequences", "my_projection", 5, "testAgg", []string{"instanceID"}),
expectUpdateCurrentSequence("my_sequences", "my_projection", 7, "testAgg", "instanceID"),
expectCommit(),
},
@@ -431,7 +474,7 @@ func TestStatementHandler_Update(t *testing.T) {
want: want{
expectations: []mockExpectation{
expectBegin(),
expectCurrentSequence("my_sequences", "my_projection", 5, "testAgg", "instanceID"),
expectCurrentSequence("my_sequences", "my_projection", 5, "testAgg", []string{"instanceID"}),
expectUpdateCurrentSequence("my_sequences", "my_projection", 7, "testAgg", "instanceID"),
expectCommit(),
},
@@ -470,13 +513,14 @@ func TestStatementHandler_Update(t *testing.T) {
want: want{
expectations: []mockExpectation{
expectBegin(),
expectCurrentSequence("my_sequences", "my_projection", 5, "testAgg", "instanceID"),
expectCurrentSequence("my_sequences", "my_projection", 5, "testAgg", []string{"instanceID"}),
expectUpdateCurrentSequence("my_sequences", "my_projection", 7, "testAgg", "instanceID"),
expectCommit(),
},
isErr: func(err error) bool {
return errors.Is(err, nil)
},
stmtsLen: 1,
},
},
}
@@ -488,17 +532,18 @@ func TestStatementHandler_Update(t *testing.T) {
}
defer client.Close()
h := NewStatementHandler(context.Background(), StatementHandlerConfig{
ProjectionHandlerConfig: handler.ProjectionHandlerConfig{
ProjectionName: "my_projection",
HandlerConfig: handler.HandlerConfig{
h := &StatementHandler{
ProjectionHandler: &handler.ProjectionHandler{
Handler: handler.Handler{
Eventstore: tt.fields.eventstore,
},
RequeueEvery: 0,
ProjectionName: "my_projection",
},
SequenceTable: "my_sequences",
Client: client,
})
sequenceTable: "my_sequences",
currentSequenceStmt: fmt.Sprintf(currentSequenceStmtFormat, "my_sequences"),
updateSequencesBaseStmt: fmt.Sprintf(updateCurrentSequencesStmtFormat, "my_sequences"),
client: client,
}
h.aggregates = tt.fields.aggregates
@@ -506,12 +551,12 @@ func TestStatementHandler_Update(t *testing.T) {
expectation(mock)
}
stmts, err := h.Update(tt.args.ctx, tt.args.stmts, tt.args.reduce)
index, err := h.Update(tt.args.ctx, tt.args.stmts, tt.args.reduce)
if !tt.want.isErr(err) {
t.Errorf("StatementHandler.Update() error = %v", err)
}
if err == nil && tt.want.stmtsLen != len(stmts) {
t.Errorf("wrong stmts length: want: %d got %d", tt.want.stmtsLen, len(stmts))
if err == nil && tt.want.stmtsLen != index {
t.Errorf("wrong stmts length: want: %d got %d", tt.want.stmtsLen, index)
}
mock.MatchExpectationsInOrder(true)
@@ -696,17 +741,12 @@ func TestProjectionHandler_fetchPreviousStmts(t *testing.T) {
h := &StatementHandler{
aggregates: tt.fields.aggregates,
}
h.ProjectionHandler = handler.NewProjectionHandler(handler.ProjectionHandlerConfig{
HandlerConfig: handler.HandlerConfig{
h.ProjectionHandler = &handler.ProjectionHandler{
Handler: handler.Handler{
Eventstore: tt.fields.eventstore,
},
ProjectionName: "my_projection",
RequeueEvery: 0,
},
h.reduce,
h.Update,
h.SearchQuery,
)
}
stmts, err := h.fetchPreviousStmts(tt.args.ctx, nil, tt.args.stmtSeq, "", tt.args.sequences, tt.args.reduce)
if !tt.want.isErr(err) {
t.Errorf("ProjectionHandler.prepareBulkStmts() error = %v", err)
@@ -1311,7 +1351,8 @@ func TestStatementHandler_currentSequence(t *testing.T) {
aggregates []eventstore.AggregateType
}
type args struct {
stmt handler.Statement
stmt handler.Statement
instanceIDs []string
}
type want struct {
expectations []mockExpectation
@@ -1338,7 +1379,7 @@ func TestStatementHandler_currentSequence(t *testing.T) {
return errors.Is(err, sql.ErrConnDone)
},
expectations: []mockExpectation{
expectCurrentSequenceErr("my_table", "my_projection", sql.ErrConnDone),
expectCurrentSequenceErr("my_table", "my_projection", nil, sql.ErrConnDone),
},
},
},
@@ -1350,14 +1391,15 @@ func TestStatementHandler_currentSequence(t *testing.T) {
aggregates: []eventstore.AggregateType{"agg"},
},
args: args{
stmt: handler.Statement{},
stmt: handler.Statement{},
instanceIDs: []string{"instanceID"},
},
want: want{
isErr: func(err error) bool {
return errors.Is(err, nil)
},
expectations: []mockExpectation{
expectCurrentSequenceNoRows("my_table", "my_projection"),
expectCurrentSequenceNoRows("my_table", "my_projection", []string{"instanceID"}),
},
sequences: currentSequences{},
},
@@ -1370,14 +1412,15 @@ func TestStatementHandler_currentSequence(t *testing.T) {
aggregates: []eventstore.AggregateType{"agg"},
},
args: args{
stmt: handler.Statement{},
stmt: handler.Statement{},
instanceIDs: []string{"instanceID"},
},
want: want{
isErr: func(err error) bool {
return errors.Is(err, sql.ErrTxDone)
},
expectations: []mockExpectation{
expectCurrentSequenceScanErr("my_table", "my_projection"),
expectCurrentSequenceScanErr("my_table", "my_projection", []string{"instanceID"}),
},
sequences: currentSequences{},
},
@@ -1390,14 +1433,15 @@ func TestStatementHandler_currentSequence(t *testing.T) {
aggregates: []eventstore.AggregateType{"agg"},
},
args: args{
stmt: handler.Statement{},
stmt: handler.Statement{},
instanceIDs: []string{"instanceID"},
},
want: want{
isErr: func(err error) bool {
return errors.Is(err, nil)
},
expectations: []mockExpectation{
expectCurrentSequence("my_table", "my_projection", 5, "agg", "instanceID"),
expectCurrentSequence("my_table", "my_projection", 5, "agg", []string{"instanceID"}),
},
sequences: currentSequences{
"agg": []*instanceSequence{
@@ -1409,15 +1453,48 @@ func TestStatementHandler_currentSequence(t *testing.T) {
},
},
},
{
name: "multiple found",
fields: fields{
sequenceTable: "my_table",
projectionName: "my_projection",
aggregates: []eventstore.AggregateType{"agg"},
},
args: args{
stmt: handler.Statement{},
instanceIDs: []string{"instanceID1", "instanceID2"},
},
want: want{
isErr: func(err error) bool {
return errors.Is(err, nil)
},
expectations: []mockExpectation{
expectCurrentSequence("my_table", "my_projection", 5, "agg", []string{"instanceID1", "instanceID2"}),
},
sequences: currentSequences{
"agg": []*instanceSequence{
{
sequence: 5,
instanceID: "instanceID1",
},
{
sequence: 5,
instanceID: "instanceID2",
},
},
},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
h := NewStatementHandler(context.Background(), StatementHandlerConfig{
ProjectionHandlerConfig: handler.ProjectionHandlerConfig{
h := &StatementHandler{
ProjectionHandler: &handler.ProjectionHandler{
ProjectionName: tt.fields.projectionName,
},
SequenceTable: tt.fields.sequenceTable,
})
sequenceTable: tt.fields.sequenceTable,
currentSequenceStmt: fmt.Sprintf(currentSequenceStmtFormat, tt.fields.sequenceTable),
}
h.aggregates = tt.fields.aggregates
@@ -1440,7 +1517,7 @@ func TestStatementHandler_currentSequence(t *testing.T) {
t.Fatalf("unexpected err in begin: %v", err)
}
seq, err := h.currentSequences(context.Background(), tx.QueryContext)
seq, err := h.currentSequences(context.Background(), tx.QueryContext, tt.args.instanceIDs)
if !tt.want.isErr(err) {
t.Errorf("unexpected error: %v", err)
}
@@ -1615,12 +1692,13 @@ func TestStatementHandler_updateCurrentSequence(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
h := NewStatementHandler(context.Background(), StatementHandlerConfig{
ProjectionHandlerConfig: handler.ProjectionHandlerConfig{
h := &StatementHandler{
ProjectionHandler: &handler.ProjectionHandler{
ProjectionName: tt.fields.projectionName,
},
SequenceTable: tt.fields.sequenceTable,
})
sequenceTable: tt.fields.sequenceTable,
updateSequencesBaseStmt: fmt.Sprintf(updateCurrentSequencesStmtFormat, tt.fields.sequenceTable),
}
h.aggregates = tt.fields.aggregates

View File

@@ -4,8 +4,11 @@ import (
"context"
"database/sql"
"fmt"
"strconv"
"strings"
"time"
"github.com/lib/pq"
"github.com/zitadel/logging"
"github.com/zitadel/zitadel/internal/errors"
@@ -14,20 +17,20 @@ import (
const (
lockStmtFormat = "INSERT INTO %[1]s" +
" (locker_id, locked_until, projection_name, instance_id) VALUES ($1, now()+$2::INTERVAL, $3, $4)" +
" (locker_id, locked_until, projection_name, instance_id) VALUES %[2]s" +
" ON CONFLICT (projection_name, instance_id)" +
" DO UPDATE SET locker_id = $1, locked_until = now()+$2::INTERVAL" +
" WHERE %[1]s.projection_name = $3 AND %[1]s.instance_id = $4 AND (%[1]s.locker_id = $1 OR %[1]s.locked_until < now())"
" WHERE %[1]s.projection_name = $3 AND %[1]s.instance_id = ANY ($%[3]d) AND (%[1]s.locker_id = $1 OR %[1]s.locked_until < now())"
)
type Locker interface {
Lock(ctx context.Context, lockDuration time.Duration, instanceID string) <-chan error
Unlock(instanceID string) error
Lock(ctx context.Context, lockDuration time.Duration, instanceIDs ...string) <-chan error
Unlock(instanceIDs ...string) error
}
type locker struct {
client *sql.DB
lockStmt string
lockStmt func(values string, instances int) string
workerName string
projectionName string
}
@@ -36,25 +39,27 @@ func NewLocker(client *sql.DB, lockTable, projectionName string) Locker {
workerName, err := id.SonyFlakeGenerator().Next()
logging.OnError(err).Panic("unable to generate lockID")
return &locker{
client: client,
lockStmt: fmt.Sprintf(lockStmtFormat, lockTable),
client: client,
lockStmt: func(values string, instances int) string {
return fmt.Sprintf(lockStmtFormat, lockTable, values, instances)
},
workerName: workerName,
projectionName: projectionName,
}
}
func (h *locker) Lock(ctx context.Context, lockDuration time.Duration, instanceID string) <-chan error {
func (h *locker) Lock(ctx context.Context, lockDuration time.Duration, instanceIDs ...string) <-chan error {
errs := make(chan error)
go h.handleLock(ctx, errs, lockDuration, instanceID)
go h.handleLock(ctx, errs, lockDuration, instanceIDs...)
return errs
}
func (h *locker) handleLock(ctx context.Context, errs chan error, lockDuration time.Duration, instanceID string) {
func (h *locker) handleLock(ctx context.Context, errs chan error, lockDuration time.Duration, instanceIDs ...string) {
renewLock := time.NewTimer(0)
for {
select {
case <-renewLock.C:
errs <- h.renewLock(ctx, lockDuration, instanceID)
errs <- h.renewLock(ctx, lockDuration, instanceIDs...)
//refresh the lock 500ms before it times out. 500ms should be enough for one transaction
renewLock.Reset(lockDuration - (500 * time.Millisecond))
case <-ctx.Done():
@@ -65,24 +70,38 @@ func (h *locker) handleLock(ctx context.Context, errs chan error, lockDuration t
}
}
func (h *locker) renewLock(ctx context.Context, lockDuration time.Duration, instanceID string) error {
//the unit of crdb interval is seconds (https://www.cockroachlabs.com/docs/stable/interval.html).
res, err := h.client.ExecContext(ctx, h.lockStmt, h.workerName, lockDuration.Seconds(), h.projectionName, instanceID)
func (h *locker) renewLock(ctx context.Context, lockDuration time.Duration, instanceIDs ...string) error {
lockStmt, values := h.lockStatement(lockDuration, instanceIDs)
res, err := h.client.ExecContext(ctx, lockStmt, values...)
if err != nil {
return errors.ThrowInternal(err, "CRDB-uaDoR", "unable to execute lock")
}
if rows, _ := res.RowsAffected(); rows == 0 {
return errors.ThrowAlreadyExists(nil, "CRDB-mmi4J", "projection already locked")
}
return nil
}
func (h *locker) Unlock(instanceID string) error {
_, err := h.client.Exec(h.lockStmt, h.workerName, float64(0), h.projectionName, instanceID)
func (h *locker) Unlock(instanceIDs ...string) error {
lockStmt, values := h.lockStatement(0, instanceIDs)
_, err := h.client.Exec(lockStmt, values...)
if err != nil {
return errors.ThrowUnknown(err, "CRDB-JjfwO", "unlock failed")
}
return nil
}
func (h *locker) lockStatement(lockDuration time.Duration, instanceIDs []string) (string, []interface{}) {
valueQueries := make([]string, len(instanceIDs))
values := make([]interface{}, len(instanceIDs)+4)
values[0] = h.workerName
//the unit of crdb interval is seconds (https://www.cockroachlabs.com/docs/stable/interval.html).
values[1] = lockDuration.Seconds()
values[2] = h.projectionName
for i, instanceID := range instanceIDs {
valueQueries[i] = "($1, now()+$2::INTERVAL, $3, $" + strconv.Itoa(i+4) + ")"
values[i+3] = instanceID
}
values[len(values)-1] = pq.StringArray(instanceIDs)
return h.lockStmt(strings.Join(valueQueries, ", "), len(values)), values
}

View File

@@ -32,7 +32,7 @@ func TestStatementHandler_handleLock(t *testing.T) {
lockDuration time.Duration
ctx context.Context
errMock *errsMock
instanceID string
instanceIDs []string
}
tests := []struct {
name string
@@ -56,7 +56,7 @@ func TestStatementHandler_handleLock(t *testing.T) {
successfulIters: 2,
shouldErr: true,
},
instanceID: "instanceID",
instanceIDs: []string{"instanceID"},
},
},
{
@@ -74,7 +74,25 @@ func TestStatementHandler_handleLock(t *testing.T) {
errs: make(chan error),
successfulIters: 2,
},
instanceID: "instanceID",
instanceIDs: []string{"instanceID"},
},
},
{
name: "success with multiple",
want: want{
expectations: []mockExpectation{
expectLockMultipleInstances(lockTable, workerName, 2, "instanceID1", "instanceID2"),
expectLockMultipleInstances(lockTable, workerName, 2, "instanceID1", "instanceID2"),
},
},
args: args{
lockDuration: 2 * time.Second,
ctx: context.Background(),
errMock: &errsMock{
errs: make(chan error),
successfulIters: 2,
},
instanceIDs: []string{"instanceID1", "instanceID2"},
},
},
}
@@ -88,7 +106,9 @@ func TestStatementHandler_handleLock(t *testing.T) {
projectionName: projectionName,
client: client,
workerName: workerName,
lockStmt: fmt.Sprintf(lockStmtFormat, lockTable),
lockStmt: func(values string, instances int) string {
return fmt.Sprintf(lockStmtFormat, lockTable, values, instances)
},
}
for _, expectation := range tt.want.expectations {
@@ -99,7 +119,7 @@ func TestStatementHandler_handleLock(t *testing.T) {
go tt.args.errMock.handleErrs(t, cancel)
go h.handleLock(ctx, tt.args.errMock.errs, tt.args.lockDuration, tt.args.instanceID)
go h.handleLock(ctx, tt.args.errMock.errs, tt.args.lockDuration, tt.args.instanceIDs...)
<-ctx.Done()
@@ -118,7 +138,7 @@ func TestStatementHandler_renewLock(t *testing.T) {
}
type args struct {
lockDuration time.Duration
instanceID string
instanceIDs []string
}
tests := []struct {
name string
@@ -137,7 +157,7 @@ func TestStatementHandler_renewLock(t *testing.T) {
},
args: args{
lockDuration: 1 * time.Second,
instanceID: "instanceID",
instanceIDs: []string{"instanceID"},
},
},
{
@@ -152,7 +172,7 @@ func TestStatementHandler_renewLock(t *testing.T) {
},
args: args{
lockDuration: 2 * time.Second,
instanceID: "instanceID",
instanceIDs: []string{"instanceID"},
},
},
{
@@ -167,7 +187,22 @@ func TestStatementHandler_renewLock(t *testing.T) {
},
args: args{
lockDuration: 3 * time.Second,
instanceID: "instanceID",
instanceIDs: []string{"instanceID"},
},
},
{
name: "success with multiple",
want: want{
expectations: []mockExpectation{
expectLockMultipleInstances(lockTable, workerName, 3, "instanceID1", "instanceID2"),
},
isErr: func(err error) bool {
return errors.Is(err, nil)
},
},
args: args{
lockDuration: 3 * time.Second,
instanceIDs: []string{"instanceID1", "instanceID2"},
},
},
}
@@ -181,14 +216,16 @@ func TestStatementHandler_renewLock(t *testing.T) {
projectionName: projectionName,
client: client,
workerName: workerName,
lockStmt: fmt.Sprintf(lockStmtFormat, lockTable),
lockStmt: func(values string, instances int) string {
return fmt.Sprintf(lockStmtFormat, lockTable, values, instances)
},
}
for _, expectation := range tt.want.expectations {
expectation(mock)
}
err = h.renewLock(context.Background(), tt.args.lockDuration, tt.args.instanceID)
err = h.renewLock(context.Background(), tt.args.lockDuration, tt.args.instanceIDs...)
if !tt.want.isErr(err) {
t.Errorf("unexpected error = %v", err)
}
@@ -253,7 +290,9 @@ func TestStatementHandler_Unlock(t *testing.T) {
projectionName: projectionName,
client: client,
workerName: workerName,
lockStmt: fmt.Sprintf(lockStmtFormat, lockTable),
lockStmt: func(values string, instances int) string {
return fmt.Sprintf(lockStmtFormat, lockTable, values, instances)
},
}
for _, expectation := range tt.want.expectations {