mirror of
https://github.com/zitadel/zitadel.git
synced 2025-03-01 00:17:24 +00:00
refactor(eventstore): sql
This commit is contained in:
parent
eb51a429ff
commit
120a8bae85
@ -8,6 +8,7 @@ import (
|
||||
"github.com/caos/zitadel/internal/eventstore/v2/repository"
|
||||
)
|
||||
|
||||
// testEvent implements the Event interface
|
||||
type testEvent struct {
|
||||
description string
|
||||
shouldCheckPrevious bool
|
||||
@ -17,6 +18,23 @@ func (e *testEvent) CheckPrevious() bool {
|
||||
return e.shouldCheckPrevious
|
||||
}
|
||||
|
||||
func (e *testEvent) EditorService() string {
|
||||
return "editorService"
|
||||
}
|
||||
func (e *testEvent) EditorUser() string {
|
||||
return "editorUser"
|
||||
}
|
||||
func (e *testEvent) Type() EventType {
|
||||
return "test.event"
|
||||
}
|
||||
func (e *testEvent) Data() interface{} {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (e *testEvent) PreviousSequence() uint64 {
|
||||
return 0
|
||||
}
|
||||
|
||||
func testPushMapper(Event) (*repository.Event, error) {
|
||||
return &repository.Event{AggregateID: "aggregateID"}, nil
|
||||
}
|
||||
|
@ -1,22 +0,0 @@
|
||||
package repository
|
||||
|
||||
//AggregateType is the object name
|
||||
type AggregateType string
|
||||
|
||||
// //Aggregate represents an object
|
||||
// type Aggregate struct {
|
||||
// //ID id is the unique identifier of the aggregate
|
||||
// // the client must generate it by it's own
|
||||
// ID string
|
||||
// //Type describes the meaning of this aggregate
|
||||
// // it could an object like user
|
||||
// Type AggregateType
|
||||
|
||||
// //ResourceOwner is the organisation which owns this aggregate
|
||||
// // an aggregate can only be managed by one organisation
|
||||
// // use the ID of the org
|
||||
// ResourceOwner string
|
||||
|
||||
// //Events describe all the changes made on an aggregate
|
||||
// Events []*Event
|
||||
// }
|
@ -66,3 +66,6 @@ type Event struct {
|
||||
|
||||
//EventType is the description of the change
|
||||
type EventType string
|
||||
|
||||
//AggregateType is the object name
|
||||
type AggregateType string
|
||||
|
@ -1,62 +0,0 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
|
||||
"github.com/caos/logging"
|
||||
"github.com/caos/zitadel/internal/errors"
|
||||
es_models "github.com/caos/zitadel/internal/eventstore/models"
|
||||
)
|
||||
|
||||
type Querier interface {
|
||||
Query(query string, args ...interface{}) (*sql.Rows, error)
|
||||
}
|
||||
|
||||
// Filter returns all events matching the given search query
|
||||
func (db *CRDB) Filter(ctx context.Context, searchQuery *es_models.SearchQueryFactory) (events []*Event, err error) {
|
||||
return filter(db.db, searchQuery)
|
||||
}
|
||||
|
||||
func filter(querier Querier, searchQuery *es_models.SearchQueryFactory) (events []*Event, err error) {
|
||||
query, limit, values, rowScanner := buildQuery(searchQuery)
|
||||
if query == "" {
|
||||
return nil, errors.ThrowInvalidArgument(nil, "SQL-rWeBw", "invalid query factory")
|
||||
}
|
||||
|
||||
rows, err := querier.Query(query, values...)
|
||||
if err != nil {
|
||||
logging.Log("SQL-HP3Uk").WithError(err).Info("query failed")
|
||||
return nil, errors.ThrowInternal(err, "SQL-IJuyR", "unable to filter events")
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
events = make([]*Event, 0, limit)
|
||||
|
||||
for rows.Next() {
|
||||
event := new(Event)
|
||||
err := rowScanner(rows.Scan, event)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
events = append(events, event)
|
||||
}
|
||||
|
||||
return events, nil
|
||||
}
|
||||
|
||||
// func (db *SQL) LatestSequence(ctx context.Context, queryFactory *es_models.SearchQueryFactory) (uint64, error) {
|
||||
// query, _, values, rowScanner := buildQuery(queryFactory)
|
||||
// if query == "" {
|
||||
// return 0, errors.ThrowInvalidArgument(nil, "SQL-rWeBw", "invalid query factory")
|
||||
// }
|
||||
// row := db.client.QueryRow(query, values...)
|
||||
// sequence := new(Sequence)
|
||||
// err := rowScanner(row.Scan, sequence)
|
||||
// if err != nil {
|
||||
// logging.Log("SQL-WsxTg").WithError(err).Info("query failed")
|
||||
// return 0, errors.ThrowInternal(err, "SQL-Yczyx", "unable to filter latest sequence")
|
||||
// }
|
||||
// return uint64(*sequence), nil
|
||||
// }
|
@ -1,199 +0,0 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/caos/logging"
|
||||
z_errors "github.com/caos/zitadel/internal/errors"
|
||||
"github.com/caos/zitadel/internal/eventstore/models"
|
||||
es_models "github.com/caos/zitadel/internal/eventstore/models"
|
||||
"github.com/lib/pq"
|
||||
)
|
||||
|
||||
const (
|
||||
selectStmt = "SELECT" +
|
||||
" creation_date" +
|
||||
", event_type" +
|
||||
", event_sequence" +
|
||||
", previous_sequence" +
|
||||
", event_data" +
|
||||
", editor_service" +
|
||||
", editor_user" +
|
||||
", resource_owner" +
|
||||
", aggregate_type" +
|
||||
", aggregate_id" +
|
||||
", aggregate_version" +
|
||||
" FROM eventstore.events"
|
||||
)
|
||||
|
||||
func buildQuery(queryFactory *models.SearchQueryFactory) (query string, limit uint64, values []interface{}, rowScanner func(s scanner, dest interface{}) error) {
|
||||
searchQuery, err := queryFactory.Build()
|
||||
if err != nil {
|
||||
logging.Log("SQL-cshKu").WithError(err).Warn("search query factory invalid")
|
||||
return "", 0, nil, nil
|
||||
}
|
||||
query, rowScanner = prepareColumns(searchQuery.Columns)
|
||||
where, values := prepareCondition(searchQuery.Filters)
|
||||
if where == "" || query == "" {
|
||||
return "", 0, nil, nil
|
||||
}
|
||||
query += where
|
||||
|
||||
if searchQuery.Columns != models.Columns_Max_Sequence {
|
||||
query += " ORDER BY event_sequence"
|
||||
if searchQuery.Desc {
|
||||
query += " DESC"
|
||||
}
|
||||
}
|
||||
|
||||
if searchQuery.Limit > 0 {
|
||||
values = append(values, searchQuery.Limit)
|
||||
query += " LIMIT ?"
|
||||
}
|
||||
|
||||
query = numberPlaceholder(query, "?", "$")
|
||||
|
||||
return query, searchQuery.Limit, values, rowScanner
|
||||
}
|
||||
|
||||
func prepareCondition(filters []*models.Filter) (clause string, values []interface{}) {
|
||||
values = make([]interface{}, len(filters))
|
||||
clauses := make([]string, len(filters))
|
||||
|
||||
if len(filters) == 0 {
|
||||
return clause, values
|
||||
}
|
||||
for i, filter := range filters {
|
||||
value := filter.GetValue()
|
||||
switch value.(type) {
|
||||
case []bool, []float64, []int64, []string, []models.AggregateType, []models.EventType, *[]bool, *[]float64, *[]int64, *[]string, *[]models.AggregateType, *[]models.EventType:
|
||||
value = pq.Array(value)
|
||||
}
|
||||
|
||||
clauses[i] = getCondition(filter)
|
||||
if clauses[i] == "" {
|
||||
return "", nil
|
||||
}
|
||||
values[i] = value
|
||||
}
|
||||
return " WHERE " + strings.Join(clauses, " AND "), values
|
||||
}
|
||||
|
||||
type scanner func(dest ...interface{}) error
|
||||
|
||||
func prepareColumns(columns models.Columns) (string, func(s scanner, dest interface{}) error) {
|
||||
switch columns {
|
||||
case models.Columns_Max_Sequence:
|
||||
return "SELECT MAX(event_sequence) FROM eventstore.events", func(row scanner, dest interface{}) (err error) {
|
||||
sequence, ok := dest.(*Sequence)
|
||||
if !ok {
|
||||
return z_errors.ThrowInvalidArgument(nil, "SQL-NBjA9", "type must be sequence")
|
||||
}
|
||||
err = row(sequence)
|
||||
if err == nil || errors.Is(err, sql.ErrNoRows) {
|
||||
return nil
|
||||
}
|
||||
return z_errors.ThrowInternal(err, "SQL-bN5xg", "something went wrong")
|
||||
}
|
||||
case models.Columns_Event:
|
||||
return selectStmt, func(row scanner, dest interface{}) (err error) {
|
||||
event, ok := dest.(*models.Event)
|
||||
if !ok {
|
||||
return z_errors.ThrowInvalidArgument(nil, "SQL-4GP6F", "type must be event")
|
||||
}
|
||||
var previousSequence Sequence
|
||||
data := make(Data, 0)
|
||||
|
||||
err = row(
|
||||
&event.CreationDate,
|
||||
&event.Type,
|
||||
&event.Sequence,
|
||||
&previousSequence,
|
||||
&data,
|
||||
&event.EditorService,
|
||||
&event.EditorUser,
|
||||
&event.ResourceOwner,
|
||||
&event.AggregateType,
|
||||
&event.AggregateID,
|
||||
&event.AggregateVersion,
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
logging.Log("SQL-kn1Sw").WithError(err).Warn("unable to scan row")
|
||||
return z_errors.ThrowInternal(err, "SQL-J0hFS", "unable to scan row")
|
||||
}
|
||||
|
||||
event.PreviousSequence = uint64(previousSequence)
|
||||
|
||||
event.Data = make([]byte, len(data))
|
||||
copy(event.Data, data)
|
||||
|
||||
return nil
|
||||
}
|
||||
default:
|
||||
return "", nil
|
||||
}
|
||||
}
|
||||
|
||||
func numberPlaceholder(query, old, new string) string {
|
||||
for i, hasChanged := 1, true; hasChanged; i++ {
|
||||
newQuery := strings.Replace(query, old, new+strconv.Itoa(i), 1)
|
||||
hasChanged = query != newQuery
|
||||
query = newQuery
|
||||
}
|
||||
return query
|
||||
}
|
||||
|
||||
func getCondition(filter *es_models.Filter) (condition string) {
|
||||
field := getField(filter.GetField())
|
||||
operation := getOperation(filter.GetOperation())
|
||||
if field == "" || operation == "" {
|
||||
return ""
|
||||
}
|
||||
format := getConditionFormat(filter.GetOperation())
|
||||
|
||||
return fmt.Sprintf(format, field, operation)
|
||||
}
|
||||
|
||||
func getConditionFormat(operation es_models.Operation) string {
|
||||
if operation == es_models.Operation_In {
|
||||
return "%s %s ANY(?)"
|
||||
}
|
||||
return "%s %s ?"
|
||||
}
|
||||
|
||||
func getField(field es_models.Field) string {
|
||||
switch field {
|
||||
case es_models.Field_AggregateID:
|
||||
return "aggregate_id"
|
||||
case es_models.Field_AggregateType:
|
||||
return "aggregate_type"
|
||||
case es_models.Field_LatestSequence:
|
||||
return "event_sequence"
|
||||
case es_models.Field_ResourceOwner:
|
||||
return "resource_owner"
|
||||
case es_models.Field_EditorService:
|
||||
return "editor_service"
|
||||
case es_models.Field_EditorUser:
|
||||
return "editor_user"
|
||||
case es_models.Field_EventType:
|
||||
return "event_type"
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func getOperation(operation es_models.Operation) string {
|
||||
switch operation {
|
||||
case es_models.Operation_Equals, es_models.Operation_In:
|
||||
return "="
|
||||
case es_models.Operation_Greater:
|
||||
return ">"
|
||||
case es_models.Operation_Less:
|
||||
return "<"
|
||||
}
|
||||
return ""
|
||||
}
|
@ -6,7 +6,6 @@ import (
|
||||
|
||||
type Repository interface {
|
||||
Health(ctx context.Context) error
|
||||
|
||||
// PushEvents adds all events of the given aggregates to the eventstreams of the aggregates.
|
||||
// This call is transaction save. The transaction will be rolled back if one event fails
|
||||
Push(ctx context.Context, events ...*Event) error
|
||||
|
@ -1,12 +1,15 @@
|
||||
package repository
|
||||
package sql
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"regexp"
|
||||
"strconv"
|
||||
|
||||
"github.com/caos/logging"
|
||||
caos_errs "github.com/caos/zitadel/internal/errors"
|
||||
"github.com/caos/zitadel/internal/eventstore/v2/repository"
|
||||
"github.com/cockroachdb/cockroach-go/v2/crdb"
|
||||
|
||||
//sql import for cockroach
|
||||
@ -99,15 +102,15 @@ const (
|
||||
)
|
||||
|
||||
type CRDB struct {
|
||||
db *sql.DB
|
||||
client *sql.DB
|
||||
}
|
||||
|
||||
func (db *CRDB) Health(ctx context.Context) error { return db.db.Ping() }
|
||||
func (db *CRDB) Health(ctx context.Context) error { return db.client.Ping() }
|
||||
|
||||
// Push adds all events to the eventstreams of the aggregates.
|
||||
// This call is transaction save. The transaction will be rolled back if one event fails
|
||||
func (db *CRDB) Push(ctx context.Context, events ...*Event) error {
|
||||
err := crdb.ExecuteTx(ctx, db.db, nil, func(tx *sql.Tx) error {
|
||||
func (db *CRDB) Push(ctx context.Context, events ...*repository.Event) error {
|
||||
err := crdb.ExecuteTx(ctx, db.client, nil, func(tx *sql.Tx) error {
|
||||
stmt, err := tx.PrepareContext(ctx, crdbInsert)
|
||||
if err != nil {
|
||||
tx.Rollback()
|
||||
@ -154,74 +157,138 @@ func (db *CRDB) Push(ctx context.Context, events ...*Event) error {
|
||||
}
|
||||
|
||||
// Filter returns all events matching the given search query
|
||||
// func (db *CRDB) Filter(ctx context.Context, searchQuery *SearchQuery) (events []*Event, err error) {
|
||||
func (db *CRDB) Filter(ctx context.Context, searchQuery *repository.SearchQuery) (events []*repository.Event, err error) {
|
||||
rows, rowScanner, err := db.query(searchQuery)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
// return events, nil
|
||||
// }
|
||||
for rows.Next() {
|
||||
event := new(repository.Event)
|
||||
err := rowScanner(rows.Scan, event)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
//LatestSequence returns the latests sequence found by the the search query
|
||||
func (db *CRDB) LatestSequence(ctx context.Context, queryFactory *SearchQuery) (uint64, error) {
|
||||
return 0, nil
|
||||
events = append(events, event)
|
||||
}
|
||||
|
||||
return events, nil
|
||||
}
|
||||
|
||||
func (db *CRDB) prepareQuery(columns Columns) (string, func(s scanner, dest interface{}) error) {
|
||||
switch columns {
|
||||
case Columns_Max_Sequence:
|
||||
return "SELECT MAX(event_sequence) FROM eventstore.events", func(scan scanner, dest interface{}) (err error) {
|
||||
sequence, ok := dest.(*Sequence)
|
||||
if !ok {
|
||||
return caos_errs.ThrowInvalidArgument(nil, "SQL-NBjA9", "type must be sequence")
|
||||
}
|
||||
err = scan(sequence)
|
||||
if err == nil || errors.Is(err, sql.ErrNoRows) {
|
||||
return nil
|
||||
}
|
||||
return caos_errs.ThrowInternal(err, "SQL-bN5xg", "something went wrong")
|
||||
}
|
||||
case Columns_Event:
|
||||
return selectStmt, func(row scanner, dest interface{}) (err error) {
|
||||
event, ok := dest.(*Event)
|
||||
if !ok {
|
||||
return caos_errs.ThrowInvalidArgument(nil, "SQL-4GP6F", "type must be event")
|
||||
}
|
||||
var previousSequence Sequence
|
||||
data := make(Data, 0)
|
||||
//LatestSequence returns the latests sequence found by the the search query
|
||||
func (db *CRDB) LatestSequence(ctx context.Context, searchQuery *repository.SearchQuery) (uint64, error) {
|
||||
rows, rowScanner, err := db.query(searchQuery)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
err = row(
|
||||
&event.CreationDate,
|
||||
&event.Type,
|
||||
&event.Sequence,
|
||||
&previousSequence,
|
||||
&data,
|
||||
&event.EditorService,
|
||||
&event.EditorUser,
|
||||
&event.ResourceOwner,
|
||||
&event.AggregateType,
|
||||
&event.AggregateID,
|
||||
&event.Version,
|
||||
)
|
||||
if !rows.Next() {
|
||||
return 0, caos_errs.ThrowNotFound(nil, "SQL-cAEzS", "latest sequence not found")
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
logging.Log("SQL-kn1Sw").WithError(err).Warn("unable to scan row")
|
||||
return caos_errs.ThrowInternal(err, "SQL-J0hFS", "unable to scan row")
|
||||
}
|
||||
var seq Sequence
|
||||
err = rowScanner(rows.Scan, &seq)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
event.PreviousSequence = uint64(previousSequence)
|
||||
return uint64(seq), nil
|
||||
}
|
||||
|
||||
event.Data = make([]byte, len(data))
|
||||
copy(event.Data, data)
|
||||
func (db *CRDB) query(searchQuery *repository.SearchQuery) (*sql.Rows, rowScan, error) {
|
||||
query, values, rowScanner := buildQuery(db, searchQuery)
|
||||
if query == "" {
|
||||
return nil, nil, caos_errs.ThrowInvalidArgument(nil, "SQL-rWeBw", "invalid query factory")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
rows, err := db.client.Query(query, values...)
|
||||
if err != nil {
|
||||
logging.Log("SQL-HP3Uk").WithError(err).Info("query failed")
|
||||
return nil, nil, caos_errs.ThrowInternal(err, "SQL-IJuyR", "unable to filter events")
|
||||
}
|
||||
return rows, rowScanner, nil
|
||||
}
|
||||
|
||||
func (db *CRDB) eventQuery() string {
|
||||
return "SELECT" +
|
||||
" creation_date" +
|
||||
", event_type" +
|
||||
", event_sequence" +
|
||||
", previous_sequence" +
|
||||
", event_data" +
|
||||
", editor_service" +
|
||||
", editor_user" +
|
||||
", resource_owner" +
|
||||
", aggregate_type" +
|
||||
", aggregate_id" +
|
||||
", aggregate_version" +
|
||||
" FROM eventstore.events"
|
||||
}
|
||||
func (db *CRDB) maxSequenceQuery() string {
|
||||
return "SELECT MAX(event_sequence) FROM eventstore.events"
|
||||
}
|
||||
|
||||
func (db *CRDB) columnName(col repository.Field) string {
|
||||
switch col {
|
||||
case repository.Field_AggregateID:
|
||||
return "aggregate_id"
|
||||
case repository.Field_AggregateType:
|
||||
return "aggregate_type"
|
||||
case repository.Field_LatestSequence:
|
||||
return "event_sequence"
|
||||
case repository.Field_ResourceOwner:
|
||||
return "resource_owner"
|
||||
case repository.Field_EditorService:
|
||||
return "editor_service"
|
||||
case repository.Field_EditorUser:
|
||||
return "editor_user"
|
||||
case repository.Field_EventType:
|
||||
return "event_type"
|
||||
default:
|
||||
return "", nil
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
func (db *CRDB) prepareFilter(filters []*Filter) string {
|
||||
filter := ""
|
||||
// for _, f := range filters{
|
||||
// f.
|
||||
// }
|
||||
return filter
|
||||
func (db *CRDB) conditionFormat(operation repository.Operation) string {
|
||||
if operation == repository.Operation_In {
|
||||
return "%s %s ANY(?)"
|
||||
}
|
||||
return "%s %s ?"
|
||||
}
|
||||
|
||||
func (db *CRDB) operation(operation repository.Operation) string {
|
||||
switch operation {
|
||||
case repository.Operation_Equals, repository.Operation_In:
|
||||
return "="
|
||||
case repository.Operation_Greater:
|
||||
return ">"
|
||||
case repository.Operation_Less:
|
||||
return "<"
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
var (
|
||||
placeholder = regexp.MustCompile(`\?`)
|
||||
)
|
||||
|
||||
//placeholder replaces all "?" with postgres placeholders ($<NUMBER>)
|
||||
func (db *CRDB) placeholder(query string) string {
|
||||
occurances := placeholder.FindAllStringIndex(query, -1)
|
||||
if len(occurances) == 0 {
|
||||
return query
|
||||
}
|
||||
replaced := query[:occurances[0][0]]
|
||||
|
||||
for i, l := range occurances {
|
||||
nextIDX := len(query)
|
||||
if i < len(occurances)-1 {
|
||||
nextIDX = occurances[i+1][0]
|
||||
}
|
||||
replaced = replaced + "$" + strconv.Itoa(i+1) + query[l[1]:nextIDX]
|
||||
}
|
||||
return replaced
|
||||
}
|
@ -1,4 +1,4 @@
|
||||
package repository
|
||||
package sql
|
||||
|
||||
import (
|
||||
"database/sql"
|
142
internal/eventstore/v2/repository/sql/query.go
Normal file
142
internal/eventstore/v2/repository/sql/query.go
Normal file
@ -0,0 +1,142 @@
|
||||
package sql
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/caos/logging"
|
||||
z_errors "github.com/caos/zitadel/internal/errors"
|
||||
"github.com/caos/zitadel/internal/eventstore/v2/repository"
|
||||
"github.com/lib/pq"
|
||||
)
|
||||
|
||||
type criteriaer interface {
|
||||
columnName(repository.Field) string
|
||||
operation(repository.Operation) string
|
||||
conditionFormat(repository.Operation) string
|
||||
placeholder(query string) string
|
||||
eventQuery() string
|
||||
maxSequenceQuery() string
|
||||
}
|
||||
|
||||
type rowScan func(scan, interface{}) error
|
||||
type scan func(dest ...interface{}) error
|
||||
|
||||
func buildQuery(criteria criteriaer, searchQuery *repository.SearchQuery) (query string, values []interface{}, rowScanner rowScan) {
|
||||
query, rowScanner = prepareColumns(criteria, searchQuery.Columns)
|
||||
where, values := prepareCondition(criteria, searchQuery.Filters)
|
||||
if where == "" || query == "" {
|
||||
return "", nil, nil
|
||||
}
|
||||
query += where
|
||||
|
||||
if searchQuery.Columns != repository.Columns_Max_Sequence {
|
||||
query += " ORDER BY event_sequence"
|
||||
if searchQuery.Desc {
|
||||
query += " DESC"
|
||||
}
|
||||
}
|
||||
|
||||
if searchQuery.Limit > 0 {
|
||||
values = append(values, searchQuery.Limit)
|
||||
query += " LIMIT ?"
|
||||
}
|
||||
|
||||
query = criteria.placeholder(query)
|
||||
|
||||
return query, values, rowScanner
|
||||
}
|
||||
|
||||
func prepareColumns(criteria criteriaer, columns repository.Columns) (string, func(s scan, dest interface{}) error) {
|
||||
switch columns {
|
||||
case repository.Columns_Max_Sequence:
|
||||
return criteria.maxSequenceQuery(), maxSequenceRowScanner
|
||||
case repository.Columns_Event:
|
||||
return criteria.eventQuery(), eventRowScanner
|
||||
default:
|
||||
return "", nil
|
||||
}
|
||||
}
|
||||
|
||||
func maxSequenceRowScanner(row scan, dest interface{}) (err error) {
|
||||
sequence, ok := dest.(*Sequence)
|
||||
if !ok {
|
||||
return z_errors.ThrowInvalidArgument(nil, "SQL-NBjA9", "type must be sequence")
|
||||
}
|
||||
err = row(sequence)
|
||||
if err == nil || errors.Is(err, sql.ErrNoRows) {
|
||||
return nil
|
||||
}
|
||||
return z_errors.ThrowInternal(err, "SQL-bN5xg", "something went wrong")
|
||||
}
|
||||
|
||||
func eventRowScanner(row scan, dest interface{}) (err error) {
|
||||
event, ok := dest.(*repository.Event)
|
||||
if !ok {
|
||||
return z_errors.ThrowInvalidArgument(nil, "SQL-4GP6F", "type must be event")
|
||||
}
|
||||
var previousSequence Sequence
|
||||
data := make(Data, 0)
|
||||
|
||||
err = row(
|
||||
&event.CreationDate,
|
||||
&event.Type,
|
||||
&event.Sequence,
|
||||
&previousSequence,
|
||||
&data,
|
||||
&event.EditorService,
|
||||
&event.EditorUser,
|
||||
&event.ResourceOwner,
|
||||
&event.AggregateType,
|
||||
&event.AggregateID,
|
||||
&event.Version,
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
logging.Log("SQL-kn1Sw").WithError(err).Warn("unable to scan row")
|
||||
return z_errors.ThrowInternal(err, "SQL-J0hFS", "unable to scan row")
|
||||
}
|
||||
|
||||
event.PreviousSequence = uint64(previousSequence)
|
||||
|
||||
event.Data = make([]byte, len(data))
|
||||
copy(event.Data, data)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func prepareCondition(criteria criteriaer, filters []*repository.Filter) (clause string, values []interface{}) {
|
||||
values = make([]interface{}, len(filters))
|
||||
clauses := make([]string, len(filters))
|
||||
|
||||
if len(filters) == 0 {
|
||||
return clause, values
|
||||
}
|
||||
for i, filter := range filters {
|
||||
value := filter.Value()
|
||||
switch value.(type) {
|
||||
case []bool, []float64, []int64, []string, []repository.AggregateType, []repository.EventType, *[]bool, *[]float64, *[]int64, *[]string, *[]repository.AggregateType, *[]repository.EventType:
|
||||
value = pq.Array(value)
|
||||
}
|
||||
|
||||
clauses[i] = getCondition(criteria, filter)
|
||||
if clauses[i] == "" {
|
||||
return "", nil
|
||||
}
|
||||
values[i] = value
|
||||
}
|
||||
return " WHERE " + strings.Join(clauses, " AND "), values
|
||||
}
|
||||
|
||||
func getCondition(cond criteriaer, filter *repository.Filter) (condition string) {
|
||||
field := cond.columnName(filter.Field())
|
||||
operation := cond.operation(filter.Operation())
|
||||
if field == "" || operation == "" {
|
||||
return ""
|
||||
}
|
||||
format := cond.conditionFormat(filter.Operation())
|
||||
|
||||
return fmt.Sprintf(format, field, operation)
|
||||
}
|
@ -1,4 +1,4 @@
|
||||
package repository
|
||||
package sql
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
@ -7,8 +7,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/caos/zitadel/internal/errors"
|
||||
"github.com/caos/zitadel/internal/eventstore/models"
|
||||
es_models "github.com/caos/zitadel/internal/eventstore/models"
|
||||
"github.com/caos/zitadel/internal/eventstore/v2/repository"
|
||||
"github.com/lib/pq"
|
||||
)
|
||||
|
||||
@ -60,12 +59,12 @@ func Test_numberPlaceholder(t *testing.T) {
|
||||
|
||||
func Test_getOperation(t *testing.T) {
|
||||
t.Run("all ops", func(t *testing.T) {
|
||||
for op, expected := range map[es_models.Operation]string{
|
||||
es_models.Operation_Equals: "=",
|
||||
es_models.Operation_In: "=",
|
||||
es_models.Operation_Greater: ">",
|
||||
es_models.Operation_Less: "<",
|
||||
es_models.Operation(-1): "",
|
||||
for op, expected := range map[repository.Operation]string{
|
||||
repository.Operation_Equals: "=",
|
||||
repository.Operation_In: "=",
|
||||
repository.Operation_Greater: ">",
|
||||
repository.Operation_Less: "<",
|
||||
repository.Operation(-1): "",
|
||||
} {
|
||||
if got := getOperation(op); got != expected {
|
||||
t.Errorf("getOperation() = %v, want %v", got, expected)
|
||||
@ -76,15 +75,15 @@ func Test_getOperation(t *testing.T) {
|
||||
|
||||
func Test_getField(t *testing.T) {
|
||||
t.Run("all fields", func(t *testing.T) {
|
||||
for field, expected := range map[es_models.Field]string{
|
||||
es_models.Field_AggregateType: "aggregate_type",
|
||||
es_models.Field_AggregateID: "aggregate_id",
|
||||
es_models.Field_LatestSequence: "event_sequence",
|
||||
es_models.Field_ResourceOwner: "resource_owner",
|
||||
es_models.Field_EditorService: "editor_service",
|
||||
es_models.Field_EditorUser: "editor_user",
|
||||
es_models.Field_EventType: "event_type",
|
||||
es_models.Field(-1): "",
|
||||
for field, expected := range map[repository.Field]string{
|
||||
repository.Field_AggregateType: "aggregate_type",
|
||||
repository.Field_AggregateID: "aggregate_id",
|
||||
repository.Field_LatestSequence: "event_sequence",
|
||||
repository.Field_ResourceOwner: "resource_owner",
|
||||
repository.Field_EditorService: "editor_service",
|
||||
repository.Field_EditorUser: "editor_user",
|
||||
repository.Field_EventType: "event_type",
|
||||
repository.Field(-1): "",
|
||||
} {
|
||||
if got := getField(field); got != expected {
|
||||
t.Errorf("getField() = %v, want %v", got, expected)
|
||||
@ -95,7 +94,7 @@ func Test_getField(t *testing.T) {
|
||||
|
||||
func Test_getConditionFormat(t *testing.T) {
|
||||
type args struct {
|
||||
operation es_models.Operation
|
||||
operation repository.Operation
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
@ -105,14 +104,14 @@ func Test_getConditionFormat(t *testing.T) {
|
||||
{
|
||||
name: "no in operation",
|
||||
args: args{
|
||||
operation: es_models.Operation_Equals,
|
||||
operation: repository.Operation_Equals,
|
||||
},
|
||||
want: "%s %s ?",
|
||||
},
|
||||
{
|
||||
name: "in operation",
|
||||
args: args{
|
||||
operation: es_models.Operation_In,
|
||||
operation: repository.Operation_In,
|
||||
},
|
||||
want: "%s %s ANY(?)",
|
||||
},
|
||||
@ -128,7 +127,7 @@ func Test_getConditionFormat(t *testing.T) {
|
||||
|
||||
func Test_getCondition(t *testing.T) {
|
||||
type args struct {
|
||||
filter *es_models.Filter
|
||||
filter *repository.Filter
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
@ -137,37 +136,37 @@ func Test_getCondition(t *testing.T) {
|
||||
}{
|
||||
{
|
||||
name: "equals",
|
||||
args: args{filter: es_models.NewFilter(es_models.Field_AggregateID, "", es_models.Operation_Equals)},
|
||||
args: args{filter: repository.NewFilter(repository.Field_AggregateID, "", repository.Operation_Equals)},
|
||||
want: "aggregate_id = ?",
|
||||
},
|
||||
{
|
||||
name: "greater",
|
||||
args: args{filter: es_models.NewFilter(es_models.Field_LatestSequence, 0, es_models.Operation_Greater)},
|
||||
args: args{filter: repository.NewFilter(repository.Field_LatestSequence, 0, repository.Operation_Greater)},
|
||||
want: "event_sequence > ?",
|
||||
},
|
||||
{
|
||||
name: "less",
|
||||
args: args{filter: es_models.NewFilter(es_models.Field_LatestSequence, 5000, es_models.Operation_Less)},
|
||||
args: args{filter: repository.NewFilter(repository.Field_LatestSequence, 5000, repository.Operation_Less)},
|
||||
want: "event_sequence < ?",
|
||||
},
|
||||
{
|
||||
name: "in list",
|
||||
args: args{filter: es_models.NewFilter(es_models.Field_AggregateType, []es_models.AggregateType{"movies", "actors"}, es_models.Operation_In)},
|
||||
args: args{filter: repository.NewFilter(repository.Field_AggregateType, []repository.AggregateType{"movies", "actors"}, repository.Operation_In)},
|
||||
want: "aggregate_type = ANY(?)",
|
||||
},
|
||||
{
|
||||
name: "invalid operation",
|
||||
args: args{filter: es_models.NewFilter(es_models.Field_AggregateType, []es_models.AggregateType{"movies", "actors"}, es_models.Operation(-1))},
|
||||
args: args{filter: repository.NewFilter(repository.Field_AggregateType, []repository.AggregateType{"movies", "actors"}, repository.Operation(-1))},
|
||||
want: "",
|
||||
},
|
||||
{
|
||||
name: "invalid field",
|
||||
args: args{filter: es_models.NewFilter(es_models.Field(-1), []es_models.AggregateType{"movies", "actors"}, es_models.Operation_Equals)},
|
||||
args: args{filter: repository.NewFilter(repository.Field(-1), []repository.AggregateType{"movies", "actors"}, repository.Operation_Equals)},
|
||||
want: "",
|
||||
},
|
||||
{
|
||||
name: "invalid field and operation",
|
||||
args: args{filter: es_models.NewFilter(es_models.Field(-1), []es_models.AggregateType{"movies", "actors"}, es_models.Operation(-1))},
|
||||
args: args{filter: repository.NewFilter(repository.Field(-1), []repository.AggregateType{"movies", "actors"}, repository.Operation(-1))},
|
||||
want: "",
|
||||
},
|
||||
}
|
||||
@ -182,7 +181,7 @@ func Test_getCondition(t *testing.T) {
|
||||
|
||||
func Test_prepareColumns(t *testing.T) {
|
||||
type args struct {
|
||||
columns models.Columns
|
||||
columns repository.Columns
|
||||
dest interface{}
|
||||
dbErr error
|
||||
}
|
||||
@ -199,7 +198,7 @@ func Test_prepareColumns(t *testing.T) {
|
||||
}{
|
||||
{
|
||||
name: "invalid columns",
|
||||
args: args{columns: es_models.Columns(-1)},
|
||||
args: args{columns: repository.Columns(-1)},
|
||||
res: res{
|
||||
query: "",
|
||||
dbErr: func(err error) bool { return err == nil },
|
||||
@ -208,7 +207,7 @@ func Test_prepareColumns(t *testing.T) {
|
||||
{
|
||||
name: "max column",
|
||||
args: args{
|
||||
columns: es_models.Columns_Max_Sequence,
|
||||
columns: repository.Columns_Max_Sequence,
|
||||
dest: new(Sequence),
|
||||
},
|
||||
res: res{
|
||||
@ -220,7 +219,7 @@ func Test_prepareColumns(t *testing.T) {
|
||||
{
|
||||
name: "max sequence wrong dest type",
|
||||
args: args{
|
||||
columns: es_models.Columns_Max_Sequence,
|
||||
columns: repository.Columns_Max_Sequence,
|
||||
dest: new(uint64),
|
||||
},
|
||||
res: res{
|
||||
@ -231,19 +230,19 @@ func Test_prepareColumns(t *testing.T) {
|
||||
{
|
||||
name: "event",
|
||||
args: args{
|
||||
columns: es_models.Columns_Event,
|
||||
dest: new(models.Event),
|
||||
columns: repository.Columns_Event,
|
||||
dest: new(repository.Event),
|
||||
},
|
||||
res: res{
|
||||
query: "SELECT creation_date, event_type, event_sequence, previous_sequence, event_data, editor_service, editor_user, resource_owner, aggregate_type, aggregate_id, aggregate_version FROM eventstore.events",
|
||||
dbRow: []interface{}{time.Time{}, models.EventType(""), uint64(5), Sequence(0), Data(nil), "", "", "", models.AggregateType("user"), "hodor", models.Version("")},
|
||||
expected: models.Event{AggregateID: "hodor", AggregateType: "user", Sequence: 5, Data: make(Data, 0)},
|
||||
dbRow: []interface{}{time.Time{}, repository.EventType(""), uint64(5), Sequence(0), Data(nil), "", "", "", repository.AggregateType("user"), "hodor", repository.Version("")},
|
||||
expected: repository.Event{AggregateID: "hodor", AggregateType: "user", Sequence: 5, Data: make(Data, 0)},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "event wrong dest type",
|
||||
args: args{
|
||||
columns: es_models.Columns_Event,
|
||||
columns: repository.Columns_Event,
|
||||
dest: new(uint64),
|
||||
},
|
||||
res: res{
|
||||
@ -254,8 +253,8 @@ func Test_prepareColumns(t *testing.T) {
|
||||
{
|
||||
name: "event query error",
|
||||
args: args{
|
||||
columns: es_models.Columns_Event,
|
||||
dest: new(models.Event),
|
||||
columns: repository.Columns_Event,
|
||||
dest: new(repository.Event),
|
||||
dbErr: sql.ErrConnDone,
|
||||
},
|
||||
res: res{
|
||||
@ -290,7 +289,7 @@ func Test_prepareColumns(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func prepareTestScan(err error, res []interface{}) scanner {
|
||||
func prepareTestScan(err error, res []interface{}) scan {
|
||||
return func(dests ...interface{}) error {
|
||||
if err != nil {
|
||||
return err
|
||||
@ -308,7 +307,7 @@ func prepareTestScan(err error, res []interface{}) scanner {
|
||||
|
||||
func Test_prepareCondition(t *testing.T) {
|
||||
type args struct {
|
||||
filters []*models.Filter
|
||||
filters []*repository.Filter
|
||||
}
|
||||
type res struct {
|
||||
clause string
|
||||
@ -332,7 +331,7 @@ func Test_prepareCondition(t *testing.T) {
|
||||
{
|
||||
name: "empty filters",
|
||||
args: args{
|
||||
filters: []*es_models.Filter{},
|
||||
filters: []*repository.Filter{},
|
||||
},
|
||||
res: res{
|
||||
clause: "",
|
||||
@ -342,8 +341,8 @@ func Test_prepareCondition(t *testing.T) {
|
||||
{
|
||||
name: "invalid condition",
|
||||
args: args{
|
||||
filters: []*es_models.Filter{
|
||||
es_models.NewFilter(es_models.Field_AggregateID, "wrong", es_models.Operation(-1)),
|
||||
filters: []*repository.Filter{
|
||||
repository.NewFilter(repository.Field_AggregateID, "wrong", repository.Operation(-1)),
|
||||
},
|
||||
},
|
||||
res: res{
|
||||
@ -354,27 +353,27 @@ func Test_prepareCondition(t *testing.T) {
|
||||
{
|
||||
name: "array as condition value",
|
||||
args: args{
|
||||
filters: []*es_models.Filter{
|
||||
es_models.NewFilter(es_models.Field_AggregateType, []es_models.AggregateType{"user", "org"}, es_models.Operation_In),
|
||||
filters: []*repository.Filter{
|
||||
repository.NewFilter(repository.Field_AggregateType, []repository.AggregateType{"user", "org"}, repository.Operation_In),
|
||||
},
|
||||
},
|
||||
res: res{
|
||||
clause: " WHERE aggregate_type = ANY(?)",
|
||||
values: []interface{}{pq.Array([]es_models.AggregateType{"user", "org"})},
|
||||
values: []interface{}{pq.Array([]repository.AggregateType{"user", "org"})},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "multiple filters",
|
||||
args: args{
|
||||
filters: []*es_models.Filter{
|
||||
es_models.NewFilter(es_models.Field_AggregateType, []es_models.AggregateType{"user", "org"}, es_models.Operation_In),
|
||||
es_models.NewFilter(es_models.Field_AggregateID, "1234", es_models.Operation_Equals),
|
||||
es_models.NewFilter(es_models.Field_EventType, []es_models.EventType{"user.created", "org.created"}, es_models.Operation_In),
|
||||
filters: []*repository.Filter{
|
||||
repository.NewFilter(repository.Field_AggregateType, []repository.AggregateType{"user", "org"}, repository.Operation_In),
|
||||
repository.NewFilter(repository.Field_AggregateID, "1234", repository.Operation_Equals),
|
||||
repository.NewFilter(repository.Field_EventType, []repository.EventType{"user.created", "org.created"}, repository.Operation_In),
|
||||
},
|
||||
},
|
||||
res: res{
|
||||
clause: " WHERE aggregate_type = ANY(?) AND aggregate_id = ? AND event_type = ANY(?)",
|
||||
values: []interface{}{pq.Array([]es_models.AggregateType{"user", "org"}), "1234", pq.Array([]es_models.EventType{"user.created", "org.created"})},
|
||||
values: []interface{}{pq.Array([]repository.AggregateType{"user", "org"}), "1234", pq.Array([]repository.EventType{"user.created", "org.created"})},
|
||||
},
|
||||
},
|
||||
}
|
||||
@ -399,7 +398,7 @@ func Test_prepareCondition(t *testing.T) {
|
||||
|
||||
func Test_buildQuery(t *testing.T) {
|
||||
type args struct {
|
||||
queryFactory *models.SearchQueryFactory
|
||||
queryFactory *repository.SearchQuery
|
||||
}
|
||||
type res struct {
|
||||
query string
|
||||
@ -427,35 +426,36 @@ func Test_buildQuery(t *testing.T) {
|
||||
{
|
||||
name: "with order by desc",
|
||||
args: args{
|
||||
queryFactory: es_models.NewSearchQueryFactory("user").OrderDesc(),
|
||||
// NewSearchQueryFactory("user").OrderDesc()
|
||||
queryFactory: &repository.SearchQuery{Desc: true},
|
||||
},
|
||||
res: res{
|
||||
query: "SELECT creation_date, event_type, event_sequence, previous_sequence, event_data, editor_service, editor_user, resource_owner, aggregate_type, aggregate_id, aggregate_version FROM eventstore.events WHERE aggregate_type = $1 ORDER BY event_sequence DESC",
|
||||
rowScanner: true,
|
||||
values: []interface{}{es_models.AggregateType("user")},
|
||||
values: []interface{}{repository.AggregateType("user")},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "with limit",
|
||||
args: args{
|
||||
queryFactory: es_models.NewSearchQueryFactory("user").Limit(5),
|
||||
queryFactory: repository.NewSearchQueryFactory("user").Limit(5),
|
||||
},
|
||||
res: res{
|
||||
query: "SELECT creation_date, event_type, event_sequence, previous_sequence, event_data, editor_service, editor_user, resource_owner, aggregate_type, aggregate_id, aggregate_version FROM eventstore.events WHERE aggregate_type = $1 ORDER BY event_sequence LIMIT $2",
|
||||
rowScanner: true,
|
||||
values: []interface{}{es_models.AggregateType("user"), uint64(5)},
|
||||
values: []interface{}{repository.AggregateType("user"), uint64(5)},
|
||||
limit: 5,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "with limit and order by desc",
|
||||
args: args{
|
||||
queryFactory: es_models.NewSearchQueryFactory("user").Limit(5).OrderDesc(),
|
||||
queryFactory: repository.NewSearchQueryFactory("user").Limit(5).OrderDesc(),
|
||||
},
|
||||
res: res{
|
||||
query: "SELECT creation_date, event_type, event_sequence, previous_sequence, event_data, editor_service, editor_user, resource_owner, aggregate_type, aggregate_id, aggregate_version FROM eventstore.events WHERE aggregate_type = $1 ORDER BY event_sequence DESC LIMIT $2",
|
||||
rowScanner: true,
|
||||
values: []interface{}{es_models.AggregateType("user"), uint64(5)},
|
||||
values: []interface{}{repository.AggregateType("user"), uint64(5)},
|
||||
limit: 5,
|
||||
},
|
||||
},
|
||||
@ -484,3 +484,5 @@ func Test_buildQuery(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// func buildQuery(t *testing.T, factory *reposear)
|
@ -1,6 +1,8 @@
|
||||
package repository
|
||||
package sql
|
||||
|
||||
import "database/sql/driver"
|
||||
import (
|
||||
"database/sql/driver"
|
||||
)
|
||||
|
||||
// Data represents a byte array that may be null.
|
||||
// Data implements the sql.Scanner interface
|
Loading…
x
Reference in New Issue
Block a user