From 5811a7b6a58150ca4304ac1fe825ff6679eb3bac Mon Sep 17 00:00:00 2001 From: Silvan Date: Fri, 26 Apr 2024 17:05:21 +0200 Subject: [PATCH] refactor(v2): init eventstore package (#7806) * refactor(v2): init database package * refactor(v2): init eventstore package * add mock package * test query constructors * option based push analog to query --- internal/v2/eventstore/aggregate.go | 24 + internal/v2/eventstore/current_sequence.go | 29 + internal/v2/eventstore/event.go | 36 + internal/v2/eventstore/event_store.go | 41 + internal/v2/eventstore/postgres/event.go | 64 + internal/v2/eventstore/postgres/intent.go | 42 + .../v2/eventstore/postgres/intent_test.go | 122 ++ internal/v2/eventstore/postgres/push.go | 245 +++ internal/v2/eventstore/postgres/push_test.go | 1292 +++++++++++++++ internal/v2/eventstore/postgres/query.go | 289 ++++ internal/v2/eventstore/postgres/query_test.go | 1380 +++++++++++++++++ internal/v2/eventstore/postgres/storage.go | 28 + internal/v2/eventstore/push.go | 190 +++ internal/v2/eventstore/query.go | 756 +++++++++ internal/v2/eventstore/query_test.go | 1063 +++++++++++++ internal/v2/eventstore/unique_constraint.go | 80 + 16 files changed, 5681 insertions(+) create mode 100644 internal/v2/eventstore/aggregate.go create mode 100644 internal/v2/eventstore/current_sequence.go create mode 100644 internal/v2/eventstore/event.go create mode 100644 internal/v2/eventstore/event_store.go create mode 100644 internal/v2/eventstore/postgres/event.go create mode 100644 internal/v2/eventstore/postgres/intent.go create mode 100644 internal/v2/eventstore/postgres/intent_test.go create mode 100644 internal/v2/eventstore/postgres/push.go create mode 100644 internal/v2/eventstore/postgres/push_test.go create mode 100644 internal/v2/eventstore/postgres/query.go create mode 100644 internal/v2/eventstore/postgres/query_test.go create mode 100644 internal/v2/eventstore/postgres/storage.go create mode 100644 internal/v2/eventstore/push.go create mode 100644 internal/v2/eventstore/query.go create mode 100644 internal/v2/eventstore/query_test.go create mode 100644 internal/v2/eventstore/unique_constraint.go diff --git a/internal/v2/eventstore/aggregate.go b/internal/v2/eventstore/aggregate.go new file mode 100644 index 0000000000..c4ab597aef --- /dev/null +++ b/internal/v2/eventstore/aggregate.go @@ -0,0 +1,24 @@ +package eventstore + +type Aggregate struct { + ID string + Type string + Instance string + Owner string +} + +func (agg *Aggregate) Equals(aggregate *Aggregate) bool { + if aggregate.ID != "" && aggregate.ID != agg.ID { + return false + } + if aggregate.Type != "" && aggregate.Type != agg.Type { + return false + } + if aggregate.Instance != "" && aggregate.Instance != agg.Instance { + return false + } + if aggregate.Owner != "" && aggregate.Owner != agg.Owner { + return false + } + return true +} diff --git a/internal/v2/eventstore/current_sequence.go b/internal/v2/eventstore/current_sequence.go new file mode 100644 index 0000000000..3fcdcf5904 --- /dev/null +++ b/internal/v2/eventstore/current_sequence.go @@ -0,0 +1,29 @@ +package eventstore + +type CurrentSequence func(current uint32) bool + +func CheckSequence(current uint32, check CurrentSequence) bool { + if check == nil { + return true + } + return check(current) +} + +// SequenceIgnore doesn't check the current sequence +func SequenceIgnore() CurrentSequence { + return nil +} + +// SequenceMatches exactly the provided sequence +func SequenceMatches(sequence uint32) CurrentSequence { + return func(current uint32) bool { + return current == sequence + } +} + +// SequenceAtLeast matches the given sequence <= the current sequence +func SequenceAtLeast(sequence uint32) CurrentSequence { + return func(current uint32) bool { + return current >= sequence + } +} diff --git a/internal/v2/eventstore/event.go b/internal/v2/eventstore/event.go new file mode 100644 index 0000000000..b452093305 --- /dev/null +++ b/internal/v2/eventstore/event.go @@ -0,0 +1,36 @@ +package eventstore + +import "time" + +type Event[P any] struct { + Aggregate Aggregate + CreatedAt time.Time + Creator string + Position GlobalPosition + Revision uint16 + Sequence uint32 + Type string + Payload P +} + +type StoragePayload interface { + Unmarshal(ptr any) error +} + +func EventFromStorage[E Event[P], P any](event *Event[StoragePayload]) (*E, error) { + var payload P + + if err := event.Payload.Unmarshal(&payload); err != nil { + return nil, err + } + return &E{ + Aggregate: event.Aggregate, + CreatedAt: event.CreatedAt, + Creator: event.Creator, + Position: event.Position, + Revision: event.Revision, + Sequence: event.Sequence, + Type: event.Type, + Payload: payload, + }, nil +} diff --git a/internal/v2/eventstore/event_store.go b/internal/v2/eventstore/event_store.go new file mode 100644 index 0000000000..fe70bb36a3 --- /dev/null +++ b/internal/v2/eventstore/event_store.go @@ -0,0 +1,41 @@ +package eventstore + +import ( + "context" +) + +func NewEventstore(querier Querier, pusher Pusher) *EventStore { + return &EventStore{ + Pusher: pusher, + Querier: querier, + } +} + +func NewEventstoreFromOne(o one) *EventStore { + return NewEventstore(o, o) +} + +type EventStore struct { + Pusher + Querier +} + +type one interface { + Pusher + Querier +} + +type healthier interface { + Health(ctx context.Context) error +} + +type GlobalPosition struct { + Position float64 + InPositionOrder uint32 +} + +type Reducer interface { + Reduce(events ...*Event[StoragePayload]) error +} + +type Reduce func(events ...*Event[StoragePayload]) error diff --git a/internal/v2/eventstore/postgres/event.go b/internal/v2/eventstore/postgres/event.go new file mode 100644 index 0000000000..9970dd14ea --- /dev/null +++ b/internal/v2/eventstore/postgres/event.go @@ -0,0 +1,64 @@ +package postgres + +import ( + "encoding/json" + + "github.com/zitadel/logging" + + "github.com/zitadel/zitadel/internal/v2/eventstore" + "github.com/zitadel/zitadel/internal/zerrors" +) + +func intentToCommands(intent *intent) (commands []*command, err error) { + commands = make([]*command, len(intent.Commands())) + + for i, cmd := range intent.Commands() { + var payload unmarshalPayload + if cmd.Payload() != nil { + payload, err = json.Marshal(cmd.Payload()) + if err != nil { + logging.WithError(err).Warn("marshal payload failed") + return nil, zerrors.ThrowInternal(err, "POSTG-MInPK", "Errors.Internal") + } + } + + commands[i] = &command{ + Event: &eventstore.Event[eventstore.StoragePayload]{ + Aggregate: *intent.Aggregate(), + Creator: cmd.Creator(), + Revision: cmd.Revision(), + Type: cmd.Type(), + // always add at least 1 to the currently stored sequence + Sequence: intent.sequence + uint32(i) + 1, + Payload: payload, + }, + intent: intent, + uniqueConstraints: cmd.UniqueConstraints(), + } + } + + return commands, nil +} + +type command struct { + *eventstore.Event[eventstore.StoragePayload] + + intent *intent + uniqueConstraints []*eventstore.UniqueConstraint +} + +var _ eventstore.StoragePayload = (unmarshalPayload)(nil) + +type unmarshalPayload []byte + +// Unmarshal implements eventstore.StoragePayload. +func (p unmarshalPayload) Unmarshal(ptr any) error { + if len(p) == 0 { + return nil + } + if err := json.Unmarshal(p, ptr); err != nil { + return zerrors.ThrowInternal(err, "POSTG-u8qVo", "Errors.Internal") + } + + return nil +} diff --git a/internal/v2/eventstore/postgres/intent.go b/internal/v2/eventstore/postgres/intent.go new file mode 100644 index 0000000000..9ab259ada8 --- /dev/null +++ b/internal/v2/eventstore/postgres/intent.go @@ -0,0 +1,42 @@ +package postgres + +import ( + "github.com/zitadel/logging" + + "github.com/zitadel/zitadel/internal/v2/eventstore" +) + +type intent struct { + *eventstore.PushAggregate + + sequence uint32 +} + +func makeIntents(pushIntent *eventstore.PushIntent) []*intent { + res := make([]*intent, len(pushIntent.Aggregates())) + + for i, aggregate := range pushIntent.Aggregates() { + res[i] = &intent{PushAggregate: aggregate} + } + + return res +} + +func intentByAggregate(intents []*intent, aggregate *eventstore.Aggregate) *intent { + for _, intent := range intents { + if intent.PushAggregate.Aggregate().Equals(aggregate) { + return intent + } + } + logging.WithFields("instance", aggregate.Instance, "owner", aggregate.Owner, "type", aggregate.Type, "id", aggregate.ID).Panic("no intent found") + return nil +} + +func checkSequences(intents []*intent) bool { + for _, intent := range intents { + if !eventstore.CheckSequence(intent.sequence, intent.PushAggregate.CurrentSequence()) { + return false + } + } + return true +} diff --git a/internal/v2/eventstore/postgres/intent_test.go b/internal/v2/eventstore/postgres/intent_test.go new file mode 100644 index 0000000000..93b3aa2162 --- /dev/null +++ b/internal/v2/eventstore/postgres/intent_test.go @@ -0,0 +1,122 @@ +package postgres + +import ( + "testing" + + "github.com/zitadel/zitadel/internal/v2/eventstore" +) + +func Test_checkSequences(t *testing.T) { + type args struct { + intents []*intent + } + tests := []struct { + name string + args args + want bool + }{ + { + name: "ignore", + args: args{ + intents: []*intent{ + { + sequence: 1, + PushAggregate: eventstore.NewPushAggregate( + "", "", "", + eventstore.IgnoreCurrentSequence(), + ), + }, + }, + }, + want: true, + }, + { + name: "ignores", + args: args{ + intents: []*intent{ + { + sequence: 1, + PushAggregate: eventstore.NewPushAggregate( + "", "", "", + eventstore.IgnoreCurrentSequence(), + ), + }, + { + sequence: 1, + PushAggregate: eventstore.NewPushAggregate( + "", "", "", + ), + }, + }, + }, + want: true, + }, + { + name: "matches", + args: args{ + intents: []*intent{ + { + sequence: 0, + PushAggregate: eventstore.NewPushAggregate( + "", "", "", + eventstore.CurrentSequenceMatches(0), + ), + }, + }, + }, + want: true, + }, + { + name: "does not match", + args: args{ + intents: []*intent{ + { + sequence: 1, + PushAggregate: eventstore.NewPushAggregate( + "", "", "", + eventstore.CurrentSequenceMatches(2), + ), + }, + }, + }, + want: false, + }, + { + name: "at least", + args: args{ + intents: []*intent{ + { + sequence: 10, + PushAggregate: eventstore.NewPushAggregate( + "", "", "", + eventstore.CurrentSequenceAtLeast(0), + ), + }, + }, + }, + want: true, + }, + { + name: "at least too low", + args: args{ + intents: []*intent{ + { + sequence: 1, + PushAggregate: eventstore.NewPushAggregate( + "", "", "", + eventstore.CurrentSequenceAtLeast(2), + ), + }, + }, + }, + want: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := checkSequences(tt.args.intents); got != tt.want { + t.Errorf("checkSequences() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/internal/v2/eventstore/postgres/push.go b/internal/v2/eventstore/postgres/push.go new file mode 100644 index 0000000000..7ae64fd41d --- /dev/null +++ b/internal/v2/eventstore/postgres/push.go @@ -0,0 +1,245 @@ +package postgres + +import ( + "context" + "database/sql" + + "github.com/zitadel/logging" + + "github.com/zitadel/zitadel/internal/telemetry/tracing" + "github.com/zitadel/zitadel/internal/v2/database" + "github.com/zitadel/zitadel/internal/v2/eventstore" + "github.com/zitadel/zitadel/internal/zerrors" +) + +// Push implements eventstore.Pusher. +func (s *Storage) Push(ctx context.Context, intent *eventstore.PushIntent) (err error) { + ctx, span := tracing.NewSpan(ctx) + defer func() { span.EndWithError(err) }() + + tx := intent.Tx() + if tx == nil { + tx, err = s.client.BeginTx(ctx, &sql.TxOptions{Isolation: sql.LevelSerializable, ReadOnly: false}) + if err != nil { + return err + } + defer func() { + err = database.CloseTx(tx, err) + }() + } + + // allows smaller wait times on query side for instances which are not actively writing + if err := setAppName(ctx, tx, "es_pusher_"+intent.Instance()); err != nil { + return err + } + + intents, err := lockAggregates(ctx, tx, intent) + if err != nil { + return err + } + + if !checkSequences(intents) { + return zerrors.ThrowInvalidArgument(nil, "POSTG-KOM6E", "Errors.Internal.Eventstore.SequenceNotMatched") + } + + commands := make([]*command, 0, len(intents)) + for _, intent := range intents { + additionalCommands, err := intentToCommands(intent) + if err != nil { + return err + } + commands = append(commands, additionalCommands...) + } + + err = uniqueConstraints(ctx, tx, commands) + if err != nil { + return err + } + + return push(ctx, tx, intent, commands) +} + +// setAppName for the the current transaction +func setAppName(ctx context.Context, tx *sql.Tx, name string) error { + _, err := tx.ExecContext(ctx, "SET LOCAL application_name TO $1", name) + if err != nil { + logging.WithFields("name", name).WithError(err).Debug("setting app name failed") + return zerrors.ThrowInternal(err, "POSTG-G3OmZ", "Errors.Internal") + } + + return nil +} + +func lockAggregates(ctx context.Context, tx *sql.Tx, intent *eventstore.PushIntent) (_ []*intent, err error) { + ctx, span := tracing.NewSpan(ctx) + defer func() { span.EndWithError(err) }() + + var stmt database.Statement + + stmt.WriteString("WITH existing AS (") + for i, aggregate := range intent.Aggregates() { + if i > 0 { + stmt.WriteString(" UNION ALL ") + } + stmt.WriteString(`(SELECT instance_id, aggregate_type, aggregate_id, "sequence" FROM eventstore.events2 WHERE instance_id = `) + stmt.WriteArgs(intent.Instance()) + stmt.WriteString(` AND aggregate_type = `) + stmt.WriteArgs(aggregate.Type()) + stmt.WriteString(` AND aggregate_id = `) + stmt.WriteArgs(aggregate.ID()) + stmt.WriteString(` AND owner = `) + stmt.WriteArgs(aggregate.Owner()) + stmt.WriteString(` ORDER BY "sequence" DESC LIMIT 1)`) + } + stmt.WriteString(") SELECT e.instance_id, e.owner, e.aggregate_type, e.aggregate_id, e.sequence FROM eventstore.events2 e JOIN existing ON e.instance_id = existing.instance_id AND e.aggregate_type = existing.aggregate_type AND e.aggregate_id = existing.aggregate_id AND e.sequence = existing.sequence FOR UPDATE") + + //nolint:rowserrcheck + // rows is checked by database.MapRowsToObject + rows, err := tx.QueryContext(ctx, stmt.String(), stmt.Args()...) + if err != nil { + return nil, err + } + + res := makeIntents(intent) + + err = database.MapRowsToObject(rows, func(scan func(dest ...any) error) error { + var sequence sql.Null[uint32] + agg := new(eventstore.Aggregate) + + err := scan( + &agg.Instance, + &agg.Owner, + &agg.Type, + &agg.ID, + &sequence, + ) + if err != nil { + return err + } + + intentByAggregate(res, agg).sequence = sequence.V + + return nil + }) + if err != nil { + return nil, err + } + + return res, nil +} + +func push(ctx context.Context, tx *sql.Tx, reducer eventstore.Reducer, commands []*command) (err error) { + ctx, span := tracing.NewSpan(ctx) + defer func() { span.EndWithError(err) }() + + var stmt database.Statement + + stmt.WriteString(`INSERT INTO eventstore.events2 (instance_id, "owner", aggregate_type, aggregate_id, revision, creator, event_type, payload, "sequence", in_tx_order, created_at, "position") VALUES `) + for i, cmd := range commands { + if i > 0 { + stmt.WriteString(", ") + } + + cmd.Position.InPositionOrder = uint32(i) + stmt.WriteString(`(`) + stmt.WriteArgs( + cmd.Aggregate.Instance, + cmd.Aggregate.Owner, + cmd.Aggregate.Type, + cmd.Aggregate.ID, + cmd.Revision, + cmd.Creator, + cmd.Type, + cmd.Payload, + cmd.Sequence, + i, + ) + stmt.WriteString(", statement_timestamp(), EXTRACT(EPOCH FROM clock_timestamp())") + stmt.WriteString(`)`) + } + stmt.WriteString(` RETURNING created_at, "position"`) + + //nolint:rowserrcheck + // rows is checked by database.MapRowsToObject + rows, err := tx.QueryContext(ctx, stmt.String(), stmt.Args()...) + if err != nil { + return err + } + + var i int + return database.MapRowsToObject(rows, func(scan func(dest ...any) error) error { + defer func() { i++ }() + + err := scan( + &commands[i].CreatedAt, + &commands[i].Position.Position, + ) + if err != nil { + return err + } + return reducer.Reduce(commands[i].Event) + }) +} + +func uniqueConstraints(ctx context.Context, tx *sql.Tx, commands []*command) (err error) { + ctx, span := tracing.NewSpan(ctx) + defer func() { span.EndWithError(err) }() + + var stmt database.Statement + + for _, cmd := range commands { + if len(cmd.uniqueConstraints) == 0 { + continue + } + for _, constraint := range cmd.uniqueConstraints { + stmt.Reset() + instance := cmd.Aggregate.Instance + if constraint.IsGlobal { + instance = "" + } + switch constraint.Action { + case eventstore.UniqueConstraintAdd: + stmt.WriteString(`INSERT INTO eventstore.unique_constraints (instance_id, unique_type, unique_field) VALUES (`) + stmt.WriteArgs(instance, constraint.UniqueType, constraint.UniqueField) + stmt.WriteRune(')') + case eventstore.UniqueConstraintInstanceRemove: + stmt.WriteString(`DELETE FROM eventstore.unique_constraints WHERE instance_id = `) + stmt.WriteArgs(instance) + case eventstore.UniqueConstraintRemove: + stmt.WriteString(`DELETE FROM eventstore.unique_constraints WHERE `) + stmt.WriteString(deleteUniqueConstraintClause) + stmt.AppendArgs( + instance, + constraint.UniqueType, + constraint.UniqueField, + ) + } + _, err := tx.ExecContext(ctx, stmt.String(), stmt.Args()...) + if err != nil { + logging.WithFields("action", constraint.Action).Warn("handling of unique constraint failed") + errMessage := constraint.ErrorMessage + if errMessage == "" { + errMessage = "Errors.Internal" + } + return zerrors.ThrowAlreadyExists(err, "POSTG-QzjyP", errMessage) + } + } + } + + return nil +} + +// the query is so complex because we accidentally stored unique constraint case sensitive +// the query checks first if there is a case sensitive match and afterwards if there is a case insensitive match +var deleteUniqueConstraintClause = ` +(instance_id = $1 AND unique_type = $2 AND unique_field = ( + SELECT unique_field from ( + SELECT instance_id, unique_type, unique_field + FROM eventstore.unique_constraints + WHERE instance_id = $1 AND unique_type = $2 AND unique_field = $3 + UNION ALL + SELECT instance_id, unique_type, unique_field + FROM eventstore.unique_constraints + WHERE instance_id = $1 AND unique_type = $2 AND unique_field = LOWER($3) + ) AS case_insensitive_constraints LIMIT 1) +)` diff --git a/internal/v2/eventstore/postgres/push_test.go b/internal/v2/eventstore/postgres/push_test.go new file mode 100644 index 0000000000..819add334f --- /dev/null +++ b/internal/v2/eventstore/postgres/push_test.go @@ -0,0 +1,1292 @@ +package postgres + +import ( + "context" + "database/sql/driver" + "errors" + "reflect" + "testing" + "time" + + "github.com/zitadel/zitadel/internal/v2/database/mock" + "github.com/zitadel/zitadel/internal/v2/eventstore" + "github.com/zitadel/zitadel/internal/zerrors" +) + +func Test_uniqueConstraints(t *testing.T) { + type args struct { + commands []*command + expectations []mock.Expectation + } + execErr := errors.New("exec err") + tests := []struct { + name string + args args + assertErr func(t *testing.T, err error) bool + }{ + { + name: "no commands", + args: args{ + commands: []*command{}, + expectations: []mock.Expectation{}, + }, + assertErr: expectNoErr, + }, + { + name: "command without constraints", + args: args{ + commands: []*command{ + {}, + }, + expectations: []mock.Expectation{}, + }, + assertErr: expectNoErr, + }, + { + name: "add 1 constraint 1 command", + args: args{ + commands: []*command{ + { + Event: &eventstore.Event[eventstore.StoragePayload]{ + Aggregate: eventstore.Aggregate{ + Instance: "instance", + }, + }, + uniqueConstraints: []*eventstore.UniqueConstraint{ + eventstore.NewAddEventUniqueConstraint("test", "id", "error"), + }, + }, + }, + expectations: []mock.Expectation{ + mock.ExpectExec( + "INSERT INTO eventstore.unique_constraints (instance_id, unique_type, unique_field) VALUES ($1, $2, $3)", + mock.WithExecArgs("instance", "test", "id"), + mock.WithExecRowsAffected(1), + ), + }, + }, + assertErr: expectNoErr, + }, + { + name: "add 1 global constraint 1 command", + args: args{ + commands: []*command{ + { + Event: &eventstore.Event[eventstore.StoragePayload]{ + Aggregate: eventstore.Aggregate{ + Instance: "instance", + }, + }, + uniqueConstraints: []*eventstore.UniqueConstraint{ + eventstore.NewAddGlobalUniqueConstraint("test", "id", "error"), + }, + }, + }, + expectations: []mock.Expectation{ + mock.ExpectExec( + "INSERT INTO eventstore.unique_constraints (instance_id, unique_type, unique_field) VALUES ($1, $2, $3)", + mock.WithExecArgs("", "test", "id"), + mock.WithExecRowsAffected(1), + ), + }, + }, + assertErr: expectNoErr, + }, + { + name: "add 2 constraint 1 command", + args: args{ + commands: []*command{ + { + Event: &eventstore.Event[eventstore.StoragePayload]{ + Aggregate: eventstore.Aggregate{ + Instance: "instance", + }, + }, + uniqueConstraints: []*eventstore.UniqueConstraint{ + eventstore.NewAddEventUniqueConstraint("test", "id", "error"), + eventstore.NewAddEventUniqueConstraint("test", "id2", "error"), + }, + }, + }, + expectations: []mock.Expectation{ + mock.ExpectExec( + "INSERT INTO eventstore.unique_constraints (instance_id, unique_type, unique_field) VALUES ($1, $2, $3)", + mock.WithExecArgs("instance", "test", "id"), + mock.WithExecRowsAffected(1), + ), + mock.ExpectExec( + "INSERT INTO eventstore.unique_constraints (instance_id, unique_type, unique_field) VALUES ($1, $2, $3)", + mock.WithExecArgs("instance", "test", "id2"), + mock.WithExecRowsAffected(1), + ), + }, + }, + assertErr: expectNoErr, + }, + { + name: "add 1 constraint per command", + args: args{ + commands: []*command{ + { + Event: &eventstore.Event[eventstore.StoragePayload]{ + Aggregate: eventstore.Aggregate{ + Instance: "instance", + }, + }, + uniqueConstraints: []*eventstore.UniqueConstraint{ + eventstore.NewAddEventUniqueConstraint("test", "id", "error"), + }, + }, + { + Event: &eventstore.Event[eventstore.StoragePayload]{ + Aggregate: eventstore.Aggregate{ + Instance: "instance", + }, + }, + uniqueConstraints: []*eventstore.UniqueConstraint{ + eventstore.NewAddEventUniqueConstraint("test", "id2", "error"), + }, + }, + }, + expectations: []mock.Expectation{ + mock.ExpectExec( + "INSERT INTO eventstore.unique_constraints (instance_id, unique_type, unique_field) VALUES ($1, $2, $3)", + mock.WithExecArgs("instance", "test", "id"), + mock.WithExecRowsAffected(1), + ), + mock.ExpectExec( + "INSERT INTO eventstore.unique_constraints (instance_id, unique_type, unique_field) VALUES ($1, $2, $3)", + mock.WithExecArgs("instance", "test", "id2"), + mock.WithExecRowsAffected(1), + ), + }, + }, + assertErr: expectNoErr, + }, + { + name: "remove instance constraints 1 command", + args: args{ + commands: []*command{ + { + Event: &eventstore.Event[eventstore.StoragePayload]{ + Aggregate: eventstore.Aggregate{ + Instance: "instance", + }, + }, + uniqueConstraints: []*eventstore.UniqueConstraint{ + eventstore.NewRemoveInstanceUniqueConstraints(), + }, + }, + }, + expectations: []mock.Expectation{ + mock.ExpectExec( + "DELETE FROM eventstore.unique_constraints WHERE instance_id = $1", + mock.WithExecArgs("instance"), + mock.WithExecRowsAffected(10), + ), + }, + }, + assertErr: expectNoErr, + }, + { + name: "remove instance constraints 2 commands", + args: args{ + commands: []*command{ + { + Event: &eventstore.Event[eventstore.StoragePayload]{ + Aggregate: eventstore.Aggregate{ + Instance: "instance", + }, + }, + uniqueConstraints: []*eventstore.UniqueConstraint{ + eventstore.NewRemoveInstanceUniqueConstraints(), + }, + }, + { + Event: &eventstore.Event[eventstore.StoragePayload]{ + Aggregate: eventstore.Aggregate{ + Instance: "instance", + }, + }, + uniqueConstraints: []*eventstore.UniqueConstraint{ + eventstore.NewRemoveInstanceUniqueConstraints(), + }, + }, + }, + expectations: []mock.Expectation{ + mock.ExpectExec( + "DELETE FROM eventstore.unique_constraints WHERE instance_id = $1", + mock.WithExecArgs("instance"), + mock.WithExecRowsAffected(10), + ), + mock.ExpectExec( + "DELETE FROM eventstore.unique_constraints WHERE instance_id = $1", + mock.WithExecArgs("instance"), + mock.WithExecRowsAffected(0), + ), + }, + }, + assertErr: expectNoErr, + }, + { + name: "remove 1 constraint 1 command", + args: args{ + commands: []*command{ + { + Event: &eventstore.Event[eventstore.StoragePayload]{ + Aggregate: eventstore.Aggregate{ + Instance: "instance", + }, + }, + uniqueConstraints: []*eventstore.UniqueConstraint{ + eventstore.NewRemoveUniqueConstraint("test", "id"), + }, + }, + }, + expectations: []mock.Expectation{ + mock.ExpectExec( + `DELETE FROM eventstore.unique_constraints WHERE (instance_id = $1 AND unique_type = $2 AND unique_field = ( SELECT unique_field from ( SELECT instance_id, unique_type, unique_field FROM eventstore.unique_constraints WHERE instance_id = $1 AND unique_type = $2 AND unique_field = $3 UNION ALL SELECT instance_id, unique_type, unique_field FROM eventstore.unique_constraints WHERE instance_id = $1 AND unique_type = $2 AND unique_field = LOWER($3) ) AS case_insensitive_constraints LIMIT 1) )`, + mock.WithExecArgs("instance", "test", "id"), + mock.WithExecRowsAffected(1), + ), + }, + }, + assertErr: expectNoErr, + }, + { + name: "remove 1 global constraint 1 command", + args: args{ + commands: []*command{ + { + Event: &eventstore.Event[eventstore.StoragePayload]{ + Aggregate: eventstore.Aggregate{ + Instance: "instance", + }, + }, + uniqueConstraints: []*eventstore.UniqueConstraint{ + eventstore.NewRemoveGlobalUniqueConstraint("test", "id"), + }, + }, + }, + expectations: []mock.Expectation{ + mock.ExpectExec( + `DELETE FROM eventstore.unique_constraints WHERE (instance_id = $1 AND unique_type = $2 AND unique_field = ( SELECT unique_field from ( SELECT instance_id, unique_type, unique_field FROM eventstore.unique_constraints WHERE instance_id = $1 AND unique_type = $2 AND unique_field = $3 UNION ALL SELECT instance_id, unique_type, unique_field FROM eventstore.unique_constraints WHERE instance_id = $1 AND unique_type = $2 AND unique_field = LOWER($3) ) AS case_insensitive_constraints LIMIT 1) )`, + mock.WithExecArgs("", "test", "id"), + mock.WithExecRowsAffected(1), + ), + }, + }, + assertErr: expectNoErr, + }, + { + name: "remove 2 constraints 1 command", + args: args{ + commands: []*command{ + { + Event: &eventstore.Event[eventstore.StoragePayload]{ + Aggregate: eventstore.Aggregate{ + Instance: "instance", + }, + }, + uniqueConstraints: []*eventstore.UniqueConstraint{ + eventstore.NewRemoveUniqueConstraint("test", "id"), + eventstore.NewRemoveUniqueConstraint("test", "id2"), + }, + }, + }, + expectations: []mock.Expectation{ + mock.ExpectExec( + `DELETE FROM eventstore.unique_constraints WHERE (instance_id = $1 AND unique_type = $2 AND unique_field = ( SELECT unique_field from ( SELECT instance_id, unique_type, unique_field FROM eventstore.unique_constraints WHERE instance_id = $1 AND unique_type = $2 AND unique_field = $3 UNION ALL SELECT instance_id, unique_type, unique_field FROM eventstore.unique_constraints WHERE instance_id = $1 AND unique_type = $2 AND unique_field = LOWER($3) ) AS case_insensitive_constraints LIMIT 1) )`, + mock.WithExecArgs("instance", "test", "id"), + mock.WithExecRowsAffected(1), + ), + mock.ExpectExec( + `DELETE FROM eventstore.unique_constraints WHERE (instance_id = $1 AND unique_type = $2 AND unique_field = ( SELECT unique_field from ( SELECT instance_id, unique_type, unique_field FROM eventstore.unique_constraints WHERE instance_id = $1 AND unique_type = $2 AND unique_field = $3 UNION ALL SELECT instance_id, unique_type, unique_field FROM eventstore.unique_constraints WHERE instance_id = $1 AND unique_type = $2 AND unique_field = LOWER($3) ) AS case_insensitive_constraints LIMIT 1) )`, + mock.WithExecArgs("instance", "test", "id2"), + mock.WithExecRowsAffected(1), + ), + }, + }, + assertErr: expectNoErr, + }, + { + name: "remove 1 constraints per command", + args: args{ + commands: []*command{ + { + Event: &eventstore.Event[eventstore.StoragePayload]{ + Aggregate: eventstore.Aggregate{ + Instance: "instance", + }, + }, + uniqueConstraints: []*eventstore.UniqueConstraint{ + eventstore.NewRemoveUniqueConstraint("test", "id"), + }, + }, + { + Event: &eventstore.Event[eventstore.StoragePayload]{ + Aggregate: eventstore.Aggregate{ + Instance: "instance", + }, + }, + uniqueConstraints: []*eventstore.UniqueConstraint{ + eventstore.NewRemoveUniqueConstraint("test", "id2"), + }, + }, + }, + expectations: []mock.Expectation{ + mock.ExpectExec( + `DELETE FROM eventstore.unique_constraints WHERE (instance_id = $1 AND unique_type = $2 AND unique_field = ( SELECT unique_field from ( SELECT instance_id, unique_type, unique_field FROM eventstore.unique_constraints WHERE instance_id = $1 AND unique_type = $2 AND unique_field = $3 UNION ALL SELECT instance_id, unique_type, unique_field FROM eventstore.unique_constraints WHERE instance_id = $1 AND unique_type = $2 AND unique_field = LOWER($3) ) AS case_insensitive_constraints LIMIT 1) )`, + mock.WithExecArgs("instance", "test", "id"), + mock.WithExecRowsAffected(1), + ), + mock.ExpectExec( + `DELETE FROM eventstore.unique_constraints WHERE (instance_id = $1 AND unique_type = $2 AND unique_field = ( SELECT unique_field from ( SELECT instance_id, unique_type, unique_field FROM eventstore.unique_constraints WHERE instance_id = $1 AND unique_type = $2 AND unique_field = $3 UNION ALL SELECT instance_id, unique_type, unique_field FROM eventstore.unique_constraints WHERE instance_id = $1 AND unique_type = $2 AND unique_field = LOWER($3) ) AS case_insensitive_constraints LIMIT 1) )`, + mock.WithExecArgs("instance", "test", "id2"), + mock.WithExecRowsAffected(1), + ), + }, + }, + assertErr: expectNoErr, + }, + { + name: "exec fails no error specified", + args: args{ + commands: []*command{ + { + Event: &eventstore.Event[eventstore.StoragePayload]{ + Aggregate: eventstore.Aggregate{ + Instance: "instance", + }, + }, + uniqueConstraints: []*eventstore.UniqueConstraint{ + eventstore.NewAddEventUniqueConstraint("test", "id", ""), + }, + }, + }, + expectations: []mock.Expectation{ + mock.ExpectExec( + "INSERT INTO eventstore.unique_constraints (instance_id, unique_type, unique_field) VALUES ($1, $2, $3)", + mock.WithExecArgs("instance", "test", "id"), + mock.WithExecErr(execErr), + ), + }, + }, + assertErr: func(t *testing.T, err error) bool { + is := errors.Is(err, zerrors.ThrowAlreadyExists(execErr, "POSTG-QzjyP", "Errors.Internal")) + if !is { + t.Errorf("no error expected got: %v", err) + } + return is + }, + }, + { + name: "exec fails error specified", + args: args{ + commands: []*command{ + { + Event: &eventstore.Event[eventstore.StoragePayload]{ + Aggregate: eventstore.Aggregate{ + Instance: "instance", + }, + }, + uniqueConstraints: []*eventstore.UniqueConstraint{ + eventstore.NewAddEventUniqueConstraint("test", "id", "My.Error"), + }, + }, + }, + expectations: []mock.Expectation{ + mock.ExpectExec( + "INSERT INTO eventstore.unique_constraints (instance_id, unique_type, unique_field) VALUES ($1, $2, $3)", + mock.WithExecArgs("instance", "test", "id"), + mock.WithExecErr(execErr), + ), + }, + }, + assertErr: func(t *testing.T, err error) bool { + is := errors.Is(err, zerrors.ThrowAlreadyExists(execErr, "POSTG-QzjyP", "My.Error")) + if !is { + t.Errorf("no error expected got: %v", err) + } + return is + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + dbMock := mock.NewSQLMock(t, append([]mock.Expectation{mock.ExpectBegin(nil)}, tt.args.expectations...)...) + tx, err := dbMock.DB.Begin() + if err != nil { + t.Errorf("unexpected error in begin: %v", err) + t.FailNow() + } + err = uniqueConstraints(context.Background(), tx, tt.args.commands) + tt.assertErr(t, err) + dbMock.Assert(t) + }) + } +} + +var errReduce = errors.New("reduce err") + +func Test_lockAggregates(t *testing.T) { + type args struct { + pushIntent *eventstore.PushIntent + expectations []mock.Expectation + } + type want struct { + intents []*intent + assertErr func(t *testing.T, err error) bool + } + tests := []struct { + name string + args args + want want + }{ + { + name: "1 intent", + args: args{ + pushIntent: eventstore.NewPushIntent( + "instance", + eventstore.AppendAggregate("owner", "testType", "testID"), + ), + expectations: []mock.Expectation{ + mock.ExpectQuery( + `WITH existing AS ((SELECT instance_id, aggregate_type, aggregate_id, "sequence" FROM eventstore.events2 WHERE instance_id = $1 AND aggregate_type = $2 AND aggregate_id = $3 AND owner = $4 ORDER BY "sequence" DESC LIMIT 1)) SELECT e.instance_id, e.owner, e.aggregate_type, e.aggregate_id, e.sequence FROM eventstore.events2 e JOIN existing ON e.instance_id = existing.instance_id AND e.aggregate_type = existing.aggregate_type AND e.aggregate_id = existing.aggregate_id AND e.sequence = existing.sequence FOR UPDATE`, + mock.WithQueryArgs("instance", "testType", "testID", "owner"), + mock.WithQueryResult( + []string{"instance_id", "owner", "aggregate_type", "aggregate_id", "sequence"}, + [][]driver.Value{ + { + "instance", + "owner", + "testType", + "testID", + 42, + }, + }, + ), + ), + }, + }, + want: want{ + intents: []*intent{ + { + PushAggregate: eventstore.NewPushIntent( + "instance", + eventstore.AppendAggregate("owner", "testType", "testID"), + ).Aggregates()[0], + sequence: 42, + }, + }, + assertErr: expectNoErr, + }, + }, + { + name: "two intents", + args: args{ + pushIntent: eventstore.NewPushIntent( + "instance", + eventstore.AppendAggregate("owner", "testType", "testID"), + eventstore.AppendAggregate("owner", "myType", "id"), + ), + expectations: []mock.Expectation{ + mock.ExpectQuery( + `WITH existing AS ((SELECT instance_id, aggregate_type, aggregate_id, "sequence" FROM eventstore.events2 WHERE instance_id = $1 AND aggregate_type = $2 AND aggregate_id = $3 AND owner = $4 ORDER BY "sequence" DESC LIMIT 1) UNION ALL (SELECT instance_id, aggregate_type, aggregate_id, "sequence" FROM eventstore.events2 WHERE instance_id = $5 AND aggregate_type = $6 AND aggregate_id = $7 AND owner = $8 ORDER BY "sequence" DESC LIMIT 1)) SELECT e.instance_id, e.owner, e.aggregate_type, e.aggregate_id, e.sequence FROM eventstore.events2 e JOIN existing ON e.instance_id = existing.instance_id AND e.aggregate_type = existing.aggregate_type AND e.aggregate_id = existing.aggregate_id AND e.sequence = existing.sequence FOR UPDATE`, + mock.WithQueryArgs( + "instance", "testType", "testID", "owner", + "instance", "myType", "id", "owner", + ), + mock.WithQueryResult( + []string{"instance_id", "owner", "aggregate_type", "aggregate_id", "sequence"}, + [][]driver.Value{ + { + "instance", + "owner", + "testType", + "testID", + 42, + }, + { + "instance", + "owner", + "myType", + "id", + 17, + }, + }, + ), + ), + }, + }, + want: want{ + intents: []*intent{ + { + PushAggregate: eventstore.NewPushIntent( + "instance", + eventstore.AppendAggregate("owner", "testType", "testID"), + ).Aggregates()[0], + sequence: 42, + }, + { + PushAggregate: eventstore.NewPushIntent( + "instance", + eventstore.AppendAggregate("owner", "myType", "id"), + ).Aggregates()[0], + sequence: 17, + }, + }, + assertErr: expectNoErr, + }, + }, + { + name: "1 intent aggregate not found", + args: args{ + pushIntent: eventstore.NewPushIntent( + "instance", + eventstore.AppendAggregate("owner", "testType", "testID"), + ), + expectations: []mock.Expectation{ + mock.ExpectQuery( + `WITH existing AS ((SELECT instance_id, aggregate_type, aggregate_id, "sequence" FROM eventstore.events2 WHERE instance_id = $1 AND aggregate_type = $2 AND aggregate_id = $3 AND owner = $4 ORDER BY "sequence" DESC LIMIT 1)) SELECT e.instance_id, e.owner, e.aggregate_type, e.aggregate_id, e.sequence FROM eventstore.events2 e JOIN existing ON e.instance_id = existing.instance_id AND e.aggregate_type = existing.aggregate_type AND e.aggregate_id = existing.aggregate_id AND e.sequence = existing.sequence FOR UPDATE`, + mock.WithQueryArgs("instance", "testType", "testID", "owner"), + mock.WithQueryResult( + []string{"instance_id", "owner", "aggregate_type", "aggregate_id", "sequence"}, + [][]driver.Value{}, + ), + ), + }, + }, + want: want{ + intents: []*intent{ + { + PushAggregate: eventstore.NewPushIntent( + "instance", + eventstore.AppendAggregate("owner", "testType", "testID"), + ).Aggregates()[0], + sequence: 0, + }, + }, + assertErr: expectNoErr, + }, + }, + { + name: "two intents none found", + args: args{ + pushIntent: eventstore.NewPushIntent( + "instance", + eventstore.AppendAggregate("owner", "testType", "testID"), + eventstore.AppendAggregate("owner", "myType", "id"), + ), + expectations: []mock.Expectation{ + mock.ExpectQuery( + `WITH existing AS ((SELECT instance_id, aggregate_type, aggregate_id, "sequence" FROM eventstore.events2 WHERE instance_id = $1 AND aggregate_type = $2 AND aggregate_id = $3 AND owner = $4 ORDER BY "sequence" DESC LIMIT 1) UNION ALL (SELECT instance_id, aggregate_type, aggregate_id, "sequence" FROM eventstore.events2 WHERE instance_id = $5 AND aggregate_type = $6 AND aggregate_id = $7 AND owner = $8 ORDER BY "sequence" DESC LIMIT 1)) SELECT e.instance_id, e.owner, e.aggregate_type, e.aggregate_id, e.sequence FROM eventstore.events2 e JOIN existing ON e.instance_id = existing.instance_id AND e.aggregate_type = existing.aggregate_type AND e.aggregate_id = existing.aggregate_id AND e.sequence = existing.sequence FOR UPDATE`, + mock.WithQueryArgs( + "instance", "testType", "testID", "owner", + "instance", "myType", "id", "owner", + ), + mock.WithQueryResult( + []string{"instance_id", "owner", "aggregate_type", "aggregate_id", "sequence"}, + [][]driver.Value{}, + ), + ), + }, + }, + want: want{ + intents: []*intent{ + { + PushAggregate: eventstore.NewPushIntent( + "instance", + eventstore.AppendAggregate("owner", "testType", "testID"), + ).Aggregates()[0], + sequence: 0, + }, + { + PushAggregate: eventstore.NewPushIntent( + "instance", + eventstore.AppendAggregate("owner", "myType", "id"), + ).Aggregates()[0], + sequence: 0, + }, + }, + assertErr: expectNoErr, + }, + }, + { + name: "two intents 1 found", + args: args{ + pushIntent: eventstore.NewPushIntent( + "instance", + eventstore.AppendAggregate("owner", "testType", "testID"), + eventstore.AppendAggregate("owner", "myType", "id"), + ), + expectations: []mock.Expectation{ + mock.ExpectQuery( + `WITH existing AS ((SELECT instance_id, aggregate_type, aggregate_id, "sequence" FROM eventstore.events2 WHERE instance_id = $1 AND aggregate_type = $2 AND aggregate_id = $3 AND owner = $4 ORDER BY "sequence" DESC LIMIT 1) UNION ALL (SELECT instance_id, aggregate_type, aggregate_id, "sequence" FROM eventstore.events2 WHERE instance_id = $5 AND aggregate_type = $6 AND aggregate_id = $7 AND owner = $8 ORDER BY "sequence" DESC LIMIT 1)) SELECT e.instance_id, e.owner, e.aggregate_type, e.aggregate_id, e.sequence FROM eventstore.events2 e JOIN existing ON e.instance_id = existing.instance_id AND e.aggregate_type = existing.aggregate_type AND e.aggregate_id = existing.aggregate_id AND e.sequence = existing.sequence FOR UPDATE`, + mock.WithQueryArgs( + "instance", "testType", "testID", "owner", + "instance", "myType", "id", "owner", + ), + mock.WithQueryResult( + []string{"instance_id", "owner", "aggregate_type", "aggregate_id", "sequence"}, + [][]driver.Value{ + { + "instance", + "owner", + "myType", + "id", + 17, + }, + }, + ), + ), + }, + }, + want: want{ + intents: []*intent{ + { + PushAggregate: eventstore.NewPushIntent( + "instance", + eventstore.AppendAggregate("owner", "testType", "testID"), + ).Aggregates()[0], + sequence: 0, + }, + { + PushAggregate: eventstore.NewPushIntent( + "instance", + eventstore.AppendAggregate("owner", "myType", "id"), + ).Aggregates()[0], + sequence: 17, + }, + }, + assertErr: expectNoErr, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + dbMock := mock.NewSQLMock(t, append([]mock.Expectation{mock.ExpectBegin(nil)}, tt.args.expectations...)...) + tx, err := dbMock.DB.Begin() + if err != nil { + t.Errorf("unexpected error in begin: %v", err) + t.FailNow() + } + got, err := lockAggregates(context.Background(), tx, tt.args.pushIntent) + tt.want.assertErr(t, err) + dbMock.Assert(t) + if len(got) != len(tt.want.intents) { + t.Errorf("unexpected length of intents %d, want: %d", len(got), len(tt.want.intents)) + return + } + for i, gotten := range got { + assertIntent(t, gotten, tt.want.intents[i]) + } + }) + } +} + +func assertIntent(t *testing.T, got, want *intent) { + if got.sequence != want.sequence { + t.Errorf("unexpected sequence %d want %d", got.sequence, want.sequence) + } + assertPushAggregate(t, got.PushAggregate, want.PushAggregate) +} + +func assertPushAggregate(t *testing.T, got, want *eventstore.PushAggregate) { + if !reflect.DeepEqual(got.Type(), want.Type()) { + t.Errorf("unexpected Type %v, want: %v", got.Type(), want.Type()) + } + if !reflect.DeepEqual(got.ID(), want.ID()) { + t.Errorf("unexpected ID %v, want: %v", got.ID(), want.ID()) + } + if !reflect.DeepEqual(got.Owner(), want.Owner()) { + t.Errorf("unexpected Owner %v, want: %v", got.Owner(), want.Owner()) + } + if !reflect.DeepEqual(got.Commands(), want.Commands()) { + t.Errorf("unexpected Commands %v, want: %v", got.Commands(), want.Commands()) + } + if !reflect.DeepEqual(got.Aggregate(), want.Aggregate()) { + t.Errorf("unexpected Aggregate %v, want: %v", got.Aggregate(), want.Aggregate()) + } + if !reflect.DeepEqual(got.CurrentSequence(), want.CurrentSequence()) { + t.Errorf("unexpected CurrentSequence %v, want: %v", got.CurrentSequence(), want.CurrentSequence()) + } +} + +func Test_push(t *testing.T) { + type args struct { + commands []*command + expectations []mock.Expectation + reducer *testReducer + } + type want struct { + assertErr func(t *testing.T, err error) bool + } + tests := []struct { + name string + args args + want want + }{ + { + name: "1 aggregate 1 command", + args: args{ + reducer: &testReducer{ + expectedReduces: 1, + }, + commands: []*command{ + { + intent: &intent{ + PushAggregate: eventstore.NewPushIntent( + "instance", + eventstore.AppendAggregate("owner", "testType", "testID"), + ).Aggregates()[0], + }, + Event: &eventstore.Event[eventstore.StoragePayload]{ + Aggregate: *eventstore.NewPushIntent( + "instance", + eventstore.AppendAggregate("owner", "testType", "testID"), + ).Aggregates()[0].Aggregate(), + Creator: "gigi", + Revision: 1, + Type: "test.type", + Sequence: 1, + }, + }, + }, + expectations: []mock.Expectation{ + mock.ExpectQuery( + `INSERT INTO eventstore.events2 (instance_id, "owner", aggregate_type, aggregate_id, revision, creator, event_type, payload, "sequence", in_tx_order, created_at, "position") VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, statement_timestamp(), EXTRACT(EPOCH FROM clock_timestamp())) RETURNING created_at, "position"`, + mock.WithQueryArgs( + "instance", + "owner", + "testType", + "testID", + uint16(1), + "gigi", + "test.type", + nil, + uint32(1), + 0, + ), + mock.WithQueryResult( + []string{"created_at", "position"}, + [][]driver.Value{ + { + time.Now(), + float64(123), + }, + }, + ), + ), + }, + }, + want: want{ + assertErr: expectNoErr, + }, + }, + { + name: "1 aggregate 2 commands", + args: args{ + reducer: &testReducer{ + expectedReduces: 2, + }, + commands: []*command{ + { + intent: &intent{ + PushAggregate: eventstore.NewPushIntent( + "instance", + eventstore.AppendAggregate("owner", "testType", "testID"), + ).Aggregates()[0], + }, + Event: &eventstore.Event[eventstore.StoragePayload]{ + Aggregate: *eventstore.NewPushIntent( + "instance", + eventstore.AppendAggregate("owner", "testType", "testID"), + ).Aggregates()[0].Aggregate(), + Creator: "gigi", + Revision: 1, + Type: "test.type", + Sequence: 1, + }, + }, + { + intent: &intent{ + PushAggregate: eventstore.NewPushIntent( + "instance", + eventstore.AppendAggregate("owner", "testType", "testID"), + ).Aggregates()[0], + }, + Event: &eventstore.Event[eventstore.StoragePayload]{ + Aggregate: *eventstore.NewPushIntent( + "instance", + eventstore.AppendAggregate("owner", "testType", "testID"), + ).Aggregates()[0].Aggregate(), + Creator: "gigi", + Revision: 1, + Type: "test.type2", + Sequence: 2, + }, + }, + }, + expectations: []mock.Expectation{ + mock.ExpectQuery( + `INSERT INTO eventstore.events2 (instance_id, "owner", aggregate_type, aggregate_id, revision, creator, event_type, payload, "sequence", in_tx_order, created_at, "position") VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, statement_timestamp(), EXTRACT(EPOCH FROM clock_timestamp())), ($11, $12, $13, $14, $15, $16, $17, $18, $19, $20, statement_timestamp(), EXTRACT(EPOCH FROM clock_timestamp())) RETURNING created_at, "position"`, + mock.WithQueryArgs( + "instance", + "owner", + "testType", + "testID", + uint16(1), + "gigi", + "test.type", + nil, + uint32(1), + 0, + "instance", + "owner", + "testType", + "testID", + uint16(1), + "gigi", + "test.type2", + nil, + uint32(2), + 1, + ), + mock.WithQueryResult( + []string{"created_at", "position"}, + [][]driver.Value{ + { + time.Now(), + float64(123), + }, + { + time.Now(), + float64(123.1), + }, + }, + ), + ), + }, + }, + want: want{ + assertErr: expectNoErr, + }, + }, + { + name: "1 command per aggregate 2 aggregates", + args: args{ + reducer: &testReducer{ + expectedReduces: 2, + }, + commands: []*command{ + { + intent: &intent{ + PushAggregate: eventstore.NewPushIntent( + "instance", + eventstore.AppendAggregate("owner", "testType", "testID"), + ).Aggregates()[0], + }, + Event: &eventstore.Event[eventstore.StoragePayload]{ + Aggregate: *eventstore.NewPushIntent( + "instance", + eventstore.AppendAggregate("owner", "testType", "testID"), + ).Aggregates()[0].Aggregate(), + Creator: "gigi", + Revision: 1, + Type: "test.type", + Sequence: 1, + }, + }, + { + intent: &intent{ + PushAggregate: eventstore.NewPushIntent( + "instance", + eventstore.AppendAggregate("owner", "testType", "testID"), + ).Aggregates()[0], + }, + Event: &eventstore.Event[eventstore.StoragePayload]{ + Aggregate: eventstore.Aggregate{ + ID: "id2", + Type: "type2", + Instance: "instance", + Owner: "owner", + }, + Creator: "gigi", + Revision: 1, + Type: "test.type2", + Sequence: 10, + }, + }, + }, + expectations: []mock.Expectation{ + mock.ExpectQuery( + `INSERT INTO eventstore.events2 (instance_id, "owner", aggregate_type, aggregate_id, revision, creator, event_type, payload, "sequence", in_tx_order, created_at, "position") VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, statement_timestamp(), EXTRACT(EPOCH FROM clock_timestamp())), ($11, $12, $13, $14, $15, $16, $17, $18, $19, $20, statement_timestamp(), EXTRACT(EPOCH FROM clock_timestamp())) RETURNING created_at, "position"`, + mock.WithQueryArgs( + "instance", + "owner", + "testType", + "testID", + uint16(1), + "gigi", + "test.type", + nil, + uint32(1), + 0, + "instance", + "owner", + "type2", + "id2", + uint16(1), + "gigi", + "test.type2", + nil, + uint32(10), + 1, + ), + mock.WithQueryResult( + []string{"created_at", "position"}, + [][]driver.Value{ + { + time.Now(), + float64(123), + }, + { + time.Now(), + float64(123.1), + }, + }, + ), + ), + }, + }, + want: want{ + assertErr: expectNoErr, + }, + }, + { + name: "1 aggregate 1 command with payload", + args: args{ + reducer: &testReducer{ + expectedReduces: 1, + }, + commands: []*command{ + { + intent: &intent{ + PushAggregate: eventstore.NewPushIntent( + "instance", + eventstore.AppendAggregate("owner", "testType", "testID"), + ).Aggregates()[0], + }, + Event: &eventstore.Event[eventstore.StoragePayload]{ + Aggregate: *eventstore.NewPushIntent( + "instance", + eventstore.AppendAggregate("owner", "testType", "testID"), + ).Aggregates()[0].Aggregate(), + Creator: "gigi", + Revision: 1, + Type: "test.type", + Sequence: 1, + Payload: unmarshalPayload(`{"name": "gigi"}`), + }, + }, + }, + expectations: []mock.Expectation{ + mock.ExpectQuery( + `INSERT INTO eventstore.events2 (instance_id, "owner", aggregate_type, aggregate_id, revision, creator, event_type, payload, "sequence", in_tx_order, created_at, "position") VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, statement_timestamp(), EXTRACT(EPOCH FROM clock_timestamp())) RETURNING created_at, "position"`, + mock.WithQueryArgs( + "instance", + "owner", + "testType", + "testID", + uint16(1), + "gigi", + "test.type", + unmarshalPayload(`{"name": "gigi"}`), + uint32(1), + 0, + ), + mock.WithQueryResult( + []string{"created_at", "position"}, + [][]driver.Value{ + { + time.Now(), + float64(123), + }, + }, + ), + ), + }, + }, + want: want{ + assertErr: expectNoErr, + }, + }, + { + name: "command reducer", + args: args{ + reducer: &testReducer{ + expectedReduces: 1, + }, + commands: []*command{ + { + intent: &intent{ + PushAggregate: eventstore.NewPushIntent( + "instance", + eventstore.AppendAggregate("owner", "testType", "testID"), + ).Aggregates()[0], + }, + Event: &eventstore.Event[eventstore.StoragePayload]{ + Aggregate: *eventstore.NewPushIntent( + "instance", + eventstore.AppendAggregate("owner", "testType", "testID"), + ).Aggregates()[0].Aggregate(), + Creator: "gigi", + Revision: 1, + Type: "test.type", + Sequence: 1, + }, + }, + }, + expectations: []mock.Expectation{ + mock.ExpectQuery( + `INSERT INTO eventstore.events2 (instance_id, "owner", aggregate_type, aggregate_id, revision, creator, event_type, payload, "sequence", in_tx_order, created_at, "position") VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, statement_timestamp(), EXTRACT(EPOCH FROM clock_timestamp())) RETURNING created_at, "position"`, + mock.WithQueryArgs( + "instance", + "owner", + "testType", + "testID", + uint16(1), + "gigi", + "test.type", + nil, + uint32(1), + 0, + ), + mock.WithQueryResult( + []string{"created_at", "position"}, + [][]driver.Value{ + { + time.Now(), + float64(123), + }, + }, + ), + ), + }, + }, + want: want{ + assertErr: expectNoErr, + }, + }, + { + name: "command reducer err", + args: args{ + reducer: &testReducer{ + expectedReduces: 1, + shouldErr: true, + }, + commands: []*command{ + { + intent: &intent{ + PushAggregate: eventstore.NewPushIntent( + "instance", + eventstore.AppendAggregate("owner", "testType", "testID"), + ).Aggregates()[0], + }, + Event: &eventstore.Event[eventstore.StoragePayload]{ + Aggregate: *eventstore.NewPushIntent( + "instance", + eventstore.AppendAggregate("owner", "testType", "testID"), + ).Aggregates()[0].Aggregate(), + Creator: "gigi", + Revision: 1, + Type: "test.type", + Sequence: 1, + }, + }, + { + intent: &intent{ + PushAggregate: eventstore.NewPushIntent( + "instance", + eventstore.AppendAggregate("owner", "testType", "testID"), + ).Aggregates()[0], + }, + Event: &eventstore.Event[eventstore.StoragePayload]{ + Aggregate: *eventstore.NewPushIntent( + "instance", + eventstore.AppendAggregate("owner", "testType", "testID"), + ).Aggregates()[0].Aggregate(), + Creator: "gigi", + Revision: 1, + Type: "test.type2", + Sequence: 2, + }, + }, + }, + expectations: []mock.Expectation{ + mock.ExpectQuery( + `INSERT INTO eventstore.events2 (instance_id, "owner", aggregate_type, aggregate_id, revision, creator, event_type, payload, "sequence", in_tx_order, created_at, "position") VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, statement_timestamp(), EXTRACT(EPOCH FROM clock_timestamp())), ($11, $12, $13, $14, $15, $16, $17, $18, $19, $20, statement_timestamp(), EXTRACT(EPOCH FROM clock_timestamp())) RETURNING created_at, "position"`, + mock.WithQueryArgs( + "instance", + "owner", + "testType", + "testID", + uint16(1), + "gigi", + "test.type", + nil, + uint32(1), + 0, + "instance", + "owner", + "testType", + "testID", + uint16(1), + "gigi", + "test.type2", + nil, + uint32(2), + 1, + ), + mock.WithQueryResult( + []string{"created_at", "position"}, + [][]driver.Value{ + { + time.Now(), + float64(123), + }, + { + time.Now(), + float64(123.1), + }, + }, + ), + ), + }, + }, + want: want{ + assertErr: func(t *testing.T, err error) bool { + is := errors.Is(err, errReduce) + if !is { + t.Errorf("no error expected got: %v", err) + } + return is + }, + }, + }, + { + name: "1 aggregate 2 commands", + args: args{ + reducer: &testReducer{ + expectedReduces: 2, + }, + commands: []*command{ + { + intent: &intent{ + PushAggregate: eventstore.NewPushIntent( + "instance", + eventstore.AppendAggregate("owner", "testType", "testID"), + ).Aggregates()[0], + }, + Event: &eventstore.Event[eventstore.StoragePayload]{ + Aggregate: *eventstore.NewPushIntent( + "instance", + eventstore.AppendAggregate("owner", "testType", "testID"), + ).Aggregates()[0].Aggregate(), + Creator: "gigi", + Revision: 1, + Type: "test.type", + Sequence: 1, + }, + }, + { + intent: &intent{ + PushAggregate: eventstore.NewPushIntent( + "instance", + eventstore.AppendAggregate("owner", "testType", "testID"), + ).Aggregates()[0], + }, + Event: &eventstore.Event[eventstore.StoragePayload]{ + Aggregate: *eventstore.NewPushIntent( + "instance", + eventstore.AppendAggregate("owner", "testType", "testID"), + ).Aggregates()[0].Aggregate(), + Creator: "gigi", + Revision: 1, + Type: "test.type2", + Sequence: 2, + }, + }, + }, + expectations: []mock.Expectation{ + mock.ExpectQuery( + `INSERT INTO eventstore.events2 (instance_id, "owner", aggregate_type, aggregate_id, revision, creator, event_type, payload, "sequence", in_tx_order, created_at, "position") VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, statement_timestamp(), EXTRACT(EPOCH FROM clock_timestamp())), ($11, $12, $13, $14, $15, $16, $17, $18, $19, $20, statement_timestamp(), EXTRACT(EPOCH FROM clock_timestamp())) RETURNING created_at, "position"`, + mock.WithQueryArgs( + "instance", + "owner", + "testType", + "testID", + uint16(1), + "gigi", + "test.type", + nil, + uint32(1), + 0, + "instance", + "owner", + "testType", + "testID", + uint16(1), + "gigi", + "test.type2", + nil, + uint32(2), + 1, + ), + mock.WithQueryResult( + []string{"created_at", "position"}, + [][]driver.Value{ + { + time.Now(), + float64(123), + }, + { + time.Now(), + float64(123.1), + }, + }, + ), + ), + }, + }, + want: want{ + assertErr: expectNoErr, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + dbMock := mock.NewSQLMock(t, append([]mock.Expectation{mock.ExpectBegin(nil)}, tt.args.expectations...)...) + tx, err := dbMock.DB.Begin() + if err != nil { + t.Errorf("unexpected error in begin: %v", err) + t.FailNow() + } + err = push(context.Background(), tx, tt.args.reducer, tt.args.commands) + tt.want.assertErr(t, err) + dbMock.Assert(t) + if tt.args.reducer != nil { + tt.args.reducer.assert(t) + } + }) + } +} + +func expectNoErr(t *testing.T, err error) bool { + is := err == nil + if !is { + t.Errorf("no error expected got: %v", err) + } + return is +} diff --git a/internal/v2/eventstore/postgres/query.go b/internal/v2/eventstore/postgres/query.go new file mode 100644 index 0000000000..608b31e533 --- /dev/null +++ b/internal/v2/eventstore/postgres/query.go @@ -0,0 +1,289 @@ +package postgres + +import ( + "context" + "database/sql" + "slices" + + "github.com/zitadel/logging" + + "github.com/zitadel/zitadel/internal/telemetry/tracing" + "github.com/zitadel/zitadel/internal/v2/database" + "github.com/zitadel/zitadel/internal/v2/eventstore" +) + +func (s *Storage) Query(ctx context.Context, query *eventstore.Query) (eventCount int, err error) { + ctx, span := tracing.NewSpan(ctx) + defer func() { span.EndWithError(err) }() + + var stmt database.Statement + writeQuery(&stmt, query) + + if query.Tx() != nil { + return executeQuery(ctx, query.Tx(), &stmt, query) + } + + return executeQuery(ctx, s.client.DB, &stmt, query) +} + +func executeQuery(ctx context.Context, tx database.Querier, stmt *database.Statement, reducer eventstore.Reducer) (eventCount int, err error) { + ctx, span := tracing.NewSpan(ctx) + defer func() { span.EndWithError(err) }() + + //nolint:rowserrcheck + // rows is checked by database.MapRowsToObject + rows, err := tx.QueryContext(ctx, stmt.String(), stmt.Args()...) + if err != nil { + return 0, err + } + + err = database.MapRowsToObject(rows, func(scan func(dest ...any) error) error { + e := new(eventstore.Event[eventstore.StoragePayload]) + + var payload sql.Null[[]byte] + + err := scan( + &e.CreatedAt, + &e.Type, + &e.Sequence, + &e.Position.Position, + &e.Position.InPositionOrder, + &payload, + &e.Creator, + &e.Aggregate.Owner, + &e.Aggregate.Instance, + &e.Aggregate.Type, + &e.Aggregate.ID, + &e.Revision, + ) + if err != nil { + return err + } + e.Payload = unmarshalPayload(payload.V) + eventCount++ + + return reducer.Reduce(e) + }) + + return eventCount, err +} + +var ( + selectColumns = `SELECT created_at, event_type, "sequence", "position", in_tx_order, payload, creator, "owner", instance_id, aggregate_type, aggregate_id, revision` + // TODO: condition must know if it's args are named parameters or not + // instancePlaceholder = database.Placeholder("@instance_id") +) + +func writeQuery(stmt *database.Statement, query *eventstore.Query) { + stmt.WriteString(selectColumns) + // stmt.SetNamedArg(instancePlaceholder, query.Instance()) + + stmt.WriteString(" FROM (") + writeFilters(stmt, query.Filters()) + stmt.WriteRune(')') + writePagination(stmt, query.Pagination()) +} + +var from = " FROM eventstore.events2" + +func writeFilters(stmt *database.Statement, filters []*eventstore.Filter) { + if len(filters) == 0 { + logging.Fatal("query does not contain filters") + } + + for i, filter := range filters { + if i > 0 { + stmt.WriteString(" UNION ALL ") + } + stmt.WriteRune('(') + stmt.WriteString(selectColumns) + stmt.WriteString(from) + + writeFilter(stmt, filter) + + stmt.WriteString(")") + } +} + +func writeFilter(stmt *database.Statement, filter *eventstore.Filter) { + stmt.WriteString(" WHERE ") + filter.Parent().Instance().Write(stmt, "instance_id") + + writeAggregateFilters(stmt, filter.AggregateFilters()) + writePagination(stmt, filter.Pagination()) +} + +func writePagination(stmt *database.Statement, pagination *eventstore.Pagination) { + writePosition(stmt, pagination.Position()) + writeOrdering(stmt, pagination.Desc()) + if pagination.Pagination() != nil { + pagination.Pagination().Write(stmt) + } +} + +func writePosition(stmt *database.Statement, position *eventstore.PositionCondition) { + if position == nil { + return + } + + max := position.Max() + min := position.Min() + + stmt.WriteString(" AND ") + + if max != nil { + if max.InPositionOrder > 0 { + stmt.WriteString("((") + database.NewNumberEquals(max.Position).Write(stmt, "position") + stmt.WriteString(" AND ") + database.NewNumberLess(max.InPositionOrder).Write(stmt, "in_tx_order") + stmt.WriteRune(')') + stmt.WriteString(" OR ") + } + database.NewNumberLess(max.Position).Write(stmt, "position") + if max.InPositionOrder > 0 { + stmt.WriteRune(')') + } + } + + if max != nil && min != nil { + stmt.WriteString(" AND ") + } + + if min != nil { + if min.InPositionOrder > 0 { + stmt.WriteString("((") + database.NewNumberEquals(min.Position).Write(stmt, "position") + stmt.WriteString(" AND ") + database.NewNumberGreater(min.InPositionOrder).Write(stmt, "in_tx_order") + stmt.WriteRune(')') + stmt.WriteString(" OR ") + } + database.NewNumberGreater(min.Position).Write(stmt, "position") + if min.InPositionOrder > 0 { + stmt.WriteRune(')') + } + } +} + +func writeAggregateFilters(stmt *database.Statement, filters []*eventstore.AggregateFilter) { + if len(filters) == 0 { + return + } + + stmt.WriteString(" AND ") + if len(filters) > 1 { + stmt.WriteRune('(') + } + for i, filter := range filters { + if i > 0 { + stmt.WriteString(" OR ") + } + writeAggregateFilter(stmt, filter) + } + if len(filters) > 1 { + stmt.WriteRune(')') + } +} + +func writeAggregateFilter(stmt *database.Statement, filter *eventstore.AggregateFilter) { + conditions := definedConditions([]*condition{ + {column: "aggregate_type", condition: filter.Type()}, + {column: "aggregate_id", condition: filter.IDs()}, + }) + + if len(conditions) > 1 || len(filter.Events()) > 0 { + stmt.WriteRune('(') + } + + writeConditions( + stmt, + conditions, + " AND ", + ) + writeEventFilters(stmt, filter.Events()) + + if len(conditions) > 1 || len(filter.Events()) > 0 { + stmt.WriteRune(')') + } +} + +func writeEventFilters(stmt *database.Statement, filters []*eventstore.EventFilter) { + if len(filters) == 0 { + return + } + + stmt.WriteString(" AND ") + if len(filters) > 1 { + stmt.WriteRune('(') + } + + for i, filter := range filters { + if i > 0 { + stmt.WriteString(" OR ") + } + writeEventFilter(stmt, filter) + } + + if len(filters) > 1 { + stmt.WriteRune(')') + } +} + +func writeEventFilter(stmt *database.Statement, filter *eventstore.EventFilter) { + conditions := definedConditions([]*condition{ + {column: "event_type", condition: filter.Types()}, + {column: "created_at", condition: filter.CreatedAt()}, + {column: "sequence", condition: filter.Sequence()}, + {column: "revision", condition: filter.Revision()}, + {column: "creator", condition: filter.Creators()}, + }) + + if len(conditions) > 1 { + stmt.WriteRune('(') + } + + writeConditions( + stmt, + conditions, + " AND ", + ) + + if len(conditions) > 1 { + stmt.WriteRune(')') + } +} + +type condition struct { + column string + condition database.Condition +} + +func writeConditions(stmt *database.Statement, conditions []*condition, sep string) { + var i int + for _, cond := range conditions { + if i > 0 { + stmt.WriteString(sep) + } + cond.condition.Write(stmt, cond.column) + i++ + } +} + +func definedConditions(conditions []*condition) []*condition { + return slices.DeleteFunc(conditions, func(cond *condition) bool { + return cond.condition == nil + }) +} + +func writeOrdering(stmt *database.Statement, descending bool) { + stmt.WriteString(" ORDER BY position") + if descending { + stmt.WriteString(" DESC") + } + + stmt.WriteString(", in_tx_order") + if descending { + stmt.WriteString(" DESC") + } +} diff --git a/internal/v2/eventstore/postgres/query_test.go b/internal/v2/eventstore/postgres/query_test.go new file mode 100644 index 0000000000..c6e2c6f8a3 --- /dev/null +++ b/internal/v2/eventstore/postgres/query_test.go @@ -0,0 +1,1380 @@ +package postgres + +import ( + "context" + "database/sql/driver" + "errors" + "reflect" + "testing" + "time" + + "github.com/zitadel/zitadel/internal/v2/database" + "github.com/zitadel/zitadel/internal/v2/database/mock" + "github.com/zitadel/zitadel/internal/v2/eventstore" +) + +func Test_writeOrdering(t *testing.T) { + type args struct { + descending bool + } + tests := []struct { + name string + args args + want wantQuery + }{ + { + name: "asc", + args: args{ + descending: false, + }, + want: wantQuery{ + query: " ORDER BY position, in_tx_order", + }, + }, + { + name: "desc", + args: args{ + descending: true, + }, + want: wantQuery{ + query: " ORDER BY position DESC, in_tx_order DESC", + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var stmt database.Statement + writeOrdering(&stmt, tt.args.descending) + assertQuery(t, &stmt, tt.want) + }) + } +} + +func Test_writeConditionsIfSet(t *testing.T) { + type args struct { + conditions []*condition + sep string + } + tests := []struct { + name string + args args + want wantQuery + }{ + { + name: "no condition", + args: args{ + conditions: []*condition{}, + sep: " AND ", + }, + want: wantQuery{ + query: "", + args: []any{}, + }, + }, + { + name: "1 condition set", + args: args{ + conditions: []*condition{ + {column: "column", condition: database.NewTextEqual("asdf")}, + }, + sep: " AND ", + }, + want: wantQuery{ + query: "column = $1", + args: []any{"asdf"}, + }, + }, + { + name: "multiple conditions set", + args: args{ + conditions: []*condition{ + {column: "column1", condition: database.NewTextEqual("asdf")}, + {column: "column2", condition: database.NewNumberAtLeast(12)}, + {column: "column3", condition: database.NewNumberBetween(1, 100)}, + }, + sep: " AND ", + }, + want: wantQuery{ + query: "column1 = $1 AND column2 >= $2 AND column3 >= $3 AND column3 <= $4", + args: []any{"asdf", 12, 1, 100}, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var stmt database.Statement + writeConditions(&stmt, tt.args.conditions, tt.args.sep) + assertQuery(t, &stmt, tt.want) + }) + } +} + +func Test_writeEventFilter(t *testing.T) { + now := time.Now() + type args struct { + filter *eventstore.EventFilter + } + tests := []struct { + name string + args args + want wantQuery + }{ + { + name: "no filters", + args: args{ + filter: &eventstore.EventFilter{}, + }, + want: wantQuery{ + query: "", + args: []any{}, + }, + }, + { + name: "event_type", + args: args{ + filter: eventstore.NewEventFilter( + eventstore.SetEventType("user.added"), + ), + }, + want: wantQuery{ + query: "event_type = $1", + args: []any{"user.added"}, + }, + }, + { + name: "created_at", + args: args{ + filter: eventstore.NewEventFilter( + eventstore.EventCreatedAtEquals(now), + ), + }, + want: wantQuery{ + query: "created_at = $1", + args: []any{now}, + }, + }, + { + name: "created_at between", + args: args{ + filter: eventstore.NewEventFilter( + eventstore.EventCreatedAtBetween(now, now.Add(time.Second)), + ), + }, + want: wantQuery{ + query: "created_at >= $1 AND created_at <= $2", + args: []any{now, now.Add(time.Second)}, + }, + }, + { + name: "sequence", + args: args{ + filter: eventstore.NewEventFilter( + eventstore.EventSequenceEquals(100), + ), + }, + want: wantQuery{ + query: "sequence = $1", + args: []any{uint32(100)}, + }, + }, + { + name: "sequence between", + args: args{ + filter: eventstore.NewEventFilter( + eventstore.EventSequenceBetween(0, 10), + ), + }, + want: wantQuery{ + query: "sequence >= $1 AND sequence <= $2", + args: []any{uint32(0), uint32(10)}, + }, + }, + { + name: "revision", + args: args{ + filter: eventstore.NewEventFilter( + eventstore.EventRevisionAtLeast(2), + ), + }, + want: wantQuery{ + query: "revision >= $1", + args: []any{uint16(2)}, + }, + }, + { + name: "creator", + args: args{ + filter: eventstore.NewEventFilter( + eventstore.EventCreatorsEqual("user-123"), + ), + }, + want: wantQuery{ + query: "creator = $1", + args: []any{"user-123"}, + }, + }, + { + name: "all", + args: args{ + filter: eventstore.NewEventFilter( + eventstore.SetEventType("user.added"), + eventstore.EventCreatedAtAtLeast(now), + eventstore.EventSequenceGreater(10), + eventstore.EventRevisionEquals(1), + eventstore.EventCreatorsEqual("user-123"), + ), + }, + want: wantQuery{ + query: "(event_type = $1 AND created_at >= $2 AND sequence > $3 AND revision = $4 AND creator = $5)", + args: []any{"user.added", now, uint32(10), uint16(1), "user-123"}, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var stmt database.Statement + writeEventFilter(&stmt, tt.args.filter) + assertQuery(t, &stmt, tt.want) + }) + } +} + +func Test_writeEventFilters(t *testing.T) { + type args struct { + filters []*eventstore.EventFilter + } + tests := []struct { + name string + args args + want wantQuery + }{ + { + name: "no filters", + args: args{}, + want: wantQuery{ + query: "", + args: []any{}, + }, + }, + { + name: "1 filter", + args: args{ + filters: []*eventstore.EventFilter{ + eventstore.NewEventFilter( + eventstore.SetEventType("user.added"), + ), + }, + }, + want: wantQuery{ + query: " AND event_type = $1", + args: []any{"user.added"}, + }, + }, + { + name: "multiple filters", + args: args{ + filters: []*eventstore.EventFilter{ + eventstore.NewEventFilter( + eventstore.SetEventType("user.added"), + ), + eventstore.NewEventFilter( + eventstore.SetEventType("org.added"), + eventstore.EventSequenceGreater(4), + ), + }, + }, + want: wantQuery{ + query: " AND (event_type = $1 OR (event_type = $2 AND sequence > $3))", + args: []any{"user.added", "org.added", uint32(4)}, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var stmt database.Statement + writeEventFilters(&stmt, tt.args.filters) + assertQuery(t, &stmt, tt.want) + }) + } +} + +func Test_writeAggregateFilter(t *testing.T) { + type args struct { + filter *eventstore.AggregateFilter + } + tests := []struct { + name string + args args + want wantQuery + }{ + { + name: "minimal", + args: args{ + filter: eventstore.NewAggregateFilter( + "user", + ), + }, + want: wantQuery{ + query: "aggregate_type = $1", + args: []any{"user"}, + }, + }, + { + name: "all on aggregate", + args: args{ + filter: eventstore.NewAggregateFilter( + "user", + eventstore.SetAggregateID("234"), + ), + }, + want: wantQuery{ + query: "(aggregate_type = $1 AND aggregate_id = $2)", + args: []any{"user", "234"}, + }, + }, + { + name: "1 event filter minimal aggregate", + args: args{ + filter: eventstore.NewAggregateFilter( + "user", + eventstore.AppendEvent( + eventstore.SetEventType("user.added"), + ), + ), + }, + want: wantQuery{ + query: "(aggregate_type = $1 AND event_type = $2)", + args: []any{"user", "user.added"}, + }, + }, + { + name: "1 event filter all aggregate", + args: args{ + filter: eventstore.NewAggregateFilter( + "user", + eventstore.SetAggregateID("123"), + eventstore.AppendEvent( + eventstore.SetEventType("user.added"), + ), + ), + }, + want: wantQuery{ + query: "(aggregate_type = $1 AND aggregate_id = $2 AND event_type = $3)", + args: []any{"user", "123", "user.added"}, + }, + }, + { + name: "1 event filter with multiple conditions all aggregate", + args: args{ + filter: eventstore.NewAggregateFilter( + "user", + eventstore.SetAggregateID("123"), + eventstore.AppendEvent( + eventstore.SetEventType("user.added"), + eventstore.EventSequenceGreater(1), + ), + ), + }, + want: wantQuery{ + query: "(aggregate_type = $1 AND aggregate_id = $2 AND (event_type = $3 AND sequence > $4))", + args: []any{"user", "123", "user.added", uint32(1)}, + }, + }, + { + name: "2 event filters all aggregate", + args: args{ + filter: eventstore.NewAggregateFilter( + "user", + eventstore.SetAggregateID("123"), + eventstore.AppendEvent( + eventstore.SetEventType("user.added"), + ), + eventstore.AppendEvent( + eventstore.EventSequenceGreater(1), + ), + ), + }, + want: wantQuery{ + query: "(aggregate_type = $1 AND aggregate_id = $2 AND (event_type = $3 OR sequence > $4))", + args: []any{"user", "123", "user.added", uint32(1)}, + }, + }, + { + name: "2 event filters with multiple conditions all aggregate", + args: args{ + filter: eventstore.NewAggregateFilter( + "user", + eventstore.SetAggregateID("123"), + eventstore.AppendEvents( + eventstore.NewEventFilter( + eventstore.SetEventType("user.added"), + eventstore.EventSequenceGreater(1), + ), + ), + eventstore.AppendEvent( + eventstore.SetEventType("user.changed"), + eventstore.EventSequenceGreater(4), + ), + ), + }, + want: wantQuery{ + query: "(aggregate_type = $1 AND aggregate_id = $2 AND ((event_type = $3 AND sequence > $4) OR (event_type = $5 AND sequence > $6)))", + args: []any{"user", "123", "user.added", uint32(1), "user.changed", uint32(4)}, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var stmt database.Statement + writeAggregateFilter(&stmt, tt.args.filter) + assertQuery(t, &stmt, tt.want) + }) + } +} + +func Test_writeAggregateFilters(t *testing.T) { + type args struct { + filters []*eventstore.AggregateFilter + } + tests := []struct { + name string + args args + want wantQuery + }{ + { + name: "no filters", + args: args{}, + want: wantQuery{ + query: "", + args: []any{}, + }, + }, + { + name: "1 filter", + args: args{ + filters: []*eventstore.AggregateFilter{ + eventstore.NewAggregateFilter("user"), + }, + }, + want: wantQuery{ + query: " AND aggregate_type = $1", + args: []any{"user"}, + }, + }, + { + name: "multiple filters", + args: args{ + filters: []*eventstore.AggregateFilter{ + eventstore.NewAggregateFilter("user"), + eventstore.NewAggregateFilter("org", + eventstore.AppendEvent( + eventstore.SetEventType("org.added"), + ), + ), + }, + }, + want: wantQuery{ + query: " AND (aggregate_type = $1 OR (aggregate_type = $2 AND event_type = $3))", + args: []any{"user", "org", "org.added"}, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var stmt database.Statement + writeAggregateFilters(&stmt, tt.args.filters) + assertQuery(t, &stmt, tt.want) + }) + } +} + +func Test_writeFilter(t *testing.T) { + type args struct { + filter *eventstore.Filter + } + tests := []struct { + name string + args args + want wantQuery + }{ + { + name: "empty filters", + args: args{ + filter: eventstore.NewFilter(), + }, + want: wantQuery{ + query: " WHERE instance_id = $1 ORDER BY position, in_tx_order", + args: []any{"i1"}, + }, + }, + { + name: "descending", + args: args{ + filter: eventstore.NewFilter( + eventstore.FilterPagination( + eventstore.Descending(), + ), + ), + }, + want: wantQuery{ + query: " WHERE instance_id = $1 ORDER BY position DESC, in_tx_order DESC", + args: []any{"i1"}, + }, + }, + { + name: "database pagination", + args: args{ + filter: eventstore.NewFilter( + eventstore.FilterPagination( + eventstore.Limit(10), + eventstore.Offset(3), + ), + ), + }, + want: wantQuery{ + query: " WHERE instance_id = $1 ORDER BY position, in_tx_order LIMIT $2 OFFSET $3", + args: []any{"i1", uint32(10), uint32(3)}, + }, + }, + { + name: "position pagination", + args: args{ + filter: eventstore.NewFilter( + eventstore.FilterPagination( + eventstore.PositionGreater(123.4, 0), + ), + ), + }, + want: wantQuery{ + query: " WHERE instance_id = $1 AND position > $2 ORDER BY position, in_tx_order", + args: []any{"i1", 123.4}, + }, + }, + { + name: "position pagination between", + args: args{ + filter: eventstore.NewFilter( + eventstore.FilterPagination( + // eventstore.PositionGreater(123.4, 0), + // eventstore.PositionLess(125.4, 10), + eventstore.PositionBetween( + &eventstore.GlobalPosition{Position: 123.4}, + &eventstore.GlobalPosition{Position: 125.4, InPositionOrder: 10}, + ), + ), + ), + }, + want: wantQuery{ + query: " WHERE instance_id = $1 AND ((position = $2 AND in_tx_order < $3) OR position < $4) AND position > $5 ORDER BY position, in_tx_order", + args: []any{"i1", 125.4, uint32(10), 125.4, 123.4}, + // TODO: (adlerhurst) would require some refactoring to reuse existing args + // query: " WHERE instance_id = $1 AND position > $2 AND ((position = $3 AND in_tx_order < $4) OR position < $3) ORDER BY position, in_tx_order", + // args: []any{"i1", 123.4, 125.4, uint32(10)}, + }, + }, + { + name: "position and inPositionOrder pagination", + args: args{ + filter: eventstore.NewFilter( + eventstore.FilterPagination( + eventstore.PositionGreater(123.4, 12), + ), + ), + }, + want: wantQuery{ + query: " WHERE instance_id = $1 AND ((position = $2 AND in_tx_order > $3) OR position > $4) ORDER BY position, in_tx_order", + args: []any{"i1", 123.4, uint32(12), 123.4}, + }, + }, + { + name: "pagination", + args: args{ + filter: eventstore.NewFilter( + eventstore.FilterPagination( + eventstore.Limit(10), + eventstore.Offset(3), + eventstore.PositionGreater(123.4, 12), + ), + ), + }, + want: wantQuery{ + query: " WHERE instance_id = $1 AND ((position = $2 AND in_tx_order > $3) OR position > $4) ORDER BY position, in_tx_order LIMIT $5 OFFSET $6", + args: []any{"i1", 123.4, uint32(12), 123.4, uint32(10), uint32(3)}, + }, + }, + { + name: "aggregate and pagination", + args: args{ + filter: eventstore.NewFilter( + eventstore.FilterPagination( + eventstore.Limit(10), + eventstore.Offset(3), + eventstore.PositionGreater(123.4, 12), + ), + eventstore.AppendAggregateFilter("user"), + ), + }, + want: wantQuery{ + query: " WHERE instance_id = $1 AND aggregate_type = $2 AND ((position = $3 AND in_tx_order > $4) OR position > $5) ORDER BY position, in_tx_order LIMIT $6 OFFSET $7", + args: []any{"i1", "user", 123.4, uint32(12), 123.4, uint32(10), uint32(3)}, + }, + }, + { + name: "aggregates and pagination", + args: args{ + filter: eventstore.NewFilter( + eventstore.FilterPagination( + eventstore.Limit(10), + eventstore.Offset(3), + eventstore.PositionGreater(123.4, 12), + ), + eventstore.AppendAggregateFilter("user"), + eventstore.AppendAggregateFilter( + "org", + eventstore.SetAggregateID("o1"), + ), + ), + }, + want: wantQuery{ + query: " WHERE instance_id = $1 AND (aggregate_type = $2 OR (aggregate_type = $3 AND aggregate_id = $4)) AND ((position = $5 AND in_tx_order > $6) OR position > $7) ORDER BY position, in_tx_order LIMIT $8 OFFSET $9", + args: []any{"i1", "user", "org", "o1", 123.4, uint32(12), 123.4, uint32(10), uint32(3)}, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var stmt database.Statement + eventstore.NewQuery("i1", nil, eventstore.AppendFilters(tt.args.filter)) + + writeFilter(&stmt, tt.args.filter) + assertQuery(t, &stmt, tt.want) + }) + } +} + +func Test_writeQuery(t *testing.T) { + type args struct { + query *eventstore.Query + } + tests := []struct { + name string + args args + want wantQuery + }{ + { + name: "empty filter", + args: args{ + query: eventstore.NewQuery( + "i1", + nil, + eventstore.AppendFilters( + eventstore.NewFilter(), + ), + ), + }, + want: wantQuery{ + query: `SELECT created_at, event_type, "sequence", "position", in_tx_order, payload, creator, "owner", instance_id, aggregate_type, aggregate_id, revision FROM ((SELECT created_at, event_type, "sequence", "position", in_tx_order, payload, creator, "owner", instance_id, aggregate_type, aggregate_id, revision FROM eventstore.events2 WHERE instance_id = $1 ORDER BY position, in_tx_order)) ORDER BY position, in_tx_order`, + args: []any{"i1"}, + }, + }, + { + name: "1 filter", + args: args{ + query: eventstore.NewQuery( + "i1", + nil, + eventstore.AppendFilters( + eventstore.NewFilter( + eventstore.AppendAggregateFilter( + "user", + eventstore.AggregateIDs("a", "b"), + ), + ), + ), + ), + }, + want: wantQuery{ + query: `SELECT created_at, event_type, "sequence", "position", in_tx_order, payload, creator, "owner", instance_id, aggregate_type, aggregate_id, revision FROM ((SELECT created_at, event_type, "sequence", "position", in_tx_order, payload, creator, "owner", instance_id, aggregate_type, aggregate_id, revision FROM eventstore.events2 WHERE instance_id = $1 AND (aggregate_type = $2 AND aggregate_id = ANY($3)) ORDER BY position, in_tx_order)) ORDER BY position, in_tx_order`, + args: []any{"i1", "user", []string{"a", "b"}}, + }, + }, + { + name: "multiple filters", + args: args{ + query: eventstore.NewQuery( + "i1", + nil, + eventstore.AppendFilters( + eventstore.NewFilter( + eventstore.AppendAggregateFilter( + "user", + eventstore.AggregateIDs("a", "b"), + ), + ), + eventstore.NewFilter( + eventstore.AppendAggregateFilter( + "org", + eventstore.AppendEvent( + eventstore.SetEventType("org.added"), + ), + ), + ), + ), + ), + }, + want: wantQuery{ + query: `SELECT created_at, event_type, "sequence", "position", in_tx_order, payload, creator, "owner", instance_id, aggregate_type, aggregate_id, revision FROM ((SELECT created_at, event_type, "sequence", "position", in_tx_order, payload, creator, "owner", instance_id, aggregate_type, aggregate_id, revision FROM eventstore.events2 WHERE instance_id = $1 AND (aggregate_type = $2 AND aggregate_id = ANY($3)) ORDER BY position, in_tx_order) UNION ALL (SELECT created_at, event_type, "sequence", "position", in_tx_order, payload, creator, "owner", instance_id, aggregate_type, aggregate_id, revision FROM eventstore.events2 WHERE instance_id = $4 AND (aggregate_type = $5 AND event_type = $6) ORDER BY position, in_tx_order)) ORDER BY position, in_tx_order`, + args: []any{"i1", "user", []string{"a", "b"}, "i1", "org", "org.added"}, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var stmt database.Statement + writeQuery(&stmt, tt.args.query) + assertQuery(t, &stmt, tt.want) + }) + } +} + +func Test_writeQueryUse_examples(t *testing.T) { + type args struct { + query *eventstore.Query + } + tests := []struct { + name string + args args + want wantQuery + }{ + { + name: "aggregate type", + args: args{ + query: eventstore.NewQuery( + "instance", + nil, + eventstore.AppendFilters( + eventstore.NewFilter( + eventstore.AppendAggregateFilter("aggregate"), + ), + ), + ), + }, + want: wantQuery{ + query: `SELECT created_at, event_type, "sequence", "position", in_tx_order, payload, creator, "owner", instance_id, aggregate_type, aggregate_id, revision FROM ((SELECT created_at, event_type, "sequence", "position", in_tx_order, payload, creator, "owner", instance_id, aggregate_type, aggregate_id, revision FROM eventstore.events2 WHERE instance_id = $1 AND aggregate_type = $2 ORDER BY position, in_tx_order)) ORDER BY position, in_tx_order`, + args: []any{ + "instance", + "aggregate", + }, + }, + }, + { + name: "descending", + args: args{ + query: eventstore.NewQuery( + "instance", + nil, + eventstore.QueryPagination( + eventstore.Descending(), + ), + eventstore.AppendFilter( + eventstore.AppendAggregateFilter("aggregate"), + ), + ), + }, + want: wantQuery{ + query: `SELECT created_at, event_type, "sequence", "position", in_tx_order, payload, creator, "owner", instance_id, aggregate_type, aggregate_id, revision FROM ((SELECT created_at, event_type, "sequence", "position", in_tx_order, payload, creator, "owner", instance_id, aggregate_type, aggregate_id, revision FROM eventstore.events2 WHERE instance_id = $1 AND aggregate_type = $2 ORDER BY position DESC, in_tx_order DESC)) ORDER BY position DESC, in_tx_order DESC`, + args: []any{ + "instance", + "aggregate", + }, + }, + }, + { + name: "multiple aggregates", + args: args{ + query: eventstore.NewQuery( + "instance", + nil, + eventstore.AppendFilters( + eventstore.NewFilter( + eventstore.AppendAggregateFilter("agg1"), + ), + eventstore.NewFilter( + eventstore.AppendAggregateFilter("agg2"), + eventstore.AppendAggregateFilter("agg3"), + ), + ), + ), + }, + want: wantQuery{ + query: `SELECT created_at, event_type, "sequence", "position", in_tx_order, payload, creator, "owner", instance_id, aggregate_type, aggregate_id, revision FROM ((SELECT created_at, event_type, "sequence", "position", in_tx_order, payload, creator, "owner", instance_id, aggregate_type, aggregate_id, revision FROM eventstore.events2 WHERE instance_id = $1 AND aggregate_type = $2 ORDER BY position, in_tx_order) UNION ALL (SELECT created_at, event_type, "sequence", "position", in_tx_order, payload, creator, "owner", instance_id, aggregate_type, aggregate_id, revision FROM eventstore.events2 WHERE instance_id = $3 AND (aggregate_type = $4 OR aggregate_type = $5) ORDER BY position, in_tx_order)) ORDER BY position, in_tx_order`, + args: []any{ + "instance", + "agg1", + "instance", + "agg2", + "agg3", + }, + }, + }, + { + name: "multiple aggregates with ids", + args: args{ + query: eventstore.NewQuery( + "instance", + nil, + eventstore.AppendFilters( + eventstore.NewFilter( + eventstore.AppendAggregateFilter("agg1", eventstore.SetAggregateID("id")), + ), + eventstore.NewFilter( + eventstore.AppendAggregateFilter("agg2", eventstore.SetAggregateID("id2")), + eventstore.AppendAggregateFilter("agg3"), + ), + ), + ), + }, + want: wantQuery{ + query: `SELECT created_at, event_type, "sequence", "position", in_tx_order, payload, creator, "owner", instance_id, aggregate_type, aggregate_id, revision FROM ((SELECT created_at, event_type, "sequence", "position", in_tx_order, payload, creator, "owner", instance_id, aggregate_type, aggregate_id, revision FROM eventstore.events2 WHERE instance_id = $1 AND (aggregate_type = $2 AND aggregate_id = $3) ORDER BY position, in_tx_order) UNION ALL (SELECT created_at, event_type, "sequence", "position", in_tx_order, payload, creator, "owner", instance_id, aggregate_type, aggregate_id, revision FROM eventstore.events2 WHERE instance_id = $4 AND ((aggregate_type = $5 AND aggregate_id = $6) OR aggregate_type = $7) ORDER BY position, in_tx_order)) ORDER BY position, in_tx_order`, + args: []any{ + "instance", + "agg1", + "id", + "instance", + "agg2", + "id2", + "agg3", + }, + }, + }, + { + name: "multiple event queries and multiple filter in queries", + args: args{ + query: eventstore.NewQuery( + "instance", + nil, + eventstore.AppendFilter( + eventstore.AppendAggregateFilter( + "agg1", + eventstore.AggregateIDs("1", "2"), + ), + eventstore.AppendAggregateFilter( + "agg2", + eventstore.SetAggregateID("3"), + ), + eventstore.AppendAggregateFilter( + "agg3", + eventstore.SetAggregateID("3"), + ), + ), + ), + }, + want: wantQuery{ + query: `SELECT created_at, event_type, "sequence", "position", in_tx_order, payload, creator, "owner", instance_id, aggregate_type, aggregate_id, revision FROM ((SELECT created_at, event_type, "sequence", "position", in_tx_order, payload, creator, "owner", instance_id, aggregate_type, aggregate_id, revision FROM eventstore.events2 WHERE instance_id = $1 AND ((aggregate_type = $2 AND aggregate_id = ANY($3)) OR (aggregate_type = $4 AND aggregate_id = $5) OR (aggregate_type = $6 AND aggregate_id = $7)) ORDER BY position, in_tx_order)) ORDER BY position, in_tx_order`, + args: []any{ + "instance", + "agg1", + []string{"1", "2"}, + "agg2", + "3", + "agg3", + "3", + }, + }, + }, + { + name: "milestones", + args: args{ + query: eventstore.NewQuery( + "instance", + nil, + eventstore.AppendFilters( + eventstore.NewFilter( + eventstore.AppendAggregateFilter( + "instance", + eventstore.AppendEvent( + eventstore.SetEventType("instance.added"), + ), + ), + eventstore.FilterPagination( + eventstore.Limit(1), + ), + ), + eventstore.NewFilter( + eventstore.AppendAggregateFilter( + "instance", + eventstore.AppendEvent( + eventstore.SetEventType("instance.removed"), + ), + ), + eventstore.FilterPagination( + eventstore.Limit(1), + ), + ), + eventstore.NewFilter( + eventstore.AppendAggregateFilter( + "instance", + eventstore.AppendEvent( + eventstore.SetEventType("instance.domain.primary.set"), + eventstore.EventCreatorsNotContains("", "SYSTEM"), + ), + ), + eventstore.FilterPagination( + eventstore.Limit(1), + ), + ), + eventstore.NewFilter( + eventstore.AppendAggregateFilter( + "project", + eventstore.AppendEvent( + eventstore.SetEventType("project.added"), + eventstore.EventCreatorsNotContains("", "SYSTEM"), + ), + ), + eventstore.FilterPagination( + eventstore.Limit(1), + ), + ), + eventstore.NewFilter( + eventstore.AppendAggregateFilter( + "project", + eventstore.AppendEvent( + eventstore.EventCreatorsNotContains("", "SYSTEM"), + eventstore.SetEventType("project.application.added"), + ), + ), + eventstore.FilterPagination( + eventstore.Limit(1), + ), + ), + eventstore.NewFilter( + eventstore.AppendAggregateFilter( + "user", + eventstore.AppendEvent( + eventstore.SetEventType("user.token.added"), + ), + ), + eventstore.FilterPagination( + // used because we need to check for first login and an app which is not console + eventstore.PositionGreater(12, 4), + ), + ), + eventstore.NewFilter( + eventstore.AppendAggregateFilter( + "instance", + eventstore.AppendEvent( + eventstore.SetEventTypes( + "instance.idp.config.added", + "instance.idp.oauth.added", + "instance.idp.oidc.added", + "instance.idp.jwt.added", + "instance.idp.azure.added", + "instance.idp.github.added", + "instance.idp.github.enterprise.added", + "instance.idp.gitlab.added", + "instance.idp.gitlab.selfhosted.added", + "instance.idp.google.added", + "instance.idp.ldap.added", + "instance.idp.config.apple.added", + "instance.idp.saml.added", + ), + ), + ), + eventstore.AppendAggregateFilter( + "org", + eventstore.AppendEvent( + eventstore.SetEventTypes( + "org.idp.config.added", + "org.idp.oauth.added", + "org.idp.oidc.added", + "org.idp.jwt.added", + "org.idp.azure.added", + "org.idp.github.added", + "org.idp.github.enterprise.added", + "org.idp.gitlab.added", + "org.idp.gitlab.selfhosted.added", + "org.idp.google.added", + "org.idp.ldap.added", + "org.idp.config.apple.added", + "org.idp.saml.added", + ), + ), + ), + eventstore.FilterPagination( + eventstore.Limit(1), + ), + ), + eventstore.NewFilter( + eventstore.AppendAggregateFilter( + "instance", + eventstore.AppendEvent( + eventstore.SetEventType("instance.login.policy.idp.added"), + ), + ), + eventstore.AppendAggregateFilter( + "org", + eventstore.AppendEvent( + eventstore.SetEventType("org.login.policy.idp.added"), + ), + ), + eventstore.FilterPagination( + eventstore.Limit(1), + ), + ), + eventstore.NewFilter( + eventstore.AppendAggregateFilter( + "instance", + eventstore.AppendEvent( + eventstore.SetEventType("instance.smtp.config.added"), + eventstore.EventCreatorsNotContains("", "SYSTEM", ""), + ), + ), + eventstore.FilterPagination( + eventstore.Limit(1), + ), + ), + ), + ), + }, + want: wantQuery{ + query: `SELECT created_at, event_type, "sequence", "position", in_tx_order, payload, creator, "owner", instance_id, aggregate_type, aggregate_id, revision FROM ((SELECT created_at, event_type, "sequence", "position", in_tx_order, payload, creator, "owner", instance_id, aggregate_type, aggregate_id, revision FROM eventstore.events2 WHERE instance_id = $1 AND (aggregate_type = $2 AND event_type = $3) ORDER BY position, in_tx_order LIMIT $4) UNION ALL (SELECT created_at, event_type, "sequence", "position", in_tx_order, payload, creator, "owner", instance_id, aggregate_type, aggregate_id, revision FROM eventstore.events2 WHERE instance_id = $5 AND (aggregate_type = $6 AND event_type = $7) ORDER BY position, in_tx_order LIMIT $8) UNION ALL (SELECT created_at, event_type, "sequence", "position", in_tx_order, payload, creator, "owner", instance_id, aggregate_type, aggregate_id, revision FROM eventstore.events2 WHERE instance_id = $9 AND (aggregate_type = $10 AND (event_type = $11 AND NOT(creator = ANY($12)))) ORDER BY position, in_tx_order LIMIT $13) UNION ALL (SELECT created_at, event_type, "sequence", "position", in_tx_order, payload, creator, "owner", instance_id, aggregate_type, aggregate_id, revision FROM eventstore.events2 WHERE instance_id = $14 AND (aggregate_type = $15 AND (event_type = $16 AND NOT(creator = ANY($17)))) ORDER BY position, in_tx_order LIMIT $18) UNION ALL (SELECT created_at, event_type, "sequence", "position", in_tx_order, payload, creator, "owner", instance_id, aggregate_type, aggregate_id, revision FROM eventstore.events2 WHERE instance_id = $19 AND (aggregate_type = $20 AND (event_type = $21 AND NOT(creator = ANY($22)))) ORDER BY position, in_tx_order LIMIT $23) UNION ALL (SELECT created_at, event_type, "sequence", "position", in_tx_order, payload, creator, "owner", instance_id, aggregate_type, aggregate_id, revision FROM eventstore.events2 WHERE instance_id = $24 AND (aggregate_type = $25 AND event_type = $26) AND ((position = $27 AND in_tx_order > $28) OR position > $29) ORDER BY position, in_tx_order) UNION ALL (SELECT created_at, event_type, "sequence", "position", in_tx_order, payload, creator, "owner", instance_id, aggregate_type, aggregate_id, revision FROM eventstore.events2 WHERE instance_id = $30 AND ((aggregate_type = $31 AND event_type = ANY($32)) OR (aggregate_type = $33 AND event_type = ANY($34))) ORDER BY position, in_tx_order LIMIT $35) UNION ALL (SELECT created_at, event_type, "sequence", "position", in_tx_order, payload, creator, "owner", instance_id, aggregate_type, aggregate_id, revision FROM eventstore.events2 WHERE instance_id = $36 AND ((aggregate_type = $37 AND event_type = $38) OR (aggregate_type = $39 AND event_type = $40)) ORDER BY position, in_tx_order LIMIT $41) UNION ALL (SELECT created_at, event_type, "sequence", "position", in_tx_order, payload, creator, "owner", instance_id, aggregate_type, aggregate_id, revision FROM eventstore.events2 WHERE instance_id = $42 AND (aggregate_type = $43 AND (event_type = $44 AND NOT(creator = ANY($45)))) ORDER BY position, in_tx_order LIMIT $46)) ORDER BY position, in_tx_order`, + args: []any{ + "instance", + "instance", + "instance.added", + uint32(1), + "instance", + "instance", + "instance.removed", + uint32(1), + "instance", + "instance", + "instance.domain.primary.set", + []string{"", "SYSTEM"}, + uint32(1), + "instance", + "project", + "project.added", + []string{"", "SYSTEM"}, + uint32(1), + "instance", + "project", + "project.application.added", + []string{"", "SYSTEM"}, + uint32(1), + "instance", + "user", + "user.token.added", + float64(12), + uint32(4), + float64(12), + "instance", + "instance", + []string{"instance.idp.config.added", "instance.idp.oauth.added", "instance.idp.oidc.added", "instance.idp.jwt.added", "instance.idp.azure.added", "instance.idp.github.added", "instance.idp.github.enterprise.added", "instance.idp.gitlab.added", "instance.idp.gitlab.selfhosted.added", "instance.idp.google.added", "instance.idp.ldap.added", "instance.idp.config.apple.added", "instance.idp.saml.added"}, + "org", + []string{"org.idp.config.added", "org.idp.oauth.added", "org.idp.oidc.added", "org.idp.jwt.added", "org.idp.azure.added", "org.idp.github.added", "org.idp.github.enterprise.added", "org.idp.gitlab.added", "org.idp.gitlab.selfhosted.added", "org.idp.google.added", "org.idp.ldap.added", "org.idp.config.apple.added", "org.idp.saml.added"}, + uint32(1), + "instance", + "instance", + "instance.login.policy.idp.added", + "org", + "org.login.policy.idp.added", + uint32(1), + "instance", + "instance", + "instance.smtp.config.added", + []string{"", "SYSTEM", ""}, + uint32(1), + }, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var stmt database.Statement + writeQuery(&stmt, tt.args.query) + assertQuery(t, &stmt, tt.want) + }) + } +} + +type wantQuery struct { + query string + args []any +} + +func assertQuery(t *testing.T, stmt *database.Statement, want wantQuery) bool { + t.Helper() + ok := true + + defer func() { + if !ok { + t.Logf("generated statement: %s\n", stmt.Debug()) + } + }() + + got := stmt.String() + if got != want.query { + t.Errorf("unexpected query:\n want: %q\n got: %q", want.query, got) + ok = false + } + + if len(want.args) != len(stmt.Args()) { + t.Errorf("unexpected length of args, want: %d got: %d", len(want.args), len(stmt.Args())) + return false + } + + for i, arg := range want.args { + if !reflect.DeepEqual(arg, stmt.Args()[i]) { + t.Errorf("unexpected arg at %d, want %v got: %v", i, arg, stmt.Args()[i]) + ok = false + } + } + + return ok +} + +var _ eventstore.Reducer = (*testReducer)(nil) + +type testReducer struct { + expectedReduces int + reduceCount int + shouldErr bool +} + +// Reduce implements eventstore.Reducer. +func (r *testReducer) Reduce(events ...*eventstore.Event[eventstore.StoragePayload]) error { + if r == nil { + return nil + } + r.reduceCount++ + if r.shouldErr { + return errReduce + } + return nil +} + +func (r *testReducer) assert(t *testing.T) { + if r.expectedReduces == r.reduceCount { + return + } + + t.Errorf("unexpected reduces, want %d, got %d", r.expectedReduces, r.reduceCount) +} + +func Test_executeQuery(t *testing.T) { + type args struct { + values [][]driver.Value + reducer *testReducer + } + type want struct { + eventCount int + assertErr func(t *testing.T, err error) bool + } + tests := []struct { + name string + args args + want want + }{ + { + name: "no result", + args: args{ + values: [][]driver.Value{}, + reducer: &testReducer{}, + }, + want: want{ + eventCount: 0, + assertErr: func(t *testing.T, err error) bool { + is := errors.Is(err, nil) + if !is { + t.Errorf("no error expected got: %v", err) + } + return is + }, + }, + }, + { + name: "1 event without payload", + args: args{ + values: [][]driver.Value{ + { + time.Now(), + "event.type", + uint32(23), + float64(123), + uint32(0), + nil, + "gigi", + "owner", + "instance", + "aggregate.type", + "aggregate.id", + uint16(1), + }, + }, + reducer: &testReducer{ + expectedReduces: 1, + }, + }, + want: want{ + eventCount: 1, + assertErr: func(t *testing.T, err error) bool { + is := errors.Is(err, nil) + if !is { + t.Errorf("no error expected got: %v", err) + } + return is + }, + }, + }, + { + name: "1 event with payload", + args: args{ + values: [][]driver.Value{ + { + time.Now(), + "event.type", + uint32(23), + float64(123), + uint32(0), + []byte(`{"name": "gigi"}`), + "gigi", + "owner", + "instance", + "aggregate.type", + "aggregate.id", + uint16(1), + }, + }, + reducer: &testReducer{ + expectedReduces: 1, + }, + }, + want: want{ + eventCount: 1, + assertErr: func(t *testing.T, err error) bool { + is := errors.Is(err, nil) + if !is { + t.Errorf("no error expected got: %v", err) + } + return is + }, + }, + }, + { + name: "multiple events", + args: args{ + values: [][]driver.Value{ + { + time.Now(), + "event.type", + uint32(23), + float64(123), + uint32(0), + nil, + "gigi", + "owner", + "instance", + "aggregate.type", + "aggregate.id", + uint16(1), + }, + { + time.Now(), + "event.type", + uint32(24), + float64(124), + uint32(0), + []byte(`{"name": "gigi"}`), + "gigi", + "owner", + "instance", + "aggregate.type", + "aggregate.id", + uint16(1), + }, + }, + reducer: &testReducer{ + expectedReduces: 2, + }, + }, + want: want{ + eventCount: 2, + assertErr: func(t *testing.T, err error) bool { + is := errors.Is(err, nil) + if !is { + t.Errorf("no error expected got: %v", err) + } + return is + }, + }, + }, + { + name: "reduce error", + args: args{ + values: [][]driver.Value{ + { + time.Now(), + "event.type", + uint32(23), + float64(123), + uint32(0), + nil, + "gigi", + "owner", + "instance", + "aggregate.type", + "aggregate.id", + uint16(1), + }, + { + time.Now(), + "event.type", + uint32(24), + float64(124), + uint32(0), + []byte(`{"name": "gigi"}`), + "gigi", + "owner", + "instance", + "aggregate.type", + "aggregate.id", + uint16(1), + }, + }, + reducer: &testReducer{ + expectedReduces: 1, + shouldErr: true, + }, + }, + want: want{ + eventCount: 1, + assertErr: func(t *testing.T, err error) bool { + is := errors.Is(err, errReduce) + if !is { + t.Errorf("no error expected got: %v", err) + } + return is + }, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + mockDB := mock.NewSQLMock(t, + mock.ExpectQuery( + "", + mock.WithQueryResult( + []string{"created_at", "event_type", "sequence", "position", "in_tx_order", "payload", "creator", "owner", "instance_id", "aggregate_type", "aggregate_id", "revision"}, + tt.args.values, + ), + ), + ) + gotEventCount, err := executeQuery(context.Background(), mockDB.DB, &database.Statement{}, tt.args.reducer) + tt.want.assertErr(t, err) + if gotEventCount != tt.want.eventCount { + t.Errorf("executeQuery() = %v, want %v", gotEventCount, tt.want.eventCount) + } + }) + } +} diff --git a/internal/v2/eventstore/postgres/storage.go b/internal/v2/eventstore/postgres/storage.go new file mode 100644 index 0000000000..d2bf2a1195 --- /dev/null +++ b/internal/v2/eventstore/postgres/storage.go @@ -0,0 +1,28 @@ +package postgres + +import ( + "context" + + "github.com/zitadel/zitadel/internal/database" + "github.com/zitadel/zitadel/internal/v2/eventstore" +) + +var ( + _ eventstore.Pusher = (*Storage)(nil) + _ eventstore.Querier = (*Storage)(nil) +) + +type Storage struct { + client *database.DB +} + +func New(client *database.DB) *Storage { + return &Storage{ + client: client, + } +} + +// Health implements eventstore.Pusher. +func (s *Storage) Health(ctx context.Context) error { + return s.client.PingContext(ctx) +} diff --git a/internal/v2/eventstore/push.go b/internal/v2/eventstore/push.go new file mode 100644 index 0000000000..6260315b82 --- /dev/null +++ b/internal/v2/eventstore/push.go @@ -0,0 +1,190 @@ +package eventstore + +import ( + "context" + "database/sql" +) + +type Pusher interface { + healthier + // Push writes the intents to the storage + // if an intent implements [PushReducerIntent] [PushReducerIntent.Reduce] is called after + // the intent was stored + Push(ctx context.Context, intent *PushIntent) error +} + +func NewPushIntent(instance string, opts ...PushOpt) *PushIntent { + intent := &PushIntent{ + instance: instance, + } + + for _, opt := range opts { + opt(intent) + } + + return intent +} + +type PushIntent struct { + instance string + reducer Reducer + tx *sql.Tx + aggregates []*PushAggregate +} + +func (pi *PushIntent) Instance() string { + return pi.instance +} + +func (pi *PushIntent) Reduce(events ...*Event[StoragePayload]) error { + if pi.reducer == nil { + return nil + } + return pi.reducer.Reduce(events...) +} + +func (pi *PushIntent) Tx() *sql.Tx { + return pi.tx +} + +func (pi *PushIntent) Aggregates() []*PushAggregate { + return pi.aggregates +} + +type PushOpt func(pi *PushIntent) + +func PushReducer(reducer Reducer) PushOpt { + return func(pi *PushIntent) { + pi.reducer = reducer + } +} + +func PushTx(tx *sql.Tx) PushOpt { + return func(pi *PushIntent) { + pi.tx = tx + } +} + +func AppendAggregate(owner, typ, id string, opts ...PushAggregateOpt) PushOpt { + return AppendAggregates(NewPushAggregate(owner, typ, id, opts...)) +} + +func AppendAggregates(aggregates ...*PushAggregate) PushOpt { + return func(pi *PushIntent) { + for _, aggregate := range aggregates { + aggregate.parent = pi + } + pi.aggregates = append(pi.aggregates, aggregates...) + } +} + +type PushAggregate struct { + parent *PushIntent + // typ of the aggregate + typ string + // id of the aggregate + id string + // owner of the aggregate + owner string + // Commands is an ordered list of changes on the aggregate + commands []Command + // CurrentSequence checks the current state of the aggregate. + // The following types match the current sequence of the aggregate as described: + // * nil or [SequenceIgnore]: Not relevant to add the commands + // * [SequenceMatches]: Must exactly match + // * [SequenceAtLeast]: Must be >= the given sequence + currentSequence CurrentSequence +} + +func NewPushAggregate(owner, typ, id string, opts ...PushAggregateOpt) *PushAggregate { + pa := &PushAggregate{ + typ: typ, + id: id, + owner: owner, + } + + for _, opt := range opts { + opt(pa) + } + + return pa +} + +func (pa *PushAggregate) Type() string { + return pa.typ +} + +func (pa *PushAggregate) ID() string { + return pa.id +} + +func (pa *PushAggregate) Owner() string { + return pa.owner +} + +func (pa *PushAggregate) Commands() []Command { + return pa.commands +} + +func (pa *PushAggregate) Aggregate() *Aggregate { + return &Aggregate{ + ID: pa.id, + Type: pa.typ, + Owner: pa.owner, + Instance: pa.parent.instance, + } +} + +func (pa *PushAggregate) CurrentSequence() CurrentSequence { + return pa.currentSequence +} + +type PushAggregateOpt func(pa *PushAggregate) + +func SetCurrentSequence(currentSequence CurrentSequence) PushAggregateOpt { + return func(pa *PushAggregate) { + pa.currentSequence = currentSequence + } +} + +func IgnoreCurrentSequence() PushAggregateOpt { + return func(pa *PushAggregate) { + pa.currentSequence = SequenceIgnore() + } +} + +func CurrentSequenceMatches(sequence uint32) PushAggregateOpt { + return func(pa *PushAggregate) { + pa.currentSequence = SequenceMatches(sequence) + } +} + +func CurrentSequenceAtLeast(sequence uint32) PushAggregateOpt { + return func(pa *PushAggregate) { + pa.currentSequence = SequenceAtLeast(sequence) + } +} + +func AppendCommands(commands ...Command) PushAggregateOpt { + return func(pa *PushAggregate) { + pa.commands = append(pa.commands, commands...) + } +} + +type Command interface { + // Creator is the id of the user which created the action + Creator() string + // Type describes the action it's in the past (e.g. user.created) + Type() string + // Revision of the action + Revision() uint16 + // Payload returns the payload of the event. It represent the changed fields by the event + // valid types are: + // * nil: no payload + // * struct: which can be marshalled to json + // * pointer to struct: which can be marshalled to json + // * []byte: json marshalled data + Payload() any + // UniqueConstraints should be added for unique attributes of an event, if nil constraints will not be checked + UniqueConstraints() []*UniqueConstraint +} diff --git a/internal/v2/eventstore/query.go b/internal/v2/eventstore/query.go new file mode 100644 index 0000000000..0dd23ea898 --- /dev/null +++ b/internal/v2/eventstore/query.go @@ -0,0 +1,756 @@ +package eventstore + +import ( + "context" + "database/sql" + "errors" + "slices" + "time" + + "github.com/zitadel/zitadel/internal/v2/database" +) + +type Querier interface { + healthier + Query(ctx context.Context, query *Query) (eventCount int, err error) +} + +type Query struct { + instances *filter[[]string] + filters []*Filter + tx *sql.Tx + pagination *Pagination + reducer Reducer + // TODO: await push +} + +func (q *Query) Instance() database.Condition { + return q.instances.condition +} + +func (q *Query) Filters() []*Filter { + return q.filters +} + +func (q *Query) Tx() *sql.Tx { + return q.tx +} + +func (q *Query) Pagination() *Pagination { + q.ensurePagination() + return q.pagination +} + +func (q *Query) Reduce(events ...*Event[StoragePayload]) error { + return q.reducer.Reduce(events...) +} + +func NewQuery(instance string, reducer Reducer, opts ...QueryOpt) *Query { + query := &Query{ + reducer: reducer, + } + + for _, opt := range append([]QueryOpt{SetInstance(instance)}, opts...) { + opt(query) + } + + return query +} + +type QueryOpt func(q *Query) + +func SetInstance(instance string) QueryOpt { + return InstancesEqual(instance) +} + +func InstancesEqual(instances ...string) QueryOpt { + return func(q *Query) { + var cond database.Condition + switch len(instances) { + case 0: + return + case 1: + cond = database.NewTextEqual(instances[0]) + default: + cond = database.NewListEquals(instances...) + } + q.instances = &filter[[]string]{ + condition: cond, + value: &instances, + } + } +} + +func InstancesContains(instances ...string) QueryOpt { + return func(f *Query) { + var cond database.Condition + switch len(instances) { + case 0: + return + case 1: + cond = database.NewTextEqual(instances[0]) + default: + cond = database.NewListContains(instances...) + } + + f.instances = &filter[[]string]{ + condition: cond, + value: &instances, + } + } +} + +func InstancesNotContains(instances ...string) QueryOpt { + return func(f *Query) { + var cond database.Condition + switch len(instances) { + case 0: + return + case 1: + cond = database.NewTextUnequal(instances[0]) + default: + cond = database.NewListNotContains(instances...) + } + f.instances = &filter[[]string]{ + condition: cond, + value: &instances, + } + } +} + +func SetQueryTx(tx *sql.Tx) QueryOpt { + return func(query *Query) { + query.tx = tx + } +} + +func QueryPagination(opts ...paginationOpt) QueryOpt { + return func(query *Query) { + query.ensurePagination() + + for _, opt := range opts { + opt(query.pagination) + } + } +} + +func (q *Query) ensurePagination() { + if q.pagination != nil { + return + } + q.pagination = new(Pagination) +} + +func AppendFilters(filters ...*Filter) QueryOpt { + return func(query *Query) { + for _, filter := range filters { + filter.parent = query + } + query.filters = append(query.filters, filters...) + } +} + +func SetFilters(filters ...*Filter) QueryOpt { + return func(query *Query) { + for _, filter := range filters { + filter.parent = query + } + query.filters = filters + } +} + +func AppendFilter(opts ...FilterOpt) QueryOpt { + return AppendFilters(NewFilter(opts...)) +} + +var ErrFilterMerge = errors.New("merge failed") + +type FilterCreator func() []*Filter + +func MergeFilters(filters ...[]*Filter) []*Filter { + // TODO: improve merge by checking fields of filters and merge filters if possible + // this will reduce cost of queries which do multiple filters + return slices.Concat(filters...) +} + +type Filter struct { + parent *Query + pagination *Pagination + + aggregateFilters []*AggregateFilter +} + +func (f *Filter) Parent() *Query { + return f.parent +} + +func (f *Filter) Pagination() *Pagination { + if f.pagination == nil { + return f.parent.Pagination() + } + return f.pagination +} + +func (f *Filter) AggregateFilters() []*AggregateFilter { + return f.aggregateFilters +} + +func NewFilter(opts ...FilterOpt) *Filter { + f := new(Filter) + + for _, opt := range opts { + opt(f) + } + + return f +} + +type FilterOpt func(f *Filter) + +func AppendAggregateFilter(typ string, opts ...AggregateFilterOpt) FilterOpt { + return AppendAggregateFilters(NewAggregateFilter(typ, opts...)) +} + +func AppendAggregateFilters(filters ...*AggregateFilter) FilterOpt { + return func(mf *Filter) { + mf.aggregateFilters = append(mf.aggregateFilters, filters...) + } +} + +func SetAggregateFilters(filters ...*AggregateFilter) FilterOpt { + return func(mf *Filter) { + mf.aggregateFilters = filters + } +} + +func FilterPagination(opts ...paginationOpt) FilterOpt { + return func(filter *Filter) { + filter.ensurePagination() + + for _, opt := range opts { + opt(filter.pagination) + } + } +} + +func (f *Filter) ensurePagination() { + if f.pagination != nil { + return + } + f.pagination = new(Pagination) +} + +func NewAggregateFilter(typ string, opts ...AggregateFilterOpt) *AggregateFilter { + filter := &AggregateFilter{ + typ: typ, + } + + for _, opt := range opts { + opt(filter) + } + + return filter +} + +type AggregateFilter struct { + typ string + ids []string + events []*EventFilter +} + +func (f *AggregateFilter) Type() *database.TextFilter[string] { + return database.NewTextEqual(f.typ) +} + +func (f *AggregateFilter) IDs() database.Condition { + if len(f.ids) == 0 { + return nil + } + if len(f.ids) == 1 { + return database.NewTextEqual(f.ids[0]) + } + + return database.NewListContains(f.ids...) +} + +func (f *AggregateFilter) Events() []*EventFilter { + return f.events +} + +type AggregateFilterOpt func(f *AggregateFilter) + +func SetAggregateID(id string) AggregateFilterOpt { + return func(filter *AggregateFilter) { + filter.ids = []string{id} + } +} + +func AppendAggregateIDs(ids ...string) AggregateFilterOpt { + return func(f *AggregateFilter) { + f.ids = append(f.ids, ids...) + } +} + +// AggregateIDs sets the given ids as search param +func AggregateIDs(ids ...string) AggregateFilterOpt { + return func(f *AggregateFilter) { + f.ids = ids + } +} + +func AppendEvent(opts ...EventFilterOpt) AggregateFilterOpt { + return AppendEvents(NewEventFilter(opts...)) +} + +func AppendEvents(events ...*EventFilter) AggregateFilterOpt { + return func(filter *AggregateFilter) { + filter.events = append(filter.events, events...) + } +} + +func SetEvents(events ...*EventFilter) AggregateFilterOpt { + return func(filter *AggregateFilter) { + filter.events = events + } +} + +func NewEventFilter(opts ...EventFilterOpt) *EventFilter { + filter := new(EventFilter) + + for _, opt := range opts { + opt(filter) + } + + return filter +} + +type EventFilter struct { + types []string + revision *filter[uint16] + createdAt *filter[time.Time] + sequence *filter[uint32] + creators *filter[[]string] +} + +type filter[T any] struct { + condition database.Condition + // the following fields are considered as one of + // you can either have value and max or value + min, max *T + value *T +} + +func (f *EventFilter) Types() database.Condition { + switch len(f.types) { + case 0: + return nil + case 1: + return database.NewTextEqual(f.types[0]) + default: + return database.NewListContains(f.types...) + } +} + +func (f *EventFilter) Revision() database.Condition { + if f.revision == nil { + return nil + } + return f.revision.condition +} + +func (f *EventFilter) CreatedAt() database.Condition { + if f.createdAt == nil { + return nil + } + return f.createdAt.condition +} + +func (f *EventFilter) Sequence() database.Condition { + if f.sequence == nil { + return nil + } + return f.sequence.condition +} + +func (f *EventFilter) Creators() database.Condition { + if f.creators == nil { + return nil + } + return f.creators.condition +} + +type EventFilterOpt func(f *EventFilter) + +func SetEventType(typ string) EventFilterOpt { + return func(filter *EventFilter) { + filter.types = []string{typ} + } +} + +// SetEventTypes overwrites the currently set types +func SetEventTypes(types ...string) EventFilterOpt { + return func(filter *EventFilter) { + filter.types = types + } +} + +// AppendEventTypes appends the types the currently set types +func AppendEventTypes(types ...string) EventFilterOpt { + return func(filter *EventFilter) { + filter.types = append(filter.types, types...) + } +} + +func EventRevisionEquals(revision uint16) EventFilterOpt { + return func(f *EventFilter) { + f.revision = &filter[uint16]{ + condition: database.NewNumberEquals(revision), + value: &revision, + } + } +} + +func EventRevisionAtLeast(revision uint16) EventFilterOpt { + return func(f *EventFilter) { + f.revision = &filter[uint16]{ + condition: database.NewNumberAtLeast(revision), + value: &revision, + } + } +} + +func EventRevisionGreater(revision uint16) EventFilterOpt { + return func(f *EventFilter) { + f.revision = &filter[uint16]{ + condition: database.NewNumberGreater(revision), + value: &revision, + } + } +} + +func EventRevisionAtMost(revision uint16) EventFilterOpt { + return func(f *EventFilter) { + f.revision = &filter[uint16]{ + condition: database.NewNumberAtMost(revision), + value: &revision, + } + } +} + +func EventRevisionLess(revision uint16) EventFilterOpt { + return func(f *EventFilter) { + f.revision = &filter[uint16]{ + condition: database.NewNumberLess(revision), + value: &revision, + } + } +} + +func EventRevisionBetween(min, max uint16) EventFilterOpt { + return func(f *EventFilter) { + f.revision = &filter[uint16]{ + condition: database.NewNumberBetween(min, max), + min: &min, + max: &max, + } + } +} + +func EventCreatedAtEquals(createdAt time.Time) EventFilterOpt { + return func(f *EventFilter) { + f.createdAt = &filter[time.Time]{ + condition: database.NewNumberEquals(createdAt), + value: &createdAt, + } + } +} + +func EventCreatedAtAtLeast(createdAt time.Time) EventFilterOpt { + return func(f *EventFilter) { + f.createdAt = &filter[time.Time]{ + condition: database.NewNumberAtLeast(createdAt), + value: &createdAt, + } + } +} + +func EventCreatedAtGreater(createdAt time.Time) EventFilterOpt { + return func(f *EventFilter) { + f.createdAt = &filter[time.Time]{ + condition: database.NewNumberGreater(createdAt), + value: &createdAt, + } + } +} + +func EventCreatedAtAtMost(createdAt time.Time) EventFilterOpt { + return func(f *EventFilter) { + f.createdAt = &filter[time.Time]{ + condition: database.NewNumberAtMost(createdAt), + value: &createdAt, + } + } +} + +func EventCreatedAtLess(createdAt time.Time) EventFilterOpt { + return func(f *EventFilter) { + f.createdAt = &filter[time.Time]{ + condition: database.NewNumberLess(createdAt), + value: &createdAt, + } + } +} + +func EventCreatedAtBetween(min, max time.Time) EventFilterOpt { + return func(f *EventFilter) { + f.createdAt = &filter[time.Time]{ + condition: database.NewNumberBetween(min, max), + min: &min, + max: &max, + } + } +} + +func EventSequenceEquals(sequence uint32) EventFilterOpt { + return func(f *EventFilter) { + f.sequence = &filter[uint32]{ + condition: database.NewNumberEquals(sequence), + value: &sequence, + } + } +} + +func EventSequenceAtLeast(sequence uint32) EventFilterOpt { + return func(f *EventFilter) { + f.sequence = &filter[uint32]{ + condition: database.NewNumberAtLeast(sequence), + value: &sequence, + } + } +} + +func EventSequenceGreater(sequence uint32) EventFilterOpt { + return func(f *EventFilter) { + f.sequence = &filter[uint32]{ + condition: database.NewNumberGreater(sequence), + value: &sequence, + } + } +} + +func EventSequenceAtMost(sequence uint32) EventFilterOpt { + return func(f *EventFilter) { + f.sequence = &filter[uint32]{ + condition: database.NewNumberAtMost(sequence), + value: &sequence, + } + } +} + +func EventSequenceLess(sequence uint32) EventFilterOpt { + return func(f *EventFilter) { + f.sequence = &filter[uint32]{ + condition: database.NewNumberLess(sequence), + value: &sequence, + } + } +} + +func EventSequenceBetween(min, max uint32) EventFilterOpt { + return func(f *EventFilter) { + f.sequence = &filter[uint32]{ + condition: database.NewNumberBetween(min, max), + min: &min, + max: &max, + } + } +} + +func EventCreatorsEqual(creators ...string) EventFilterOpt { + return func(f *EventFilter) { + var cond database.Condition + switch len(creators) { + case 0: + return + case 1: + cond = database.NewTextEqual(creators[0]) + default: + cond = database.NewListEquals(creators...) + } + f.creators = &filter[[]string]{ + condition: cond, + value: &creators, + } + } +} + +func EventCreatorsContains(creators ...string) EventFilterOpt { + return func(f *EventFilter) { + var cond database.Condition + switch len(creators) { + case 0: + return + case 1: + cond = database.NewTextEqual(creators[0]) + default: + cond = database.NewListContains(creators...) + } + + f.creators = &filter[[]string]{ + condition: cond, + value: &creators, + } + } +} + +func EventCreatorsNotContains(creators ...string) EventFilterOpt { + return func(f *EventFilter) { + var cond database.Condition + switch len(creators) { + case 0: + return + case 1: + cond = database.NewTextUnequal(creators[0]) + default: + cond = database.NewListNotContains(creators...) + } + f.creators = &filter[[]string]{ + condition: cond, + value: &creators, + } + } +} + +func Limit(limit uint32) paginationOpt { + return func(p *Pagination) { + p.ensurePagination() + + p.pagination.Limit = limit + } +} + +func Offset(offset uint32) paginationOpt { + return func(p *Pagination) { + p.ensurePagination() + + p.pagination.Offset = offset + } +} + +type PositionCondition struct { + min, max *GlobalPosition +} + +func (pc *PositionCondition) Max() *GlobalPosition { + if pc == nil || pc.max == nil { + return nil + } + max := *pc.max + return &max +} + +func (pc *PositionCondition) Min() *GlobalPosition { + if pc == nil || pc.min == nil { + return nil + } + min := *pc.min + return &min +} + +// PositionGreater prepares the condition as follows +// if inPositionOrder is set: position = AND in_tx_order > OR or position > +// if inPositionOrder is NOT set: position > +func PositionGreater(position float64, inPositionOrder uint32) paginationOpt { + return func(p *Pagination) { + p.ensurePosition() + p.position.min = &GlobalPosition{ + Position: position, + InPositionOrder: inPositionOrder, + } + } +} + +// GlobalPositionGreater prepares the condition as follows +// if inPositionOrder is set: position = AND in_tx_order > OR or position > +// if inPositionOrder is NOT set: position > +func GlobalPositionGreater(position *GlobalPosition) paginationOpt { + return PositionGreater(position.Position, position.InPositionOrder) +} + +// PositionLess prepares the condition as follows +// if inPositionOrder is set: position = AND in_tx_order > OR or position > +// if inPositionOrder is NOT set: position > +func PositionLess(position float64, inPositionOrder uint32) paginationOpt { + return func(p *Pagination) { + p.ensurePosition() + p.position.max = &GlobalPosition{ + Position: position, + InPositionOrder: inPositionOrder, + } + } +} + +func PositionBetween(min, max *GlobalPosition) paginationOpt { + return func(p *Pagination) { + GlobalPositionGreater(min)(p) + GlobalPositionLess(max)(p) + } +} + +// GlobalPositionLess prepares the condition as follows +// if inPositionOrder is set: position = AND in_tx_order > OR or position > +// if inPositionOrder is NOT set: position > +func GlobalPositionLess(position *GlobalPosition) paginationOpt { + return PositionLess(position.Position, position.InPositionOrder) +} + +type Pagination struct { + pagination *database.Pagination + position *PositionCondition + + desc bool +} + +type paginationOpt func(*Pagination) + +func (p *Pagination) Pagination() *database.Pagination { + if p == nil { + return nil + } + return p.pagination +} + +func (p *Pagination) Position() *PositionCondition { + if p == nil { + return nil + } + return p.position +} + +func (p *Pagination) Desc() bool { + if p == nil { + return false + } + + return p.desc +} + +func (p *Pagination) ensurePagination() { + if p.pagination != nil { + return + } + p.pagination = new(database.Pagination) +} + +func (p *Pagination) ensurePosition() { + if p.position != nil { + return + } + p.position = new(PositionCondition) +} + +func Descending() paginationOpt { + return func(p *Pagination) { + p.desc = true + } +} diff --git a/internal/v2/eventstore/query_test.go b/internal/v2/eventstore/query_test.go new file mode 100644 index 0000000000..00c08914c1 --- /dev/null +++ b/internal/v2/eventstore/query_test.go @@ -0,0 +1,1063 @@ +package eventstore + +import ( + "database/sql" + "reflect" + "testing" + "time" + + "github.com/zitadel/zitadel/internal/v2/database" +) + +func TestPaginationOpt(t *testing.T) { + type args struct { + opts []paginationOpt + } + tests := []struct { + name string + args args + want *Pagination + }{ + { + name: "desc", + args: args{ + opts: []paginationOpt{ + Descending(), + }, + }, + want: &Pagination{ + desc: true, + }, + }, + { + name: "limit", + args: args{ + opts: []paginationOpt{ + Limit(10), + }, + }, + want: &Pagination{ + pagination: &database.Pagination{ + Limit: 10, + }, + }, + }, + { + name: "offset", + args: args{ + opts: []paginationOpt{ + Offset(10), + }, + }, + want: &Pagination{ + pagination: &database.Pagination{ + Offset: 10, + }, + }, + }, + { + name: "limit and offset", + args: args{ + opts: []paginationOpt{ + Limit(10), + Offset(20), + }, + }, + want: &Pagination{ + pagination: &database.Pagination{ + Limit: 10, + Offset: 20, + }, + }, + }, + { + name: "global position greater", + args: args{ + opts: []paginationOpt{ + GlobalPositionGreater(&GlobalPosition{Position: 10}), + }, + }, + want: &Pagination{ + position: &PositionCondition{ + min: &GlobalPosition{ + Position: 10, + InPositionOrder: 0, + }, + }, + }, + }, + { + name: "position greater", + args: args{ + opts: []paginationOpt{ + PositionGreater(10, 0), + }, + }, + want: &Pagination{ + position: &PositionCondition{ + min: &GlobalPosition{ + Position: 10, + InPositionOrder: 0, + }, + }, + desc: false, + }, + }, + { + name: "position less", + args: args{ + opts: []paginationOpt{ + PositionLess(10, 12), + }, + }, + want: &Pagination{ + position: &PositionCondition{ + max: &GlobalPosition{ + Position: 10, + InPositionOrder: 12, + }, + }, + }, + }, + { + name: "global position less", + args: args{ + opts: []paginationOpt{ + GlobalPositionLess(&GlobalPosition{Position: 12, InPositionOrder: 24}), + }, + }, + want: &Pagination{ + position: &PositionCondition{ + max: &GlobalPosition{ + Position: 12, + InPositionOrder: 24, + }, + }, + }, + }, + { + name: "position between", + args: args{ + opts: []paginationOpt{ + PositionBetween( + &GlobalPosition{10, 12}, + &GlobalPosition{20, 0}, + ), + }, + }, + want: &Pagination{ + position: &PositionCondition{ + min: &GlobalPosition{ + Position: 10, + InPositionOrder: 12, + }, + max: &GlobalPosition{ + Position: 20, + InPositionOrder: 0, + }, + }, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := new(Pagination) + for _, opt := range tt.args.opts { + opt(got) + } + + if tt.want.Desc() != got.Desc() { + t.Errorf("unexpected desc %v, want: %v", got.desc, tt.want.desc) + } + if !reflect.DeepEqual(tt.want.Pagination(), got.Pagination()) { + t.Errorf("unexpected pagination %v, want: %v", got.pagination, tt.want.pagination) + } + if !reflect.DeepEqual(tt.want.Position(), got.Position()) { + t.Errorf("unexpected position %v, want: %v", got.position, tt.want.position) + } + if !reflect.DeepEqual(tt.want.Position().Max(), got.Position().Max()) { + t.Errorf("unexpected position.max %v, want: %v", got.Position().max, tt.want.Position().max) + } + if !reflect.DeepEqual(tt.want.Position().Min(), got.Position().Min()) { + t.Errorf("unexpected position.min %v, want: %v", got.Position().min, tt.want.Position().min) + } + }) + } +} + +func TestEventFilterOpt(t *testing.T) { + type args struct { + opts []EventFilterOpt + } + now := time.Now() + tests := []struct { + name string + args args + want *EventFilter + }{ + { + name: "EventType", + args: args{ + opts: []EventFilterOpt{ + SetEventType("test"), + SetEventType("test2"), + }, + }, + want: &EventFilter{ + types: []string{"test2"}, + }, + }, + { + name: "EventTypes", + args: args{ + opts: []EventFilterOpt{ + SetEventTypes("a", "s"), + SetEventTypes("d", "f"), + }, + }, + want: &EventFilter{ + types: []string{"d", "f"}, + }, + }, + { + name: "AppendEventTypes", + args: args{ + opts: []EventFilterOpt{ + AppendEventTypes("a", "s"), + AppendEventTypes("d", "f"), + }, + }, + want: &EventFilter{ + types: []string{"a", "s", "d", "f"}, + }, + }, + { + name: "EventRevisionEquals", + args: args{ + opts: []EventFilterOpt{ + EventRevisionEquals(12), + }, + }, + want: &EventFilter{ + revision: &filter[uint16]{ + condition: database.NewNumberEquals[uint16](12), + value: toPtr(uint16(12)), + }, + }, + }, + { + name: "EventRevisionAtLeast", + args: args{ + opts: []EventFilterOpt{ + EventRevisionAtLeast(12), + }, + }, + want: &EventFilter{ + revision: &filter[uint16]{ + condition: database.NewNumberAtLeast[uint16](12), + value: toPtr(uint16(12)), + }, + }, + }, + { + name: "EventRevisionGreater", + args: args{ + opts: []EventFilterOpt{ + EventRevisionGreater(12), + }, + }, + want: &EventFilter{ + revision: &filter[uint16]{ + condition: database.NewNumberGreater[uint16](12), + value: toPtr(uint16(12)), + }, + }, + }, + { + name: "EventRevisionAtMost", + args: args{ + opts: []EventFilterOpt{ + EventRevisionAtMost(12), + }, + }, + want: &EventFilter{ + revision: &filter[uint16]{ + condition: database.NewNumberAtMost[uint16](12), + value: toPtr(uint16(12)), + }, + }, + }, + { + name: "EventRevisionLess", + args: args{ + opts: []EventFilterOpt{ + EventRevisionLess(12), + }, + }, + want: &EventFilter{ + revision: &filter[uint16]{ + condition: database.NewNumberLess[uint16](12), + value: toPtr(uint16(12)), + }, + }, + }, + { + name: "EventRevisionBetween", + args: args{ + opts: []EventFilterOpt{ + EventRevisionBetween(12, 20), + }, + }, + want: &EventFilter{ + revision: &filter[uint16]{ + condition: database.NewNumberBetween[uint16](12, 20), + min: toPtr(uint16(12)), + max: toPtr(uint16(20)), + }, + }, + }, + { + name: "EventCreatedAtEquals", + args: args{ + opts: []EventFilterOpt{ + EventCreatedAtEquals(now), + }, + }, + want: &EventFilter{ + createdAt: &filter[time.Time]{ + condition: database.NewNumberEquals(now), + value: toPtr(now), + }, + }, + }, + { + name: "EventCreatedAtAtLeast", + args: args{ + opts: []EventFilterOpt{ + EventCreatedAtAtLeast(now), + }, + }, + want: &EventFilter{ + createdAt: &filter[time.Time]{ + condition: database.NewNumberAtLeast(now), + value: toPtr(now), + }, + }, + }, + { + name: "EventCreatedAtGreater", + args: args{ + opts: []EventFilterOpt{ + EventCreatedAtGreater(now), + }, + }, + want: &EventFilter{ + createdAt: &filter[time.Time]{ + condition: database.NewNumberGreater(now), + value: toPtr(now), + }, + }, + }, + { + name: "EventCreatedAtAtMost", + args: args{ + opts: []EventFilterOpt{ + EventCreatedAtAtMost(now), + }, + }, + want: &EventFilter{ + createdAt: &filter[time.Time]{ + condition: database.NewNumberAtMost(now), + value: toPtr(now), + }, + }, + }, + { + name: "EventCreatedAtLess", + args: args{ + opts: []EventFilterOpt{ + EventCreatedAtLess(now), + }, + }, + want: &EventFilter{ + createdAt: &filter[time.Time]{ + condition: database.NewNumberLess(now), + value: toPtr(now), + }, + }, + }, + { + name: "EventCreatedAtBetween", + args: args{ + opts: []EventFilterOpt{ + EventCreatedAtBetween(now, now.Add(1*time.Second)), + }, + }, + want: &EventFilter{ + createdAt: &filter[time.Time]{ + condition: database.NewNumberBetween(now, now.Add(1*time.Second)), + min: toPtr(now), + max: toPtr(now.Add(1 * time.Second)), + }, + }, + }, + { + name: "EventSequenceEquals", + args: args{ + opts: []EventFilterOpt{ + EventSequenceEquals(12), + }, + }, + want: &EventFilter{ + sequence: &filter[uint32]{ + condition: database.NewNumberEquals[uint32](12), + value: toPtr(uint32(12)), + }, + }, + }, + { + name: "EventSequenceAtLeast", + args: args{ + opts: []EventFilterOpt{ + EventSequenceAtLeast(12), + }, + }, + want: &EventFilter{ + sequence: &filter[uint32]{ + condition: database.NewNumberAtLeast[uint32](12), + value: toPtr(uint32(12)), + }, + }, + }, + { + name: "EventSequenceGreater", + args: args{ + opts: []EventFilterOpt{ + EventSequenceGreater(12), + }, + }, + want: &EventFilter{ + sequence: &filter[uint32]{ + condition: database.NewNumberGreater[uint32](12), + value: toPtr(uint32(12)), + }, + }, + }, + { + name: "EventSequenceAtMost", + args: args{ + opts: []EventFilterOpt{ + EventSequenceAtMost(12), + }, + }, + want: &EventFilter{ + sequence: &filter[uint32]{ + condition: database.NewNumberAtMost[uint32](12), + value: toPtr(uint32(12)), + }, + }, + }, + { + name: "EventSequenceLess", + args: args{ + opts: []EventFilterOpt{ + EventSequenceLess(12), + }, + }, + want: &EventFilter{ + sequence: &filter[uint32]{ + condition: database.NewNumberLess[uint32](12), + value: toPtr(uint32(12)), + }, + }, + }, + { + name: "EventSequenceBetween", + args: args{ + opts: []EventFilterOpt{ + EventSequenceBetween(12, 24), + }, + }, + want: &EventFilter{ + sequence: &filter[uint32]{ + condition: database.NewNumberBetween[uint32](12, 24), + min: toPtr(uint32(12)), + max: toPtr(uint32(24)), + }, + }, + }, + { + name: "EventCreatorsEqual", + args: args{ + opts: []EventFilterOpt{ + EventCreatorsEqual("cr", "ea", "tor"), + }, + }, + want: &EventFilter{ + creators: &filter[[]string]{ + condition: database.NewListEquals("cr", "ea", "tor"), + value: toPtr([]string{"cr", "ea", "tor"}), + }, + }, + }, + { + name: "EventCreatorsEqual no params", + args: args{ + opts: []EventFilterOpt{ + EventCreatorsEqual(), + }, + }, + want: &EventFilter{}, + }, + { + name: "EventCreatorsEqual one params", + args: args{ + opts: []EventFilterOpt{ + EventCreatorsEqual("asdf"), + }, + }, + want: &EventFilter{ + creators: &filter[[]string]{ + condition: database.NewTextEqual("asdf"), + value: toPtr([]string{"asdf"}), + }, + }, + }, + { + name: "EventCreatorsContains", + args: args{ + opts: []EventFilterOpt{ + EventCreatorsContains("cr", "ea", "tor"), + }, + }, + want: &EventFilter{ + creators: &filter[[]string]{ + condition: database.NewListContains("cr", "ea", "tor"), + value: toPtr([]string{"cr", "ea", "tor"}), + }, + }, + }, + { + name: "EventCreatorsContains no params", + args: args{ + opts: []EventFilterOpt{ + EventCreatorsContains(), + }, + }, + want: &EventFilter{}, + }, + { + name: "EventCreatorsContains one params", + args: args{ + opts: []EventFilterOpt{ + EventCreatorsContains("asdf"), + }, + }, + want: &EventFilter{ + creators: &filter[[]string]{ + condition: database.NewTextEqual("asdf"), + value: toPtr([]string{"asdf"}), + }, + }, + }, + { + name: "EventCreatorsNotContains", + args: args{ + opts: []EventFilterOpt{ + EventCreatorsNotContains("cr", "ea", "tor"), + }, + }, + want: &EventFilter{ + creators: &filter[[]string]{ + condition: database.NewListNotContains("cr", "ea", "tor"), + value: toPtr([]string{"cr", "ea", "tor"}), + }, + }, + }, + { + name: "EventCreatorsNotContains no params", + args: args{ + opts: []EventFilterOpt{ + EventCreatorsNotContains(), + }, + }, + want: &EventFilter{}, + }, + { + name: "EventCreatorsNotContains one params", + args: args{ + opts: []EventFilterOpt{ + EventCreatorsNotContains("asdf"), + }, + }, + want: &EventFilter{ + creators: &filter[[]string]{ + condition: database.NewTextUnequal("asdf"), + value: toPtr([]string{"asdf"}), + }, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := NewEventFilter(tt.args.opts...) + + if !reflect.DeepEqual(tt.want.Types(), got.Types()) { + t.Errorf("unexpected types %v, want: %v", got.types, tt.want.types) + } + if !reflect.DeepEqual(tt.want.Revision(), got.Revision()) { + t.Errorf("unexpected revision %v, want: %v", got.revision, tt.want.revision) + } + if !reflect.DeepEqual(tt.want.CreatedAt(), got.CreatedAt()) { + t.Errorf("unexpected createdAt %v, want: %v", got.createdAt, tt.want.createdAt) + } + if !reflect.DeepEqual(tt.want.Sequence(), got.Sequence()) { + t.Errorf("unexpected sequence %v, want: %v", got.sequence, tt.want.sequence) + } + if !reflect.DeepEqual(tt.want.Creators(), got.Creators()) { + t.Errorf("unexpected creators %v, want: %v", got.creators, tt.want.creators) + } + }) + } +} + +func TestAggregateFilter(t *testing.T) { + type args struct { + opts []AggregateFilterOpt + } + tests := []struct { + name string + args args + want *AggregateFilter + }{ + { + name: "AggregateID", + args: args{ + opts: []AggregateFilterOpt{ + SetAggregateID("asdf"), + }, + }, + want: &AggregateFilter{ + ids: []string{"asdf"}, + }, + }, + { + name: "AggregateIDs", + args: args{ + opts: []AggregateFilterOpt{ + AggregateIDs("a", "s"), + AggregateIDs("d", "f"), + }, + }, + want: &AggregateFilter{ + ids: []string{"d", "f"}, + }, + }, + { + name: "AggregateIDs", + args: args{ + opts: []AggregateFilterOpt{ + AppendAggregateIDs("a", "s"), + AppendAggregateIDs("d", "f"), + }, + }, + want: &AggregateFilter{ + ids: []string{"a", "s", "d", "f"}, + }, + }, + { + name: "AppendEvent", + args: args{ + opts: []AggregateFilterOpt{ + AppendEvent(AppendEventTypes("asdf")), + AppendEvent(AppendEventTypes("asdf")), + }, + }, + want: &AggregateFilter{ + events: make([]*EventFilter, 2), + }, + }, + { + name: "AppendEvents", + args: args{ + opts: []AggregateFilterOpt{ + AppendEvents(NewEventFilter()), + AppendEvents(NewEventFilter()), + }, + }, + want: &AggregateFilter{ + events: make([]*EventFilter, 2), + }, + }, + { + name: "Events", + args: args{ + opts: []AggregateFilterOpt{ + SetEvents(NewEventFilter()), + SetEvents(NewEventFilter()), + }, + }, + want: &AggregateFilter{ + events: make([]*EventFilter, 1), + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := NewAggregateFilter("", tt.args.opts...) + + if tt.want.typ != got.typ { + t.Errorf("unexpected typ %v, want: %v", got.typ, tt.want.typ) + } + if !reflect.DeepEqual(tt.want.Type(), got.Type()) { + t.Errorf("unexpected typ %v, want: %v", got.typ, tt.want.typ) + } + if !reflect.DeepEqual(tt.want.IDs(), got.IDs()) { + t.Errorf("unexpected ids %v, want: %v", got.ids, tt.want.ids) + } + if len(tt.want.Events()) != len(got.Events()) { + t.Errorf("unexpected length of events %v, want: %v", len(got.events), len(tt.want.events)) + } + }) + } +} + +func TestFilterOpt(t *testing.T) { + type args struct { + opts []FilterOpt + } + tests := []struct { + name string + args args + want *Filter + }{ + { + name: "limit 1", + args: args{ + opts: []FilterOpt{ + FilterPagination(Limit(10)), + FilterPagination(Limit(1)), + }, + }, + want: &Filter{ + pagination: &Pagination{ + pagination: &database.Pagination{ + Limit: 1, + }, + }, + }, + }, + { + name: "AppendAggregateFilter", + args: args{ + opts: []FilterOpt{ + AppendAggregateFilter("typ"), + AppendAggregateFilter("typ2"), + }, + }, + want: &Filter{ + aggregateFilters: make([]*AggregateFilter, 2), + }, + }, + { + name: "AppendAggregateFilters", + args: args{ + opts: []FilterOpt{ + AppendAggregateFilters(NewAggregateFilter("typ")), + AppendAggregateFilters(NewAggregateFilter("typ2")), + }, + }, + want: &Filter{ + aggregateFilters: make([]*AggregateFilter, 2), + }, + }, + { + name: "AggregateFilters", + args: args{ + opts: []FilterOpt{ + SetAggregateFilters(NewAggregateFilter("typ")), + SetAggregateFilters(NewAggregateFilter("typ2")), + }, + }, + want: &Filter{ + aggregateFilters: make([]*AggregateFilter, 1), + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := NewFilter(tt.args.opts...) + parent := NewQuery("instance", nil) + got.parent = parent + tt.want.parent = parent + + if !reflect.DeepEqual(tt.want.Pagination(), got.Pagination()) { + t.Errorf("unexpected pagination %v, want: %v", got.pagination, tt.want.pagination) + } + if len(tt.want.AggregateFilters()) != len(got.AggregateFilters()) { + t.Errorf("unexpected length of aggregateFilters %v, want: %v", len(got.aggregateFilters), len(tt.want.aggregateFilters)) + } + }) + } +} + +func TestQueryOpt(t *testing.T) { + type args struct { + opts []QueryOpt + } + var tx sql.Tx + tests := []struct { + name string + args args + want *Query + }{ + { + name: "limit 1", + args: args{ + opts: []QueryOpt{ + QueryPagination(Limit(10)), + QueryPagination(Limit(1)), + }, + }, + want: &Query{ + instances: &filter[[]string]{ + condition: database.NewTextEqual("instance"), + value: toPtr([]string{"instance"}), + }, + pagination: &Pagination{ + pagination: &database.Pagination{ + Limit: 1, + }, + }, + }, + }, + { + name: "with tx", + args: args{ + opts: []QueryOpt{ + SetQueryTx(&tx), + }, + }, + want: &Query{ + instances: &filter[[]string]{ + condition: database.NewTextEqual("instance"), + value: toPtr([]string{"instance"}), + }, + tx: &tx, + }, + }, + { + name: "instance", + args: args{ + opts: []QueryOpt{ + SetInstance("instance2"), + }, + }, + want: &Query{ + instances: &filter[[]string]{ + condition: database.NewTextEqual("instance2"), + value: toPtr([]string{"instance2"}), + }, + }, + }, + { + name: "InstanceEqual no param", + args: args{ + opts: []QueryOpt{ + InstancesEqual(), + }, + }, + want: &Query{ + instances: &filter[[]string]{ + condition: database.NewTextEqual("instance"), + value: toPtr([]string{"instance"}), + }, + }, + }, + { + name: "InstanceEqual 1 param", + args: args{ + opts: []QueryOpt{ + InstancesEqual("instance2"), + }, + }, + want: &Query{ + instances: &filter[[]string]{ + condition: database.NewTextEqual("instance2"), + value: toPtr([]string{"instance2"}), + }, + }, + }, + { + name: "InstanceEqual 2 params", + args: args{ + opts: []QueryOpt{ + InstancesEqual("instance2"), + InstancesEqual("inst", "ancestor"), + }, + }, + want: &Query{ + instances: &filter[[]string]{ + condition: database.NewListEquals("inst", "ancestor"), + value: toPtr([]string{"inst", "ancestor"}), + }, + }, + }, + { + name: "InstancesContains no param", + args: args{ + opts: []QueryOpt{ + InstancesContains(), + }, + }, + want: &Query{ + instances: &filter[[]string]{ + condition: database.NewTextEqual("instance"), + value: toPtr([]string{"instance"}), + }, + }, + }, + { + name: "InstancesContains 1 param", + args: args{ + opts: []QueryOpt{ + InstancesContains("instance2"), + }, + }, + want: &Query{ + instances: &filter[[]string]{ + condition: database.NewTextEqual("instance2"), + value: toPtr([]string{"instance2"}), + }, + }, + }, + { + name: "InstancesContains 2 params", + args: args{ + opts: []QueryOpt{ + InstancesContains("instance2"), + InstancesContains("inst", "ancestor"), + }, + }, + want: &Query{ + instances: &filter[[]string]{ + condition: database.NewListContains("inst", "ancestor"), + value: toPtr([]string{"inst", "ancestor"}), + }, + }, + }, + { + name: "InstancesNotContains no param", + args: args{ + opts: []QueryOpt{ + InstancesNotContains(), + }, + }, + want: &Query{ + instances: &filter[[]string]{ + condition: database.NewTextEqual("instance"), + value: toPtr([]string{"instance"}), + }, + }, + }, + { + name: "InstancesNotContains 1 param", + args: args{ + opts: []QueryOpt{ + InstancesNotContains("instance2"), + }, + }, + want: &Query{ + instances: &filter[[]string]{ + condition: database.NewTextUnequal("instance2"), + value: toPtr([]string{"instance2"}), + }, + }, + }, + { + name: "InstancesNotContains 2 params", + args: args{ + opts: []QueryOpt{ + InstancesNotContains("instance2"), + InstancesNotContains("inst", "ancestor"), + }, + }, + want: &Query{ + instances: &filter[[]string]{ + condition: database.NewListNotContains("inst", "ancestor"), + value: toPtr([]string{"inst", "ancestor"}), + }, + }, + }, + { + name: "AppendFilters", + args: args{ + opts: []QueryOpt{ + AppendFilters(NewFilter()), + AppendFilters(NewFilter()), + }, + }, + want: &Query{ + instances: &filter[[]string]{ + condition: database.NewTextEqual("instance"), + value: toPtr([]string{"instance"}), + }, + filters: make([]*Filter, 2), + }, + }, + { + name: "AppendFilter", + args: args{ + opts: []QueryOpt{ + AppendFilter(), + AppendFilter(), + }, + }, + want: &Query{ + instances: &filter[[]string]{ + condition: database.NewTextEqual("instance"), + value: toPtr([]string{"instance"}), + }, + filters: make([]*Filter, 2), + }, + }, + { + name: "Filter", + args: args{ + opts: []QueryOpt{ + SetFilters(NewFilter()), + SetFilters(NewFilter()), + }, + }, + want: &Query{ + instances: &filter[[]string]{ + condition: database.NewTextEqual("instance"), + value: toPtr([]string{"instance"}), + }, + filters: make([]*Filter, 1), + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := NewQuery("instance", nil, tt.args.opts...) + + if !reflect.DeepEqual(tt.want.Instance(), got.Instance()) { + t.Errorf("unexpected instances %v, want: %v", got.instances, tt.want.instances) + } + if len(tt.want.Filters()) != len(got.Filters()) { + t.Errorf("unexpected length of filters %v, want: %v", len(got.filters), len(tt.want.filters)) + } + if !reflect.DeepEqual(tt.want.Tx(), got.Tx()) { + t.Errorf("unexpected tx %v, want: %v", got.tx, tt.want.tx) + } + if !reflect.DeepEqual(tt.want.Pagination(), got.Pagination()) { + t.Errorf("unexpected pagination %v, want: %v", got.pagination, tt.want.pagination) + } + }) + } +} + +func toPtr[T any](value T) *T { + return &value +} diff --git a/internal/v2/eventstore/unique_constraint.go b/internal/v2/eventstore/unique_constraint.go new file mode 100644 index 0000000000..4486e19e5d --- /dev/null +++ b/internal/v2/eventstore/unique_constraint.go @@ -0,0 +1,80 @@ +package eventstore + +type UniqueConstraint struct { + // UniqueType is the table name for the unique constraint + UniqueType string + // UniqueField is the unique key + UniqueField string + // Action defines if unique constraint should be added or removed + Action UniqueConstraintAction + // ErrorMessage defines the translation file key for the error message + ErrorMessage string + // IsGlobal defines if the unique constraint is globally unique or just within a single instance + IsGlobal bool +} + +type UniqueConstraintAction int8 + +const ( + UniqueConstraintAdd UniqueConstraintAction = iota + UniqueConstraintRemove + UniqueConstraintInstanceRemove + + uniqueConstraintActionCount +) + +func (f UniqueConstraintAction) Valid() bool { + return f >= 0 && f < uniqueConstraintActionCount +} + +func NewAddEventUniqueConstraint( + uniqueType, + uniqueField, + errMessage string) *UniqueConstraint { + return &UniqueConstraint{ + UniqueType: uniqueType, + UniqueField: uniqueField, + ErrorMessage: errMessage, + Action: UniqueConstraintAdd, + } +} + +func NewRemoveUniqueConstraint( + uniqueType, + uniqueField string) *UniqueConstraint { + return &UniqueConstraint{ + UniqueType: uniqueType, + UniqueField: uniqueField, + Action: UniqueConstraintRemove, + } +} + +func NewRemoveInstanceUniqueConstraints() *UniqueConstraint { + return &UniqueConstraint{ + Action: UniqueConstraintInstanceRemove, + } +} + +func NewAddGlobalUniqueConstraint( + uniqueType, + uniqueField, + errMessage string) *UniqueConstraint { + return &UniqueConstraint{ + UniqueType: uniqueType, + UniqueField: uniqueField, + ErrorMessage: errMessage, + IsGlobal: true, + Action: UniqueConstraintAdd, + } +} + +func NewRemoveGlobalUniqueConstraint( + uniqueType, + uniqueField string) *UniqueConstraint { + return &UniqueConstraint{ + UniqueType: uniqueType, + UniqueField: uniqueField, + IsGlobal: true, + Action: UniqueConstraintRemove, + } +}