test(eventstore): sql unit tests

This commit is contained in:
adlerhurst 2020-10-05 20:39:36 +02:00
parent 120a8bae85
commit 64a0859d76
7 changed files with 541 additions and 364 deletions

View File

@ -19,9 +19,9 @@ const (
) )
type Filter struct { type Filter struct {
field Field Field Field
value interface{} Value interface{}
operation Operation Operation Operation
} }
type Operation int32 type Operation int32
@ -48,33 +48,33 @@ const (
//NewFilter is used in tests. Use searchQuery.*Filter() instead //NewFilter is used in tests. Use searchQuery.*Filter() instead
func NewFilter(field Field, value interface{}, operation Operation) *Filter { func NewFilter(field Field, value interface{}, operation Operation) *Filter {
return &Filter{ return &Filter{
field: field, Field: field,
value: value, Value: value,
operation: operation, Operation: operation,
} }
} }
func (f *Filter) Field() Field { // func (f *Filter) Field() Field {
return f.field // return f.field
} // }
func (f *Filter) Operation() Operation { // func (f *Filter) Operation() Operation {
return f.operation // return f.operation
} // }
func (f *Filter) Value() interface{} { // func (f *Filter) Value() interface{} {
return f.value // return f.value
} // }
func (f *Filter) Validate() error { func (f *Filter) Validate() error {
if f == nil { if f == nil {
return errors.ThrowPreconditionFailed(nil, "REPO-z6KcG", "filter is nil") return errors.ThrowPreconditionFailed(nil, "REPO-z6KcG", "filter is nil")
} }
if f.field <= 0 { if f.Field <= 0 {
return errors.ThrowPreconditionFailed(nil, "REPO-zw62U", "field not definded") return errors.ThrowPreconditionFailed(nil, "REPO-zw62U", "field not definded")
} }
if f.value == nil { if f.Value == nil {
return errors.ThrowPreconditionFailed(nil, "REPO-GJ9ct", "no value definded") return errors.ThrowPreconditionFailed(nil, "REPO-GJ9ct", "no value definded")
} }
if f.operation <= 0 { if f.Operation <= 0 {
return errors.ThrowPreconditionFailed(nil, "REPO-RrQTy", "operation not definded") return errors.ThrowPreconditionFailed(nil, "REPO-RrQTy", "operation not definded")
} }
return nil return nil

View File

@ -23,7 +23,7 @@ func TestNewFilter(t *testing.T) {
value: "hodor", value: "hodor",
operation: Operation_Equals, operation: Operation_Equals,
}, },
want: &Filter{field: Field_AggregateID, operation: Operation_Equals, value: "hodor"}, want: &Filter{Field: Field_AggregateID, Operation: Operation_Equals, Value: "hodor"},
}, },
} }
for _, tt := range tests { for _, tt := range tests {
@ -91,9 +91,9 @@ func TestFilter_Validate(t *testing.T) {
var f *Filter var f *Filter
if !tt.fields.isNil { if !tt.fields.isNil {
f = &Filter{ f = &Filter{
field: tt.fields.field, Field: tt.fields.field,
value: tt.fields.value, Value: tt.fields.value,
operation: tt.fields.operation, Operation: tt.fields.operation,
} }
} }
if err := f.Validate(); (err != nil) != tt.wantErr { if err := f.Validate(); (err != nil) != tt.wantErr {

View File

@ -158,58 +158,46 @@ func (db *CRDB) Push(ctx context.Context, events ...*repository.Event) error {
// Filter returns all events matching the given search query // Filter returns all events matching the given search query
func (db *CRDB) Filter(ctx context.Context, searchQuery *repository.SearchQuery) (events []*repository.Event, err error) { func (db *CRDB) Filter(ctx context.Context, searchQuery *repository.SearchQuery) (events []*repository.Event, err error) {
rows, rowScanner, err := db.query(searchQuery) events = []*repository.Event{}
err = db.query(searchQuery, &events)
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer rows.Close()
for rows.Next() {
event := new(repository.Event)
err := rowScanner(rows.Scan, event)
if err != nil {
return nil, err
}
events = append(events, event)
}
return events, nil return events, nil
} }
//LatestSequence returns the latests sequence found by the the search query //LatestSequence returns the latests sequence found by the the search query
func (db *CRDB) LatestSequence(ctx context.Context, searchQuery *repository.SearchQuery) (uint64, error) { func (db *CRDB) LatestSequence(ctx context.Context, searchQuery *repository.SearchQuery) (uint64, error) {
rows, rowScanner, err := db.query(searchQuery)
if err != nil {
return 0, err
}
defer rows.Close()
if !rows.Next() {
return 0, caos_errs.ThrowNotFound(nil, "SQL-cAEzS", "latest sequence not found")
}
var seq Sequence var seq Sequence
err = rowScanner(rows.Scan, &seq) err := db.query(searchQuery, &seq)
if err != nil { if err != nil {
return 0, err return 0, err
} }
return uint64(seq), nil return uint64(seq), nil
} }
func (db *CRDB) query(searchQuery *repository.SearchQuery) (*sql.Rows, rowScan, error) { func (db *CRDB) query(searchQuery *repository.SearchQuery, data interface{}) error {
query, values, rowScanner := buildQuery(db, searchQuery) query, values, rowScanner := buildQuery(db, searchQuery)
if query == "" { if query == "" {
return nil, nil, caos_errs.ThrowInvalidArgument(nil, "SQL-rWeBw", "invalid query factory") return caos_errs.ThrowInvalidArgument(nil, "SQL-rWeBw", "invalid query factory")
} }
rows, err := db.client.Query(query, values...) rows, err := db.client.Query(query, values...)
if err != nil { if err != nil {
logging.Log("SQL-HP3Uk").WithError(err).Info("query failed") logging.Log("SQL-HP3Uk").WithError(err).Info("query failed")
return nil, nil, caos_errs.ThrowInternal(err, "SQL-IJuyR", "unable to filter events") return caos_errs.ThrowInternal(err, "SQL-IJuyR", "unable to filter events")
} }
return rows, rowScanner, nil defer rows.Close()
for rows.Next() {
err = rowScanner(rows.Scan, nil)
if err != nil {
return err
}
}
return nil
} }
func (db *CRDB) eventQuery() string { func (db *CRDB) eventQuery() string {

View File

@ -0,0 +1,262 @@
package sql
import (
"testing"
"github.com/caos/zitadel/internal/eventstore/v2/repository"
_ "github.com/lib/pq"
)
func TestCRDB_placeholder(t *testing.T) {
type args struct {
query string
}
type res struct {
query string
}
tests := []struct {
name string
args args
res res
}{
{
name: "no placeholders",
args: args{
query: "SELECT * FROM eventstore.events",
},
res: res{
query: "SELECT * FROM eventstore.events",
},
},
{
name: "one placeholder",
args: args{
query: "SELECT * FROM eventstore.events WHERE aggregate_type = ?",
},
res: res{
query: "SELECT * FROM eventstore.events WHERE aggregate_type = $1",
},
},
{
name: "multiple placeholders",
args: args{
query: "SELECT * FROM eventstore.events WHERE aggregate_type = ? AND aggregate_id = ? LIMIT ?",
},
res: res{
query: "SELECT * FROM eventstore.events WHERE aggregate_type = $1 AND aggregate_id = $2 LIMIT $3",
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
db := &CRDB{}
if query := db.placeholder(tt.args.query); query != tt.res.query {
t.Errorf("CRDB.placeholder() = %v, want %v", query, tt.res.query)
}
})
}
}
func TestCRDB_operation(t *testing.T) {
type res struct {
op string
}
type args struct {
operation repository.Operation
}
tests := []struct {
name string
args args
res res
}{
{
name: "no op",
args: args{
operation: repository.Operation(-1),
},
res: res{
op: "",
},
},
{
name: "greater",
args: args{
operation: repository.Operation_Greater,
},
res: res{
op: ">",
},
},
{
name: "less",
args: args{
operation: repository.Operation_Less,
},
res: res{
op: "<",
},
},
{
name: "equals",
args: args{
operation: repository.Operation_Equals,
},
res: res{
op: "=",
},
},
{
name: "in",
args: args{
operation: repository.Operation_In,
},
res: res{
op: "=",
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
db := &CRDB{}
if got := db.operation(tt.args.operation); got != tt.res.op {
t.Errorf("CRDB.operation() = %v, want %v", got, tt.res.op)
}
})
}
}
func TestCRDB_conditionFormat(t *testing.T) {
type res struct {
format string
}
type args struct {
operation repository.Operation
}
tests := []struct {
name string
args args
res res
}{
{
name: "default",
args: args{
operation: repository.Operation_Equals,
},
res: res{
format: "%s %s ?",
},
},
{
name: "in",
args: args{
operation: repository.Operation_In,
},
res: res{
format: "%s %s ANY(?)",
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
db := &CRDB{}
if got := db.conditionFormat(tt.args.operation); got != tt.res.format {
t.Errorf("CRDB.conditionFormat() = %v, want %v", got, tt.res.format)
}
})
}
}
func TestCRDB_columnName(t *testing.T) {
type res struct {
name string
}
type args struct {
field repository.Field
}
tests := []struct {
name string
args args
res res
}{
{
name: "invalid field",
args: args{
field: repository.Field(-1),
},
res: res{
name: "",
},
},
{
name: "aggregate id",
args: args{
field: repository.Field_AggregateID,
},
res: res{
name: "aggregate_id",
},
},
{
name: "aggregate type",
args: args{
field: repository.Field_AggregateType,
},
res: res{
name: "aggregate_type",
},
},
{
name: "editor service",
args: args{
field: repository.Field_EditorService,
},
res: res{
name: "editor_service",
},
},
{
name: "editor user",
args: args{
field: repository.Field_EditorUser,
},
res: res{
name: "editor_user",
},
},
{
name: "event type",
args: args{
field: repository.Field_EventType,
},
res: res{
name: "event_type",
},
},
{
name: "latest sequence",
args: args{
field: repository.Field_LatestSequence,
},
res: res{
name: "event_sequence",
},
},
{
name: "resource owner",
args: args{
field: repository.Field_ResourceOwner,
},
res: res{
name: "resource_owner",
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
db := &CRDB{}
if got := db.columnName(tt.args.field); got != tt.res.name {
t.Errorf("CRDB.operation() = %v, want %v", got, tt.res.name)
}
})
}
}

View File

@ -1,163 +1,164 @@
package sql package sql
import ( // import (
"database/sql" // "database/sql"
"fmt" // "fmt"
"io/ioutil" // "io/ioutil"
"os" // "os"
"path/filepath" // "path/filepath"
"sort" // "sort"
"strconv" // "strconv"
"strings" // "strings"
"testing" // "testing"
"time" // "time"
"github.com/caos/logging" // "github.com/caos/logging"
"github.com/cockroachdb/cockroach-go/v2/testserver" // "github.com/caos/zitadel/internal/eventstore/v2/repository"
) // "github.com/cockroachdb/cockroach-go/v2/testserver"
// )
var ( // var (
migrationsPath = os.ExpandEnv("${GOPATH}/src/github.com/caos/zitadel/migrations/cockroach") // migrationsPath = os.ExpandEnv("${GOPATH}/src/github.com/caos/zitadel/migrations/cockroach")
db *sql.DB // db *sql.DB
) // )
func TestMain(m *testing.M) { // func TestMain(m *testing.M) {
ts, err := testserver.NewTestServer() // ts, err := testserver.NewTestServer()
if err != nil { // if err != nil {
logging.LogWithFields("REPOS-RvjLG", "error", err).Fatal("unable to start db") // logging.LogWithFields("REPOS-RvjLG", "error", err).Fatal("unable to start db")
} // }
db, err = sql.Open("postgres", ts.PGURL().String()) // db, err = sql.Open("postgres", ts.PGURL().String())
if err != nil { // if err != nil {
logging.LogWithFields("REPOS-CF6dQ", "error", err).Fatal("unable to connect to db") // logging.LogWithFields("REPOS-CF6dQ", "error", err).Fatal("unable to connect to db")
} // }
defer func() { // defer func() {
db.Close() // db.Close()
ts.Stop() // ts.Stop()
}() // }()
if err = executeMigrations(); err != nil { // if err = executeMigrations(); err != nil {
logging.LogWithFields("REPOS-jehDD", "error", err).Fatal("migrations failed") // logging.LogWithFields("REPOS-jehDD", "error", err).Fatal("migrations failed")
} // }
os.Exit(m.Run()) // os.Exit(m.Run())
} // }
func TestInsert(t *testing.T) { // func TestInsert(t *testing.T) {
tx, _ := db.Begin() // tx, _ := db.Begin()
var seq Sequence // var seq Sequence
var d time.Time // var d time.Time
row := tx.QueryRow(crdbInsert, "event.type", "aggregate.type", "aggregate.id", Version("v1"), nil, Data(nil), "editor.user", "editor.service", "resource.owner", Sequence(0), false) // row := tx.QueryRow(crdbInsert, "event.type", "aggregate.type", "aggregate.id", repository.Version("v1"), nil, Data(nil), "editor.user", "editor.service", "resource.owner", Sequence(0), false)
err := row.Scan(&seq, &d) // err := row.Scan(&seq, &d)
row = tx.QueryRow(crdbInsert, "event.type", "aggregate.type", "aggregate.id", Version("v1"), nil, Data(nil), "editor.user", "editor.service", "resource.owner", Sequence(1), true) // row = tx.QueryRow(crdbInsert, "event.type", "aggregate.type", "aggregate.id", repository.Version("v1"), nil, Data(nil), "editor.user", "editor.service", "resource.owner", Sequence(1), true)
err = row.Scan(&seq, &d) // err = row.Scan(&seq, &d)
row = tx.QueryRow(crdbInsert, "event.type", "aggregate.type", "aggregate.id", Version("v1"), nil, Data(nil), "editor.user", "editor.service", "resource.owner", Sequence(0), false) // row = tx.QueryRow(crdbInsert, "event.type", "aggregate.type", "aggregate.id", repository.Version("v1"), nil, Data(nil), "editor.user", "editor.service", "resource.owner", Sequence(0), false)
err = row.Scan(&seq, &d) // err = row.Scan(&seq, &d)
tx.Commit() // tx.Commit()
rows, err := db.Query("select * from eventstore.events order by event_sequence") // rows, err := db.Query("select * from eventstore.events order by event_sequence")
defer rows.Close() // defer rows.Close()
fmt.Println(err) // fmt.Println(err)
fmt.Println(rows.Columns()) // fmt.Println(rows.Columns())
for rows.Next() { // for rows.Next() {
i := make([]interface{}, 12) // i := make([]interface{}, 12)
var id string // var id string
rows.Scan(&id, &i[1], &i[2], &i[3], &i[4], &i[5], &i[6], &i[7], &i[8], &i[9], &i[10], &i[11]) // rows.Scan(&id, &i[1], &i[2], &i[3], &i[4], &i[5], &i[6], &i[7], &i[8], &i[9], &i[10], &i[11])
i[0] = id // i[0] = id
fmt.Println(i) // fmt.Println(i)
} // }
t.Fail() // t.Fail()
} // }
func executeMigrations() error { // func executeMigrations() error {
files, err := migrationFilePaths() // files, err := migrationFilePaths()
if err != nil { // if err != nil {
return err // return err
} // }
sort.Sort(files) // sort.Sort(files)
for _, file := range files { // for _, file := range files {
migration, err := ioutil.ReadFile(string(file)) // migration, err := ioutil.ReadFile(string(file))
if err != nil { // if err != nil {
return err // return err
} // }
transactionInMigration := strings.Contains(string(migration), "BEGIN;") // transactionInMigration := strings.Contains(string(migration), "BEGIN;")
exec := db.Exec // exec := db.Exec
var tx *sql.Tx // var tx *sql.Tx
if !transactionInMigration { // if !transactionInMigration {
tx, err = db.Begin() // tx, err = db.Begin()
if err != nil { // if err != nil {
return fmt.Errorf("begin file: %v || err: %w", file, err) // return fmt.Errorf("begin file: %v || err: %w", file, err)
} // }
exec = tx.Exec // exec = tx.Exec
} // }
if _, err = exec(string(migration)); err != nil { // if _, err = exec(string(migration)); err != nil {
return fmt.Errorf("exec file: %v || err: %w", file, err) // return fmt.Errorf("exec file: %v || err: %w", file, err)
} // }
if !transactionInMigration { // if !transactionInMigration {
if err = tx.Commit(); err != nil { // if err = tx.Commit(); err != nil {
return fmt.Errorf("commit file: %v || err: %w", file, err) // return fmt.Errorf("commit file: %v || err: %w", file, err)
} // }
} // }
} // }
return nil // return nil
} // }
type migrationPaths []string // type migrationPaths []string
type version struct { // type version struct {
major int // major int
minor int // minor int
} // }
func versionFromPath(s string) version { // func versionFromPath(s string) version {
v := s[strings.Index(s, "/V")+2 : strings.Index(s, "__")] // v := s[strings.Index(s, "/V")+2 : strings.Index(s, "__")]
splitted := strings.Split(v, ".") // splitted := strings.Split(v, ".")
res := version{} // res := version{}
var err error // var err error
if len(splitted) >= 1 { // if len(splitted) >= 1 {
res.major, err = strconv.Atoi(splitted[0]) // res.major, err = strconv.Atoi(splitted[0])
if err != nil { // if err != nil {
panic(err) // panic(err)
} // }
} // }
if len(splitted) >= 2 { // if len(splitted) >= 2 {
res.minor, err = strconv.Atoi(splitted[1]) // res.minor, err = strconv.Atoi(splitted[1])
if err != nil { // if err != nil {
panic(err) // panic(err)
} // }
} // }
return res // return res
} // }
func (a migrationPaths) Len() int { return len(a) } // func (a migrationPaths) Len() int { return len(a) }
func (a migrationPaths) Swap(i, j int) { a[i], a[j] = a[j], a[i] } // func (a migrationPaths) Swap(i, j int) { a[i], a[j] = a[j], a[i] }
func (a migrationPaths) Less(i, j int) bool { // func (a migrationPaths) Less(i, j int) bool {
versionI := versionFromPath(a[i]) // versionI := versionFromPath(a[i])
versionJ := versionFromPath(a[j]) // versionJ := versionFromPath(a[j])
return versionI.major < versionJ.major || // return versionI.major < versionJ.major ||
(versionI.major == versionJ.major && versionI.minor < versionJ.minor) // (versionI.major == versionJ.major && versionI.minor < versionJ.minor)
} // }
func migrationFilePaths() (migrationPaths, error) { // func migrationFilePaths() (migrationPaths, error) {
files := make(migrationPaths, 0) // files := make(migrationPaths, 0)
err := filepath.Walk(migrationsPath, func(path string, info os.FileInfo, err error) error { // err := filepath.Walk(migrationsPath, func(path string, info os.FileInfo, err error) error {
if err != nil || info.IsDir() || !strings.HasSuffix(info.Name(), ".sql") { // if err != nil || info.IsDir() || !strings.HasSuffix(info.Name(), ".sql") {
return err // return err
} // }
files = append(files, path) // files = append(files, path)
return nil // return nil
}) // })
return files, err // return files, err
} // }

View File

@ -52,15 +52,15 @@ func buildQuery(criteria criteriaer, searchQuery *repository.SearchQuery) (query
func prepareColumns(criteria criteriaer, columns repository.Columns) (string, func(s scan, dest interface{}) error) { func prepareColumns(criteria criteriaer, columns repository.Columns) (string, func(s scan, dest interface{}) error) {
switch columns { switch columns {
case repository.Columns_Max_Sequence: case repository.Columns_Max_Sequence:
return criteria.maxSequenceQuery(), maxSequenceRowScanner return criteria.maxSequenceQuery(), maxSequenceScanner
case repository.Columns_Event: case repository.Columns_Event:
return criteria.eventQuery(), eventRowScanner return criteria.eventQuery(), eventsScanner
default: default:
return "", nil return "", nil
} }
} }
func maxSequenceRowScanner(row scan, dest interface{}) (err error) { func maxSequenceScanner(row scan, dest interface{}) (err error) {
sequence, ok := dest.(*Sequence) sequence, ok := dest.(*Sequence)
if !ok { if !ok {
return z_errors.ThrowInvalidArgument(nil, "SQL-NBjA9", "type must be sequence") return z_errors.ThrowInvalidArgument(nil, "SQL-NBjA9", "type must be sequence")
@ -72,15 +72,16 @@ func maxSequenceRowScanner(row scan, dest interface{}) (err error) {
return z_errors.ThrowInternal(err, "SQL-bN5xg", "something went wrong") return z_errors.ThrowInternal(err, "SQL-bN5xg", "something went wrong")
} }
func eventRowScanner(row scan, dest interface{}) (err error) { func eventsScanner(scanner scan, dest interface{}) (err error) {
event, ok := dest.(*repository.Event) events, ok := dest.(*[]*repository.Event)
if !ok { if !ok {
return z_errors.ThrowInvalidArgument(nil, "SQL-4GP6F", "type must be event") return z_errors.ThrowInvalidArgument(nil, "SQL-4GP6F", "type must be event")
} }
var previousSequence Sequence var previousSequence Sequence
data := make(Data, 0) data := make(Data, 0)
event := new(repository.Event)
err = row( err = scanner(
&event.CreationDate, &event.CreationDate,
&event.Type, &event.Type,
&event.Sequence, &event.Sequence,
@ -104,6 +105,8 @@ func eventRowScanner(row scan, dest interface{}) (err error) {
event.Data = make([]byte, len(data)) event.Data = make([]byte, len(data))
copy(event.Data, data) copy(event.Data, data)
*events = append(*events, event)
return nil return nil
} }
@ -115,7 +118,7 @@ func prepareCondition(criteria criteriaer, filters []*repository.Filter) (clause
return clause, values return clause, values
} }
for i, filter := range filters { for i, filter := range filters {
value := filter.Value() value := filter.Value
switch value.(type) { switch value.(type) {
case []bool, []float64, []int64, []string, []repository.AggregateType, []repository.EventType, *[]bool, *[]float64, *[]int64, *[]string, *[]repository.AggregateType, *[]repository.EventType: case []bool, []float64, []int64, []string, []repository.AggregateType, []repository.EventType, *[]bool, *[]float64, *[]int64, *[]string, *[]repository.AggregateType, *[]repository.EventType:
value = pq.Array(value) value = pq.Array(value)
@ -131,12 +134,12 @@ func prepareCondition(criteria criteriaer, filters []*repository.Filter) (clause
} }
func getCondition(cond criteriaer, filter *repository.Filter) (condition string) { func getCondition(cond criteriaer, filter *repository.Filter) (condition string) {
field := cond.columnName(filter.Field()) field := cond.columnName(filter.Field)
operation := cond.operation(filter.Operation()) operation := cond.operation(filter.Operation)
if field == "" || operation == "" { if field == "" || operation == "" {
return "" return ""
} }
format := cond.conditionFormat(filter.Operation()) format := cond.conditionFormat(filter.Operation)
return fmt.Sprintf(format, field, operation) return fmt.Sprintf(format, field, operation)
} }

View File

@ -11,120 +11,6 @@ import (
"github.com/lib/pq" "github.com/lib/pq"
) )
func Test_numberPlaceholder(t *testing.T) {
type args struct {
query string
old string
new string
}
type res struct {
query string
}
tests := []struct {
name string
args args
res res
}{
{
name: "no replaces",
args: args{
new: "$",
old: "?",
query: "SELECT * FROM eventstore.events",
},
res: res{
query: "SELECT * FROM eventstore.events",
},
},
{
name: "two replaces",
args: args{
new: "$",
old: "?",
query: "SELECT * FROM eventstore.events WHERE aggregate_type = ? AND LIMIT = ?",
},
res: res{
query: "SELECT * FROM eventstore.events WHERE aggregate_type = $1 AND LIMIT = $2",
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := numberPlaceholder(tt.args.query, tt.args.old, tt.args.new); got != tt.res.query {
t.Errorf("numberPlaceholder() = %v, want %v", got, tt.res.query)
}
})
}
}
func Test_getOperation(t *testing.T) {
t.Run("all ops", func(t *testing.T) {
for op, expected := range map[repository.Operation]string{
repository.Operation_Equals: "=",
repository.Operation_In: "=",
repository.Operation_Greater: ">",
repository.Operation_Less: "<",
repository.Operation(-1): "",
} {
if got := getOperation(op); got != expected {
t.Errorf("getOperation() = %v, want %v", got, expected)
}
}
})
}
func Test_getField(t *testing.T) {
t.Run("all fields", func(t *testing.T) {
for field, expected := range map[repository.Field]string{
repository.Field_AggregateType: "aggregate_type",
repository.Field_AggregateID: "aggregate_id",
repository.Field_LatestSequence: "event_sequence",
repository.Field_ResourceOwner: "resource_owner",
repository.Field_EditorService: "editor_service",
repository.Field_EditorUser: "editor_user",
repository.Field_EventType: "event_type",
repository.Field(-1): "",
} {
if got := getField(field); got != expected {
t.Errorf("getField() = %v, want %v", got, expected)
}
}
})
}
func Test_getConditionFormat(t *testing.T) {
type args struct {
operation repository.Operation
}
tests := []struct {
name string
args args
want string
}{
{
name: "no in operation",
args: args{
operation: repository.Operation_Equals,
},
want: "%s %s ?",
},
{
name: "in operation",
args: args{
operation: repository.Operation_In,
},
want: "%s %s ANY(?)",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := getConditionFormat(tt.args.operation); got != tt.want {
t.Errorf("prepareConditionFormat() = %v, want %v", got, tt.want)
}
})
}
}
func Test_getCondition(t *testing.T) { func Test_getCondition(t *testing.T) {
type args struct { type args struct {
filter *repository.Filter filter *repository.Filter
@ -172,7 +58,8 @@ func Test_getCondition(t *testing.T) {
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
if got := getCondition(tt.args.filter); got != tt.want { db := &CRDB{}
if got := getCondition(db, tt.args.filter); got != tt.want {
t.Errorf("getCondition() = %v, want %v", got, tt.want) t.Errorf("getCondition() = %v, want %v", got, tt.want)
} }
}) })
@ -180,6 +67,9 @@ func Test_getCondition(t *testing.T) {
} }
func Test_prepareColumns(t *testing.T) { func Test_prepareColumns(t *testing.T) {
type fields struct {
dbRow []interface{}
}
type args struct { type args struct {
columns repository.Columns columns repository.Columns
dest interface{} dest interface{}
@ -187,7 +77,6 @@ func Test_prepareColumns(t *testing.T) {
} }
type res struct { type res struct {
query string query string
dbRow []interface{}
expected interface{} expected interface{}
dbErr func(error) bool dbErr func(error) bool
} }
@ -195,6 +84,7 @@ func Test_prepareColumns(t *testing.T) {
name string name string
args args args args
res res res res
fields fields
}{ }{
{ {
name: "invalid columns", name: "invalid columns",
@ -212,9 +102,11 @@ func Test_prepareColumns(t *testing.T) {
}, },
res: res{ res: res{
query: "SELECT MAX(event_sequence) FROM eventstore.events", query: "SELECT MAX(event_sequence) FROM eventstore.events",
dbRow: []interface{}{Sequence(5)},
expected: Sequence(5), expected: Sequence(5),
}, },
fields: fields{
dbRow: []interface{}{Sequence(5)},
},
}, },
{ {
name: "max sequence wrong dest type", name: "max sequence wrong dest type",
@ -228,22 +120,26 @@ func Test_prepareColumns(t *testing.T) {
}, },
}, },
{ {
name: "event", name: "events",
args: args{ args: args{
columns: repository.Columns_Event, columns: repository.Columns_Event,
dest: new(repository.Event), dest: &[]*repository.Event{},
}, },
res: res{ res: res{
query: "SELECT creation_date, event_type, event_sequence, previous_sequence, event_data, editor_service, editor_user, resource_owner, aggregate_type, aggregate_id, aggregate_version FROM eventstore.events", query: "SELECT creation_date, event_type, event_sequence, previous_sequence, event_data, editor_service, editor_user, resource_owner, aggregate_type, aggregate_id, aggregate_version FROM eventstore.events",
expected: []*repository.Event{
{AggregateID: "hodor", AggregateType: "user", Sequence: 5, Data: make(Data, 0)},
},
},
fields: fields{
dbRow: []interface{}{time.Time{}, repository.EventType(""), uint64(5), Sequence(0), Data(nil), "", "", "", repository.AggregateType("user"), "hodor", repository.Version("")}, dbRow: []interface{}{time.Time{}, repository.EventType(""), uint64(5), Sequence(0), Data(nil), "", "", "", repository.AggregateType("user"), "hodor", repository.Version("")},
expected: repository.Event{AggregateID: "hodor", AggregateType: "user", Sequence: 5, Data: make(Data, 0)},
}, },
}, },
{ {
name: "event wrong dest type", name: "events wrong dest type",
args: args{ args: args{
columns: repository.Columns_Event, columns: repository.Columns_Event,
dest: new(uint64), dest: []*repository.Event{},
}, },
res: res{ res: res{
query: "SELECT creation_date, event_type, event_sequence, previous_sequence, event_data, editor_service, editor_user, resource_owner, aggregate_type, aggregate_id, aggregate_version FROM eventstore.events", query: "SELECT creation_date, event_type, event_sequence, previous_sequence, event_data, editor_service, editor_user, resource_owner, aggregate_type, aggregate_id, aggregate_version FROM eventstore.events",
@ -254,7 +150,7 @@ func Test_prepareColumns(t *testing.T) {
name: "event query error", name: "event query error",
args: args{ args: args{
columns: repository.Columns_Event, columns: repository.Columns_Event,
dest: new(repository.Event), dest: &[]*repository.Event{},
dbErr: sql.ErrConnDone, dbErr: sql.ErrConnDone,
}, },
res: res{ res: res{
@ -265,9 +161,10 @@ func Test_prepareColumns(t *testing.T) {
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
query, rowScanner := prepareColumns(tt.args.columns) crdb := &CRDB{}
query, rowScanner := prepareColumns(crdb, tt.args.columns)
if query != tt.res.query { if query != tt.res.query {
t.Errorf("prepareColumns() got = %v, want %v", query, tt.res.query) t.Errorf("prepareColumns() got = %s, want %s", query, tt.res.query)
} }
if tt.res.query == "" && rowScanner != nil { if tt.res.query == "" && rowScanner != nil {
t.Errorf("row scanner should be nil") t.Errorf("row scanner should be nil")
@ -275,15 +172,16 @@ func Test_prepareColumns(t *testing.T) {
if rowScanner == nil { if rowScanner == nil {
return return
} }
err := rowScanner(prepareTestScan(tt.args.dbErr, tt.res.dbRow), tt.args.dest) err := rowScanner(prepareTestScan(tt.args.dbErr, tt.fields.dbRow), tt.args.dest)
if tt.res.dbErr != nil { if err != nil && tt.res.dbErr == nil || err != nil && !tt.res.dbErr(err) || err == nil && tt.res.dbErr != nil {
if !tt.res.dbErr(err) {
t.Errorf("wrong error type in rowScanner got: %v", err) t.Errorf("wrong error type in rowScanner got: %v", err)
return
}
if tt.res.dbErr != nil && tt.res.dbErr(err) {
return
} }
} else {
if !reflect.DeepEqual(reflect.Indirect(reflect.ValueOf(tt.args.dest)).Interface(), tt.res.expected) { if !reflect.DeepEqual(reflect.Indirect(reflect.ValueOf(tt.args.dest)).Interface(), tt.res.expected) {
t.Errorf("unexpected result from rowScanner want: %v got: %v", tt.res.dbRow, tt.args.dest) t.Errorf("unexpected result from rowScanner \nwant: %+v \ngot: %+v", tt.fields.dbRow, reflect.Indirect(reflect.ValueOf(tt.args.dest)).Interface())
}
} }
}) })
} }
@ -379,7 +277,8 @@ func Test_prepareCondition(t *testing.T) {
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
gotClause, gotValues := prepareCondition(tt.args.filters) crdb := &CRDB{}
gotClause, gotValues := prepareCondition(crdb, tt.args.filters)
if gotClause != tt.res.clause { if gotClause != tt.res.clause {
t.Errorf("prepareCondition() gotClause = %v, want %v", gotClause, tt.res.clause) t.Errorf("prepareCondition() gotClause = %v, want %v", gotClause, tt.res.clause)
} }
@ -398,11 +297,10 @@ func Test_prepareCondition(t *testing.T) {
func Test_buildQuery(t *testing.T) { func Test_buildQuery(t *testing.T) {
type args struct { type args struct {
queryFactory *repository.SearchQuery query *repository.SearchQuery
} }
type res struct { type res struct {
query string query string
limit uint64
values []interface{} values []interface{}
rowScanner bool rowScanner bool
} }
@ -411,23 +309,34 @@ func Test_buildQuery(t *testing.T) {
args args args args
res res res res
}{ }{
{
name: "invalid query factory", // {
args: args{ //removed because it's no valid test case
queryFactory: nil, // name: "no query",
}, // args: args{
res: res{ // query: nil,
query: "", // },
limit: 0, // res: res{
rowScanner: false, // query: "",
values: nil, // rowScanner: false,
}, // values: nil,
}, // },
// },
{ {
name: "with order by desc", name: "with order by desc",
args: args{ args: args{
// NewSearchQueryFactory("user").OrderDesc() // NewSearchQueryFactory("user").OrderDesc()
queryFactory: &repository.SearchQuery{Desc: true}, query: &repository.SearchQuery{
Columns: repository.Columns_Event,
Desc: true,
Filters: []*repository.Filter{
{
Field: repository.Field_AggregateType,
Value: repository.AggregateType("user"),
Operation: repository.Operation_Equals,
},
},
},
}, },
res: res{ res: res{
query: "SELECT creation_date, event_type, event_sequence, previous_sequence, event_data, editor_service, editor_user, resource_owner, aggregate_type, aggregate_id, aggregate_version FROM eventstore.events WHERE aggregate_type = $1 ORDER BY event_sequence DESC", query: "SELECT creation_date, event_type, event_sequence, previous_sequence, event_data, editor_service, editor_user, resource_owner, aggregate_type, aggregate_id, aggregate_version FROM eventstore.events WHERE aggregate_type = $1 ORDER BY event_sequence DESC",
@ -438,45 +347,61 @@ func Test_buildQuery(t *testing.T) {
{ {
name: "with limit", name: "with limit",
args: args{ args: args{
queryFactory: repository.NewSearchQueryFactory("user").Limit(5), query: &repository.SearchQuery{
Columns: repository.Columns_Event,
Desc: false,
Limit: 5,
Filters: []*repository.Filter{
{
Field: repository.Field_AggregateType,
Value: repository.AggregateType("user"),
Operation: repository.Operation_Equals,
},
},
},
}, },
res: res{ res: res{
query: "SELECT creation_date, event_type, event_sequence, previous_sequence, event_data, editor_service, editor_user, resource_owner, aggregate_type, aggregate_id, aggregate_version FROM eventstore.events WHERE aggregate_type = $1 ORDER BY event_sequence LIMIT $2", query: "SELECT creation_date, event_type, event_sequence, previous_sequence, event_data, editor_service, editor_user, resource_owner, aggregate_type, aggregate_id, aggregate_version FROM eventstore.events WHERE aggregate_type = $1 ORDER BY event_sequence LIMIT $2",
rowScanner: true, rowScanner: true,
values: []interface{}{repository.AggregateType("user"), uint64(5)}, values: []interface{}{repository.AggregateType("user"), uint64(5)},
limit: 5,
}, },
}, },
{ {
name: "with limit and order by desc", name: "with limit and order by desc",
args: args{ args: args{
queryFactory: repository.NewSearchQueryFactory("user").Limit(5).OrderDesc(), query: &repository.SearchQuery{
Columns: repository.Columns_Event,
Desc: true,
Limit: 5,
Filters: []*repository.Filter{
{
Field: repository.Field_AggregateType,
Value: repository.AggregateType("user"),
Operation: repository.Operation_Equals,
},
},
},
}, },
res: res{ res: res{
query: "SELECT creation_date, event_type, event_sequence, previous_sequence, event_data, editor_service, editor_user, resource_owner, aggregate_type, aggregate_id, aggregate_version FROM eventstore.events WHERE aggregate_type = $1 ORDER BY event_sequence DESC LIMIT $2", query: "SELECT creation_date, event_type, event_sequence, previous_sequence, event_data, editor_service, editor_user, resource_owner, aggregate_type, aggregate_id, aggregate_version FROM eventstore.events WHERE aggregate_type = $1 ORDER BY event_sequence DESC LIMIT $2",
rowScanner: true, rowScanner: true,
values: []interface{}{repository.AggregateType("user"), uint64(5)}, values: []interface{}{repository.AggregateType("user"), uint64(5)},
limit: 5,
}, },
}, },
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
gotQuery, gotLimit, gotValues, gotRowScanner := buildQuery(tt.args.queryFactory) crdb := &CRDB{}
gotQuery, gotValues, gotRowScanner := buildQuery(crdb, tt.args.query)
if gotQuery != tt.res.query { if gotQuery != tt.res.query {
t.Errorf("buildQuery() gotQuery = %v, want %v", gotQuery, tt.res.query) t.Errorf("buildQuery() gotQuery = %v, want %v", gotQuery, tt.res.query)
} }
if gotLimit != tt.res.limit {
t.Errorf("buildQuery() gotLimit = %v, want %v", gotLimit, tt.res.limit)
}
if len(gotValues) != len(tt.res.values) { if len(gotValues) != len(tt.res.values) {
t.Errorf("wrong length of gotten values got = %d, want %d", len(gotValues), len(tt.res.values)) t.Errorf("wrong length of gotten values got = %d, want %d", len(gotValues), len(tt.res.values))
return return
} }
for i, value := range gotValues { if !reflect.DeepEqual(gotValues, tt.res.values) {
if !reflect.DeepEqual(value, tt.res.values[i]) { t.Errorf("prepareCondition() gotValues = %T: %v, want %T: %v", gotValues, gotValues, tt.res.values, tt.res.values)
t.Errorf("prepareCondition() gotValues = %v, want %v", gotValues, tt.res.values)
}
} }
if (tt.res.rowScanner && gotRowScanner == nil) || (!tt.res.rowScanner && gotRowScanner != nil) { if (tt.res.rowScanner && gotRowScanner == nil) || (!tt.res.rowScanner && gotRowScanner != nil) {
t.Errorf("rowScanner should be nil==%v got nil==%v", tt.res.rowScanner, gotRowScanner == nil) t.Errorf("rowScanner should be nil==%v got nil==%v", tt.res.rowScanner, gotRowScanner == nil)
@ -484,5 +409,3 @@ func Test_buildQuery(t *testing.T) {
}) })
} }
} }
// func buildQuery(t *testing.T, factory *reposear)