mirror of
https://github.com/zitadel/zitadel.git
synced 2025-08-11 19:17:32 +00:00
feat(storage): read only transactions (#6417)
feat(storage): read only transactions for queries (#6415) * fix: tests * bastle wie en grosse * fix(database): scan as callback * fix tests * fix merge failures * remove as of system time * refactor: remove unused test * refacotr: remove unused lines
This commit is contained in:
@@ -12,9 +12,10 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
currentSequenceStmtFormat = `SELECT current_sequence, aggregate_type, instance_id FROM %s WHERE projection_name = $1 AND instance_id = ANY ($2) FOR UPDATE`
|
||||
updateCurrentSequencesStmtFormat = `INSERT INTO %s (projection_name, aggregate_type, current_sequence, instance_id, timestamp) VALUES `
|
||||
updateCurrentSequencesConflictStmt = ` ON CONFLICT (projection_name, aggregate_type, instance_id) DO UPDATE SET current_sequence = EXCLUDED.current_sequence, timestamp = EXCLUDED.timestamp`
|
||||
currentSequenceStmtFormat = `SELECT current_sequence, aggregate_type, instance_id FROM %s WHERE projection_name = $1 AND instance_id = ANY ($2) FOR UPDATE`
|
||||
currentSequenceStmtWithoutLockFormat = `SELECT current_sequence, aggregate_type, instance_id FROM %s WHERE projection_name = $1 AND instance_id = ANY ($2)`
|
||||
updateCurrentSequencesStmtFormat = `INSERT INTO %s (projection_name, aggregate_type, current_sequence, instance_id, timestamp) VALUES `
|
||||
updateCurrentSequencesConflictStmt = ` ON CONFLICT (projection_name, aggregate_type, instance_id) DO UPDATE SET current_sequence = EXCLUDED.current_sequence, timestamp = EXCLUDED.timestamp`
|
||||
)
|
||||
|
||||
type currentSequences map[eventstore.AggregateType][]*instanceSequence
|
||||
@@ -24,41 +25,38 @@ type instanceSequence struct {
|
||||
sequence uint64
|
||||
}
|
||||
|
||||
func (h *StatementHandler) currentSequences(ctx context.Context, query func(context.Context, string, ...interface{}) (*sql.Rows, error), instanceIDs database.StringArray) (currentSequences, error) {
|
||||
rows, err := query(ctx, h.currentSequenceStmt, h.ProjectionName, instanceIDs)
|
||||
func (h *StatementHandler) currentSequences(ctx context.Context, isTx bool, query func(context.Context, func(*sql.Rows) error, string, ...interface{}) error, instanceIDs database.StringArray) (currentSequences, error) {
|
||||
stmt := h.currentSequenceStmt
|
||||
if !isTx {
|
||||
stmt = h.currentSequenceWithoutLockStmt
|
||||
}
|
||||
|
||||
sequences := make(currentSequences, len(h.aggregates))
|
||||
err := query(ctx,
|
||||
func(rows *sql.Rows) error {
|
||||
for rows.Next() {
|
||||
var (
|
||||
aggregateType eventstore.AggregateType
|
||||
sequence uint64
|
||||
instanceID string
|
||||
)
|
||||
|
||||
err := rows.Scan(&sequence, &aggregateType, &instanceID)
|
||||
if err != nil {
|
||||
return errors.ThrowInternal(err, "CRDB-dbatK", "scan failed")
|
||||
}
|
||||
|
||||
sequences[aggregateType] = append(sequences[aggregateType], &instanceSequence{
|
||||
sequence: sequence,
|
||||
instanceID: instanceID,
|
||||
})
|
||||
}
|
||||
return nil
|
||||
},
|
||||
stmt, h.ProjectionName, instanceIDs)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
defer rows.Close()
|
||||
|
||||
sequences := make(currentSequences, len(h.aggregates))
|
||||
for rows.Next() {
|
||||
var (
|
||||
aggregateType eventstore.AggregateType
|
||||
sequence uint64
|
||||
instanceID string
|
||||
)
|
||||
|
||||
err = rows.Scan(&sequence, &aggregateType, &instanceID)
|
||||
if err != nil {
|
||||
return nil, errors.ThrowInternal(err, "CRDB-dbatK", "scan failed")
|
||||
}
|
||||
|
||||
sequences[aggregateType] = append(sequences[aggregateType], &instanceSequence{
|
||||
sequence: sequence,
|
||||
instanceID: instanceID,
|
||||
})
|
||||
}
|
||||
|
||||
if err = rows.Close(); err != nil {
|
||||
return nil, errors.ThrowInternal(err, "CRDB-h5i5m", "close rows failed")
|
||||
}
|
||||
|
||||
if err = rows.Err(); err != nil {
|
||||
return nil, errors.ThrowInternal(err, "CRDB-O8zig", "errors in scanning rows")
|
||||
}
|
||||
|
||||
return sequences, nil
|
||||
}
|
||||
|
||||
|
@@ -124,13 +124,17 @@ func expectSavePointRelease() func(sqlmock.Sqlmock) {
|
||||
}
|
||||
}
|
||||
|
||||
func expectCurrentSequence(tableName, projection string, seq uint64, aggregateType string, instanceIDs []string) func(sqlmock.Sqlmock) {
|
||||
func expectCurrentSequence(isTx bool, tableName, projection string, seq uint64, aggregateType string, instanceIDs []string) func(sqlmock.Sqlmock) {
|
||||
rows := sqlmock.NewRows([]string{"current_sequence", "aggregate_type", "instance_id"})
|
||||
for _, instanceID := range instanceIDs {
|
||||
rows.AddRow(seq, aggregateType, instanceID)
|
||||
}
|
||||
return func(m sqlmock.Sqlmock) {
|
||||
m.ExpectQuery(`SELECT current_sequence, aggregate_type, instance_id FROM `+tableName+` WHERE projection_name = \$1 AND instance_id = ANY \(\$2\) FOR UPDATE`).
|
||||
stmt := `SELECT current_sequence, aggregate_type, instance_id FROM ` + tableName + ` WHERE projection_name = \$1 AND instance_id = ANY \(\$2\)`
|
||||
if isTx {
|
||||
stmt += " FOR UPDATE"
|
||||
}
|
||||
m.ExpectQuery(stmt).
|
||||
WithArgs(
|
||||
projection,
|
||||
database.StringArray(instanceIDs),
|
||||
@@ -141,9 +145,13 @@ func expectCurrentSequence(tableName, projection string, seq uint64, aggregateTy
|
||||
}
|
||||
}
|
||||
|
||||
func expectCurrentSequenceErr(tableName, projection string, instanceIDs []string, err error) func(sqlmock.Sqlmock) {
|
||||
func expectCurrentSequenceErr(isTx bool, tableName, projection string, instanceIDs []string, err error) func(sqlmock.Sqlmock) {
|
||||
return func(m sqlmock.Sqlmock) {
|
||||
m.ExpectQuery(`SELECT current_sequence, aggregate_type, instance_id FROM `+tableName+` WHERE projection_name = \$1 AND instance_id = ANY \(\$2\) FOR UPDATE`).
|
||||
stmt := `SELECT current_sequence, aggregate_type, instance_id FROM ` + tableName + ` WHERE projection_name = \$1 AND instance_id = ANY \(\$2\)`
|
||||
if isTx {
|
||||
stmt += " FOR UPDATE"
|
||||
}
|
||||
m.ExpectQuery(stmt).
|
||||
WithArgs(
|
||||
projection,
|
||||
database.StringArray(instanceIDs),
|
||||
|
@@ -36,13 +36,14 @@ type StatementHandler struct {
|
||||
*handler.ProjectionHandler
|
||||
Locker
|
||||
|
||||
client *database.DB
|
||||
sequenceTable string
|
||||
currentSequenceStmt string
|
||||
updateSequencesBaseStmt string
|
||||
maxFailureCount uint
|
||||
failureCountStmt string
|
||||
setFailureCountStmt string
|
||||
client *database.DB
|
||||
sequenceTable string
|
||||
currentSequenceStmt string
|
||||
currentSequenceWithoutLockStmt string
|
||||
updateSequencesBaseStmt string
|
||||
maxFailureCount uint
|
||||
failureCountStmt string
|
||||
setFailureCountStmt string
|
||||
|
||||
aggregates []eventstore.AggregateType
|
||||
reduces map[eventstore.EventType]handler.Reduce
|
||||
@@ -77,20 +78,21 @@ func NewStatementHandler(
|
||||
}
|
||||
|
||||
h := StatementHandler{
|
||||
client: config.Client,
|
||||
sequenceTable: config.SequenceTable,
|
||||
maxFailureCount: config.MaxFailureCount,
|
||||
currentSequenceStmt: fmt.Sprintf(currentSequenceStmtFormat, config.SequenceTable),
|
||||
updateSequencesBaseStmt: fmt.Sprintf(updateCurrentSequencesStmtFormat, config.SequenceTable),
|
||||
failureCountStmt: fmt.Sprintf(failureCountStmtFormat, config.FailedEventsTable),
|
||||
setFailureCountStmt: fmt.Sprintf(setFailureCountStmtFormat, config.FailedEventsTable),
|
||||
aggregates: aggregateTypes,
|
||||
reduces: reduces,
|
||||
bulkLimit: config.BulkLimit,
|
||||
Locker: NewLocker(config.Client.DB, config.LockTable, config.ProjectionName),
|
||||
initCheck: config.InitCheck,
|
||||
initialized: make(chan bool),
|
||||
reduceScheduledPseudoEvent: reduceScheduledPseudoEvent,
|
||||
client: config.Client,
|
||||
sequenceTable: config.SequenceTable,
|
||||
maxFailureCount: config.MaxFailureCount,
|
||||
currentSequenceStmt: fmt.Sprintf(currentSequenceStmtFormat, config.SequenceTable),
|
||||
currentSequenceWithoutLockStmt: fmt.Sprintf(currentSequenceStmtWithoutLockFormat, config.SequenceTable),
|
||||
updateSequencesBaseStmt: fmt.Sprintf(updateCurrentSequencesStmtFormat, config.SequenceTable),
|
||||
failureCountStmt: fmt.Sprintf(failureCountStmtFormat, config.FailedEventsTable),
|
||||
setFailureCountStmt: fmt.Sprintf(setFailureCountStmtFormat, config.FailedEventsTable),
|
||||
aggregates: aggregateTypes,
|
||||
reduces: reduces,
|
||||
bulkLimit: config.BulkLimit,
|
||||
Locker: NewLocker(config.Client.DB, config.LockTable, config.ProjectionName),
|
||||
initCheck: config.InitCheck,
|
||||
initialized: make(chan bool),
|
||||
reduceScheduledPseudoEvent: reduceScheduledPseudoEvent,
|
||||
}
|
||||
|
||||
h.ProjectionHandler = handler.NewProjectionHandler(ctx, config.ProjectionHandlerConfig, h.reduce, h.Update, h.searchQuery, h.Lock, h.Unlock, h.initialized, reduceScheduledPseudoEvent)
|
||||
@@ -114,7 +116,7 @@ func (h *StatementHandler) searchQuery(ctx context.Context, instanceIDs []string
|
||||
}
|
||||
|
||||
func (h *StatementHandler) dbSearchQuery(ctx context.Context, instanceIDs []string) (*eventstore.SearchQueryBuilder, uint64, error) {
|
||||
sequences, err := h.currentSequences(ctx, h.client.QueryContext, instanceIDs)
|
||||
sequences, err := h.currentSequences(ctx, false, h.client.QueryContext, instanceIDs)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
@@ -140,6 +142,26 @@ func (h *StatementHandler) dbSearchQuery(ctx context.Context, instanceIDs []stri
|
||||
return queryBuilder, h.bulkLimit, nil
|
||||
}
|
||||
|
||||
type transaction struct {
|
||||
*sql.Tx
|
||||
}
|
||||
|
||||
func (t *transaction) QueryContext(ctx context.Context, scan func(rows *sql.Rows) error, query string, args ...any) error {
|
||||
rows, err := t.Tx.QueryContext(ctx, query, args...)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer func() {
|
||||
closeErr := rows.Close()
|
||||
logging.OnError(closeErr).Info("rows.Close failed")
|
||||
}()
|
||||
|
||||
if err = scan(rows); err != nil {
|
||||
return err
|
||||
}
|
||||
return rows.Err()
|
||||
}
|
||||
|
||||
// Update implements handler.Update
|
||||
func (h *StatementHandler) Update(ctx context.Context, stmts []*handler.Statement, reduce handler.Reduce) (index int, err error) {
|
||||
if len(stmts) == 0 {
|
||||
@@ -154,7 +176,7 @@ func (h *StatementHandler) Update(ctx context.Context, stmts []*handler.Statemen
|
||||
return -1, errors.ThrowInternal(err, "CRDB-e89Gq", "begin failed")
|
||||
}
|
||||
|
||||
sequences, err := h.currentSequences(ctx, tx.QueryContext, instanceIDs)
|
||||
sequences, err := h.currentSequences(ctx, true, (&transaction{Tx: tx}).QueryContext, instanceIDs)
|
||||
if err != nil {
|
||||
tx.Rollback()
|
||||
return -1, err
|
||||
|
@@ -90,7 +90,9 @@ func TestProjectionHandler_SearchQuery(t *testing.T) {
|
||||
return errors.Is(err, sql.ErrTxDone)
|
||||
},
|
||||
expectations: []mockExpectation{
|
||||
expectCurrentSequenceErr("my_sequences", "my_projection", []string{"instanceID1"}, sql.ErrTxDone),
|
||||
expectBegin(),
|
||||
expectCurrentSequenceErr(false, "my_sequences", "my_projection", []string{"instanceID1"}, sql.ErrTxDone),
|
||||
expectRollback(),
|
||||
},
|
||||
SearchQueryBuilder: nil,
|
||||
},
|
||||
@@ -112,7 +114,9 @@ func TestProjectionHandler_SearchQuery(t *testing.T) {
|
||||
return err == nil
|
||||
},
|
||||
expectations: []mockExpectation{
|
||||
expectCurrentSequence("my_sequences", "my_projection", 5, "testAgg", []string{"instanceID1"}),
|
||||
expectBegin(),
|
||||
expectCurrentSequence(false, "my_sequences", "my_projection", 5, "testAgg", []string{"instanceID1"}),
|
||||
expectCommit(),
|
||||
},
|
||||
SearchQueryBuilder: eventstore.
|
||||
NewSearchQueryBuilder(eventstore.ColumnsEvent).
|
||||
@@ -142,7 +146,9 @@ func TestProjectionHandler_SearchQuery(t *testing.T) {
|
||||
return err == nil
|
||||
},
|
||||
expectations: []mockExpectation{
|
||||
expectCurrentSequence("my_sequences", "my_projection", 5, "testAgg", []string{"instanceID1", "instanceID2"}),
|
||||
expectBegin(),
|
||||
expectCurrentSequence(false, "my_sequences", "my_projection", 5, "testAgg", []string{"instanceID1", "instanceID2"}),
|
||||
expectCommit(),
|
||||
},
|
||||
SearchQueryBuilder: eventstore.
|
||||
NewSearchQueryBuilder(eventstore.ColumnsEvent).
|
||||
@@ -216,6 +222,7 @@ func TestProjectionHandler_SearchQuery(t *testing.T) {
|
||||
t.Errorf("ProjectionHandler.prepareBulkStmts() error = %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(query, tt.want.SearchQueryBuilder) {
|
||||
t.Errorf("unexpected query: expected %v, got %v", tt.want.SearchQueryBuilder, query)
|
||||
}
|
||||
@@ -289,7 +296,7 @@ func TestStatementHandler_Update(t *testing.T) {
|
||||
want: want{
|
||||
expectations: []mockExpectation{
|
||||
expectBegin(),
|
||||
expectCurrentSequenceErr("my_sequences", "my_projection", []string{"instanceID"}, sql.ErrTxDone),
|
||||
expectCurrentSequenceErr(false, "my_sequences", "my_projection", []string{"instanceID"}, sql.ErrTxDone),
|
||||
expectRollback(),
|
||||
},
|
||||
isErr: func(err error) bool {
|
||||
@@ -321,7 +328,7 @@ func TestStatementHandler_Update(t *testing.T) {
|
||||
want: want{
|
||||
expectations: []mockExpectation{
|
||||
expectBegin(),
|
||||
expectCurrentSequence("my_sequences", "my_projection", 5, "testAgg", []string{"instanceID"}),
|
||||
expectCurrentSequence(false, "my_sequences", "my_projection", 5, "testAgg", []string{"instanceID"}),
|
||||
expectRollback(),
|
||||
},
|
||||
isErr: func(err error) bool {
|
||||
@@ -360,7 +367,7 @@ func TestStatementHandler_Update(t *testing.T) {
|
||||
want: want{
|
||||
expectations: []mockExpectation{
|
||||
expectBegin(),
|
||||
expectCurrentSequence("my_sequences", "my_projection", 5, "testAgg", []string{"instanceID"}),
|
||||
expectCurrentSequence(false, "my_sequences", "my_projection", 5, "testAgg", []string{"instanceID"}),
|
||||
expectCommit(),
|
||||
},
|
||||
isErr: func(err error) bool {
|
||||
@@ -399,7 +406,7 @@ func TestStatementHandler_Update(t *testing.T) {
|
||||
want: want{
|
||||
expectations: []mockExpectation{
|
||||
expectBegin(),
|
||||
expectCurrentSequence("my_sequences", "my_projection", 5, "agg", []string{"instanceID"}),
|
||||
expectCurrentSequence(false, "my_sequences", "my_projection", 5, "agg", []string{"instanceID"}),
|
||||
expectSavePoint(),
|
||||
expectCreate("my_projection", []string{"col"}, []string{"$1"}),
|
||||
expectSavePointRelease(),
|
||||
@@ -442,7 +449,7 @@ func TestStatementHandler_Update(t *testing.T) {
|
||||
want: want{
|
||||
expectations: []mockExpectation{
|
||||
expectBegin(),
|
||||
expectCurrentSequence("my_sequences", "my_projection", 5, "agg", []string{"instanceID"}),
|
||||
expectCurrentSequence(false, "my_sequences", "my_projection", 5, "agg", []string{"instanceID"}),
|
||||
expectSavePoint(),
|
||||
expectCreate("my_projection", []string{"col"}, []string{"$1"}),
|
||||
expectSavePointRelease(),
|
||||
@@ -478,7 +485,7 @@ func TestStatementHandler_Update(t *testing.T) {
|
||||
want: want{
|
||||
expectations: []mockExpectation{
|
||||
expectBegin(),
|
||||
expectCurrentSequence("my_sequences", "my_projection", 5, "testAgg", []string{"instanceID"}),
|
||||
expectCurrentSequence(false, "my_sequences", "my_projection", 5, "testAgg", []string{"instanceID"}),
|
||||
expectUpdateCurrentSequence("my_sequences", "my_projection", 7, "testAgg", "instanceID"),
|
||||
expectCommit(),
|
||||
},
|
||||
@@ -511,7 +518,7 @@ func TestStatementHandler_Update(t *testing.T) {
|
||||
want: want{
|
||||
expectations: []mockExpectation{
|
||||
expectBegin(),
|
||||
expectCurrentSequence("my_sequences", "my_projection", 5, "testAgg", []string{"instanceID"}),
|
||||
expectCurrentSequence(false, "my_sequences", "my_projection", 5, "testAgg", []string{"instanceID"}),
|
||||
expectUpdateCurrentSequence("my_sequences", "my_projection", 7, "testAgg", "instanceID"),
|
||||
expectCommit(),
|
||||
},
|
||||
@@ -551,7 +558,7 @@ func TestStatementHandler_Update(t *testing.T) {
|
||||
want: want{
|
||||
expectations: []mockExpectation{
|
||||
expectBegin(),
|
||||
expectCurrentSequence("my_sequences", "my_projection", 5, "testAgg", []string{"instanceID"}),
|
||||
expectCurrentSequence(false, "my_sequences", "my_projection", 5, "testAgg", []string{"instanceID"}),
|
||||
expectUpdateCurrentSequence("my_sequences", "my_projection", 7, "testAgg", "instanceID"),
|
||||
expectCommit(),
|
||||
},
|
||||
@@ -1425,7 +1432,7 @@ func TestStatementHandler_currentSequence(t *testing.T) {
|
||||
return errors.Is(err, sql.ErrConnDone)
|
||||
},
|
||||
expectations: []mockExpectation{
|
||||
expectCurrentSequenceErr("my_table", "my_projection", nil, sql.ErrConnDone),
|
||||
expectCurrentSequenceErr(true, "my_table", "my_projection", nil, sql.ErrConnDone),
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -1487,7 +1494,7 @@ func TestStatementHandler_currentSequence(t *testing.T) {
|
||||
return errors.Is(err, nil)
|
||||
},
|
||||
expectations: []mockExpectation{
|
||||
expectCurrentSequence("my_table", "my_projection", 5, "agg", []string{"instanceID"}),
|
||||
expectCurrentSequence(true, "my_table", "my_projection", 5, "agg", []string{"instanceID"}),
|
||||
},
|
||||
sequences: currentSequences{
|
||||
"agg": []*instanceSequence{
|
||||
@@ -1515,7 +1522,7 @@ func TestStatementHandler_currentSequence(t *testing.T) {
|
||||
return errors.Is(err, nil)
|
||||
},
|
||||
expectations: []mockExpectation{
|
||||
expectCurrentSequence("my_table", "my_projection", 5, "agg", []string{"instanceID1", "instanceID2"}),
|
||||
expectCurrentSequence(true, "my_table", "my_projection", 5, "agg", []string{"instanceID1", "instanceID2"}),
|
||||
},
|
||||
sequences: currentSequences{
|
||||
"agg": []*instanceSequence{
|
||||
@@ -1563,7 +1570,7 @@ func TestStatementHandler_currentSequence(t *testing.T) {
|
||||
t.Fatalf("unexpected err in begin: %v", err)
|
||||
}
|
||||
|
||||
seq, err := h.currentSequences(context.Background(), tx.QueryContext, tt.args.instanceIDs)
|
||||
seq, err := h.currentSequences(context.Background(), true, (&transaction{Tx: tx}).QueryContext, tt.args.instanceIDs)
|
||||
if !tt.want.isErr(err) {
|
||||
t.Errorf("unexpected error: %v", err)
|
||||
}
|
||||
|
@@ -161,13 +161,18 @@ func (db *CRDB) Push(ctx context.Context, events []*repository.Event, uniqueCons
|
||||
var instanceRegexp = regexp.MustCompile(`eventstore\.i_[0-9a-zA-Z]{1,}_seq`)
|
||||
|
||||
func (db *CRDB) CreateInstance(ctx context.Context, instanceID string) error {
|
||||
row := db.QueryRowContext(ctx, "SELECT CONCAT('eventstore.i_', $1::TEXT, '_seq')", instanceID)
|
||||
if row.Err() != nil {
|
||||
return caos_errs.ThrowInvalidArgument(row.Err(), "SQL-7gtFA", "Errors.InvalidArgument")
|
||||
}
|
||||
var sequenceName string
|
||||
if err := row.Scan(&sequenceName); err != nil || !instanceRegexp.MatchString(sequenceName) {
|
||||
return caos_errs.ThrowInvalidArgument(err, "SQL-7gtFA", "Errors.InvalidArgument")
|
||||
err := db.QueryRowContext(ctx,
|
||||
func(row *sql.Row) error {
|
||||
if err := row.Scan(&sequenceName); err != nil || !instanceRegexp.MatchString(sequenceName) {
|
||||
return caos_errs.ThrowInvalidArgument(err, "SQL-7gtFA", "Errors.InvalidArgument")
|
||||
}
|
||||
return nil
|
||||
},
|
||||
"SELECT CONCAT('eventstore.i_', $1::TEXT, '_seq')", instanceID,
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if _, err := db.ExecContext(ctx, "CREATE SEQUENCE "+sequenceName); err != nil {
|
||||
@@ -220,9 +225,9 @@ func (db *CRDB) handleUniqueConstraints(ctx context.Context, tx *sql.Tx, uniqueC
|
||||
}
|
||||
|
||||
// Filter returns all events matching the given search query
|
||||
func (db *CRDB) Filter(ctx context.Context, searchQuery *repository.SearchQuery) (events []*repository.Event, err error) {
|
||||
func (crdb *CRDB) Filter(ctx context.Context, searchQuery *repository.SearchQuery) (events []*repository.Event, err error) {
|
||||
events = []*repository.Event{}
|
||||
err = query(ctx, db, searchQuery, &events)
|
||||
err = query(ctx, crdb, searchQuery, &events)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -250,8 +255,8 @@ func (db *CRDB) InstanceIDs(ctx context.Context, searchQuery *repository.SearchQ
|
||||
return ids, nil
|
||||
}
|
||||
|
||||
func (db *CRDB) db() *sql.DB {
|
||||
return db.DB.DB
|
||||
func (db *CRDB) db() *database.DB {
|
||||
return db.DB
|
||||
}
|
||||
|
||||
func (db *CRDB) orderByEventSequence(desc bool) string {
|
||||
|
@@ -11,6 +11,7 @@ import (
|
||||
"github.com/zitadel/logging"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/api/call"
|
||||
"github.com/zitadel/zitadel/internal/database"
|
||||
"github.com/zitadel/zitadel/internal/database/dialect"
|
||||
z_errors "github.com/zitadel/zitadel/internal/errors"
|
||||
"github.com/zitadel/zitadel/internal/eventstore/repository"
|
||||
@@ -24,13 +25,33 @@ type querier interface {
|
||||
eventQuery() string
|
||||
maxSequenceQuery() string
|
||||
instanceIDsQuery() string
|
||||
db() *sql.DB
|
||||
db() *database.DB
|
||||
orderByEventSequence(desc bool) string
|
||||
dialect.Database
|
||||
}
|
||||
|
||||
type scan func(dest ...interface{}) error
|
||||
|
||||
type tx struct {
|
||||
*sql.Tx
|
||||
}
|
||||
|
||||
func (t *tx) QueryContext(ctx context.Context, scan func(rows *sql.Rows) error, query string, args ...any) error {
|
||||
rows, err := t.Tx.QueryContext(ctx, query, args...)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer func() {
|
||||
closeErr := rows.Close()
|
||||
logging.OnError(closeErr).Info("rows.Close failed")
|
||||
}()
|
||||
|
||||
if err = scan(rows); err != nil {
|
||||
return err
|
||||
}
|
||||
return rows.Err()
|
||||
}
|
||||
|
||||
func query(ctx context.Context, criteria querier, searchQuery *repository.SearchQuery, dest interface{}) error {
|
||||
query, rowScanner := prepareColumns(criteria, searchQuery.Columns)
|
||||
where, values := prepareCondition(criteria, searchQuery.Filters)
|
||||
@@ -56,26 +77,27 @@ func query(ctx context.Context, criteria querier, searchQuery *repository.Search
|
||||
query = criteria.placeholder(query)
|
||||
|
||||
var contextQuerier interface {
|
||||
QueryContext(context.Context, string, ...interface{}) (*sql.Rows, error)
|
||||
QueryContext(context.Context, func(rows *sql.Rows) error, string, ...interface{}) error
|
||||
}
|
||||
contextQuerier = criteria.db()
|
||||
if searchQuery.Tx != nil {
|
||||
contextQuerier = searchQuery.Tx
|
||||
contextQuerier = &tx{Tx: searchQuery.Tx}
|
||||
}
|
||||
|
||||
rows, err := contextQuerier.QueryContext(ctx, query, values...)
|
||||
err := contextQuerier.QueryContext(ctx,
|
||||
func(rows *sql.Rows) error {
|
||||
for rows.Next() {
|
||||
err := rowScanner(rows.Scan, dest)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}, query, values...)
|
||||
if err != nil {
|
||||
logging.New().WithError(err).Info("query failed")
|
||||
return z_errors.ThrowInternal(err, "SQL-KyeAx", "unable to filter events")
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
for rows.Next() {
|
||||
err = rowScanner(rows.Scan, dest)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
@@ -741,7 +741,7 @@ func Test_query_events_mocked(t *testing.T) {
|
||||
},
|
||||
},
|
||||
fields: fields{
|
||||
mock: newMockClient(t).expectQuery(t,
|
||||
mock: newMockClient(t).expectQueryScanErr(t,
|
||||
`SELECT creation_date, event_type, event_sequence, previous_aggregate_sequence, previous_aggregate_type_sequence, event_data, editor_service, editor_user, resource_owner, instance_id, aggregate_type, aggregate_id, aggregate_version FROM eventstore.events WHERE \( aggregate_type = \$1 \) ORDER BY creation_date DESC, event_sequence DESC`,
|
||||
[]driver.Value{repository.AggregateType("user")},
|
||||
&repository.Event{Sequence: 100}),
|
||||
@@ -853,7 +853,21 @@ type dbMock struct {
|
||||
}
|
||||
|
||||
func (m *dbMock) expectQuery(t *testing.T, expectedQuery string, args []driver.Value, events ...*repository.Event) *dbMock {
|
||||
m.mock.ExpectBegin()
|
||||
query := m.mock.ExpectQuery(expectedQuery).WithArgs(args...)
|
||||
m.mock.ExpectCommit()
|
||||
rows := sqlmock.NewRows([]string{"event_sequence"})
|
||||
for _, event := range events {
|
||||
rows = rows.AddRow(event.Sequence)
|
||||
}
|
||||
query.WillReturnRows(rows).RowsWillBeClosed()
|
||||
return m
|
||||
}
|
||||
|
||||
func (m *dbMock) expectQueryScanErr(t *testing.T, expectedQuery string, args []driver.Value, events ...*repository.Event) *dbMock {
|
||||
m.mock.ExpectBegin()
|
||||
query := m.mock.ExpectQuery(expectedQuery).WithArgs(args...)
|
||||
m.mock.ExpectRollback()
|
||||
rows := sqlmock.NewRows([]string{"event_sequence"})
|
||||
for _, event := range events {
|
||||
rows = rows.AddRow(event.Sequence)
|
||||
@@ -863,6 +877,7 @@ func (m *dbMock) expectQuery(t *testing.T, expectedQuery string, args []driver.V
|
||||
}
|
||||
|
||||
func (m *dbMock) expectQueryErr(t *testing.T, expectedQuery string, args []driver.Value, err error) *dbMock {
|
||||
m.mock.ExpectBegin()
|
||||
m.mock.ExpectQuery(expectedQuery).WithArgs(args...).WillReturnError(err)
|
||||
return m
|
||||
}
|
||||
|
@@ -127,9 +127,11 @@ func (db *dbMock) expectFilterEventsLimit(aggregateType string, limit uint64, ev
|
||||
for i := 0; i < eventCount; i++ {
|
||||
rows.AddRow(time.Now(), "eventType", Sequence(i+1), Sequence(i), nil, "svc", "hodor", "org", "instanceID", "aggType", "aggID", "v1.0.0")
|
||||
}
|
||||
db.mock.ExpectBegin()
|
||||
db.mock.ExpectQuery(expectedFilterEventsLimitFormat).
|
||||
WithArgs(aggregateType, limit).
|
||||
WillReturnRows(rows)
|
||||
db.mock.ExpectCommit()
|
||||
return db
|
||||
}
|
||||
|
||||
@@ -138,8 +140,10 @@ func (db *dbMock) expectFilterEventsDesc(aggregateType string, eventCount int) *
|
||||
for i := eventCount; i > 0; i-- {
|
||||
rows.AddRow(time.Now(), "eventType", Sequence(i+1), Sequence(i), nil, "svc", "hodor", "org", "instanceID", "aggType", "aggID", "v1.0.0")
|
||||
}
|
||||
db.mock.ExpectBegin()
|
||||
db.mock.ExpectQuery(expectedFilterEventsDescFormat).
|
||||
WillReturnRows(rows)
|
||||
db.mock.ExpectCommit()
|
||||
return db
|
||||
}
|
||||
|
||||
@@ -148,9 +152,11 @@ func (db *dbMock) expectFilterEventsAggregateIDLimit(aggregateType, aggregateID
|
||||
for i := limit; i > 0; i-- {
|
||||
rows.AddRow(time.Now(), "eventType", Sequence(i+1), Sequence(i), nil, "svc", "hodor", "org", "instanceID", "aggType", "aggID", "v1.0.0")
|
||||
}
|
||||
db.mock.ExpectBegin()
|
||||
db.mock.ExpectQuery(expectedFilterEventsAggregateIDLimit).
|
||||
WithArgs(aggregateType, aggregateID, limit).
|
||||
WillReturnRows(rows)
|
||||
db.mock.ExpectCommit()
|
||||
return db
|
||||
}
|
||||
|
||||
@@ -159,28 +165,36 @@ func (db *dbMock) expectFilterEventsAggregateIDTypeLimit(aggregateType, aggregat
|
||||
for i := limit; i > 0; i-- {
|
||||
rows.AddRow(time.Now(), "eventType", Sequence(i+1), Sequence(i), nil, "svc", "hodor", "org", "instanceID", "aggType", "aggID", "v1.0.0")
|
||||
}
|
||||
db.mock.ExpectBegin()
|
||||
db.mock.ExpectQuery(expectedFilterEventsAggregateIDTypeLimit).
|
||||
WithArgs(aggregateType, aggregateID, limit).
|
||||
WillReturnRows(rows)
|
||||
db.mock.ExpectCommit()
|
||||
return db
|
||||
}
|
||||
|
||||
func (db *dbMock) expectFilterEventsError(returnedErr error) *dbMock {
|
||||
db.mock.ExpectBegin()
|
||||
db.mock.ExpectQuery(expectedGetAllEvents).
|
||||
WillReturnError(returnedErr)
|
||||
db.mock.ExpectRollback()
|
||||
return db
|
||||
}
|
||||
|
||||
func (db *dbMock) expectLatestSequenceFilter(aggregateType string, sequence Sequence) *dbMock {
|
||||
db.mock.ExpectBegin()
|
||||
db.mock.ExpectQuery(`SELECT MAX\(event_sequence\) FROM eventstore\.events AS OF SYSTEM TIME '-1 ms' WHERE \( aggregate_type = \$1 \)`).
|
||||
WithArgs(aggregateType).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"max_sequence"}).AddRow(sequence))
|
||||
db.mock.ExpectCommit()
|
||||
return db
|
||||
}
|
||||
|
||||
func (db *dbMock) expectLatestSequenceFilterError(aggregateType string, err error) *dbMock {
|
||||
db.mock.ExpectBegin()
|
||||
db.mock.ExpectQuery(`SELECT MAX\(event_sequence\) FROM eventstore\.events AS OF SYSTEM TIME '-1 ms' WHERE \( aggregate_type = \$1 \)`).
|
||||
WithArgs(aggregateType).WillReturnError(err)
|
||||
// db.mock.ExpectRollback()
|
||||
return db
|
||||
}
|
||||
|
||||
|
@@ -3,12 +3,13 @@ package sql
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"runtime/debug"
|
||||
|
||||
"github.com/zitadel/logging"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/database"
|
||||
"github.com/zitadel/zitadel/internal/errors"
|
||||
errs "github.com/zitadel/zitadel/internal/errors"
|
||||
es_models "github.com/zitadel/zitadel/internal/eventstore/v1/models"
|
||||
"github.com/zitadel/zitadel/internal/telemetry/tracing"
|
||||
)
|
||||
@@ -24,72 +25,74 @@ func (db *SQL) Filter(ctx context.Context, searchQuery *es_models.SearchQueryFac
|
||||
return db.filter(ctx, db.client, searchQuery)
|
||||
}
|
||||
|
||||
func (sql *SQL) filter(ctx context.Context, db *database.DB, searchQuery *es_models.SearchQueryFactory) (events []*es_models.Event, err error) {
|
||||
query, limit, values, rowScanner := sql.buildQuery(ctx, db, searchQuery)
|
||||
func (server *SQL) filter(ctx context.Context, db *database.DB, searchQuery *es_models.SearchQueryFactory) (events []*es_models.Event, err error) {
|
||||
query, limit, values, rowScanner := server.buildQuery(ctx, db, searchQuery)
|
||||
if query == "" {
|
||||
return nil, errors.ThrowInvalidArgument(nil, "SQL-rWeBw", "invalid query factory")
|
||||
return nil, errs.ThrowInvalidArgument(nil, "SQL-rWeBw", "invalid query factory")
|
||||
}
|
||||
|
||||
rows, err := db.Query(query, values...)
|
||||
if err != nil {
|
||||
logging.New().WithError(err).Info("query failed")
|
||||
return nil, errors.ThrowInternal(err, "SQL-IJuyR", "unable to filter events")
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
events = make([]*es_models.Event, 0, limit)
|
||||
err = db.QueryContext(ctx,
|
||||
func(rows *sql.Rows) error {
|
||||
for rows.Next() {
|
||||
event := new(es_models.Event)
|
||||
err := rowScanner(rows.Scan, event)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for rows.Next() {
|
||||
event := new(es_models.Event)
|
||||
err := rowScanner(rows.Scan, event)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
events = append(events, event)
|
||||
events = append(events, event)
|
||||
}
|
||||
return nil
|
||||
},
|
||||
query, values...,
|
||||
)
|
||||
if err != nil {
|
||||
logging.New().WithError(err).Info("query failed")
|
||||
return nil, errs.ThrowInternal(err, "SQL-IJuyR", "unable to filter events")
|
||||
}
|
||||
|
||||
return events, nil
|
||||
}
|
||||
|
||||
func (db *SQL) LatestSequence(ctx context.Context, queryFactory *es_models.SearchQueryFactory) (uint64, error) {
|
||||
query, _, values, rowScanner := db.buildQuery(ctx, db.client, queryFactory)
|
||||
if query == "" {
|
||||
return 0, errors.ThrowInvalidArgument(nil, "SQL-rWeBw", "invalid query factory")
|
||||
return 0, errs.ThrowInvalidArgument(nil, "SQL-rWeBw", "invalid query factory")
|
||||
}
|
||||
row := db.client.QueryRow(query, values...)
|
||||
sequence := new(Sequence)
|
||||
err := rowScanner(row.Scan, sequence)
|
||||
if err != nil {
|
||||
err := db.client.QueryRowContext(ctx, func(row *sql.Row) error {
|
||||
return rowScanner(row.Scan, sequence)
|
||||
}, query, values...)
|
||||
if err != nil && !errors.Is(err, sql.ErrNoRows) {
|
||||
logging.New().WithError(err).WithField("traceID", tracing.TraceIDFromCtx(ctx)).Info("query failed")
|
||||
return 0, errors.ThrowInternal(err, "SQL-Yczyx", "unable to filter latest sequence")
|
||||
return 0, errs.ThrowInternal(err, "SQL-Yczyx", "unable to filter latest sequence")
|
||||
}
|
||||
return uint64(*sequence), nil
|
||||
}
|
||||
|
||||
func (db *SQL) InstanceIDs(ctx context.Context, queryFactory *es_models.SearchQueryFactory) ([]string, error) {
|
||||
func (db *SQL) InstanceIDs(ctx context.Context, queryFactory *es_models.SearchQueryFactory) (ids []string, err error) {
|
||||
query, _, values, rowScanner := db.buildQuery(ctx, db.client, queryFactory)
|
||||
if query == "" {
|
||||
return nil, errors.ThrowInvalidArgument(nil, "SQL-Sfwg2", "invalid query factory")
|
||||
return nil, errs.ThrowInvalidArgument(nil, "SQL-Sfwg2", "invalid query factory")
|
||||
}
|
||||
|
||||
rows, err := db.client.Query(query, values...)
|
||||
err = db.client.QueryContext(ctx,
|
||||
func(rows *sql.Rows) error {
|
||||
for rows.Next() {
|
||||
var id string
|
||||
err := rowScanner(rows.Scan, &id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
ids = append(ids, id)
|
||||
}
|
||||
return nil
|
||||
},
|
||||
query, values...)
|
||||
if err != nil {
|
||||
logging.New().WithError(err).Info("query failed")
|
||||
return nil, errors.ThrowInternal(err, "SQL-Sfg3r", "unable to filter instance ids")
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
ids := make([]string, 0)
|
||||
|
||||
for rows.Next() {
|
||||
var id string
|
||||
err := rowScanner(rows.Scan, &id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ids = append(ids, id)
|
||||
return nil, errs.ThrowInternal(err, "SQL-Sfg3r", "unable to filter instance ids")
|
||||
}
|
||||
|
||||
return ids, nil
|
||||
|
@@ -130,6 +130,7 @@ func TestSQL_Filter(t *testing.T) {
|
||||
if (err != nil) != tt.res.wantErr {
|
||||
t.Errorf("SQL.Filter() error = %v, wantErr %v", err, tt.res.wantErr)
|
||||
}
|
||||
|
||||
if tt.res.eventsLen != 0 && len(events) != tt.res.eventsLen {
|
||||
t.Errorf("events has wrong length got: %d want %d", len(events), tt.res.eventsLen)
|
||||
}
|
||||
@@ -221,10 +222,12 @@ func TestSQL_LatestSequence(t *testing.T) {
|
||||
sql := &SQL{
|
||||
client: &database.DB{DB: tt.fields.client.sqlClient, Database: new(testDB)},
|
||||
}
|
||||
|
||||
sequence, err := sql.LatestSequence(context.Background(), tt.args.searchQuery)
|
||||
if (err != nil) != tt.res.wantErr {
|
||||
t.Errorf("SQL.Filter() error = %v, wantErr %v", err, tt.res.wantErr)
|
||||
}
|
||||
|
||||
if tt.res.sequence != sequence {
|
||||
t.Errorf("events has wrong length got: %d want %d", sequence, tt.res.sequence)
|
||||
}
|
||||
|
Reference in New Issue
Block a user