testing with local cockroach started for tests and migrations

This commit is contained in:
adlerhurst
2020-10-02 16:21:51 +02:00
parent 169b1787df
commit eb51a429ff
33 changed files with 1277 additions and 1185 deletions

View File

@@ -87,20 +87,20 @@ func (es *Eventstore) aggregatesToEvents(aggregates []aggregater) ([]*repository
for _, aggregate := range aggregates {
var previousEvent *repository.Event
for _, event := range aggregate.Events() {
//TODO: map event.Data() into json
var data []byte
events = append(events, &repository.Event{
AggregateID: aggregate.ID(),
AggregateType: repository.AggregateType(aggregate.Type()),
ResourceOwner: aggregate.ResourceOwner(),
EditorService: event.EditorService(),
EditorUser: event.EditorUser(),
Type: repository.EventType(event.Type()),
Version: repository.Version(aggregate.Version()),
PreviousEvent: previousEvent,
Data: event.Data(),
AggregateID: aggregate.ID(),
AggregateType: repository.AggregateType(aggregate.Type()),
ResourceOwner: aggregate.ResourceOwner(),
EditorService: event.EditorService(),
EditorUser: event.EditorUser(),
Type: repository.EventType(event.Type()),
Version: repository.Version(aggregate.Version()),
PreviousEvent: previousEvent,
Data: data,
PreviousSequence: event.PreviousSequence(),
})
if previousEvent != nil && event.CheckPrevious() {
events[len(events)-1].PreviousSequence = event.PreviousSequence()
}
previousEvent = events[len(events)-1]
}
}

View File

@@ -1,17 +0,0 @@
package repository
import (
"testing"
"github.com/cockroachdb/cockroach/pkg/base"
"github.com/cockroachdb/cockroach/pkg/server"
"github.com/cockroachdb/cockroach/pkg/testutils/serverutils"
)
func TestBlub(t *testing.T) {
s, db, kvDB := serverutils.StartServer(t, base.TestServerArgs{})
defer s.Stopper().Stop()
// If really needed, in tests that can depend on server, downcast to
// server.TestServer:
ts := s.(*server.TestServer)
}

View File

@@ -0,0 +1,227 @@
package repository
import (
"context"
"database/sql"
"errors"
"github.com/caos/logging"
caos_errs "github.com/caos/zitadel/internal/errors"
"github.com/cockroachdb/cockroach-go/v2/crdb"
//sql import for cockroach
_ "github.com/lib/pq"
)
const (
crdbInsert = "WITH input_event ( " +
" event_type, " +
" aggregate_type, " +
" aggregate_id, " +
" aggregate_version, " +
" creation_date, " +
" event_data, " +
" editor_user, " +
" editor_service, " +
" resource_owner, " +
" previous_sequence, " +
" check_previous, " +
// variables below are calculated
" max_event_seq, " +
" event_count " +
") " +
" AS( " +
" SELECT " +
" $1::VARCHAR," +
" $2::VARCHAR," +
" $3::VARCHAR," +
" $4::VARCHAR," +
" COALESCE($5::TIMESTAMPTZ, NOW()), " +
" $6::JSONB, " +
" $7::VARCHAR, " +
" $8::VARCHAR, " +
" $9::VARCHAR, " +
" $10::BIGINT, " +
" $11::BOOLEAN," +
" MAX(event_sequence) AS max_event_seq, " +
" COUNT(*) AS event_count " +
" FROM eventstore.events " +
" WHERE " +
" aggregate_type = $2::VARCHAR " +
" AND aggregate_id = $3::VARCHAR " +
") " +
"INSERT INTO eventstore.events " +
" ( " +
" event_type, " +
" aggregate_type," +
" aggregate_id, " +
" aggregate_version, " +
" creation_date, " +
" event_data, " +
" editor_user, " +
" editor_service, " +
" resource_owner, " +
" previous_sequence " +
" ) " +
" ( " +
" SELECT " +
" event_type, " +
" aggregate_type," +
" aggregate_id, " +
" aggregate_version, " +
" COALESCE(creation_date, NOW()), " +
" event_data, " +
" editor_user, " +
" editor_service, " +
" resource_owner, " +
" ( " +
" SELECT " +
" CASE " +
" WHEN NOT check_previous THEN " +
" max_event_seq " +
" ELSE " +
" previous_sequence " +
" END" +
" ) " +
" FROM input_event " +
" WHERE EXISTS ( " +
" SELECT " +
" CASE " +
" WHEN NOT check_previous THEN 1 " +
" ELSE ( " +
" SELECT 1 FROM input_event " +
" WHERE max_event_seq = previous_sequence OR (previous_sequence IS NULL AND event_count = 0) " +
" ) " +
" END " +
" ) " +
" ) " +
"RETURNING event_sequence, creation_date "
)
type CRDB struct {
db *sql.DB
}
func (db *CRDB) Health(ctx context.Context) error { return db.db.Ping() }
// Push adds all events to the eventstreams of the aggregates.
// This call is transaction save. The transaction will be rolled back if one event fails
func (db *CRDB) Push(ctx context.Context, events ...*Event) error {
err := crdb.ExecuteTx(ctx, db.db, nil, func(tx *sql.Tx) error {
stmt, err := tx.PrepareContext(ctx, crdbInsert)
if err != nil {
tx.Rollback()
logging.Log("SQL-3to5p").WithError(err).Warn("prepare failed")
return caos_errs.ThrowInternal(err, "SQL-OdXRE", "prepare failed")
}
for _, event := range events {
previousSequence := event.PreviousSequence
if event.PreviousEvent != nil {
previousSequence = event.PreviousSequence
}
err = stmt.QueryRowContext(ctx,
event.Type,
event.AggregateType,
event.AggregateID,
event.Version,
event.CreationDate,
event.Data,
event.EditorUser,
event.EditorService,
event.ResourceOwner,
previousSequence,
event.CheckPreviousSequence,
).Scan(&event.Sequence, &event.CreationDate)
if err != nil {
tx.Rollback()
logging.LogWithFields("SQL-IP3js",
"aggregate", event.AggregateType,
"aggregateId", event.AggregateID,
"aggregateType", event.AggregateType,
"eventType", event.Type).WithError(err).Info("query failed")
return caos_errs.ThrowInternal(err, "SQL-SBP37", "unable to create event")
}
}
return nil
})
if err != nil && !errors.Is(err, &caos_errs.CaosError{}) {
err = caos_errs.ThrowInternal(err, "SQL-DjgtG", "unable to store events")
}
return err
}
// Filter returns all events matching the given search query
// func (db *CRDB) Filter(ctx context.Context, searchQuery *SearchQuery) (events []*Event, err error) {
// return events, nil
// }
//LatestSequence returns the latests sequence found by the the search query
func (db *CRDB) LatestSequence(ctx context.Context, queryFactory *SearchQuery) (uint64, error) {
return 0, nil
}
func (db *CRDB) prepareQuery(columns Columns) (string, func(s scanner, dest interface{}) error) {
switch columns {
case Columns_Max_Sequence:
return "SELECT MAX(event_sequence) FROM eventstore.events", func(scan scanner, dest interface{}) (err error) {
sequence, ok := dest.(*Sequence)
if !ok {
return caos_errs.ThrowInvalidArgument(nil, "SQL-NBjA9", "type must be sequence")
}
err = scan(sequence)
if err == nil || errors.Is(err, sql.ErrNoRows) {
return nil
}
return caos_errs.ThrowInternal(err, "SQL-bN5xg", "something went wrong")
}
case Columns_Event:
return selectStmt, func(row scanner, dest interface{}) (err error) {
event, ok := dest.(*Event)
if !ok {
return caos_errs.ThrowInvalidArgument(nil, "SQL-4GP6F", "type must be event")
}
var previousSequence Sequence
data := make(Data, 0)
err = row(
&event.CreationDate,
&event.Type,
&event.Sequence,
&previousSequence,
&data,
&event.EditorService,
&event.EditorUser,
&event.ResourceOwner,
&event.AggregateType,
&event.AggregateID,
&event.Version,
)
if err != nil {
logging.Log("SQL-kn1Sw").WithError(err).Warn("unable to scan row")
return caos_errs.ThrowInternal(err, "SQL-J0hFS", "unable to scan row")
}
event.PreviousSequence = uint64(previousSequence)
event.Data = make([]byte, len(data))
copy(event.Data, data)
return nil
}
default:
return "", nil
}
}
func (db *CRDB) prepareFilter(filters []*Filter) string {
filter := ""
// for _, f := range filters{
// f.
// }
return filter
}

View File

@@ -1,6 +1,7 @@
package repository
import (
"database/sql"
"time"
)
@@ -11,6 +12,7 @@ type Event struct {
//Sequence is the sequence of the event
Sequence uint64
//PreviousSequence is the sequence of the previous sequence
// if it's 0 then it's the first event of this aggregate
PreviousSequence uint64
@@ -19,6 +21,10 @@ type Event struct {
// it implements a linked list
PreviousEvent *Event
//CheckPreviousSequence decides if the event can only be written
// if event.PreviousSequence == max(event_sequence) of this aggregate
CheckPreviousSequence bool
//CreationDate is the time the event is created
// it's used for human readability.
// Don't use it for event ordering,
@@ -31,7 +37,7 @@ type Event struct {
//Data describe the changed fields (e.g. userName = "hodor")
// data must always a pointer to a struct, a struct or a byte array containing json bytes
Data interface{}
Data []byte
//EditorService should be a unique identifier for the service which created the event
// it's meant for maintainability
@@ -54,6 +60,8 @@ type Event struct {
// an aggregate can only be managed by one organisation
// use the ID of the org
ResourceOwner string
stmt *sql.Stmt
}
//EventType is the description of the change

View File

@@ -0,0 +1,62 @@
package repository
import (
"context"
"database/sql"
"github.com/caos/logging"
"github.com/caos/zitadel/internal/errors"
es_models "github.com/caos/zitadel/internal/eventstore/models"
)
type Querier interface {
Query(query string, args ...interface{}) (*sql.Rows, error)
}
// Filter returns all events matching the given search query
func (db *CRDB) Filter(ctx context.Context, searchQuery *es_models.SearchQueryFactory) (events []*Event, err error) {
return filter(db.db, searchQuery)
}
func filter(querier Querier, searchQuery *es_models.SearchQueryFactory) (events []*Event, err error) {
query, limit, values, rowScanner := buildQuery(searchQuery)
if query == "" {
return nil, errors.ThrowInvalidArgument(nil, "SQL-rWeBw", "invalid query factory")
}
rows, err := querier.Query(query, values...)
if err != nil {
logging.Log("SQL-HP3Uk").WithError(err).Info("query failed")
return nil, errors.ThrowInternal(err, "SQL-IJuyR", "unable to filter events")
}
defer rows.Close()
events = make([]*Event, 0, limit)
for rows.Next() {
event := new(Event)
err := rowScanner(rows.Scan, event)
if err != nil {
return nil, err
}
events = append(events, event)
}
return events, nil
}
// func (db *SQL) LatestSequence(ctx context.Context, queryFactory *es_models.SearchQueryFactory) (uint64, error) {
// query, _, values, rowScanner := buildQuery(queryFactory)
// if query == "" {
// return 0, errors.ThrowInvalidArgument(nil, "SQL-rWeBw", "invalid query factory")
// }
// row := db.client.QueryRow(query, values...)
// sequence := new(Sequence)
// err := rowScanner(row.Scan, sequence)
// if err != nil {
// logging.Log("SQL-WsxTg").WithError(err).Info("query failed")
// return 0, errors.ThrowInternal(err, "SQL-Yczyx", "unable to filter latest sequence")
// }
// return uint64(*sequence), nil
// }

View File

@@ -1,82 +0,0 @@
package repository
import "context"
type InMemory struct {
events []*Event
}
func (repo *InMemory) Health(ctx context.Context) error { return nil }
// PushEvents adds all events of the given aggregates to the eventstreams of the aggregates.
// This call is transaction save. The transaction will be rolled back if one event fails
func (repo *InMemory) Push(ctx context.Context, events ...*Event) error {
repo.events = append(repo.events, events...)
return nil
}
// Filter returns all events matching the given search query
func (repo *InMemory) Filter(ctx context.Context, searchQuery *SearchQuery) (events []*Event, err error) {
indexes := repo.filter(searchQuery)
events = make([]*Event, len(indexes))
for i, index := range indexes {
events[i] = repo.events[index]
}
return events, nil
}
func (repo *InMemory) filter(query *SearchQuery) []int {
foundIndex := make([]int, 0, query.Limit)
events:
for i, event := range repo.events {
if query.Limit > 0 && uint64(len(foundIndex)) < query.Limit {
return foundIndex
}
for _, filter := range query.Filters {
var value interface{}
switch filter.field {
case Field_AggregateID:
value = event.AggregateID
case Field_EditorService:
value = event.EditorService
case Field_EventType:
value = event.Type
case Field_AggregateType:
value = event.AggregateType
case Field_EditorUser:
value = event.EditorUser
case Field_ResourceOwner:
value = event.ResourceOwner
case Field_LatestSequence:
value = event.Sequence
}
switch filter.operation {
case Operation_Equals:
if filter.value == value {
foundIndex = append(foundIndex, i)
}
case Operation_Greater:
fallthrough
case Operation_Less:
return nil
case Operation_In:
values := filter.Value().([]interface{})
for _, val := range values {
if val == value {
foundIndex = append(foundIndex, i)
continue events
}
}
}
}
}
return foundIndex
}
//LatestSequence returns the latests sequence found by the the search query
func (repo *InMemory) LatestSequence(ctx context.Context, queryFactory *SearchQuery) (uint64, error) {
return 0, nil
}

View File

@@ -0,0 +1,163 @@
package repository
import (
"database/sql"
"fmt"
"io/ioutil"
"os"
"path/filepath"
"sort"
"strconv"
"strings"
"testing"
"time"
"github.com/caos/logging"
"github.com/cockroachdb/cockroach-go/v2/testserver"
)
var (
migrationsPath = os.ExpandEnv("${GOPATH}/src/github.com/caos/zitadel/migrations/cockroach")
db *sql.DB
)
func TestMain(m *testing.M) {
ts, err := testserver.NewTestServer()
if err != nil {
logging.LogWithFields("REPOS-RvjLG", "error", err).Fatal("unable to start db")
}
db, err = sql.Open("postgres", ts.PGURL().String())
if err != nil {
logging.LogWithFields("REPOS-CF6dQ", "error", err).Fatal("unable to connect to db")
}
defer func() {
db.Close()
ts.Stop()
}()
if err = executeMigrations(); err != nil {
logging.LogWithFields("REPOS-jehDD", "error", err).Fatal("migrations failed")
}
os.Exit(m.Run())
}
func TestInsert(t *testing.T) {
tx, _ := db.Begin()
var seq Sequence
var d time.Time
row := tx.QueryRow(crdbInsert, "event.type", "aggregate.type", "aggregate.id", Version("v1"), nil, Data(nil), "editor.user", "editor.service", "resource.owner", Sequence(0), false)
err := row.Scan(&seq, &d)
row = tx.QueryRow(crdbInsert, "event.type", "aggregate.type", "aggregate.id", Version("v1"), nil, Data(nil), "editor.user", "editor.service", "resource.owner", Sequence(1), true)
err = row.Scan(&seq, &d)
row = tx.QueryRow(crdbInsert, "event.type", "aggregate.type", "aggregate.id", Version("v1"), nil, Data(nil), "editor.user", "editor.service", "resource.owner", Sequence(0), false)
err = row.Scan(&seq, &d)
tx.Commit()
rows, err := db.Query("select * from eventstore.events order by event_sequence")
defer rows.Close()
fmt.Println(err)
fmt.Println(rows.Columns())
for rows.Next() {
i := make([]interface{}, 12)
var id string
rows.Scan(&id, &i[1], &i[2], &i[3], &i[4], &i[5], &i[6], &i[7], &i[8], &i[9], &i[10], &i[11])
i[0] = id
fmt.Println(i)
}
t.Fail()
}
func executeMigrations() error {
files, err := migrationFilePaths()
if err != nil {
return err
}
sort.Sort(files)
for _, file := range files {
migration, err := ioutil.ReadFile(string(file))
if err != nil {
return err
}
transactionInMigration := strings.Contains(string(migration), "BEGIN;")
exec := db.Exec
var tx *sql.Tx
if !transactionInMigration {
tx, err = db.Begin()
if err != nil {
return fmt.Errorf("begin file: %v || err: %w", file, err)
}
exec = tx.Exec
}
if _, err = exec(string(migration)); err != nil {
return fmt.Errorf("exec file: %v || err: %w", file, err)
}
if !transactionInMigration {
if err = tx.Commit(); err != nil {
return fmt.Errorf("commit file: %v || err: %w", file, err)
}
}
}
return nil
}
type migrationPaths []string
type version struct {
major int
minor int
}
func versionFromPath(s string) version {
v := s[strings.Index(s, "/V")+2 : strings.Index(s, "__")]
splitted := strings.Split(v, ".")
res := version{}
var err error
if len(splitted) >= 1 {
res.major, err = strconv.Atoi(splitted[0])
if err != nil {
panic(err)
}
}
if len(splitted) >= 2 {
res.minor, err = strconv.Atoi(splitted[1])
if err != nil {
panic(err)
}
}
return res
}
func (a migrationPaths) Len() int { return len(a) }
func (a migrationPaths) Swap(i, j int) { a[i], a[j] = a[j], a[i] }
func (a migrationPaths) Less(i, j int) bool {
versionI := versionFromPath(a[i])
versionJ := versionFromPath(a[j])
return versionI.major < versionJ.major ||
(versionI.major == versionJ.major && versionI.minor < versionJ.minor)
}
func migrationFilePaths() (migrationPaths, error) {
files := make(migrationPaths, 0)
err := filepath.Walk(migrationsPath, func(path string, info os.FileInfo, err error) error {
if err != nil || info.IsDir() || !strings.HasSuffix(info.Name(), ".sql") {
return err
}
files = append(files, path)
return nil
})
return files, err
}

View File

@@ -0,0 +1,199 @@
package repository
import (
"database/sql"
"errors"
"fmt"
"strconv"
"strings"
"github.com/caos/logging"
z_errors "github.com/caos/zitadel/internal/errors"
"github.com/caos/zitadel/internal/eventstore/models"
es_models "github.com/caos/zitadel/internal/eventstore/models"
"github.com/lib/pq"
)
const (
selectStmt = "SELECT" +
" creation_date" +
", event_type" +
", event_sequence" +
", previous_sequence" +
", event_data" +
", editor_service" +
", editor_user" +
", resource_owner" +
", aggregate_type" +
", aggregate_id" +
", aggregate_version" +
" FROM eventstore.events"
)
func buildQuery(queryFactory *models.SearchQueryFactory) (query string, limit uint64, values []interface{}, rowScanner func(s scanner, dest interface{}) error) {
searchQuery, err := queryFactory.Build()
if err != nil {
logging.Log("SQL-cshKu").WithError(err).Warn("search query factory invalid")
return "", 0, nil, nil
}
query, rowScanner = prepareColumns(searchQuery.Columns)
where, values := prepareCondition(searchQuery.Filters)
if where == "" || query == "" {
return "", 0, nil, nil
}
query += where
if searchQuery.Columns != models.Columns_Max_Sequence {
query += " ORDER BY event_sequence"
if searchQuery.Desc {
query += " DESC"
}
}
if searchQuery.Limit > 0 {
values = append(values, searchQuery.Limit)
query += " LIMIT ?"
}
query = numberPlaceholder(query, "?", "$")
return query, searchQuery.Limit, values, rowScanner
}
func prepareCondition(filters []*models.Filter) (clause string, values []interface{}) {
values = make([]interface{}, len(filters))
clauses := make([]string, len(filters))
if len(filters) == 0 {
return clause, values
}
for i, filter := range filters {
value := filter.GetValue()
switch value.(type) {
case []bool, []float64, []int64, []string, []models.AggregateType, []models.EventType, *[]bool, *[]float64, *[]int64, *[]string, *[]models.AggregateType, *[]models.EventType:
value = pq.Array(value)
}
clauses[i] = getCondition(filter)
if clauses[i] == "" {
return "", nil
}
values[i] = value
}
return " WHERE " + strings.Join(clauses, " AND "), values
}
type scanner func(dest ...interface{}) error
func prepareColumns(columns models.Columns) (string, func(s scanner, dest interface{}) error) {
switch columns {
case models.Columns_Max_Sequence:
return "SELECT MAX(event_sequence) FROM eventstore.events", func(row scanner, dest interface{}) (err error) {
sequence, ok := dest.(*Sequence)
if !ok {
return z_errors.ThrowInvalidArgument(nil, "SQL-NBjA9", "type must be sequence")
}
err = row(sequence)
if err == nil || errors.Is(err, sql.ErrNoRows) {
return nil
}
return z_errors.ThrowInternal(err, "SQL-bN5xg", "something went wrong")
}
case models.Columns_Event:
return selectStmt, func(row scanner, dest interface{}) (err error) {
event, ok := dest.(*models.Event)
if !ok {
return z_errors.ThrowInvalidArgument(nil, "SQL-4GP6F", "type must be event")
}
var previousSequence Sequence
data := make(Data, 0)
err = row(
&event.CreationDate,
&event.Type,
&event.Sequence,
&previousSequence,
&data,
&event.EditorService,
&event.EditorUser,
&event.ResourceOwner,
&event.AggregateType,
&event.AggregateID,
&event.AggregateVersion,
)
if err != nil {
logging.Log("SQL-kn1Sw").WithError(err).Warn("unable to scan row")
return z_errors.ThrowInternal(err, "SQL-J0hFS", "unable to scan row")
}
event.PreviousSequence = uint64(previousSequence)
event.Data = make([]byte, len(data))
copy(event.Data, data)
return nil
}
default:
return "", nil
}
}
func numberPlaceholder(query, old, new string) string {
for i, hasChanged := 1, true; hasChanged; i++ {
newQuery := strings.Replace(query, old, new+strconv.Itoa(i), 1)
hasChanged = query != newQuery
query = newQuery
}
return query
}
func getCondition(filter *es_models.Filter) (condition string) {
field := getField(filter.GetField())
operation := getOperation(filter.GetOperation())
if field == "" || operation == "" {
return ""
}
format := getConditionFormat(filter.GetOperation())
return fmt.Sprintf(format, field, operation)
}
func getConditionFormat(operation es_models.Operation) string {
if operation == es_models.Operation_In {
return "%s %s ANY(?)"
}
return "%s %s ?"
}
func getField(field es_models.Field) string {
switch field {
case es_models.Field_AggregateID:
return "aggregate_id"
case es_models.Field_AggregateType:
return "aggregate_type"
case es_models.Field_LatestSequence:
return "event_sequence"
case es_models.Field_ResourceOwner:
return "resource_owner"
case es_models.Field_EditorService:
return "editor_service"
case es_models.Field_EditorUser:
return "editor_user"
case es_models.Field_EventType:
return "event_type"
}
return ""
}
func getOperation(operation es_models.Operation) string {
switch operation {
case es_models.Operation_Equals, es_models.Operation_In:
return "="
case es_models.Operation_Greater:
return ">"
case es_models.Operation_Less:
return "<"
}
return ""
}

View File

@@ -0,0 +1,486 @@
package repository
import (
"database/sql"
"reflect"
"testing"
"time"
"github.com/caos/zitadel/internal/errors"
"github.com/caos/zitadel/internal/eventstore/models"
es_models "github.com/caos/zitadel/internal/eventstore/models"
"github.com/lib/pq"
)
func Test_numberPlaceholder(t *testing.T) {
type args struct {
query string
old string
new string
}
type res struct {
query string
}
tests := []struct {
name string
args args
res res
}{
{
name: "no replaces",
args: args{
new: "$",
old: "?",
query: "SELECT * FROM eventstore.events",
},
res: res{
query: "SELECT * FROM eventstore.events",
},
},
{
name: "two replaces",
args: args{
new: "$",
old: "?",
query: "SELECT * FROM eventstore.events WHERE aggregate_type = ? AND LIMIT = ?",
},
res: res{
query: "SELECT * FROM eventstore.events WHERE aggregate_type = $1 AND LIMIT = $2",
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := numberPlaceholder(tt.args.query, tt.args.old, tt.args.new); got != tt.res.query {
t.Errorf("numberPlaceholder() = %v, want %v", got, tt.res.query)
}
})
}
}
func Test_getOperation(t *testing.T) {
t.Run("all ops", func(t *testing.T) {
for op, expected := range map[es_models.Operation]string{
es_models.Operation_Equals: "=",
es_models.Operation_In: "=",
es_models.Operation_Greater: ">",
es_models.Operation_Less: "<",
es_models.Operation(-1): "",
} {
if got := getOperation(op); got != expected {
t.Errorf("getOperation() = %v, want %v", got, expected)
}
}
})
}
func Test_getField(t *testing.T) {
t.Run("all fields", func(t *testing.T) {
for field, expected := range map[es_models.Field]string{
es_models.Field_AggregateType: "aggregate_type",
es_models.Field_AggregateID: "aggregate_id",
es_models.Field_LatestSequence: "event_sequence",
es_models.Field_ResourceOwner: "resource_owner",
es_models.Field_EditorService: "editor_service",
es_models.Field_EditorUser: "editor_user",
es_models.Field_EventType: "event_type",
es_models.Field(-1): "",
} {
if got := getField(field); got != expected {
t.Errorf("getField() = %v, want %v", got, expected)
}
}
})
}
func Test_getConditionFormat(t *testing.T) {
type args struct {
operation es_models.Operation
}
tests := []struct {
name string
args args
want string
}{
{
name: "no in operation",
args: args{
operation: es_models.Operation_Equals,
},
want: "%s %s ?",
},
{
name: "in operation",
args: args{
operation: es_models.Operation_In,
},
want: "%s %s ANY(?)",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := getConditionFormat(tt.args.operation); got != tt.want {
t.Errorf("prepareConditionFormat() = %v, want %v", got, tt.want)
}
})
}
}
func Test_getCondition(t *testing.T) {
type args struct {
filter *es_models.Filter
}
tests := []struct {
name string
args args
want string
}{
{
name: "equals",
args: args{filter: es_models.NewFilter(es_models.Field_AggregateID, "", es_models.Operation_Equals)},
want: "aggregate_id = ?",
},
{
name: "greater",
args: args{filter: es_models.NewFilter(es_models.Field_LatestSequence, 0, es_models.Operation_Greater)},
want: "event_sequence > ?",
},
{
name: "less",
args: args{filter: es_models.NewFilter(es_models.Field_LatestSequence, 5000, es_models.Operation_Less)},
want: "event_sequence < ?",
},
{
name: "in list",
args: args{filter: es_models.NewFilter(es_models.Field_AggregateType, []es_models.AggregateType{"movies", "actors"}, es_models.Operation_In)},
want: "aggregate_type = ANY(?)",
},
{
name: "invalid operation",
args: args{filter: es_models.NewFilter(es_models.Field_AggregateType, []es_models.AggregateType{"movies", "actors"}, es_models.Operation(-1))},
want: "",
},
{
name: "invalid field",
args: args{filter: es_models.NewFilter(es_models.Field(-1), []es_models.AggregateType{"movies", "actors"}, es_models.Operation_Equals)},
want: "",
},
{
name: "invalid field and operation",
args: args{filter: es_models.NewFilter(es_models.Field(-1), []es_models.AggregateType{"movies", "actors"}, es_models.Operation(-1))},
want: "",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := getCondition(tt.args.filter); got != tt.want {
t.Errorf("getCondition() = %v, want %v", got, tt.want)
}
})
}
}
func Test_prepareColumns(t *testing.T) {
type args struct {
columns models.Columns
dest interface{}
dbErr error
}
type res struct {
query string
dbRow []interface{}
expected interface{}
dbErr func(error) bool
}
tests := []struct {
name string
args args
res res
}{
{
name: "invalid columns",
args: args{columns: es_models.Columns(-1)},
res: res{
query: "",
dbErr: func(err error) bool { return err == nil },
},
},
{
name: "max column",
args: args{
columns: es_models.Columns_Max_Sequence,
dest: new(Sequence),
},
res: res{
query: "SELECT MAX(event_sequence) FROM eventstore.events",
dbRow: []interface{}{Sequence(5)},
expected: Sequence(5),
},
},
{
name: "max sequence wrong dest type",
args: args{
columns: es_models.Columns_Max_Sequence,
dest: new(uint64),
},
res: res{
query: "SELECT MAX(event_sequence) FROM eventstore.events",
dbErr: errors.IsErrorInvalidArgument,
},
},
{
name: "event",
args: args{
columns: es_models.Columns_Event,
dest: new(models.Event),
},
res: res{
query: "SELECT creation_date, event_type, event_sequence, previous_sequence, event_data, editor_service, editor_user, resource_owner, aggregate_type, aggregate_id, aggregate_version FROM eventstore.events",
dbRow: []interface{}{time.Time{}, models.EventType(""), uint64(5), Sequence(0), Data(nil), "", "", "", models.AggregateType("user"), "hodor", models.Version("")},
expected: models.Event{AggregateID: "hodor", AggregateType: "user", Sequence: 5, Data: make(Data, 0)},
},
},
{
name: "event wrong dest type",
args: args{
columns: es_models.Columns_Event,
dest: new(uint64),
},
res: res{
query: "SELECT creation_date, event_type, event_sequence, previous_sequence, event_data, editor_service, editor_user, resource_owner, aggregate_type, aggregate_id, aggregate_version FROM eventstore.events",
dbErr: errors.IsErrorInvalidArgument,
},
},
{
name: "event query error",
args: args{
columns: es_models.Columns_Event,
dest: new(models.Event),
dbErr: sql.ErrConnDone,
},
res: res{
query: "SELECT creation_date, event_type, event_sequence, previous_sequence, event_data, editor_service, editor_user, resource_owner, aggregate_type, aggregate_id, aggregate_version FROM eventstore.events",
dbErr: errors.IsInternal,
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
query, rowScanner := prepareColumns(tt.args.columns)
if query != tt.res.query {
t.Errorf("prepareColumns() got = %v, want %v", query, tt.res.query)
}
if tt.res.query == "" && rowScanner != nil {
t.Errorf("row scanner should be nil")
}
if rowScanner == nil {
return
}
err := rowScanner(prepareTestScan(tt.args.dbErr, tt.res.dbRow), tt.args.dest)
if tt.res.dbErr != nil {
if !tt.res.dbErr(err) {
t.Errorf("wrong error type in rowScanner got: %v", err)
}
} else {
if !reflect.DeepEqual(reflect.Indirect(reflect.ValueOf(tt.args.dest)).Interface(), tt.res.expected) {
t.Errorf("unexpected result from rowScanner want: %v got: %v", tt.res.dbRow, tt.args.dest)
}
}
})
}
}
func prepareTestScan(err error, res []interface{}) scanner {
return func(dests ...interface{}) error {
if err != nil {
return err
}
if len(dests) != len(res) {
return errors.ThrowInvalidArgumentf(nil, "SQL-NML1q", "expected len %d got %d", len(res), len(dests))
}
for i, r := range res {
reflect.ValueOf(dests[i]).Elem().Set(reflect.ValueOf(r))
}
return nil
}
}
func Test_prepareCondition(t *testing.T) {
type args struct {
filters []*models.Filter
}
type res struct {
clause string
values []interface{}
}
tests := []struct {
name string
args args
res res
}{
{
name: "nil filters",
args: args{
filters: nil,
},
res: res{
clause: "",
values: nil,
},
},
{
name: "empty filters",
args: args{
filters: []*es_models.Filter{},
},
res: res{
clause: "",
values: nil,
},
},
{
name: "invalid condition",
args: args{
filters: []*es_models.Filter{
es_models.NewFilter(es_models.Field_AggregateID, "wrong", es_models.Operation(-1)),
},
},
res: res{
clause: "",
values: nil,
},
},
{
name: "array as condition value",
args: args{
filters: []*es_models.Filter{
es_models.NewFilter(es_models.Field_AggregateType, []es_models.AggregateType{"user", "org"}, es_models.Operation_In),
},
},
res: res{
clause: " WHERE aggregate_type = ANY(?)",
values: []interface{}{pq.Array([]es_models.AggregateType{"user", "org"})},
},
},
{
name: "multiple filters",
args: args{
filters: []*es_models.Filter{
es_models.NewFilter(es_models.Field_AggregateType, []es_models.AggregateType{"user", "org"}, es_models.Operation_In),
es_models.NewFilter(es_models.Field_AggregateID, "1234", es_models.Operation_Equals),
es_models.NewFilter(es_models.Field_EventType, []es_models.EventType{"user.created", "org.created"}, es_models.Operation_In),
},
},
res: res{
clause: " WHERE aggregate_type = ANY(?) AND aggregate_id = ? AND event_type = ANY(?)",
values: []interface{}{pq.Array([]es_models.AggregateType{"user", "org"}), "1234", pq.Array([]es_models.EventType{"user.created", "org.created"})},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
gotClause, gotValues := prepareCondition(tt.args.filters)
if gotClause != tt.res.clause {
t.Errorf("prepareCondition() gotClause = %v, want %v", gotClause, tt.res.clause)
}
if len(gotValues) != len(tt.res.values) {
t.Errorf("wrong length of gotten values got = %d, want %d", len(gotValues), len(tt.res.values))
return
}
for i, value := range gotValues {
if !reflect.DeepEqual(value, tt.res.values[i]) {
t.Errorf("prepareCondition() gotValues = %v, want %v", gotValues, tt.res.values)
}
}
})
}
}
func Test_buildQuery(t *testing.T) {
type args struct {
queryFactory *models.SearchQueryFactory
}
type res struct {
query string
limit uint64
values []interface{}
rowScanner bool
}
tests := []struct {
name string
args args
res res
}{
{
name: "invalid query factory",
args: args{
queryFactory: nil,
},
res: res{
query: "",
limit: 0,
rowScanner: false,
values: nil,
},
},
{
name: "with order by desc",
args: args{
queryFactory: es_models.NewSearchQueryFactory("user").OrderDesc(),
},
res: res{
query: "SELECT creation_date, event_type, event_sequence, previous_sequence, event_data, editor_service, editor_user, resource_owner, aggregate_type, aggregate_id, aggregate_version FROM eventstore.events WHERE aggregate_type = $1 ORDER BY event_sequence DESC",
rowScanner: true,
values: []interface{}{es_models.AggregateType("user")},
},
},
{
name: "with limit",
args: args{
queryFactory: es_models.NewSearchQueryFactory("user").Limit(5),
},
res: res{
query: "SELECT creation_date, event_type, event_sequence, previous_sequence, event_data, editor_service, editor_user, resource_owner, aggregate_type, aggregate_id, aggregate_version FROM eventstore.events WHERE aggregate_type = $1 ORDER BY event_sequence LIMIT $2",
rowScanner: true,
values: []interface{}{es_models.AggregateType("user"), uint64(5)},
limit: 5,
},
},
{
name: "with limit and order by desc",
args: args{
queryFactory: es_models.NewSearchQueryFactory("user").Limit(5).OrderDesc(),
},
res: res{
query: "SELECT creation_date, event_type, event_sequence, previous_sequence, event_data, editor_service, editor_user, resource_owner, aggregate_type, aggregate_id, aggregate_version FROM eventstore.events WHERE aggregate_type = $1 ORDER BY event_sequence DESC LIMIT $2",
rowScanner: true,
values: []interface{}{es_models.AggregateType("user"), uint64(5)},
limit: 5,
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
gotQuery, gotLimit, gotValues, gotRowScanner := buildQuery(tt.args.queryFactory)
if gotQuery != tt.res.query {
t.Errorf("buildQuery() gotQuery = %v, want %v", gotQuery, tt.res.query)
}
if gotLimit != tt.res.limit {
t.Errorf("buildQuery() gotLimit = %v, want %v", gotLimit, tt.res.limit)
}
if len(gotValues) != len(tt.res.values) {
t.Errorf("wrong length of gotten values got = %d, want %d", len(gotValues), len(tt.res.values))
return
}
for i, value := range gotValues {
if !reflect.DeepEqual(value, tt.res.values[i]) {
t.Errorf("prepareCondition() gotValues = %v, want %v", gotValues, tt.res.values)
}
}
if (tt.res.rowScanner && gotRowScanner == nil) || (!tt.res.rowScanner && gotRowScanner != nil) {
t.Errorf("rowScanner should be nil==%v got nil==%v", tt.res.rowScanner, gotRowScanner == nil)
}
})
}
}

View File

@@ -0,0 +1,47 @@
package repository
import "database/sql/driver"
// Data represents a byte array that may be null.
// Data implements the sql.Scanner interface
type Data []byte
// Scan implements the Scanner interface.
func (data *Data) Scan(value interface{}) error {
if value == nil {
*data = nil
return nil
}
*data = Data(value.([]byte))
return nil
}
// Value implements the driver Valuer interface.
func (data Data) Value() (driver.Value, error) {
if len(data) == 0 {
return nil, nil
}
return []byte(data), nil
}
// Sequence represents a number that may be null.
// Sequence implements the sql.Scanner interface
type Sequence uint64
// Scan implements the Scanner interface.
func (seq *Sequence) Scan(value interface{}) error {
if value == nil {
*seq = 0
return nil
}
*seq = Sequence(value.(int64))
return nil
}
// Value implements the driver Valuer interface.
func (seq Sequence) Value() (driver.Value, error) {
if seq == 0 {
return nil, nil
}
return int64(seq), nil
}