test(eventstore): sql unit tests

This commit is contained in:
adlerhurst 2020-10-05 20:39:36 +02:00
parent 120a8bae85
commit 64a0859d76
7 changed files with 541 additions and 364 deletions

View File

@ -19,9 +19,9 @@ const (
)
type Filter struct {
field Field
value interface{}
operation Operation
Field Field
Value interface{}
Operation Operation
}
type Operation int32
@ -48,33 +48,33 @@ const (
//NewFilter is used in tests. Use searchQuery.*Filter() instead
func NewFilter(field Field, value interface{}, operation Operation) *Filter {
return &Filter{
field: field,
value: value,
operation: operation,
Field: field,
Value: value,
Operation: operation,
}
}
func (f *Filter) Field() Field {
return f.field
}
func (f *Filter) Operation() Operation {
return f.operation
}
func (f *Filter) Value() interface{} {
return f.value
}
// func (f *Filter) Field() Field {
// return f.field
// }
// func (f *Filter) Operation() Operation {
// return f.operation
// }
// func (f *Filter) Value() interface{} {
// return f.value
// }
func (f *Filter) Validate() error {
if f == nil {
return errors.ThrowPreconditionFailed(nil, "REPO-z6KcG", "filter is nil")
}
if f.field <= 0 {
if f.Field <= 0 {
return errors.ThrowPreconditionFailed(nil, "REPO-zw62U", "field not definded")
}
if f.value == nil {
if f.Value == nil {
return errors.ThrowPreconditionFailed(nil, "REPO-GJ9ct", "no value definded")
}
if f.operation <= 0 {
if f.Operation <= 0 {
return errors.ThrowPreconditionFailed(nil, "REPO-RrQTy", "operation not definded")
}
return nil

View File

@ -23,7 +23,7 @@ func TestNewFilter(t *testing.T) {
value: "hodor",
operation: Operation_Equals,
},
want: &Filter{field: Field_AggregateID, operation: Operation_Equals, value: "hodor"},
want: &Filter{Field: Field_AggregateID, Operation: Operation_Equals, Value: "hodor"},
},
}
for _, tt := range tests {
@ -91,9 +91,9 @@ func TestFilter_Validate(t *testing.T) {
var f *Filter
if !tt.fields.isNil {
f = &Filter{
field: tt.fields.field,
value: tt.fields.value,
operation: tt.fields.operation,
Field: tt.fields.field,
Value: tt.fields.value,
Operation: tt.fields.operation,
}
}
if err := f.Validate(); (err != nil) != tt.wantErr {

View File

@ -158,58 +158,46 @@ func (db *CRDB) Push(ctx context.Context, events ...*repository.Event) error {
// Filter returns all events matching the given search query
func (db *CRDB) Filter(ctx context.Context, searchQuery *repository.SearchQuery) (events []*repository.Event, err error) {
rows, rowScanner, err := db.query(searchQuery)
events = []*repository.Event{}
err = db.query(searchQuery, &events)
if err != nil {
return nil, err
}
defer rows.Close()
for rows.Next() {
event := new(repository.Event)
err := rowScanner(rows.Scan, event)
if err != nil {
return nil, err
}
events = append(events, event)
}
return events, nil
}
//LatestSequence returns the latests sequence found by the the search query
func (db *CRDB) LatestSequence(ctx context.Context, searchQuery *repository.SearchQuery) (uint64, error) {
rows, rowScanner, err := db.query(searchQuery)
if err != nil {
return 0, err
}
defer rows.Close()
if !rows.Next() {
return 0, caos_errs.ThrowNotFound(nil, "SQL-cAEzS", "latest sequence not found")
}
var seq Sequence
err = rowScanner(rows.Scan, &seq)
err := db.query(searchQuery, &seq)
if err != nil {
return 0, err
}
return uint64(seq), nil
}
func (db *CRDB) query(searchQuery *repository.SearchQuery) (*sql.Rows, rowScan, error) {
func (db *CRDB) query(searchQuery *repository.SearchQuery, data interface{}) error {
query, values, rowScanner := buildQuery(db, searchQuery)
if query == "" {
return nil, nil, caos_errs.ThrowInvalidArgument(nil, "SQL-rWeBw", "invalid query factory")
return caos_errs.ThrowInvalidArgument(nil, "SQL-rWeBw", "invalid query factory")
}
rows, err := db.client.Query(query, values...)
if err != nil {
logging.Log("SQL-HP3Uk").WithError(err).Info("query failed")
return nil, nil, caos_errs.ThrowInternal(err, "SQL-IJuyR", "unable to filter events")
return caos_errs.ThrowInternal(err, "SQL-IJuyR", "unable to filter events")
}
return rows, rowScanner, nil
defer rows.Close()
for rows.Next() {
err = rowScanner(rows.Scan, nil)
if err != nil {
return err
}
}
return nil
}
func (db *CRDB) eventQuery() string {

View File

@ -0,0 +1,262 @@
package sql
import (
"testing"
"github.com/caos/zitadel/internal/eventstore/v2/repository"
_ "github.com/lib/pq"
)
func TestCRDB_placeholder(t *testing.T) {
type args struct {
query string
}
type res struct {
query string
}
tests := []struct {
name string
args args
res res
}{
{
name: "no placeholders",
args: args{
query: "SELECT * FROM eventstore.events",
},
res: res{
query: "SELECT * FROM eventstore.events",
},
},
{
name: "one placeholder",
args: args{
query: "SELECT * FROM eventstore.events WHERE aggregate_type = ?",
},
res: res{
query: "SELECT * FROM eventstore.events WHERE aggregate_type = $1",
},
},
{
name: "multiple placeholders",
args: args{
query: "SELECT * FROM eventstore.events WHERE aggregate_type = ? AND aggregate_id = ? LIMIT ?",
},
res: res{
query: "SELECT * FROM eventstore.events WHERE aggregate_type = $1 AND aggregate_id = $2 LIMIT $3",
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
db := &CRDB{}
if query := db.placeholder(tt.args.query); query != tt.res.query {
t.Errorf("CRDB.placeholder() = %v, want %v", query, tt.res.query)
}
})
}
}
func TestCRDB_operation(t *testing.T) {
type res struct {
op string
}
type args struct {
operation repository.Operation
}
tests := []struct {
name string
args args
res res
}{
{
name: "no op",
args: args{
operation: repository.Operation(-1),
},
res: res{
op: "",
},
},
{
name: "greater",
args: args{
operation: repository.Operation_Greater,
},
res: res{
op: ">",
},
},
{
name: "less",
args: args{
operation: repository.Operation_Less,
},
res: res{
op: "<",
},
},
{
name: "equals",
args: args{
operation: repository.Operation_Equals,
},
res: res{
op: "=",
},
},
{
name: "in",
args: args{
operation: repository.Operation_In,
},
res: res{
op: "=",
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
db := &CRDB{}
if got := db.operation(tt.args.operation); got != tt.res.op {
t.Errorf("CRDB.operation() = %v, want %v", got, tt.res.op)
}
})
}
}
func TestCRDB_conditionFormat(t *testing.T) {
type res struct {
format string
}
type args struct {
operation repository.Operation
}
tests := []struct {
name string
args args
res res
}{
{
name: "default",
args: args{
operation: repository.Operation_Equals,
},
res: res{
format: "%s %s ?",
},
},
{
name: "in",
args: args{
operation: repository.Operation_In,
},
res: res{
format: "%s %s ANY(?)",
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
db := &CRDB{}
if got := db.conditionFormat(tt.args.operation); got != tt.res.format {
t.Errorf("CRDB.conditionFormat() = %v, want %v", got, tt.res.format)
}
})
}
}
func TestCRDB_columnName(t *testing.T) {
type res struct {
name string
}
type args struct {
field repository.Field
}
tests := []struct {
name string
args args
res res
}{
{
name: "invalid field",
args: args{
field: repository.Field(-1),
},
res: res{
name: "",
},
},
{
name: "aggregate id",
args: args{
field: repository.Field_AggregateID,
},
res: res{
name: "aggregate_id",
},
},
{
name: "aggregate type",
args: args{
field: repository.Field_AggregateType,
},
res: res{
name: "aggregate_type",
},
},
{
name: "editor service",
args: args{
field: repository.Field_EditorService,
},
res: res{
name: "editor_service",
},
},
{
name: "editor user",
args: args{
field: repository.Field_EditorUser,
},
res: res{
name: "editor_user",
},
},
{
name: "event type",
args: args{
field: repository.Field_EventType,
},
res: res{
name: "event_type",
},
},
{
name: "latest sequence",
args: args{
field: repository.Field_LatestSequence,
},
res: res{
name: "event_sequence",
},
},
{
name: "resource owner",
args: args{
field: repository.Field_ResourceOwner,
},
res: res{
name: "resource_owner",
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
db := &CRDB{}
if got := db.columnName(tt.args.field); got != tt.res.name {
t.Errorf("CRDB.operation() = %v, want %v", got, tt.res.name)
}
})
}
}

View File

@ -1,163 +1,164 @@
package sql
import (
"database/sql"
"fmt"
"io/ioutil"
"os"
"path/filepath"
"sort"
"strconv"
"strings"
"testing"
"time"
// import (
// "database/sql"
// "fmt"
// "io/ioutil"
// "os"
// "path/filepath"
// "sort"
// "strconv"
// "strings"
// "testing"
// "time"
"github.com/caos/logging"
"github.com/cockroachdb/cockroach-go/v2/testserver"
)
// "github.com/caos/logging"
// "github.com/caos/zitadel/internal/eventstore/v2/repository"
// "github.com/cockroachdb/cockroach-go/v2/testserver"
// )
var (
migrationsPath = os.ExpandEnv("${GOPATH}/src/github.com/caos/zitadel/migrations/cockroach")
db *sql.DB
)
// var (
// migrationsPath = os.ExpandEnv("${GOPATH}/src/github.com/caos/zitadel/migrations/cockroach")
// db *sql.DB
// )
func TestMain(m *testing.M) {
ts, err := testserver.NewTestServer()
if err != nil {
logging.LogWithFields("REPOS-RvjLG", "error", err).Fatal("unable to start db")
}
// func TestMain(m *testing.M) {
// ts, err := testserver.NewTestServer()
// if err != nil {
// logging.LogWithFields("REPOS-RvjLG", "error", err).Fatal("unable to start db")
// }
db, err = sql.Open("postgres", ts.PGURL().String())
if err != nil {
logging.LogWithFields("REPOS-CF6dQ", "error", err).Fatal("unable to connect to db")
}
// db, err = sql.Open("postgres", ts.PGURL().String())
// if err != nil {
// logging.LogWithFields("REPOS-CF6dQ", "error", err).Fatal("unable to connect to db")
// }
defer func() {
db.Close()
ts.Stop()
}()
// defer func() {
// db.Close()
// ts.Stop()
// }()
if err = executeMigrations(); err != nil {
logging.LogWithFields("REPOS-jehDD", "error", err).Fatal("migrations failed")
}
// if err = executeMigrations(); err != nil {
// logging.LogWithFields("REPOS-jehDD", "error", err).Fatal("migrations failed")
// }
os.Exit(m.Run())
}
// os.Exit(m.Run())
// }
func TestInsert(t *testing.T) {
tx, _ := db.Begin()
// func TestInsert(t *testing.T) {
// tx, _ := db.Begin()
var seq Sequence
var d time.Time
// var seq Sequence
// var d time.Time
row := tx.QueryRow(crdbInsert, "event.type", "aggregate.type", "aggregate.id", Version("v1"), nil, Data(nil), "editor.user", "editor.service", "resource.owner", Sequence(0), false)
err := row.Scan(&seq, &d)
// row := tx.QueryRow(crdbInsert, "event.type", "aggregate.type", "aggregate.id", repository.Version("v1"), nil, Data(nil), "editor.user", "editor.service", "resource.owner", Sequence(0), false)
// err := row.Scan(&seq, &d)
row = tx.QueryRow(crdbInsert, "event.type", "aggregate.type", "aggregate.id", Version("v1"), nil, Data(nil), "editor.user", "editor.service", "resource.owner", Sequence(1), true)
err = row.Scan(&seq, &d)
// row = tx.QueryRow(crdbInsert, "event.type", "aggregate.type", "aggregate.id", repository.Version("v1"), nil, Data(nil), "editor.user", "editor.service", "resource.owner", Sequence(1), true)
// err = row.Scan(&seq, &d)
row = tx.QueryRow(crdbInsert, "event.type", "aggregate.type", "aggregate.id", Version("v1"), nil, Data(nil), "editor.user", "editor.service", "resource.owner", Sequence(0), false)
err = row.Scan(&seq, &d)
// row = tx.QueryRow(crdbInsert, "event.type", "aggregate.type", "aggregate.id", repository.Version("v1"), nil, Data(nil), "editor.user", "editor.service", "resource.owner", Sequence(0), false)
// err = row.Scan(&seq, &d)
tx.Commit()
// tx.Commit()
rows, err := db.Query("select * from eventstore.events order by event_sequence")
defer rows.Close()
fmt.Println(err)
// rows, err := db.Query("select * from eventstore.events order by event_sequence")
// defer rows.Close()
// fmt.Println(err)
fmt.Println(rows.Columns())
for rows.Next() {
i := make([]interface{}, 12)
var id string
rows.Scan(&id, &i[1], &i[2], &i[3], &i[4], &i[5], &i[6], &i[7], &i[8], &i[9], &i[10], &i[11])
i[0] = id
// fmt.Println(rows.Columns())
// for rows.Next() {
// i := make([]interface{}, 12)
// var id string
// rows.Scan(&id, &i[1], &i[2], &i[3], &i[4], &i[5], &i[6], &i[7], &i[8], &i[9], &i[10], &i[11])
// i[0] = id
fmt.Println(i)
}
// fmt.Println(i)
// }
t.Fail()
}
// t.Fail()
// }
func executeMigrations() error {
files, err := migrationFilePaths()
if err != nil {
return err
}
sort.Sort(files)
for _, file := range files {
migration, err := ioutil.ReadFile(string(file))
if err != nil {
return err
}
transactionInMigration := strings.Contains(string(migration), "BEGIN;")
exec := db.Exec
var tx *sql.Tx
if !transactionInMigration {
tx, err = db.Begin()
if err != nil {
return fmt.Errorf("begin file: %v || err: %w", file, err)
}
exec = tx.Exec
}
if _, err = exec(string(migration)); err != nil {
return fmt.Errorf("exec file: %v || err: %w", file, err)
}
if !transactionInMigration {
if err = tx.Commit(); err != nil {
return fmt.Errorf("commit file: %v || err: %w", file, err)
}
}
}
return nil
}
// func executeMigrations() error {
// files, err := migrationFilePaths()
// if err != nil {
// return err
// }
// sort.Sort(files)
// for _, file := range files {
// migration, err := ioutil.ReadFile(string(file))
// if err != nil {
// return err
// }
// transactionInMigration := strings.Contains(string(migration), "BEGIN;")
// exec := db.Exec
// var tx *sql.Tx
// if !transactionInMigration {
// tx, err = db.Begin()
// if err != nil {
// return fmt.Errorf("begin file: %v || err: %w", file, err)
// }
// exec = tx.Exec
// }
// if _, err = exec(string(migration)); err != nil {
// return fmt.Errorf("exec file: %v || err: %w", file, err)
// }
// if !transactionInMigration {
// if err = tx.Commit(); err != nil {
// return fmt.Errorf("commit file: %v || err: %w", file, err)
// }
// }
// }
// return nil
// }
type migrationPaths []string
// type migrationPaths []string
type version struct {
major int
minor int
}
// type version struct {
// major int
// minor int
// }
func versionFromPath(s string) version {
v := s[strings.Index(s, "/V")+2 : strings.Index(s, "__")]
splitted := strings.Split(v, ".")
res := version{}
var err error
if len(splitted) >= 1 {
res.major, err = strconv.Atoi(splitted[0])
if err != nil {
panic(err)
}
}
// func versionFromPath(s string) version {
// v := s[strings.Index(s, "/V")+2 : strings.Index(s, "__")]
// splitted := strings.Split(v, ".")
// res := version{}
// var err error
// if len(splitted) >= 1 {
// res.major, err = strconv.Atoi(splitted[0])
// if err != nil {
// panic(err)
// }
// }
if len(splitted) >= 2 {
res.minor, err = strconv.Atoi(splitted[1])
if err != nil {
panic(err)
}
}
// if len(splitted) >= 2 {
// res.minor, err = strconv.Atoi(splitted[1])
// if err != nil {
// panic(err)
// }
// }
return res
}
// return res
// }
func (a migrationPaths) Len() int { return len(a) }
func (a migrationPaths) Swap(i, j int) { a[i], a[j] = a[j], a[i] }
func (a migrationPaths) Less(i, j int) bool {
versionI := versionFromPath(a[i])
versionJ := versionFromPath(a[j])
// func (a migrationPaths) Len() int { return len(a) }
// func (a migrationPaths) Swap(i, j int) { a[i], a[j] = a[j], a[i] }
// func (a migrationPaths) Less(i, j int) bool {
// versionI := versionFromPath(a[i])
// versionJ := versionFromPath(a[j])
return versionI.major < versionJ.major ||
(versionI.major == versionJ.major && versionI.minor < versionJ.minor)
}
// return versionI.major < versionJ.major ||
// (versionI.major == versionJ.major && versionI.minor < versionJ.minor)
// }
func migrationFilePaths() (migrationPaths, error) {
files := make(migrationPaths, 0)
err := filepath.Walk(migrationsPath, func(path string, info os.FileInfo, err error) error {
if err != nil || info.IsDir() || !strings.HasSuffix(info.Name(), ".sql") {
return err
}
files = append(files, path)
return nil
})
return files, err
}
// func migrationFilePaths() (migrationPaths, error) {
// files := make(migrationPaths, 0)
// err := filepath.Walk(migrationsPath, func(path string, info os.FileInfo, err error) error {
// if err != nil || info.IsDir() || !strings.HasSuffix(info.Name(), ".sql") {
// return err
// }
// files = append(files, path)
// return nil
// })
// return files, err
// }

View File

@ -52,15 +52,15 @@ func buildQuery(criteria criteriaer, searchQuery *repository.SearchQuery) (query
func prepareColumns(criteria criteriaer, columns repository.Columns) (string, func(s scan, dest interface{}) error) {
switch columns {
case repository.Columns_Max_Sequence:
return criteria.maxSequenceQuery(), maxSequenceRowScanner
return criteria.maxSequenceQuery(), maxSequenceScanner
case repository.Columns_Event:
return criteria.eventQuery(), eventRowScanner
return criteria.eventQuery(), eventsScanner
default:
return "", nil
}
}
func maxSequenceRowScanner(row scan, dest interface{}) (err error) {
func maxSequenceScanner(row scan, dest interface{}) (err error) {
sequence, ok := dest.(*Sequence)
if !ok {
return z_errors.ThrowInvalidArgument(nil, "SQL-NBjA9", "type must be sequence")
@ -72,15 +72,16 @@ func maxSequenceRowScanner(row scan, dest interface{}) (err error) {
return z_errors.ThrowInternal(err, "SQL-bN5xg", "something went wrong")
}
func eventRowScanner(row scan, dest interface{}) (err error) {
event, ok := dest.(*repository.Event)
func eventsScanner(scanner scan, dest interface{}) (err error) {
events, ok := dest.(*[]*repository.Event)
if !ok {
return z_errors.ThrowInvalidArgument(nil, "SQL-4GP6F", "type must be event")
}
var previousSequence Sequence
data := make(Data, 0)
event := new(repository.Event)
err = row(
err = scanner(
&event.CreationDate,
&event.Type,
&event.Sequence,
@ -104,6 +105,8 @@ func eventRowScanner(row scan, dest interface{}) (err error) {
event.Data = make([]byte, len(data))
copy(event.Data, data)
*events = append(*events, event)
return nil
}
@ -115,7 +118,7 @@ func prepareCondition(criteria criteriaer, filters []*repository.Filter) (clause
return clause, values
}
for i, filter := range filters {
value := filter.Value()
value := filter.Value
switch value.(type) {
case []bool, []float64, []int64, []string, []repository.AggregateType, []repository.EventType, *[]bool, *[]float64, *[]int64, *[]string, *[]repository.AggregateType, *[]repository.EventType:
value = pq.Array(value)
@ -131,12 +134,12 @@ func prepareCondition(criteria criteriaer, filters []*repository.Filter) (clause
}
func getCondition(cond criteriaer, filter *repository.Filter) (condition string) {
field := cond.columnName(filter.Field())
operation := cond.operation(filter.Operation())
field := cond.columnName(filter.Field)
operation := cond.operation(filter.Operation)
if field == "" || operation == "" {
return ""
}
format := cond.conditionFormat(filter.Operation())
format := cond.conditionFormat(filter.Operation)
return fmt.Sprintf(format, field, operation)
}

View File

@ -11,120 +11,6 @@ import (
"github.com/lib/pq"
)
func Test_numberPlaceholder(t *testing.T) {
type args struct {
query string
old string
new string
}
type res struct {
query string
}
tests := []struct {
name string
args args
res res
}{
{
name: "no replaces",
args: args{
new: "$",
old: "?",
query: "SELECT * FROM eventstore.events",
},
res: res{
query: "SELECT * FROM eventstore.events",
},
},
{
name: "two replaces",
args: args{
new: "$",
old: "?",
query: "SELECT * FROM eventstore.events WHERE aggregate_type = ? AND LIMIT = ?",
},
res: res{
query: "SELECT * FROM eventstore.events WHERE aggregate_type = $1 AND LIMIT = $2",
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := numberPlaceholder(tt.args.query, tt.args.old, tt.args.new); got != tt.res.query {
t.Errorf("numberPlaceholder() = %v, want %v", got, tt.res.query)
}
})
}
}
func Test_getOperation(t *testing.T) {
t.Run("all ops", func(t *testing.T) {
for op, expected := range map[repository.Operation]string{
repository.Operation_Equals: "=",
repository.Operation_In: "=",
repository.Operation_Greater: ">",
repository.Operation_Less: "<",
repository.Operation(-1): "",
} {
if got := getOperation(op); got != expected {
t.Errorf("getOperation() = %v, want %v", got, expected)
}
}
})
}
func Test_getField(t *testing.T) {
t.Run("all fields", func(t *testing.T) {
for field, expected := range map[repository.Field]string{
repository.Field_AggregateType: "aggregate_type",
repository.Field_AggregateID: "aggregate_id",
repository.Field_LatestSequence: "event_sequence",
repository.Field_ResourceOwner: "resource_owner",
repository.Field_EditorService: "editor_service",
repository.Field_EditorUser: "editor_user",
repository.Field_EventType: "event_type",
repository.Field(-1): "",
} {
if got := getField(field); got != expected {
t.Errorf("getField() = %v, want %v", got, expected)
}
}
})
}
func Test_getConditionFormat(t *testing.T) {
type args struct {
operation repository.Operation
}
tests := []struct {
name string
args args
want string
}{
{
name: "no in operation",
args: args{
operation: repository.Operation_Equals,
},
want: "%s %s ?",
},
{
name: "in operation",
args: args{
operation: repository.Operation_In,
},
want: "%s %s ANY(?)",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := getConditionFormat(tt.args.operation); got != tt.want {
t.Errorf("prepareConditionFormat() = %v, want %v", got, tt.want)
}
})
}
}
func Test_getCondition(t *testing.T) {
type args struct {
filter *repository.Filter
@ -172,7 +58,8 @@ func Test_getCondition(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := getCondition(tt.args.filter); got != tt.want {
db := &CRDB{}
if got := getCondition(db, tt.args.filter); got != tt.want {
t.Errorf("getCondition() = %v, want %v", got, tt.want)
}
})
@ -180,6 +67,9 @@ func Test_getCondition(t *testing.T) {
}
func Test_prepareColumns(t *testing.T) {
type fields struct {
dbRow []interface{}
}
type args struct {
columns repository.Columns
dest interface{}
@ -187,14 +77,14 @@ func Test_prepareColumns(t *testing.T) {
}
type res struct {
query string
dbRow []interface{}
expected interface{}
dbErr func(error) bool
}
tests := []struct {
name string
args args
res res
name string
args args
res res
fields fields
}{
{
name: "invalid columns",
@ -212,9 +102,11 @@ func Test_prepareColumns(t *testing.T) {
},
res: res{
query: "SELECT MAX(event_sequence) FROM eventstore.events",
dbRow: []interface{}{Sequence(5)},
expected: Sequence(5),
},
fields: fields{
dbRow: []interface{}{Sequence(5)},
},
},
{
name: "max sequence wrong dest type",
@ -228,22 +120,26 @@ func Test_prepareColumns(t *testing.T) {
},
},
{
name: "event",
name: "events",
args: args{
columns: repository.Columns_Event,
dest: new(repository.Event),
dest: &[]*repository.Event{},
},
res: res{
query: "SELECT creation_date, event_type, event_sequence, previous_sequence, event_data, editor_service, editor_user, resource_owner, aggregate_type, aggregate_id, aggregate_version FROM eventstore.events",
dbRow: []interface{}{time.Time{}, repository.EventType(""), uint64(5), Sequence(0), Data(nil), "", "", "", repository.AggregateType("user"), "hodor", repository.Version("")},
expected: repository.Event{AggregateID: "hodor", AggregateType: "user", Sequence: 5, Data: make(Data, 0)},
query: "SELECT creation_date, event_type, event_sequence, previous_sequence, event_data, editor_service, editor_user, resource_owner, aggregate_type, aggregate_id, aggregate_version FROM eventstore.events",
expected: []*repository.Event{
{AggregateID: "hodor", AggregateType: "user", Sequence: 5, Data: make(Data, 0)},
},
},
fields: fields{
dbRow: []interface{}{time.Time{}, repository.EventType(""), uint64(5), Sequence(0), Data(nil), "", "", "", repository.AggregateType("user"), "hodor", repository.Version("")},
},
},
{
name: "event wrong dest type",
name: "events wrong dest type",
args: args{
columns: repository.Columns_Event,
dest: new(uint64),
dest: []*repository.Event{},
},
res: res{
query: "SELECT creation_date, event_type, event_sequence, previous_sequence, event_data, editor_service, editor_user, resource_owner, aggregate_type, aggregate_id, aggregate_version FROM eventstore.events",
@ -254,7 +150,7 @@ func Test_prepareColumns(t *testing.T) {
name: "event query error",
args: args{
columns: repository.Columns_Event,
dest: new(repository.Event),
dest: &[]*repository.Event{},
dbErr: sql.ErrConnDone,
},
res: res{
@ -265,9 +161,10 @@ func Test_prepareColumns(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
query, rowScanner := prepareColumns(tt.args.columns)
crdb := &CRDB{}
query, rowScanner := prepareColumns(crdb, tt.args.columns)
if query != tt.res.query {
t.Errorf("prepareColumns() got = %v, want %v", query, tt.res.query)
t.Errorf("prepareColumns() got = %s, want %s", query, tt.res.query)
}
if tt.res.query == "" && rowScanner != nil {
t.Errorf("row scanner should be nil")
@ -275,15 +172,16 @@ func Test_prepareColumns(t *testing.T) {
if rowScanner == nil {
return
}
err := rowScanner(prepareTestScan(tt.args.dbErr, tt.res.dbRow), tt.args.dest)
if tt.res.dbErr != nil {
if !tt.res.dbErr(err) {
t.Errorf("wrong error type in rowScanner got: %v", err)
}
} else {
if !reflect.DeepEqual(reflect.Indirect(reflect.ValueOf(tt.args.dest)).Interface(), tt.res.expected) {
t.Errorf("unexpected result from rowScanner want: %v got: %v", tt.res.dbRow, tt.args.dest)
}
err := rowScanner(prepareTestScan(tt.args.dbErr, tt.fields.dbRow), tt.args.dest)
if err != nil && tt.res.dbErr == nil || err != nil && !tt.res.dbErr(err) || err == nil && tt.res.dbErr != nil {
t.Errorf("wrong error type in rowScanner got: %v", err)
return
}
if tt.res.dbErr != nil && tt.res.dbErr(err) {
return
}
if !reflect.DeepEqual(reflect.Indirect(reflect.ValueOf(tt.args.dest)).Interface(), tt.res.expected) {
t.Errorf("unexpected result from rowScanner \nwant: %+v \ngot: %+v", tt.fields.dbRow, reflect.Indirect(reflect.ValueOf(tt.args.dest)).Interface())
}
})
}
@ -379,7 +277,8 @@ func Test_prepareCondition(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
gotClause, gotValues := prepareCondition(tt.args.filters)
crdb := &CRDB{}
gotClause, gotValues := prepareCondition(crdb, tt.args.filters)
if gotClause != tt.res.clause {
t.Errorf("prepareCondition() gotClause = %v, want %v", gotClause, tt.res.clause)
}
@ -398,11 +297,10 @@ func Test_prepareCondition(t *testing.T) {
func Test_buildQuery(t *testing.T) {
type args struct {
queryFactory *repository.SearchQuery
query *repository.SearchQuery
}
type res struct {
query string
limit uint64
values []interface{}
rowScanner bool
}
@ -411,23 +309,34 @@ func Test_buildQuery(t *testing.T) {
args args
res res
}{
{
name: "invalid query factory",
args: args{
queryFactory: nil,
},
res: res{
query: "",
limit: 0,
rowScanner: false,
values: nil,
},
},
// {
//removed because it's no valid test case
// name: "no query",
// args: args{
// query: nil,
// },
// res: res{
// query: "",
// rowScanner: false,
// values: nil,
// },
// },
{
name: "with order by desc",
args: args{
// NewSearchQueryFactory("user").OrderDesc()
queryFactory: &repository.SearchQuery{Desc: true},
query: &repository.SearchQuery{
Columns: repository.Columns_Event,
Desc: true,
Filters: []*repository.Filter{
{
Field: repository.Field_AggregateType,
Value: repository.AggregateType("user"),
Operation: repository.Operation_Equals,
},
},
},
},
res: res{
query: "SELECT creation_date, event_type, event_sequence, previous_sequence, event_data, editor_service, editor_user, resource_owner, aggregate_type, aggregate_id, aggregate_version FROM eventstore.events WHERE aggregate_type = $1 ORDER BY event_sequence DESC",
@ -438,45 +347,61 @@ func Test_buildQuery(t *testing.T) {
{
name: "with limit",
args: args{
queryFactory: repository.NewSearchQueryFactory("user").Limit(5),
query: &repository.SearchQuery{
Columns: repository.Columns_Event,
Desc: false,
Limit: 5,
Filters: []*repository.Filter{
{
Field: repository.Field_AggregateType,
Value: repository.AggregateType("user"),
Operation: repository.Operation_Equals,
},
},
},
},
res: res{
query: "SELECT creation_date, event_type, event_sequence, previous_sequence, event_data, editor_service, editor_user, resource_owner, aggregate_type, aggregate_id, aggregate_version FROM eventstore.events WHERE aggregate_type = $1 ORDER BY event_sequence LIMIT $2",
rowScanner: true,
values: []interface{}{repository.AggregateType("user"), uint64(5)},
limit: 5,
},
},
{
name: "with limit and order by desc",
args: args{
queryFactory: repository.NewSearchQueryFactory("user").Limit(5).OrderDesc(),
query: &repository.SearchQuery{
Columns: repository.Columns_Event,
Desc: true,
Limit: 5,
Filters: []*repository.Filter{
{
Field: repository.Field_AggregateType,
Value: repository.AggregateType("user"),
Operation: repository.Operation_Equals,
},
},
},
},
res: res{
query: "SELECT creation_date, event_type, event_sequence, previous_sequence, event_data, editor_service, editor_user, resource_owner, aggregate_type, aggregate_id, aggregate_version FROM eventstore.events WHERE aggregate_type = $1 ORDER BY event_sequence DESC LIMIT $2",
rowScanner: true,
values: []interface{}{repository.AggregateType("user"), uint64(5)},
limit: 5,
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
gotQuery, gotLimit, gotValues, gotRowScanner := buildQuery(tt.args.queryFactory)
crdb := &CRDB{}
gotQuery, gotValues, gotRowScanner := buildQuery(crdb, tt.args.query)
if gotQuery != tt.res.query {
t.Errorf("buildQuery() gotQuery = %v, want %v", gotQuery, tt.res.query)
}
if gotLimit != tt.res.limit {
t.Errorf("buildQuery() gotLimit = %v, want %v", gotLimit, tt.res.limit)
}
if len(gotValues) != len(tt.res.values) {
t.Errorf("wrong length of gotten values got = %d, want %d", len(gotValues), len(tt.res.values))
return
}
for i, value := range gotValues {
if !reflect.DeepEqual(value, tt.res.values[i]) {
t.Errorf("prepareCondition() gotValues = %v, want %v", gotValues, tt.res.values)
}
if !reflect.DeepEqual(gotValues, tt.res.values) {
t.Errorf("prepareCondition() gotValues = %T: %v, want %T: %v", gotValues, gotValues, tt.res.values, tt.res.values)
}
if (tt.res.rowScanner && gotRowScanner == nil) || (!tt.res.rowScanner && gotRowScanner != nil) {
t.Errorf("rowScanner should be nil==%v got nil==%v", tt.res.rowScanner, gotRowScanner == nil)
@ -484,5 +409,3 @@ func Test_buildQuery(t *testing.T) {
})
}
}
// func buildQuery(t *testing.T, factory *reposear)