feat(eventstore): increase parallel write capabilities (#5940)

This implementation increases parallel write capabilities of the eventstore.
Please have a look at the technical advisories: [05](https://zitadel.com/docs/support/advisory/a10005) and  [06](https://zitadel.com/docs/support/advisory/a10006).
The implementation of eventstore.push is rewritten and stored events are migrated to a new table `eventstore.events2`.
If you are using cockroach: make sure that the database user of ZITADEL has `VIEWACTIVITY` grant. This is used to query events.
This commit is contained in:
Silvan
2023-10-19 12:19:10 +02:00
committed by GitHub
parent 259faba3f0
commit b5564572bc
791 changed files with 30326 additions and 43202 deletions

View File

@@ -3,6 +3,7 @@ package sql
import (
"context"
"database/sql"
"encoding/json"
"errors"
"regexp"
"strconv"
@@ -13,8 +14,10 @@ import (
"github.com/lib/pq"
"github.com/zitadel/logging"
"github.com/zitadel/zitadel/internal/api/authz"
"github.com/zitadel/zitadel/internal/database"
caos_errs "github.com/zitadel/zitadel/internal/errors"
"github.com/zitadel/zitadel/internal/eventstore"
"github.com/zitadel/zitadel/internal/eventstore/repository"
)
@@ -52,6 +55,7 @@ const (
" aggregate_id," +
" aggregate_version," +
" creation_date," +
" position," +
" event_data," +
" editor_user," +
" editor_service," +
@@ -59,7 +63,8 @@ const (
" instance_id," +
" event_sequence," +
" previous_aggregate_sequence," +
" previous_aggregate_type_sequence" +
" previous_aggregate_type_sequence," +
" in_tx_order" +
") " +
// defines the data to be inserted
"SELECT" +
@@ -67,17 +72,19 @@ const (
" $2::VARCHAR AS aggregate_type," +
" $3::VARCHAR AS aggregate_id," +
" $4::VARCHAR AS aggregate_version," +
" statement_timestamp() AS creation_date," +
" hlc_to_timestamp(cluster_logical_timestamp()) AS creation_date," +
" cluster_logical_timestamp() AS position," +
" $5::JSONB AS event_data," +
" $6::VARCHAR AS editor_user," +
" $7::VARCHAR AS editor_service," +
" COALESCE((resource_owner), $8::VARCHAR) AS resource_owner," +
" $9::VARCHAR AS instance_id," +
" NEXTVAL(CONCAT('eventstore.', (CASE WHEN $9 <> '' THEN CONCAT('i_', $9) ELSE 'system' END), '_seq'))," +
" COALESCE(aggregate_sequence, 0)+1," +
" aggregate_sequence AS previous_aggregate_sequence," +
" aggregate_type_sequence AS previous_aggregate_type_sequence " +
" aggregate_type_sequence AS previous_aggregate_type_sequence," +
" $10 AS in_tx_order " +
"FROM previous_data " +
"RETURNING id, event_sequence, previous_aggregate_sequence, previous_aggregate_type_sequence, creation_date, resource_owner, instance_id"
"RETURNING id, event_sequence, creation_date, resource_owner, instance_id"
uniqueInsert = `INSERT INTO eventstore.unique_constraints
(
@@ -97,93 +104,109 @@ const (
WHERE instance_id = $1`
)
type CRDB struct {
*database.DB
AllowOrderByCreationDate bool
// awaitOpenTransactions ensures event ordering, so we don't events younger that open transactions
var (
awaitOpenTransactionsV1 string
awaitOpenTransactionsV2 string
)
func awaitOpenTransactions(useV1 bool) string {
if useV1 {
return awaitOpenTransactionsV1
}
return awaitOpenTransactionsV2
}
func NewCRDB(client *database.DB, allowOrderByCreationDate bool) *CRDB {
return &CRDB{client, allowOrderByCreationDate}
type CRDB struct {
*database.DB
}
func NewCRDB(client *database.DB) *CRDB {
switch client.Type() {
case "cockroach":
awaitOpenTransactionsV1 = " AND creation_date::TIMESTAMP < (SELECT COALESCE(MIN(start), NOW())::TIMESTAMP FROM crdb_internal.cluster_transactions where application_name = '" + database.EventstorePusherAppName + "')"
awaitOpenTransactionsV2 = ` AND hlc_to_timestamp("position") < (SELECT COALESCE(MIN(start), NOW())::TIMESTAMP FROM crdb_internal.cluster_transactions where application_name = '` + database.EventstorePusherAppName + `')`
case "postgres":
awaitOpenTransactionsV1 = ` AND EXTRACT(EPOCH FROM created_at) < (SELECT COALESCE(EXTRACT(EPOCH FROM min(xact_start)), EXTRACT(EPOCH FROM now())) FROM pg_stat_activity WHERE datname = current_database() AND application_name = '` + database.EventstorePusherAppName + `' AND state <> 'idle')`
awaitOpenTransactionsV2 = ` AND "position" < (SELECT COALESCE(EXTRACT(EPOCH FROM min(xact_start)), EXTRACT(EPOCH FROM now())) FROM pg_stat_activity WHERE datname = current_database() AND application_name = '` + database.EventstorePusherAppName + `' AND state <> 'idle')`
}
return &CRDB{client}
}
func (db *CRDB) Health(ctx context.Context) error { return db.Ping() }
// Push adds all events to the eventstreams of the aggregates.
// This call is transaction save. The transaction will be rolled back if one event fails
func (db *CRDB) Push(ctx context.Context, events []*repository.Event, uniqueConstraints ...*repository.UniqueConstraint) error {
err := crdb.ExecuteTx(ctx, db.DB.DB, nil, func(tx *sql.Tx) error {
func (db *CRDB) Push(ctx context.Context, commands ...eventstore.Command) (events []eventstore.Event, err error) {
events = make([]eventstore.Event, len(commands))
err = crdb.ExecuteTx(ctx, db.DB.DB, nil, func(tx *sql.Tx) error {
var uniqueConstraints []*eventstore.UniqueConstraint
for i, command := range commands {
if command.Aggregate().InstanceID == "" {
command.Aggregate().InstanceID = authz.GetInstance(ctx).InstanceID()
}
var payload []byte
if command.Payload() != nil {
payload, err = json.Marshal(command.Payload())
if err != nil {
return err
}
}
e := &repository.Event{
Typ: command.Type(),
Data: payload,
EditorUser: command.Creator(),
Version: command.Aggregate().Version,
AggregateID: command.Aggregate().ID,
AggregateType: command.Aggregate().Type,
ResourceOwner: sql.NullString{String: command.Aggregate().ResourceOwner, Valid: command.Aggregate().ResourceOwner != ""},
InstanceID: command.Aggregate().InstanceID,
}
var (
previousAggregateSequence Sequence
previousAggregateTypeSequence Sequence
)
for _, event := range events {
err := tx.QueryRowContext(ctx, crdbInsert,
event.Type,
event.AggregateType,
event.AggregateID,
event.Version,
Data(event.Data),
event.EditorUser,
event.EditorService,
event.ResourceOwner,
event.InstanceID,
).Scan(&event.ID, &event.Sequence, &previousAggregateSequence, &previousAggregateTypeSequence, &event.CreationDate, &event.ResourceOwner, &event.InstanceID)
event.PreviousAggregateSequence = uint64(previousAggregateSequence)
event.PreviousAggregateTypeSequence = uint64(previousAggregateTypeSequence)
e.Type(),
e.Aggregate().Type,
e.Aggregate().ID,
e.Aggregate().Version,
payload,
e.Creator(),
"zitadel",
e.Aggregate().ResourceOwner,
e.Aggregate().InstanceID,
i,
).Scan(&e.ID, &e.Seq, &e.CreationDate, &e.ResourceOwner, &e.InstanceID)
if err != nil {
logging.WithFields(
"aggregate", event.AggregateType,
"aggregateId", event.AggregateID,
"aggregateType", event.AggregateType,
"eventType", event.Type,
"instanceID", event.InstanceID,
"aggregate", e.Aggregate().Type,
"aggregateId", e.Aggregate().ID,
"aggregateType", e.Aggregate().Type,
"eventType", e.Type(),
"instanceID", e.Aggregate().InstanceID,
).WithError(err).Debug("query failed")
return caos_errs.ThrowInternal(err, "SQL-SBP37", "unable to create event")
}
uniqueConstraints = append(uniqueConstraints, command.UniqueConstraints()...)
events[i] = e
}
err := db.handleUniqueConstraints(ctx, tx, uniqueConstraints...)
if err != nil {
return err
}
return nil
return db.handleUniqueConstraints(ctx, tx, uniqueConstraints...)
})
if err != nil && !errors.Is(err, &caos_errs.CaosError{}) {
err = caos_errs.ThrowInternal(err, "SQL-DjgtG", "unable to store events")
}
return err
}
var instanceRegexp = regexp.MustCompile(`eventstore\.i_[0-9a-zA-Z]{1,}_seq`)
func (db *CRDB) CreateInstance(ctx context.Context, instanceID string) error {
var sequenceName string
err := db.QueryRowContext(ctx,
func(row *sql.Row) error {
if err := row.Scan(&sequenceName); err != nil || !instanceRegexp.MatchString(sequenceName) {
return caos_errs.ThrowInvalidArgument(err, "SQL-7gtFA", "Errors.InvalidArgument")
}
return nil
},
"SELECT CONCAT('eventstore.i_', $1::TEXT, '_seq')", instanceID,
)
if err != nil {
return err
}
if _, err := db.ExecContext(ctx, "CREATE SEQUENCE "+sequenceName); err != nil {
return caos_errs.ThrowInternal(err, "SQL-7gtFA", "Errors.Internal")
}
return nil
return events, err
}
// handleUniqueConstraints adds or removes unique constraints
func (db *CRDB) handleUniqueConstraints(ctx context.Context, tx *sql.Tx, uniqueConstraints ...*repository.UniqueConstraint) (err error) {
func (db *CRDB) handleUniqueConstraints(ctx context.Context, tx *sql.Tx, uniqueConstraints ...*eventstore.UniqueConstraint) (err error) {
if len(uniqueConstraints) == 0 || (len(uniqueConstraints) == 1 && uniqueConstraints[0] == nil) {
return nil
}
@@ -191,32 +214,32 @@ func (db *CRDB) handleUniqueConstraints(ctx context.Context, tx *sql.Tx, uniqueC
for _, uniqueConstraint := range uniqueConstraints {
uniqueConstraint.UniqueField = strings.ToLower(uniqueConstraint.UniqueField)
switch uniqueConstraint.Action {
case repository.UniqueConstraintAdd:
_, err := tx.ExecContext(ctx, uniqueInsert, uniqueConstraint.UniqueType, uniqueConstraint.UniqueField, uniqueConstraint.InstanceID)
case eventstore.UniqueConstraintAdd:
_, err := tx.ExecContext(ctx, uniqueInsert, uniqueConstraint.UniqueType, uniqueConstraint.UniqueField, authz.GetInstance(ctx).InstanceID())
if err != nil {
logging.WithFields(
"unique_type", uniqueConstraint.UniqueType,
"unique_field", uniqueConstraint.UniqueField).WithError(err).Info("insert unique constraint failed")
if db.isUniqueViolationError(err) {
return caos_errs.ThrowAlreadyExists(err, "SQL-M0dsf", uniqueConstraint.ErrorMessage)
return caos_errs.ThrowAlreadyExists(err, "SQL-wHcEq", uniqueConstraint.ErrorMessage)
}
return caos_errs.ThrowInternal(err, "SQL-dM9ds", "unable to create unique constraint")
}
case repository.UniqueConstraintRemoved:
_, err := tx.ExecContext(ctx, uniqueDelete, uniqueConstraint.UniqueType, uniqueConstraint.UniqueField, uniqueConstraint.InstanceID)
case eventstore.UniqueConstraintRemove:
_, err := tx.ExecContext(ctx, uniqueDelete, uniqueConstraint.UniqueType, uniqueConstraint.UniqueField, authz.GetInstance(ctx).InstanceID())
if err != nil {
logging.WithFields(
"unique_type", uniqueConstraint.UniqueType,
"unique_field", uniqueConstraint.UniqueField).WithError(err).Info("delete unique constraint failed")
return caos_errs.ThrowInternal(err, "SQL-6n88i", "unable to remove unique constraint")
}
case repository.UniqueConstraintInstanceRemoved:
_, err := tx.ExecContext(ctx, uniqueDeleteInstance, uniqueConstraint.InstanceID)
case eventstore.UniqueConstraintInstanceRemove:
_, err := tx.ExecContext(ctx, uniqueDeleteInstance, authz.GetInstance(ctx).InstanceID())
if err != nil {
logging.WithFields(
"instance_id", uniqueConstraint.InstanceID).WithError(err).Info("delete instance unique constraints failed")
"instance_id", authz.GetInstance(ctx).InstanceID()).WithError(err).Info("delete instance unique constraints failed")
return caos_errs.ThrowInternal(err, "SQL-6n88i", "unable to remove unique constraints of instance")
}
}
@@ -225,9 +248,16 @@ func (db *CRDB) handleUniqueConstraints(ctx context.Context, tx *sql.Tx, uniqueC
}
// Filter returns all events matching the given search query
func (crdb *CRDB) Filter(ctx context.Context, searchQuery *repository.SearchQuery) (events []*repository.Event, err error) {
events = []*repository.Event{}
err = query(ctx, crdb, searchQuery, &events)
func (crdb *CRDB) Filter(ctx context.Context, searchQuery *eventstore.SearchQueryBuilder) (events []eventstore.Event, err error) {
events = make([]eventstore.Event, 0, searchQuery.GetLimit())
err = query(ctx, crdb, searchQuery, &events, false)
pgErr := new(pgconn.PgError)
// check events2 not exists
if err != nil && errors.As(err, &pgErr) {
if pgErr.Code == "42P01" {
err = query(ctx, crdb, searchQuery, &events, true)
}
}
if err != nil {
return nil, err
}
@@ -236,19 +266,16 @@ func (crdb *CRDB) Filter(ctx context.Context, searchQuery *repository.SearchQuer
}
// LatestSequence returns the latest sequence found by the search query
func (db *CRDB) LatestSequence(ctx context.Context, searchQuery *repository.SearchQuery) (uint64, error) {
var seq Sequence
err := query(ctx, db, searchQuery, &seq)
if err != nil {
return 0, err
}
return uint64(seq), nil
func (db *CRDB) LatestSequence(ctx context.Context, searchQuery *eventstore.SearchQueryBuilder) (float64, error) {
var position sql.NullFloat64
err := query(ctx, db, searchQuery, &position, false)
return position.Float64, err
}
// InstanceIDs returns the instance ids found by the search query
func (db *CRDB) InstanceIDs(ctx context.Context, searchQuery *repository.SearchQuery) ([]string, error) {
func (db *CRDB) InstanceIDs(ctx context.Context, searchQuery *eventstore.SearchQueryBuilder) ([]string, error) {
var ids []string
err := query(ctx, db, searchQuery, &ids)
err := query(ctx, db, searchQuery, &ids, false)
if err != nil {
return nil, err
}
@@ -259,70 +286,107 @@ func (db *CRDB) db() *database.DB {
return db.DB
}
func (db *CRDB) orderByEventSequence(desc bool) string {
if db.AllowOrderByCreationDate {
func (db *CRDB) orderByEventSequence(desc, useV1 bool) string {
if useV1 {
if desc {
return " ORDER BY creation_date DESC, event_sequence DESC"
return ` ORDER BY event_sequence DESC`
}
return " ORDER BY creation_date, event_sequence"
return ` ORDER BY event_sequence`
}
if desc {
return " ORDER BY event_sequence DESC"
return ` ORDER BY "position" DESC, in_tx_order DESC`
}
return " ORDER BY event_sequence"
return ` ORDER BY "position", in_tx_order`
}
func (db *CRDB) eventQuery() string {
func (db *CRDB) eventQuery(useV1 bool) string {
if useV1 {
return "SELECT" +
" creation_date" +
", event_type" +
", event_sequence" +
", event_data" +
", editor_user" +
", resource_owner" +
", instance_id" +
", aggregate_type" +
", aggregate_id" +
", aggregate_version" +
" FROM eventstore.events"
}
return "SELECT" +
" creation_date" +
" created_at" +
", event_type" +
", event_sequence" +
", previous_aggregate_sequence" +
", previous_aggregate_type_sequence" +
", event_data" +
", editor_service" +
", editor_user" +
", resource_owner" +
`, "sequence"` +
`, "position"` +
", payload" +
", creator" +
`, "owner"` +
", instance_id" +
", aggregate_type" +
", aggregate_id" +
", aggregate_version" +
" FROM eventstore.events"
", revision" +
" FROM eventstore.events2"
}
func (db *CRDB) maxSequenceQuery() string {
return "SELECT MAX(event_sequence) FROM eventstore.events"
func (db *CRDB) maxSequenceQuery(useV1 bool) string {
if useV1 {
return `SELECT event_sequence FROM eventstore.events`
}
return `SELECT "position" FROM eventstore.events2`
}
func (db *CRDB) instanceIDsQuery() string {
return "SELECT DISTINCT instance_id FROM eventstore.events"
func (db *CRDB) instanceIDsQuery(useV1 bool) string {
table := "eventstore.events2"
if useV1 {
table = "eventstore.events"
}
return "SELECT DISTINCT instance_id FROM " + table
}
func (db *CRDB) columnName(col repository.Field) string {
func (db *CRDB) columnName(col repository.Field, useV1 bool) string {
switch col {
case repository.FieldAggregateID:
return "aggregate_id"
case repository.FieldAggregateType:
return "aggregate_type"
case repository.FieldSequence:
return "event_sequence"
if useV1 {
return "event_sequence"
}
return `"sequence"`
case repository.FieldResourceOwner:
return "resource_owner"
if useV1 {
return "resource_owner"
}
return `"owner"`
case repository.FieldInstanceID:
return "instance_id"
case repository.FieldEditorService:
return "editor_service"
if useV1 {
return "editor_service"
}
return ""
case repository.FieldEditorUser:
return "editor_user"
if useV1 {
return "editor_user"
}
return "creator"
case repository.FieldEventType:
return "event_type"
case repository.FieldEventData:
return "event_data"
if useV1 {
return "event_data"
}
return "payload"
case repository.FieldCreationDate:
return "creation_date"
if useV1 {
return "creation_date"
}
return "created_at"
case repository.FieldPosition:
return `"position"`
default:
return ""
}

File diff suppressed because it is too large Load Diff

View File

@@ -37,14 +37,14 @@ func TestMain(m *testing.M) {
ts.Stop()
}()
if err = initDB(testCRDBClient); err != nil {
if err = initDB(&database.DB{DB: testCRDBClient, Database: &cockroach.Config{Database: "zitadel"}}); err != nil {
logging.WithFields("error", err).Fatal("migrations failed")
}
os.Exit(m.Run())
}
func initDB(db *sql.DB) error {
func initDB(db *database.DB) error {
config := new(database.Config)
config.SetConnector(&cockroach.Config{User: cockroach.User{Username: "zitadel"}, Database: "zitadel"})
@@ -60,11 +60,13 @@ func initDB(db *sql.DB) error {
return err
}
return initialise.VerifyZitadel(db, *config)
}
err = initialise.VerifyZitadel(db, *config)
if err != nil {
return err
}
func fillUniqueData(unique_type, field, instanceID string) error {
_, err := testCRDBClient.Exec("INSERT INTO eventstore.unique_constraints (unique_type, unique_field, instance_id) VALUES ($1, $2, $3)", unique_type, field, instanceID)
// create old events
_, err = db.Exec(oldEventsTable)
return err
}
@@ -76,4 +78,26 @@ func (*testDB) DatabaseName() string { return "db" }
func (*testDB) Username() string { return "user" }
func (*testDB) Type() string { return "type" }
func (*testDB) Type() string { return "cockroach" }
const oldEventsTable = `CREATE TABLE IF NOT EXISTS eventstore.events (
id UUID DEFAULT gen_random_uuid()
, event_type TEXT NOT NULL
, aggregate_type TEXT NOT NULL
, aggregate_id TEXT NOT NULL
, aggregate_version TEXT NOT NULL
, event_sequence BIGINT NOT NULL
, previous_aggregate_sequence BIGINT
, previous_aggregate_type_sequence INT8
, creation_date TIMESTAMPTZ NOT NULL DEFAULT now()
, created_at TIMESTAMPTZ NOT NULL DEFAULT clock_timestamp()
, event_data JSONB
, editor_user TEXT NOT NULL
, editor_service TEXT
, resource_owner TEXT NOT NULL
, instance_id TEXT NOT NULL
, "position" DECIMAL NOT NULL
, in_tx_order INTEGER NOT NULL
, PRIMARY KEY (instance_id, aggregate_type, aggregate_id, event_sequence DESC)
);`

View File

@@ -6,6 +6,7 @@ import (
"encoding/json"
"errors"
"fmt"
"strconv"
"strings"
"github.com/zitadel/logging"
@@ -14,19 +15,20 @@ import (
"github.com/zitadel/zitadel/internal/database"
"github.com/zitadel/zitadel/internal/database/dialect"
z_errors "github.com/zitadel/zitadel/internal/errors"
"github.com/zitadel/zitadel/internal/eventstore"
"github.com/zitadel/zitadel/internal/eventstore/repository"
)
type querier interface {
columnName(repository.Field) string
columnName(field repository.Field, useV1 bool) string
operation(repository.Operation) string
conditionFormat(repository.Operation) string
placeholder(query string) string
eventQuery() string
maxSequenceQuery() string
instanceIDsQuery() string
eventQuery(useV1 bool) string
maxSequenceQuery(useV1 bool) string
instanceIDsQuery(useV1 bool) string
db() *database.DB
orderByEventSequence(desc bool) string
orderByEventSequence(desc, useV1 bool) string
dialect.Database
}
@@ -52,25 +54,38 @@ func (t *tx) QueryContext(ctx context.Context, scan func(rows *sql.Rows) error,
return rows.Err()
}
func query(ctx context.Context, criteria querier, searchQuery *repository.SearchQuery, dest interface{}) error {
query, rowScanner := prepareColumns(criteria, searchQuery.Columns)
where, values := prepareCondition(criteria, searchQuery.Filters)
func query(ctx context.Context, criteria querier, searchQuery *eventstore.SearchQueryBuilder, dest interface{}, useV1 bool) error {
q, err := repository.QueryFromBuilder(searchQuery)
if err != nil {
return err
}
query, rowScanner := prepareColumns(criteria, q.Columns, useV1)
where, values := prepareConditions(criteria, q, useV1)
if where == "" || query == "" {
return z_errors.ThrowInvalidArgument(nil, "SQL-rWeBw", "invalid query factory")
}
if searchQuery.Tx == nil {
if travel := prepareTimeTravel(ctx, criteria, searchQuery.AllowTimeTravel); travel != "" {
if q.Tx == nil {
if travel := prepareTimeTravel(ctx, criteria, q.AllowTimeTravel); travel != "" {
query += travel
}
}
query += where
if searchQuery.Columns == repository.ColumnsEvent {
query += criteria.orderByEventSequence(searchQuery.Desc)
// instead of using the max function of the database (which doesn't work for postgres)
// we select the most recent row
if q.Columns == eventstore.ColumnsMaxSequence {
q.Limit = 1
q.Desc = true
}
if searchQuery.Limit > 0 {
values = append(values, searchQuery.Limit)
switch q.Columns {
case eventstore.ColumnsEvent,
eventstore.ColumnsMaxSequence:
query += criteria.orderByEventSequence(q.Desc, useV1)
}
if q.Limit > 0 {
values = append(values, q.Limit)
query += " LIMIT ?"
}
@@ -80,11 +95,11 @@ func query(ctx context.Context, criteria querier, searchQuery *repository.Search
QueryContext(context.Context, func(rows *sql.Rows) error, string, ...interface{}) error
}
contextQuerier = criteria.db()
if searchQuery.Tx != nil {
contextQuerier = &tx{Tx: searchQuery.Tx}
if q.Tx != nil {
contextQuerier = &tx{Tx: q.Tx}
}
err := contextQuerier.QueryContext(ctx,
err = contextQuerier.QueryContext(ctx,
func(rows *sql.Rows) error {
for rows.Next() {
err := rowScanner(rows.Scan, dest)
@@ -102,14 +117,14 @@ func query(ctx context.Context, criteria querier, searchQuery *repository.Search
return nil
}
func prepareColumns(criteria querier, columns repository.Columns) (string, func(s scan, dest interface{}) error) {
func prepareColumns(criteria querier, columns eventstore.Columns, useV1 bool) (string, func(s scan, dest interface{}) error) {
switch columns {
case repository.ColumnsMaxSequence:
return criteria.maxSequenceQuery(), maxSequenceScanner
case repository.ColumnsInstanceIDs:
return criteria.instanceIDsQuery(), instanceIDsScanner
case repository.ColumnsEvent:
return criteria.eventQuery(), eventsScanner
case eventstore.ColumnsMaxSequence:
return criteria.maxSequenceQuery(useV1), maxSequenceScanner
case eventstore.ColumnsInstanceIDs:
return criteria.instanceIDsQuery(useV1), instanceIDsScanner
case eventstore.ColumnsEvent:
return criteria.eventQuery(useV1), eventsScanner(useV1)
default:
return "", nil
}
@@ -124,11 +139,11 @@ func prepareTimeTravel(ctx context.Context, criteria querier, allow bool) string
}
func maxSequenceScanner(row scan, dest interface{}) (err error) {
sequence, ok := dest.(*Sequence)
position, ok := dest.(*sql.NullFloat64)
if !ok {
return z_errors.ThrowInvalidArgument(nil, "SQL-NBjA9", "type must be sequence")
return z_errors.ThrowInvalidArgumentf(nil, "SQL-NBjA9", "type must be sql.NullInt64 got: %T", dest)
}
err = row(sequence)
err = row(position)
if err == nil || errors.Is(err, sql.ErrNoRows) {
return nil
}
@@ -151,84 +166,139 @@ func instanceIDsScanner(scanner scan, dest interface{}) (err error) {
return nil
}
func eventsScanner(scanner scan, dest interface{}) (err error) {
events, ok := dest.(*[]*repository.Event)
if !ok {
return z_errors.ThrowInvalidArgument(nil, "SQL-4GP6F", "type must be event")
}
var (
previousAggregateSequence Sequence
previousAggregateTypeSequence Sequence
)
data := make(Data, 0)
event := new(repository.Event)
err = scanner(
&event.CreationDate,
&event.Type,
&event.Sequence,
&previousAggregateSequence,
&previousAggregateTypeSequence,
&data,
&event.EditorService,
&event.EditorUser,
&event.ResourceOwner,
&event.InstanceID,
&event.AggregateType,
&event.AggregateID,
&event.Version,
)
if err != nil {
logging.New().WithError(err).Warn("unable to scan row")
return z_errors.ThrowInternal(err, "SQL-M0dsf", "unable to scan row")
}
event.PreviousAggregateSequence = uint64(previousAggregateSequence)
event.PreviousAggregateTypeSequence = uint64(previousAggregateTypeSequence)
event.Data = make([]byte, len(data))
copy(event.Data, data)
*events = append(*events, event)
return nil
}
func prepareCondition(criteria querier, filters [][]*repository.Filter) (clause string, values []interface{}) {
values = make([]interface{}, 0, len(filters))
if len(filters) == 0 {
return clause, values
}
clauses := make([]string, len(filters))
for idx, filter := range filters {
subClauses := make([]string, 0, len(filter))
for _, f := range filter {
value := f.Value
switch value.(type) {
case map[string]interface{}:
var err error
value, err = json.Marshal(value)
if err != nil {
logging.WithError(err).Warn("unable to marshal search value")
continue
}
}
subClauses = append(subClauses, getCondition(criteria, f))
if subClauses[len(subClauses)-1] == "" {
return "", nil
}
values = append(values, value)
func eventsScanner(useV1 bool) func(scanner scan, dest interface{}) (err error) {
return func(scanner scan, dest interface{}) (err error) {
events, ok := dest.(*[]eventstore.Event)
if !ok {
return z_errors.ThrowInvalidArgument(nil, "SQL-4GP6F", "type must be event")
}
clauses[idx] = "( " + strings.Join(subClauses, " AND ") + " )"
event := new(repository.Event)
data := sql.RawBytes{}
position := new(sql.NullFloat64)
if useV1 {
err = scanner(
&event.CreationDate,
&event.Typ,
&event.Seq,
&data,
&event.EditorUser,
&event.ResourceOwner,
&event.InstanceID,
&event.AggregateType,
&event.AggregateID,
&event.Version,
)
} else {
var revision uint8
err = scanner(
&event.CreationDate,
&event.Typ,
&event.Seq,
position,
&data,
&event.EditorUser,
&event.ResourceOwner,
&event.InstanceID,
&event.AggregateType,
&event.AggregateID,
&revision,
)
event.Version = eventstore.Version("v" + strconv.Itoa(int(revision)))
}
if err != nil {
logging.New().WithError(err).Warn("unable to scan row")
return z_errors.ThrowInternal(err, "SQL-M0dsf", "unable to scan row")
}
event.Data = make([]byte, len(data))
copy(event.Data, data)
event.Pos = position.Float64
*events = append(*events, event)
return nil
}
return " WHERE " + strings.Join(clauses, " OR "), values
}
func getCondition(cond querier, filter *repository.Filter) (condition string) {
field := cond.columnName(filter.Field)
func prepareConditions(criteria querier, query *repository.SearchQuery, useV1 bool) (string, []any) {
clauses, args := prepareQuery(criteria, useV1, query.InstanceID, query.ExcludedInstances)
if clauses != "" && len(query.SubQueries) > 0 {
clauses += " AND "
}
subClauses := make([]string, len(query.SubQueries))
for i, filters := range query.SubQueries {
var subArgs []any
subClauses[i], subArgs = prepareQuery(criteria, useV1, filters...)
// an error is thrown in [query]
if subClauses[i] == "" {
return "", nil
}
if len(query.SubQueries) > 1 && len(subArgs) > 1 {
subClauses[i] = "(" + subClauses[i] + ")"
}
args = append(args, subArgs...)
}
if len(subClauses) == 1 {
clauses += subClauses[0]
} else if len(subClauses) > 1 {
clauses += "(" + strings.Join(subClauses, " OR ") + ")"
}
additionalClauses, additionalArgs := prepareQuery(criteria, useV1, query.Position, query.Owner, query.Sequence, query.CreatedAt, query.Creator)
if additionalClauses != "" {
if clauses != "" {
clauses += " AND "
}
clauses += additionalClauses
args = append(args, additionalArgs...)
}
if query.AwaitOpenTransactions {
clauses += awaitOpenTransactions(useV1)
}
if clauses == "" {
return "", nil
}
return " WHERE " + clauses, args
}
func prepareQuery(criteria querier, useV1 bool, filters ...*repository.Filter) (_ string, args []any) {
clauses := make([]string, 0, len(filters))
args = make([]any, 0, len(filters))
for _, filter := range filters {
if filter == nil {
continue
}
arg := filter.Value
// marshal if payload filter
if filter.Field == repository.FieldEventData {
var err error
arg, err = json.Marshal(arg)
if err != nil {
logging.WithError(err).Warn("unable to marshal search value")
continue
}
}
clauses = append(clauses, getCondition(criteria, filter, useV1))
// if mapping failed an error is thrown in [query]
if clauses[len(clauses)-1] == "" {
return "", nil
}
args = append(args, arg)
}
return strings.Join(clauses, " AND "), args
}
func getCondition(cond querier, filter *repository.Filter, useV1 bool) (condition string) {
field := cond.columnName(filter.Field, useV1)
operation := cond.operation(filter.Operation)
if field == "" || operation == "" {
return ""

View File

@@ -5,13 +5,16 @@ import (
"database/sql"
"database/sql/driver"
"reflect"
"strconv"
"testing"
"time"
"github.com/DATA-DOG/go-sqlmock"
"github.com/zitadel/zitadel/internal/database"
"github.com/zitadel/zitadel/internal/database/cockroach"
"github.com/zitadel/zitadel/internal/errors"
"github.com/zitadel/zitadel/internal/eventstore"
"github.com/zitadel/zitadel/internal/eventstore/repository"
)
@@ -32,38 +35,38 @@ func Test_getCondition(t *testing.T) {
{
name: "greater",
args: args{filter: repository.NewFilter(repository.FieldSequence, 0, repository.OperationGreater)},
want: "event_sequence > ?",
want: `"sequence" > ?`,
},
{
name: "less",
args: args{filter: repository.NewFilter(repository.FieldSequence, 5000, repository.OperationLess)},
want: "event_sequence < ?",
want: `"sequence" < ?`,
},
{
name: "in list",
args: args{filter: repository.NewFilter(repository.FieldAggregateType, []repository.AggregateType{"movies", "actors"}, repository.OperationIn)},
args: args{filter: repository.NewFilter(repository.FieldAggregateType, []eventstore.AggregateType{"movies", "actors"}, repository.OperationIn)},
want: "aggregate_type = ANY(?)",
},
{
name: "invalid operation",
args: args{filter: repository.NewFilter(repository.FieldAggregateType, []repository.AggregateType{"movies", "actors"}, repository.Operation(-1))},
args: args{filter: repository.NewFilter(repository.FieldAggregateType, []eventstore.AggregateType{"movies", "actors"}, repository.Operation(-1))},
want: "",
},
{
name: "invalid field",
args: args{filter: repository.NewFilter(repository.Field(-1), []repository.AggregateType{"movies", "actors"}, repository.OperationEquals)},
args: args{filter: repository.NewFilter(repository.Field(-1), []eventstore.AggregateType{"movies", "actors"}, repository.OperationEquals)},
want: "",
},
{
name: "invalid field and operation",
args: args{filter: repository.NewFilter(repository.Field(-1), []repository.AggregateType{"movies", "actors"}, repository.Operation(-1))},
args: args{filter: repository.NewFilter(repository.Field(-1), []eventstore.AggregateType{"movies", "actors"}, repository.Operation(-1))},
want: "",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
db := &CRDB{}
if got := getCondition(db, tt.args.filter); got != tt.want {
if got := getCondition(db, tt.args.filter, false); got != tt.want {
t.Errorf("getCondition() = %v, want %v", got, tt.want)
}
})
@@ -75,9 +78,10 @@ func Test_prepareColumns(t *testing.T) {
dbRow []interface{}
}
type args struct {
columns repository.Columns
columns eventstore.Columns
dest interface{}
dbErr error
useV1 bool
}
type res struct {
query string
@@ -92,7 +96,7 @@ func Test_prepareColumns(t *testing.T) {
}{
{
name: "invalid columns",
args: args{columns: repository.Columns(-1)},
args: args{columns: eventstore.Columns(-1)},
res: res{
query: "",
dbErr: func(err error) bool { return err == nil },
@@ -101,64 +105,114 @@ func Test_prepareColumns(t *testing.T) {
{
name: "max column",
args: args{
columns: repository.ColumnsMaxSequence,
dest: new(Sequence),
columns: eventstore.ColumnsMaxSequence,
dest: new(sql.NullFloat64),
useV1: true,
},
res: res{
query: "SELECT MAX(event_sequence) FROM eventstore.events",
expected: Sequence(5),
query: `SELECT event_sequence FROM eventstore.events`,
expected: sql.NullFloat64{Float64: 43, Valid: true},
},
fields: fields{
dbRow: []interface{}{Sequence(5)},
dbRow: []interface{}{sql.NullFloat64{Float64: 43, Valid: true}},
},
},
{
name: "max column v2",
args: args{
columns: eventstore.ColumnsMaxSequence,
dest: new(sql.NullFloat64),
},
res: res{
query: `SELECT "position" FROM eventstore.events2`,
expected: sql.NullFloat64{Float64: 43, Valid: true},
},
fields: fields{
dbRow: []interface{}{sql.NullFloat64{Float64: 43, Valid: true}},
},
},
{
name: "max sequence wrong dest type",
args: args{
columns: repository.ColumnsMaxSequence,
columns: eventstore.ColumnsMaxSequence,
dest: new(uint64),
},
res: res{
query: "SELECT MAX(event_sequence) FROM eventstore.events",
query: `SELECT "position" FROM eventstore.events2`,
dbErr: errors.IsErrorInvalidArgument,
},
},
{
name: "events",
args: args{
columns: repository.ColumnsEvent,
dest: &[]*repository.Event{},
columns: eventstore.ColumnsEvent,
dest: &[]eventstore.Event{},
useV1: true,
},
res: res{
query: "SELECT creation_date, event_type, event_sequence, previous_aggregate_sequence, previous_aggregate_type_sequence, event_data, editor_service, editor_user, resource_owner, instance_id, aggregate_type, aggregate_id, aggregate_version FROM eventstore.events",
expected: []*repository.Event{
{AggregateID: "hodor", AggregateType: "user", Sequence: 5, Data: make(Data, 0)},
query: `SELECT creation_date, event_type, event_sequence, event_data, editor_user, resource_owner, instance_id, aggregate_type, aggregate_id, aggregate_version FROM eventstore.events`,
expected: []eventstore.Event{
&repository.Event{AggregateID: "hodor", AggregateType: "user", Seq: 5, Data: make(sql.RawBytes, 0)},
},
},
fields: fields{
dbRow: []interface{}{time.Time{}, repository.EventType(""), uint64(5), Sequence(0), Sequence(0), Data(nil), "", "", sql.NullString{String: ""}, "", repository.AggregateType("user"), "hodor", repository.Version("")},
dbRow: []interface{}{time.Time{}, eventstore.EventType(""), uint64(5), sql.RawBytes(nil), "", sql.NullString{}, "", eventstore.AggregateType("user"), "hodor", eventstore.Version("")},
},
},
{
name: "events v2",
args: args{
columns: eventstore.ColumnsEvent,
dest: &[]eventstore.Event{},
},
res: res{
query: `SELECT created_at, event_type, "sequence", "position", payload, creator, "owner", instance_id, aggregate_type, aggregate_id, revision FROM eventstore.events2`,
expected: []eventstore.Event{
&repository.Event{AggregateID: "hodor", AggregateType: "user", Seq: 5, Pos: 42, Data: make(sql.RawBytes, 0), Version: "v1"},
},
},
fields: fields{
dbRow: []interface{}{time.Time{}, eventstore.EventType(""), uint64(5), sql.NullFloat64{Float64: 42, Valid: true}, sql.RawBytes(nil), "", sql.NullString{}, "", eventstore.AggregateType("user"), "hodor", uint8(1)},
},
},
{
name: "event null position",
args: args{
columns: eventstore.ColumnsEvent,
dest: &[]eventstore.Event{},
},
res: res{
query: `SELECT created_at, event_type, "sequence", "position", payload, creator, "owner", instance_id, aggregate_type, aggregate_id, revision FROM eventstore.events2`,
expected: []eventstore.Event{
&repository.Event{AggregateID: "hodor", AggregateType: "user", Seq: 5, Pos: 0, Data: make(sql.RawBytes, 0), Version: "v1"},
},
},
fields: fields{
dbRow: []interface{}{time.Time{}, eventstore.EventType(""), uint64(5), sql.NullFloat64{Float64: 0, Valid: false}, sql.RawBytes(nil), "", sql.NullString{}, "", eventstore.AggregateType("user"), "hodor", uint8(1)},
},
},
{
name: "events wrong dest type",
args: args{
columns: repository.ColumnsEvent,
columns: eventstore.ColumnsEvent,
dest: []*repository.Event{},
useV1: true,
},
res: res{
query: "SELECT creation_date, event_type, event_sequence, previous_aggregate_sequence, previous_aggregate_type_sequence, event_data, editor_service, editor_user, resource_owner, instance_id, aggregate_type, aggregate_id, aggregate_version FROM eventstore.events",
query: `SELECT creation_date, event_type, event_sequence, event_data, editor_user, resource_owner, instance_id, aggregate_type, aggregate_id, aggregate_version FROM eventstore.events`,
dbErr: errors.IsErrorInvalidArgument,
},
},
{
name: "event query error",
args: args{
columns: repository.ColumnsEvent,
dest: &[]*repository.Event{},
columns: eventstore.ColumnsEvent,
dest: &[]eventstore.Event{},
dbErr: sql.ErrConnDone,
useV1: true,
},
res: res{
query: "SELECT creation_date, event_type, event_sequence, previous_aggregate_sequence, previous_aggregate_type_sequence, event_data, editor_service, editor_user, resource_owner, instance_id, aggregate_type, aggregate_id, aggregate_version FROM eventstore.events",
query: `SELECT creation_date, event_type, event_sequence, event_data, editor_user, resource_owner, instance_id, aggregate_type, aggregate_id, aggregate_version FROM eventstore.events`,
dbErr: errors.IsInternal,
},
},
@@ -166,7 +220,7 @@ func Test_prepareColumns(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
crdb := &CRDB{}
query, rowScanner := prepareColumns(crdb, tt.args.columns)
query, rowScanner := prepareColumns(crdb, tt.args.columns, tt.args.useV1)
if query != tt.res.query {
t.Errorf("prepareColumns() got = %s, want %s", query, tt.res.query)
}
@@ -184,8 +238,13 @@ func Test_prepareColumns(t *testing.T) {
if tt.res.dbErr != nil && tt.res.dbErr(err) {
return
}
if !reflect.DeepEqual(reflect.Indirect(reflect.ValueOf(tt.args.dest)).Interface(), tt.res.expected) {
t.Errorf("unexpected result from rowScanner \nwant: %+v \ngot: %+v", tt.fields.dbRow, reflect.Indirect(reflect.ValueOf(tt.args.dest)).Interface())
if equalizer, ok := tt.res.expected.(interface{ Equal(time.Time) bool }); ok {
equalizer.Equal(tt.args.dest.(*sql.NullTime).Time)
return
}
got := reflect.Indirect(reflect.ValueOf(tt.args.dest)).Interface()
if !reflect.DeepEqual(got, tt.res.expected) {
t.Errorf("unexpected result from rowScanner \nwant: %+v \ngot: %+v", tt.res.expected, got)
}
})
}
@@ -200,6 +259,13 @@ func prepareTestScan(err error, res []interface{}) scan {
return errors.ThrowInvalidArgumentf(nil, "SQL-NML1q", "expected len %d got %d", len(res), len(dests))
}
for i, r := range res {
_, ok := dests[i].(*eventstore.Version)
if ok {
val, ok := r.(uint8)
if ok {
r = eventstore.Version("" + strconv.Itoa(int(val)))
}
}
reflect.ValueOf(dests[i]).Elem().Set(reflect.ValueOf(r))
}
@@ -209,7 +275,8 @@ func prepareTestScan(err error, res []interface{}) scan {
func Test_prepareCondition(t *testing.T) {
type args struct {
filters [][]*repository.Filter
query *repository.SearchQuery
useV1 bool
}
type res struct {
clause string
@@ -223,7 +290,18 @@ func Test_prepareCondition(t *testing.T) {
{
name: "nil filters",
args: args{
filters: nil,
query: &repository.SearchQuery{},
useV1: true,
},
res: res{
clause: "",
values: nil,
},
},
{
name: "nil filters v2",
args: args{
query: &repository.SearchQuery{},
},
res: res{
clause: "",
@@ -233,7 +311,22 @@ func Test_prepareCondition(t *testing.T) {
{
name: "empty filters",
args: args{
filters: [][]*repository.Filter{},
query: &repository.SearchQuery{
SubQueries: [][]*repository.Filter{},
},
useV1: true,
},
res: res{
clause: "",
values: nil,
},
},
{
name: "empty filters v2",
args: args{
query: &repository.SearchQuery{
SubQueries: [][]*repository.Filter{},
},
},
res: res{
clause: "",
@@ -243,9 +336,28 @@ func Test_prepareCondition(t *testing.T) {
{
name: "invalid condition",
args: args{
filters: [][]*repository.Filter{
{
repository.NewFilter(repository.FieldAggregateID, "wrong", repository.Operation(-1)),
query: &repository.SearchQuery{
SubQueries: [][]*repository.Filter{
{
repository.NewFilter(repository.FieldAggregateID, "wrong", repository.Operation(-1)),
},
},
},
useV1: true,
},
res: res{
clause: "",
values: nil,
},
},
{
name: "invalid condition v2",
args: args{
query: &repository.SearchQuery{
SubQueries: [][]*repository.Filter{
{
repository.NewFilter(repository.FieldAggregateID, "wrong", repository.Operation(-1)),
},
},
},
},
@@ -257,38 +369,82 @@ func Test_prepareCondition(t *testing.T) {
{
name: "array as condition value",
args: args{
filters: [][]*repository.Filter{
{
repository.NewFilter(repository.FieldAggregateType, []repository.AggregateType{"user", "org"}, repository.OperationIn),
query: &repository.SearchQuery{
AwaitOpenTransactions: true,
SubQueries: [][]*repository.Filter{
{
repository.NewFilter(repository.FieldAggregateType, []eventstore.AggregateType{"user", "org"}, repository.OperationIn),
},
},
},
useV1: true,
},
res: res{
clause: " WHERE aggregate_type = ANY(?) AND creation_date::TIMESTAMP < (SELECT COALESCE(MIN(start), NOW())::TIMESTAMP FROM crdb_internal.cluster_transactions where application_name = 'zitadel_es_pusher')",
values: []interface{}{[]eventstore.AggregateType{"user", "org"}},
},
},
{
name: "array as condition value v2",
args: args{
query: &repository.SearchQuery{
AwaitOpenTransactions: true,
SubQueries: [][]*repository.Filter{
{
repository.NewFilter(repository.FieldAggregateType, []eventstore.AggregateType{"user", "org"}, repository.OperationIn),
},
},
},
},
res: res{
clause: " WHERE ( aggregate_type = ANY(?) )",
values: []interface{}{[]repository.AggregateType{"user", "org"}},
clause: ` WHERE aggregate_type = ANY(?) AND hlc_to_timestamp("position") < (SELECT COALESCE(MIN(start), NOW())::TIMESTAMP FROM crdb_internal.cluster_transactions where application_name = 'zitadel_es_pusher')`,
values: []interface{}{[]eventstore.AggregateType{"user", "org"}},
},
},
{
name: "multiple filters",
args: args{
filters: [][]*repository.Filter{
{
repository.NewFilter(repository.FieldAggregateType, []repository.AggregateType{"user", "org"}, repository.OperationIn),
repository.NewFilter(repository.FieldAggregateID, "1234", repository.OperationEquals),
repository.NewFilter(repository.FieldEventType, []repository.EventType{"user.created", "org.created"}, repository.OperationIn),
query: &repository.SearchQuery{
AwaitOpenTransactions: true,
SubQueries: [][]*repository.Filter{
{
repository.NewFilter(repository.FieldAggregateType, []eventstore.AggregateType{"user", "org"}, repository.OperationIn),
repository.NewFilter(repository.FieldAggregateID, "1234", repository.OperationEquals),
repository.NewFilter(repository.FieldEventType, []eventstore.EventType{"user.created", "org.created"}, repository.OperationIn),
},
},
},
useV1: true,
},
res: res{
clause: " WHERE aggregate_type = ANY(?) AND aggregate_id = ? AND event_type = ANY(?) AND creation_date::TIMESTAMP < (SELECT COALESCE(MIN(start), NOW())::TIMESTAMP FROM crdb_internal.cluster_transactions where application_name = 'zitadel_es_pusher')",
values: []interface{}{[]eventstore.AggregateType{"user", "org"}, "1234", []eventstore.EventType{"user.created", "org.created"}},
},
},
{
name: "multiple filters v2",
args: args{
query: &repository.SearchQuery{
AwaitOpenTransactions: true,
SubQueries: [][]*repository.Filter{
{
repository.NewFilter(repository.FieldAggregateType, []eventstore.AggregateType{"user", "org"}, repository.OperationIn),
repository.NewFilter(repository.FieldAggregateID, "1234", repository.OperationEquals),
repository.NewFilter(repository.FieldEventType, []eventstore.EventType{"user.created", "org.created"}, repository.OperationIn),
},
},
},
},
res: res{
clause: " WHERE ( aggregate_type = ANY(?) AND aggregate_id = ? AND event_type = ANY(?) )",
values: []interface{}{[]repository.AggregateType{"user", "org"}, "1234", []repository.EventType{"user.created", "org.created"}},
clause: ` WHERE aggregate_type = ANY(?) AND aggregate_id = ? AND event_type = ANY(?) AND hlc_to_timestamp("position") < (SELECT COALESCE(MIN(start), NOW())::TIMESTAMP FROM crdb_internal.cluster_transactions where application_name = 'zitadel_es_pusher')`,
values: []interface{}{[]eventstore.AggregateType{"user", "org"}, "1234", []eventstore.EventType{"user.created", "org.created"}},
},
},
}
crdb := NewCRDB(&database.DB{Database: new(cockroach.Config)})
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
crdb := &CRDB{}
gotClause, gotValues := prepareCondition(crdb, tt.args.filters)
gotClause, gotValues := prepareConditions(crdb, tt.args.query, tt.args.useV1)
if gotClause != tt.res.clause {
t.Errorf("prepareCondition() gotClause = %v, want %v", gotClause, tt.res.clause)
}
@@ -307,10 +463,10 @@ func Test_prepareCondition(t *testing.T) {
func Test_query_events_with_crdb(t *testing.T) {
type args struct {
searchQuery *repository.SearchQuery
searchQuery *eventstore.SearchQueryBuilder
}
type fields struct {
existingEvents []*repository.Event
existingEvents []eventstore.Command
client *sql.DB
}
type res struct {
@@ -326,18 +482,14 @@ func Test_query_events_with_crdb(t *testing.T) {
{
name: "aggregate type filter no events",
args: args{
searchQuery: &repository.SearchQuery{
Columns: repository.ColumnsEvent,
Filters: [][]*repository.Filter{
{
repository.NewFilter(repository.FieldAggregateType, "not found", repository.OperationEquals),
},
},
},
searchQuery: eventstore.NewSearchQueryBuilder(eventstore.ColumnsEvent).
AddQuery().
AggregateTypes("not found").
Builder(),
},
fields: fields{
client: testCRDBClient,
existingEvents: []*repository.Event{
existingEvents: []eventstore.Command{
generateEvent(t, "300"),
generateEvent(t, "300"),
generateEvent(t, "300"),
@@ -351,18 +503,14 @@ func Test_query_events_with_crdb(t *testing.T) {
{
name: "aggregate type filter events found",
args: args{
searchQuery: &repository.SearchQuery{
Columns: repository.ColumnsEvent,
Filters: [][]*repository.Filter{
{
repository.NewFilter(repository.FieldAggregateType, t.Name(), repository.OperationEquals),
},
},
},
searchQuery: eventstore.NewSearchQueryBuilder(eventstore.ColumnsEvent).
AddQuery().
AggregateTypes(eventstore.AggregateType(t.Name())).
Builder(),
},
fields: fields{
client: testCRDBClient,
existingEvents: []*repository.Event{
existingEvents: []eventstore.Command{
generateEvent(t, "301"),
generateEvent(t, "302"),
generateEvent(t, "302"),
@@ -377,19 +525,15 @@ func Test_query_events_with_crdb(t *testing.T) {
{
name: "aggregate type and id filter events found",
args: args{
searchQuery: &repository.SearchQuery{
Columns: repository.ColumnsEvent,
Filters: [][]*repository.Filter{
{
repository.NewFilter(repository.FieldAggregateType, t.Name(), repository.OperationEquals),
repository.NewFilter(repository.FieldAggregateID, "303", repository.OperationEquals),
},
},
},
searchQuery: eventstore.NewSearchQueryBuilder(eventstore.ColumnsEvent).
AddQuery().
AggregateTypes(eventstore.AggregateType(t.Name())).
AggregateIDs("303").
Builder(),
},
fields: fields{
client: testCRDBClient,
existingEvents: []*repository.Event{
existingEvents: []eventstore.Command{
generateEvent(t, "303"),
generateEvent(t, "303"),
generateEvent(t, "303"),
@@ -405,18 +549,12 @@ func Test_query_events_with_crdb(t *testing.T) {
{
name: "resource owner filter events found",
args: args{
searchQuery: &repository.SearchQuery{
Columns: repository.ColumnsEvent,
Filters: [][]*repository.Filter{
{
repository.NewFilter(repository.FieldResourceOwner, "caos", repository.OperationEquals),
},
},
},
searchQuery: eventstore.NewSearchQueryBuilder(eventstore.ColumnsEvent).
ResourceOwner("caos"),
},
fields: fields{
client: testCRDBClient,
existingEvents: []*repository.Event{
existingEvents: []eventstore.Command{
generateEvent(t, "306", func(e *repository.Event) { e.ResourceOwner = sql.NullString{String: "caos", Valid: true} }),
generateEvent(t, "307", func(e *repository.Event) { e.ResourceOwner = sql.NullString{String: "caos", Valid: true} }),
generateEvent(t, "308", func(e *repository.Event) { e.ResourceOwner = sql.NullString{String: "caos", Valid: true} }),
@@ -429,90 +567,26 @@ func Test_query_events_with_crdb(t *testing.T) {
},
wantErr: false,
},
{
name: "editor service filter events found",
args: args{
searchQuery: &repository.SearchQuery{
Columns: repository.ColumnsEvent,
Filters: [][]*repository.Filter{
{
repository.NewFilter(repository.FieldEditorService, "MANAGEMENT-API", repository.OperationEquals),
repository.NewFilter(repository.FieldEditorService, "ADMIN-API", repository.OperationEquals),
},
},
},
},
fields: fields{
client: testCRDBClient,
existingEvents: []*repository.Event{
generateEvent(t, "307", func(e *repository.Event) { e.EditorService = "MANAGEMENT-API" }),
generateEvent(t, "307", func(e *repository.Event) { e.EditorService = "MANAGEMENT-API" }),
generateEvent(t, "308", func(e *repository.Event) { e.EditorService = "ADMIN-API" }),
generateEvent(t, "309", func(e *repository.Event) { e.EditorService = "AUTHAPI" }),
generateEvent(t, "309", func(e *repository.Event) { e.EditorService = "AUTHAPI" }),
},
},
res: res{
eventCount: 3,
},
wantErr: false,
},
{
name: "editor user filter events found",
args: args{
searchQuery: &repository.SearchQuery{
Columns: repository.ColumnsEvent,
Filters: [][]*repository.Filter{
{
repository.NewFilter(repository.FieldEditorUser, "adlerhurst", repository.OperationEquals),
repository.NewFilter(repository.FieldEditorUser, "nobody", repository.OperationEquals),
repository.NewFilter(repository.FieldEditorUser, "", repository.OperationEquals),
},
},
},
},
fields: fields{
client: testCRDBClient,
existingEvents: []*repository.Event{
generateEvent(t, "310", func(e *repository.Event) { e.EditorUser = "adlerhurst" }),
generateEvent(t, "310", func(e *repository.Event) { e.EditorUser = "adlerhurst" }),
generateEvent(t, "310", func(e *repository.Event) { e.EditorUser = "nobody" }),
generateEvent(t, "311", func(e *repository.Event) { e.EditorUser = "" }),
generateEvent(t, "311", func(e *repository.Event) { e.EditorUser = "" }),
generateEvent(t, "312", func(e *repository.Event) { e.EditorUser = "fforootd" }),
generateEvent(t, "312", func(e *repository.Event) { e.EditorUser = "fforootd" }),
},
},
res: res{
eventCount: 5,
},
wantErr: false,
},
{
name: "event type filter events found",
args: args{
searchQuery: &repository.SearchQuery{
Columns: repository.ColumnsEvent,
Filters: [][]*repository.Filter{
{
repository.NewFilter(repository.FieldEventType, repository.EventType("user.created"), repository.OperationEquals),
repository.NewFilter(repository.FieldEventType, repository.EventType("user.updated"), repository.OperationEquals),
},
},
},
searchQuery: eventstore.NewSearchQueryBuilder(eventstore.ColumnsEvent).
AddQuery().
EventTypes("user.created", "user.updated").
Builder(),
},
fields: fields{
client: testCRDBClient,
existingEvents: []*repository.Event{
generateEvent(t, "311", func(e *repository.Event) { e.Type = "user.created" }),
generateEvent(t, "311", func(e *repository.Event) { e.Type = "user.updated" }),
generateEvent(t, "311", func(e *repository.Event) { e.Type = "user.deactivated" }),
generateEvent(t, "311", func(e *repository.Event) { e.Type = "user.locked" }),
generateEvent(t, "312", func(e *repository.Event) { e.Type = "user.created" }),
generateEvent(t, "312", func(e *repository.Event) { e.Type = "user.updated" }),
generateEvent(t, "312", func(e *repository.Event) { e.Type = "user.deactivated" }),
generateEvent(t, "312", func(e *repository.Event) { e.Type = "user.reactivated" }),
generateEvent(t, "313", func(e *repository.Event) { e.Type = "user.locked" }),
existingEvents: []eventstore.Command{
generateEvent(t, "311", func(e *repository.Event) { e.Typ = "user.created" }),
generateEvent(t, "311", func(e *repository.Event) { e.Typ = "user.updated" }),
generateEvent(t, "311", func(e *repository.Event) { e.Typ = "user.deactivated" }),
generateEvent(t, "311", func(e *repository.Event) { e.Typ = "user.locked" }),
generateEvent(t, "312", func(e *repository.Event) { e.Typ = "user.created" }),
generateEvent(t, "312", func(e *repository.Event) { e.Typ = "user.updated" }),
generateEvent(t, "312", func(e *repository.Event) { e.Typ = "user.deactivated" }),
generateEvent(t, "312", func(e *repository.Event) { e.Typ = "user.reactivated" }),
generateEvent(t, "313", func(e *repository.Event) { e.Typ = "user.locked" }),
},
},
res: res{
@@ -523,11 +597,11 @@ func Test_query_events_with_crdb(t *testing.T) {
{
name: "fail because no filter",
args: args{
searchQuery: &repository.SearchQuery{},
searchQuery: eventstore.NewSearchQueryBuilder(eventstore.Columns(-1)),
},
fields: fields{
client: testCRDBClient,
existingEvents: []*repository.Event{},
existingEvents: []eventstore.Command{},
},
res: res{
eventCount: 0,
@@ -542,17 +616,16 @@ func Test_query_events_with_crdb(t *testing.T) {
DB: tt.fields.client,
Database: new(testDB),
},
AllowOrderByCreationDate: true,
}
// setup initial data for query
if err := db.Push(context.Background(), tt.fields.existingEvents); err != nil {
if _, err := db.Push(context.Background(), tt.fields.existingEvents...); err != nil {
t.Errorf("error in setup = %v", err)
return
}
events := []*repository.Event{}
if err := query(context.Background(), db, tt.args.searchQuery, &events); (err != nil) != tt.wantErr {
events := []eventstore.Event{}
if err := query(context.Background(), db, tt.args.searchQuery, &events, true); (err != nil) != tt.wantErr {
t.Errorf("CRDB.query() error = %v, wantErr %v", err, tt.wantErr)
}
})
@@ -561,7 +634,7 @@ func Test_query_events_with_crdb(t *testing.T) {
func Test_query_events_mocked(t *testing.T) {
type args struct {
query *repository.SearchQuery
query *eventstore.SearchQueryBuilder
dest interface{}
}
type res struct {
@@ -580,24 +653,17 @@ func Test_query_events_mocked(t *testing.T) {
name: "with order by desc",
args: args{
dest: &[]*repository.Event{},
query: &repository.SearchQuery{
Columns: repository.ColumnsEvent,
Desc: true,
Filters: [][]*repository.Filter{
{
{
Field: repository.FieldAggregateType,
Value: repository.AggregateType("user"),
Operation: repository.OperationEquals,
},
},
},
},
query: eventstore.NewSearchQueryBuilder(eventstore.ColumnsEvent).
OrderDesc().
AwaitOpenTransactions().
AddQuery().
AggregateTypes("user").
Builder(),
},
fields: fields{
mock: newMockClient(t).expectQuery(t,
`SELECT creation_date, event_type, event_sequence, previous_aggregate_sequence, previous_aggregate_type_sequence, event_data, editor_service, editor_user, resource_owner, instance_id, aggregate_type, aggregate_id, aggregate_version FROM eventstore.events WHERE \( aggregate_type = \$1 \) ORDER BY creation_date DESC, event_sequence DESC`,
[]driver.Value{repository.AggregateType("user")},
`SELECT creation_date, event_type, event_sequence, event_data, editor_user, resource_owner, instance_id, aggregate_type, aggregate_id, aggregate_version FROM eventstore.events WHERE aggregate_type = \$1 AND creation_date::TIMESTAMP < \(SELECT COALESCE\(MIN\(start\), NOW\(\)\)::TIMESTAMP FROM crdb_internal\.cluster_transactions where application_name = 'zitadel_es_pusher'\) ORDER BY event_sequence DESC`,
[]driver.Value{eventstore.AggregateType("user")},
),
},
res: res{
@@ -608,25 +674,18 @@ func Test_query_events_mocked(t *testing.T) {
name: "with limit",
args: args{
dest: &[]*repository.Event{},
query: &repository.SearchQuery{
Columns: repository.ColumnsEvent,
Desc: false,
Limit: 5,
Filters: [][]*repository.Filter{
{
{
Field: repository.FieldAggregateType,
Value: repository.AggregateType("user"),
Operation: repository.OperationEquals,
},
},
},
},
query: eventstore.NewSearchQueryBuilder(eventstore.ColumnsEvent).
OrderAsc().
AwaitOpenTransactions().
Limit(5).
AddQuery().
AggregateTypes("user").
Builder(),
},
fields: fields{
mock: newMockClient(t).expectQuery(t,
`SELECT creation_date, event_type, event_sequence, previous_aggregate_sequence, previous_aggregate_type_sequence, event_data, editor_service, editor_user, resource_owner, instance_id, aggregate_type, aggregate_id, aggregate_version FROM eventstore.events WHERE \( aggregate_type = \$1 \) ORDER BY creation_date, event_sequence LIMIT \$2`,
[]driver.Value{repository.AggregateType("user"), uint64(5)},
`SELECT creation_date, event_type, event_sequence, event_data, editor_user, resource_owner, instance_id, aggregate_type, aggregate_id, aggregate_version FROM eventstore.events WHERE aggregate_type = \$1 AND creation_date::TIMESTAMP < \(SELECT COALESCE\(MIN\(start\), NOW\(\)\)::TIMESTAMP FROM crdb_internal\.cluster_transactions where application_name = 'zitadel_es_pusher'\) ORDER BY event_sequence LIMIT \$2`,
[]driver.Value{eventstore.AggregateType("user"), uint64(5)},
),
},
res: res{
@@ -637,25 +696,18 @@ func Test_query_events_mocked(t *testing.T) {
name: "with limit and order by desc",
args: args{
dest: &[]*repository.Event{},
query: &repository.SearchQuery{
Columns: repository.ColumnsEvent,
Desc: true,
Limit: 5,
Filters: [][]*repository.Filter{
{
{
Field: repository.FieldAggregateType,
Value: repository.AggregateType("user"),
Operation: repository.OperationEquals,
},
},
},
},
query: eventstore.NewSearchQueryBuilder(eventstore.ColumnsEvent).
OrderDesc().
AwaitOpenTransactions().
Limit(5).
AddQuery().
AggregateTypes("user").
Builder(),
},
fields: fields{
mock: newMockClient(t).expectQuery(t,
`SELECT creation_date, event_type, event_sequence, previous_aggregate_sequence, previous_aggregate_type_sequence, event_data, editor_service, editor_user, resource_owner, instance_id, aggregate_type, aggregate_id, aggregate_version FROM eventstore.events WHERE \( aggregate_type = \$1 \) ORDER BY creation_date DESC, event_sequence DESC LIMIT \$2`,
[]driver.Value{repository.AggregateType("user"), uint64(5)},
`SELECT creation_date, event_type, event_sequence, event_data, editor_user, resource_owner, instance_id, aggregate_type, aggregate_id, aggregate_version FROM eventstore.events WHERE aggregate_type = \$1 AND creation_date::TIMESTAMP < \(SELECT COALESCE\(MIN\(start\), NOW\(\)\)::TIMESTAMP FROM crdb_internal\.cluster_transactions where application_name = 'zitadel_es_pusher'\) ORDER BY event_sequence DESC LIMIT \$2`,
[]driver.Value{eventstore.AggregateType("user"), uint64(5)},
),
},
res: res{
@@ -666,26 +718,19 @@ func Test_query_events_mocked(t *testing.T) {
name: "with limit and order by desc as of system time",
args: args{
dest: &[]*repository.Event{},
query: &repository.SearchQuery{
Columns: repository.ColumnsEvent,
Desc: true,
Limit: 5,
AllowTimeTravel: true,
Filters: [][]*repository.Filter{
{
{
Field: repository.FieldAggregateType,
Value: repository.AggregateType("user"),
Operation: repository.OperationEquals,
},
},
},
},
query: eventstore.NewSearchQueryBuilder(eventstore.ColumnsEvent).
OrderDesc().
AwaitOpenTransactions().
Limit(5).
AllowTimeTravel().
AddQuery().
AggregateTypes("user").
Builder(),
},
fields: fields{
mock: newMockClient(t).expectQuery(t,
`SELECT creation_date, event_type, event_sequence, previous_aggregate_sequence, previous_aggregate_type_sequence, event_data, editor_service, editor_user, resource_owner, instance_id, aggregate_type, aggregate_id, aggregate_version FROM eventstore.events AS OF SYSTEM TIME '-1 ms' WHERE \( aggregate_type = \$1 \) ORDER BY creation_date DESC, event_sequence DESC LIMIT \$2`,
[]driver.Value{repository.AggregateType("user"), uint64(5)},
`SELECT creation_date, event_type, event_sequence, event_data, editor_user, resource_owner, instance_id, aggregate_type, aggregate_id, aggregate_version FROM eventstore.events AS OF SYSTEM TIME '-1 ms' WHERE aggregate_type = \$1 AND creation_date::TIMESTAMP < \(SELECT COALESCE\(MIN\(start\), NOW\(\)\)::TIMESTAMP FROM crdb_internal\.cluster_transactions where application_name = 'zitadel_es_pusher'\) ORDER BY event_sequence DESC LIMIT \$2`,
[]driver.Value{eventstore.AggregateType("user"), uint64(5)},
),
},
res: res{
@@ -696,25 +741,18 @@ func Test_query_events_mocked(t *testing.T) {
name: "error sql conn closed",
args: args{
dest: &[]*repository.Event{},
query: &repository.SearchQuery{
Columns: repository.ColumnsEvent,
Desc: true,
Limit: 0,
Filters: [][]*repository.Filter{
{
{
Field: repository.FieldAggregateType,
Value: repository.AggregateType("user"),
Operation: repository.OperationEquals,
},
},
},
},
query: eventstore.NewSearchQueryBuilder(eventstore.ColumnsEvent).
OrderDesc().
AwaitOpenTransactions().
Limit(0).
AddQuery().
AggregateTypes("user").
Builder(),
},
fields: fields{
mock: newMockClient(t).expectQueryErr(t,
`SELECT creation_date, event_type, event_sequence, previous_aggregate_sequence, previous_aggregate_type_sequence, event_data, editor_service, editor_user, resource_owner, instance_id, aggregate_type, aggregate_id, aggregate_version FROM eventstore.events WHERE \( aggregate_type = \$1 \) ORDER BY creation_date DESC, event_sequence DESC`,
[]driver.Value{repository.AggregateType("user")},
`SELECT creation_date, event_type, event_sequence, event_data, editor_user, resource_owner, instance_id, aggregate_type, aggregate_id, aggregate_version FROM eventstore.events WHERE aggregate_type = \$1 AND creation_date::TIMESTAMP < \(SELECT COALESCE\(MIN\(start\), NOW\(\)\)::TIMESTAMP FROM crdb_internal\.cluster_transactions where application_name = 'zitadel_es_pusher'\) ORDER BY event_sequence DESC`,
[]driver.Value{eventstore.AggregateType("user")},
sql.ErrConnDone),
},
res: res{
@@ -725,26 +763,19 @@ func Test_query_events_mocked(t *testing.T) {
name: "error unexpected dest",
args: args{
dest: nil,
query: &repository.SearchQuery{
Columns: repository.ColumnsEvent,
Desc: true,
Limit: 0,
Filters: [][]*repository.Filter{
{
{
Field: repository.FieldAggregateType,
Value: repository.AggregateType("user"),
Operation: repository.OperationEquals,
},
},
},
},
query: eventstore.NewSearchQueryBuilder(eventstore.ColumnsEvent).
OrderDesc().
AwaitOpenTransactions().
Limit(0).
AddQuery().
AggregateTypes("user").
Builder(),
},
fields: fields{
mock: newMockClient(t).expectQueryScanErr(t,
`SELECT creation_date, event_type, event_sequence, previous_aggregate_sequence, previous_aggregate_type_sequence, event_data, editor_service, editor_user, resource_owner, instance_id, aggregate_type, aggregate_id, aggregate_version FROM eventstore.events WHERE \( aggregate_type = \$1 \) ORDER BY creation_date DESC, event_sequence DESC`,
[]driver.Value{repository.AggregateType("user")},
&repository.Event{Sequence: 100}),
`SELECT creation_date, event_type, event_sequence, event_data, editor_user, resource_owner, instance_id, aggregate_type, aggregate_id, aggregate_version FROM eventstore.events WHERE aggregate_type = \$1 AND creation_date::TIMESTAMP < \(SELECT COALESCE\(MIN\(start\), NOW\(\)\)::TIMESTAMP FROM crdb_internal\.cluster_transactions where application_name = 'zitadel_es_pusher'\) ORDER BY event_sequence DESC`,
[]driver.Value{eventstore.AggregateType("user")},
&repository.Event{Seq: 100}),
},
res: res{
wantErr: true,
@@ -753,25 +784,7 @@ func Test_query_events_mocked(t *testing.T) {
{
name: "error no columns",
args: args{
query: &repository.SearchQuery{
Columns: repository.Columns(-1),
},
},
res: res{
wantErr: true,
},
},
{
name: "invalid condition",
args: args{
query: &repository.SearchQuery{
Columns: repository.ColumnsEvent,
Filters: [][]*repository.Filter{
{
{},
},
},
},
query: eventstore.NewSearchQueryBuilder(eventstore.Columns(-1)),
},
res: res{
wantErr: true,
@@ -781,37 +794,21 @@ func Test_query_events_mocked(t *testing.T) {
name: "with subqueries",
args: args{
dest: &[]*repository.Event{},
query: &repository.SearchQuery{
Columns: repository.ColumnsEvent,
Desc: true,
Limit: 5,
Filters: [][]*repository.Filter{
{
{
Field: repository.FieldAggregateType,
Value: repository.AggregateType("user"),
Operation: repository.OperationEquals,
},
},
{
{
Field: repository.FieldAggregateType,
Value: repository.AggregateType("org"),
Operation: repository.OperationEquals,
},
{
Field: repository.FieldAggregateID,
Value: "asdf42",
Operation: repository.OperationEquals,
},
},
},
},
query: eventstore.NewSearchQueryBuilder(eventstore.ColumnsEvent).
OrderDesc().
AwaitOpenTransactions().
Limit(5).
AddQuery().
AggregateTypes("user").
Or().
AggregateTypes("org").
AggregateIDs("asdf42").
Builder(),
},
fields: fields{
mock: newMockClient(t).expectQuery(t,
`SELECT creation_date, event_type, event_sequence, previous_aggregate_sequence, previous_aggregate_type_sequence, event_data, editor_service, editor_user, resource_owner, instance_id, aggregate_type, aggregate_id, aggregate_version FROM eventstore.events WHERE \( aggregate_type = \$1 \) OR \( aggregate_type = \$2 AND aggregate_id = \$3 \) ORDER BY creation_date DESC, event_sequence DESC LIMIT \$4`,
[]driver.Value{repository.AggregateType("user"), repository.AggregateType("org"), "asdf42", uint64(5)},
`SELECT creation_date, event_type, event_sequence, event_data, editor_user, resource_owner, instance_id, aggregate_type, aggregate_id, aggregate_version FROM eventstore.events WHERE \(aggregate_type = \$1 OR \(aggregate_type = \$2 AND aggregate_id = \$3\)\) AND creation_date::TIMESTAMP < \(SELECT COALESCE\(MIN\(start\), NOW\(\)\)::TIMESTAMP FROM crdb_internal\.cluster_transactions where application_name = 'zitadel_es_pusher'\) ORDER BY event_sequence DESC LIMIT \$4`,
[]driver.Value{eventstore.AggregateType("user"), eventstore.AggregateType("org"), "asdf42", uint64(5)},
),
},
res: res{
@@ -819,19 +816,14 @@ func Test_query_events_mocked(t *testing.T) {
},
},
}
crdb := NewCRDB(&database.DB{Database: new(testDB)})
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
crdb := &CRDB{
DB: &database.DB{
Database: new(testDB),
},
AllowOrderByCreationDate: true,
}
if tt.fields.mock != nil {
crdb.DB.DB = tt.fields.mock.client
}
err := query(context.Background(), crdb, tt.args.query, tt.args.dest)
err := query(context.Background(), crdb, tt.args.query, tt.args.dest, true)
if (err != nil) != tt.res.wantErr {
t.Errorf("query() error = %v, wantErr %v", err, tt.res.wantErr)
}
@@ -856,9 +848,9 @@ func (m *dbMock) expectQuery(t *testing.T, expectedQuery string, args []driver.V
m.mock.ExpectBegin()
query := m.mock.ExpectQuery(expectedQuery).WithArgs(args...)
m.mock.ExpectCommit()
rows := sqlmock.NewRows([]string{"event_sequence"})
rows := sqlmock.NewRows([]string{"sequence"})
for _, event := range events {
rows = rows.AddRow(event.Sequence)
rows = rows.AddRow(event.Seq)
}
query.WillReturnRows(rows).RowsWillBeClosed()
return m
@@ -868,9 +860,9 @@ func (m *dbMock) expectQueryScanErr(t *testing.T, expectedQuery string, args []d
m.mock.ExpectBegin()
query := m.mock.ExpectQuery(expectedQuery).WithArgs(args...)
m.mock.ExpectRollback()
rows := sqlmock.NewRows([]string{"event_sequence"})
rows := sqlmock.NewRows([]string{"sequence"})
for _, event := range events {
rows = rows.AddRow(event.Sequence)
rows = rows.AddRow(event.Seq)
}
query.WillReturnRows(rows).RowsWillBeClosed()
return m

View File

@@ -1,49 +0,0 @@
package sql
import (
"database/sql/driver"
)
// Data represents a byte array that may be null.
// Data implements the sql.Scanner interface
type Data []byte
// Scan implements the Scanner interface.
func (data *Data) Scan(value interface{}) error {
if value == nil {
*data = nil
return nil
}
*data = Data(value.([]byte))
return nil
}
// Value implements the driver Valuer interface.
func (data Data) Value() (driver.Value, error) {
if len(data) == 0 {
return nil, nil
}
return []byte(data), nil
}
// Sequence represents a number that may be null.
// Sequence implements the sql.Scanner interface
type Sequence uint64
// Scan implements the Scanner interface.
func (seq *Sequence) Scan(value interface{}) error {
if value == nil {
*seq = 0
return nil
}
*seq = Sequence(value.(int64))
return nil
}
// Value implements the driver Valuer interface.
func (seq Sequence) Value() (driver.Value, error) {
if seq == 0 {
return nil, nil
}
return int64(seq), nil
}