mirror of
https://github.com/zitadel/zitadel.git
synced 2025-08-12 01:37:31 +00:00
testing with local cockroach started for tests and migrations
This commit is contained in:
@@ -87,20 +87,20 @@ func (es *Eventstore) aggregatesToEvents(aggregates []aggregater) ([]*repository
|
||||
for _, aggregate := range aggregates {
|
||||
var previousEvent *repository.Event
|
||||
for _, event := range aggregate.Events() {
|
||||
//TODO: map event.Data() into json
|
||||
var data []byte
|
||||
events = append(events, &repository.Event{
|
||||
AggregateID: aggregate.ID(),
|
||||
AggregateType: repository.AggregateType(aggregate.Type()),
|
||||
ResourceOwner: aggregate.ResourceOwner(),
|
||||
EditorService: event.EditorService(),
|
||||
EditorUser: event.EditorUser(),
|
||||
Type: repository.EventType(event.Type()),
|
||||
Version: repository.Version(aggregate.Version()),
|
||||
PreviousEvent: previousEvent,
|
||||
Data: event.Data(),
|
||||
AggregateID: aggregate.ID(),
|
||||
AggregateType: repository.AggregateType(aggregate.Type()),
|
||||
ResourceOwner: aggregate.ResourceOwner(),
|
||||
EditorService: event.EditorService(),
|
||||
EditorUser: event.EditorUser(),
|
||||
Type: repository.EventType(event.Type()),
|
||||
Version: repository.Version(aggregate.Version()),
|
||||
PreviousEvent: previousEvent,
|
||||
Data: data,
|
||||
PreviousSequence: event.PreviousSequence(),
|
||||
})
|
||||
if previousEvent != nil && event.CheckPrevious() {
|
||||
events[len(events)-1].PreviousSequence = event.PreviousSequence()
|
||||
}
|
||||
previousEvent = events[len(events)-1]
|
||||
}
|
||||
}
|
||||
|
@@ -1,17 +0,0 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/cockroachdb/cockroach/pkg/base"
|
||||
"github.com/cockroachdb/cockroach/pkg/server"
|
||||
"github.com/cockroachdb/cockroach/pkg/testutils/serverutils"
|
||||
)
|
||||
|
||||
func TestBlub(t *testing.T) {
|
||||
s, db, kvDB := serverutils.StartServer(t, base.TestServerArgs{})
|
||||
defer s.Stopper().Stop()
|
||||
// If really needed, in tests that can depend on server, downcast to
|
||||
// server.TestServer:
|
||||
ts := s.(*server.TestServer)
|
||||
}
|
227
internal/eventstore/v2/repository/crdb.go
Normal file
227
internal/eventstore/v2/repository/crdb.go
Normal file
@@ -0,0 +1,227 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
|
||||
"github.com/caos/logging"
|
||||
caos_errs "github.com/caos/zitadel/internal/errors"
|
||||
"github.com/cockroachdb/cockroach-go/v2/crdb"
|
||||
|
||||
//sql import for cockroach
|
||||
_ "github.com/lib/pq"
|
||||
)
|
||||
|
||||
const (
|
||||
crdbInsert = "WITH input_event ( " +
|
||||
" event_type, " +
|
||||
" aggregate_type, " +
|
||||
" aggregate_id, " +
|
||||
" aggregate_version, " +
|
||||
" creation_date, " +
|
||||
" event_data, " +
|
||||
" editor_user, " +
|
||||
" editor_service, " +
|
||||
" resource_owner, " +
|
||||
" previous_sequence, " +
|
||||
" check_previous, " +
|
||||
// variables below are calculated
|
||||
" max_event_seq, " +
|
||||
" event_count " +
|
||||
") " +
|
||||
" AS( " +
|
||||
" SELECT " +
|
||||
" $1::VARCHAR," +
|
||||
" $2::VARCHAR," +
|
||||
" $3::VARCHAR," +
|
||||
" $4::VARCHAR," +
|
||||
" COALESCE($5::TIMESTAMPTZ, NOW()), " +
|
||||
" $6::JSONB, " +
|
||||
" $7::VARCHAR, " +
|
||||
" $8::VARCHAR, " +
|
||||
" $9::VARCHAR, " +
|
||||
" $10::BIGINT, " +
|
||||
" $11::BOOLEAN," +
|
||||
" MAX(event_sequence) AS max_event_seq, " +
|
||||
" COUNT(*) AS event_count " +
|
||||
" FROM eventstore.events " +
|
||||
" WHERE " +
|
||||
" aggregate_type = $2::VARCHAR " +
|
||||
" AND aggregate_id = $3::VARCHAR " +
|
||||
") " +
|
||||
"INSERT INTO eventstore.events " +
|
||||
" ( " +
|
||||
" event_type, " +
|
||||
" aggregate_type," +
|
||||
" aggregate_id, " +
|
||||
" aggregate_version, " +
|
||||
" creation_date, " +
|
||||
" event_data, " +
|
||||
" editor_user, " +
|
||||
" editor_service, " +
|
||||
" resource_owner, " +
|
||||
" previous_sequence " +
|
||||
" ) " +
|
||||
" ( " +
|
||||
" SELECT " +
|
||||
" event_type, " +
|
||||
" aggregate_type," +
|
||||
" aggregate_id, " +
|
||||
" aggregate_version, " +
|
||||
" COALESCE(creation_date, NOW()), " +
|
||||
" event_data, " +
|
||||
" editor_user, " +
|
||||
" editor_service, " +
|
||||
" resource_owner, " +
|
||||
" ( " +
|
||||
" SELECT " +
|
||||
" CASE " +
|
||||
" WHEN NOT check_previous THEN " +
|
||||
" max_event_seq " +
|
||||
" ELSE " +
|
||||
" previous_sequence " +
|
||||
" END" +
|
||||
" ) " +
|
||||
" FROM input_event " +
|
||||
" WHERE EXISTS ( " +
|
||||
" SELECT " +
|
||||
" CASE " +
|
||||
" WHEN NOT check_previous THEN 1 " +
|
||||
" ELSE ( " +
|
||||
" SELECT 1 FROM input_event " +
|
||||
" WHERE max_event_seq = previous_sequence OR (previous_sequence IS NULL AND event_count = 0) " +
|
||||
" ) " +
|
||||
" END " +
|
||||
" ) " +
|
||||
" ) " +
|
||||
"RETURNING event_sequence, creation_date "
|
||||
)
|
||||
|
||||
type CRDB struct {
|
||||
db *sql.DB
|
||||
}
|
||||
|
||||
func (db *CRDB) Health(ctx context.Context) error { return db.db.Ping() }
|
||||
|
||||
// Push adds all events to the eventstreams of the aggregates.
|
||||
// This call is transaction save. The transaction will be rolled back if one event fails
|
||||
func (db *CRDB) Push(ctx context.Context, events ...*Event) error {
|
||||
err := crdb.ExecuteTx(ctx, db.db, nil, func(tx *sql.Tx) error {
|
||||
stmt, err := tx.PrepareContext(ctx, crdbInsert)
|
||||
if err != nil {
|
||||
tx.Rollback()
|
||||
logging.Log("SQL-3to5p").WithError(err).Warn("prepare failed")
|
||||
return caos_errs.ThrowInternal(err, "SQL-OdXRE", "prepare failed")
|
||||
}
|
||||
for _, event := range events {
|
||||
previousSequence := event.PreviousSequence
|
||||
if event.PreviousEvent != nil {
|
||||
previousSequence = event.PreviousSequence
|
||||
}
|
||||
err = stmt.QueryRowContext(ctx,
|
||||
event.Type,
|
||||
event.AggregateType,
|
||||
event.AggregateID,
|
||||
event.Version,
|
||||
event.CreationDate,
|
||||
event.Data,
|
||||
event.EditorUser,
|
||||
event.EditorService,
|
||||
event.ResourceOwner,
|
||||
previousSequence,
|
||||
event.CheckPreviousSequence,
|
||||
).Scan(&event.Sequence, &event.CreationDate)
|
||||
|
||||
if err != nil {
|
||||
tx.Rollback()
|
||||
|
||||
logging.LogWithFields("SQL-IP3js",
|
||||
"aggregate", event.AggregateType,
|
||||
"aggregateId", event.AggregateID,
|
||||
"aggregateType", event.AggregateType,
|
||||
"eventType", event.Type).WithError(err).Info("query failed")
|
||||
return caos_errs.ThrowInternal(err, "SQL-SBP37", "unable to create event")
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if err != nil && !errors.Is(err, &caos_errs.CaosError{}) {
|
||||
err = caos_errs.ThrowInternal(err, "SQL-DjgtG", "unable to store events")
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// Filter returns all events matching the given search query
|
||||
// func (db *CRDB) Filter(ctx context.Context, searchQuery *SearchQuery) (events []*Event, err error) {
|
||||
|
||||
// return events, nil
|
||||
// }
|
||||
|
||||
//LatestSequence returns the latests sequence found by the the search query
|
||||
func (db *CRDB) LatestSequence(ctx context.Context, queryFactory *SearchQuery) (uint64, error) {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func (db *CRDB) prepareQuery(columns Columns) (string, func(s scanner, dest interface{}) error) {
|
||||
switch columns {
|
||||
case Columns_Max_Sequence:
|
||||
return "SELECT MAX(event_sequence) FROM eventstore.events", func(scan scanner, dest interface{}) (err error) {
|
||||
sequence, ok := dest.(*Sequence)
|
||||
if !ok {
|
||||
return caos_errs.ThrowInvalidArgument(nil, "SQL-NBjA9", "type must be sequence")
|
||||
}
|
||||
err = scan(sequence)
|
||||
if err == nil || errors.Is(err, sql.ErrNoRows) {
|
||||
return nil
|
||||
}
|
||||
return caos_errs.ThrowInternal(err, "SQL-bN5xg", "something went wrong")
|
||||
}
|
||||
case Columns_Event:
|
||||
return selectStmt, func(row scanner, dest interface{}) (err error) {
|
||||
event, ok := dest.(*Event)
|
||||
if !ok {
|
||||
return caos_errs.ThrowInvalidArgument(nil, "SQL-4GP6F", "type must be event")
|
||||
}
|
||||
var previousSequence Sequence
|
||||
data := make(Data, 0)
|
||||
|
||||
err = row(
|
||||
&event.CreationDate,
|
||||
&event.Type,
|
||||
&event.Sequence,
|
||||
&previousSequence,
|
||||
&data,
|
||||
&event.EditorService,
|
||||
&event.EditorUser,
|
||||
&event.ResourceOwner,
|
||||
&event.AggregateType,
|
||||
&event.AggregateID,
|
||||
&event.Version,
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
logging.Log("SQL-kn1Sw").WithError(err).Warn("unable to scan row")
|
||||
return caos_errs.ThrowInternal(err, "SQL-J0hFS", "unable to scan row")
|
||||
}
|
||||
|
||||
event.PreviousSequence = uint64(previousSequence)
|
||||
|
||||
event.Data = make([]byte, len(data))
|
||||
copy(event.Data, data)
|
||||
|
||||
return nil
|
||||
}
|
||||
default:
|
||||
return "", nil
|
||||
}
|
||||
}
|
||||
|
||||
func (db *CRDB) prepareFilter(filters []*Filter) string {
|
||||
filter := ""
|
||||
// for _, f := range filters{
|
||||
// f.
|
||||
// }
|
||||
return filter
|
||||
}
|
@@ -1,6 +1,7 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"time"
|
||||
)
|
||||
|
||||
@@ -11,6 +12,7 @@ type Event struct {
|
||||
|
||||
//Sequence is the sequence of the event
|
||||
Sequence uint64
|
||||
|
||||
//PreviousSequence is the sequence of the previous sequence
|
||||
// if it's 0 then it's the first event of this aggregate
|
||||
PreviousSequence uint64
|
||||
@@ -19,6 +21,10 @@ type Event struct {
|
||||
// it implements a linked list
|
||||
PreviousEvent *Event
|
||||
|
||||
//CheckPreviousSequence decides if the event can only be written
|
||||
// if event.PreviousSequence == max(event_sequence) of this aggregate
|
||||
CheckPreviousSequence bool
|
||||
|
||||
//CreationDate is the time the event is created
|
||||
// it's used for human readability.
|
||||
// Don't use it for event ordering,
|
||||
@@ -31,7 +37,7 @@ type Event struct {
|
||||
|
||||
//Data describe the changed fields (e.g. userName = "hodor")
|
||||
// data must always a pointer to a struct, a struct or a byte array containing json bytes
|
||||
Data interface{}
|
||||
Data []byte
|
||||
|
||||
//EditorService should be a unique identifier for the service which created the event
|
||||
// it's meant for maintainability
|
||||
@@ -54,6 +60,8 @@ type Event struct {
|
||||
// an aggregate can only be managed by one organisation
|
||||
// use the ID of the org
|
||||
ResourceOwner string
|
||||
|
||||
stmt *sql.Stmt
|
||||
}
|
||||
|
||||
//EventType is the description of the change
|
||||
|
62
internal/eventstore/v2/repository/filter.go
Normal file
62
internal/eventstore/v2/repository/filter.go
Normal file
@@ -0,0 +1,62 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
|
||||
"github.com/caos/logging"
|
||||
"github.com/caos/zitadel/internal/errors"
|
||||
es_models "github.com/caos/zitadel/internal/eventstore/models"
|
||||
)
|
||||
|
||||
type Querier interface {
|
||||
Query(query string, args ...interface{}) (*sql.Rows, error)
|
||||
}
|
||||
|
||||
// Filter returns all events matching the given search query
|
||||
func (db *CRDB) Filter(ctx context.Context, searchQuery *es_models.SearchQueryFactory) (events []*Event, err error) {
|
||||
return filter(db.db, searchQuery)
|
||||
}
|
||||
|
||||
func filter(querier Querier, searchQuery *es_models.SearchQueryFactory) (events []*Event, err error) {
|
||||
query, limit, values, rowScanner := buildQuery(searchQuery)
|
||||
if query == "" {
|
||||
return nil, errors.ThrowInvalidArgument(nil, "SQL-rWeBw", "invalid query factory")
|
||||
}
|
||||
|
||||
rows, err := querier.Query(query, values...)
|
||||
if err != nil {
|
||||
logging.Log("SQL-HP3Uk").WithError(err).Info("query failed")
|
||||
return nil, errors.ThrowInternal(err, "SQL-IJuyR", "unable to filter events")
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
events = make([]*Event, 0, limit)
|
||||
|
||||
for rows.Next() {
|
||||
event := new(Event)
|
||||
err := rowScanner(rows.Scan, event)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
events = append(events, event)
|
||||
}
|
||||
|
||||
return events, nil
|
||||
}
|
||||
|
||||
// func (db *SQL) LatestSequence(ctx context.Context, queryFactory *es_models.SearchQueryFactory) (uint64, error) {
|
||||
// query, _, values, rowScanner := buildQuery(queryFactory)
|
||||
// if query == "" {
|
||||
// return 0, errors.ThrowInvalidArgument(nil, "SQL-rWeBw", "invalid query factory")
|
||||
// }
|
||||
// row := db.client.QueryRow(query, values...)
|
||||
// sequence := new(Sequence)
|
||||
// err := rowScanner(row.Scan, sequence)
|
||||
// if err != nil {
|
||||
// logging.Log("SQL-WsxTg").WithError(err).Info("query failed")
|
||||
// return 0, errors.ThrowInternal(err, "SQL-Yczyx", "unable to filter latest sequence")
|
||||
// }
|
||||
// return uint64(*sequence), nil
|
||||
// }
|
@@ -1,82 +0,0 @@
|
||||
package repository
|
||||
|
||||
import "context"
|
||||
|
||||
type InMemory struct {
|
||||
events []*Event
|
||||
}
|
||||
|
||||
func (repo *InMemory) Health(ctx context.Context) error { return nil }
|
||||
|
||||
// PushEvents adds all events of the given aggregates to the eventstreams of the aggregates.
|
||||
// This call is transaction save. The transaction will be rolled back if one event fails
|
||||
func (repo *InMemory) Push(ctx context.Context, events ...*Event) error {
|
||||
repo.events = append(repo.events, events...)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Filter returns all events matching the given search query
|
||||
func (repo *InMemory) Filter(ctx context.Context, searchQuery *SearchQuery) (events []*Event, err error) {
|
||||
indexes := repo.filter(searchQuery)
|
||||
events = make([]*Event, len(indexes))
|
||||
for i, index := range indexes {
|
||||
events[i] = repo.events[index]
|
||||
}
|
||||
|
||||
return events, nil
|
||||
}
|
||||
|
||||
func (repo *InMemory) filter(query *SearchQuery) []int {
|
||||
foundIndex := make([]int, 0, query.Limit)
|
||||
events:
|
||||
for i, event := range repo.events {
|
||||
if query.Limit > 0 && uint64(len(foundIndex)) < query.Limit {
|
||||
return foundIndex
|
||||
}
|
||||
for _, filter := range query.Filters {
|
||||
var value interface{}
|
||||
switch filter.field {
|
||||
case Field_AggregateID:
|
||||
value = event.AggregateID
|
||||
case Field_EditorService:
|
||||
value = event.EditorService
|
||||
case Field_EventType:
|
||||
value = event.Type
|
||||
case Field_AggregateType:
|
||||
value = event.AggregateType
|
||||
case Field_EditorUser:
|
||||
value = event.EditorUser
|
||||
case Field_ResourceOwner:
|
||||
value = event.ResourceOwner
|
||||
case Field_LatestSequence:
|
||||
value = event.Sequence
|
||||
}
|
||||
switch filter.operation {
|
||||
case Operation_Equals:
|
||||
if filter.value == value {
|
||||
foundIndex = append(foundIndex, i)
|
||||
}
|
||||
case Operation_Greater:
|
||||
fallthrough
|
||||
case Operation_Less:
|
||||
|
||||
return nil
|
||||
case Operation_In:
|
||||
values := filter.Value().([]interface{})
|
||||
for _, val := range values {
|
||||
if val == value {
|
||||
foundIndex = append(foundIndex, i)
|
||||
continue events
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return foundIndex
|
||||
}
|
||||
|
||||
//LatestSequence returns the latests sequence found by the the search query
|
||||
func (repo *InMemory) LatestSequence(ctx context.Context, queryFactory *SearchQuery) (uint64, error) {
|
||||
return 0, nil
|
||||
}
|
163
internal/eventstore/v2/repository/local_crdb_test.go
Normal file
163
internal/eventstore/v2/repository/local_crdb_test.go
Normal file
@@ -0,0 +1,163 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/caos/logging"
|
||||
"github.com/cockroachdb/cockroach-go/v2/testserver"
|
||||
)
|
||||
|
||||
var (
|
||||
migrationsPath = os.ExpandEnv("${GOPATH}/src/github.com/caos/zitadel/migrations/cockroach")
|
||||
db *sql.DB
|
||||
)
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
ts, err := testserver.NewTestServer()
|
||||
if err != nil {
|
||||
logging.LogWithFields("REPOS-RvjLG", "error", err).Fatal("unable to start db")
|
||||
}
|
||||
|
||||
db, err = sql.Open("postgres", ts.PGURL().String())
|
||||
if err != nil {
|
||||
logging.LogWithFields("REPOS-CF6dQ", "error", err).Fatal("unable to connect to db")
|
||||
}
|
||||
|
||||
defer func() {
|
||||
db.Close()
|
||||
ts.Stop()
|
||||
}()
|
||||
|
||||
if err = executeMigrations(); err != nil {
|
||||
logging.LogWithFields("REPOS-jehDD", "error", err).Fatal("migrations failed")
|
||||
}
|
||||
|
||||
os.Exit(m.Run())
|
||||
}
|
||||
|
||||
func TestInsert(t *testing.T) {
|
||||
tx, _ := db.Begin()
|
||||
|
||||
var seq Sequence
|
||||
var d time.Time
|
||||
|
||||
row := tx.QueryRow(crdbInsert, "event.type", "aggregate.type", "aggregate.id", Version("v1"), nil, Data(nil), "editor.user", "editor.service", "resource.owner", Sequence(0), false)
|
||||
err := row.Scan(&seq, &d)
|
||||
|
||||
row = tx.QueryRow(crdbInsert, "event.type", "aggregate.type", "aggregate.id", Version("v1"), nil, Data(nil), "editor.user", "editor.service", "resource.owner", Sequence(1), true)
|
||||
err = row.Scan(&seq, &d)
|
||||
|
||||
row = tx.QueryRow(crdbInsert, "event.type", "aggregate.type", "aggregate.id", Version("v1"), nil, Data(nil), "editor.user", "editor.service", "resource.owner", Sequence(0), false)
|
||||
err = row.Scan(&seq, &d)
|
||||
|
||||
tx.Commit()
|
||||
|
||||
rows, err := db.Query("select * from eventstore.events order by event_sequence")
|
||||
defer rows.Close()
|
||||
fmt.Println(err)
|
||||
|
||||
fmt.Println(rows.Columns())
|
||||
for rows.Next() {
|
||||
i := make([]interface{}, 12)
|
||||
var id string
|
||||
rows.Scan(&id, &i[1], &i[2], &i[3], &i[4], &i[5], &i[6], &i[7], &i[8], &i[9], &i[10], &i[11])
|
||||
i[0] = id
|
||||
|
||||
fmt.Println(i)
|
||||
}
|
||||
|
||||
t.Fail()
|
||||
}
|
||||
|
||||
func executeMigrations() error {
|
||||
files, err := migrationFilePaths()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
sort.Sort(files)
|
||||
for _, file := range files {
|
||||
migration, err := ioutil.ReadFile(string(file))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
transactionInMigration := strings.Contains(string(migration), "BEGIN;")
|
||||
exec := db.Exec
|
||||
var tx *sql.Tx
|
||||
if !transactionInMigration {
|
||||
tx, err = db.Begin()
|
||||
if err != nil {
|
||||
return fmt.Errorf("begin file: %v || err: %w", file, err)
|
||||
}
|
||||
exec = tx.Exec
|
||||
}
|
||||
if _, err = exec(string(migration)); err != nil {
|
||||
return fmt.Errorf("exec file: %v || err: %w", file, err)
|
||||
}
|
||||
if !transactionInMigration {
|
||||
if err = tx.Commit(); err != nil {
|
||||
return fmt.Errorf("commit file: %v || err: %w", file, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type migrationPaths []string
|
||||
|
||||
type version struct {
|
||||
major int
|
||||
minor int
|
||||
}
|
||||
|
||||
func versionFromPath(s string) version {
|
||||
v := s[strings.Index(s, "/V")+2 : strings.Index(s, "__")]
|
||||
splitted := strings.Split(v, ".")
|
||||
res := version{}
|
||||
var err error
|
||||
if len(splitted) >= 1 {
|
||||
res.major, err = strconv.Atoi(splitted[0])
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
|
||||
if len(splitted) >= 2 {
|
||||
res.minor, err = strconv.Atoi(splitted[1])
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
|
||||
return res
|
||||
}
|
||||
|
||||
func (a migrationPaths) Len() int { return len(a) }
|
||||
func (a migrationPaths) Swap(i, j int) { a[i], a[j] = a[j], a[i] }
|
||||
func (a migrationPaths) Less(i, j int) bool {
|
||||
versionI := versionFromPath(a[i])
|
||||
versionJ := versionFromPath(a[j])
|
||||
|
||||
return versionI.major < versionJ.major ||
|
||||
(versionI.major == versionJ.major && versionI.minor < versionJ.minor)
|
||||
}
|
||||
|
||||
func migrationFilePaths() (migrationPaths, error) {
|
||||
files := make(migrationPaths, 0)
|
||||
err := filepath.Walk(migrationsPath, func(path string, info os.FileInfo, err error) error {
|
||||
if err != nil || info.IsDir() || !strings.HasSuffix(info.Name(), ".sql") {
|
||||
return err
|
||||
}
|
||||
files = append(files, path)
|
||||
return nil
|
||||
})
|
||||
return files, err
|
||||
}
|
199
internal/eventstore/v2/repository/query.go
Normal file
199
internal/eventstore/v2/repository/query.go
Normal file
@@ -0,0 +1,199 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/caos/logging"
|
||||
z_errors "github.com/caos/zitadel/internal/errors"
|
||||
"github.com/caos/zitadel/internal/eventstore/models"
|
||||
es_models "github.com/caos/zitadel/internal/eventstore/models"
|
||||
"github.com/lib/pq"
|
||||
)
|
||||
|
||||
const (
|
||||
selectStmt = "SELECT" +
|
||||
" creation_date" +
|
||||
", event_type" +
|
||||
", event_sequence" +
|
||||
", previous_sequence" +
|
||||
", event_data" +
|
||||
", editor_service" +
|
||||
", editor_user" +
|
||||
", resource_owner" +
|
||||
", aggregate_type" +
|
||||
", aggregate_id" +
|
||||
", aggregate_version" +
|
||||
" FROM eventstore.events"
|
||||
)
|
||||
|
||||
func buildQuery(queryFactory *models.SearchQueryFactory) (query string, limit uint64, values []interface{}, rowScanner func(s scanner, dest interface{}) error) {
|
||||
searchQuery, err := queryFactory.Build()
|
||||
if err != nil {
|
||||
logging.Log("SQL-cshKu").WithError(err).Warn("search query factory invalid")
|
||||
return "", 0, nil, nil
|
||||
}
|
||||
query, rowScanner = prepareColumns(searchQuery.Columns)
|
||||
where, values := prepareCondition(searchQuery.Filters)
|
||||
if where == "" || query == "" {
|
||||
return "", 0, nil, nil
|
||||
}
|
||||
query += where
|
||||
|
||||
if searchQuery.Columns != models.Columns_Max_Sequence {
|
||||
query += " ORDER BY event_sequence"
|
||||
if searchQuery.Desc {
|
||||
query += " DESC"
|
||||
}
|
||||
}
|
||||
|
||||
if searchQuery.Limit > 0 {
|
||||
values = append(values, searchQuery.Limit)
|
||||
query += " LIMIT ?"
|
||||
}
|
||||
|
||||
query = numberPlaceholder(query, "?", "$")
|
||||
|
||||
return query, searchQuery.Limit, values, rowScanner
|
||||
}
|
||||
|
||||
func prepareCondition(filters []*models.Filter) (clause string, values []interface{}) {
|
||||
values = make([]interface{}, len(filters))
|
||||
clauses := make([]string, len(filters))
|
||||
|
||||
if len(filters) == 0 {
|
||||
return clause, values
|
||||
}
|
||||
for i, filter := range filters {
|
||||
value := filter.GetValue()
|
||||
switch value.(type) {
|
||||
case []bool, []float64, []int64, []string, []models.AggregateType, []models.EventType, *[]bool, *[]float64, *[]int64, *[]string, *[]models.AggregateType, *[]models.EventType:
|
||||
value = pq.Array(value)
|
||||
}
|
||||
|
||||
clauses[i] = getCondition(filter)
|
||||
if clauses[i] == "" {
|
||||
return "", nil
|
||||
}
|
||||
values[i] = value
|
||||
}
|
||||
return " WHERE " + strings.Join(clauses, " AND "), values
|
||||
}
|
||||
|
||||
type scanner func(dest ...interface{}) error
|
||||
|
||||
func prepareColumns(columns models.Columns) (string, func(s scanner, dest interface{}) error) {
|
||||
switch columns {
|
||||
case models.Columns_Max_Sequence:
|
||||
return "SELECT MAX(event_sequence) FROM eventstore.events", func(row scanner, dest interface{}) (err error) {
|
||||
sequence, ok := dest.(*Sequence)
|
||||
if !ok {
|
||||
return z_errors.ThrowInvalidArgument(nil, "SQL-NBjA9", "type must be sequence")
|
||||
}
|
||||
err = row(sequence)
|
||||
if err == nil || errors.Is(err, sql.ErrNoRows) {
|
||||
return nil
|
||||
}
|
||||
return z_errors.ThrowInternal(err, "SQL-bN5xg", "something went wrong")
|
||||
}
|
||||
case models.Columns_Event:
|
||||
return selectStmt, func(row scanner, dest interface{}) (err error) {
|
||||
event, ok := dest.(*models.Event)
|
||||
if !ok {
|
||||
return z_errors.ThrowInvalidArgument(nil, "SQL-4GP6F", "type must be event")
|
||||
}
|
||||
var previousSequence Sequence
|
||||
data := make(Data, 0)
|
||||
|
||||
err = row(
|
||||
&event.CreationDate,
|
||||
&event.Type,
|
||||
&event.Sequence,
|
||||
&previousSequence,
|
||||
&data,
|
||||
&event.EditorService,
|
||||
&event.EditorUser,
|
||||
&event.ResourceOwner,
|
||||
&event.AggregateType,
|
||||
&event.AggregateID,
|
||||
&event.AggregateVersion,
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
logging.Log("SQL-kn1Sw").WithError(err).Warn("unable to scan row")
|
||||
return z_errors.ThrowInternal(err, "SQL-J0hFS", "unable to scan row")
|
||||
}
|
||||
|
||||
event.PreviousSequence = uint64(previousSequence)
|
||||
|
||||
event.Data = make([]byte, len(data))
|
||||
copy(event.Data, data)
|
||||
|
||||
return nil
|
||||
}
|
||||
default:
|
||||
return "", nil
|
||||
}
|
||||
}
|
||||
|
||||
func numberPlaceholder(query, old, new string) string {
|
||||
for i, hasChanged := 1, true; hasChanged; i++ {
|
||||
newQuery := strings.Replace(query, old, new+strconv.Itoa(i), 1)
|
||||
hasChanged = query != newQuery
|
||||
query = newQuery
|
||||
}
|
||||
return query
|
||||
}
|
||||
|
||||
func getCondition(filter *es_models.Filter) (condition string) {
|
||||
field := getField(filter.GetField())
|
||||
operation := getOperation(filter.GetOperation())
|
||||
if field == "" || operation == "" {
|
||||
return ""
|
||||
}
|
||||
format := getConditionFormat(filter.GetOperation())
|
||||
|
||||
return fmt.Sprintf(format, field, operation)
|
||||
}
|
||||
|
||||
func getConditionFormat(operation es_models.Operation) string {
|
||||
if operation == es_models.Operation_In {
|
||||
return "%s %s ANY(?)"
|
||||
}
|
||||
return "%s %s ?"
|
||||
}
|
||||
|
||||
func getField(field es_models.Field) string {
|
||||
switch field {
|
||||
case es_models.Field_AggregateID:
|
||||
return "aggregate_id"
|
||||
case es_models.Field_AggregateType:
|
||||
return "aggregate_type"
|
||||
case es_models.Field_LatestSequence:
|
||||
return "event_sequence"
|
||||
case es_models.Field_ResourceOwner:
|
||||
return "resource_owner"
|
||||
case es_models.Field_EditorService:
|
||||
return "editor_service"
|
||||
case es_models.Field_EditorUser:
|
||||
return "editor_user"
|
||||
case es_models.Field_EventType:
|
||||
return "event_type"
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func getOperation(operation es_models.Operation) string {
|
||||
switch operation {
|
||||
case es_models.Operation_Equals, es_models.Operation_In:
|
||||
return "="
|
||||
case es_models.Operation_Greater:
|
||||
return ">"
|
||||
case es_models.Operation_Less:
|
||||
return "<"
|
||||
}
|
||||
return ""
|
||||
}
|
486
internal/eventstore/v2/repository/query_test.go
Normal file
486
internal/eventstore/v2/repository/query_test.go
Normal file
@@ -0,0 +1,486 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"reflect"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/caos/zitadel/internal/errors"
|
||||
"github.com/caos/zitadel/internal/eventstore/models"
|
||||
es_models "github.com/caos/zitadel/internal/eventstore/models"
|
||||
"github.com/lib/pq"
|
||||
)
|
||||
|
||||
func Test_numberPlaceholder(t *testing.T) {
|
||||
type args struct {
|
||||
query string
|
||||
old string
|
||||
new string
|
||||
}
|
||||
type res struct {
|
||||
query string
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
res res
|
||||
}{
|
||||
{
|
||||
name: "no replaces",
|
||||
args: args{
|
||||
new: "$",
|
||||
old: "?",
|
||||
query: "SELECT * FROM eventstore.events",
|
||||
},
|
||||
res: res{
|
||||
query: "SELECT * FROM eventstore.events",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "two replaces",
|
||||
args: args{
|
||||
new: "$",
|
||||
old: "?",
|
||||
query: "SELECT * FROM eventstore.events WHERE aggregate_type = ? AND LIMIT = ?",
|
||||
},
|
||||
res: res{
|
||||
query: "SELECT * FROM eventstore.events WHERE aggregate_type = $1 AND LIMIT = $2",
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := numberPlaceholder(tt.args.query, tt.args.old, tt.args.new); got != tt.res.query {
|
||||
t.Errorf("numberPlaceholder() = %v, want %v", got, tt.res.query)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_getOperation(t *testing.T) {
|
||||
t.Run("all ops", func(t *testing.T) {
|
||||
for op, expected := range map[es_models.Operation]string{
|
||||
es_models.Operation_Equals: "=",
|
||||
es_models.Operation_In: "=",
|
||||
es_models.Operation_Greater: ">",
|
||||
es_models.Operation_Less: "<",
|
||||
es_models.Operation(-1): "",
|
||||
} {
|
||||
if got := getOperation(op); got != expected {
|
||||
t.Errorf("getOperation() = %v, want %v", got, expected)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func Test_getField(t *testing.T) {
|
||||
t.Run("all fields", func(t *testing.T) {
|
||||
for field, expected := range map[es_models.Field]string{
|
||||
es_models.Field_AggregateType: "aggregate_type",
|
||||
es_models.Field_AggregateID: "aggregate_id",
|
||||
es_models.Field_LatestSequence: "event_sequence",
|
||||
es_models.Field_ResourceOwner: "resource_owner",
|
||||
es_models.Field_EditorService: "editor_service",
|
||||
es_models.Field_EditorUser: "editor_user",
|
||||
es_models.Field_EventType: "event_type",
|
||||
es_models.Field(-1): "",
|
||||
} {
|
||||
if got := getField(field); got != expected {
|
||||
t.Errorf("getField() = %v, want %v", got, expected)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func Test_getConditionFormat(t *testing.T) {
|
||||
type args struct {
|
||||
operation es_models.Operation
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "no in operation",
|
||||
args: args{
|
||||
operation: es_models.Operation_Equals,
|
||||
},
|
||||
want: "%s %s ?",
|
||||
},
|
||||
{
|
||||
name: "in operation",
|
||||
args: args{
|
||||
operation: es_models.Operation_In,
|
||||
},
|
||||
want: "%s %s ANY(?)",
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := getConditionFormat(tt.args.operation); got != tt.want {
|
||||
t.Errorf("prepareConditionFormat() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_getCondition(t *testing.T) {
|
||||
type args struct {
|
||||
filter *es_models.Filter
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "equals",
|
||||
args: args{filter: es_models.NewFilter(es_models.Field_AggregateID, "", es_models.Operation_Equals)},
|
||||
want: "aggregate_id = ?",
|
||||
},
|
||||
{
|
||||
name: "greater",
|
||||
args: args{filter: es_models.NewFilter(es_models.Field_LatestSequence, 0, es_models.Operation_Greater)},
|
||||
want: "event_sequence > ?",
|
||||
},
|
||||
{
|
||||
name: "less",
|
||||
args: args{filter: es_models.NewFilter(es_models.Field_LatestSequence, 5000, es_models.Operation_Less)},
|
||||
want: "event_sequence < ?",
|
||||
},
|
||||
{
|
||||
name: "in list",
|
||||
args: args{filter: es_models.NewFilter(es_models.Field_AggregateType, []es_models.AggregateType{"movies", "actors"}, es_models.Operation_In)},
|
||||
want: "aggregate_type = ANY(?)",
|
||||
},
|
||||
{
|
||||
name: "invalid operation",
|
||||
args: args{filter: es_models.NewFilter(es_models.Field_AggregateType, []es_models.AggregateType{"movies", "actors"}, es_models.Operation(-1))},
|
||||
want: "",
|
||||
},
|
||||
{
|
||||
name: "invalid field",
|
||||
args: args{filter: es_models.NewFilter(es_models.Field(-1), []es_models.AggregateType{"movies", "actors"}, es_models.Operation_Equals)},
|
||||
want: "",
|
||||
},
|
||||
{
|
||||
name: "invalid field and operation",
|
||||
args: args{filter: es_models.NewFilter(es_models.Field(-1), []es_models.AggregateType{"movies", "actors"}, es_models.Operation(-1))},
|
||||
want: "",
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := getCondition(tt.args.filter); got != tt.want {
|
||||
t.Errorf("getCondition() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_prepareColumns(t *testing.T) {
|
||||
type args struct {
|
||||
columns models.Columns
|
||||
dest interface{}
|
||||
dbErr error
|
||||
}
|
||||
type res struct {
|
||||
query string
|
||||
dbRow []interface{}
|
||||
expected interface{}
|
||||
dbErr func(error) bool
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
res res
|
||||
}{
|
||||
{
|
||||
name: "invalid columns",
|
||||
args: args{columns: es_models.Columns(-1)},
|
||||
res: res{
|
||||
query: "",
|
||||
dbErr: func(err error) bool { return err == nil },
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "max column",
|
||||
args: args{
|
||||
columns: es_models.Columns_Max_Sequence,
|
||||
dest: new(Sequence),
|
||||
},
|
||||
res: res{
|
||||
query: "SELECT MAX(event_sequence) FROM eventstore.events",
|
||||
dbRow: []interface{}{Sequence(5)},
|
||||
expected: Sequence(5),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "max sequence wrong dest type",
|
||||
args: args{
|
||||
columns: es_models.Columns_Max_Sequence,
|
||||
dest: new(uint64),
|
||||
},
|
||||
res: res{
|
||||
query: "SELECT MAX(event_sequence) FROM eventstore.events",
|
||||
dbErr: errors.IsErrorInvalidArgument,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "event",
|
||||
args: args{
|
||||
columns: es_models.Columns_Event,
|
||||
dest: new(models.Event),
|
||||
},
|
||||
res: res{
|
||||
query: "SELECT creation_date, event_type, event_sequence, previous_sequence, event_data, editor_service, editor_user, resource_owner, aggregate_type, aggregate_id, aggregate_version FROM eventstore.events",
|
||||
dbRow: []interface{}{time.Time{}, models.EventType(""), uint64(5), Sequence(0), Data(nil), "", "", "", models.AggregateType("user"), "hodor", models.Version("")},
|
||||
expected: models.Event{AggregateID: "hodor", AggregateType: "user", Sequence: 5, Data: make(Data, 0)},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "event wrong dest type",
|
||||
args: args{
|
||||
columns: es_models.Columns_Event,
|
||||
dest: new(uint64),
|
||||
},
|
||||
res: res{
|
||||
query: "SELECT creation_date, event_type, event_sequence, previous_sequence, event_data, editor_service, editor_user, resource_owner, aggregate_type, aggregate_id, aggregate_version FROM eventstore.events",
|
||||
dbErr: errors.IsErrorInvalidArgument,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "event query error",
|
||||
args: args{
|
||||
columns: es_models.Columns_Event,
|
||||
dest: new(models.Event),
|
||||
dbErr: sql.ErrConnDone,
|
||||
},
|
||||
res: res{
|
||||
query: "SELECT creation_date, event_type, event_sequence, previous_sequence, event_data, editor_service, editor_user, resource_owner, aggregate_type, aggregate_id, aggregate_version FROM eventstore.events",
|
||||
dbErr: errors.IsInternal,
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
query, rowScanner := prepareColumns(tt.args.columns)
|
||||
if query != tt.res.query {
|
||||
t.Errorf("prepareColumns() got = %v, want %v", query, tt.res.query)
|
||||
}
|
||||
if tt.res.query == "" && rowScanner != nil {
|
||||
t.Errorf("row scanner should be nil")
|
||||
}
|
||||
if rowScanner == nil {
|
||||
return
|
||||
}
|
||||
err := rowScanner(prepareTestScan(tt.args.dbErr, tt.res.dbRow), tt.args.dest)
|
||||
if tt.res.dbErr != nil {
|
||||
if !tt.res.dbErr(err) {
|
||||
t.Errorf("wrong error type in rowScanner got: %v", err)
|
||||
}
|
||||
} else {
|
||||
if !reflect.DeepEqual(reflect.Indirect(reflect.ValueOf(tt.args.dest)).Interface(), tt.res.expected) {
|
||||
t.Errorf("unexpected result from rowScanner want: %v got: %v", tt.res.dbRow, tt.args.dest)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func prepareTestScan(err error, res []interface{}) scanner {
|
||||
return func(dests ...interface{}) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if len(dests) != len(res) {
|
||||
return errors.ThrowInvalidArgumentf(nil, "SQL-NML1q", "expected len %d got %d", len(res), len(dests))
|
||||
}
|
||||
for i, r := range res {
|
||||
reflect.ValueOf(dests[i]).Elem().Set(reflect.ValueOf(r))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func Test_prepareCondition(t *testing.T) {
|
||||
type args struct {
|
||||
filters []*models.Filter
|
||||
}
|
||||
type res struct {
|
||||
clause string
|
||||
values []interface{}
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
res res
|
||||
}{
|
||||
{
|
||||
name: "nil filters",
|
||||
args: args{
|
||||
filters: nil,
|
||||
},
|
||||
res: res{
|
||||
clause: "",
|
||||
values: nil,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "empty filters",
|
||||
args: args{
|
||||
filters: []*es_models.Filter{},
|
||||
},
|
||||
res: res{
|
||||
clause: "",
|
||||
values: nil,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "invalid condition",
|
||||
args: args{
|
||||
filters: []*es_models.Filter{
|
||||
es_models.NewFilter(es_models.Field_AggregateID, "wrong", es_models.Operation(-1)),
|
||||
},
|
||||
},
|
||||
res: res{
|
||||
clause: "",
|
||||
values: nil,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "array as condition value",
|
||||
args: args{
|
||||
filters: []*es_models.Filter{
|
||||
es_models.NewFilter(es_models.Field_AggregateType, []es_models.AggregateType{"user", "org"}, es_models.Operation_In),
|
||||
},
|
||||
},
|
||||
res: res{
|
||||
clause: " WHERE aggregate_type = ANY(?)",
|
||||
values: []interface{}{pq.Array([]es_models.AggregateType{"user", "org"})},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "multiple filters",
|
||||
args: args{
|
||||
filters: []*es_models.Filter{
|
||||
es_models.NewFilter(es_models.Field_AggregateType, []es_models.AggregateType{"user", "org"}, es_models.Operation_In),
|
||||
es_models.NewFilter(es_models.Field_AggregateID, "1234", es_models.Operation_Equals),
|
||||
es_models.NewFilter(es_models.Field_EventType, []es_models.EventType{"user.created", "org.created"}, es_models.Operation_In),
|
||||
},
|
||||
},
|
||||
res: res{
|
||||
clause: " WHERE aggregate_type = ANY(?) AND aggregate_id = ? AND event_type = ANY(?)",
|
||||
values: []interface{}{pq.Array([]es_models.AggregateType{"user", "org"}), "1234", pq.Array([]es_models.EventType{"user.created", "org.created"})},
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
gotClause, gotValues := prepareCondition(tt.args.filters)
|
||||
if gotClause != tt.res.clause {
|
||||
t.Errorf("prepareCondition() gotClause = %v, want %v", gotClause, tt.res.clause)
|
||||
}
|
||||
if len(gotValues) != len(tt.res.values) {
|
||||
t.Errorf("wrong length of gotten values got = %d, want %d", len(gotValues), len(tt.res.values))
|
||||
return
|
||||
}
|
||||
for i, value := range gotValues {
|
||||
if !reflect.DeepEqual(value, tt.res.values[i]) {
|
||||
t.Errorf("prepareCondition() gotValues = %v, want %v", gotValues, tt.res.values)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_buildQuery(t *testing.T) {
|
||||
type args struct {
|
||||
queryFactory *models.SearchQueryFactory
|
||||
}
|
||||
type res struct {
|
||||
query string
|
||||
limit uint64
|
||||
values []interface{}
|
||||
rowScanner bool
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
res res
|
||||
}{
|
||||
{
|
||||
name: "invalid query factory",
|
||||
args: args{
|
||||
queryFactory: nil,
|
||||
},
|
||||
res: res{
|
||||
query: "",
|
||||
limit: 0,
|
||||
rowScanner: false,
|
||||
values: nil,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "with order by desc",
|
||||
args: args{
|
||||
queryFactory: es_models.NewSearchQueryFactory("user").OrderDesc(),
|
||||
},
|
||||
res: res{
|
||||
query: "SELECT creation_date, event_type, event_sequence, previous_sequence, event_data, editor_service, editor_user, resource_owner, aggregate_type, aggregate_id, aggregate_version FROM eventstore.events WHERE aggregate_type = $1 ORDER BY event_sequence DESC",
|
||||
rowScanner: true,
|
||||
values: []interface{}{es_models.AggregateType("user")},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "with limit",
|
||||
args: args{
|
||||
queryFactory: es_models.NewSearchQueryFactory("user").Limit(5),
|
||||
},
|
||||
res: res{
|
||||
query: "SELECT creation_date, event_type, event_sequence, previous_sequence, event_data, editor_service, editor_user, resource_owner, aggregate_type, aggregate_id, aggregate_version FROM eventstore.events WHERE aggregate_type = $1 ORDER BY event_sequence LIMIT $2",
|
||||
rowScanner: true,
|
||||
values: []interface{}{es_models.AggregateType("user"), uint64(5)},
|
||||
limit: 5,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "with limit and order by desc",
|
||||
args: args{
|
||||
queryFactory: es_models.NewSearchQueryFactory("user").Limit(5).OrderDesc(),
|
||||
},
|
||||
res: res{
|
||||
query: "SELECT creation_date, event_type, event_sequence, previous_sequence, event_data, editor_service, editor_user, resource_owner, aggregate_type, aggregate_id, aggregate_version FROM eventstore.events WHERE aggregate_type = $1 ORDER BY event_sequence DESC LIMIT $2",
|
||||
rowScanner: true,
|
||||
values: []interface{}{es_models.AggregateType("user"), uint64(5)},
|
||||
limit: 5,
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
gotQuery, gotLimit, gotValues, gotRowScanner := buildQuery(tt.args.queryFactory)
|
||||
if gotQuery != tt.res.query {
|
||||
t.Errorf("buildQuery() gotQuery = %v, want %v", gotQuery, tt.res.query)
|
||||
}
|
||||
if gotLimit != tt.res.limit {
|
||||
t.Errorf("buildQuery() gotLimit = %v, want %v", gotLimit, tt.res.limit)
|
||||
}
|
||||
if len(gotValues) != len(tt.res.values) {
|
||||
t.Errorf("wrong length of gotten values got = %d, want %d", len(gotValues), len(tt.res.values))
|
||||
return
|
||||
}
|
||||
for i, value := range gotValues {
|
||||
if !reflect.DeepEqual(value, tt.res.values[i]) {
|
||||
t.Errorf("prepareCondition() gotValues = %v, want %v", gotValues, tt.res.values)
|
||||
}
|
||||
}
|
||||
if (tt.res.rowScanner && gotRowScanner == nil) || (!tt.res.rowScanner && gotRowScanner != nil) {
|
||||
t.Errorf("rowScanner should be nil==%v got nil==%v", tt.res.rowScanner, gotRowScanner == nil)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
47
internal/eventstore/v2/repository/sql_types.go
Normal file
47
internal/eventstore/v2/repository/sql_types.go
Normal file
@@ -0,0 +1,47 @@
|
||||
package repository
|
||||
|
||||
import "database/sql/driver"
|
||||
|
||||
// Data represents a byte array that may be null.
|
||||
// Data implements the sql.Scanner interface
|
||||
type Data []byte
|
||||
|
||||
// Scan implements the Scanner interface.
|
||||
func (data *Data) Scan(value interface{}) error {
|
||||
if value == nil {
|
||||
*data = nil
|
||||
return nil
|
||||
}
|
||||
*data = Data(value.([]byte))
|
||||
return nil
|
||||
}
|
||||
|
||||
// Value implements the driver Valuer interface.
|
||||
func (data Data) Value() (driver.Value, error) {
|
||||
if len(data) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
return []byte(data), nil
|
||||
}
|
||||
|
||||
// Sequence represents a number that may be null.
|
||||
// Sequence implements the sql.Scanner interface
|
||||
type Sequence uint64
|
||||
|
||||
// Scan implements the Scanner interface.
|
||||
func (seq *Sequence) Scan(value interface{}) error {
|
||||
if value == nil {
|
||||
*seq = 0
|
||||
return nil
|
||||
}
|
||||
*seq = Sequence(value.(int64))
|
||||
return nil
|
||||
}
|
||||
|
||||
// Value implements the driver Valuer interface.
|
||||
func (seq Sequence) Value() (driver.Value, error) {
|
||||
if seq == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
return int64(seq), nil
|
||||
}
|
Reference in New Issue
Block a user