diff --git a/internal/eventstore/v2/repository/sql/crdb.go b/internal/eventstore/v2/repository/sql/crdb.go index 0a383c7c39..51818be239 100644 --- a/internal/eventstore/v2/repository/sql/crdb.go +++ b/internal/eventstore/v2/repository/sql/crdb.go @@ -80,7 +80,9 @@ const ( " CASE " + " WHEN NOT check_previous " + " THEN max_event_seq " + - " ELSE previous_sequence " + + " WHEN previous_sequence > 0" + + " THEN previous_sequence " + + " ELSE previous_sequence " + " END" + " ) " + " FROM input_event " + diff --git a/internal/eventstore/v2/repository/sql/crdb_test.go b/internal/eventstore/v2/repository/sql/crdb_test.go index ce3d02cb3e..cf66d1b9d0 100644 --- a/internal/eventstore/v2/repository/sql/crdb_test.go +++ b/internal/eventstore/v2/repository/sql/crdb_test.go @@ -5,6 +5,7 @@ import ( "testing" "github.com/caos/zitadel/internal/eventstore/v2/repository" + "github.com/lib/pq" _ "github.com/lib/pq" ) @@ -262,61 +263,171 @@ func TestCRDB_columnName(t *testing.T) { } } -func TestCRDB_Push(t *testing.T) { +func TestCRDB_Push_OneAggregate(t *testing.T) { type args struct { events []*repository.Event } + type eventsRes struct { + pushedEventsCount int + aggType repository.AggregateType + aggID []string + } + type res struct { + wantErr bool + eventsRes eventsRes + } tests := []struct { - name string - args args - wantErr bool + name string + args args + res res }{ { name: "push no events", args: args{ events: []*repository.Event{}, }, - wantErr: false, + res: res{ + wantErr: false, + eventsRes: eventsRes{ + pushedEventsCount: 0, + aggID: []string{"0"}, + aggType: repository.AggregateType(t.Name()), + }, + }, }, { name: "push 1 event with check previous", args: args{ events: []*repository.Event{ - { - // AggregateID: t.Name(), - AggregateType: "test", - CheckPreviousSequence: true, - EditorService: "svc", - EditorUser: "user", - PreviousEvent: nil, - PreviousSequence: 0, - ResourceOwner: "ro", - Type: "test.created", - Version: "v1", - }, + generateEvent(t, "1", true, 0), }, }, - wantErr: false, + res: res{ + wantErr: false, + eventsRes: eventsRes{ + pushedEventsCount: 1, + aggID: []string{"1"}, + aggType: repository.AggregateType(t.Name()), + }}, }, { - name: "push 1 event with check previous wrong sequence", + name: "fail push 1 event with check previous wrong sequence", args: args{ events: []*repository.Event{ - { - // AggregateID: t.Name(), - AggregateType: "test", - CheckPreviousSequence: true, - EditorService: "svc", - EditorUser: "user", - PreviousEvent: nil, - PreviousSequence: 5, - ResourceOwner: "ro", - Type: "test.created", - Version: "v1", - }, + generateEvent(t, "2", true, 5), + }, + }, + res: res{ + wantErr: true, + eventsRes: eventsRes{ + pushedEventsCount: 0, + aggID: []string{"2"}, + aggType: repository.AggregateType(t.Name()), + }, + }, + }, + { + name: "push 1 event without check previous", + args: args{ + events: []*repository.Event{ + generateEvent(t, "3", false, 0), + }, + }, + res: res{ + wantErr: false, + eventsRes: eventsRes{ + pushedEventsCount: 1, + aggID: []string{"3"}, + aggType: repository.AggregateType(t.Name()), + }, + }, + }, + { + name: "push 1 event without check previous wrong sequence", + args: args{ + events: []*repository.Event{ + generateEvent(t, "4", false, 5), + }, + }, + res: res{ + wantErr: false, + eventsRes: eventsRes{ + pushedEventsCount: 1, + aggID: []string{"4"}, + aggType: repository.AggregateType(t.Name()), + }, + }, + }, + { + name: "fail on push two events on agg without linking", + args: args{ + events: []*repository.Event{ + generateEvent(t, "5", true, 0), + generateEvent(t, "5", true, 0), + }, + }, + res: res{ + wantErr: true, + eventsRes: eventsRes{ + pushedEventsCount: 0, + aggID: []string{"5"}, + aggType: repository.AggregateType(t.Name()), + }, + }, + }, + { + name: "push two events on agg with linking", + args: args{ + events: linkEvents( + generateEvent(t, "6", true, 0), + generateEvent(t, "6", true, 0), + ), + }, + res: res{ + wantErr: false, + eventsRes: eventsRes{ + pushedEventsCount: 2, + aggID: []string{"6"}, + aggType: repository.AggregateType(t.Name()), + }, + }, + }, + { + name: "push two events on agg with linking without check previous", + args: args{ + events: linkEvents( + generateEvent(t, "7", false, 0), + generateEvent(t, "7", false, 0), + ), + }, + res: res{ + wantErr: false, + eventsRes: eventsRes{ + pushedEventsCount: 2, + aggID: []string{"7"}, + aggType: repository.AggregateType(t.Name()), + }, + }, + }, + { + name: "push two events on agg with linking mixed check previous", + args: args{ + events: linkEvents( + generateEvent(t, "8", false, 0), + generateEvent(t, "8", true, 0), + generateEvent(t, "8", false, 0), + generateEvent(t, "8", true, 0), + generateEvent(t, "8", true, 0), + ), + }, + res: res{ + wantErr: false, + eventsRes: eventsRes{ + pushedEventsCount: 5, + aggID: []string{"8"}, + aggType: repository.AggregateType(t.Name()), }, }, - wantErr: true, }, } for _, tt := range tests { @@ -324,9 +435,146 @@ func TestCRDB_Push(t *testing.T) { db := &CRDB{ client: testCRDBClient, } - if err := db.Push(context.Background(), tt.args.events...); (err != nil) != tt.wantErr { - t.Errorf("CRDB.Push() error = %v, wantErr %v", err, tt.wantErr) + if err := db.Push(context.Background(), tt.args.events...); (err != nil) != tt.res.wantErr { + t.Errorf("CRDB.Push() error = %v, wantErr %v", err, tt.res.wantErr) + } + + countRow := testCRDBClient.QueryRow("SELECT COUNT(*) FROM eventstore.events where aggregate_type = $1 AND aggregate_id = ANY($2)", tt.res.eventsRes.aggType, pq.Array(tt.res.eventsRes.aggID)) + var count int + err := countRow.Scan(&count) + if err != nil { + t.Error("unable to query inserted rows: ", err) + return + } + if count != tt.res.eventsRes.pushedEventsCount { + t.Errorf("expected push count %d got %d", tt.res.eventsRes.pushedEventsCount, count) } }) } } + +func TestCRDB_Push_MultipleAggregate(t *testing.T) { + type args struct { + events []*repository.Event + } + type eventsRes struct { + pushedEventsCount int + aggType []repository.AggregateType + aggID []string + } + type res struct { + wantErr bool + eventsRes eventsRes + } + tests := []struct { + name string + args args + res res + }{ + { + name: "push two aggregates both check previous", + args: args{ + events: []*repository.Event{ + generateEvent(t, "100", true, 0), + generateEvent(t, "101", true, 0), + }, + }, + res: res{ + wantErr: false, + eventsRes: eventsRes{ + pushedEventsCount: 2, + aggID: []string{"100", "101"}, + aggType: []repository.AggregateType{repository.AggregateType(t.Name())}, + }, + }, + }, + { + name: "push two aggregates both check previous multiple events", + args: args{ + events: combineEventLists( + linkEvents( + generateEvent(t, "102", true, 0), + generateEvent(t, "102", true, 0), + ), + linkEvents( + generateEvent(t, "103", true, 0), + generateEvent(t, "103", true, 0), + ), + ), + }, + res: res{ + wantErr: false, + eventsRes: eventsRes{ + pushedEventsCount: 4, + aggID: []string{"102", "103"}, + aggType: []repository.AggregateType{repository.AggregateType(t.Name())}, + }, + }, + }, + { + name: "fail push linked events of different aggregates", + args: args{ + events: linkEvents( + generateEvent(t, "104", false, 0), + generateEvent(t, "104", false, 0), + ), + }, + res: res{ + wantErr: false, + eventsRes: eventsRes{ + pushedEventsCount: 0, + aggID: []string{"104"}, + aggType: []repository.AggregateType{repository.AggregateType(t.Name())}, + }, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + db := &CRDB{ + client: testCRDBClient, + } + if err := db.Push(context.Background(), tt.args.events...); (err != nil) != tt.res.wantErr { + t.Errorf("CRDB.Push() error = %v, wantErr %v", err, tt.res.wantErr) + } + + countRow := testCRDBClient.QueryRow("SELECT COUNT(*) FROM eventstore.events where aggregate_type = ANY($1) AND aggregate_id = ANY($2)", pq.Array(tt.res.eventsRes.aggType), pq.Array(tt.res.eventsRes.aggID)) + var count int + err := countRow.Scan(&count) + if err != nil { + t.Error("unable to query inserted rows: ", err) + return + } + if count != tt.res.eventsRes.pushedEventsCount { + t.Errorf("expected push count %d got %d", tt.res.eventsRes.pushedEventsCount, count) + } + }) + } +} + +func combineEventLists(firstList []*repository.Event, secondList []*repository.Event) []*repository.Event { + return append(firstList, secondList...) +} + +func linkEvents(events ...*repository.Event) []*repository.Event { + for i := 1; i < len(events); i++ { + events[i].PreviousEvent = events[i-1] + } + return events +} + +func generateEvent(t *testing.T, aggregateID string, checkPrevious bool, previousSeq uint64) *repository.Event { + t.Helper() + return &repository.Event{ + AggregateID: aggregateID, + AggregateType: repository.AggregateType(t.Name()), + CheckPreviousSequence: checkPrevious, + EditorService: "svc", + EditorUser: "user", + PreviousEvent: nil, + PreviousSequence: previousSeq, + ResourceOwner: "ro", + Type: "test.created", + Version: "v1", + } +}