diff --git a/internal/eventstore/v2/repository/search_query.go b/internal/eventstore/v2/repository/search_query.go index 931e6234fa..3c09cd5008 100644 --- a/internal/eventstore/v2/repository/search_query.go +++ b/internal/eventstore/v2/repository/search_query.go @@ -19,9 +19,9 @@ const ( ) type Filter struct { - field Field - value interface{} - operation Operation + Field Field + Value interface{} + Operation Operation } type Operation int32 @@ -48,33 +48,33 @@ const ( //NewFilter is used in tests. Use searchQuery.*Filter() instead func NewFilter(field Field, value interface{}, operation Operation) *Filter { return &Filter{ - field: field, - value: value, - operation: operation, + Field: field, + Value: value, + Operation: operation, } } -func (f *Filter) Field() Field { - return f.field -} -func (f *Filter) Operation() Operation { - return f.operation -} -func (f *Filter) Value() interface{} { - return f.value -} +// func (f *Filter) Field() Field { +// return f.field +// } +// func (f *Filter) Operation() Operation { +// return f.operation +// } +// func (f *Filter) Value() interface{} { +// return f.value +// } func (f *Filter) Validate() error { if f == nil { return errors.ThrowPreconditionFailed(nil, "REPO-z6KcG", "filter is nil") } - if f.field <= 0 { + if f.Field <= 0 { return errors.ThrowPreconditionFailed(nil, "REPO-zw62U", "field not definded") } - if f.value == nil { + if f.Value == nil { return errors.ThrowPreconditionFailed(nil, "REPO-GJ9ct", "no value definded") } - if f.operation <= 0 { + if f.Operation <= 0 { return errors.ThrowPreconditionFailed(nil, "REPO-RrQTy", "operation not definded") } return nil diff --git a/internal/eventstore/v2/repository/search_query_test.go b/internal/eventstore/v2/repository/search_query_test.go index e35d5d4299..e83533d5fc 100644 --- a/internal/eventstore/v2/repository/search_query_test.go +++ b/internal/eventstore/v2/repository/search_query_test.go @@ -23,7 +23,7 @@ func TestNewFilter(t *testing.T) { value: "hodor", operation: Operation_Equals, }, - want: &Filter{field: Field_AggregateID, operation: Operation_Equals, value: "hodor"}, + want: &Filter{Field: Field_AggregateID, Operation: Operation_Equals, Value: "hodor"}, }, } for _, tt := range tests { @@ -91,9 +91,9 @@ func TestFilter_Validate(t *testing.T) { var f *Filter if !tt.fields.isNil { f = &Filter{ - field: tt.fields.field, - value: tt.fields.value, - operation: tt.fields.operation, + Field: tt.fields.field, + Value: tt.fields.value, + Operation: tt.fields.operation, } } if err := f.Validate(); (err != nil) != tt.wantErr { diff --git a/internal/eventstore/v2/repository/sql/crdb.go b/internal/eventstore/v2/repository/sql/crdb.go index 3a75825784..7e8bfc75cc 100644 --- a/internal/eventstore/v2/repository/sql/crdb.go +++ b/internal/eventstore/v2/repository/sql/crdb.go @@ -158,58 +158,46 @@ func (db *CRDB) Push(ctx context.Context, events ...*repository.Event) error { // Filter returns all events matching the given search query func (db *CRDB) Filter(ctx context.Context, searchQuery *repository.SearchQuery) (events []*repository.Event, err error) { - rows, rowScanner, err := db.query(searchQuery) + events = []*repository.Event{} + err = db.query(searchQuery, &events) if err != nil { return nil, err } - defer rows.Close() - - for rows.Next() { - event := new(repository.Event) - err := rowScanner(rows.Scan, event) - if err != nil { - return nil, err - } - - events = append(events, event) - } return events, nil } //LatestSequence returns the latests sequence found by the the search query func (db *CRDB) LatestSequence(ctx context.Context, searchQuery *repository.SearchQuery) (uint64, error) { - rows, rowScanner, err := db.query(searchQuery) - if err != nil { - return 0, err - } - defer rows.Close() - - if !rows.Next() { - return 0, caos_errs.ThrowNotFound(nil, "SQL-cAEzS", "latest sequence not found") - } - var seq Sequence - err = rowScanner(rows.Scan, &seq) + err := db.query(searchQuery, &seq) if err != nil { return 0, err } - return uint64(seq), nil } -func (db *CRDB) query(searchQuery *repository.SearchQuery) (*sql.Rows, rowScan, error) { +func (db *CRDB) query(searchQuery *repository.SearchQuery, data interface{}) error { query, values, rowScanner := buildQuery(db, searchQuery) if query == "" { - return nil, nil, caos_errs.ThrowInvalidArgument(nil, "SQL-rWeBw", "invalid query factory") + return caos_errs.ThrowInvalidArgument(nil, "SQL-rWeBw", "invalid query factory") } rows, err := db.client.Query(query, values...) if err != nil { logging.Log("SQL-HP3Uk").WithError(err).Info("query failed") - return nil, nil, caos_errs.ThrowInternal(err, "SQL-IJuyR", "unable to filter events") + return caos_errs.ThrowInternal(err, "SQL-IJuyR", "unable to filter events") } - return rows, rowScanner, nil + defer rows.Close() + + for rows.Next() { + err = rowScanner(rows.Scan, nil) + if err != nil { + return err + } + } + + return nil } func (db *CRDB) eventQuery() string { diff --git a/internal/eventstore/v2/repository/sql/crdb_test.go b/internal/eventstore/v2/repository/sql/crdb_test.go new file mode 100644 index 0000000000..3d6007a3ce --- /dev/null +++ b/internal/eventstore/v2/repository/sql/crdb_test.go @@ -0,0 +1,262 @@ +package sql + +import ( + "testing" + + "github.com/caos/zitadel/internal/eventstore/v2/repository" + _ "github.com/lib/pq" +) + +func TestCRDB_placeholder(t *testing.T) { + type args struct { + query string + } + type res struct { + query string + } + tests := []struct { + name string + args args + res res + }{ + { + name: "no placeholders", + args: args{ + query: "SELECT * FROM eventstore.events", + }, + res: res{ + query: "SELECT * FROM eventstore.events", + }, + }, + { + name: "one placeholder", + args: args{ + query: "SELECT * FROM eventstore.events WHERE aggregate_type = ?", + }, + res: res{ + query: "SELECT * FROM eventstore.events WHERE aggregate_type = $1", + }, + }, + { + name: "multiple placeholders", + args: args{ + query: "SELECT * FROM eventstore.events WHERE aggregate_type = ? AND aggregate_id = ? LIMIT ?", + }, + res: res{ + query: "SELECT * FROM eventstore.events WHERE aggregate_type = $1 AND aggregate_id = $2 LIMIT $3", + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + db := &CRDB{} + if query := db.placeholder(tt.args.query); query != tt.res.query { + t.Errorf("CRDB.placeholder() = %v, want %v", query, tt.res.query) + } + }) + } +} + +func TestCRDB_operation(t *testing.T) { + type res struct { + op string + } + type args struct { + operation repository.Operation + } + tests := []struct { + name string + args args + res res + }{ + { + name: "no op", + args: args{ + operation: repository.Operation(-1), + }, + res: res{ + op: "", + }, + }, + { + name: "greater", + args: args{ + operation: repository.Operation_Greater, + }, + res: res{ + op: ">", + }, + }, + { + name: "less", + args: args{ + operation: repository.Operation_Less, + }, + res: res{ + op: "<", + }, + }, + { + name: "equals", + args: args{ + operation: repository.Operation_Equals, + }, + res: res{ + op: "=", + }, + }, + { + name: "in", + args: args{ + operation: repository.Operation_In, + }, + res: res{ + op: "=", + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + db := &CRDB{} + if got := db.operation(tt.args.operation); got != tt.res.op { + t.Errorf("CRDB.operation() = %v, want %v", got, tt.res.op) + } + }) + } +} + +func TestCRDB_conditionFormat(t *testing.T) { + type res struct { + format string + } + type args struct { + operation repository.Operation + } + tests := []struct { + name string + args args + res res + }{ + { + name: "default", + args: args{ + operation: repository.Operation_Equals, + }, + res: res{ + format: "%s %s ?", + }, + }, + { + name: "in", + args: args{ + operation: repository.Operation_In, + }, + res: res{ + format: "%s %s ANY(?)", + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + db := &CRDB{} + if got := db.conditionFormat(tt.args.operation); got != tt.res.format { + t.Errorf("CRDB.conditionFormat() = %v, want %v", got, tt.res.format) + } + }) + } +} + +func TestCRDB_columnName(t *testing.T) { + type res struct { + name string + } + type args struct { + field repository.Field + } + tests := []struct { + name string + args args + res res + }{ + { + name: "invalid field", + args: args{ + field: repository.Field(-1), + }, + res: res{ + name: "", + }, + }, + { + name: "aggregate id", + args: args{ + field: repository.Field_AggregateID, + }, + res: res{ + name: "aggregate_id", + }, + }, + { + name: "aggregate type", + args: args{ + field: repository.Field_AggregateType, + }, + res: res{ + name: "aggregate_type", + }, + }, + { + name: "editor service", + args: args{ + field: repository.Field_EditorService, + }, + res: res{ + name: "editor_service", + }, + }, + { + name: "editor user", + args: args{ + field: repository.Field_EditorUser, + }, + res: res{ + name: "editor_user", + }, + }, + { + name: "event type", + args: args{ + field: repository.Field_EventType, + }, + res: res{ + name: "event_type", + }, + }, + { + name: "latest sequence", + args: args{ + field: repository.Field_LatestSequence, + }, + res: res{ + name: "event_sequence", + }, + }, + { + name: "resource owner", + args: args{ + field: repository.Field_ResourceOwner, + }, + res: res{ + name: "resource_owner", + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + db := &CRDB{} + if got := db.columnName(tt.args.field); got != tt.res.name { + t.Errorf("CRDB.operation() = %v, want %v", got, tt.res.name) + } + }) + } +} diff --git a/internal/eventstore/v2/repository/sql/local_crdb_test.go b/internal/eventstore/v2/repository/sql/local_crdb_test.go index 33bfd82a30..1ee47f7b50 100644 --- a/internal/eventstore/v2/repository/sql/local_crdb_test.go +++ b/internal/eventstore/v2/repository/sql/local_crdb_test.go @@ -1,163 +1,164 @@ package sql -import ( - "database/sql" - "fmt" - "io/ioutil" - "os" - "path/filepath" - "sort" - "strconv" - "strings" - "testing" - "time" +// import ( +// "database/sql" +// "fmt" +// "io/ioutil" +// "os" +// "path/filepath" +// "sort" +// "strconv" +// "strings" +// "testing" +// "time" - "github.com/caos/logging" - "github.com/cockroachdb/cockroach-go/v2/testserver" -) +// "github.com/caos/logging" +// "github.com/caos/zitadel/internal/eventstore/v2/repository" +// "github.com/cockroachdb/cockroach-go/v2/testserver" +// ) -var ( - migrationsPath = os.ExpandEnv("${GOPATH}/src/github.com/caos/zitadel/migrations/cockroach") - db *sql.DB -) +// var ( +// migrationsPath = os.ExpandEnv("${GOPATH}/src/github.com/caos/zitadel/migrations/cockroach") +// db *sql.DB +// ) -func TestMain(m *testing.M) { - ts, err := testserver.NewTestServer() - if err != nil { - logging.LogWithFields("REPOS-RvjLG", "error", err).Fatal("unable to start db") - } +// func TestMain(m *testing.M) { +// ts, err := testserver.NewTestServer() +// if err != nil { +// logging.LogWithFields("REPOS-RvjLG", "error", err).Fatal("unable to start db") +// } - db, err = sql.Open("postgres", ts.PGURL().String()) - if err != nil { - logging.LogWithFields("REPOS-CF6dQ", "error", err).Fatal("unable to connect to db") - } +// db, err = sql.Open("postgres", ts.PGURL().String()) +// if err != nil { +// logging.LogWithFields("REPOS-CF6dQ", "error", err).Fatal("unable to connect to db") +// } - defer func() { - db.Close() - ts.Stop() - }() +// defer func() { +// db.Close() +// ts.Stop() +// }() - if err = executeMigrations(); err != nil { - logging.LogWithFields("REPOS-jehDD", "error", err).Fatal("migrations failed") - } +// if err = executeMigrations(); err != nil { +// logging.LogWithFields("REPOS-jehDD", "error", err).Fatal("migrations failed") +// } - os.Exit(m.Run()) -} +// os.Exit(m.Run()) +// } -func TestInsert(t *testing.T) { - tx, _ := db.Begin() +// func TestInsert(t *testing.T) { +// tx, _ := db.Begin() - var seq Sequence - var d time.Time +// var seq Sequence +// var d time.Time - row := tx.QueryRow(crdbInsert, "event.type", "aggregate.type", "aggregate.id", Version("v1"), nil, Data(nil), "editor.user", "editor.service", "resource.owner", Sequence(0), false) - err := row.Scan(&seq, &d) +// row := tx.QueryRow(crdbInsert, "event.type", "aggregate.type", "aggregate.id", repository.Version("v1"), nil, Data(nil), "editor.user", "editor.service", "resource.owner", Sequence(0), false) +// err := row.Scan(&seq, &d) - row = tx.QueryRow(crdbInsert, "event.type", "aggregate.type", "aggregate.id", Version("v1"), nil, Data(nil), "editor.user", "editor.service", "resource.owner", Sequence(1), true) - err = row.Scan(&seq, &d) +// row = tx.QueryRow(crdbInsert, "event.type", "aggregate.type", "aggregate.id", repository.Version("v1"), nil, Data(nil), "editor.user", "editor.service", "resource.owner", Sequence(1), true) +// err = row.Scan(&seq, &d) - row = tx.QueryRow(crdbInsert, "event.type", "aggregate.type", "aggregate.id", Version("v1"), nil, Data(nil), "editor.user", "editor.service", "resource.owner", Sequence(0), false) - err = row.Scan(&seq, &d) +// row = tx.QueryRow(crdbInsert, "event.type", "aggregate.type", "aggregate.id", repository.Version("v1"), nil, Data(nil), "editor.user", "editor.service", "resource.owner", Sequence(0), false) +// err = row.Scan(&seq, &d) - tx.Commit() +// tx.Commit() - rows, err := db.Query("select * from eventstore.events order by event_sequence") - defer rows.Close() - fmt.Println(err) +// rows, err := db.Query("select * from eventstore.events order by event_sequence") +// defer rows.Close() +// fmt.Println(err) - fmt.Println(rows.Columns()) - for rows.Next() { - i := make([]interface{}, 12) - var id string - rows.Scan(&id, &i[1], &i[2], &i[3], &i[4], &i[5], &i[6], &i[7], &i[8], &i[9], &i[10], &i[11]) - i[0] = id +// fmt.Println(rows.Columns()) +// for rows.Next() { +// i := make([]interface{}, 12) +// var id string +// rows.Scan(&id, &i[1], &i[2], &i[3], &i[4], &i[5], &i[6], &i[7], &i[8], &i[9], &i[10], &i[11]) +// i[0] = id - fmt.Println(i) - } +// fmt.Println(i) +// } - t.Fail() -} +// t.Fail() +// } -func executeMigrations() error { - files, err := migrationFilePaths() - if err != nil { - return err - } - sort.Sort(files) - for _, file := range files { - migration, err := ioutil.ReadFile(string(file)) - if err != nil { - return err - } - transactionInMigration := strings.Contains(string(migration), "BEGIN;") - exec := db.Exec - var tx *sql.Tx - if !transactionInMigration { - tx, err = db.Begin() - if err != nil { - return fmt.Errorf("begin file: %v || err: %w", file, err) - } - exec = tx.Exec - } - if _, err = exec(string(migration)); err != nil { - return fmt.Errorf("exec file: %v || err: %w", file, err) - } - if !transactionInMigration { - if err = tx.Commit(); err != nil { - return fmt.Errorf("commit file: %v || err: %w", file, err) - } - } - } - return nil -} +// func executeMigrations() error { +// files, err := migrationFilePaths() +// if err != nil { +// return err +// } +// sort.Sort(files) +// for _, file := range files { +// migration, err := ioutil.ReadFile(string(file)) +// if err != nil { +// return err +// } +// transactionInMigration := strings.Contains(string(migration), "BEGIN;") +// exec := db.Exec +// var tx *sql.Tx +// if !transactionInMigration { +// tx, err = db.Begin() +// if err != nil { +// return fmt.Errorf("begin file: %v || err: %w", file, err) +// } +// exec = tx.Exec +// } +// if _, err = exec(string(migration)); err != nil { +// return fmt.Errorf("exec file: %v || err: %w", file, err) +// } +// if !transactionInMigration { +// if err = tx.Commit(); err != nil { +// return fmt.Errorf("commit file: %v || err: %w", file, err) +// } +// } +// } +// return nil +// } -type migrationPaths []string +// type migrationPaths []string -type version struct { - major int - minor int -} +// type version struct { +// major int +// minor int +// } -func versionFromPath(s string) version { - v := s[strings.Index(s, "/V")+2 : strings.Index(s, "__")] - splitted := strings.Split(v, ".") - res := version{} - var err error - if len(splitted) >= 1 { - res.major, err = strconv.Atoi(splitted[0]) - if err != nil { - panic(err) - } - } +// func versionFromPath(s string) version { +// v := s[strings.Index(s, "/V")+2 : strings.Index(s, "__")] +// splitted := strings.Split(v, ".") +// res := version{} +// var err error +// if len(splitted) >= 1 { +// res.major, err = strconv.Atoi(splitted[0]) +// if err != nil { +// panic(err) +// } +// } - if len(splitted) >= 2 { - res.minor, err = strconv.Atoi(splitted[1]) - if err != nil { - panic(err) - } - } +// if len(splitted) >= 2 { +// res.minor, err = strconv.Atoi(splitted[1]) +// if err != nil { +// panic(err) +// } +// } - return res -} +// return res +// } -func (a migrationPaths) Len() int { return len(a) } -func (a migrationPaths) Swap(i, j int) { a[i], a[j] = a[j], a[i] } -func (a migrationPaths) Less(i, j int) bool { - versionI := versionFromPath(a[i]) - versionJ := versionFromPath(a[j]) +// func (a migrationPaths) Len() int { return len(a) } +// func (a migrationPaths) Swap(i, j int) { a[i], a[j] = a[j], a[i] } +// func (a migrationPaths) Less(i, j int) bool { +// versionI := versionFromPath(a[i]) +// versionJ := versionFromPath(a[j]) - return versionI.major < versionJ.major || - (versionI.major == versionJ.major && versionI.minor < versionJ.minor) -} +// return versionI.major < versionJ.major || +// (versionI.major == versionJ.major && versionI.minor < versionJ.minor) +// } -func migrationFilePaths() (migrationPaths, error) { - files := make(migrationPaths, 0) - err := filepath.Walk(migrationsPath, func(path string, info os.FileInfo, err error) error { - if err != nil || info.IsDir() || !strings.HasSuffix(info.Name(), ".sql") { - return err - } - files = append(files, path) - return nil - }) - return files, err -} +// func migrationFilePaths() (migrationPaths, error) { +// files := make(migrationPaths, 0) +// err := filepath.Walk(migrationsPath, func(path string, info os.FileInfo, err error) error { +// if err != nil || info.IsDir() || !strings.HasSuffix(info.Name(), ".sql") { +// return err +// } +// files = append(files, path) +// return nil +// }) +// return files, err +// } diff --git a/internal/eventstore/v2/repository/sql/query.go b/internal/eventstore/v2/repository/sql/query.go index e362a4adb5..b54880eff5 100644 --- a/internal/eventstore/v2/repository/sql/query.go +++ b/internal/eventstore/v2/repository/sql/query.go @@ -52,15 +52,15 @@ func buildQuery(criteria criteriaer, searchQuery *repository.SearchQuery) (query func prepareColumns(criteria criteriaer, columns repository.Columns) (string, func(s scan, dest interface{}) error) { switch columns { case repository.Columns_Max_Sequence: - return criteria.maxSequenceQuery(), maxSequenceRowScanner + return criteria.maxSequenceQuery(), maxSequenceScanner case repository.Columns_Event: - return criteria.eventQuery(), eventRowScanner + return criteria.eventQuery(), eventsScanner default: return "", nil } } -func maxSequenceRowScanner(row scan, dest interface{}) (err error) { +func maxSequenceScanner(row scan, dest interface{}) (err error) { sequence, ok := dest.(*Sequence) if !ok { return z_errors.ThrowInvalidArgument(nil, "SQL-NBjA9", "type must be sequence") @@ -72,15 +72,16 @@ func maxSequenceRowScanner(row scan, dest interface{}) (err error) { return z_errors.ThrowInternal(err, "SQL-bN5xg", "something went wrong") } -func eventRowScanner(row scan, dest interface{}) (err error) { - event, ok := dest.(*repository.Event) +func eventsScanner(scanner scan, dest interface{}) (err error) { + events, ok := dest.(*[]*repository.Event) if !ok { return z_errors.ThrowInvalidArgument(nil, "SQL-4GP6F", "type must be event") } var previousSequence Sequence data := make(Data, 0) + event := new(repository.Event) - err = row( + err = scanner( &event.CreationDate, &event.Type, &event.Sequence, @@ -104,6 +105,8 @@ func eventRowScanner(row scan, dest interface{}) (err error) { event.Data = make([]byte, len(data)) copy(event.Data, data) + *events = append(*events, event) + return nil } @@ -115,7 +118,7 @@ func prepareCondition(criteria criteriaer, filters []*repository.Filter) (clause return clause, values } for i, filter := range filters { - value := filter.Value() + value := filter.Value switch value.(type) { case []bool, []float64, []int64, []string, []repository.AggregateType, []repository.EventType, *[]bool, *[]float64, *[]int64, *[]string, *[]repository.AggregateType, *[]repository.EventType: value = pq.Array(value) @@ -131,12 +134,12 @@ func prepareCondition(criteria criteriaer, filters []*repository.Filter) (clause } func getCondition(cond criteriaer, filter *repository.Filter) (condition string) { - field := cond.columnName(filter.Field()) - operation := cond.operation(filter.Operation()) + field := cond.columnName(filter.Field) + operation := cond.operation(filter.Operation) if field == "" || operation == "" { return "" } - format := cond.conditionFormat(filter.Operation()) + format := cond.conditionFormat(filter.Operation) return fmt.Sprintf(format, field, operation) } diff --git a/internal/eventstore/v2/repository/sql/query_test.go b/internal/eventstore/v2/repository/sql/query_test.go index ba01e66ac3..25ac030102 100644 --- a/internal/eventstore/v2/repository/sql/query_test.go +++ b/internal/eventstore/v2/repository/sql/query_test.go @@ -11,120 +11,6 @@ import ( "github.com/lib/pq" ) -func Test_numberPlaceholder(t *testing.T) { - type args struct { - query string - old string - new string - } - type res struct { - query string - } - tests := []struct { - name string - args args - res res - }{ - { - name: "no replaces", - args: args{ - new: "$", - old: "?", - query: "SELECT * FROM eventstore.events", - }, - res: res{ - query: "SELECT * FROM eventstore.events", - }, - }, - { - name: "two replaces", - args: args{ - new: "$", - old: "?", - query: "SELECT * FROM eventstore.events WHERE aggregate_type = ? AND LIMIT = ?", - }, - res: res{ - query: "SELECT * FROM eventstore.events WHERE aggregate_type = $1 AND LIMIT = $2", - }, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - if got := numberPlaceholder(tt.args.query, tt.args.old, tt.args.new); got != tt.res.query { - t.Errorf("numberPlaceholder() = %v, want %v", got, tt.res.query) - } - }) - } -} - -func Test_getOperation(t *testing.T) { - t.Run("all ops", func(t *testing.T) { - for op, expected := range map[repository.Operation]string{ - repository.Operation_Equals: "=", - repository.Operation_In: "=", - repository.Operation_Greater: ">", - repository.Operation_Less: "<", - repository.Operation(-1): "", - } { - if got := getOperation(op); got != expected { - t.Errorf("getOperation() = %v, want %v", got, expected) - } - } - }) -} - -func Test_getField(t *testing.T) { - t.Run("all fields", func(t *testing.T) { - for field, expected := range map[repository.Field]string{ - repository.Field_AggregateType: "aggregate_type", - repository.Field_AggregateID: "aggregate_id", - repository.Field_LatestSequence: "event_sequence", - repository.Field_ResourceOwner: "resource_owner", - repository.Field_EditorService: "editor_service", - repository.Field_EditorUser: "editor_user", - repository.Field_EventType: "event_type", - repository.Field(-1): "", - } { - if got := getField(field); got != expected { - t.Errorf("getField() = %v, want %v", got, expected) - } - } - }) -} - -func Test_getConditionFormat(t *testing.T) { - type args struct { - operation repository.Operation - } - tests := []struct { - name string - args args - want string - }{ - { - name: "no in operation", - args: args{ - operation: repository.Operation_Equals, - }, - want: "%s %s ?", - }, - { - name: "in operation", - args: args{ - operation: repository.Operation_In, - }, - want: "%s %s ANY(?)", - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - if got := getConditionFormat(tt.args.operation); got != tt.want { - t.Errorf("prepareConditionFormat() = %v, want %v", got, tt.want) - } - }) - } -} - func Test_getCondition(t *testing.T) { type args struct { filter *repository.Filter @@ -172,7 +58,8 @@ func Test_getCondition(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - if got := getCondition(tt.args.filter); got != tt.want { + db := &CRDB{} + if got := getCondition(db, tt.args.filter); got != tt.want { t.Errorf("getCondition() = %v, want %v", got, tt.want) } }) @@ -180,6 +67,9 @@ func Test_getCondition(t *testing.T) { } func Test_prepareColumns(t *testing.T) { + type fields struct { + dbRow []interface{} + } type args struct { columns repository.Columns dest interface{} @@ -187,14 +77,14 @@ func Test_prepareColumns(t *testing.T) { } type res struct { query string - dbRow []interface{} expected interface{} dbErr func(error) bool } tests := []struct { - name string - args args - res res + name string + args args + res res + fields fields }{ { name: "invalid columns", @@ -212,9 +102,11 @@ func Test_prepareColumns(t *testing.T) { }, res: res{ query: "SELECT MAX(event_sequence) FROM eventstore.events", - dbRow: []interface{}{Sequence(5)}, expected: Sequence(5), }, + fields: fields{ + dbRow: []interface{}{Sequence(5)}, + }, }, { name: "max sequence wrong dest type", @@ -228,22 +120,26 @@ func Test_prepareColumns(t *testing.T) { }, }, { - name: "event", + name: "events", args: args{ columns: repository.Columns_Event, - dest: new(repository.Event), + dest: &[]*repository.Event{}, }, res: res{ - query: "SELECT creation_date, event_type, event_sequence, previous_sequence, event_data, editor_service, editor_user, resource_owner, aggregate_type, aggregate_id, aggregate_version FROM eventstore.events", - dbRow: []interface{}{time.Time{}, repository.EventType(""), uint64(5), Sequence(0), Data(nil), "", "", "", repository.AggregateType("user"), "hodor", repository.Version("")}, - expected: repository.Event{AggregateID: "hodor", AggregateType: "user", Sequence: 5, Data: make(Data, 0)}, + query: "SELECT creation_date, event_type, event_sequence, previous_sequence, event_data, editor_service, editor_user, resource_owner, aggregate_type, aggregate_id, aggregate_version FROM eventstore.events", + expected: []*repository.Event{ + {AggregateID: "hodor", AggregateType: "user", Sequence: 5, Data: make(Data, 0)}, + }, + }, + fields: fields{ + dbRow: []interface{}{time.Time{}, repository.EventType(""), uint64(5), Sequence(0), Data(nil), "", "", "", repository.AggregateType("user"), "hodor", repository.Version("")}, }, }, { - name: "event wrong dest type", + name: "events wrong dest type", args: args{ columns: repository.Columns_Event, - dest: new(uint64), + dest: []*repository.Event{}, }, res: res{ query: "SELECT creation_date, event_type, event_sequence, previous_sequence, event_data, editor_service, editor_user, resource_owner, aggregate_type, aggregate_id, aggregate_version FROM eventstore.events", @@ -254,7 +150,7 @@ func Test_prepareColumns(t *testing.T) { name: "event query error", args: args{ columns: repository.Columns_Event, - dest: new(repository.Event), + dest: &[]*repository.Event{}, dbErr: sql.ErrConnDone, }, res: res{ @@ -265,9 +161,10 @@ func Test_prepareColumns(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - query, rowScanner := prepareColumns(tt.args.columns) + crdb := &CRDB{} + query, rowScanner := prepareColumns(crdb, tt.args.columns) if query != tt.res.query { - t.Errorf("prepareColumns() got = %v, want %v", query, tt.res.query) + t.Errorf("prepareColumns() got = %s, want %s", query, tt.res.query) } if tt.res.query == "" && rowScanner != nil { t.Errorf("row scanner should be nil") @@ -275,15 +172,16 @@ func Test_prepareColumns(t *testing.T) { if rowScanner == nil { return } - err := rowScanner(prepareTestScan(tt.args.dbErr, tt.res.dbRow), tt.args.dest) - if tt.res.dbErr != nil { - if !tt.res.dbErr(err) { - t.Errorf("wrong error type in rowScanner got: %v", err) - } - } else { - if !reflect.DeepEqual(reflect.Indirect(reflect.ValueOf(tt.args.dest)).Interface(), tt.res.expected) { - t.Errorf("unexpected result from rowScanner want: %v got: %v", tt.res.dbRow, tt.args.dest) - } + err := rowScanner(prepareTestScan(tt.args.dbErr, tt.fields.dbRow), tt.args.dest) + if err != nil && tt.res.dbErr == nil || err != nil && !tt.res.dbErr(err) || err == nil && tt.res.dbErr != nil { + t.Errorf("wrong error type in rowScanner got: %v", err) + return + } + if tt.res.dbErr != nil && tt.res.dbErr(err) { + return + } + if !reflect.DeepEqual(reflect.Indirect(reflect.ValueOf(tt.args.dest)).Interface(), tt.res.expected) { + t.Errorf("unexpected result from rowScanner \nwant: %+v \ngot: %+v", tt.fields.dbRow, reflect.Indirect(reflect.ValueOf(tt.args.dest)).Interface()) } }) } @@ -379,7 +277,8 @@ func Test_prepareCondition(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - gotClause, gotValues := prepareCondition(tt.args.filters) + crdb := &CRDB{} + gotClause, gotValues := prepareCondition(crdb, tt.args.filters) if gotClause != tt.res.clause { t.Errorf("prepareCondition() gotClause = %v, want %v", gotClause, tt.res.clause) } @@ -398,11 +297,10 @@ func Test_prepareCondition(t *testing.T) { func Test_buildQuery(t *testing.T) { type args struct { - queryFactory *repository.SearchQuery + query *repository.SearchQuery } type res struct { query string - limit uint64 values []interface{} rowScanner bool } @@ -411,23 +309,34 @@ func Test_buildQuery(t *testing.T) { args args res res }{ - { - name: "invalid query factory", - args: args{ - queryFactory: nil, - }, - res: res{ - query: "", - limit: 0, - rowScanner: false, - values: nil, - }, - }, + + // { + //removed because it's no valid test case + // name: "no query", + // args: args{ + // query: nil, + // }, + // res: res{ + // query: "", + // rowScanner: false, + // values: nil, + // }, + // }, { name: "with order by desc", args: args{ // NewSearchQueryFactory("user").OrderDesc() - queryFactory: &repository.SearchQuery{Desc: true}, + query: &repository.SearchQuery{ + Columns: repository.Columns_Event, + Desc: true, + Filters: []*repository.Filter{ + { + Field: repository.Field_AggregateType, + Value: repository.AggregateType("user"), + Operation: repository.Operation_Equals, + }, + }, + }, }, res: res{ query: "SELECT creation_date, event_type, event_sequence, previous_sequence, event_data, editor_service, editor_user, resource_owner, aggregate_type, aggregate_id, aggregate_version FROM eventstore.events WHERE aggregate_type = $1 ORDER BY event_sequence DESC", @@ -438,45 +347,61 @@ func Test_buildQuery(t *testing.T) { { name: "with limit", args: args{ - queryFactory: repository.NewSearchQueryFactory("user").Limit(5), + query: &repository.SearchQuery{ + Columns: repository.Columns_Event, + Desc: false, + Limit: 5, + Filters: []*repository.Filter{ + { + Field: repository.Field_AggregateType, + Value: repository.AggregateType("user"), + Operation: repository.Operation_Equals, + }, + }, + }, }, res: res{ query: "SELECT creation_date, event_type, event_sequence, previous_sequence, event_data, editor_service, editor_user, resource_owner, aggregate_type, aggregate_id, aggregate_version FROM eventstore.events WHERE aggregate_type = $1 ORDER BY event_sequence LIMIT $2", rowScanner: true, values: []interface{}{repository.AggregateType("user"), uint64(5)}, - limit: 5, }, }, { name: "with limit and order by desc", args: args{ - queryFactory: repository.NewSearchQueryFactory("user").Limit(5).OrderDesc(), + query: &repository.SearchQuery{ + Columns: repository.Columns_Event, + Desc: true, + Limit: 5, + Filters: []*repository.Filter{ + { + Field: repository.Field_AggregateType, + Value: repository.AggregateType("user"), + Operation: repository.Operation_Equals, + }, + }, + }, }, res: res{ query: "SELECT creation_date, event_type, event_sequence, previous_sequence, event_data, editor_service, editor_user, resource_owner, aggregate_type, aggregate_id, aggregate_version FROM eventstore.events WHERE aggregate_type = $1 ORDER BY event_sequence DESC LIMIT $2", rowScanner: true, values: []interface{}{repository.AggregateType("user"), uint64(5)}, - limit: 5, }, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - gotQuery, gotLimit, gotValues, gotRowScanner := buildQuery(tt.args.queryFactory) + crdb := &CRDB{} + gotQuery, gotValues, gotRowScanner := buildQuery(crdb, tt.args.query) if gotQuery != tt.res.query { t.Errorf("buildQuery() gotQuery = %v, want %v", gotQuery, tt.res.query) } - if gotLimit != tt.res.limit { - t.Errorf("buildQuery() gotLimit = %v, want %v", gotLimit, tt.res.limit) - } if len(gotValues) != len(tt.res.values) { t.Errorf("wrong length of gotten values got = %d, want %d", len(gotValues), len(tt.res.values)) return } - for i, value := range gotValues { - if !reflect.DeepEqual(value, tt.res.values[i]) { - t.Errorf("prepareCondition() gotValues = %v, want %v", gotValues, tt.res.values) - } + if !reflect.DeepEqual(gotValues, tt.res.values) { + t.Errorf("prepareCondition() gotValues = %T: %v, want %T: %v", gotValues, gotValues, tt.res.values, tt.res.values) } if (tt.res.rowScanner && gotRowScanner == nil) || (!tt.res.rowScanner && gotRowScanner != nil) { t.Errorf("rowScanner should be nil==%v got nil==%v", tt.res.rowScanner, gotRowScanner == nil) @@ -484,5 +409,3 @@ func Test_buildQuery(t *testing.T) { }) } } - -// func buildQuery(t *testing.T, factory *reposear)